"""Audio transcription and text-to-speech functions""" import os import asyncio import tempfile import soundfile as sf import torch from logger import logger from client import MCP_AVAILABLE, call_agent, get_mcp_session, get_cached_mcp_tools import config from models import TTS_AVAILABLE, WHISPER_AVAILABLE, initialize_tts_model, initialize_whisper_model import spaces try: import nest_asyncio except ImportError: nest_asyncio = None async def transcribe_audio_gemini(audio_path: str) -> str: """Transcribe audio using Gemini MCP transcribe_audio tool""" if not MCP_AVAILABLE: return "" try: session = await get_mcp_session() if session is None: logger.warning("MCP session not available for transcription") return "" tools = await get_cached_mcp_tools() transcribe_tool = None for tool in tools: if tool.name == "transcribe_audio": transcribe_tool = tool logger.info(f"Found MCP transcribe_audio tool: {tool.name}") break if not transcribe_tool: logger.warning("transcribe_audio MCP tool not found, falling back to generate_content") # Fallback to using generate_content audio_path_abs = os.path.abspath(audio_path) files = [{"path": audio_path_abs}] system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts." user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises." result = await call_agent( user_prompt=user_prompt, system_prompt=system_prompt, files=files, model=config.GEMINI_MODEL_LITE, temperature=0.2 ) return result.strip() # Use the transcribe_audio tool audio_path_abs = os.path.abspath(audio_path) result = await session.call_tool( transcribe_tool.name, arguments={"audio_path": audio_path_abs} ) if hasattr(result, 'content') and result.content: for item in result.content: if hasattr(item, 'text'): transcribed_text = item.text.strip() if transcribed_text: logger.info(f"✅ Transcribed via MCP transcribe_audio tool: {transcribed_text[:50]}...") return transcribed_text logger.warning("MCP transcribe_audio returned empty result") return "" except Exception as e: logger.error(f"Gemini transcription error: {e}") return "" @spaces.GPU(max_duration=60) def transcribe_audio_whisper(audio_path: str) -> str: """Transcribe audio using Whisper model from Hugging Face""" if not WHISPER_AVAILABLE: logger.warning("[ASR] Whisper not available for transcription") return "" try: logger.info(f"[ASR] Starting Whisper transcription for: {audio_path}") if config.global_whisper_model is None: logger.info("[ASR] Whisper model not loaded, initializing...") initialize_whisper_model() if config.global_whisper_model is None: logger.error("[ASR] Failed to initialize Whisper model") return "" # Extract processor and model from stored dict processor = config.global_whisper_model["processor"] model = config.global_whisper_model["model"] logger.info("[ASR] Loading audio file...") # Load audio using torchaudio (imported from models) from models import torchaudio if torchaudio is None: logger.error("[ASR] torchaudio not available") return "" waveform, sample_rate = torchaudio.load(audio_path) # Resample to 16kHz if needed (Whisper expects 16kHz) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) sample_rate = 16000 logger.info("[ASR] Processing audio with Whisper...") # Process audio inputs = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt") # Move inputs to same device as model device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} logger.info("[ASR] Running Whisper transcription...") # Generate transcription with torch.no_grad(): generated_ids = model.generate(**inputs) # Decode transcription transcribed_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() if transcribed_text: logger.info(f"[ASR] ✅ Transcription successful: {transcribed_text[:100]}...") logger.info(f"[ASR] Transcription length: {len(transcribed_text)} characters") else: logger.warning("[ASR] Whisper returned empty transcription") return transcribed_text except Exception as e: logger.error(f"[ASR] Whisper transcription error: {e}") import traceback logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}") return "" def transcribe_audio(audio): """Transcribe audio to text using Whisper (primary) or Gemini MCP (fallback)""" if audio is None: logger.warning("[ASR] No audio provided") return "" try: # Convert audio input to file path if isinstance(audio, str): audio_path = audio elif isinstance(audio, tuple): sample_rate, audio_data = audio logger.info(f"[ASR] Processing audio tuple: sample_rate={sample_rate}, data_shape={audio_data.shape if hasattr(audio_data, 'shape') else 'unknown'}") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: sf.write(tmp_file.name, audio_data, samplerate=sample_rate) audio_path = tmp_file.name logger.info(f"[ASR] Created temporary audio file: {audio_path}") else: audio_path = audio logger.info(f"[ASR] Attempting transcription with Whisper (primary method)...") # Try Whisper first (primary method) if WHISPER_AVAILABLE: try: transcribed = transcribe_audio_whisper(audio_path) if transcribed: logger.info(f"[ASR] ✅ Successfully transcribed via Whisper: {transcribed[:50]}...") # Clean up temp file if we created it if isinstance(audio, tuple) and os.path.exists(audio_path): try: os.unlink(audio_path) except: pass return transcribed else: logger.warning("[ASR] Whisper transcription returned empty, trying fallback...") except Exception as e: logger.error(f"[ASR] Whisper transcription failed: {e}, trying fallback...") else: logger.warning("[ASR] Whisper not available, trying Gemini fallback...") # Fallback to Gemini MCP if Whisper fails or is unavailable if MCP_AVAILABLE: try: logger.info("[ASR] Attempting transcription with Gemini MCP (fallback)...") loop = asyncio.get_event_loop() if loop.is_running(): if nest_asyncio: transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path)) if transcribed: logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...") # Clean up temp file if we created it if isinstance(audio, tuple) and os.path.exists(audio_path): try: os.unlink(audio_path) except: pass return transcribed else: logger.error("[ASR] nest_asyncio not available for nested async transcription") else: transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path)) if transcribed: logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...") # Clean up temp file if we created it if isinstance(audio, tuple) and os.path.exists(audio_path): try: os.unlink(audio_path) except: pass return transcribed except Exception as e: logger.error(f"[ASR] Gemini MCP transcription error: {e}") # Clean up temp file if we created it if isinstance(audio, tuple) and os.path.exists(audio_path): try: os.unlink(audio_path) except: pass logger.warning("[ASR] All transcription methods failed") return "" except Exception as e: logger.error(f"[ASR] Transcription error: {e}") import traceback logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}") return "" async def generate_speech_mcp(text: str) -> str: """Generate speech using MCP text_to_speech tool (fallback path).""" if not MCP_AVAILABLE: return None try: session = await get_mcp_session() if session is None: logger.warning("MCP session not available for TTS") return None tools = await get_cached_mcp_tools() tts_tool = None for tool in tools: if tool.name == "text_to_speech": tts_tool = tool logger.info(f"Found MCP text_to_speech tool: {tool.name}") break if not tts_tool: # Fallback: search for any TTS-related tool for tool in tools: tool_name_lower = tool.name.lower() if "tts" in tool_name_lower or "speech" in tool_name_lower or "synthesize" in tool_name_lower: tts_tool = tool logger.info(f"Found MCP TTS tool (fallback): {tool.name}") break if tts_tool: result = await session.call_tool( tts_tool.name, arguments={"text": text, "language": "en"} ) if hasattr(result, 'content') and result.content: for item in result.content: if hasattr(item, 'text'): text_result = item.text # Check if it's a signal to use local TTS if text_result == "USE_LOCAL_TTS": logger.info("MCP TTS tool indicates client-side TTS should be used") return None # Return None to trigger client-side TTS elif os.path.exists(text_result): return text_result elif hasattr(item, 'data') and item.data: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: tmp_file.write(item.data) return tmp_file.name return None except Exception as e: logger.warning(f"MCP TTS error: {e}") return None def _generate_speech_via_mcp(text: str): """Helper to generate speech via MCP in a synchronous context.""" if not MCP_AVAILABLE: return None try: loop = asyncio.get_event_loop() if loop.is_running(): if nest_asyncio: audio_path = nest_asyncio.run(generate_speech_mcp(text)) else: logger.error("nest_asyncio not available for nested async TTS via MCP") return None else: audio_path = loop.run_until_complete(generate_speech_mcp(text)) if audio_path: logger.info("Generated speech via MCP") return audio_path except Exception as e: logger.warning(f"MCP TTS error (sync wrapper): {e}") return None @spaces.GPU(max_duration=60) def generate_speech(text: str): """Generate speech from text using local maya1 TTS model (with MCP fallback). The primary path uses the local TTS model (maya-research/maya1). MCP-based TTS is only used as a last-resort fallback if the local model is unavailable or fails. """ if not text or len(text.strip()) == 0: logger.warning("[TTS] Empty text provided") return None logger.info(f"[TTS] Generating speech for text: {text[:50]}...") if not TTS_AVAILABLE: logger.error("[TTS] TTS library not installed. Please install TTS to use voice generation.") # As a last resort, try MCP-based TTS if available return _generate_speech_via_mcp(text) if config.global_tts_model is None: logger.info("[TTS] TTS model not loaded, initializing...") initialize_tts_model() if config.global_tts_model is None: logger.error("[TTS] TTS model not available. Please check dependencies.") return _generate_speech_via_mcp(text) try: logger.info("[TTS] Running TTS generation...") wav = config.global_tts_model.tts(text) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: sf.write(tmp_file.name, wav, samplerate=22050) logger.info(f"[TTS] ✅ Speech generated successfully: {tmp_file.name}") return tmp_file.name except Exception as e: logger.error(f"[TTS] TTS error (local maya1): {e}") import traceback logger.debug(f"[TTS] Full traceback: {traceback.format_exc()}") return _generate_speech_via_mcp(text)