Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
f89165d
1
Parent(s):
20851fb
Upd ASR and TTS
Browse files- app.py +168 -6
- requirements.txt +4 -1
app.py
CHANGED
|
@@ -36,6 +36,11 @@ from langdetect import detect, LangDetectException
|
|
| 36 |
from duckduckgo_search import DDGS
|
| 37 |
import requests
|
| 38 |
from bs4 import BeautifulSoup
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 41 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -51,6 +56,8 @@ MEDSWIN_MODELS = {
|
|
| 51 |
}
|
| 52 |
DEFAULT_MEDICAL_MODEL = "MedSwin SFT"
|
| 53 |
EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1" # Domain-tuned medical embedding model
|
|
|
|
|
|
|
| 54 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 55 |
if not HF_TOKEN:
|
| 56 |
raise ValueError("HF_TOKEN not found in environment variables")
|
|
@@ -161,6 +168,8 @@ global_translation_tokenizer = None
|
|
| 161 |
global_medical_models = {}
|
| 162 |
global_medical_tokenizers = {}
|
| 163 |
global_file_info = {}
|
|
|
|
|
|
|
| 164 |
|
| 165 |
def initialize_translation_model():
|
| 166 |
"""Initialize DeepSeek-R1 model for translation purposes"""
|
|
@@ -196,6 +205,82 @@ def initialize_medical_model(model_name: str):
|
|
| 196 |
logger.info(f"Medical model {model_name} initialized successfully")
|
| 197 |
return global_medical_models[model_name], global_medical_tokenizers[model_name]
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
def detect_language(text: str) -> str:
|
| 200 |
"""Detect language of input text"""
|
| 201 |
try:
|
|
@@ -964,16 +1049,27 @@ def stream_chat(
|
|
| 964 |
if needs_translation and partial_response:
|
| 965 |
logger.info(f"Translating response back to {original_lang}...")
|
| 966 |
translated_response = translate_text(partial_response, target_lang=original_lang, source_lang="en")
|
| 967 |
-
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 971 |
|
| 972 |
except GeneratorExit:
|
| 973 |
stop_event.set()
|
| 974 |
thread.join()
|
| 975 |
raise
|
| 976 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 977 |
def create_demo():
|
| 978 |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
| 979 |
gr.HTML(TITLE)
|
|
@@ -1011,15 +1107,75 @@ def create_demo():
|
|
| 1011 |
type="messages"
|
| 1012 |
)
|
| 1013 |
with gr.Row(elem_classes="input-row"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1014 |
message_input = gr.Textbox(
|
| 1015 |
placeholder="Type your medical question here...",
|
| 1016 |
show_label=False,
|
| 1017 |
container=False,
|
| 1018 |
lines=1,
|
| 1019 |
-
scale=
|
| 1020 |
)
|
| 1021 |
submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
|
| 1022 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1023 |
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 1024 |
with gr.Row():
|
| 1025 |
use_rag = gr.Checkbox(
|
|
@@ -1143,8 +1299,14 @@ def create_demo():
|
|
| 1143 |
return demo
|
| 1144 |
|
| 1145 |
if __name__ == "__main__":
|
| 1146 |
-
#
|
|
|
|
| 1147 |
logger.info("Initializing default medical model (MedSwin SFT)...")
|
| 1148 |
initialize_medical_model(DEFAULT_MEDICAL_MODEL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1149 |
demo = create_demo()
|
| 1150 |
demo.launch()
|
|
|
|
| 36 |
from duckduckgo_search import DDGS
|
| 37 |
import requests
|
| 38 |
from bs4 import BeautifulSoup
|
| 39 |
+
import whisper
|
| 40 |
+
from TTS.api import TTS
|
| 41 |
+
import numpy as np
|
| 42 |
+
import soundfile as sf
|
| 43 |
+
import tempfile
|
| 44 |
|
| 45 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 46 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 56 |
}
|
| 57 |
DEFAULT_MEDICAL_MODEL = "MedSwin SFT"
|
| 58 |
EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1" # Domain-tuned medical embedding model
|
| 59 |
+
WHISPER_MODEL = "openai/whisper-large-v3-turbo"
|
| 60 |
+
TTS_MODEL = "maya-research/maya1"
|
| 61 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 62 |
if not HF_TOKEN:
|
| 63 |
raise ValueError("HF_TOKEN not found in environment variables")
|
|
|
|
| 168 |
global_medical_models = {}
|
| 169 |
global_medical_tokenizers = {}
|
| 170 |
global_file_info = {}
|
| 171 |
+
global_whisper_model = None
|
| 172 |
+
global_tts_model = None
|
| 173 |
|
| 174 |
def initialize_translation_model():
|
| 175 |
"""Initialize DeepSeek-R1 model for translation purposes"""
|
|
|
|
| 205 |
logger.info(f"Medical model {model_name} initialized successfully")
|
| 206 |
return global_medical_models[model_name], global_medical_tokenizers[model_name]
|
| 207 |
|
| 208 |
+
def initialize_whisper_model():
|
| 209 |
+
"""Initialize Whisper model for speech-to-text"""
|
| 210 |
+
global global_whisper_model
|
| 211 |
+
if global_whisper_model is None:
|
| 212 |
+
logger.info("Initializing Whisper model for speech transcription...")
|
| 213 |
+
try:
|
| 214 |
+
# Try loading from HuggingFace
|
| 215 |
+
global_whisper_model = whisper.load_model("large-v3-turbo")
|
| 216 |
+
except:
|
| 217 |
+
# Fallback to base model
|
| 218 |
+
global_whisper_model = whisper.load_model("base")
|
| 219 |
+
logger.info("Whisper model initialized successfully")
|
| 220 |
+
return global_whisper_model
|
| 221 |
+
|
| 222 |
+
def initialize_tts_model():
|
| 223 |
+
"""Initialize TTS model for text-to-speech"""
|
| 224 |
+
global global_tts_model
|
| 225 |
+
if global_tts_model is None:
|
| 226 |
+
logger.info("Initializing TTS model for voice generation...")
|
| 227 |
+
global_tts_model = TTS(model_name=TTS_MODEL, progress_bar=False)
|
| 228 |
+
logger.info("TTS model initialized successfully")
|
| 229 |
+
return global_tts_model
|
| 230 |
+
|
| 231 |
+
def transcribe_audio(audio):
|
| 232 |
+
"""Transcribe audio to text using Whisper"""
|
| 233 |
+
global global_whisper_model
|
| 234 |
+
if global_whisper_model is None:
|
| 235 |
+
initialize_whisper_model()
|
| 236 |
+
|
| 237 |
+
if audio is None:
|
| 238 |
+
return ""
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
# Handle file path (Gradio Audio component returns file path)
|
| 242 |
+
if isinstance(audio, str):
|
| 243 |
+
audio_path = audio
|
| 244 |
+
elif isinstance(audio, tuple):
|
| 245 |
+
# Handle tuple format (sample_rate, audio_data)
|
| 246 |
+
sample_rate, audio_data = audio
|
| 247 |
+
# Save to temp file
|
| 248 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
| 249 |
+
sf.write(tmp_file.name, audio_data, samplerate=sample_rate)
|
| 250 |
+
audio_path = tmp_file.name
|
| 251 |
+
else:
|
| 252 |
+
audio_path = audio
|
| 253 |
+
|
| 254 |
+
# Transcribe
|
| 255 |
+
result = global_whisper_model.transcribe(audio_path, language="en")
|
| 256 |
+
transcribed_text = result["text"].strip()
|
| 257 |
+
logger.info(f"Transcribed: {transcribed_text}")
|
| 258 |
+
return transcribed_text
|
| 259 |
+
except Exception as e:
|
| 260 |
+
logger.error(f"Transcription error: {e}")
|
| 261 |
+
return ""
|
| 262 |
+
|
| 263 |
+
def generate_speech(text: str):
|
| 264 |
+
"""Generate speech from text using TTS model"""
|
| 265 |
+
global global_tts_model
|
| 266 |
+
if global_tts_model is None:
|
| 267 |
+
initialize_tts_model()
|
| 268 |
+
|
| 269 |
+
if not text or len(text.strip()) == 0:
|
| 270 |
+
return None
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
# Generate audio
|
| 274 |
+
wav = global_tts_model.tts(text)
|
| 275 |
+
|
| 276 |
+
# Save to temporary file
|
| 277 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
| 278 |
+
sf.write(tmp_file.name, wav, samplerate=22050)
|
| 279 |
+
return tmp_file.name
|
| 280 |
+
except Exception as e:
|
| 281 |
+
logger.error(f"TTS error: {e}")
|
| 282 |
+
return None
|
| 283 |
+
|
| 284 |
def detect_language(text: str) -> str:
|
| 285 |
"""Detect language of input text"""
|
| 286 |
try:
|
|
|
|
| 1049 |
if needs_translation and partial_response:
|
| 1050 |
logger.info(f"Translating response back to {original_lang}...")
|
| 1051 |
translated_response = translate_text(partial_response, target_lang=original_lang, source_lang="en")
|
| 1052 |
+
partial_response = translated_response
|
| 1053 |
+
|
| 1054 |
+
# Add speaker icon to assistant message
|
| 1055 |
+
speaker_icon = ' 🔊'
|
| 1056 |
+
partial_response_with_speaker = partial_response + speaker_icon
|
| 1057 |
+
updated_history[-1]["content"] = partial_response_with_speaker
|
| 1058 |
+
|
| 1059 |
+
yield updated_history
|
| 1060 |
|
| 1061 |
except GeneratorExit:
|
| 1062 |
stop_event.set()
|
| 1063 |
thread.join()
|
| 1064 |
raise
|
| 1065 |
|
| 1066 |
+
def generate_speech_for_message(text: str):
|
| 1067 |
+
"""Generate speech for a message and return audio file"""
|
| 1068 |
+
audio_path = generate_speech(text)
|
| 1069 |
+
if audio_path:
|
| 1070 |
+
return audio_path
|
| 1071 |
+
return None
|
| 1072 |
+
|
| 1073 |
def create_demo():
|
| 1074 |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
| 1075 |
gr.HTML(TITLE)
|
|
|
|
| 1107 |
type="messages"
|
| 1108 |
)
|
| 1109 |
with gr.Row(elem_classes="input-row"):
|
| 1110 |
+
with gr.Column(scale=1, min_width=50):
|
| 1111 |
+
mic_button = gr.Audio(
|
| 1112 |
+
sources=["microphone"],
|
| 1113 |
+
type="filepath",
|
| 1114 |
+
label="",
|
| 1115 |
+
show_label=False,
|
| 1116 |
+
container=False
|
| 1117 |
+
)
|
| 1118 |
message_input = gr.Textbox(
|
| 1119 |
placeholder="Type your medical question here...",
|
| 1120 |
show_label=False,
|
| 1121 |
container=False,
|
| 1122 |
lines=1,
|
| 1123 |
+
scale=7
|
| 1124 |
)
|
| 1125 |
submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
|
| 1126 |
|
| 1127 |
+
# Handle microphone transcription
|
| 1128 |
+
def handle_transcription(audio):
|
| 1129 |
+
if audio is None:
|
| 1130 |
+
return ""
|
| 1131 |
+
transcribed = transcribe_audio(audio)
|
| 1132 |
+
return transcribed
|
| 1133 |
+
|
| 1134 |
+
mic_button.stop_recording(
|
| 1135 |
+
fn=handle_transcription,
|
| 1136 |
+
inputs=[mic_button],
|
| 1137 |
+
outputs=[message_input]
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
# TTS component for generating speech from messages
|
| 1141 |
+
with gr.Row(visible=False) as tts_row:
|
| 1142 |
+
tts_text = gr.Textbox(visible=False)
|
| 1143 |
+
tts_audio = gr.Audio(label="Generated Speech", visible=False)
|
| 1144 |
+
|
| 1145 |
+
# Function to generate speech when speaker icon is clicked
|
| 1146 |
+
def generate_speech_from_chat(history):
|
| 1147 |
+
"""Extract last assistant message and generate speech"""
|
| 1148 |
+
if not history or len(history) == 0:
|
| 1149 |
+
return None
|
| 1150 |
+
last_msg = history[-1]
|
| 1151 |
+
if last_msg.get("role") == "assistant":
|
| 1152 |
+
text = last_msg.get("content", "").replace(" 🔊", "").strip()
|
| 1153 |
+
if text:
|
| 1154 |
+
audio_path = generate_speech(text)
|
| 1155 |
+
return audio_path
|
| 1156 |
+
return None
|
| 1157 |
+
|
| 1158 |
+
# Add TTS button that appears when assistant responds
|
| 1159 |
+
tts_button = gr.Button("🔊 Play Response", visible=False, size="sm")
|
| 1160 |
+
|
| 1161 |
+
# Update TTS button visibility and generate speech
|
| 1162 |
+
def update_tts_button(history):
|
| 1163 |
+
if history and len(history) > 0 and history[-1].get("role") == "assistant":
|
| 1164 |
+
return gr.update(visible=True)
|
| 1165 |
+
return gr.update(visible=False)
|
| 1166 |
+
|
| 1167 |
+
chatbot.change(
|
| 1168 |
+
fn=update_tts_button,
|
| 1169 |
+
inputs=[chatbot],
|
| 1170 |
+
outputs=[tts_button]
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
tts_button.click(
|
| 1174 |
+
fn=generate_speech_from_chat,
|
| 1175 |
+
inputs=[chatbot],
|
| 1176 |
+
outputs=[tts_audio]
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 1180 |
with gr.Row():
|
| 1181 |
use_rag = gr.Checkbox(
|
|
|
|
| 1299 |
return demo
|
| 1300 |
|
| 1301 |
if __name__ == "__main__":
|
| 1302 |
+
# Preload models on startup
|
| 1303 |
+
logger.info("Preloading models on startup...")
|
| 1304 |
logger.info("Initializing default medical model (MedSwin SFT)...")
|
| 1305 |
initialize_medical_model(DEFAULT_MEDICAL_MODEL)
|
| 1306 |
+
logger.info("Preloading Whisper model...")
|
| 1307 |
+
initialize_whisper_model()
|
| 1308 |
+
logger.info("Preloading TTS model...")
|
| 1309 |
+
initialize_tts_model()
|
| 1310 |
+
logger.info("All models preloaded successfully!")
|
| 1311 |
demo = create_demo()
|
| 1312 |
demo.launch()
|
requirements.txt
CHANGED
|
@@ -16,4 +16,7 @@ requests
|
|
| 16 |
beautifulsoup4
|
| 17 |
duckduckgo-search
|
| 18 |
gradio
|
| 19 |
-
spaces
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
beautifulsoup4
|
| 17 |
duckduckgo-search
|
| 18 |
gradio
|
| 19 |
+
spaces
|
| 20 |
+
openai-whisper
|
| 21 |
+
TTS
|
| 22 |
+
soundfile
|