Spaces:
Running
on
Zero
Running
on
Zero
| """Gradio UI setup""" | |
| import time | |
| import gradio as gr | |
| import spaces | |
| from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL | |
| import config | |
| from indexing import create_or_update_index | |
| from pipeline import stream_chat | |
| from voice import transcribe_audio, generate_speech | |
| from models import ( | |
| initialize_medical_model, | |
| is_model_loaded, | |
| get_model_loading_state, | |
| set_model_loading_state, | |
| initialize_tts_model, | |
| initialize_whisper_model, | |
| TTS_AVAILABLE, | |
| WHISPER_AVAILABLE, | |
| ) | |
| from logger import logger | |
| MAX_DURATION = 120 | |
| def create_demo(): | |
| """Create and return Gradio demo interface""" | |
| with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
| gr.HTML(TITLE) | |
| gr.HTML(DESCRIPTION) | |
| with gr.Row(elem_classes="main-container"): | |
| with gr.Column(elem_classes="upload-section"): | |
| file_upload = gr.File( | |
| file_count="multiple", | |
| label="Drag and Drop Files Here", | |
| file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"], | |
| elem_id="file-upload" | |
| ) | |
| upload_button = gr.Button("Upload & Index", elem_classes="upload-button") | |
| status_output = gr.Textbox( | |
| label="Status", | |
| placeholder="Upload files to start...", | |
| interactive=False | |
| ) | |
| file_info_output = gr.HTML( | |
| label="File Information", | |
| elem_classes="processing-info" | |
| ) | |
| upload_button.click( | |
| fn=create_or_update_index, | |
| inputs=[file_upload], | |
| outputs=[status_output, file_info_output] | |
| ) | |
| with gr.Column(elem_classes="chatbot-container"): | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| placeholder="Chat with MedSwin... Type your question below.", | |
| show_label=False, | |
| type="messages" | |
| ) | |
| with gr.Row(elem_classes="input-row"): | |
| message_input = gr.Textbox( | |
| placeholder="Type your medical question here...", | |
| show_label=False, | |
| container=False, | |
| lines=1, | |
| scale=10 | |
| ) | |
| mic_button = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="", | |
| show_label=False, | |
| container=False, | |
| scale=1 | |
| ) | |
| submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1) | |
| recording_timer = gr.Textbox( | |
| value="", | |
| label="", | |
| show_label=False, | |
| interactive=False, | |
| visible=False, | |
| container=False, | |
| elem_classes="recording-timer" | |
| ) | |
| recording_start_time = [None] | |
| def handle_recording_start(): | |
| """Called when recording starts""" | |
| recording_start_time[0] = time.time() | |
| return gr.update(visible=True, value="Recording... 0s") | |
| def handle_recording_stop(audio): | |
| """Called when recording stops""" | |
| recording_start_time[0] = None | |
| if audio is None: | |
| return gr.update(visible=False, value=""), "" | |
| transcribed = transcribe_audio(audio) | |
| return gr.update(visible=False, value=""), transcribed | |
| mic_button.start_recording( | |
| fn=handle_recording_start, | |
| outputs=[recording_timer] | |
| ) | |
| mic_button.stop_recording( | |
| fn=handle_recording_stop, | |
| inputs=[mic_button], | |
| outputs=[recording_timer, message_input] | |
| ) | |
| with gr.Row(visible=False) as tts_row: | |
| tts_text = gr.Textbox(visible=False) | |
| tts_audio = gr.Audio(label="Generated Speech", visible=False) | |
| def generate_speech_from_chat(history): | |
| """Extract last assistant message and generate speech""" | |
| if not history or len(history) == 0: | |
| return None | |
| last_msg = history[-1] | |
| if last_msg.get("role") == "assistant": | |
| text = last_msg.get("content", "").replace(" 🔊", "").strip() | |
| if text: | |
| audio_path = generate_speech(text) | |
| return audio_path | |
| return None | |
| tts_button = gr.Button("🔊 Play Response", visible=False, size="sm") | |
| def update_tts_button(history): | |
| if history and len(history) > 0 and history[-1].get("role") == "assistant": | |
| return gr.update(visible=True) | |
| return gr.update(visible=False) | |
| chatbot.change( | |
| fn=update_tts_button, | |
| inputs=[chatbot], | |
| outputs=[tts_button] | |
| ) | |
| tts_button.click( | |
| fn=generate_speech_from_chat, | |
| inputs=[chatbot], | |
| outputs=[tts_audio] | |
| ) | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| with gr.Row(): | |
| disable_agentic_reasoning = gr.Checkbox( | |
| value=False, | |
| label="Disable agentic reasoning", | |
| info="Use MedSwin model alone without agentic reasoning, RAG, or web search" | |
| ) | |
| show_agentic_thought = gr.Button( | |
| "Show agentic thought", | |
| size="sm" | |
| ) | |
| enable_clinical_intake = gr.Checkbox( | |
| value=True, | |
| label="Enable clinical intake (max 5 Q&A)", | |
| info="Ask focused follow-up questions before breaking down the case" | |
| ) | |
| agentic_thoughts_box = gr.Textbox( | |
| label="Agentic Thoughts", | |
| placeholder="Internal thoughts from MedSwin and supervisor will appear here...", | |
| lines=8, | |
| max_lines=15, | |
| interactive=False, | |
| visible=False, | |
| elem_classes="agentic-thoughts" | |
| ) | |
| with gr.Row(): | |
| use_rag = gr.Checkbox( | |
| value=False, | |
| label="Enable Document RAG", | |
| info="Answer based on uploaded documents (upload required)" | |
| ) | |
| use_web_search = gr.Checkbox( | |
| value=False, | |
| label="Enable Web Search (MCP)", | |
| info="Fetch knowledge from online medical resources" | |
| ) | |
| medical_model = gr.Radio( | |
| choices=list(MEDSWIN_MODELS.keys()), | |
| value=DEFAULT_MEDICAL_MODEL, | |
| label="Medical Model", | |
| info="MedSwin DT (default), others download on selection" | |
| ) | |
| model_status = gr.Textbox( | |
| value="Checking model status...", | |
| label="Model Status", | |
| interactive=False, | |
| visible=True, | |
| lines=3, | |
| max_lines=3, | |
| elem_classes="model-status" | |
| ) | |
| system_prompt = gr.Textbox( | |
| value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.", | |
| label="System Prompt", | |
| lines=3 | |
| ) | |
| with gr.Tab("Generation Parameters"): | |
| temperature = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.2, | |
| label="Temperature" | |
| ) | |
| max_new_tokens = gr.Slider( | |
| minimum=512, | |
| maximum=4096, | |
| step=128, | |
| value=2048, | |
| label="Max New Tokens", | |
| info="Increased for medical models to prevent early stopping" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.1, | |
| value=0.7, | |
| label="Top P" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| label="Top K" | |
| ) | |
| penalty = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.1, | |
| value=1.2, | |
| label="Repetition Penalty" | |
| ) | |
| with gr.Tab("Retrieval Parameters"): | |
| retriever_k = gr.Slider( | |
| minimum=5, | |
| maximum=30, | |
| step=1, | |
| value=15, | |
| label="Initial Retrieval Size (Top K)" | |
| ) | |
| merge_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| step=0.1, | |
| value=0.5, | |
| label="Merge Threshold (lower = more merging)" | |
| ) | |
| # MedSwin Model Links | |
| gr.Markdown( | |
| """ | |
| <div style="margin-top: 20px; padding: 15px; background-color: #f5f5f5; border-radius: 8px;"> | |
| <h4 style="margin-top: 0; margin-bottom: 10px;">🔗 MedSwin Models on Hugging Face</h4> | |
| <div style="display: flex; flex-wrap: wrap; gap: 10px;"> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-Merged-DaRE-TIES-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin DT</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-Merged-NuSLERP-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin Nsl</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-Merged-DaRE-Linear-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin DL</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-Merged-TIES-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin Ti</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-Merged-TA-SFT-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin TA</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-7B-SFT" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin SFT</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-7B-KD" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin KD</a> | |
| </div> | |
| <p style="margin-top: 10px; margin-bottom: 0; font-size: 11px; color: #666;">Click any model name to view details on Hugging Face</p> | |
| </div> | |
| """ | |
| ) | |
| show_thoughts_state = gr.State(value=False) | |
| def toggle_thoughts_box(current_state): | |
| """Toggle visibility of agentic thoughts box""" | |
| new_state = not current_state | |
| return gr.update(visible=new_state), new_state | |
| show_agentic_thought.click( | |
| fn=toggle_thoughts_box, | |
| inputs=[show_thoughts_state], | |
| outputs=[agentic_thoughts_box, show_thoughts_state] | |
| ) | |
| # GPU-decorated function to load any model (for user selection) | |
| # @spaces.GPU(max_duration=MAX_DURATION) | |
| def load_model_with_gpu(model_name): | |
| """Load medical model (GPU-decorated for ZeroGPU compatibility)""" | |
| try: | |
| if not is_model_loaded(model_name): | |
| logger.info(f"Loading medical model: {model_name}...") | |
| set_model_loading_state(model_name, "loading") | |
| try: | |
| initialize_medical_model(model_name) | |
| logger.info(f"✅ Medical model {model_name} loaded successfully!") | |
| return "✅ The model has been loaded successfully", True | |
| except Exception as e: | |
| logger.error(f"Failed to load medical model {model_name}: {e}") | |
| set_model_loading_state(model_name, "error") | |
| return f"❌ Error loading model: {str(e)[:100]}", False | |
| else: | |
| logger.info(f"Medical model {model_name} is already loaded") | |
| return "✅ The model has been loaded successfully", True | |
| except Exception as e: | |
| logger.error(f"Error loading model {model_name}: {e}") | |
| return f"❌ Error: {str(e)[:100]}", False | |
| def load_model_and_update_status(model_name): | |
| """Load model and update status, return status text and whether model is ready""" | |
| try: | |
| status_lines = [] | |
| # Medical model status | |
| if is_model_loaded(model_name): | |
| status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready") | |
| else: | |
| state = get_model_loading_state(model_name) | |
| if state == "loading": | |
| status_lines.append(f"⏳ MedSwin ({model_name}): loading...") | |
| elif state == "error": | |
| status_lines.append(f"❌ MedSwin ({model_name}): error loading") | |
| else: | |
| # Use GPU-decorated function to load the model | |
| try: | |
| result = load_model_with_gpu(model_name) | |
| if result and isinstance(result, tuple) and len(result) == 2: | |
| status_text, is_ready = result | |
| if is_ready: | |
| status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready") | |
| else: | |
| status_lines.append(f"⏳ MedSwin ({model_name}): loading...") | |
| else: | |
| status_lines.append(f"⏳ MedSwin ({model_name}): loading...") | |
| except Exception as e: | |
| logger.error(f"Error calling load_model_with_gpu: {e}") | |
| status_lines.append(f"⏳ MedSwin ({model_name}): loading...") | |
| # TTS model status (only show if available or if there's an issue) | |
| if TTS_AVAILABLE: | |
| if config.global_tts_model is not None: | |
| status_lines.append("✅ TTS (maya1): loaded and ready") | |
| else: | |
| # TTS available but not loaded - optional feature | |
| pass # Don't show if not loaded, it's optional | |
| # Don't show TTS status if library not available (it's optional) | |
| # ASR (Whisper) model status | |
| if WHISPER_AVAILABLE: | |
| if config.global_whisper_model is not None: | |
| status_lines.append("✅ ASR (Whisper): loaded and ready") | |
| else: | |
| status_lines.append("⏳ ASR (Whisper): will load on first use") | |
| else: | |
| status_lines.append("❌ ASR: library not available") | |
| status_text = "\n".join(status_lines) | |
| is_ready = is_model_loaded(model_name) | |
| return status_text, is_ready | |
| except Exception as e: | |
| return f"❌ Error: {str(e)[:100]}", False | |
| def check_model_status(model_name): | |
| """Check current model status without loading""" | |
| status_lines = [] | |
| # Medical model status | |
| if is_model_loaded(model_name): | |
| status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready") | |
| else: | |
| state = get_model_loading_state(model_name) | |
| if state == "loading": | |
| status_lines.append(f"⏳ MedSwin ({model_name}): loading...") | |
| elif state == "error": | |
| status_lines.append(f"❌ MedSwin ({model_name}): error loading") | |
| else: | |
| status_lines.append(f"⚠️ MedSwin ({model_name}): not loaded") | |
| # TTS model status (only show if available and loaded) | |
| if TTS_AVAILABLE: | |
| if config.global_tts_model is not None: | |
| status_lines.append("✅ TTS (maya1): loaded and ready") | |
| # Don't show if TTS library available but model not loaded (optional feature) | |
| # Don't show TTS status if library not available (it's optional) | |
| # ASR (Whisper) model status | |
| if WHISPER_AVAILABLE: | |
| if config.global_whisper_model is not None: | |
| status_lines.append("✅ ASR (Whisper): loaded and ready") | |
| else: | |
| status_lines.append("⏳ ASR (Whisper): will load on first use") | |
| else: | |
| status_lines.append("❌ ASR: library not available") | |
| status_text = "\n".join(status_lines) | |
| is_ready = is_model_loaded(model_name) | |
| return status_text, is_ready | |
| # GPU-decorated function to load ONLY medical model on startup | |
| # According to ZeroGPU best practices: | |
| # 1. Load models to CPU in global scope (no GPU decorator needed) | |
| # 2. Move models to GPU only in inference functions (with @spaces.GPU decorator) | |
| # However, for large models, loading to CPU then moving to GPU uses more memory | |
| # So we use a hybrid approach: load to GPU directly but within GPU-decorated function | |
| def load_medical_model_on_startup_cpu(): | |
| """ | |
| Load model to CPU on startup (ZeroGPU best practice - no GPU decorator needed) | |
| Model will be moved to GPU during first inference | |
| """ | |
| status_messages = [] | |
| try: | |
| # Load only medical model (MedSwin) to CPU - TTS and Whisper load on-demand | |
| if not is_model_loaded(DEFAULT_MEDICAL_MODEL): | |
| logger.info(f"[STARTUP] Loading medical model to CPU: {DEFAULT_MEDICAL_MODEL}...") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading") | |
| try: | |
| # Load to CPU (no GPU decorator needed) | |
| initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=False) | |
| # Verify model is actually loaded | |
| if is_model_loaded(DEFAULT_MEDICAL_MODEL): | |
| status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to CPU") | |
| logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to CPU successfully!") | |
| else: | |
| status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed") | |
| logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") | |
| except Exception as e: | |
| status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}") | |
| logger.error(f"[STARTUP] Failed to load medical model: {e}") | |
| import traceback | |
| logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") | |
| else: | |
| status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded") | |
| logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded") | |
| # Add ASR status (will load on first use) | |
| if WHISPER_AVAILABLE: | |
| status_messages.append("⏳ ASR (Whisper): will load on first use") | |
| else: | |
| status_messages.append("❌ ASR: library not available") | |
| # Return status | |
| status_text = "\n".join(status_messages) | |
| logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}") | |
| return status_text | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f"[STARTUP] Error loading model to CPU: {error_msg}") | |
| return f"⚠️ Error loading model: {error_msg[:100]}" | |
| # Alternative: Load directly to GPU (requires GPU decorator) | |
| # @spaces.GPU(max_duration=MAX_DURATION) | |
| def load_medical_model_on_startup_gpu(): | |
| """ | |
| Load model directly to GPU on startup (alternative approach) | |
| Uses GPU quota but model is immediately ready for inference | |
| """ | |
| import torch | |
| status_messages = [] | |
| try: | |
| # Clear GPU cache at start | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("[STARTUP] Cleared GPU cache before model loading") | |
| # Load only medical model (MedSwin) - TTS and Whisper load on-demand | |
| if not is_model_loaded(DEFAULT_MEDICAL_MODEL): | |
| logger.info(f"[STARTUP] Loading medical model to GPU: {DEFAULT_MEDICAL_MODEL}...") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading") | |
| try: | |
| # Load directly to GPU (within GPU-decorated function) | |
| initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=True) | |
| # Verify model is actually loaded | |
| if is_model_loaded(DEFAULT_MEDICAL_MODEL): | |
| status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to GPU") | |
| logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to GPU successfully!") | |
| else: | |
| status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed") | |
| logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") | |
| except Exception as e: | |
| status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}") | |
| logger.error(f"[STARTUP] Failed to load medical model: {e}") | |
| import traceback | |
| logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") | |
| else: | |
| status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded") | |
| logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded") | |
| # Add ASR status (will load on first use) | |
| if WHISPER_AVAILABLE: | |
| status_messages.append("⏳ ASR (Whisper): will load on first use") | |
| else: | |
| status_messages.append("❌ ASR: library not available") | |
| # Clear cache after loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("[STARTUP] Cleared GPU cache after model loading") | |
| # Return status | |
| status_text = "\n".join(status_messages) | |
| logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}") | |
| return status_text | |
| except Exception as e: | |
| error_msg = str(e) | |
| # Check if it's a ZeroGPU quota/rate limit error | |
| is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or | |
| "quota" in error_msg.lower() or "ZeroGPU" in error_msg or | |
| "runnning out" in error_msg.lower() or "running out" in error_msg.lower()) | |
| if is_quota_error: | |
| logger.warning(f"[STARTUP] ZeroGPU quota/rate limit error detected: {error_msg[:100]}") | |
| # Return status message indicating quota error (will be handled by retry logic) | |
| status_messages.append("⚠️ ZeroGPU quota error - will retry") | |
| status_text = "\n".join(status_messages) | |
| # Also add ASR status | |
| if WHISPER_AVAILABLE: | |
| status_text += "\n⏳ ASR (Whisper): will load on first use" | |
| return status_text # Return status instead of raising, let wrapper handle retry | |
| logger.error(f"[STARTUP] ❌ Error in model loading startup: {e}") | |
| import traceback | |
| logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") | |
| # Clear cache on error | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return f"⚠️ Startup loading error: {str(e)[:100]}" | |
| # Initialize status on load | |
| def init_model_status(): | |
| try: | |
| result = check_model_status(DEFAULT_MEDICAL_MODEL) | |
| if result and isinstance(result, tuple) and len(result) == 2: | |
| status_text, is_ready = result | |
| return status_text | |
| else: | |
| return "⚠️ Unable to check model status" | |
| except Exception as e: | |
| logger.error(f"Error in init_model_status: {e}") | |
| return f"⚠️ Error: {str(e)[:100]}" | |
| # Update status when model selection changes | |
| def update_model_status_on_change(model_name): | |
| try: | |
| result = check_model_status(model_name) | |
| if result and isinstance(result, tuple) and len(result) == 2: | |
| status_text, is_ready = result | |
| return status_text | |
| else: | |
| return "⚠️ Unable to check model status" | |
| except Exception as e: | |
| logger.error(f"Error in update_model_status_on_change: {e}") | |
| return f"⚠️ Error: {str(e)[:100]}" | |
| # Handle model selection change | |
| def on_model_change(model_name): | |
| try: | |
| result = load_model_and_update_status(model_name) | |
| if result and isinstance(result, tuple) and len(result) == 2: | |
| status_text, is_ready = result | |
| submit_enabled = is_ready | |
| return ( | |
| status_text, | |
| gr.update(interactive=submit_enabled), | |
| gr.update(interactive=submit_enabled) | |
| ) | |
| else: | |
| error_msg = "⚠️ Unable to load model status" | |
| return ( | |
| error_msg, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in on_model_change: {e}") | |
| error_msg = f"⚠️ Error: {str(e)[:100]}" | |
| return ( | |
| error_msg, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False) | |
| ) | |
| # Update status display periodically or on model status changes | |
| def refresh_model_status(model_name): | |
| return update_model_status_on_change(model_name) | |
| medical_model.change( | |
| fn=on_model_change, | |
| inputs=[medical_model], | |
| outputs=[model_status, submit_button, message_input] | |
| ) | |
| # GPU-decorated function to load Whisper ASR model on-demand | |
| # @spaces.GPU(max_duration=MAX_DURATION) | |
| def load_whisper_model_on_demand(): | |
| """Load Whisper ASR model when needed""" | |
| try: | |
| if WHISPER_AVAILABLE and config.global_whisper_model is None: | |
| logger.info("[ASR] Loading Whisper model on-demand...") | |
| initialize_whisper_model() | |
| if config.global_whisper_model is not None: | |
| logger.info("[ASR] ✅ Whisper model loaded successfully!") | |
| return "✅ ASR (Whisper): loaded" | |
| else: | |
| logger.warning("[ASR] ⚠️ Whisper model failed to load") | |
| return "⚠️ ASR (Whisper): failed to load" | |
| elif config.global_whisper_model is not None: | |
| return "✅ ASR (Whisper): already loaded" | |
| else: | |
| return "❌ ASR: library not available" | |
| except Exception as e: | |
| logger.error(f"[ASR] Error loading Whisper model: {e}") | |
| return f"❌ ASR: error - {str(e)[:100]}" | |
| # Load medical model on startup and update status | |
| # Use a wrapper to handle GPU context properly with retry logic | |
| def load_startup_and_update_ui(): | |
| """ | |
| Load model on startup with retry logic (max 3 attempts) and return status with UI updates | |
| Uses CPU-first approach (ZeroGPU best practice): | |
| - Load model to CPU (no GPU decorator needed, avoids quota issues) | |
| - Model will be moved to GPU during first inference | |
| """ | |
| import time | |
| max_retries = 3 | |
| base_delay = 5.0 # Start with 5 seconds delay | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model to CPU...") | |
| # Use CPU-first approach (no GPU decorator, avoids quota issues) | |
| status_text = load_medical_model_on_startup_cpu() | |
| # Check if model is ready and update submit button state | |
| is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL) | |
| if is_ready: | |
| logger.info(f"[STARTUP] ✅ Model loaded successfully on attempt {attempt}") | |
| return status_text, gr.update(interactive=is_ready), gr.update(interactive=is_ready) | |
| else: | |
| # Check if status text indicates quota error | |
| if status_text and ("quota" in status_text.lower() or "ZeroGPU" in status_text or | |
| "429" in status_text or "runnning out" in status_text.lower() or | |
| "running out" in status_text.lower()): | |
| if attempt < max_retries: | |
| delay = base_delay * attempt | |
| logger.warning(f"[STARTUP] Quota error detected in status, retrying in {delay} seconds...") | |
| time.sleep(delay) | |
| continue | |
| else: | |
| # Quota exhausted after retries - allow user to proceed, model will load on-demand | |
| status_msg = "⚠️ ZeroGPU quota exhausted.\n⏳ Model will load automatically when you send a message.\n💡 You can also select a model from the dropdown." | |
| logger.info("[STARTUP] Quota exhausted after retries - allowing user to proceed with on-demand loading") | |
| return status_msg, gr.update(interactive=True), gr.update(interactive=True) | |
| # Model didn't load, but no exception - might be a state issue | |
| logger.warning(f"[STARTUP] Model not ready after attempt {attempt}, but no error") | |
| if attempt < max_retries: | |
| delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s | |
| logger.info(f"[STARTUP] Retrying in {delay} seconds...") | |
| time.sleep(delay) | |
| continue | |
| else: | |
| # Even if model didn't load, allow user to try selecting another model | |
| return status_text + "\n⚠️ Model not loaded. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True) | |
| except Exception as e: | |
| error_msg = str(e) | |
| is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or | |
| "quota" in error_msg.lower() or "ZeroGPU" in error_msg or | |
| "runnning out" in error_msg.lower() or "running out" in error_msg.lower()) | |
| if is_quota_error and attempt < max_retries: | |
| delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s | |
| logger.warning(f"[STARTUP] ZeroGPU rate limit/quota error on attempt {attempt}/{max_retries}") | |
| logger.info(f"[STARTUP] Retrying in {delay} seconds...") | |
| time.sleep(delay) | |
| continue | |
| else: | |
| logger.error(f"[STARTUP] Error in load_startup_and_update_ui (attempt {attempt}/{max_retries}): {e}") | |
| import traceback | |
| logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") | |
| if is_quota_error: | |
| # If quota exhausted, allow user to proceed - model will load on-demand | |
| error_display = "⚠️ ZeroGPU quota exhausted.\n⏳ Model will load automatically when you send a message.\n💡 You can also select a model from the dropdown." | |
| logger.info("[STARTUP] Quota exhausted - allowing user to proceed with on-demand loading") | |
| return error_display, gr.update(interactive=True), gr.update(interactive=True) | |
| else: | |
| error_display = f"⚠️ Startup error: {str(e)[:100]}" | |
| if attempt >= max_retries: | |
| logger.error(f"[STARTUP] Failed after {max_retries} attempts") | |
| return error_display, gr.update(interactive=False), gr.update(interactive=False) | |
| # Should not reach here, but just in case | |
| return "⚠️ Startup failed after retries. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True) | |
| demo.load( | |
| fn=load_startup_and_update_ui, | |
| inputs=None, | |
| outputs=[model_status, submit_button, message_input] | |
| ) | |
| # Note: We removed the preload on focus functionality because: | |
| # 1. Model loading requires GPU access (device_map="auto" needs GPU in ZeroGPU) | |
| # 2. The startup function already loads the model with GPU decorator | |
| # 3. Preloading without GPU decorator would fail or cause conflicts | |
| # 4. If startup fails, user can select a model from dropdown to trigger loading | |
| # Wrap stream_chat - ensure model is loaded before starting (don't load inside stream_chat to save time) | |
| def stream_chat_with_model_check( | |
| message, history, system_prompt, temperature, max_new_tokens, | |
| top_p, top_k, penalty, retriever_k, merge_threshold, | |
| use_rag, medical_model_name, use_web_search, | |
| enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None | |
| ): | |
| # Check if model is loaded - if not, show error (don't load here to save stream_chat time) | |
| model_loaded = is_model_loaded(medical_model_name) | |
| if not model_loaded: | |
| loading_state = get_model_loading_state(medical_model_name) | |
| # Debug logging to understand why model check fails | |
| logger.debug(f"[STREAM_CHAT] Model check: name={medical_model_name}, loaded={model_loaded}, state={loading_state}, in_dict={medical_model_name in config.global_medical_models}, model_exists={config.global_medical_models.get(medical_model_name) is not None if medical_model_name in config.global_medical_models else False}") | |
| if loading_state == "loading": | |
| error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages." | |
| else: | |
| error_msg = f"⚠️ {medical_model_name} is not loaded. Please wait a moment for the model to finish loading, or select a model from the dropdown to load it." | |
| updated_history = history + [{"role": "assistant", "content": error_msg}] | |
| yield updated_history, "" | |
| return | |
| # If request is None, create a mock request for compatibility | |
| if request is None: | |
| class MockRequest: | |
| session_hash = "anonymous" | |
| request = MockRequest() | |
| # Model is loaded, proceed with stream_chat (no model loading here to save time) | |
| try: | |
| for result in stream_chat( | |
| message, history, system_prompt, temperature, max_new_tokens, | |
| top_p, top_k, penalty, retriever_k, merge_threshold, | |
| use_rag, medical_model_name, use_web_search, | |
| enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request | |
| ): | |
| yield result | |
| except Exception as e: | |
| # Handle any errors gracefully | |
| logger.error(f"Error in stream_chat_with_model_check: {e}") | |
| import traceback | |
| logger.debug(f"Full traceback: {traceback.format_exc()}") | |
| error_msg = f"⚠️ An error occurred: {str(e)[:200]}" | |
| updated_history = history + [{"role": "assistant", "content": error_msg}] | |
| yield updated_history, "" | |
| submit_button.click( | |
| fn=stream_chat_with_model_check, | |
| inputs=[ | |
| message_input, | |
| chatbot, | |
| system_prompt, | |
| temperature, | |
| max_new_tokens, | |
| top_p, | |
| top_k, | |
| penalty, | |
| retriever_k, | |
| merge_threshold, | |
| use_rag, | |
| medical_model, | |
| use_web_search, | |
| enable_clinical_intake, | |
| disable_agentic_reasoning, | |
| show_thoughts_state | |
| ], | |
| outputs=[chatbot, agentic_thoughts_box] | |
| ) | |
| message_input.submit( | |
| fn=stream_chat_with_model_check, | |
| inputs=[ | |
| message_input, | |
| chatbot, | |
| system_prompt, | |
| temperature, | |
| max_new_tokens, | |
| top_p, | |
| top_k, | |
| penalty, | |
| retriever_k, | |
| merge_threshold, | |
| use_rag, | |
| medical_model, | |
| use_web_search, | |
| enable_clinical_intake, | |
| disable_agentic_reasoning, | |
| show_thoughts_state | |
| ], | |
| outputs=[chatbot, agentic_thoughts_box] | |
| ) | |
| return demo | |