Y Phung Nguyen commited on
Commit
98c58ec
Β·
1 Parent(s): faa95c5

Upd ASR loader

Browse files
Files changed (2) hide show
  1. config.py +2 -2
  2. voice.py +93 -31
config.py CHANGED
@@ -54,8 +54,8 @@ DESCRIPTION = """
54
  <p>πŸ“„ <strong>Document RAG:</strong> Answer based on uploaded medical documents</p>
55
  <p>🌐 <strong>Web Search:</strong> Fetch knowledge from reliable online medical resources</p>
56
  <p>🌍 <strong>Multi-language:</strong> Automatic translation for non-English queries</p>
57
- <p>Tips: Customise configurations, system prompt to see the magic happens!</p>
58
- <p>Note: Case GPU aborted or MedSwin not ready, please select another model!</p>
59
  </center>
60
  """
61
  CSS = """
 
54
  <p>πŸ“„ <strong>Document RAG:</strong> Answer based on uploaded medical documents</p>
55
  <p>🌐 <strong>Web Search:</strong> Fetch knowledge from reliable online medical resources</p>
56
  <p>🌍 <strong>Multi-language:</strong> Automatic translation for non-English queries</p>
57
+ <p><strong>Tips:</strong> Customise configurations & system prompt to see the magic!</p>
58
+ <p><strong>Note:</strong> Case GPU aborted or MedSwin not ready, please try another model!</p>
59
  </center>
60
  """
