MedLLM-Agent / voice.py
Y Phung Nguyen
Upd whisper loader
ccc284e
raw
history blame
21.2 kB
"""Audio transcription and text-to-speech functions"""
import os
import asyncio
import tempfile
import soundfile as sf
import torch
import numpy as np
from logger import logger
from client import MCP_AVAILABLE, call_agent, get_mcp_session, get_cached_mcp_tools
import config
from models import TTS_AVAILABLE, WHISPER_AVAILABLE, initialize_tts_model, initialize_whisper_model
import spaces
try:
import nest_asyncio
except ImportError:
nest_asyncio = None
async def transcribe_audio_gemini(audio_path: str) -> str:
"""Transcribe audio using Gemini MCP transcribe_audio tool"""
if not MCP_AVAILABLE:
return ""
try:
session = await get_mcp_session()
if session is None:
logger.warning("MCP session not available for transcription")
return ""
tools = await get_cached_mcp_tools()
transcribe_tool = None
for tool in tools:
if tool.name == "transcribe_audio":
transcribe_tool = tool
logger.info(f"Found MCP transcribe_audio tool: {tool.name}")
break
if not transcribe_tool:
logger.warning("transcribe_audio MCP tool not found, falling back to generate_content")
# Fallback to using generate_content
audio_path_abs = os.path.abspath(audio_path)
files = [{"path": audio_path_abs}]
system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts."
user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises."
result = await call_agent(
user_prompt=user_prompt,
system_prompt=system_prompt,
files=files,
model=config.GEMINI_MODEL_LITE,
temperature=0.2
)
return result.strip()
# Use the transcribe_audio tool
audio_path_abs = os.path.abspath(audio_path)
result = await session.call_tool(
transcribe_tool.name,
arguments={"audio_path": audio_path_abs}
)
if hasattr(result, 'content') and result.content:
for item in result.content:
if hasattr(item, 'text'):
transcribed_text = item.text.strip()
if transcribed_text:
logger.info(f"✅ Transcribed via MCP transcribe_audio tool: {transcribed_text[:50]}...")
return transcribed_text
logger.warning("MCP transcribe_audio returned empty result")
return ""
except Exception as e:
logger.error(f"Gemini transcription error: {e}")
return ""
@spaces.GPU(max_duration=60)
def transcribe_audio_whisper(audio_path: str) -> str:
"""Transcribe audio using Whisper model from Hugging Face"""
if not WHISPER_AVAILABLE:
logger.warning("[ASR] Whisper not available for transcription")
return ""
try:
logger.info(f"[ASR] Starting Whisper transcription for: {audio_path}")
if config.global_whisper_model is None:
logger.info("[ASR] Whisper model not loaded, initializing now (on-demand)...")
try:
initialize_whisper_model()
if config.global_whisper_model is None:
logger.error("[ASR] Failed to initialize Whisper model - check logs for errors")
return ""
else:
logger.info("[ASR] ✅ Whisper model loaded successfully on-demand!")
except Exception as e:
logger.error(f"[ASR] Error initializing Whisper model: {e}")
import traceback
logger.error(f"[ASR] Initialization traceback: {traceback.format_exc()}")
return ""
if config.global_whisper_model is None:
logger.error("[ASR] Whisper model is still None after initialization attempt")
return ""
# Extract processor and model from stored dict
processor = config.global_whisper_model["processor"]
model = config.global_whisper_model["model"]
logger.info("[ASR] Loading audio file...")
import torch
import numpy as np
# Check if audio file exists
if not os.path.exists(audio_path):
logger.error(f"[ASR] Audio file not found: {audio_path}")
return ""
try:
# Use soundfile to load audio (more reliable, doesn't require torchcodec)
logger.info(f"[ASR] Loading audio with soundfile: {audio_path}")
audio_data, sample_rate = sf.read(audio_path, dtype='float32')
logger.info(f"[ASR] Loaded audio with soundfile: shape={audio_data.shape}, sample_rate={sample_rate}, dtype={audio_data.dtype}")
# Convert to torch tensor and ensure it's 2D (channels, samples)
if len(audio_data.shape) == 1:
# Mono audio - add channel dimension
waveform = torch.from_numpy(audio_data).unsqueeze(0)
else:
# Multi-channel - transpose to (channels, samples)
waveform = torch.from_numpy(audio_data).T
logger.info(f"[ASR] Converted to tensor: shape={waveform.shape}, dtype={waveform.dtype}")
# Ensure audio is mono (single channel)
if waveform.shape[0] > 1:
logger.info(f"[ASR] Converting {waveform.shape[0]}-channel audio to mono")
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 16kHz if needed (Whisper expects 16kHz)
if sample_rate != 16000:
logger.info(f"[ASR] Resampling from {sample_rate}Hz to 16000Hz")
# Use scipy or librosa for resampling if available, otherwise use simple interpolation
try:
from scipy import signal
# Resample using scipy
num_samples = int(len(waveform[0]) * 16000 / sample_rate)
resampled = signal.resample(waveform[0].numpy(), num_samples)
waveform = torch.from_numpy(resampled).unsqueeze(0)
sample_rate = 16000
logger.info(f"[ASR] Resampled using scipy: new shape={waveform.shape}")
except ImportError:
# Fallback: simple linear interpolation (scipy not available)
logger.info("[ASR] scipy not available, using simple linear interpolation for resampling")
num_samples = int(len(waveform[0]) * 16000 / sample_rate)
waveform_1d = waveform[0].numpy()
indices = np.linspace(0, len(waveform_1d) - 1, num_samples)
resampled = np.interp(indices, np.arange(len(waveform_1d)), waveform_1d)
waveform = torch.from_numpy(resampled).unsqueeze(0)
sample_rate = 16000
logger.info(f"[ASR] Resampled using simple interpolation: new shape={waveform.shape}")
logger.info(f"[ASR] Audio ready: shape={waveform.shape}, sample_rate={sample_rate}")
logger.info("[ASR] Processing audio with Whisper processor...")
# Process audio - convert to numpy and ensure it's the right shape
audio_array = waveform.squeeze().numpy()
logger.info(f"[ASR] Audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}")
# Process audio
inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt")
logger.info(f"[ASR] Processor inputs: {list(inputs.keys())}")
# Move inputs to same device as model
device = next(model.parameters()).device
logger.info(f"[ASR] Model device: {device}")
inputs = {k: v.to(device) for k, v in inputs.items()}
logger.info("[ASR] Running Whisper model.generate()...")
# Generate transcription with proper parameters
# Whisper expects input_features as the main parameter
if "input_features" not in inputs:
logger.error(f"[ASR] Missing input_features in processor output. Keys: {list(inputs.keys())}")
return ""
input_features = inputs["input_features"]
logger.info(f"[ASR] Input features shape: {input_features.shape}, dtype: {input_features.dtype}")
# Convert input features to match model dtype (float16)
model_dtype = next(model.parameters()).dtype
if input_features.dtype != model_dtype:
logger.info(f"[ASR] Converting input features from {input_features.dtype} to {model_dtype} to match model")
input_features = input_features.to(dtype=model_dtype)
logger.info(f"[ASR] Converted input features dtype: {input_features.dtype}")
with torch.no_grad():
try:
# Whisper generate with proper parameters
generated_ids = model.generate(
input_features,
max_length=448, # Whisper default max length
num_beams=5,
language=None, # Auto-detect language
task="transcribe",
return_timestamps=False
)
logger.info(f"[ASR] Generated IDs shape: {generated_ids.shape}, dtype: {generated_ids.dtype}")
logger.info(f"[ASR] Generated IDs sample: {generated_ids[0][:20] if len(generated_ids) > 0 else 'empty'}")
except Exception as gen_error:
logger.error(f"[ASR] Error in model.generate(): {gen_error}")
import traceback
logger.error(f"[ASR] Generate traceback: {traceback.format_exc()}")
# Try simpler generation without optional parameters
logger.info("[ASR] Retrying with minimal parameters...")
try:
# Ensure dtype is correct for retry too
if input_features.dtype != model_dtype:
input_features = input_features.to(dtype=model_dtype)
generated_ids = model.generate(input_features)
logger.info(f"[ASR] Retry successful, generated IDs shape: {generated_ids.shape}")
except Exception as retry_error:
logger.error(f"[ASR] Retry also failed: {retry_error}")
return ""
logger.info("[ASR] Decoding transcription...")
# Decode transcription
transcribed_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
if transcribed_text:
logger.info(f"[ASR] ✅ Transcription successful: {transcribed_text[:100]}...")
logger.info(f"[ASR] Transcription length: {len(transcribed_text)} characters")
else:
logger.warning("[ASR] Whisper returned empty transcription")
logger.warning(f"[ASR] Generated IDs: {generated_ids}")
logger.warning(f"[ASR] Decoded (before strip): {processor.batch_decode(generated_ids, skip_special_tokens=False)[0]}")
return transcribed_text
except Exception as audio_error:
logger.error(f"[ASR] Error processing audio file: {audio_error}")
import traceback
logger.error(f"[ASR] Audio processing traceback: {traceback.format_exc()}")
return ""
except Exception as e:
logger.error(f"[ASR] Whisper transcription error: {e}")
import traceback
logger.error(f"[ASR] Full traceback: {traceback.format_exc()}")
return ""
def transcribe_audio(audio):
"""Transcribe audio to text using Whisper (primary) or Gemini MCP (fallback)"""
if audio is None:
logger.warning("[ASR] No audio provided")
return ""
try:
# Convert audio input to file path
if isinstance(audio, str):
audio_path = audio
elif isinstance(audio, tuple):
sample_rate, audio_data = audio
logger.info(f"[ASR] Processing audio tuple: sample_rate={sample_rate}, data_shape={audio_data.shape if hasattr(audio_data, 'shape') else 'unknown'}")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, audio_data, samplerate=sample_rate)
audio_path = tmp_file.name
logger.info(f"[ASR] Created temporary audio file: {audio_path}")
else:
audio_path = audio
logger.info(f"[ASR] Attempting transcription with Whisper (primary method)...")
# Try Whisper first (primary method)
if WHISPER_AVAILABLE:
try:
transcribed = transcribe_audio_whisper(audio_path)
if transcribed:
logger.info(f"[ASR] ✅ Successfully transcribed via Whisper: {transcribed[:50]}...")
# Clean up temp file if we created it
if isinstance(audio, tuple) and os.path.exists(audio_path):
try:
os.unlink(audio_path)
except:
pass
return transcribed
else:
logger.warning("[ASR] Whisper transcription returned empty, trying fallback...")
except Exception as e:
logger.error(f"[ASR] Whisper transcription failed: {e}, trying fallback...")
else:
logger.warning("[ASR] Whisper not available, trying Gemini fallback...")
# Fallback to Gemini MCP if Whisper fails or is unavailable
if MCP_AVAILABLE:
try:
logger.info("[ASR] Attempting transcription with Gemini MCP (fallback)...")
loop = asyncio.get_event_loop()
if loop.is_running():
if nest_asyncio:
transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path))
if transcribed:
logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...")
# Clean up temp file if we created it
if isinstance(audio, tuple) and os.path.exists(audio_path):
try:
os.unlink(audio_path)
except:
pass
return transcribed
else:
logger.error("[ASR] nest_asyncio not available for nested async transcription")
else:
transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path))
if transcribed:
logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...")
# Clean up temp file if we created it
if isinstance(audio, tuple) and os.path.exists(audio_path):
try:
os.unlink(audio_path)
except:
pass
return transcribed
except Exception as e:
logger.error(f"[ASR] Gemini MCP transcription error: {e}")
# Clean up temp file if we created it
if isinstance(audio, tuple) and os.path.exists(audio_path):
try:
os.unlink(audio_path)
except:
pass
logger.warning("[ASR] All transcription methods failed")
return ""
except Exception as e:
logger.error(f"[ASR] Transcription error: {e}")
import traceback
logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}")
return ""
async def generate_speech_mcp(text: str) -> str:
"""Generate speech using MCP text_to_speech tool (fallback path)."""
if not MCP_AVAILABLE:
return None
try:
session = await get_mcp_session()
if session is None:
logger.warning("MCP session not available for TTS")
return None
tools = await get_cached_mcp_tools()
tts_tool = None
for tool in tools:
if tool.name == "text_to_speech":
tts_tool = tool
logger.info(f"Found MCP text_to_speech tool: {tool.name}")
break
if not tts_tool:
# Fallback: search for any TTS-related tool
for tool in tools:
tool_name_lower = tool.name.lower()
if "tts" in tool_name_lower or "speech" in tool_name_lower or "synthesize" in tool_name_lower:
tts_tool = tool
logger.info(f"Found MCP TTS tool (fallback): {tool.name}")
break
if tts_tool:
result = await session.call_tool(
tts_tool.name,
arguments={"text": text, "language": "en"}
)
if hasattr(result, 'content') and result.content:
for item in result.content:
if hasattr(item, 'text'):
text_result = item.text
# Check if it's a signal to use local TTS
if text_result == "USE_LOCAL_TTS":
logger.info("MCP TTS tool indicates client-side TTS should be used")
return None # Return None to trigger client-side TTS
elif os.path.exists(text_result):
return text_result
elif hasattr(item, 'data') and item.data:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
tmp_file.write(item.data)
return tmp_file.name
return None
except Exception as e:
logger.warning(f"MCP TTS error: {e}")
return None
def _generate_speech_via_mcp(text: str):
"""Helper to generate speech via MCP in a synchronous context."""
if not MCP_AVAILABLE:
return None
try:
loop = asyncio.get_event_loop()
if loop.is_running():
if nest_asyncio:
audio_path = nest_asyncio.run(generate_speech_mcp(text))
else:
logger.error("nest_asyncio not available for nested async TTS via MCP")
return None
else:
audio_path = loop.run_until_complete(generate_speech_mcp(text))
if audio_path:
logger.info("Generated speech via MCP")
return audio_path
except Exception as e:
logger.warning(f"MCP TTS error (sync wrapper): {e}")
return None
@spaces.GPU(max_duration=120)
def generate_speech(text: str):
"""Generate speech from text using local maya1 TTS model (with MCP fallback).
The primary path uses the local TTS model (maya-research/maya1). MCP-based
TTS is only used as a last-resort fallback if the local model is unavailable
or fails.
"""
if not text or len(text.strip()) == 0:
logger.warning("[TTS] Empty text provided")
return None
logger.info(f"[TTS] Generating speech for text: {text[:50]}...")
if not TTS_AVAILABLE:
logger.error("[TTS] TTS library not installed. Please install TTS to use voice generation.")
# As a last resort, try MCP-based TTS if available
return _generate_speech_via_mcp(text)
if config.global_tts_model is None:
logger.info("[TTS] TTS model not loaded, initializing...")
initialize_tts_model()
if config.global_tts_model is None:
logger.error("[TTS] TTS model not available. Please check dependencies.")
return _generate_speech_via_mcp(text)
try:
logger.info("[TTS] Running TTS generation...")
wav = config.global_tts_model.tts(text)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, wav, samplerate=22050)
logger.info(f"[TTS] ✅ Speech generated successfully: {tmp_file.name}")
return tmp_file.name
except Exception as e:
logger.error(f"[TTS] TTS error (local maya1): {e}")
import traceback
logger.debug(f"[TTS] Full traceback: {traceback.format_exc()}")
return _generate_speech_via_mcp(text)