gpaasch commited on
Commit
3b5fe24
·
1 Parent(s): d364129

app.py wrapper for Gradio waa very bad idea, regoranizing project for clarity, utils folder will be very import of separation of concerns

Browse files
Files changed (4) hide show
  1. app.py +296 -2
  2. src/app.py +0 -625
  3. utils/model_configuration_utils.py +126 -0
  4. utils/voice_input_utils.py +193 -0
app.py CHANGED
@@ -1,5 +1,299 @@
1
- # app.py at repo root
2
- from src.app import demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  if __name__ == "__main__":
5
  demo.launch(
 
1
+ from huggingface_hub import hf_hub_download
2
+ import gradio as gr
3
+ from llama_index.core import Settings
4
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
5
+ from llama_index.llms.llama_cpp import LlamaCPP
6
+ from src.parse_tabular import create_symptom_index
7
+ from utils import model_configuration_utils as mc
8
+ from utils import voice_input_utils as viu
9
+ import json
10
+ import torch
11
+ import torchaudio.transforms as T
12
+
13
+ # Set up model paths
14
+ MODEL_NAME, REPO_ID = mc.select_best_model()
15
+
16
+ # Ensure model is downloaded
17
+ model_path = mc.ensure_model()
18
+
19
+ # Configure local LLM with LlamaCPP
20
+ print("\nInitializing LLM...")
21
+ llm = LlamaCPP(
22
+ model_path=model_path,
23
+ temperature=0.7,
24
+ max_new_tokens=256,
25
+ context_window=2048,
26
+ verbose=False # Reduce logging
27
+ # n_batch and n_threads are not valid parameters for LlamaCPP and should not be used.
28
+ # If you encounter segmentation faults, try reducing context_window or check your system resources.
29
+ )
30
+ print("LLM initialized successfully")
31
+
32
+ # Configure global settings
33
+ print("\nConfiguring settings...")
34
+ Settings.llm = llm
35
+ Settings.embed_model = HuggingFaceEmbedding(
36
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
37
+ )
38
+ print("Settings configured")
39
+
40
+ # Create the index at startup
41
+ print("\nCreating symptom index...")
42
+ symptom_index = create_symptom_index()
43
+ print("Index created successfully")
44
+ print("Loaded symptom_index:", type(symptom_index))
45
+
46
+ # --- System prompt ---
47
+ SYSTEM_PROMPT = """
48
+ You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
49
+ At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?")
50
+ or, if you have enough info, output a final JSON with fields:
51
+ {"diagnoses":[…], "confidences":[…]}.
52
+ """
53
+
54
+ # Build enhanced Gradio interface
55
+ with gr.Blocks(theme="default") as demo:
56
+ gr.Markdown("""
57
+ # 🏥 Medical Symptom to ICD-10 Code Assistant
58
+
59
+ ## About
60
+ This application is part of the Agents+MCP Hackathon. It helps medical professionals
61
+ and patients understand potential diagnoses based on described symptoms.
62
+
63
+ ### How it works:
64
+ 1. Either click the record button and describe your symptoms or type them into the textbox
65
+ 2. The AI will analyze your description and suggest possible diagnoses
66
+ 3. Answer follow-up questions to refine the diagnosis
67
+ """)
68
+
69
+ with gr.Row():
70
+ with gr.Column(scale=2):
71
+ # Add text input above microphone
72
+ with gr.Row():
73
+ text_input = gr.Textbox(
74
+ label="Type your symptoms",
75
+ placeholder="Or type your symptoms here...",
76
+ lines=3
77
+ )
78
+ submit_btn = gr.Button("Submit", variant="primary")
79
+
80
+ # Existing microphone row
81
+ with gr.Row():
82
+ microphone = gr.Audio(
83
+ sources=["microphone"],
84
+ streaming=True,
85
+ type="numpy",
86
+ label="Describe your symptoms"
87
+ )
88
+ transcript_box = gr.Textbox(
89
+ label="Transcribed Text",
90
+ interactive=False,
91
+ show_label=True
92
+ )
93
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
94
+
95
+ chatbot = gr.Chatbot(
96
+ label="Medical Consultation",
97
+ height=500,
98
+ container=True,
99
+ type="messages" # This is now properly supported by our message format
100
+ )
101
+
102
+ with gr.Column(scale=1):
103
+ with gr.Accordion("Enter an API Key to give it more power!", open=False):
104
+ api_key = gr.Textbox(
105
+ label="OpenAI API Key (optional)",
106
+ type="password",
107
+ placeholder="sk-..."
108
+ )
109
+
110
+ with gr.Row():
111
+ with gr.Column():
112
+ modal_key = gr.Textbox(
113
+ label="Modal Labs API Key",
114
+ type="password",
115
+ placeholder="mk-..."
116
+ )
117
+ anthropic_key = gr.Textbox(
118
+ label="Anthropic API Key",
119
+ type="password",
120
+ placeholder="sk-ant-..."
121
+ )
122
+ mistral_key = gr.Textbox(
123
+ label="MistralAI API Key",
124
+ type="password",
125
+ placeholder="..."
126
+ )
127
+
128
+ with gr.Column():
129
+ nebius_key = gr.Textbox(
130
+ label="Nebius API Key",
131
+ type="password",
132
+ placeholder="..."
133
+ )
134
+ hyperbolic_key = gr.Textbox(
135
+ label="Hyperbolic Labs API Key",
136
+ type="password",
137
+ placeholder="hyp-..."
138
+ )
139
+ sambanova_key = gr.Textbox(
140
+ label="SambaNova API Key",
141
+ type="password",
142
+ placeholder="..."
143
+ )
144
+
145
+ with gr.Row():
146
+ model_selector = gr.Dropdown(
147
+ choices=["OpenAI", "Modal", "Anthropic", "MistralAI", "Nebius", "Hyperbolic", "SambaNova"],
148
+ value="OpenAI",
149
+ label="Model Provider"
150
+ )
151
+ temperature = gr.Slider(
152
+ minimum=0,
153
+ maximum=1,
154
+ value=0.7,
155
+ label="Temperature"
156
+ )
157
+ # self promotion at bottom of page
158
+ gr.Markdown("""
159
+ ---
160
+ ### 👋 About the Creator
161
+
162
+ Hi! I'm Graham Paasch, an experienced technology professional!
163
+
164
+ 🎥 **Check out my YouTube channel** for more tech content:
165
+ [Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ)
166
+
167
+ 💼 **Looking for a skilled developer?**
168
+ I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/)
169
+
170
+ ⭐ If you found this tool helpful, please consider:
171
+ - Subscribing to my YouTube channel
172
+ - Connecting on LinkedIn
173
+ - Sharing this tool with others in healthcare tech
174
+ """)
175
+
176
+ # Event handlers
177
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
178
+
179
+ microphone.stream(
180
+ fn=viu.enhanced_process_speech,
181
+ inputs=[microphone, chatbot, api_key, model_selector, temperature],
182
+ outputs=chatbot,
183
+ show_progress="hidden",
184
+ api_name=False,
185
+ queue=True # Enable queuing for better stream handling
186
+ )
187
+
188
+ def process_audio(audio_array, sample_rate):
189
+ """Pre-process audio for Whisper."""
190
+ if audio_array.ndim > 1:
191
+ audio_array = audio_array.mean(axis=1)
192
+
193
+ # Convert to tensor for resampling
194
+ audio_tensor = torch.FloatTensor(audio_array)
195
+
196
+ # Resample to 16kHz if needed
197
+ if sample_rate != 16000:
198
+ resampler = T.Resample(sample_rate, 16000)
199
+ audio_tensor = resampler(audio_tensor)
200
+
201
+ # Normalize
202
+ audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
203
+
204
+ # Convert back to numpy array and return in correct format
205
+ return {
206
+ "raw": audio_tensor.numpy(), # Key must be "raw"
207
+ "sampling_rate": 16000 # Key must be "sampling_rate"
208
+ }
209
+
210
+ # Update transcription handler
211
+ def update_live_transcription(audio):
212
+ """Real-time transcription updates."""
213
+ if not audio or not isinstance(audio, tuple):
214
+ return ""
215
+
216
+ try:
217
+ sample_rate, audio_array = audio
218
+ features = process_audio(audio_array, sample_rate)
219
+
220
+ asr = viu.get_asr_pipeline()
221
+ result = asr(features)
222
+
223
+ return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
224
+ except Exception as e:
225
+ print(f"Transcription error: {str(e)}")
226
+ return ""
227
+
228
+ microphone.stream(
229
+ fn=update_live_transcription,
230
+ inputs=[microphone],
231
+ outputs=transcript_box,
232
+ show_progress="hidden",
233
+ queue=True
234
+ )
235
+
236
+ clear_btn.click(
237
+ fn=lambda: (None, "", ""),
238
+ outputs=[chatbot, transcript_box, text_input],
239
+ queue=False
240
+ )
241
+
242
+ def cleanup_memory():
243
+ """Release unused memory (placeholder for future memory management)."""
244
+ import gc
245
+ gc.collect()
246
+ if torch.cuda.is_available():
247
+ torch.cuda.empty_cache()
248
+
249
+ def process_text_input(text, history):
250
+ """Process text input with memory management."""
251
+
252
+ print("process_text_input received:", text)
253
+
254
+ if not text:
255
+ return history, "" # Return tuple to clear input
256
+
257
+ # Process the symptoms using the configured LLM
258
+ prompt = f"""Given these symptoms: '{text}'
259
+ Please provide:
260
+ 1. Most likely ICD-10 codes
261
+ 2. Confidence levels for each diagnosis
262
+ 3. Key follow-up questions
263
+
264
+ Format as JSON with diagnoses, confidences, and follow_up fields."""
265
+
266
+ response = llm.complete(prompt)
267
+
268
+ try:
269
+ # Try to parse as JSON first
270
+ result = json.loads(response.text)
271
+ except json.JSONDecodeError:
272
+ # If not JSON, wrap in our format
273
+ result = {
274
+ "diagnoses": [],
275
+ "confidences": [],
276
+ "follow_up": str(response.text)[:1000] # Limit response length
277
+ }
278
+
279
+ new_history = history + [
280
+ {"role": "user", "content": text},
281
+ {"role": "assistant", "content": viu.format_response_for_user(result)}
282
+ ]
283
+ return new_history, "" # Return empty string to clear input
284
+
285
+ # Update the submit button handler
286
+ submit_btn.click(
287
+ fn=process_text_input,
288
+ inputs=[text_input, chatbot],
289
+ outputs=[chatbot, text_input],
290
+ queue=True
291
+ ).success( # Changed from .then to .success for better error handling
292
+ fn=cleanup_memory,
293
+ inputs=None,
294
+ outputs=None,
295
+ queue=False
296
+ )
297
 
298
  if __name__ == "__main__":
299
  demo.launch(
src/app.py DELETED
@@ -1,625 +0,0 @@
1
- import os
2
- from pathlib import Path
3
- from huggingface_hub import hf_hub_download
4
- import gradio as gr
5
- from llama_index.core import Settings
6
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
- from llama_index.llms.llama_cpp import LlamaCPP
8
- from .parse_tabular import create_symptom_index # Use relative import
9
- import json
10
- import psutil
11
- from typing import Tuple, Dict
12
- import torch
13
- from gtts import gTTS
14
- import io
15
- import base64
16
- import numpy as np
17
- from transformers.pipelines import pipeline # Changed from transformers import pipeline
18
- from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
19
- import torchaudio
20
- import torchaudio.transforms as T
21
-
22
- # Model options mapped to their requirements
23
- MODEL_OPTIONS = {
24
- "tiny": {
25
- "name": "TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf",
26
- "repo": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
27
- "vram_req": 2, # GB
28
- "ram_req": 4 # GB
29
- },
30
- "small": {
31
- "name": "phi-2.Q4_K_M.gguf",
32
- "repo": "TheBloke/phi-2-GGUF",
33
- "vram_req": 4,
34
- "ram_req": 8
35
- },
36
- "medium": {
37
- "name": "mistral-7b-instruct-v0.1.Q4_K_M.gguf",
38
- "repo": "TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
39
- "vram_req": 6,
40
- "ram_req": 16
41
- }
42
- }
43
-
44
- # Initialize Whisper components globally (these are lightweight)
45
- feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en")
46
- tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
47
- processor = WhisperProcessor(feature_extractor, tokenizer)
48
-
49
- def get_asr_pipeline():
50
- """Lazy load ASR pipeline with proper configuration."""
51
- global transcriber
52
- if "transcriber" not in globals():
53
- transcriber = pipeline(
54
- "automatic-speech-recognition",
55
- model="openai/whisper-base.en",
56
- chunk_length_s=30,
57
- stride_length_s=5,
58
- device="cpu",
59
- torch_dtype=torch.float32
60
- )
61
- return transcriber
62
-
63
- # Audio preprocessing function
64
- def process_audio(audio_array, sample_rate):
65
- """Pre-process audio for Whisper."""
66
- if audio_array.ndim > 1:
67
- audio_array = audio_array.mean(axis=1)
68
-
69
- # Convert to tensor for resampling
70
- audio_tensor = torch.FloatTensor(audio_array)
71
-
72
- # Resample to 16kHz if needed
73
- if sample_rate != 16000:
74
- resampler = T.Resample(sample_rate, 16000)
75
- audio_tensor = resampler(audio_tensor)
76
-
77
- # Normalize
78
- audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
79
-
80
- # Convert back to numpy array and return in correct format
81
- return {
82
- "raw": audio_tensor.numpy(), # Key must be "raw"
83
- "sampling_rate": 16000 # Key must be "sampling_rate"
84
- }
85
-
86
- def get_system_specs() -> Dict[str, float]:
87
- """Get system specifications."""
88
- # Get RAM
89
- ram_gb = psutil.virtual_memory().total / (1024**3)
90
-
91
- # Get GPU info if available
92
- gpu_vram_gb = 0
93
- if torch.cuda.is_available():
94
- try:
95
- # Query GPU memory in bytes and convert to GB
96
- gpu_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
97
- except Exception as e:
98
- print(f"Warning: Could not get GPU memory: {e}")
99
-
100
- return {
101
- "ram_gb": ram_gb,
102
- "gpu_vram_gb": gpu_vram_gb
103
- }
104
-
105
- def select_best_model() -> Tuple[str, str]:
106
- """Select the best model based on system specifications."""
107
- specs = get_system_specs()
108
- print(f"\nSystem specifications:")
109
- print(f"RAM: {specs['ram_gb']:.1f} GB")
110
- print(f"GPU VRAM: {specs['gpu_vram_gb']:.1f} GB")
111
-
112
- # Prioritize GPU if available
113
- if specs['gpu_vram_gb'] >= 4: # You have 6GB, so this should work
114
- model_tier = "small" # phi-2 should work well on RTX 2060
115
- elif specs['ram_gb'] >= 8:
116
- model_tier = "small"
117
- else:
118
- model_tier = "tiny"
119
-
120
- selected = MODEL_OPTIONS[model_tier]
121
- print(f"\nSelected model tier: {model_tier}")
122
- print(f"Model: {selected['name']}")
123
-
124
- return selected['name'], selected['repo']
125
-
126
- # Set up model paths
127
- MODEL_NAME, REPO_ID = select_best_model()
128
- BASE_DIR = os.path.dirname(os.path.dirname(__file__))
129
- MODEL_DIR = os.path.join(BASE_DIR, "models")
130
- MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)
131
-
132
- from typing import Optional
133
-
134
- def ensure_model(model_name: Optional[str] = None, repo_id: Optional[str] = None) -> str:
135
- """Ensures model is available, downloading only if needed."""
136
-
137
- # Determine environment and set cache directory
138
- if os.path.exists("/home/user"):
139
- # HF Space environment
140
- cache_dir = "/home/user/.cache/models"
141
- else:
142
- # Local development environment
143
- cache_dir = os.path.join(BASE_DIR, "models")
144
-
145
- # Create cache directory if it doesn't exist
146
- try:
147
- os.makedirs(cache_dir, exist_ok=True)
148
- except Exception as e:
149
- print(f"Warning: Could not create cache directory {cache_dir}: {e}")
150
- # Fall back to temporary directory if needed
151
- cache_dir = os.path.join("/tmp", "models")
152
- os.makedirs(cache_dir, exist_ok=True)
153
-
154
- # Get model details
155
- if not model_name or not repo_id:
156
- model_option = MODEL_OPTIONS["small"] # default to small model
157
- model_name = model_option["name"]
158
- repo_id = model_option["repo"]
159
-
160
- # Ensure model_name and repo_id are not None
161
- if model_name is None:
162
- raise ValueError("model_name cannot be None")
163
- if repo_id is None:
164
- raise ValueError("repo_id cannot be None")
165
- # Check if model already exists in cache
166
- model_path = os.path.join(cache_dir, model_name)
167
- if os.path.exists(model_path):
168
- print(f"\nUsing cached model: {model_path}")
169
- return model_path
170
-
171
- print(f"\nDownloading model {model_name} from {repo_id}...")
172
- try:
173
- model_path = hf_hub_download(
174
- repo_id=repo_id,
175
- filename=model_name,
176
- cache_dir=cache_dir,
177
- local_dir=cache_dir
178
- )
179
- print(f"Model downloaded successfully to {model_path}")
180
- return model_path
181
- except Exception as e:
182
- print(f"Error downloading model: {str(e)}")
183
- raise
184
-
185
- # Ensure model is downloaded
186
- model_path = ensure_model()
187
-
188
- # Configure local LLM with LlamaCPP
189
- print("\nInitializing LLM...")
190
- llm = LlamaCPP(
191
- model_path=model_path,
192
- temperature=0.7,
193
- max_new_tokens=256,
194
- context_window=2048,
195
- verbose=False # Reduce logging
196
- # n_batch and n_threads are not valid parameters for LlamaCPP and should not be used.
197
- # If you encounter segmentation faults, try reducing context_window or check your system resources.
198
- )
199
- print("LLM initialized successfully")
200
-
201
- # Configure global settings
202
- print("\nConfiguring settings...")
203
- Settings.llm = llm
204
- Settings.embed_model = HuggingFaceEmbedding(
205
- model_name="sentence-transformers/all-MiniLM-L6-v2"
206
- )
207
- print("Settings configured")
208
-
209
- # Create the index at startup
210
- print("\nCreating symptom index...")
211
- symptom_index = create_symptom_index()
212
- print("Index created successfully")
213
- print("Loaded symptom_index:", type(symptom_index))
214
-
215
- # --- System prompt ---
216
- SYSTEM_PROMPT = """
217
- You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
218
- At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?")
219
- or, if you have enough info, output a final JSON with fields:
220
- {"diagnoses":[…], "confidences":[…]}.
221
- """
222
-
223
- def process_speech(audio_data, history):
224
- """Process speech input and convert to text."""
225
- try:
226
- if not audio_data:
227
- return []
228
-
229
- if isinstance(audio_data, tuple) and len(audio_data) == 2:
230
- sample_rate, audio_array = audio_data
231
-
232
- # Audio preprocessing
233
- if audio_array.ndim > 1:
234
- audio_array = audio_array.mean(axis=1)
235
- audio_array = audio_array.astype(np.float32)
236
- audio_array /= np.max(np.abs(audio_array))
237
-
238
- # Ensure correct sampling rate
239
- if sample_rate != 16000:
240
- resampler = T.Resample(sample_rate, 16000)
241
- audio_tensor = torch.FloatTensor(audio_array)
242
- audio_tensor = resampler(audio_tensor)
243
- audio_array = audio_tensor.numpy()
244
- sample_rate = 16000
245
-
246
- # Transcribe with error handling
247
-
248
- # Format dictionary correctly with required keys
249
- input_features = {
250
- "raw": audio_array,
251
- "sampling_rate": sample_rate
252
- }
253
-
254
- result = transcriber(input_features)
255
-
256
- # Handle different result types
257
- if isinstance(result, dict) and "text" in result:
258
- transcript = result["text"].strip()
259
- elif isinstance(result, str):
260
- transcript = result.strip()
261
- else:
262
- print(f"Unexpected transcriber result type: {type(result)}")
263
- return []
264
-
265
- if not transcript:
266
- print("No transcription generated")
267
- return []
268
-
269
- # Query symptoms with transcribed text
270
- diagnosis_query = f"""
271
- Given these symptoms: '{transcript}'
272
- Identify the most likely ICD-10 diagnoses and key questions.
273
- Focus on clinical implications.
274
- """
275
-
276
- response = symptom_index.as_query_engine().query(diagnosis_query)
277
-
278
- return [
279
- {"role": "user", "content": transcript},
280
- {"role": "assistant", "content": json.dumps({
281
- "diagnoses": [],
282
- "confidences": [],
283
- "follow_up": str(response)
284
- })}
285
- ]
286
-
287
- else:
288
- print(f"Invalid audio format: {type(audio_data)}")
289
- return []
290
-
291
- except Exception as e:
292
- print(f"Processing error: {str(e)}")
293
- return []
294
-
295
- # Build enhanced Gradio interface
296
- with gr.Blocks(theme="default") as demo:
297
- gr.Markdown("""
298
- # 🏥 Medical Symptom to ICD-10 Code Assistant
299
-
300
- ## About
301
- This application is part of the Agents+MCP Hackathon. It helps medical professionals
302
- and patients understand potential diagnoses based on described symptoms.
303
-
304
- ### How it works:
305
- 1. Either click the record button and describe your symptoms or type them into the textbox
306
- 2. The AI will analyze your description and suggest possible diagnoses
307
- 3. Answer follow-up questions to refine the diagnosis
308
- """)
309
-
310
- with gr.Row():
311
- with gr.Column(scale=2):
312
- # Add text input above microphone
313
- with gr.Row():
314
- text_input = gr.Textbox(
315
- label="Type your symptoms",
316
- placeholder="Or type your symptoms here...",
317
- lines=3
318
- )
319
- submit_btn = gr.Button("Submit", variant="primary")
320
-
321
- # Existing microphone row
322
- with gr.Row():
323
- microphone = gr.Audio(
324
- sources=["microphone"],
325
- streaming=True,
326
- type="numpy",
327
- label="Describe your symptoms"
328
- )
329
- transcript_box = gr.Textbox(
330
- label="Transcribed Text",
331
- interactive=False,
332
- show_label=True
333
- )
334
- clear_btn = gr.Button("Clear Chat", variant="secondary")
335
-
336
- chatbot = gr.Chatbot(
337
- label="Medical Consultation",
338
- height=500,
339
- container=True,
340
- type="messages" # This is now properly supported by our message format
341
- )
342
-
343
- with gr.Column(scale=1):
344
- with gr.Accordion("Advanced Settings", open=False):
345
- api_key = gr.Textbox(
346
- label="OpenAI API Key (optional)",
347
- type="password",
348
- placeholder="sk-..."
349
- )
350
-
351
- with gr.Row():
352
- with gr.Column():
353
- modal_key = gr.Textbox(
354
- label="Modal Labs API Key",
355
- type="password",
356
- placeholder="mk-..."
357
- )
358
- anthropic_key = gr.Textbox(
359
- label="Anthropic API Key",
360
- type="password",
361
- placeholder="sk-ant-..."
362
- )
363
- mistral_key = gr.Textbox(
364
- label="MistralAI API Key",
365
- type="password",
366
- placeholder="..."
367
- )
368
-
369
- with gr.Column():
370
- nebius_key = gr.Textbox(
371
- label="Nebius API Key",
372
- type="password",
373
- placeholder="..."
374
- )
375
- hyperbolic_key = gr.Textbox(
376
- label="Hyperbolic Labs API Key",
377
- type="password",
378
- placeholder="hyp-..."
379
- )
380
- sambanova_key = gr.Textbox(
381
- label="SambaNova API Key",
382
- type="password",
383
- placeholder="..."
384
- )
385
-
386
- with gr.Row():
387
- model_selector = gr.Dropdown(
388
- choices=["OpenAI", "Modal", "Anthropic", "MistralAI", "Nebius", "Hyperbolic", "SambaNova"],
389
- value="OpenAI",
390
- label="Model Provider"
391
- )
392
- temperature = gr.Slider(
393
- minimum=0,
394
- maximum=1,
395
- value=0.7,
396
- label="Temperature"
397
- )
398
- # self promotion at bottom of page
399
- gr.Markdown("""
400
- ---
401
- ### 👋 About the Creator
402
-
403
- Hi! I'm Graham Paasch, an experienced technology professional!
404
-
405
- 🎥 **Check out my YouTube channel** for more tech content:
406
- [Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ)
407
-
408
- 💼 **Looking for a skilled developer?**
409
- I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/)
410
-
411
- ⭐ If you found this tool helpful, please consider:
412
- - Subscribing to my YouTube channel
413
- - Connecting on LinkedIn
414
- - Sharing this tool with others in healthcare tech
415
- """)
416
-
417
- # Event handlers
418
- clear_btn.click(lambda: None, None, chatbot, queue=False)
419
-
420
- def format_response_for_user(response_dict):
421
- """Format the assistant's response dictionary into a user-friendly string."""
422
- diagnoses = response_dict.get("diagnoses", [])
423
- confidences = response_dict.get("confidences", [])
424
- follow_up = response_dict.get("follow_up", "")
425
- result = ""
426
- if diagnoses:
427
- result += "Possible Diagnoses:\n"
428
- for i, diag in enumerate(diagnoses):
429
- conf = f" ({confidences[i]*100:.1f}%)" if i < len(confidences) else ""
430
- result += f"- {diag}{conf}\n"
431
- if follow_up:
432
- result += f"\nFollow-up: {follow_up}"
433
- return result.strip()
434
-
435
- def enhanced_process_speech(audio_path, history, api_key=None, model_tier="small", temp=0.7):
436
- """Handle streaming speech processing and chat updates."""
437
-
438
- transcriber = get_asr_pipeline()
439
-
440
- if not audio_path:
441
- return history
442
-
443
- try:
444
- if isinstance(audio_path, tuple) and len(audio_path) == 2:
445
- sample_rate, audio_array = audio_path
446
-
447
- # Audio preprocessing
448
- if audio_array.ndim > 1:
449
- audio_array = audio_array.mean(axis=1)
450
- audio_array = audio_array.astype(np.float32)
451
- audio_array /= np.max(np.abs(audio_array))
452
-
453
- # Ensure correct sampling rate
454
- if sample_rate != 16000:
455
- resampler = T.Resample(
456
- orig_freq=sample_rate,
457
- new_freq=16000
458
- )
459
- audio_tensor = torch.FloatTensor(audio_array)
460
- audio_tensor = resampler(audio_tensor)
461
- audio_array = audio_tensor.numpy()
462
- sample_rate = 16000
463
-
464
- # Format input dictionary exactly as required
465
- transcriber_input = {
466
- "raw": audio_array,
467
- "sampling_rate": sample_rate
468
- }
469
-
470
- # Get transcription from Whisper
471
- result = transcriber(transcriber_input)
472
-
473
- # Extract text from result
474
- transcript = ""
475
- if isinstance(result, dict):
476
- transcript = result.get("text", "").strip()
477
- elif isinstance(result, str):
478
- transcript = result.strip()
479
-
480
- if not transcript:
481
- return history
482
-
483
- # Process the symptoms
484
- diagnosis_query = f"""
485
- Based on these symptoms: '{transcript}'
486
- Provide relevant ICD-10 codes and diagnostic questions.
487
- """
488
- response = symptom_index.as_query_engine().query(diagnosis_query)
489
-
490
- # Format and return chat messages
491
- return history + [
492
- {"role": "user", "content": transcript},
493
- {"role": "assistant", "content": format_response_for_user({
494
- "diagnoses": [],
495
- "confidences": [],
496
- "follow_up": str(response)
497
- })}
498
- ]
499
-
500
- except Exception as e:
501
- print(f"Streaming error: {str(e)}")
502
- return history
503
-
504
- microphone.stream(
505
- fn=enhanced_process_speech,
506
- inputs=[microphone, chatbot, api_key, model_selector, temperature],
507
- outputs=chatbot,
508
- show_progress="hidden",
509
- api_name=False,
510
- queue=True # Enable queuing for better stream handling
511
- )
512
-
513
- def process_audio(audio_array, sample_rate):
514
- """Pre-process audio for Whisper."""
515
- if audio_array.ndim > 1:
516
- audio_array = audio_array.mean(axis=1)
517
-
518
- # Convert to tensor for resampling
519
- audio_tensor = torch.FloatTensor(audio_array)
520
-
521
- # Resample to 16kHz if needed
522
- if sample_rate != 16000:
523
- resampler = T.Resample(sample_rate, 16000)
524
- audio_tensor = resampler(audio_tensor)
525
-
526
- # Normalize
527
- audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
528
-
529
- # Convert back to numpy array and return in correct format
530
- return {
531
- "raw": audio_tensor.numpy(), # Key must be "raw"
532
- "sampling_rate": 16000 # Key must be "sampling_rate"
533
- }
534
-
535
- # Update transcription handler
536
- def update_live_transcription(audio):
537
- """Real-time transcription updates."""
538
- if not audio or not isinstance(audio, tuple):
539
- return ""
540
-
541
- try:
542
- sample_rate, audio_array = audio
543
- features = process_audio(audio_array, sample_rate)
544
-
545
- asr = get_asr_pipeline()
546
- result = asr(features)
547
-
548
- return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
549
- except Exception as e:
550
- print(f"Transcription error: {str(e)}")
551
- return ""
552
-
553
- microphone.stream(
554
- fn=update_live_transcription,
555
- inputs=[microphone],
556
- outputs=transcript_box,
557
- show_progress="hidden",
558
- queue=True
559
- )
560
-
561
- clear_btn.click(
562
- fn=lambda: (None, "", ""),
563
- outputs=[chatbot, transcript_box, text_input],
564
- queue=False
565
- )
566
-
567
- def cleanup_memory():
568
- """Release unused memory (placeholder for future memory management)."""
569
- import gc
570
- gc.collect()
571
- if torch.cuda.is_available():
572
- torch.cuda.empty_cache()
573
-
574
- def process_text_input(text, history):
575
- """Process text input with memory management."""
576
-
577
- print("process_text_input received:", text)
578
-
579
- if not text:
580
- return history, "" # Return tuple to clear input
581
-
582
- try:
583
- # Process the symptoms using the configured LLM
584
- prompt = f"""Given these symptoms: '{text}'
585
- Please provide:
586
- 1. Most likely ICD-10 codes
587
- 2. Confidence levels for each diagnosis
588
- 3. Key follow-up questions
589
-
590
- Format as JSON with diagnoses, confidences, and follow_up fields."""
591
-
592
- response = llm.complete(prompt)
593
-
594
- try:
595
- # Try to parse as JSON first
596
- result = json.loads(response.text)
597
- except json.JSONDecodeError:
598
- # If not JSON, wrap in our format
599
- result = {
600
- "diagnoses": [],
601
- "confidences": [],
602
- "follow_up": str(response.text)[:1000] # Limit response length
603
- }
604
-
605
- new_history = history + [
606
- {"role": "user", "content": text},
607
- {"role": "assistant", "content": format_response_for_user(result)}
608
- ]
609
- return new_history, "" # Return empty string to clear input
610
- except Exception as e:
611
- print(f"Error processing text: {str(e)}")
612
- return history, text # Keep text on error
613
-
614
- # Update the submit button handler
615
- submit_btn.click(
616
- fn=process_text_input,
617
- inputs=[text_input, chatbot],
618
- outputs=[chatbot, text_input],
619
- queue=True
620
- ).success( # Changed from .then to .success for better error handling
621
- fn=cleanup_memory,
622
- inputs=None,
623
- outputs=None,
624
- queue=False
625
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/model_configuration_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Defines available model configurations.
2
+
3
+ Maps three tiers (“tiny”, “small”, “medium”) to their model filename, Hugging Face repo, required GPU VRAM, and required system RAM.
4
+
5
+ get_system_specs() uses psutil to compute total system RAM in GB and torch.cuda to query GPU VRAM in GB (zero if no CUDA device).
6
+
7
+ select_best_model() prints detected RAM and GPU VRAM, chooses “small” if GPU VRAM ≥ 4 GB or if RAM ≥ 8 GB, otherwise “tiny”, prints the chosen tier and model name, and returns the model filename and repo string.
8
+ '''
9
+ import os
10
+ import psutil
11
+ from typing import Tuple, Dict
12
+ import torch
13
+ import torchaudio.transforms as T
14
+ from huggingface_hub import hf_hub_download
15
+ from typing import Optional
16
+
17
+ # Model options mapped to their requirements
18
+ MODEL_OPTIONS = {
19
+ "tiny": {
20
+ "name": "TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf",
21
+ "repo": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
22
+ "vram_req": 2, # GB
23
+ "ram_req": 4 # GB
24
+ },
25
+ "small": {
26
+ "name": "phi-2.Q4_K_M.gguf",
27
+ "repo": "TheBloke/phi-2-GGUF",
28
+ "vram_req": 4,
29
+ "ram_req": 8
30
+ },
31
+ "medium": {
32
+ "name": "mistral-7b-instruct-v0.1.Q4_K_M.gguf",
33
+ "repo": "TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
34
+ "vram_req": 6,
35
+ "ram_req": 16
36
+ }
37
+ }
38
+
39
+ def get_system_specs() -> Dict[str, float]:
40
+ """Get system specifications."""
41
+ # Get RAM
42
+ ram_gb = psutil.virtual_memory().total / (1024**3)
43
+
44
+ # Get GPU info if available
45
+ gpu_vram_gb = 0
46
+ if torch.cuda.is_available():
47
+ try:
48
+ # Query GPU memory in bytes and convert to GB
49
+ gpu_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
50
+ except Exception as e:
51
+ print(f"Warning: Could not get GPU memory: {e}")
52
+
53
+ return {
54
+ "ram_gb": ram_gb,
55
+ "gpu_vram_gb": gpu_vram_gb
56
+ }
57
+
58
+ def select_best_model() -> Tuple[str, str]:
59
+ """Select the best model based on system specifications."""
60
+ specs = get_system_specs()
61
+ print(f"\nSystem specifications:")
62
+ print(f"RAM: {specs['ram_gb']:.1f} GB")
63
+ print(f"GPU VRAM: {specs['gpu_vram_gb']:.1f} GB")
64
+
65
+ # Prioritize GPU if available
66
+ if specs['gpu_vram_gb'] >= 4: # You have 6GB, so this should work
67
+ model_tier = "small" # phi-2 should work well on RTX 2060
68
+ elif specs['ram_gb'] >= 8:
69
+ model_tier = "small"
70
+ else:
71
+ model_tier = "tiny"
72
+
73
+ selected = MODEL_OPTIONS[model_tier]
74
+ print(f"\nSelected model tier: {model_tier}")
75
+ print(f"Model: {selected['name']}")
76
+
77
+ return selected['name'], selected['repo']
78
+
79
+ def ensure_model(model_name: Optional[str] = None, repo_id: Optional[str] = None) -> str:
80
+ """Ensures model is available, downloading only if needed."""
81
+ BASE_DIR = os.path.dirname(os.path.dirname(__file__))
82
+
83
+ # Determine environment and set cache directory
84
+ if os.path.exists("/home/user"):
85
+ # HF Space environment
86
+ cache_dir = "/home/user/.cache/models"
87
+ else:
88
+ # Local development environment
89
+ cache_dir = os.path.join(BASE_DIR, "models")
90
+
91
+ # Create cache directory if it doesn't exist
92
+ try:
93
+ os.makedirs(cache_dir, exist_ok=True)
94
+ except Exception as e:
95
+ print(f"Warning: Could not create cache directory {cache_dir}: {e}")
96
+ # Fall back to temporary directory if needed
97
+ cache_dir = os.path.join("/tmp", "models")
98
+ os.makedirs(cache_dir, exist_ok=True)
99
+
100
+ # Get model details
101
+ if not model_name or not repo_id:
102
+ model_option = MODEL_OPTIONS["small"] # default to small model
103
+ model_name = model_option["name"]
104
+ repo_id = model_option["repo"]
105
+
106
+ # Ensure model_name and repo_id are not None
107
+ if model_name is None:
108
+ raise ValueError("model_name cannot be None")
109
+ if repo_id is None:
110
+ raise ValueError("repo_id cannot be None")
111
+ # Check if model already exists in cache
112
+ model_path = os.path.join(cache_dir, model_name)
113
+ if os.path.exists(model_path):
114
+ print(f"\nUsing cached model: {model_path}")
115
+ return model_path
116
+
117
+ print(f"\nDownloading model {model_name} from {repo_id}...")
118
+
119
+ model_path = hf_hub_download(
120
+ repo_id=repo_id,
121
+ filename=model_name,
122
+ cache_dir=cache_dir,
123
+ local_dir=cache_dir
124
+ )
125
+ print(f"Model downloaded successfully to {model_path}")
126
+ return model_path
utils/voice_input_utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
2
+ from transformers.pipelines import pipeline
3
+ import torch
4
+ import torchaudio.transforms as T
5
+ import numpy as np
6
+ import json
7
+
8
+ # Initialize Whisper components globally (these are lightweight)
9
+ feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en")
10
+ tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
11
+ processor = WhisperProcessor(feature_extractor, tokenizer)
12
+
13
+ def get_asr_pipeline():
14
+ """Lazy load ASR pipeline with proper configuration."""
15
+ global transcriber
16
+ if "transcriber" not in globals():
17
+ transcriber = pipeline(
18
+ "automatic-speech-recognition",
19
+ model="openai/whisper-base.en",
20
+ chunk_length_s=30,
21
+ stride_length_s=5,
22
+ device="cpu",
23
+ torch_dtype=torch.float32
24
+ )
25
+ return transcriber
26
+
27
+ def process_audio(audio_array, sample_rate):
28
+ """Pre-process audio for Whisper."""
29
+ if audio_array.ndim > 1:
30
+ audio_array = audio_array.mean(axis=1)
31
+
32
+ # Convert to tensor for resampling
33
+ audio_tensor = torch.FloatTensor(audio_array)
34
+
35
+ # Resample to 16kHz if needed
36
+ if sample_rate != 16000:
37
+ resampler = T.Resample(sample_rate, 16000)
38
+ audio_tensor = resampler(audio_tensor)
39
+
40
+ # Normalize
41
+ audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
42
+
43
+ # Convert back to numpy array and return in correct format
44
+ return {
45
+ "raw": audio_tensor.numpy(), # Key must be "raw"
46
+ "sampling_rate": 16000 # Key must be "sampling_rate"
47
+ }
48
+
49
+ def process_speech(audio_data, symptom_index):
50
+ """Process speech input and convert to text."""
51
+ if not audio_data:
52
+ return []
53
+
54
+ if isinstance(audio_data, tuple) and len(audio_data) == 2:
55
+ sample_rate, audio_array = audio_data
56
+
57
+ # Audio preprocessing
58
+ if audio_array.ndim > 1:
59
+ audio_array = audio_array.mean(axis=1)
60
+ audio_array = audio_array.astype(np.float32)
61
+ audio_array /= np.max(np.abs(audio_array))
62
+
63
+ # Ensure correct sampling rate
64
+ if sample_rate != 16000:
65
+ resampler = T.Resample(sample_rate, 16000)
66
+ audio_tensor = torch.FloatTensor(audio_array)
67
+ audio_tensor = resampler(audio_tensor)
68
+ audio_array = audio_tensor.numpy()
69
+ sample_rate = 16000
70
+
71
+ # Transcribe with error handling
72
+
73
+ # Format dictionary correctly with required keys
74
+ input_features = {
75
+ "raw": audio_array,
76
+ "sampling_rate": sample_rate
77
+ }
78
+
79
+ result = transcriber(input_features)
80
+
81
+ # Handle different result types
82
+ if isinstance(result, dict) and "text" in result:
83
+ transcript = result["text"].strip()
84
+ elif isinstance(result, str):
85
+ transcript = result.strip()
86
+ else:
87
+ print(f"Unexpected transcriber result type: {type(result)}")
88
+ return []
89
+
90
+ if not transcript:
91
+ print("No transcription generated")
92
+ return []
93
+
94
+ # Query symptoms with transcribed text
95
+ diagnosis_query = f"""
96
+ Given these symptoms: '{transcript}'
97
+ Identify the most likely ICD-10 diagnoses and key questions.
98
+ Focus on clinical implications.
99
+ """
100
+
101
+ response = symptom_index.as_query_engine().query(diagnosis_query)
102
+
103
+ return [
104
+ {"role": "user", "content": transcript},
105
+ {"role": "assistant", "content": json.dumps({
106
+ "diagnoses": [],
107
+ "confidences": [],
108
+ "follow_up": str(response)
109
+ })}
110
+ ]
111
+
112
+ else:
113
+ print(f"Invalid audio format: {type(audio_data)}")
114
+ return []
115
+
116
+ def format_response_for_user(response_dict):
117
+ """Format the assistant's response dictionary into a user-friendly string."""
118
+ diagnoses = response_dict.get("diagnoses", [])
119
+ confidences = response_dict.get("confidences", [])
120
+ follow_up = response_dict.get("follow_up", "")
121
+ result = ""
122
+ if diagnoses:
123
+ result += "Possible Diagnoses:\n"
124
+ for i, diag in enumerate(diagnoses):
125
+ conf = f" ({confidences[i]*100:.1f}%)" if i < len(confidences) else ""
126
+ result += f"- {diag}{conf}\n"
127
+ if follow_up:
128
+ result += f"\nFollow-up: {follow_up}"
129
+ return result.strip()
130
+
131
+ def enhanced_process_speech(audio_path, symptom_index, history, api_key=None, model_tier="small", temp=0.7):
132
+ """Handle streaming speech processing and chat updates."""
133
+
134
+ transcriber = get_asr_pipeline()
135
+
136
+ if not audio_path:
137
+ return history
138
+
139
+ if isinstance(audio_path, tuple) and len(audio_path) == 2:
140
+ sample_rate, audio_array = audio_path
141
+
142
+ # Audio preprocessing
143
+ if audio_array.ndim > 1:
144
+ audio_array = audio_array.mean(axis=1)
145
+ audio_array = audio_array.astype(np.float32)
146
+ audio_array /= np.max(np.abs(audio_array))
147
+
148
+ # Ensure correct sampling rate
149
+ if sample_rate != 16000:
150
+ resampler = T.Resample(
151
+ orig_freq=sample_rate,
152
+ new_freq=16000
153
+ )
154
+ audio_tensor = torch.FloatTensor(audio_array)
155
+ audio_tensor = resampler(audio_tensor)
156
+ audio_array = audio_tensor.numpy()
157
+ sample_rate = 16000
158
+
159
+ # Format input dictionary exactly as required
160
+ transcriber_input = {
161
+ "raw": audio_array,
162
+ "sampling_rate": sample_rate
163
+ }
164
+
165
+ # Get transcription from Whisper
166
+ result = transcriber(transcriber_input)
167
+
168
+ # Extract text from result
169
+ transcript = ""
170
+ if isinstance(result, dict):
171
+ transcript = result.get("text", "").strip()
172
+ elif isinstance(result, str):
173
+ transcript = result.strip()
174
+
175
+ if not transcript:
176
+ return history
177
+
178
+ # Process the symptoms
179
+ diagnosis_query = f"""
180
+ Based on these symptoms: '{transcript}'
181
+ Provide relevant ICD-10 codes and diagnostic questions.
182
+ """
183
+ response = symptom_index.as_query_engine().query(diagnosis_query)
184
+
185
+ # Format and return chat messages
186
+ return history + [
187
+ {"role": "user", "content": transcript},
188
+ {"role": "assistant", "content": format_response_for_user({
189
+ "diagnoses": [],
190
+ "confidences": [],
191
+ "follow_up": str(response)
192
+ })}
193
+ ]