Spaces:
Running
on
Zero
Running
on
Zero
| """Audio transcription and text-to-speech functions""" | |
| import os | |
| import asyncio | |
| import tempfile | |
| import soundfile as sf | |
| import torch | |
| from logger import logger | |
| from client import MCP_AVAILABLE, call_agent, get_mcp_session, get_cached_mcp_tools | |
| import config | |
| from models import TTS_AVAILABLE, WHISPER_AVAILABLE, initialize_tts_model, initialize_whisper_model | |
| import spaces | |
| try: | |
| import nest_asyncio | |
| except ImportError: | |
| nest_asyncio = None | |
| async def transcribe_audio_gemini(audio_path: str) -> str: | |
| """Transcribe audio using Gemini MCP transcribe_audio tool""" | |
| if not MCP_AVAILABLE: | |
| return "" | |
| try: | |
| session = await get_mcp_session() | |
| if session is None: | |
| logger.warning("MCP session not available for transcription") | |
| return "" | |
| tools = await get_cached_mcp_tools() | |
| transcribe_tool = None | |
| for tool in tools: | |
| if tool.name == "transcribe_audio": | |
| transcribe_tool = tool | |
| logger.info(f"Found MCP transcribe_audio tool: {tool.name}") | |
| break | |
| if not transcribe_tool: | |
| logger.warning("transcribe_audio MCP tool not found, falling back to generate_content") | |
| # Fallback to using generate_content | |
| audio_path_abs = os.path.abspath(audio_path) | |
| files = [{"path": audio_path_abs}] | |
| system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts." | |
| user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises." | |
| result = await call_agent( | |
| user_prompt=user_prompt, | |
| system_prompt=system_prompt, | |
| files=files, | |
| model=config.GEMINI_MODEL_LITE, | |
| temperature=0.2 | |
| ) | |
| return result.strip() | |
| # Use the transcribe_audio tool | |
| audio_path_abs = os.path.abspath(audio_path) | |
| result = await session.call_tool( | |
| transcribe_tool.name, | |
| arguments={"audio_path": audio_path_abs} | |
| ) | |
| if hasattr(result, 'content') and result.content: | |
| for item in result.content: | |
| if hasattr(item, 'text'): | |
| transcribed_text = item.text.strip() | |
| if transcribed_text: | |
| logger.info(f"✅ Transcribed via MCP transcribe_audio tool: {transcribed_text[:50]}...") | |
| return transcribed_text | |
| logger.warning("MCP transcribe_audio returned empty result") | |
| return "" | |
| except Exception as e: | |
| logger.error(f"Gemini transcription error: {e}") | |
| return "" | |
| def transcribe_audio_whisper(audio_path: str) -> str: | |
| """Transcribe audio using Whisper model from Hugging Face""" | |
| if not WHISPER_AVAILABLE: | |
| logger.warning("[ASR] Whisper not available for transcription") | |
| return "" | |
| try: | |
| logger.info(f"[ASR] Starting Whisper transcription for: {audio_path}") | |
| if config.global_whisper_model is None: | |
| logger.info("[ASR] Whisper model not loaded, initializing...") | |
| initialize_whisper_model() | |
| if config.global_whisper_model is None: | |
| logger.error("[ASR] Failed to initialize Whisper model") | |
| return "" | |
| # Extract processor and model from stored dict | |
| processor = config.global_whisper_model["processor"] | |
| model = config.global_whisper_model["model"] | |
| logger.info("[ASR] Loading audio file...") | |
| # Load audio using torchaudio (imported from models) | |
| from models import torchaudio | |
| if torchaudio is None: | |
| logger.error("[ASR] torchaudio not available") | |
| return "" | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Resample to 16kHz if needed (Whisper expects 16kHz) | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| waveform = resampler(waveform) | |
| sample_rate = 16000 | |
| logger.info("[ASR] Processing audio with Whisper...") | |
| # Process audio | |
| inputs = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt") | |
| # Move inputs to same device as model | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| logger.info("[ASR] Running Whisper transcription...") | |
| # Generate transcription | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs) | |
| # Decode transcription | |
| transcribed_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| if transcribed_text: | |
| logger.info(f"[ASR] ✅ Transcription successful: {transcribed_text[:100]}...") | |
| logger.info(f"[ASR] Transcription length: {len(transcribed_text)} characters") | |
| else: | |
| logger.warning("[ASR] Whisper returned empty transcription") | |
| return transcribed_text | |
| except Exception as e: | |
| logger.error(f"[ASR] Whisper transcription error: {e}") | |
| import traceback | |
| logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}") | |
| return "" | |
| def transcribe_audio(audio): | |
| """Transcribe audio to text using Whisper (primary) or Gemini MCP (fallback)""" | |
| if audio is None: | |
| logger.warning("[ASR] No audio provided") | |
| return "" | |
| try: | |
| # Convert audio input to file path | |
| if isinstance(audio, str): | |
| audio_path = audio | |
| elif isinstance(audio, tuple): | |
| sample_rate, audio_data = audio | |
| logger.info(f"[ASR] Processing audio tuple: sample_rate={sample_rate}, data_shape={audio_data.shape if hasattr(audio_data, 'shape') else 'unknown'}") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| sf.write(tmp_file.name, audio_data, samplerate=sample_rate) | |
| audio_path = tmp_file.name | |
| logger.info(f"[ASR] Created temporary audio file: {audio_path}") | |
| else: | |
| audio_path = audio | |
| logger.info(f"[ASR] Attempting transcription with Whisper (primary method)...") | |
| # Try Whisper first (primary method) | |
| if WHISPER_AVAILABLE: | |
| try: | |
| transcribed = transcribe_audio_whisper(audio_path) | |
| if transcribed: | |
| logger.info(f"[ASR] ✅ Successfully transcribed via Whisper: {transcribed[:50]}...") | |
| # Clean up temp file if we created it | |
| if isinstance(audio, tuple) and os.path.exists(audio_path): | |
| try: | |
| os.unlink(audio_path) | |
| except: | |
| pass | |
| return transcribed | |
| else: | |
| logger.warning("[ASR] Whisper transcription returned empty, trying fallback...") | |
| except Exception as e: | |
| logger.error(f"[ASR] Whisper transcription failed: {e}, trying fallback...") | |
| else: | |
| logger.warning("[ASR] Whisper not available, trying Gemini fallback...") | |
| # Fallback to Gemini MCP if Whisper fails or is unavailable | |
| if MCP_AVAILABLE: | |
| try: | |
| logger.info("[ASR] Attempting transcription with Gemini MCP (fallback)...") | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| if nest_asyncio: | |
| transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path)) | |
| if transcribed: | |
| logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...") | |
| # Clean up temp file if we created it | |
| if isinstance(audio, tuple) and os.path.exists(audio_path): | |
| try: | |
| os.unlink(audio_path) | |
| except: | |
| pass | |
| return transcribed | |
| else: | |
| logger.error("[ASR] nest_asyncio not available for nested async transcription") | |
| else: | |
| transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path)) | |
| if transcribed: | |
| logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...") | |
| # Clean up temp file if we created it | |
| if isinstance(audio, tuple) and os.path.exists(audio_path): | |
| try: | |
| os.unlink(audio_path) | |
| except: | |
| pass | |
| return transcribed | |
| except Exception as e: | |
| logger.error(f"[ASR] Gemini MCP transcription error: {e}") | |
| # Clean up temp file if we created it | |
| if isinstance(audio, tuple) and os.path.exists(audio_path): | |
| try: | |
| os.unlink(audio_path) | |
| except: | |
| pass | |
| logger.warning("[ASR] All transcription methods failed") | |
| return "" | |
| except Exception as e: | |
| logger.error(f"[ASR] Transcription error: {e}") | |
| import traceback | |
| logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}") | |
| return "" | |
| async def generate_speech_mcp(text: str) -> str: | |
| """Generate speech using MCP text_to_speech tool (fallback path).""" | |
| if not MCP_AVAILABLE: | |
| return None | |
| try: | |
| session = await get_mcp_session() | |
| if session is None: | |
| logger.warning("MCP session not available for TTS") | |
| return None | |
| tools = await get_cached_mcp_tools() | |
| tts_tool = None | |
| for tool in tools: | |
| if tool.name == "text_to_speech": | |
| tts_tool = tool | |
| logger.info(f"Found MCP text_to_speech tool: {tool.name}") | |
| break | |
| if not tts_tool: | |
| # Fallback: search for any TTS-related tool | |
| for tool in tools: | |
| tool_name_lower = tool.name.lower() | |
| if "tts" in tool_name_lower or "speech" in tool_name_lower or "synthesize" in tool_name_lower: | |
| tts_tool = tool | |
| logger.info(f"Found MCP TTS tool (fallback): {tool.name}") | |
| break | |
| if tts_tool: | |
| result = await session.call_tool( | |
| tts_tool.name, | |
| arguments={"text": text, "language": "en"} | |
| ) | |
| if hasattr(result, 'content') and result.content: | |
| for item in result.content: | |
| if hasattr(item, 'text'): | |
| text_result = item.text | |
| # Check if it's a signal to use local TTS | |
| if text_result == "USE_LOCAL_TTS": | |
| logger.info("MCP TTS tool indicates client-side TTS should be used") | |
| return None # Return None to trigger client-side TTS | |
| elif os.path.exists(text_result): | |
| return text_result | |
| elif hasattr(item, 'data') and item.data: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| tmp_file.write(item.data) | |
| return tmp_file.name | |
| return None | |
| except Exception as e: | |
| logger.warning(f"MCP TTS error: {e}") | |
| return None | |
| def _generate_speech_via_mcp(text: str): | |
| """Helper to generate speech via MCP in a synchronous context.""" | |
| if not MCP_AVAILABLE: | |
| return None | |
| try: | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| if nest_asyncio: | |
| audio_path = nest_asyncio.run(generate_speech_mcp(text)) | |
| else: | |
| logger.error("nest_asyncio not available for nested async TTS via MCP") | |
| return None | |
| else: | |
| audio_path = loop.run_until_complete(generate_speech_mcp(text)) | |
| if audio_path: | |
| logger.info("Generated speech via MCP") | |
| return audio_path | |
| except Exception as e: | |
| logger.warning(f"MCP TTS error (sync wrapper): {e}") | |
| return None | |
| def generate_speech(text: str): | |
| """Generate speech from text using local maya1 TTS model (with MCP fallback). | |
| The primary path uses the local TTS model (maya-research/maya1). MCP-based | |
| TTS is only used as a last-resort fallback if the local model is unavailable | |
| or fails. | |
| """ | |
| if not text or len(text.strip()) == 0: | |
| logger.warning("[TTS] Empty text provided") | |
| return None | |
| logger.info(f"[TTS] Generating speech for text: {text[:50]}...") | |
| if not TTS_AVAILABLE: | |
| logger.error("[TTS] TTS library not installed. Please install TTS to use voice generation.") | |
| # As a last resort, try MCP-based TTS if available | |
| return _generate_speech_via_mcp(text) | |
| if config.global_tts_model is None: | |
| logger.info("[TTS] TTS model not loaded, initializing...") | |
| initialize_tts_model() | |
| if config.global_tts_model is None: | |
| logger.error("[TTS] TTS model not available. Please check dependencies.") | |
| return _generate_speech_via_mcp(text) | |
| try: | |
| logger.info("[TTS] Running TTS generation...") | |
| wav = config.global_tts_model.tts(text) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| sf.write(tmp_file.name, wav, samplerate=22050) | |
| logger.info(f"[TTS] ✅ Speech generated successfully: {tmp_file.name}") | |
| return tmp_file.name | |
| except Exception as e: | |
| logger.error(f"[TTS] TTS error (local maya1): {e}") | |
| import traceback | |
| logger.debug(f"[TTS] Full traceback: {traceback.format_exc()}") | |
| return _generate_speech_via_mcp(text) | |