61
  CSS = """
voice.py CHANGED
@@ -92,7 +92,7 @@ def transcribe_audio_whisper(audio_path: str) -> str:
92
  except Exception as e:
93
  logger.error(f"[ASR] Error initializing Whisper model: {e}")
94
  import traceback
95
- logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}")
96
  return ""
97
 
98
  if config.global_whisper_model is None:
@@ -106,44 +106,106 @@ def transcribe_audio_whisper(audio_path: str) -> str:
106
  logger.info("[ASR] Loading audio file...")
107
  # Load audio using torchaudio (imported from models)
108
  from models import torchaudio
 
109
  if torchaudio is None:
110
  logger.error("[ASR] torchaudio not available")
111
  return ""
112
 
113
- waveform, sample_rate = torchaudio.load(audio_path)
114
- # Resample to 16kHz if needed (Whisper expects 16kHz)
115
- if sample_rate != 16000:
116
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
117
- waveform = resampler(waveform)
118
- sample_rate = 16000
119
-
120
- logger.info("[ASR] Processing audio with Whisper...")
121
- # Process audio
122
- inputs = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt")
123
-
124
- # Move inputs to same device as model
125
- device = next(model.parameters()).device
126
- inputs = {k: v.to(device) for k, v in inputs.items()}
127
-
128
- logger.info("[ASR] Running Whisper transcription...")
129
- # Generate transcription
130
- with torch.no_grad():
131
- generated_ids = model.generate(**inputs)
132
-
133
- # Decode transcription
134
- transcribed_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
135
-
136
- if transcribed_text:
137
- logger.info(f"[ASR] βœ… Transcription successful: {transcribed_text[:100]}...")
138
- logger.info(f"[ASR] Transcription length: {len(transcribed_text)} characters")
139
- else:
140
- logger.warning("[ASR] Whisper returned empty transcription")
141
 
142
- return transcribed_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  except Exception as e:
144
  logger.error(f"[ASR] Whisper transcription error: {e}")
145
  import traceback
146
- logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}")
147
  return ""
148
 
149
  def transcribe_audio(audio):
 
92
  except Exception as e:
93
  logger.error(f"[ASR] Error initializing Whisper model: {e}")
94
  import traceback
95
+ logger.error(f"[ASR] Initialization traceback: {traceback.format_exc()}")
96
  return ""
97
 
98
  if config.global_whisper_model is None:
 
106
  logger.info("[ASR] Loading audio file...")
107
  # Load audio using torchaudio (imported from models)
108
  from models import torchaudio
109
+ import torch
110
  if torchaudio is None:
111
  logger.error("[ASR] torchaudio not available")
112
  return ""
113
 
114
+ # Check if audio file exists
115
+ if not os.path.exists(audio_path):
116
+ logger.error(f"[ASR] Audio file not found: {audio_path}")
117
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ try:
120
+ waveform, sample_rate = torchaudio.load(audio_path)
121
+ logger.info(f"[ASR] Loaded audio: shape={waveform.shape}, sample_rate={sample_rate}")
122
+
123
+ # Ensure audio is mono (single channel)
124
+ if waveform.shape[0] > 1:
125
+ logger.info(f"[ASR] Converting {waveform.shape[0]}-channel audio to mono")
126
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
127
+
128
+ # Resample to 16kHz if needed (Whisper expects 16kHz)
129
+ if sample_rate != 16000:
130
+ logger.info(f"[ASR] Resampling from {sample_rate}Hz to 16000Hz")
131
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
132
+ waveform = resampler(waveform)
133
+ sample_rate = 16000
134
+
135
+ logger.info(f"[ASR] Audio ready: shape={waveform.shape}, sample_rate={sample_rate}")
136
+
137
+ logger.info("[ASR] Processing audio with Whisper processor...")
138
+ # Process audio - convert to numpy and ensure it's the right shape
139
+ audio_array = waveform.squeeze().numpy()
140
+ logger.info(f"[ASR] Audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}")
141
+
142
+ # Process audio
143
+ inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt")
144
+ logger.info(f"[ASR] Processor inputs: {list(inputs.keys())}")
145
+
146
+ # Move inputs to same device as model
147
+ device = next(model.parameters()).device
148
+ logger.info(f"[ASR] Model device: {device}")
149
+ inputs = {k: v.to(device) for k, v in inputs.items()}
150
+
151
+ logger.info("[ASR] Running Whisper model.generate()...")
152
+ # Generate transcription with proper parameters
153
+ # Whisper expects input_features as the main parameter
154
+ if "input_features" not in inputs:
155
+ logger.error(f"[ASR] Missing input_features in processor output. Keys: {list(inputs.keys())}")
156
+ return ""
157
+
158
+ input_features = inputs["input_features"]
159
+ logger.info(f"[ASR] Input features shape: {input_features.shape}, dtype: {input_features.dtype}")
160
+
161
+ with torch.no_grad():
162
+ try:
163
+ # Whisper generate with proper parameters
164
+ generated_ids = model.generate(
165
+ input_features,
166
+ max_length=448, # Whisper default max length
167
+ num_beams=5,
168
+ language=None, # Auto-detect language
169
+ task="transcribe",
170
+ return_timestamps=False
171
+ )
172
+ logger.info(f"[ASR] Generated IDs shape: {generated_ids.shape}, dtype: {generated_ids.dtype}")
173
+ logger.info(f"[ASR] Generated IDs sample: {generated_ids[0][:20] if len(generated_ids) > 0 else 'empty'}")
174
+ except Exception as gen_error:
175
+ logger.error(f"[ASR] Error in model.generate(): {gen_error}")
176
+ import traceback
177
+ logger.error(f"[ASR] Generate traceback: {traceback.format_exc()}")
178
+ # Try simpler generation without optional parameters
179
+ logger.info("[ASR] Retrying with minimal parameters...")
180
+ try:
181
+ generated_ids = model.generate(input_features)
182
+ logger.info(f"[ASR] Retry successful, generated IDs shape: {generated_ids.shape}")
183
+ except Exception as retry_error:
184
+ logger.error(f"[ASR] Retry also failed: {retry_error}")
185
+ return ""
186
+
187
+ logger.info("[ASR] Decoding transcription...")
188
+ # Decode transcription
189
+ transcribed_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
190
+
191
+ if transcribed_text:
192
+ logger.info(f"[ASR] βœ… Transcription successful: {transcribed_text[:100]}...")
193
+ logger.info(f"[ASR] Transcription length: {len(transcribed_text)} characters")
194
+ else:
195
+ logger.warning("[ASR] Whisper returned empty transcription")
196
+ logger.warning(f"[ASR] Generated IDs: {generated_ids}")
197
+ logger.warning(f"[ASR] Decoded (before strip): {processor.batch_decode(generated_ids, skip_special_tokens=False)[0]}")
198
+
199
+ return transcribed_text
200
+ except Exception as audio_error:
201
+ logger.error(f"[ASR] Error processing audio file: {audio_error}")
202
+ import traceback
203
+ logger.error(f"[ASR] Audio processing traceback: {traceback.format_exc()}")
204
+ return ""
205
  except Exception as e:
206
  logger.error(f"[ASR] Whisper transcription error: {e}")
207
  import traceback
208
+ logger.error(f"[ASR] Full traceback: {traceback.format_exc()}")
209
  return ""
210
 
211
  def transcribe_audio(audio):