Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| from llama_index.core import Settings | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.llms.llama_cpp import LlamaCPP | |
| from .parse_tabular import create_symptom_index # Use relative import | |
| import json | |
| import psutil | |
| from typing import Tuple, Dict | |
| import torch | |
| from gtts import gTTS | |
| import io | |
| import base64 | |
| import numpy as np | |
| from transformers.pipelines import pipeline # Changed from transformers import pipeline | |
| from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| # Model options mapped to their requirements | |
| MODEL_OPTIONS = { | |
| "tiny": { | |
| "name": "TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf", | |
| "repo": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", | |
| "vram_req": 2, # GB | |
| "ram_req": 4 # GB | |
| }, | |
| "small": { | |
| "name": "phi-2.Q4_K_M.gguf", | |
| "repo": "TheBloke/phi-2-GGUF", | |
| "vram_req": 4, | |
| "ram_req": 8 | |
| }, | |
| "medium": { | |
| "name": "mistral-7b-instruct-v0.1.Q4_K_M.gguf", | |
| "repo": "TheBloke/Mistral-7B-Instruct-v0.1-GGUF", | |
| "vram_req": 6, | |
| "ram_req": 16 | |
| } | |
| } | |
| # Initialize Whisper components globally (these are lightweight) | |
| feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en") | |
| tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en") | |
| processor = WhisperProcessor(feature_extractor, tokenizer) | |
| def get_asr_pipeline(): | |
| """Lazy load ASR pipeline with proper configuration.""" | |
| global transcriber | |
| if "transcriber" not in globals(): | |
| transcriber = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-base.en", | |
| chunk_length_s=30, | |
| stride_length_s=5, | |
| device="cpu", | |
| torch_dtype=torch.float32 | |
| ) | |
| return transcriber | |
| # Audio preprocessing function | |
| def process_audio(audio_array, sample_rate): | |
| """Pre-process audio for Whisper.""" | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| # Normalize audio | |
| audio_array = audio_array.astype(np.float32) | |
| audio_array /= np.max(np.abs(audio_array)) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| resampler = T.Resample(orig_freq=sample_rate, new_freq=16000) | |
| audio_tensor = torch.FloatTensor(audio_array) | |
| audio_tensor = resampler(audio_tensor) | |
| audio_array = audio_tensor.numpy() | |
| # Process with correct input format | |
| inputs = processor( | |
| audio_array, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ) | |
| return { | |
| "input_features": inputs.input_features, | |
| "attention_mask": inputs.attention_mask | |
| } | |
| # Update transcriber configuration | |
| transcriber = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-base.en", | |
| chunk_length_s=30, | |
| stride_length_s=5, | |
| device="cpu", | |
| torch_dtype=torch.float32, | |
| feature_extractor=feature_extractor, | |
| generate_kwargs={ | |
| "use_cache": True, | |
| "return_timestamps": True | |
| } | |
| ) | |
| def get_system_specs() -> Dict[str, float]: | |
| """Get system specifications.""" | |
| # Get RAM | |
| ram_gb = psutil.virtual_memory().total / (1024**3) | |
| # Get GPU info if available | |
| gpu_vram_gb = 0 | |
| if torch.cuda.is_available(): | |
| try: | |
| # Query GPU memory in bytes and convert to GB | |
| gpu_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| except Exception as e: | |
| print(f"Warning: Could not get GPU memory: {e}") | |
| return { | |
| "ram_gb": ram_gb, | |
| "gpu_vram_gb": gpu_vram_gb | |
| } | |
| def select_best_model() -> Tuple[str, str]: | |
| """Select the best model based on system specifications.""" | |
| specs = get_system_specs() | |
| print(f"\nSystem specifications:") | |
| print(f"RAM: {specs['ram_gb']:.1f} GB") | |
| print(f"GPU VRAM: {specs['gpu_vram_gb']:.1f} GB") | |
| # Prioritize GPU if available | |
| if specs['gpu_vram_gb'] >= 4: # You have 6GB, so this should work | |
| model_tier = "small" # phi-2 should work well on RTX 2060 | |
| elif specs['ram_gb'] >= 8: | |
| model_tier = "small" | |
| else: | |
| model_tier = "tiny" | |
| selected = MODEL_OPTIONS[model_tier] | |
| print(f"\nSelected model tier: {model_tier}") | |
| print(f"Model: {selected['name']}") | |
| return selected['name'], selected['repo'] | |
| # Set up model paths | |
| MODEL_NAME, REPO_ID = select_best_model() | |
| BASE_DIR = os.path.dirname(os.path.dirname(__file__)) | |
| MODEL_DIR = os.path.join(BASE_DIR, "models") | |
| MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME) | |
| from typing import Optional | |
| def ensure_model(model_name: Optional[str] = None, repo_id: Optional[str] = None) -> str: | |
| """Ensures model is available, downloading only if needed.""" | |
| # Determine environment and set cache directory | |
| if os.path.exists("/home/user"): | |
| # HF Space environment | |
| cache_dir = "/home/user/.cache/models" | |
| else: | |
| # Local development environment | |
| cache_dir = os.path.join(BASE_DIR, "models") | |
| # Create cache directory if it doesn't exist | |
| try: | |
| os.makedirs(cache_dir, exist_ok=True) | |
| except Exception as e: | |
| print(f"Warning: Could not create cache directory {cache_dir}: {e}") | |
| # Fall back to temporary directory if needed | |
| cache_dir = os.path.join("/tmp", "models") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Get model details | |
| if not model_name or not repo_id: | |
| model_option = MODEL_OPTIONS["small"] # default to small model | |
| model_name = model_option["name"] | |
| repo_id = model_option["repo"] | |
| # Ensure model_name and repo_id are not None | |
| if model_name is None: | |
| raise ValueError("model_name cannot be None") | |
| if repo_id is None: | |
| raise ValueError("repo_id cannot be None") | |
| # Check if model already exists in cache | |
| model_path = os.path.join(cache_dir, model_name) | |
| if os.path.exists(model_path): | |
| print(f"\nUsing cached model: {model_path}") | |
| return model_path | |
| print(f"\nDownloading model {model_name} from {repo_id}...") | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=model_name, | |
| cache_dir=cache_dir, | |
| local_dir=cache_dir | |
| ) | |
| print(f"Model downloaded successfully to {model_path}") | |
| return model_path | |
| except Exception as e: | |
| print(f"Error downloading model: {str(e)}") | |
| raise | |
| # Ensure model is downloaded | |
| model_path = ensure_model() | |
| # Configure local LLM with LlamaCPP | |
| print("\nInitializing LLM...") | |
| llm = LlamaCPP( | |
| model_path=model_path, | |
| temperature=0.7, | |
| max_new_tokens=256, | |
| context_window=2048, | |
| verbose=False # Reduce logging | |
| # n_batch and n_threads are not valid parameters for LlamaCPP and should not be used. | |
| # If you encounter segmentation faults, try reducing context_window or check your system resources. | |
| ) | |
| print("LLM initialized successfully") | |
| # Configure global settings | |
| print("\nConfiguring settings...") | |
| Settings.llm = llm | |
| Settings.embed_model = HuggingFaceEmbedding( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| print("Settings configured") | |
| # Create the index at startup | |
| print("\nCreating symptom index...") | |
| symptom_index = create_symptom_index() | |
| print("Index created successfully") | |
| # --- System prompt --- | |
| SYSTEM_PROMPT = """ | |
| You are a medical assistant helping a user narrow down to the most likely ICD-10 code. | |
| At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?") | |
| or, if you have enough info, output a final JSON with fields: | |
| {"diagnoses":[…], "confidences":[…]}. | |
| """ | |
| def process_speech(audio_data, history): | |
| """Process speech input and convert to text.""" | |
| try: | |
| if not audio_data: | |
| return [] | |
| if isinstance(audio_data, tuple) and len(audio_data) == 2: | |
| sample_rate, audio_array = audio_data | |
| # Audio preprocessing | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| audio_array = audio_array.astype(np.float32) | |
| audio_array /= np.max(np.abs(audio_array)) | |
| # Ensure correct sampling rate | |
| if sample_rate != 16000: | |
| resampler = T.Resample(sample_rate, 16000) | |
| audio_tensor = torch.FloatTensor(audio_array) | |
| audio_tensor = resampler(audio_tensor) | |
| audio_array = audio_tensor.numpy() | |
| sample_rate = 16000 | |
| # Transcribe with error handling | |
| try: | |
| # Format dictionary correctly with required keys | |
| input_features = { | |
| "raw": audio_array, | |
| "sampling_rate": sample_rate | |
| } | |
| result = transcriber(input_features) | |
| # Handle different result types | |
| if isinstance(result, dict) and "text" in result: | |
| transcript = result["text"].strip() | |
| elif isinstance(result, str): | |
| transcript = result.strip() | |
| else: | |
| print(f"Unexpected transcriber result type: {type(result)}") | |
| return [] | |
| if not transcript: | |
| print("No transcription generated") | |
| return [] | |
| # Query symptoms with transcribed text | |
| diagnosis_query = f""" | |
| Given these symptoms: '{transcript}' | |
| Identify the most likely ICD-10 diagnoses and key questions. | |
| Focus on clinical implications. | |
| """ | |
| response = symptom_index.as_query_engine().query(diagnosis_query) | |
| return [ | |
| {"role": "user", "content": transcript}, | |
| {"role": "assistant", "content": json.dumps({ | |
| "diagnoses": [], | |
| "confidences": [], | |
| "follow_up": str(response) | |
| })} | |
| ] | |
| except Exception as e: | |
| print(f"Transcription error: {str(e)}") | |
| return [] | |
| else: | |
| print(f"Invalid audio format: {type(audio_data)}") | |
| return [] | |
| except Exception as e: | |
| print(f"Processing error: {str(e)}") | |
| return [] | |
| def update_transcription(audio_path): | |
| """Update transcription box with speech recognition results.""" | |
| if not audio_path: | |
| return "" | |
| # Extract transcription from audio result | |
| transcript = audio_path[1] if isinstance(audio_path, tuple) else audio_path | |
| return transcript | |
| # Build enhanced Gradio interface | |
| with gr.Blocks( | |
| theme="default", | |
| css=""" | |
| * { | |
| font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', | |
| Roboto, Ubuntu, 'Helvetica Neue', Arial, sans-serif; | |
| } | |
| code, pre { | |
| font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, | |
| 'Liberation Mono', 'Courier New', monospace; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 🏥 Medical Symptom to ICD-10 Code Assistant | |
| ## About | |
| This application is part of the Agents+MCP Hackathon. It helps medical professionals | |
| and patients understand potential diagnoses based on described symptoms. | |
| ### How it works: | |
| 1. Either click the record button and describe your symptoms or type them into the textbox | |
| 2. The AI will analyze your description and suggest possible diagnoses | |
| 3. Answer follow-up questions to refine the diagnosis | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Add text input above microphone | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Type your symptoms", | |
| placeholder="Or type your symptoms here...", | |
| lines=3 | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| # Existing microphone row | |
| with gr.Row(): | |
| microphone = gr.Audio( | |
| sources=["microphone"], | |
| streaming=True, | |
| type="numpy", | |
| label="Describe your symptoms" | |
| ) | |
| transcript_box = gr.Textbox( | |
| label="Transcribed Text", | |
| interactive=False, | |
| show_label=True | |
| ) | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| chatbot = gr.Chatbot( | |
| label="Medical Consultation", | |
| height=500, | |
| container=True, | |
| type="messages" # This is now properly supported by our message format | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Advanced Settings", open=False): | |
| api_key = gr.Textbox( | |
| label="OpenAI API Key (optional)", | |
| type="password", | |
| placeholder="sk-..." | |
| ) | |
| model_selector = gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| label="Model Tier", | |
| value="small", | |
| interactive=True | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| label="Temperature" | |
| ) | |
| # Event handlers | |
| clear_btn.click(lambda: None, None, chatbot, queue=False) | |
| def format_response_for_user(response_dict): | |
| """Format the assistant's response dictionary into a user-friendly string.""" | |
| diagnoses = response_dict.get("diagnoses", []) | |
| confidences = response_dict.get("confidences", []) | |
| follow_up = response_dict.get("follow_up", "") | |
| result = "" | |
| if diagnoses: | |
| result += "Possible Diagnoses:\n" | |
| for i, diag in enumerate(diagnoses): | |
| conf = f" ({confidences[i]*100:.1f}%)" if i < len(confidences) else "" | |
| result += f"- {diag}{conf}\n" | |
| if follow_up: | |
| result += f"\nFollow-up: {follow_up}" | |
| return result.strip() | |
| def enhanced_process_speech(audio_path, history, api_key=None, model_tier="small", temp=0.7): | |
| """Handle streaming speech processing and chat updates.""" | |
| if not audio_path: | |
| return history | |
| try: | |
| if isinstance(audio_path, tuple) and len(audio_path) == 2: | |
| sample_rate, audio_array = audio_path | |
| # Audio preprocessing | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| audio_array = audio_array.astype(np.float32) | |
| audio_array /= np.max(np.abs(audio_array)) | |
| # Ensure correct sampling rate | |
| if sample_rate != 16000: | |
| resampler = T.Resample( | |
| orig_freq=sample_rate, | |
| new_freq=16000 | |
| ) | |
| audio_tensor = torch.FloatTensor(audio_array) | |
| audio_tensor = resampler(audio_tensor) | |
| audio_array = audio_tensor.numpy() | |
| sample_rate = 16000 | |
| # Format input dictionary exactly as required | |
| transcriber_input = { | |
| "raw": audio_array, | |
| "sampling_rate": sample_rate | |
| } | |
| # Get transcription from Whisper | |
| result = transcriber(transcriber_input) | |
| # Extract text from result | |
| transcript = "" | |
| if isinstance(result, dict): | |
| transcript = result.get("text", "").strip() | |
| elif isinstance(result, str): | |
| transcript = result.strip() | |
| if not transcript: | |
| return history | |
| # Process the symptoms | |
| diagnosis_query = f""" | |
| Based on these symptoms: '{transcript}' | |
| Provide relevant ICD-10 codes and diagnostic questions. | |
| """ | |
| response = symptom_index.as_query_engine().query(diagnosis_query) | |
| # Format and return chat messages | |
| return history + [ | |
| {"role": "user", "content": transcript}, | |
| {"role": "assistant", "content": format_response_for_user({ | |
| "diagnoses": [], | |
| "confidences": [], | |
| "follow_up": str(response) | |
| })} | |
| ] | |
| except Exception as e: | |
| print(f"Streaming error: {str(e)}") | |
| return history | |
| microphone.stream( | |
| fn=enhanced_process_speech, | |
| inputs=[microphone, chatbot, api_key, model_selector, temperature], | |
| outputs=chatbot, | |
| show_progress="hidden", | |
| api_name=False, | |
| queue=True # Enable queuing for better stream handling | |
| ) | |
| def process_audio(audio_array, sample_rate): | |
| """Pre-process audio for Whisper.""" | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| # Convert to tensor for resampling | |
| audio_tensor = torch.FloatTensor(audio_array) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| resampler = T.Resample(sample_rate, 16000) | |
| audio_tensor = resampler(audio_tensor) | |
| # Normalize | |
| audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor)) | |
| # Use feature extractor with correct sampling rate | |
| features = feature_extractor( | |
| audio_tensor.numpy(), | |
| sampling_rate=16000, # Always use 16kHz | |
| return_tensors="pt" | |
| ) | |
| return { | |
| "input_features": features.input_features, | |
| "sampling_rate": 16000 # Return resampled rate | |
| } | |
| # Update transcription handler | |
| def update_live_transcription(audio): | |
| """Real-time transcription updates.""" | |
| if not audio or not isinstance(audio, tuple): | |
| return "" | |
| try: | |
| sample_rate, audio_array = audio | |
| features = process_audio(audio_array, sample_rate) | |
| # Get pipeline and transcribe | |
| asr = get_asr_pipeline() | |
| result = asr(features) | |
| if isinstance(result, dict): | |
| return result.get("text", "").strip() | |
| elif isinstance(result, str): | |
| return result.strip() | |
| return "" | |
| except Exception as e: | |
| print(f"Transcription error: {str(e)}") | |
| return "" | |
| microphone.stream( | |
| fn=update_live_transcription, | |
| inputs=[microphone], | |
| outputs=transcript_box, | |
| show_progress="hidden", | |
| queue=True | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "", ""), | |
| outputs=[chatbot, transcript_box, text_input], | |
| queue=False | |
| ) | |
| def cleanup_memory(): | |
| """Release unused memory (placeholder for future memory management).""" | |
| import gc | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def process_text_input(text, history): | |
| """Process text input with memory management.""" | |
| if not text: | |
| return history | |
| try: | |
| # Limit input length | |
| if len(text) > 500: | |
| text = text[:500] + "..." | |
| # Process the symptoms | |
| diagnosis_query = f""" | |
| Based on these symptoms: '{text}' | |
| Provide relevant ICD-10 codes and diagnostic questions. | |
| Focus on clinical implications. | |
| Limit response to 1000 characters. | |
| """ | |
| response = symptom_index.as_query_engine().query(diagnosis_query) | |
| # Clean up memory | |
| cleanup_memory() | |
| return history + [ | |
| {"role": "user", "content": text}, | |
| {"role": "assistant", "content": format_response_for_user({ | |
| "diagnoses": [], | |
| "confidences": [], | |
| "follow_up": str(response)[:1000] # Limit response length | |
| })} | |
| ] | |
| except Exception as e: | |
| print(f"Text processing error: {str(e)}") | |
| cleanup_memory() | |
| return history | |
| submit_btn.click( | |
| fn=process_text_input, | |
| inputs=[text_input, chatbot], | |
| outputs=chatbot, | |
| queue=True | |
| ) | |
| # Add footer with social links | |
| gr.Markdown(""" | |
| --- | |
| ### 👋 About the Creator | |
| Hi! I'm Graham Paasch, an experienced technology professional! | |
| 🎥 **Check out my YouTube channel** for more tech content: | |
| [Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ) | |
| 💼 **Looking for a skilled developer?** | |
| I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/) | |
| ⭐ If you found this tool helpful, please consider: | |
| - Subscribing to my YouTube channel | |
| - Connecting on LinkedIn | |
| - Sharing this tool with others in healthcare tech | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| mcp_server=True, | |
| allowed_paths=["*"] | |
| ) | |