LiamKhoaLe commited on
Commit
f89165d
·
1 Parent(s): 20851fb

Upd ASR and TTS

Browse files
Files changed (2) hide show
  1. app.py +168 -6
  2. requirements.txt +4 -1
app.py CHANGED
@@ -36,6 +36,11 @@ from langdetect import detect, LangDetectException
36
  from duckduckgo_search import DDGS
37
  import requests
38
  from bs4 import BeautifulSoup
 
 
 
 
 
39
 
40
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
  logging.basicConfig(level=logging.INFO)
@@ -51,6 +56,8 @@ MEDSWIN_MODELS = {
51
  }
52
  DEFAULT_MEDICAL_MODEL = "MedSwin SFT"
53
  EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1" # Domain-tuned medical embedding model
 
 
54
  HF_TOKEN = os.environ.get("HF_TOKEN")
55
  if not HF_TOKEN:
56
  raise ValueError("HF_TOKEN not found in environment variables")
@@ -161,6 +168,8 @@ global_translation_tokenizer = None
161
  global_medical_models = {}
162
  global_medical_tokenizers = {}
163
  global_file_info = {}
 
 
164
 
165
  def initialize_translation_model():
166
  """Initialize DeepSeek-R1 model for translation purposes"""
@@ -196,6 +205,82 @@ def initialize_medical_model(model_name: str):
196
  logger.info(f"Medical model {model_name} initialized successfully")
197
  return global_medical_models[model_name], global_medical_tokenizers[model_name]
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def detect_language(text: str) -> str:
200
  """Detect language of input text"""
201
  try:
@@ -964,16 +1049,27 @@ def stream_chat(
964
  if needs_translation and partial_response:
965
  logger.info(f"Translating response back to {original_lang}...")
966
  translated_response = translate_text(partial_response, target_lang=original_lang, source_lang="en")
967
- updated_history[-1]["content"] = translated_response
968
- yield updated_history
969
- else:
970
- yield updated_history
 
 
 
 
971
 
972
  except GeneratorExit:
973
  stop_event.set()
974
  thread.join()
975
  raise
976
 
 
 
 
 
 
 
 
977
  def create_demo():
978
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
979
  gr.HTML(TITLE)
@@ -1011,15 +1107,75 @@ def create_demo():
1011
  type="messages"
1012
  )
1013
  with gr.Row(elem_classes="input-row"):
 
 
 
 
 
 
 
 
1014
  message_input = gr.Textbox(
1015
  placeholder="Type your medical question here...",
1016
  show_label=False,
1017
  container=False,
1018
  lines=1,
1019
- scale=8
1020
  )
1021
  submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
1022
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1023
  with gr.Accordion("⚙️ Advanced Settings", open=False):
1024
  with gr.Row():
1025
  use_rag = gr.Checkbox(
@@ -1143,8 +1299,14 @@ def create_demo():
1143
  return demo
1144
 
1145
  if __name__ == "__main__":
1146
- # Initialize default medical model
 
1147
  logger.info("Initializing default medical model (MedSwin SFT)...")
1148
  initialize_medical_model(DEFAULT_MEDICAL_MODEL)
 
 
 
 
 
1149
  demo = create_demo()
1150
  demo.launch()
 
36
  from duckduckgo_search import DDGS
37
  import requests
38
  from bs4 import BeautifulSoup
39
+ import whisper
40
+ from TTS.api import TTS
41
+ import numpy as np
42
+ import soundfile as sf
43
+ import tempfile
44
 
45
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
46
  logging.basicConfig(level=logging.INFO)
 
56
  }
57
  DEFAULT_MEDICAL_MODEL = "MedSwin SFT"
58
  EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1" # Domain-tuned medical embedding model
59
+ WHISPER_MODEL = "openai/whisper-large-v3-turbo"
60
+ TTS_MODEL = "maya-research/maya1"
61
  HF_TOKEN = os.environ.get("HF_TOKEN")
62
  if not HF_TOKEN:
63
  raise ValueError("HF_TOKEN not found in environment variables")
 
168
  global_medical_models = {}
