Spaces:
Running
on
Zero
Running
on
Zero
| """Audio transcription and text-to-speech functions""" | |
| import os | |
| import asyncio | |
| import tempfile | |
| import soundfile as sf | |
| import torch | |
| import numpy as np | |
| from logger import logger | |
| from client import MCP_AVAILABLE, call_agent, get_mcp_session, get_cached_mcp_tools | |
| import config | |
| from models import TTS_AVAILABLE, SNAC_AVAILABLE, WHISPER_AVAILABLE, initialize_tts_model, initialize_whisper_model | |
| # Maya1 constants (from maya1 docs) | |
| CODE_START_TOKEN_ID = 128257 | |
| CODE_END_TOKEN_ID = 128258 | |
| CODE_TOKEN_OFFSET = 128266 | |
| SNAC_MIN_ID = 128266 | |
| SNAC_MAX_ID = 156937 | |
| SOH_ID = 128259 | |
| EOH_ID = 128260 | |
| SOA_ID = 128261 | |
| TEXT_EOT_ID = 128009 | |
| AUDIO_SAMPLE_RATE = 24000 | |
| # Default voice description for Maya1 - female, soft and bright voice | |
| DEFAULT_VOICE_DESCRIPTION = "Realistic female voice in the 20s age with a american accent. High pitch, bright timbre, conversational pacing, warm tone delivery at medium intensity, podcast domain, narrator role, friendly delivery" | |
| # Chunking configuration | |
| MAX_CHUNK_LENGTH = 600 # Maximum characters per chunk for TTS | |
| MIN_CHUNK_LENGTH = 100 # Minimum characters per chunk (to avoid too many tiny chunks) | |
| import spaces | |
| try: | |
| import nest_asyncio | |
| except ImportError: | |
| nest_asyncio = None | |
| async def transcribe_audio_gemini(audio_path: str) -> str: | |
| """Transcribe audio using Gemini MCP transcribe_audio tool""" | |
| if not MCP_AVAILABLE: | |
| return "" | |
| try: | |
| session = await get_mcp_session() | |
| if session is None: | |
| logger.warning("MCP session not available for transcription") | |
| return "" | |
| tools = await get_cached_mcp_tools() | |
| transcribe_tool = None | |
| for tool in tools: | |
| if tool.name == "transcribe_audio": | |
| transcribe_tool = tool | |
| logger.info(f"Found MCP transcribe_audio tool: {tool.name}") | |
| break | |
| if not transcribe_tool: | |
| logger.warning("transcribe_audio MCP tool not found, falling back to generate_content") | |
| # Fallback to using generate_content | |
| audio_path_abs = os.path.abspath(audio_path) | |
| files = [{"path": audio_path_abs}] | |
| system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts." | |
| user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises." | |
| result = await call_agent( | |
| user_prompt=user_prompt, | |
| system_prompt=system_prompt, | |
| files=files, | |
| model=config.GEMINI_MODEL_LITE, | |
| temperature=0.2 | |
| ) | |
| return result.strip() | |
| # Use the transcribe_audio tool | |
| audio_path_abs = os.path.abspath(audio_path) | |
| result = await session.call_tool( | |
| transcribe_tool.name, | |
| arguments={"audio_path": audio_path_abs} | |
| ) | |
| if hasattr(result, 'content') and result.content: | |
| for item in result.content: | |
| if hasattr(item, 'text'): | |
| transcribed_text = item.text.strip() | |
| if transcribed_text: | |
| logger.info(f"✅ Transcribed via MCP transcribe_audio tool: {transcribed_text[:50]}...") | |
| return transcribed_text | |
| logger.warning("MCP transcribe_audio returned empty result") | |
| return "" | |
| except Exception as e: | |
| logger.error(f"Gemini transcription error: {e}") | |
| return "" | |
| def transcribe_audio_whisper(audio_path: str) -> str: | |
| """Transcribe audio using Whisper model from Hugging Face""" | |
| if not WHISPER_AVAILABLE: | |
| logger.warning("[ASR] Whisper not available for transcription") | |
| return "" | |
| try: | |
| logger.info(f"[ASR] Starting Whisper transcription for: {audio_path}") | |
| if config.global_whisper_model is None: | |
| logger.info("[ASR] Whisper model not loaded, initializing now (on-demand)...") | |
| try: | |
| initialize_whisper_model() | |
| if config.global_whisper_model is None: | |
| logger.error("[ASR] Failed to initialize Whisper model - check logs for errors") | |
| return "" | |
| else: | |
| logger.info("[ASR] ✅ Whisper model loaded successfully on-demand!") | |
| except Exception as e: | |
| logger.error(f"[ASR] Error initializing Whisper model: {e}") | |
| import traceback | |
| logger.error(f"[ASR] Initialization traceback: {traceback.format_exc()}") | |
| return "" | |
| if config.global_whisper_model is None: | |
| logger.error("[ASR] Whisper model is still None after initialization attempt") | |
| return "" | |
| # Extract processor and model from stored dict | |
| processor = config.global_whisper_model["processor"] | |
| model = config.global_whisper_model["model"] | |
| logger.info("[ASR] Loading audio file...") | |
| import torch | |
| import numpy as np | |
| # Check if audio file exists | |
| if not os.path.exists(audio_path): | |
| logger.error(f"[ASR] Audio file not found: {audio_path}") | |
| return "" | |
| try: | |
| # Use soundfile to load audio (more reliable, doesn't require torchcodec) | |
| logger.info(f"[ASR] Loading audio with soundfile: {audio_path}") | |
| audio_data, sample_rate = sf.read(audio_path, dtype='float32') | |
| logger.info(f"[ASR] Loaded audio with soundfile: shape={audio_data.shape}, sample_rate={sample_rate}, dtype={audio_data.dtype}") | |
| # Convert to torch tensor and ensure it's 2D (channels, samples) | |
| if len(audio_data.shape) == 1: | |
| # Mono audio - add channel dimension | |
| waveform = torch.from_numpy(audio_data).unsqueeze(0) | |
| else: | |
| # Multi-channel - transpose to (channels, samples) | |
| waveform = torch.from_numpy(audio_data).T | |
| logger.info(f"[ASR] Converted to tensor: shape={waveform.shape}, dtype={waveform.dtype}") | |
| # Ensure audio is mono (single channel) | |
| if waveform.shape[0] > 1: | |
| logger.info(f"[ASR] Converting {waveform.shape[0]}-channel audio to mono") | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # Resample to 16kHz if needed (Whisper expects 16kHz) | |
| if sample_rate != 16000: | |
| logger.info(f"[ASR] Resampling from {sample_rate}Hz to 16000Hz") | |
| # Use scipy or librosa for resampling if available, otherwise use simple interpolation | |
| try: | |
| from scipy import signal | |
| # Resample using scipy | |
| num_samples = int(len(waveform[0]) * 16000 / sample_rate) | |
| resampled = signal.resample(waveform[0].numpy(), num_samples) | |
| waveform = torch.from_numpy(resampled).unsqueeze(0) | |
| sample_rate = 16000 | |
| logger.info(f"[ASR] Resampled using scipy: new shape={waveform.shape}") | |
| except ImportError: | |
| # Fallback: simple linear interpolation (scipy not available) | |
| logger.info("[ASR] scipy not available, using simple linear interpolation for resampling") | |
| num_samples = int(len(waveform[0]) * 16000 / sample_rate) | |
| waveform_1d = waveform[0].numpy() | |
| indices = np.linspace(0, len(waveform_1d) - 1, num_samples) | |
| resampled = np.interp(indices, np.arange(len(waveform_1d)), waveform_1d) | |
| waveform = torch.from_numpy(resampled).unsqueeze(0) | |
| sample_rate = 16000 | |
| logger.info(f"[ASR] Resampled using simple interpolation: new shape={waveform.shape}") | |
| logger.info(f"[ASR] Audio ready: shape={waveform.shape}, sample_rate={sample_rate}") | |
| logger.info("[ASR] Processing audio with Whisper processor...") | |
| # Process audio - convert to numpy and ensure it's the right shape | |
| audio_array = waveform.squeeze().numpy() | |
| logger.info(f"[ASR] Audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}") | |
| # Process audio | |
| inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt") | |
| logger.info(f"[ASR] Processor inputs: {list(inputs.keys())}") | |
| # Move inputs to same device as model | |
| device = next(model.parameters()).device | |
| logger.info(f"[ASR] Model device: {device}") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| logger.info("[ASR] Running Whisper model.generate()...") | |
| # Generate transcription with proper parameters | |
| # Whisper expects input_features as the main parameter | |
| if "input_features" not in inputs: | |
| logger.error(f"[ASR] Missing input_features in processor output. Keys: {list(inputs.keys())}") | |
| return "" | |
| input_features = inputs["input_features"] | |
| logger.info(f"[ASR] Input features shape: {input_features.shape}, dtype: {input_features.dtype}") | |
| # Convert input features to match model dtype (float16) | |
| model_dtype = next(model.parameters()).dtype | |
| if input_features.dtype != model_dtype: | |
| logger.info(f"[ASR] Converting input features from {input_features.dtype} to {model_dtype} to match model") | |
| input_features = input_features.to(dtype=model_dtype) | |
| logger.info(f"[ASR] Converted input features dtype: {input_features.dtype}") | |
| with torch.no_grad(): | |
| try: | |
| # Whisper generate with proper parameters | |
| generated_ids = model.generate( | |
| input_features, | |
| max_length=448, # Whisper default max length | |
| num_beams=5, | |
| language=None, # Auto-detect language | |
| task="transcribe", | |
| return_timestamps=False | |
| ) | |
| logger.info(f"[ASR] Generated IDs shape: {generated_ids.shape}, dtype: {generated_ids.dtype}") | |
| logger.info(f"[ASR] Generated IDs sample: {generated_ids[0][:20] if len(generated_ids) > 0 else 'empty'}") | |
| except Exception as gen_error: | |
| logger.error(f"[ASR] Error in model.generate(): {gen_error}") | |
| import traceback | |
| logger.error(f"[ASR] Generate traceback: {traceback.format_exc()}") | |
| # Try simpler generation without optional parameters | |
| logger.info("[ASR] Retrying with minimal parameters...") | |
| try: | |
| # Ensure dtype is correct for retry too | |
| if input_features.dtype != model_dtype: | |
| input_features = input_features.to(dtype=model_dtype) | |
| generated_ids = model.generate(input_features) | |
| logger.info(f"[ASR] Retry successful, generated IDs shape: {generated_ids.shape}") | |
| except Exception as retry_error: | |
| logger.error(f"[ASR] Retry also failed: {retry_error}") | |
| return "" | |
| logger.info("[ASR] Decoding transcription...") | |
| # Decode transcription | |
| transcribed_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| if transcribed_text: | |
| logger.info(f"[ASR] ✅ Transcription successful: {transcribed_text[:100]}...") | |
| logger.info(f"[ASR] Transcription length: {len(transcribed_text)} characters") | |
| else: | |
| logger.warning("[ASR] Whisper returned empty transcription") | |
| logger.warning(f"[ASR] Generated IDs: {generated_ids}") | |
| logger.warning(f"[ASR] Decoded (before strip): {processor.batch_decode(generated_ids, skip_special_tokens=False)[0]}") | |
| return transcribed_text | |
| except Exception as audio_error: | |
| logger.error(f"[ASR] Error processing audio file: {audio_error}") | |
| import traceback | |
| logger.error(f"[ASR] Audio processing traceback: {traceback.format_exc()}") | |
| return "" | |
| except Exception as e: | |
| logger.error(f"[ASR] Whisper transcription error: {e}") | |
| import traceback | |
| logger.error(f"[ASR] Full traceback: {traceback.format_exc()}") | |
| return "" | |
| def transcribe_audio(audio): | |
| """Transcribe audio to text using Whisper (primary) or Gemini MCP (fallback)""" | |
| if audio is None: | |
| logger.warning("[ASR] No audio provided") | |
| return "" | |
| try: | |
| # Convert audio input to file path | |
| if isinstance(audio, str): | |
| audio_path = audio | |
| elif isinstance(audio, tuple): | |
| sample_rate, audio_data = audio | |
| logger.info(f"[ASR] Processing audio tuple: sample_rate={sample_rate}, data_shape={audio_data.shape if hasattr(audio_data, 'shape') else 'unknown'}") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| sf.write(tmp_file.name, audio_data, samplerate=sample_rate) | |
| audio_path = tmp_file.name | |
| logger.info(f"[ASR] Created temporary audio file: {audio_path}") | |
| else: | |
| audio_path = audio | |
| logger.info(f"[ASR] Attempting transcription with Whisper (primary method)...") | |
| # Try Whisper first (primary method) | |
| if WHISPER_AVAILABLE: | |
| try: | |
| transcribed = transcribe_audio_whisper(audio_path) | |
| if transcribed: | |
| logger.info(f"[ASR] ✅ Successfully transcribed via Whisper: {transcribed[:50]}...") | |
| # Clean up temp file if we created it | |
| if isinstance(audio, tuple) and os.path.exists(audio_path): | |
| try: | |
| os.unlink(audio_path) | |
| except: | |
| pass | |
| return transcribed | |
| else: | |
| logger.warning("[ASR] Whisper transcription returned empty, trying fallback...") | |
| except Exception as e: | |
| logger.error(f"[ASR] Whisper transcription failed: {e}, trying fallback...") | |
| else: | |
| logger.warning("[ASR] Whisper not available, trying Gemini fallback...") | |
| # Fallback to Gemini MCP if Whisper fails or is unavailable | |
| if MCP_AVAILABLE: | |
| try: | |
| logger.info("[ASR] Attempting transcription with Gemini MCP (fallback)...") | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| if nest_asyncio: | |
| transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path)) | |
| if transcribed: | |
| logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...") | |
| # Clean up temp file if we created it | |
| if isinstance(audio, tuple) and os.path.exists(audio_path): | |
| try: | |
| os.unlink(audio_path) | |
| except: | |
| pass | |
| return transcribed | |
| else: | |
| logger.error("[ASR] nest_asyncio not available for nested async transcription") | |
| else: | |
| transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path)) | |
| if transcribed: | |
| logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...") | |
| # Clean up temp file if we created it | |
| if isinstance(audio, tuple) and os.path.exists(audio_path): | |
| try: | |
| os.unlink(audio_path) | |
| except: | |
| pass | |
| return transcribed | |
| except Exception as e: | |
| logger.error(f"[ASR] Gemini MCP transcription error: {e}") | |
| # Clean up temp file if we created it | |
| if isinstance(audio, tuple) and os.path.exists(audio_path): | |
| try: | |
| os.unlink(audio_path) | |
| except: | |
| pass | |
| logger.warning("[ASR] All transcription methods failed") | |
| return "" | |
| except Exception as e: | |
| logger.error(f"[ASR] Transcription error: {e}") | |
| import traceback | |
| logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}") | |
| return "" | |
| async def generate_speech_mcp(text: str) -> str: | |
| """Generate speech using MCP text_to_speech tool (fallback path).""" | |
| if not MCP_AVAILABLE: | |
| return None | |
| try: | |
| session = await get_mcp_session() | |
| if session is None: | |
| logger.warning("MCP session not available for TTS") | |
| return None | |
| tools = await get_cached_mcp_tools() | |
| tts_tool = None | |
| for tool in tools: | |
| if tool.name == "text_to_speech": | |
| tts_tool = tool | |
| logger.info(f"Found MCP text_to_speech tool: {tool.name}") | |
| break | |
| if not tts_tool: | |
| # Fallback: search for any TTS-related tool | |
| for tool in tools: | |
| tool_name_lower = tool.name.lower() | |
| if "tts" in tool_name_lower or "speech" in tool_name_lower or "synthesize" in tool_name_lower: | |
| tts_tool = tool | |
| logger.info(f"Found MCP TTS tool (fallback): {tool.name}") | |
| break | |
| if tts_tool: | |
| result = await session.call_tool( | |
| tts_tool.name, | |
| arguments={"text": text, "language": "en"} | |
| ) | |
| if hasattr(result, 'content') and result.content: | |
| for item in result.content: | |
| if hasattr(item, 'text'): | |
| text_result = item.text | |
| # Check if it's a signal to use local TTS | |
| if text_result == "USE_LOCAL_TTS": | |
| logger.info("MCP TTS tool indicates client-side TTS should be used") | |
| return None # Return None to trigger client-side TTS | |
| elif os.path.exists(text_result): | |
| return text_result | |
| elif hasattr(item, 'data') and item.data: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| tmp_file.write(item.data) | |
| return tmp_file.name | |
| return None | |
| except Exception as e: | |
| logger.warning(f"MCP TTS error: {e}") | |
| return None | |
| def _generate_speech_via_mcp(text: str): | |
| """Helper to generate speech via MCP in a synchronous context.""" | |
| if not MCP_AVAILABLE: | |
| return None | |
| try: | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| if nest_asyncio: | |
| audio_path = nest_asyncio.run(generate_speech_mcp(text)) | |
| else: | |
| logger.error("nest_asyncio not available for nested async TTS via MCP") | |
| return None | |
| else: | |
| audio_path = loop.run_until_complete(generate_speech_mcp(text)) | |
| if audio_path: | |
| logger.info("Generated speech via MCP") | |
| return audio_path | |
| except Exception as e: | |
| logger.warning(f"MCP TTS error (sync wrapper): {e}") | |
| return None | |
| def preprocess_text_for_tts(text: str) -> str: | |
| """Remove titles and introductory paragraphs from text. | |
| Removes: | |
| - Lines that are very short (likely titles) | |
| - Lines that are all caps (likely titles) | |
| - Lines ending with colons (likely section headers) | |
| - Paragraphs starting with common intro phrases like "here's", "this is", etc. | |
| """ | |
| if not text: | |
| return text | |
| lines = text.split('\n') | |
| processed_lines = [] | |
| skip_next_paragraph = False | |
| for i, line in enumerate(lines): | |
| line_stripped = line.strip() | |
| # Skip empty lines | |
| if not line_stripped: | |
| processed_lines.append('') | |
| continue | |
| # Skip very short lines (likely titles) - less than 30 chars | |
| if len(line_stripped) < 30: | |
| # Check if it's all caps (likely a title) | |
| if line_stripped.isupper() or (line_stripped.endswith(':') and len(line_stripped) < 50): | |
| logger.debug(f"[TTS] Skipping title: {line_stripped[:50]}") | |
| skip_next_paragraph = True | |
| continue | |
| # Check for introductory phrases at the start of paragraphs | |
| line_lower = line_stripped.lower() | |
| intro_phrases = [ | |
| "here's", "here is", "this is", "this was", "let me", "let's", | |
| "i'll", "i will", "i'm going to", "i want to", "i'd like to", | |
| "in this", "in the following", "below is", "below are" | |
| ] | |
| # Check if this line starts with an intro phrase | |
| starts_with_intro = any(line_lower.startswith(phrase) for phrase in intro_phrases) | |
| # If it's a short paragraph starting with intro phrase, skip it | |
| if starts_with_intro and len(line_stripped) < 150: | |
| logger.debug(f"[TTS] Skipping intro paragraph: {line_stripped[:80]}...") | |
| skip_next_paragraph = True | |
| continue | |
| # If we're skipping the next paragraph and this is a short one, skip it | |
| if skip_next_paragraph and len(line_stripped) < 200: | |
| logger.debug(f"[TTS] Skipping paragraph after title/intro: {line_stripped[:80]}...") | |
| skip_next_paragraph = False | |
| continue | |
| skip_next_paragraph = False | |
| processed_lines.append(line) | |
| result = '\n'.join(processed_lines) | |
| # Clean up multiple consecutive newlines | |
| import re | |
| result = re.sub(r'\n{3,}', '\n\n', result) | |
| return result.strip() | |
| def chunk_text_for_tts(text: str, max_length: int = MAX_CHUNK_LENGTH, min_length: int = MIN_CHUNK_LENGTH) -> list[str]: | |
| """Split text into chunks suitable for TTS generation. | |
| Tries to split at sentence boundaries first, then at paragraph boundaries, | |
| and finally at word boundaries if needed. | |
| """ | |
| if len(text) <= max_length: | |
| return [text] | |
| chunks = [] | |
| remaining = text | |
| while len(remaining) > max_length: | |
| # Try to find a good split point | |
| chunk = remaining[:max_length] | |
| # First, try to split at sentence boundary (., !, ?) | |
| sentence_end = max( | |
| chunk.rfind('. '), | |
| chunk.rfind('! '), | |
| chunk.rfind('? '), | |
| chunk.rfind('.\n'), | |
| chunk.rfind('!\n'), | |
| chunk.rfind('?\n') | |
| ) | |
| if sentence_end > min_length: | |
| chunk = remaining[:sentence_end + 1] | |
| remaining = remaining[sentence_end + 1:].lstrip() | |
| else: | |
| # Try paragraph boundary | |
| para_end = chunk.rfind('\n\n') | |
| if para_end > min_length: | |
| chunk = remaining[:para_end] | |
| remaining = remaining[para_end:].lstrip() | |
| else: | |
| # Try word boundary | |
| word_end = chunk.rfind(' ') | |
| if word_end > min_length: | |
| chunk = remaining[:word_end] | |
| remaining = remaining[word_end:].lstrip() | |
| else: | |
| # Force split at max_length | |
| chunk = remaining[:max_length] | |
| remaining = remaining[max_length:] | |
| if chunk.strip(): | |
| chunks.append(chunk.strip()) | |
| # Add remaining text | |
| if remaining.strip(): | |
| chunks.append(remaining.strip()) | |
| return chunks | |
| def build_maya1_prompt(tokenizer, description: str, text: str) -> str: | |
| """Build formatted prompt for Maya1. | |
| The description is used only for voice characteristics and should not be spoken. | |
| Only the text after the description tag should be synthesized. | |
| """ | |
| soh_token = tokenizer.decode([SOH_ID]) | |
| eoh_token = tokenizer.decode([EOH_ID]) | |
| soa_token = tokenizer.decode([SOA_ID]) | |
| sos_token = tokenizer.decode([CODE_START_TOKEN_ID]) | |
| eot_token = tokenizer.decode([TEXT_EOT_ID]) | |
| bos_token = tokenizer.bos_token | |
| # Ensure description is only metadata - add newline after description tag | |
| # to clearly separate it from the text to be spoken | |
| formatted_text = f'<description="{description}">\n{text}' | |
| prompt = ( | |
| soh_token + bos_token + formatted_text + eot_token + | |
| eoh_token + soa_token + sos_token | |
| ) | |
| # Log the prompt structure for debugging (without the actual description text) | |
| logger.debug(f"[TTS] Prompt structure: <description=\"...\">\\n[text to speak] (text length: {len(text)} chars)") | |
| return prompt | |
| def unpack_snac_from_7(snac_tokens: list) -> list: | |
| """Unpack 7-token SNAC frames to 3 hierarchical levels.""" | |
| if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID: | |
| snac_tokens = snac_tokens[:-1] | |
| frames = len(snac_tokens) // 7 | |
| snac_tokens = snac_tokens[:frames * 7] | |
| if frames == 0: | |
| return [[], [], []] | |
| l1, l2, l3 = [], [], [] | |
| for i in range(frames): | |
| slots = snac_tokens[i*7:(i+1)*7] | |
| l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096) | |
| l2.extend([ | |
| (slots[1] - CODE_TOKEN_OFFSET) % 4096, | |
| (slots[4] - CODE_TOKEN_OFFSET) % 4096, | |
| ]) | |
| l3.extend([ | |
| (slots[2] - CODE_TOKEN_OFFSET) % 4096, | |
| (slots[3] - CODE_TOKEN_OFFSET) % 4096, | |
| (slots[5] - CODE_TOKEN_OFFSET) % 4096, | |
| (slots[6] - CODE_TOKEN_OFFSET) % 4096, | |
| ]) | |
| return [l1, l2, l3] | |
| def _generate_speech_with_gpu(text: str, description: str = None): | |
| """Internal GPU-decorated function for TTS generation when TTS is available.""" | |
| if config.global_tts_model is None: | |
| logger.info("[TTS] TTS model not loaded, initializing...") | |
| initialize_tts_model() | |
| if config.global_tts_model is None: | |
| logger.error("[TTS] TTS model not available. Please check dependencies.") | |
| return None | |
| # Check if it's the new Maya1 format (dictionary) or old format | |
| if not isinstance(config.global_tts_model, dict): | |
| logger.error("[TTS] TTS model format is incorrect. Expected dictionary with model, tokenizer, snac_model.") | |
| return None | |
| try: | |
| model = config.global_tts_model["model"] | |
| tokenizer = config.global_tts_model["tokenizer"] | |
| snac_model = config.global_tts_model["snac_model"] | |
| # Use default description if not provided | |
| if description is None: | |
| description = DEFAULT_VOICE_DESCRIPTION | |
| logger.info("[TTS] Running Maya1 TTS generation...") | |
| logger.debug(f"[TTS] Voice description (metadata only, not spoken): {description[:80]}...") | |
| logger.debug(f"[TTS] Text to speak: {text[:100]}...") | |
| # Build prompt - description is metadata, only text should be spoken | |
| prompt = build_maya1_prompt(tokenizer, description, text) | |
| # Verify prompt structure - the description should be in the attribute, not in the spoken text | |
| if description.lower() in prompt.lower() and description.lower() not in f'<description="{description.lower()}">': | |
| logger.warning("[TTS] Warning: Description text appears in prompt outside of description attribute") | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| # Generate tokens | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1500, | |
| min_new_tokens=28, | |
| temperature=0.4, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| eos_token_id=CODE_END_TOKEN_ID, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| # Extract SNAC tokens | |
| generated_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist() | |
| # Find EOS and extract SNAC codes | |
| eos_idx = generated_ids.index(CODE_END_TOKEN_ID) if CODE_END_TOKEN_ID in generated_ids else len(generated_ids) | |
| snac_tokens = [t for t in generated_ids[:eos_idx] if SNAC_MIN_ID <= t <= SNAC_MAX_ID] | |
| if len(snac_tokens) < 7: | |
| logger.error(f"[TTS] Not enough tokens generated ({len(snac_tokens)}). Try different text or increase max_tokens.") | |
| return None | |
| # Unpack and decode | |
| levels = unpack_snac_from_7(snac_tokens) | |
| frames = len(levels[0]) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| codes_tensor = [torch.tensor(level, dtype=torch.long, device=device).unsqueeze(0) for level in levels] | |
| with torch.inference_mode(): | |
| z_q = snac_model.quantizer.from_codes(codes_tensor) | |
| audio = snac_model.decoder(z_q)[0, 0].cpu().numpy() | |
| # Trim warmup | |
| if len(audio) > 2048: | |
| audio = audio[2048:] | |
| # Convert to WAV and save to temporary file | |
| audio_int16 = (audio * 32767).astype(np.int16) | |
| # Create temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: | |
| tmp_path = tmp_file.name | |
| # Save audio | |
| sf.write(tmp_path, audio_int16, AUDIO_SAMPLE_RATE) | |
| duration = len(audio) / AUDIO_SAMPLE_RATE | |
| logger.info(f"[TTS] ✅ Speech generated successfully: {tmp_path} ({duration:.2f}s)") | |
| return tmp_path | |
| except Exception as e: | |
| logger.error(f"[TTS] TTS error (local maya1): {e}") | |
| import traceback | |
| logger.debug(f"[TTS] Full traceback: {traceback.format_exc()}") | |
| return None | |
| def _generate_speech_gpu_wrapper(text: str): | |
| """GPU wrapper for TTS generation - only called when TTS is available.""" | |
| return _generate_speech_with_gpu(text) | |
| def concatenate_audio_files(audio_files: list[str], output_path: str) -> str: | |
| """Concatenate multiple audio files into one. | |
| Args: | |
| audio_files: List of paths to audio files to concatenate | |
| output_path: Path to save the concatenated audio | |
| Returns: | |
| Path to the concatenated audio file | |
| """ | |
| if not audio_files: | |
| return None | |
| if len(audio_files) == 1: | |
| # Just return the single file | |
| return audio_files[0] | |
| try: | |
| # Load all audio files | |
| audio_data_list = [] | |
| sample_rate = None | |
| for audio_file in audio_files: | |
| if not os.path.exists(audio_file): | |
| logger.warning(f"[TTS] Audio file not found: {audio_file}, skipping...") | |
| continue | |
| data, sr = sf.read(audio_file, dtype='float32') | |
| if sample_rate is None: | |
| sample_rate = sr | |
| elif sample_rate != sr: | |
| logger.warning(f"[TTS] Sample rate mismatch: {sr} vs {sample_rate}, resampling...") | |
| # Resample to match first file | |
| try: | |
| from scipy import signal | |
| num_samples = int(len(data) * sample_rate / sr) | |
| data = signal.resample(data, num_samples) | |
| except ImportError: | |
| # Fallback: simple linear interpolation if scipy not available | |
| logger.warning("[TTS] scipy not available, using simple interpolation for resampling") | |
| num_samples = int(len(data) * sample_rate / sr) | |
| indices = np.linspace(0, len(data) - 1, num_samples) | |
| data = np.interp(indices, np.arange(len(data)), data) | |
| audio_data_list.append(data) | |
| if not audio_data_list: | |
| logger.error("[TTS] No valid audio files to concatenate") | |
| return None | |
| # Concatenate all audio | |
| concatenated = np.concatenate(audio_data_list) | |
| # Save concatenated audio | |
| sf.write(output_path, concatenated, sample_rate) | |
| logger.info(f"[TTS] Concatenated {len(audio_data_list)} audio chunks into {output_path}") | |
| # Clean up individual chunk files | |
| for audio_file in audio_files: | |
| try: | |
| if os.path.exists(audio_file) and audio_file != output_path: | |
| os.unlink(audio_file) | |
| except Exception as e: | |
| logger.debug(f"[TTS] Could not delete temp file {audio_file}: {e}") | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"[TTS] Error concatenating audio files: {e}") | |
| import traceback | |
| logger.debug(f"[TTS] Concatenation traceback: {traceback.format_exc()}") | |
| return None | |
| def generate_speech(text: str): | |
| """Generate speech from text using local maya1 TTS model (with MCP fallback). | |
| The primary path uses the local TTS model (maya-research/maya1). MCP-based | |
| TTS is only used as a last-resort fallback if the local model is unavailable | |
| or fails. | |
| This function checks TTS availability before attempting GPU allocation. | |
| For long texts, it automatically chunks the text and concatenates the audio. | |
| """ | |
| if not text or len(text.strip()) == 0: | |
| logger.warning("[TTS] Empty text provided") | |
| return None | |
| # Preprocess text: remove titles and intro paragraphs | |
| processed_text = preprocess_text_for_tts(text) | |
| if not processed_text or len(processed_text.strip()) == 0: | |
| logger.warning("[TTS] Text is empty after preprocessing") | |
| return None | |
| logger.info(f"[TTS] Generating speech for text (original: {len(text)} chars, processed: {len(processed_text)} chars)") | |
| # Check TTS availability first - avoid GPU allocation if not available | |
| # Use SNAC_AVAILABLE for Maya1, but keep TTS_AVAILABLE check for backward compatibility | |
| if not SNAC_AVAILABLE: | |
| logger.warning("[TTS] SNAC library not installed (required for Maya1). Trying MCP fallback...") | |
| # Try MCP-based TTS if available (doesn't require GPU) | |
| audio_path = _generate_speech_via_mcp(processed_text) | |
| if audio_path: | |
| logger.info(f"[TTS] ✅ Generated via MCP fallback: {audio_path}") | |
| return audio_path | |
| else: | |
| logger.error("[TTS] ❌ SNAC library not installed and MCP fallback failed. Please install: pip install snac") | |
| return None | |
| # Chunk text if it's too long | |
| chunks = chunk_text_for_tts(processed_text) | |
| logger.info(f"[TTS] Split text into {len(chunks)} chunk(s)") | |
| if len(chunks) == 1: | |
| # Single chunk - process normally | |
| try: | |
| audio_path = _generate_speech_gpu_wrapper(chunks[0]) | |
| if audio_path: | |
| return audio_path | |
| else: | |
| # GPU generation failed, try MCP fallback | |
| logger.warning("[TTS] Local TTS generation failed, trying MCP fallback...") | |
| return _generate_speech_via_mcp(processed_text) | |
| except Exception as e: | |
| logger.error(f"[TTS] GPU TTS generation error: {e}") | |
| import traceback | |
| logger.debug(f"[TTS] Full traceback: {traceback.format_exc()}") | |
| # Try MCP fallback on error | |
| logger.info("[TTS] Attempting MCP fallback after error...") | |
| return _generate_speech_via_mcp(processed_text) | |
| else: | |
| # Multiple chunks - process each and concatenate | |
| logger.info(f"[TTS] Processing {len(chunks)} chunks...") | |
| audio_files = [] | |
| for i, chunk in enumerate(chunks): | |
| logger.info(f"[TTS] Processing chunk {i+1}/{len(chunks)} ({len(chunk)} chars)...") | |
| try: | |
| chunk_audio = _generate_speech_gpu_wrapper(chunk) | |
| if chunk_audio and os.path.exists(chunk_audio): | |
| audio_files.append(chunk_audio) | |
| logger.info(f"[TTS] ✅ Chunk {i+1}/{len(chunks)} generated successfully") | |
| else: | |
| logger.warning(f"[TTS] ⚠️ Chunk {i+1}/{len(chunks)} generation failed, skipping...") | |
| except Exception as e: | |
| logger.error(f"[TTS] Error generating chunk {i+1}/{len(chunks)}: {e}") | |
| # Continue with other chunks | |
| if not audio_files: | |
| logger.error("[TTS] ❌ All chunks failed to generate. Trying MCP fallback...") | |
| return _generate_speech_via_mcp(processed_text) | |
| # Concatenate all audio chunks | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: | |
| output_path = tmp_file.name | |
| final_audio = concatenate_audio_files(audio_files, output_path) | |
| if final_audio: | |
| logger.info(f"[TTS] ✅ Successfully generated and concatenated {len(audio_files)} chunks") | |
| return final_audio | |
| else: | |
| logger.error("[TTS] ❌ Failed to concatenate audio chunks. Trying MCP fallback...") | |
| return _generate_speech_via_mcp(processed_text) | |