"""Gradio UI setup""" import time import gradio as gr import spaces from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL import config from indexing import create_or_update_index from pipeline import stream_chat from voice import transcribe_audio, generate_speech from models import ( initialize_medical_model, is_model_loaded, get_model_loading_state, set_model_loading_state, initialize_tts_model, initialize_whisper_model, TTS_AVAILABLE, WHISPER_AVAILABLE, ) from logger import logger def create_demo(): """Create and return Gradio demo interface""" with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: gr.HTML(TITLE) gr.HTML(DESCRIPTION) with gr.Row(elem_classes="main-container"): with gr.Column(elem_classes="upload-section"): file_upload = gr.File( file_count="multiple", label="Drag and Drop Files Here", file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"], elem_id="file-upload" ) upload_button = gr.Button("Upload & Index", elem_classes="upload-button") status_output = gr.Textbox( label="Status", placeholder="Upload files to start...", interactive=False ) file_info_output = gr.HTML( label="File Information", elem_classes="processing-info" ) upload_button.click( fn=create_or_update_index, inputs=[file_upload], outputs=[status_output, file_info_output] ) with gr.Column(elem_classes="chatbot-container"): chatbot = gr.Chatbot( height=500, placeholder="Chat with MedSwin... Type your question below.", show_label=False, type="messages" ) with gr.Row(elem_classes="input-row"): message_input = gr.Textbox( placeholder="Type your medical question here...", show_label=False, container=False, lines=1, scale=10 ) mic_button = gr.Audio( sources=["microphone"], type="filepath", label="", show_label=False, container=False, scale=1 ) submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1) recording_timer = gr.Textbox( value="", label="", show_label=False, interactive=False, visible=False, container=False, elem_classes="recording-timer" ) recording_start_time = [None] def handle_recording_start(): """Called when recording starts""" recording_start_time[0] = time.time() return gr.update(visible=True, value="Recording... 0s") def handle_recording_stop(audio): """Called when recording stops""" recording_start_time[0] = None if audio is None: return gr.update(visible=False, value=""), "" transcribed = transcribe_audio(audio) return gr.update(visible=False, value=""), transcribed mic_button.start_recording( fn=handle_recording_start, outputs=[recording_timer] ) mic_button.stop_recording( fn=handle_recording_stop, inputs=[mic_button], outputs=[recording_timer, message_input] ) with gr.Row(visible=False) as tts_row: tts_text = gr.Textbox(visible=False) tts_audio = gr.Audio(label="Generated Speech", visible=False) def generate_speech_from_chat(history): """Extract last assistant message and generate speech""" if not history or len(history) == 0: return None last_msg = history[-1] if last_msg.get("role") == "assistant": text = last_msg.get("content", "").replace(" 🔊", "").strip() if text: audio_path = generate_speech(text) return audio_path return None tts_button = gr.Button("🔊 Play Response", visible=False, size="sm") def update_tts_button(history): if history and len(history) > 0 and history[-1].get("role") == "assistant": return gr.update(visible=True) return gr.update(visible=False) chatbot.change( fn=update_tts_button, inputs=[chatbot], outputs=[tts_button] ) tts_button.click( fn=generate_speech_from_chat, inputs=[chatbot], outputs=[tts_audio] ) with gr.Accordion("⚙️ Advanced Settings", open=False): with gr.Row(): disable_agentic_reasoning = gr.Checkbox( value=False, label="Disable agentic reasoning", info="Use MedSwin model alone without agentic reasoning, RAG, or web search" ) show_agentic_thought = gr.Button( "Show agentic thought", size="sm" ) enable_clinical_intake = gr.Checkbox( value=True, label="Enable clinical intake (max 5 Q&A)", info="Ask focused follow-up questions before breaking down the case" ) agentic_thoughts_box = gr.Textbox( label="Agentic Thoughts", placeholder="Internal thoughts from MedSwin and supervisor will appear here...", lines=8, max_lines=15, interactive=False, visible=False, elem_classes="agentic-thoughts" ) with gr.Row(): use_rag = gr.Checkbox( value=False, label="Enable Document RAG", info="Answer based on uploaded documents (upload required)" ) use_web_search = gr.Checkbox( value=False, label="Enable Web Search (MCP)", info="Fetch knowledge from online medical resources" ) medical_model = gr.Radio( choices=list(MEDSWIN_MODELS.keys()), value=DEFAULT_MEDICAL_MODEL, label="Medical Model", info="MedSwin DT (default), others download on selection" ) model_status = gr.Textbox( value="Checking model status...", label="Model Status", interactive=False, visible=True, elem_classes="model-status" ) system_prompt = gr.Textbox( value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.", label="System Prompt", lines=3 ) with gr.Tab("Generation Parameters"): temperature = gr.Slider( minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature" ) max_new_tokens = gr.Slider( minimum=512, maximum=4096, step=128, value=2048, label="Max New Tokens", info="Increased for medical models to prevent early stopping" ) top_p = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Top P" ) top_k = gr.Slider( minimum=1, maximum=100, step=1, value=50, label="Top K" ) penalty = gr.Slider( minimum=0.0, maximum=2.0, step=0.1, value=1.2, label="Repetition Penalty" ) with gr.Tab("Retrieval Parameters"): retriever_k = gr.Slider( minimum=5, maximum=30, step=1, value=15, label="Initial Retrieval Size (Top K)" ) merge_threshold = gr.Slider( minimum=0.1, maximum=0.9, step=0.1, value=0.5, label="Merge Threshold (lower = more merging)" ) show_thoughts_state = gr.State(value=False) def toggle_thoughts_box(current_state): """Toggle visibility of agentic thoughts box""" new_state = not current_state return gr.update(visible=new_state), new_state show_agentic_thought.click( fn=toggle_thoughts_box, inputs=[show_thoughts_state], outputs=[agentic_thoughts_box, show_thoughts_state] ) # GPU-decorated function to load any model (for user selection) @spaces.GPU(max_duration=120) def load_model_with_gpu(model_name): """Load medical model (GPU-decorated for ZeroGPU compatibility)""" try: if not is_model_loaded(model_name): logger.info(f"Loading medical model: {model_name}...") set_model_loading_state(model_name, "loading") try: initialize_medical_model(model_name) logger.info(f"✅ Medical model {model_name} loaded successfully!") return "✅ The model has been loaded successfully", True except Exception as e: logger.error(f"Failed to load medical model {model_name}: {e}") set_model_loading_state(model_name, "error") return f"❌ Error loading model: {str(e)[:100]}", False else: logger.info(f"Medical model {model_name} is already loaded") return "✅ The model has been loaded successfully", True except Exception as e: logger.error(f"Error loading model {model_name}: {e}") return f"❌ Error: {str(e)[:100]}", False def load_model_and_update_status(model_name): """Load model and update status, return status text and whether model is ready""" try: status_lines = [] # Medical model status if is_model_loaded(model_name): status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready") else: state = get_model_loading_state(model_name) if state == "loading": status_lines.append(f"⏳ MedSwin ({model_name}): loading...") elif state == "error": status_lines.append(f"❌ MedSwin ({model_name}): error loading") else: # Use GPU-decorated function to load the model status_text, is_ready = load_model_with_gpu(model_name) if is_ready: status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready") else: status_lines.append(f"⏳ MedSwin ({model_name}): loading...") # TTS model status if TTS_AVAILABLE: if config.global_tts_model is not None: status_lines.append("✅ TTS (maya1): loaded and ready") else: status_lines.append("⚠️ TTS (maya1): not loaded") else: status_lines.append("❌ TTS: library not available") # ASR (Whisper) model status if WHISPER_AVAILABLE: if config.global_whisper_model is not None: status_lines.append("✅ ASR (Whisper large-v3-turbo): loaded and ready") else: status_lines.append("⚠️ ASR (Whisper large-v3-turbo): not loaded") else: status_lines.append("❌ ASR: library not available") status_text = "\n".join(status_lines) is_ready = is_model_loaded(model_name) return status_text, is_ready except Exception as e: return f"❌ Error: {str(e)[:100]}", False def check_model_status(model_name): """Check current model status without loading""" status_lines = [] # Medical model status if is_model_loaded(model_name): status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready") else: state = get_model_loading_state(model_name) if state == "loading": status_lines.append(f"⏳ MedSwin ({model_name}): loading...") elif state == "error": status_lines.append(f"❌ MedSwin ({model_name}): error loading") else: status_lines.append(f"⚠️ MedSwin ({model_name}): not loaded") # TTS model status if TTS_AVAILABLE: if config.global_tts_model is not None: status_lines.append("✅ TTS (maya1): loaded and ready") else: status_lines.append("⚠️ TTS (maya1): not loaded") else: status_lines.append("❌ TTS: library not available") # ASR (Whisper) model status if WHISPER_AVAILABLE: if config.global_whisper_model is not None: status_lines.append("✅ ASR (Whisper large-v3-turbo): loaded and ready") else: status_lines.append("⚠️ ASR (Whisper large-v3-turbo): not loaded") else: status_lines.append("❌ ASR: library not available") status_text = "\n".join(status_lines) is_ready = is_model_loaded(model_name) return status_text, is_ready # GPU-decorated function to load model on startup @spaces.GPU(max_duration=120) def load_default_model_on_startup(): """Load default medical model on startup (GPU-decorated for ZeroGPU compatibility)""" try: if not is_model_loaded(DEFAULT_MEDICAL_MODEL): logger.info(f"Loading default medical model on startup: {DEFAULT_MEDICAL_MODEL}...") set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading") try: initialize_medical_model(DEFAULT_MEDICAL_MODEL) logger.info(f"✅ Default medical model {DEFAULT_MEDICAL_MODEL} loaded successfully on startup!") return f"✅ {DEFAULT_MEDICAL_MODEL} loaded successfully" except Exception as e: logger.error(f"Failed to load default medical model on startup: {e}") set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") return f"❌ Error loading model: {str(e)[:100]}" else: logger.info(f"Default medical model {DEFAULT_MEDICAL_MODEL} is already loaded") return f"✅ {DEFAULT_MEDICAL_MODEL} is ready" except Exception as e: logger.error(f"Error in model loading startup: {e}") return f"⚠️ Startup loading error: {str(e)[:100]}" # GPU-decorated function to load default TTS and ASR models on startup @spaces.GPU(max_duration=120) def load_voice_models_on_startup(): """Load default TTS model (maya1) and ASR model (Whisper) on startup""" try: # Load TTS model if TTS_AVAILABLE: logger.info("Loading default TTS model (maya1) on startup...") initialize_tts_model() if config.global_tts_model is not None: logger.info("✅ Default TTS model (maya1) loaded successfully on startup!") else: logger.warning("⚠️ TTS model failed to load on startup") else: logger.warning("TTS library not installed; skipping TTS preload.") # Load ASR (Whisper) model if WHISPER_AVAILABLE: logger.info("Loading default ASR model (Whisper large-v3-turbo) on startup...") initialize_whisper_model() if config.global_whisper_model is not None: logger.info("✅ Default ASR model (Whisper large-v3-turbo) loaded successfully on startup!") else: logger.warning("⚠️ ASR model failed to load on startup") else: logger.warning("Whisper transformers not installed; skipping ASR preload.") except Exception as e: logger.error(f"Error in voice models loading startup: {e}") import traceback logger.debug(f"Full traceback: {traceback.format_exc()}") # Initialize status on load def init_model_status(): status_text, is_ready = check_model_status(DEFAULT_MEDICAL_MODEL) return status_text # Update status when model selection changes def update_model_status_on_change(model_name): status_text, is_ready = check_model_status(model_name) return status_text # Handle model selection change def on_model_change(model_name): status_text, is_ready = load_model_and_update_status(model_name) submit_enabled = is_ready return ( status_text, gr.update(interactive=submit_enabled), gr.update(interactive=submit_enabled) ) # Update status display periodically or on model status changes def refresh_model_status(model_name): return update_model_status_on_change(model_name) medical_model.change( fn=on_model_change, inputs=[medical_model], outputs=[model_status, submit_button, message_input] ) # Load models on startup - they will be loaded in separate GPU sessions # First load medical model demo.load( fn=load_default_model_on_startup, inputs=None, outputs=[model_status] ) # Then load voice models (TTS and ASR) demo.load( fn=load_voice_models_on_startup, inputs=None, outputs=None ) # Finally update status to show all models demo.load( fn=lambda: check_model_status(DEFAULT_MEDICAL_MODEL)[0], inputs=None, outputs=[model_status] ) # Wrap stream_chat to check model status before execution def stream_chat_with_model_check( message, history, system_prompt, temperature, max_new_tokens, top_p, top_k, penalty, retriever_k, merge_threshold, use_rag, medical_model_name, use_web_search, enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None ): try: # Check if model is loaded if not is_model_loaded(medical_model_name): # Try to load it status_text, is_ready = load_model_and_update_status(medical_model_name) if not is_ready: error_msg = "⚠️ Model is not ready. Please wait for the model to finish loading before sending messages." updated_history = history + [{"role": "assistant", "content": error_msg}] yield updated_history, "" return # Model is ready, proceed with chat if request is None: # If request is None, create a mock request for compatibility class MockRequest: session_hash = "anonymous" request = MockRequest() for result in stream_chat( message, history, system_prompt, temperature, max_new_tokens, top_p, top_k, penalty, retriever_k, merge_threshold, use_rag, medical_model_name, use_web_search, enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request ): yield result except Exception as e: # Handle any errors gracefully error_msg = f"⚠️ An error occurred: {str(e)[:200]}" updated_history = history + [{"role": "assistant", "content": error_msg}] yield updated_history, "" submit_button.click( fn=stream_chat_with_model_check, inputs=[ message_input, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k, penalty, retriever_k, merge_threshold, use_rag, medical_model, use_web_search, enable_clinical_intake, disable_agentic_reasoning, show_thoughts_state ], outputs=[chatbot, agentic_thoughts_box] ) message_input.submit( fn=stream_chat_with_model_check, inputs=[ message_input, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k, penalty, retriever_k, merge_threshold, use_rag, medical_model, use_web_search, enable_clinical_intake, disable_agentic_reasoning, show_thoughts_state ], outputs=[chatbot, agentic_thoughts_box] ) return demo