169
  global_medical_tokenizers = {}
170
  global_file_info = {}
171
+ global_whisper_model = None
172
+ global_tts_model = None
173
 
174
  def initialize_translation_model():
175
  """Initialize DeepSeek-R1 model for translation purposes"""
 
205
  logger.info(f"Medical model {model_name} initialized successfully")
206
  return global_medical_models[model_name], global_medical_tokenizers[model_name]
207
 
208
+ def initialize_whisper_model():
209
+ """Initialize Whisper model for speech-to-text"""
210
+ global global_whisper_model
211
+ if global_whisper_model is None:
212
+ logger.info("Initializing Whisper model for speech transcription...")
213
+ try:
214
+ # Try loading from HuggingFace
215
+ global_whisper_model = whisper.load_model("large-v3-turbo")
216
+ except:
217
+ # Fallback to base model
218
+ global_whisper_model = whisper.load_model("base")
219
+ logger.info("Whisper model initialized successfully")
220
+ return global_whisper_model
221
+
222
+ def initialize_tts_model():
223
+ """Initialize TTS model for text-to-speech"""
224
+ global global_tts_model
225
+ if global_tts_model is None:
226
+ logger.info("Initializing TTS model for voice generation...")
227
+ global_tts_model = TTS(model_name=TTS_MODEL, progress_bar=False)
228
+ logger.info("TTS model initialized successfully")
229
+ return global_tts_model
230
+
231
+ def transcribe_audio(audio):
232
+ """Transcribe audio to text using Whisper"""
233
+ global global_whisper_model
234
+ if global_whisper_model is None:
235
+ initialize_whisper_model()
236
+
237
+ if audio is None:
238
+ return ""
239
+
240
+ try:
241
+ # Handle file path (Gradio Audio component returns file path)
242
+ if isinstance(audio, str):
243
+ audio_path = audio
244
+ elif isinstance(audio, tuple):
245
+ # Handle tuple format (sample_rate, audio_data)
246
+ sample_rate, audio_data = audio
247
+ # Save to temp file
248
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
249
+ sf.write(tmp_file.name, audio_data, samplerate=sample_rate)
250
+ audio_path = tmp_file.name
251
+ else:
252
+ audio_path = audio
253
+
254
+ # Transcribe
255
+ result = global_whisper_model.transcribe(audio_path, language="en")
256
+ transcribed_text = result["text"].strip()
257
+ logger.info(f"Transcribed: {transcribed_text}")
258
+ return transcribed_text
259
+ except Exception as e:
260
+ logger.error(f"Transcription error: {e}")
261
+ return ""
262
+
263
+ def generate_speech(text: str):
264
+ """Generate speech from text using TTS model"""
265
+ global global_tts_model
266
+ if global_tts_model is None:
267
+ initialize_tts_model()
268
+
269
+ if not text or len(text.strip()) == 0:
270
+ return None
271
+
272
+ try:
273
+ # Generate audio
274
+ wav = global_tts_model.tts(text)
275
+
276
+ # Save to temporary file
277
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
278
+ sf.write(tmp_file.name, wav, samplerate=22050)
279
+ return tmp_file.name
280
+ except Exception as e:
281
+ logger.error(f"TTS error: {e}")
282
+ return None
283
+
284
  def detect_language(text: str) -> str:
285
  """Detect language of input text"""
286
  try:
 
1049
  if needs_translation and partial_response:
1050
  logger.info(f"Translating response back to {original_lang}...")
1051
  translated_response = translate_text(partial_response, target_lang=original_lang, source_lang="en")
1052
+ partial_response = translated_response
1053
+
1054
+ # Add speaker icon to assistant message
1055
+ speaker_icon = ' 🔊'
1056
+ partial_response_with_speaker = partial_response + speaker_icon
1057
+ updated_history[-1]["content"] = partial_response_with_speaker
1058
+
1059
+ yield updated_history
1060
 
1061
  except GeneratorExit:
1062
  stop_event.set()
1063
  thread.join()
1064
  raise
1065
 
