MedLLM-Agent / voice.py
Y Phung Nguyen
Fix PDF upload, add Whisper ASR, and enhance model status display
af9efda
raw
history blame
14.5 kB
"""Audio transcription and text-to-speech functions"""
import os
import asyncio
import tempfile
import soundfile as sf
import torch
from logger import logger
from client import MCP_AVAILABLE, call_agent, get_mcp_session, get_cached_mcp_tools
import config
from models import TTS_AVAILABLE, WHISPER_AVAILABLE, initialize_tts_model, initialize_whisper_model
import spaces
try:
import nest_asyncio
except ImportError:
nest_asyncio = None
async def transcribe_audio_gemini(audio_path: str) -> str:
"""Transcribe audio using Gemini MCP transcribe_audio tool"""
if not MCP_AVAILABLE:
return ""
try:
session = await get_mcp_session()
if session is None:
logger.warning("MCP session not available for transcription")
return ""
tools = await get_cached_mcp_tools()
transcribe_tool = None
for tool in tools:
if tool.name == "transcribe_audio":
transcribe_tool = tool
logger.info(f"Found MCP transcribe_audio tool: {tool.name}")
break
if not transcribe_tool:
logger.warning("transcribe_audio MCP tool not found, falling back to generate_content")
# Fallback to using generate_content
audio_path_abs = os.path.abspath(audio_path)
files = [{"path": audio_path_abs}]
system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts."
user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises."
result = await call_agent(
user_prompt=user_prompt,
system_prompt=system_prompt,
files=files,
model=config.GEMINI_MODEL_LITE,
temperature=0.2
)
return result.strip()
# Use the transcribe_audio tool
audio_path_abs = os.path.abspath(audio_path)
result = await session.call_tool(
transcribe_tool.name,
arguments={"audio_path": audio_path_abs}
)
if hasattr(result, 'content') and result.content:
for item in result.content:
if hasattr(item, 'text'):
transcribed_text = item.text.strip()
if transcribed_text:
logger.info(f"✅ Transcribed via MCP transcribe_audio tool: {transcribed_text[:50]}...")
return transcribed_text
logger.warning("MCP transcribe_audio returned empty result")
return ""
except Exception as e:
logger.error(f"Gemini transcription error: {e}")
return ""
@spaces.GPU(max_duration=60)
def transcribe_audio_whisper(audio_path: str) -> str:
"""Transcribe audio using Whisper model from Hugging Face"""
if not WHISPER_AVAILABLE:
logger.warning("[ASR] Whisper not available for transcription")
return ""
try:
logger.info(f"[ASR] Starting Whisper transcription for: {audio_path}")
if config.global_whisper_model is None:
logger.info("[ASR] Whisper model not loaded, initializing...")
initialize_whisper_model()
if config.global_whisper_model is None:
logger.error("[ASR] Failed to initialize Whisper model")
return ""
# Extract processor and model from stored dict
processor = config.global_whisper_model["processor"]
model = config.global_whisper_model["model"]
logger.info("[ASR] Loading audio file...")
# Load audio using torchaudio (imported from models)
from models import torchaudio
if torchaudio is None:
logger.error("[ASR] torchaudio not available")
return ""
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 16kHz if needed (Whisper expects 16kHz)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
sample_rate = 16000
logger.info("[ASR] Processing audio with Whisper...")
# Process audio
inputs = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt")
# Move inputs to same device as model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
logger.info("[ASR] Running Whisper transcription...")
# Generate transcription
with torch.no_grad():
generated_ids = model.generate(**inputs)
# Decode transcription
transcribed_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
if transcribed_text:
logger.info(f"[ASR] ✅ Transcription successful: {transcribed_text[:100]}...")
logger.info(f"[ASR] Transcription length: {len(transcribed_text)} characters")
else:
logger.warning("[ASR] Whisper returned empty transcription")
return transcribed_text
except Exception as e:
logger.error(f"[ASR] Whisper transcription error: {e}")
import traceback
logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}")
return ""
def transcribe_audio(audio):
"""Transcribe audio to text using Whisper (primary) or Gemini MCP (fallback)"""
if audio is None:
logger.warning("[ASR] No audio provided")
return ""
try:
# Convert audio input to file path
if isinstance(audio, str):
audio_path = audio
elif isinstance(audio, tuple):
sample_rate, audio_data = audio
logger.info(f"[ASR] Processing audio tuple: sample_rate={sample_rate}, data_shape={audio_data.shape if hasattr(audio_data, 'shape') else 'unknown'}")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, audio_data, samplerate=sample_rate)
audio_path = tmp_file.name
logger.info(f"[ASR] Created temporary audio file: {audio_path}")
else:
audio_path = audio
logger.info(f"[ASR] Attempting transcription with Whisper (primary method)...")
# Try Whisper first (primary method)
if WHISPER_AVAILABLE:
try:
transcribed = transcribe_audio_whisper(audio_path)
if transcribed:
logger.info(f"[ASR] ✅ Successfully transcribed via Whisper: {transcribed[:50]}...")
# Clean up temp file if we created it
if isinstance(audio, tuple) and os.path.exists(audio_path):
try:
os.unlink(audio_path)
except:
pass
return transcribed
else:
logger.warning("[ASR] Whisper transcription returned empty, trying fallback...")
except Exception as e:
logger.error(f"[ASR] Whisper transcription failed: {e}, trying fallback...")
else:
logger.warning("[ASR] Whisper not available, trying Gemini fallback...")
# Fallback to Gemini MCP if Whisper fails or is unavailable
if MCP_AVAILABLE:
try:
logger.info("[ASR] Attempting transcription with Gemini MCP (fallback)...")
loop = asyncio.get_event_loop()
if loop.is_running():
if nest_asyncio:
transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path))
if transcribed:
logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...")
# Clean up temp file if we created it
if isinstance(audio, tuple) and os.path.exists(audio_path):
try:
os.unlink(audio_path)
except:
pass
return transcribed
else:
logger.error("[ASR] nest_asyncio not available for nested async transcription")
else:
transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path))
if transcribed:
logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...")
# Clean up temp file if we created it
if isinstance(audio, tuple) and os.path.exists(audio_path):
try:
os.unlink(audio_path)
except:
pass
return transcribed
except Exception as e:
logger.error(f"[ASR] Gemini MCP transcription error: {e}")
# Clean up temp file if we created it
if isinstance(audio, tuple) and os.path.exists(audio_path):
try:
os.unlink(audio_path)
except:
pass
logger.warning("[ASR] All transcription methods failed")
return ""
except Exception as e:
logger.error(f"[ASR] Transcription error: {e}")
import traceback
logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}")
return ""
async def generate_speech_mcp(text: str) -> str:
"""Generate speech using MCP text_to_speech tool (fallback path)."""
if not MCP_AVAILABLE:
return None
try:
session = await get_mcp_session()
if session is None:
logger.warning("MCP session not available for TTS")
return None
tools = await get_cached_mcp_tools()
tts_tool = None
for tool in tools:
if tool.name == "text_to_speech":
tts_tool = tool
logger.info(f"Found MCP text_to_speech tool: {tool.name}")
break
if not tts_tool:
# Fallback: search for any TTS-related tool
for tool in tools:
tool_name_lower = tool.name.lower()
if "tts" in tool_name_lower or "speech" in tool_name_lower or "synthesize" in tool_name_lower:
tts_tool = tool
logger.info(f"Found MCP TTS tool (fallback): {tool.name}")
break
if tts_tool:
result = await session.call_tool(
tts_tool.name,
arguments={"text": text, "language": "en"}
)
if hasattr(result, 'content') and result.content:
for item in result.content:
if hasattr(item, 'text'):
text_result = item.text
# Check if it's a signal to use local TTS
if text_result == "USE_LOCAL_TTS":
logger.info("MCP TTS tool indicates client-side TTS should be used")
return None # Return None to trigger client-side TTS
elif os.path.exists(text_result):
return text_result
elif hasattr(item, 'data') and item.data:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
tmp_file.write(item.data)
return tmp_file.name
return None
except Exception as e:
logger.warning(f"MCP TTS error: {e}")
return None
def _generate_speech_via_mcp(text: str):
"""Helper to generate speech via MCP in a synchronous context."""
if not MCP_AVAILABLE:
return None
try:
loop = asyncio.get_event_loop()
if loop.is_running():
if nest_asyncio:
audio_path = nest_asyncio.run(generate_speech_mcp(text))
else:
logger.error("nest_asyncio not available for nested async TTS via MCP")
return None
else:
audio_path = loop.run_until_complete(generate_speech_mcp(text))
if audio_path:
logger.info("Generated speech via MCP")
return audio_path
except Exception as e:
logger.warning(f"MCP TTS error (sync wrapper): {e}")
return None
@spaces.GPU(max_duration=60)
def generate_speech(text: str):
"""Generate speech from text using local maya1 TTS model (with MCP fallback).
The primary path uses the local TTS model (maya-research/maya1). MCP-based
TTS is only used as a last-resort fallback if the local model is unavailable
or fails.
"""
if not text or len(text.strip()) == 0:
logger.warning("[TTS] Empty text provided")
return None
logger.info(f"[TTS] Generating speech for text: {text[:50]}...")
if not TTS_AVAILABLE:
logger.error("[TTS] TTS library not installed. Please install TTS to use voice generation.")
# As a last resort, try MCP-based TTS if available
return _generate_speech_via_mcp(text)
if config.global_tts_model is None:
logger.info("[TTS] TTS model not loaded, initializing...")
initialize_tts_model()
if config.global_tts_model is None:
logger.error("[TTS] TTS model not available. Please check dependencies.")
return _generate_speech_via_mcp(text)
try:
logger.info("[TTS] Running TTS generation...")
wav = config.global_tts_model.tts(text)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, wav, samplerate=22050)
logger.info(f"[TTS] ✅ Speech generated successfully: {tmp_file.name}")
return tmp_file.name
except Exception as e:
logger.error(f"[TTS] TTS error (local maya1): {e}")
import traceback
logger.debug(f"[TTS] Full traceback: {traceback.format_exc()}")
return _generate_speech_via_mcp(text)