1066
+ def generate_speech_for_message(text: str):
1067
+ """Generate speech for a message and return audio file"""
1068
+ audio_path = generate_speech(text)
1069
+ if audio_path:
1070
+ return audio_path
1071
+ return None
1072
+
1073
  def create_demo():
1074
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
1075
  gr.HTML(TITLE)
 
1107
  type="messages"
1108
  )
1109
  with gr.Row(elem_classes="input-row"):
1110
+ with gr.Column(scale=1, min_width=50):
1111
+ mic_button = gr.Audio(
1112
+ sources=["microphone"],
1113
+ type="filepath",
1114
+ label="",
1115
+ show_label=False,
1116
+ container=False
1117
+ )
1118
  message_input = gr.Textbox(
1119
  placeholder="Type your medical question here...",
1120
  show_label=False,
1121
  container=False,
1122
  lines=1,
1123
+ scale=7
1124
  )
1125
  submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
1126
 
1127
+ # Handle microphone transcription
1128
+ def handle_transcription(audio):
1129
+ if audio is None:
1130
+ return ""
1131
+ transcribed = transcribe_audio(audio)
1132
+ return transcribed
1133
+
1134
+ mic_button.stop_recording(
1135
+ fn=handle_transcription,
1136
+ inputs=[mic_button],
1137
+ outputs=[message_input]
1138
+ )
1139
+
1140
+ # TTS component for generating speech from messages
1141
+ with gr.Row(visible=False) as tts_row:
1142
+ tts_text = gr.Textbox(visible=False)
1143
+ tts_audio = gr.Audio(label="Generated Speech", visible=False)
1144
+
1145
+ # Function to generate speech when speaker icon is clicked
1146
+ def generate_speech_from_chat(history):
1147
+ """Extract last assistant message and generate speech"""
1148
+ if not history or len(history) == 0:
1149
+ return None
1150
+ last_msg = history[-1]
1151
+ if last_msg.get("role") == "assistant":
1152
+ text = last_msg.get("content", "").replace(" 🔊", "").strip()
1153
+ if text:
1154
+ audio_path = generate_speech(text)
1155
+ return audio_path
1156
+ return None
1157
+
1158
+ # Add TTS button that appears when assistant responds
1159
+ tts_button = gr.Button("🔊 Play Response", visible=False, size="sm")
1160
+
1161
+ # Update TTS button visibility and generate speech
1162
+ def update_tts_button(history):
1163
+ if history and len(history) > 0 and history[-1].get("role") == "assistant":
1164
+ return gr.update(visible=True)
1165
+ return gr.update(visible=False)
1166
+
1167
+ chatbot.change(
1168
+ fn=update_tts_button,
1169
+ inputs=[chatbot],
1170
+ outputs=[tts_button]
1171
+ )
1172
+
1173
+ tts_button.click(
1174
+ fn=generate_speech_from_chat,
1175
+ inputs=[chatbot],
1176
+ outputs=[tts_audio]
1177
+ )
1178
+
1179
  with gr.Accordion("⚙️ Advanced Settings", open=False):
1180
  with gr.Row():
1181
  use_rag = gr.Checkbox(
 
1299
  return demo
1300
 
1301
  if __name__ == "__main__":
1302
+ # Preload models on startup
1303
+ logger.info("Preloading models on startup...")
1304
  logger.info("Initializing default medical model (MedSwin SFT)...")
1305
  initialize_medical_model(DEFAULT_MEDICAL_MODEL)
1306
+ logger.info("Preloading Whisper model...")
1307
+ initialize_whisper_model()
1308
+ logger.info("Preloading TTS model...")
1309
+ initialize_tts_model()
1310
+ logger.info("All models preloaded successfully!")
1311
  demo = create_demo()
1312
  demo.launch()
requirements.txt CHANGED
@@ -16,4 +16,7 @@ requests
16
  beautifulsoup4
17
  duckduckgo-search
18
  gradio
19
- spaces
 
 
 
 
16
  beautifulsoup4
17
  duckduckgo-search
18
  gradio
19
+ spaces
20
+ openai-whisper
21
+ TTS
22
+ soundfile