"""Gradio UI setup""" import os import time import gradio as gr import spaces from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL import config from indexing import create_or_update_index from pipeline import stream_chat from voice import transcribe_audio, generate_speech from models import ( initialize_medical_model, is_model_loaded, get_model_loading_state, set_model_loading_state, initialize_tts_model, initialize_whisper_model, TTS_AVAILABLE, SNAC_AVAILABLE, WHISPER_AVAILABLE, ) from logger import logger MAX_DURATION = 120 def create_demo(): """Create and return Gradio demo interface""" with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: gr.HTML(TITLE) gr.HTML(DESCRIPTION) with gr.Row(elem_classes="main-container"): with gr.Column(elem_classes="upload-section"): file_upload = gr.File( file_count="multiple", label="Drag and Drop Files Here", file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"], elem_id="file-upload" ) upload_button = gr.Button("Upload & Index", elem_classes="upload-button") status_output = gr.Textbox( label="Status", placeholder="Upload files to start...", interactive=False ) file_info_output = gr.HTML( label="File Information", elem_classes="processing-info" ) upload_button.click( fn=create_or_update_index, inputs=[file_upload], outputs=[status_output, file_info_output] ) with gr.Column(elem_classes="chatbot-container"): chatbot = gr.Chatbot( height=500, placeholder="Chat with MedSwin... Type your question below.", show_label=False, type="messages" ) with gr.Row(elem_classes="input-row"): message_input = gr.Textbox( placeholder="Type your medical question here...", show_label=False, container=False, lines=1, scale=10 ) mic_button = gr.Audio( sources=["microphone"], type="filepath", label="", show_label=False, container=False, scale=1 ) submit_button = gr.Button("āž¤", elem_classes="submit-btn", scale=1) recording_timer = gr.Textbox( value="", label="", show_label=False, interactive=False, visible=False, container=False, elem_classes="recording-timer" ) recording_start_time = [None] def handle_recording_start(): """Called when recording starts""" recording_start_time[0] = time.time() return gr.update(visible=True, value="Recording... 0s") def handle_recording_stop(audio): """Called when recording stops""" recording_start_time[0] = None if audio is None: return gr.update(visible=False, value=""), "" transcribed = transcribe_audio(audio) return gr.update(visible=False, value=""), transcribed mic_button.start_recording( fn=handle_recording_start, outputs=[recording_timer] ) mic_button.stop_recording( fn=handle_recording_stop, inputs=[mic_button], outputs=[recording_timer, message_input] ) with gr.Row(): tts_button = gr.Button("šŸ”Š Play Response", visible=False, size="sm") tts_audio = gr.Audio(label="", visible=True, autoplay=True, show_label=False, container=False) def generate_speech_from_chat(history): """Extract last assistant message and generate speech""" if not history or len(history) == 0: logger.warning("[TTS] No history available") return None last_msg = history[-1] if last_msg.get("role") == "assistant": text = last_msg.get("content", "").replace(" šŸ”Š", "").strip() if text: logger.info(f"[TTS] Generating speech for text: {text[:100]}...") try: audio_path = generate_speech(text) if audio_path and os.path.exists(audio_path): logger.info(f"[TTS] āœ… Generated audio successfully: {audio_path}") return audio_path else: logger.warning(f"[TTS] āŒ Failed to generate audio or file doesn't exist: {audio_path}") return None except Exception as e: logger.error(f"[TTS] Error generating speech: {e}") import traceback logger.debug(f"[TTS] Traceback: {traceback.format_exc()}") return None else: logger.warning("[TTS] Empty text extracted from assistant message") else: logger.warning(f"[TTS] Last message is not from assistant: {last_msg.get('role')}") return None def update_tts_button(history): if history and len(history) > 0 and history[-1].get("role") == "assistant": return gr.update(visible=True) return gr.update(visible=False) chatbot.change( fn=update_tts_button, inputs=[chatbot], outputs=[tts_button] ) tts_button.click( fn=generate_speech_from_chat, inputs=[chatbot], outputs=[tts_audio] ) with gr.Accordion("āš™ļø Advanced Settings", open=False): with gr.Row(): disable_agentic_reasoning = gr.Checkbox( value=False, label="Disable agentic reasoning", info="Use MedSwin model alone without agentic reasoning, RAG, or web search" ) show_agentic_thought = gr.Button( "Show agentic thought", size="sm" ) enable_clinical_intake = gr.Checkbox( value=True, label="Enable clinical intake (max 5 Q&A)", info="Ask focused follow-up questions before breaking down the case" ) agentic_thoughts_box = gr.Textbox( label="Agentic Thoughts", placeholder="Internal thoughts from MedSwin and supervisor will appear here...", lines=8, max_lines=15, interactive=False, visible=False, elem_classes="agentic-thoughts" ) with gr.Row(): use_rag = gr.Checkbox( value=False, label="Enable Document RAG", info="Answer based on uploaded documents (upload required)" ) use_web_search = gr.Checkbox( value=False, label="Enable Web Search (MCP)", info="Fetch knowledge from online medical resources" ) medical_model = gr.Radio( choices=list(MEDSWIN_MODELS.keys()), value=DEFAULT_MEDICAL_MODEL, label="Medical Model", info="MedSwin DT (default), others download on selection" ) model_status = gr.Textbox( value="Checking model status...", label="Model Status", interactive=False, visible=True, lines=3, max_lines=3, elem_classes="model-status" ) system_prompt = gr.Textbox( value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available. Provide answers directly without conversational prefixes like 'Here is...', 'This is...', or 'To answer your question...'. Start with the actual content immediately.", label="System Prompt", lines=3 ) with gr.Tab("Generation Parameters"): temperature = gr.Slider( minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature" ) max_new_tokens = gr.Slider( minimum=512, maximum=4096, step=128, value=2048, label="Max New Tokens", info="Increased for medical models to prevent early stopping" ) top_p = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Top P" ) top_k = gr.Slider( minimum=1, maximum=100, step=1, value=50, label="Top K" ) penalty = gr.Slider( minimum=0.0, maximum=2.0, step=0.1, value=1.2, label="Repetition Penalty" ) with gr.Tab("Retrieval Parameters"): retriever_k = gr.Slider( minimum=5, maximum=30, step=1, value=15, label="Initial Retrieval Size (Top K)" ) merge_threshold = gr.Slider( minimum=0.1, maximum=0.9, step=0.1, value=0.5, label="Merge Threshold (lower = more merging)" ) # MedSwin Model Links gr.Markdown( """

šŸ”— MedSwin Models on Hugging Face

MedSwin DT MedSwin Nsl MedSwin DL MedSwin Ti MedSwin TA MedSwin SFT MedSwin KD

Click any model name to view details on Hugging Face

""" ) show_thoughts_state = gr.State(value=False) def toggle_thoughts_box(current_state): """Toggle visibility of agentic thoughts box""" new_state = not current_state return gr.update(visible=new_state), new_state show_agentic_thought.click( fn=toggle_thoughts_box, inputs=[show_thoughts_state], outputs=[agentic_thoughts_box, show_thoughts_state] ) # GPU-decorated function to load any model (for user selection) # @spaces.GPU(max_duration=MAX_DURATION) def load_model_with_gpu(model_name): """Load medical model (GPU-decorated for ZeroGPU compatibility)""" try: if not is_model_loaded(model_name): logger.info(f"Loading medical model: {model_name}...") set_model_loading_state(model_name, "loading") try: initialize_medical_model(model_name) logger.info(f"āœ… Medical model {model_name} loaded successfully!") return "āœ… The model has been loaded successfully", True except Exception as e: logger.error(f"Failed to load medical model {model_name}: {e}") set_model_loading_state(model_name, "error") return f"āŒ Error loading model: {str(e)[:100]}", False else: logger.info(f"Medical model {model_name} is already loaded") return "āœ… The model has been loaded successfully", True except Exception as e: logger.error(f"Error loading model {model_name}: {e}") return f"āŒ Error: {str(e)[:100]}", False def load_model_and_update_status(model_name): """Load model and update status, return status text and whether model is ready""" try: status_lines = [] # Medical model status if is_model_loaded(model_name): status_lines.append(f"āœ… MedSwin ({model_name}): loaded and ready") else: state = get_model_loading_state(model_name) if state == "loading": status_lines.append(f"ā³ MedSwin ({model_name}): loading...") elif state == "error": status_lines.append(f"āŒ MedSwin ({model_name}): error loading") else: # Use GPU-decorated function to load the model try: result = load_model_with_gpu(model_name) if result and isinstance(result, tuple) and len(result) == 2: status_text, is_ready = result if is_ready: status_lines.append(f"āœ… MedSwin ({model_name}): loaded and ready") else: status_lines.append(f"ā³ MedSwin ({model_name}): loading...") else: status_lines.append(f"ā³ MedSwin ({model_name}): loading...") except Exception as e: logger.error(f"Error calling load_model_with_gpu: {e}") status_lines.append(f"ā³ MedSwin ({model_name}): loading...") # TTS model status (only show if available or if there's an issue) if SNAC_AVAILABLE: if config.global_tts_model is not None: status_lines.append("āœ… TTS (maya1): loaded and ready") else: # TTS available but not loaded - optional feature pass # Don't show if not loaded, it's optional # Don't show TTS status if library not available (it's optional) # ASR (Whisper) model status if WHISPER_AVAILABLE: if config.global_whisper_model is not None: status_lines.append("āœ… ASR (Whisper): loaded and ready") else: status_lines.append("ā³ ASR (Whisper): will load on first use") else: status_lines.append("āŒ ASR: library not available") status_text = "\n".join(status_lines) is_ready = is_model_loaded(model_name) return status_text, is_ready except Exception as e: return f"āŒ Error: {str(e)[:100]}", False def check_model_status(model_name): """Check current model status without loading""" status_lines = [] # Medical model status if is_model_loaded(model_name): status_lines.append(f"āœ… MedSwin ({model_name}): loaded and ready") else: state = get_model_loading_state(model_name) if state == "loading": status_lines.append(f"ā³ MedSwin ({model_name}): loading...") elif state == "error": status_lines.append(f"āŒ MedSwin ({model_name}): error loading") else: status_lines.append(f"āš ļø MedSwin ({model_name}): not loaded") # TTS model status (only show if available and loaded) if SNAC_AVAILABLE: if config.global_tts_model is not None: status_lines.append("āœ… TTS (maya1): loaded and ready") # Don't show if TTS library available but model not loaded (optional feature) # Don't show TTS status if library not available (it's optional) # ASR (Whisper) model status if WHISPER_AVAILABLE: if config.global_whisper_model is not None: status_lines.append("āœ… ASR (Whisper): loaded and ready") else: status_lines.append("ā³ ASR (Whisper): will load on first use") else: status_lines.append("āŒ ASR: library not available") status_text = "\n".join(status_lines) is_ready = is_model_loaded(model_name) return status_text, is_ready # GPU-decorated function to load ONLY medical model on startup # According to ZeroGPU best practices: # 1. Load models to CPU in global scope (no GPU decorator needed) # 2. Move models to GPU only in inference functions (with @spaces.GPU decorator) # However, for large models, loading to CPU then moving to GPU uses more memory # So we use a hybrid approach: load to GPU directly but within GPU-decorated function def load_medical_model_on_startup_cpu(): """ Load model to CPU on startup (ZeroGPU best practice - no GPU decorator needed) Model will be moved to GPU during first inference """ status_messages = [] try: # Load only medical model (MedSwin) to CPU - TTS and Whisper load on-demand if not is_model_loaded(DEFAULT_MEDICAL_MODEL): logger.info(f"[STARTUP] Loading medical model to CPU: {DEFAULT_MEDICAL_MODEL}...") set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading") try: # Load to CPU (no GPU decorator needed) initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=False) # Verify model is actually loaded if is_model_loaded(DEFAULT_MEDICAL_MODEL): status_messages.append(f"āœ… MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to CPU") logger.info(f"[STARTUP] āœ… Medical model {DEFAULT_MEDICAL_MODEL} loaded to CPU successfully!") else: status_messages.append(f"āš ļø MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed") logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded") set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") except Exception as e: status_messages.append(f"āŒ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}") logger.error(f"[STARTUP] Failed to load medical model: {e}") import traceback logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") else: status_messages.append(f"āœ… MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded") logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded") # Add ASR status (will load on first use) if WHISPER_AVAILABLE: status_messages.append("ā³ ASR (Whisper): will load on first use") else: status_messages.append("āŒ ASR: library not available") # Return status status_text = "\n".join(status_messages) logger.info(f"[STARTUP] āœ… Model loading complete. Status:\n{status_text}") return status_text except Exception as e: error_msg = str(e) logger.error(f"[STARTUP] Error loading model to CPU: {error_msg}") return f"āš ļø Error loading model: {error_msg[:100]}" # Alternative: Load directly to GPU (requires GPU decorator) # @spaces.GPU(max_duration=MAX_DURATION) def load_medical_model_on_startup_gpu(): """ Load model directly to GPU on startup (alternative approach) Uses GPU quota but model is immediately ready for inference """ import torch status_messages = [] try: # Clear GPU cache at start if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("[STARTUP] Cleared GPU cache before model loading") # Load only medical model (MedSwin) - TTS and Whisper load on-demand if not is_model_loaded(DEFAULT_MEDICAL_MODEL): logger.info(f"[STARTUP] Loading medical model to GPU: {DEFAULT_MEDICAL_MODEL}...") set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading") try: # Load directly to GPU (within GPU-decorated function) initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=True) # Verify model is actually loaded if is_model_loaded(DEFAULT_MEDICAL_MODEL): status_messages.append(f"āœ… MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to GPU") logger.info(f"[STARTUP] āœ… Medical model {DEFAULT_MEDICAL_MODEL} loaded to GPU successfully!") else: status_messages.append(f"āš ļø MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed") logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded") set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") except Exception as e: status_messages.append(f"āŒ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}") logger.error(f"[STARTUP] Failed to load medical model: {e}") import traceback logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") else: status_messages.append(f"āœ… MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded") logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded") # Add ASR status (will load on first use) if WHISPER_AVAILABLE: status_messages.append("ā³ ASR (Whisper): will load on first use") else: status_messages.append("āŒ ASR: library not available") # Clear cache after loading if torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("[STARTUP] Cleared GPU cache after model loading") # Return status status_text = "\n".join(status_messages) logger.info(f"[STARTUP] āœ… Model loading complete. Status:\n{status_text}") return status_text except Exception as e: error_msg = str(e) # Check if it's a ZeroGPU quota/rate limit error is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or "quota" in error_msg.lower() or "ZeroGPU" in error_msg or "runnning out" in error_msg.lower() or "running out" in error_msg.lower()) if is_quota_error: logger.warning(f"[STARTUP] ZeroGPU quota/rate limit error detected: {error_msg[:100]}") # Return status message indicating quota error (will be handled by retry logic) status_messages.append("āš ļø ZeroGPU quota error - will retry") status_text = "\n".join(status_messages) # Also add ASR status if WHISPER_AVAILABLE: status_text += "\nā³ ASR (Whisper): will load on first use" return status_text # Return status instead of raising, let wrapper handle retry logger.error(f"[STARTUP] āŒ Error in model loading startup: {e}") import traceback logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") # Clear cache on error if torch.cuda.is_available(): torch.cuda.empty_cache() return f"āš ļø Startup loading error: {str(e)[:100]}" # Initialize status on load def init_model_status(): try: result = check_model_status(DEFAULT_MEDICAL_MODEL) if result and isinstance(result, tuple) and len(result) == 2: status_text, is_ready = result return status_text else: return "āš ļø Unable to check model status" except Exception as e: logger.error(f"Error in init_model_status: {e}") return f"āš ļø Error: {str(e)[:100]}" # Update status when model selection changes def update_model_status_on_change(model_name): try: result = check_model_status(model_name) if result and isinstance(result, tuple) and len(result) == 2: status_text, is_ready = result return status_text else: return "āš ļø Unable to check model status" except Exception as e: logger.error(f"Error in update_model_status_on_change: {e}") return f"āš ļø Error: {str(e)[:100]}" # Handle model selection change def on_model_change(model_name): try: result = load_model_and_update_status(model_name) if result and isinstance(result, tuple) and len(result) == 2: status_text, is_ready = result submit_enabled = is_ready return ( status_text, gr.update(interactive=submit_enabled), gr.update(interactive=submit_enabled) ) else: error_msg = "āš ļø Unable to load model status" return ( error_msg, gr.update(interactive=False), gr.update(interactive=False) ) except Exception as e: logger.error(f"Error in on_model_change: {e}") error_msg = f"āš ļø Error: {str(e)[:100]}" return ( error_msg, gr.update(interactive=False), gr.update(interactive=False) ) # Update status display periodically or on model status changes def refresh_model_status(model_name): return update_model_status_on_change(model_name) medical_model.change( fn=on_model_change, inputs=[medical_model], outputs=[model_status, submit_button, message_input] ) # GPU-decorated function to load Whisper ASR model on-demand # @spaces.GPU(max_duration=MAX_DURATION) def load_whisper_model_on_demand(): """Load Whisper ASR model when needed""" try: if WHISPER_AVAILABLE and config.global_whisper_model is None: logger.info("[ASR] Loading Whisper model on-demand...") initialize_whisper_model() if config.global_whisper_model is not None: logger.info("[ASR] āœ… Whisper model loaded successfully!") return "āœ… ASR (Whisper): loaded" else: logger.warning("[ASR] āš ļø Whisper model failed to load") return "āš ļø ASR (Whisper): failed to load" elif config.global_whisper_model is not None: return "āœ… ASR (Whisper): already loaded" else: return "āŒ ASR: library not available" except Exception as e: logger.error(f"[ASR] Error loading Whisper model: {e}") return f"āŒ ASR: error - {str(e)[:100]}" # Load medical model on startup and update status # Use a wrapper to handle GPU context properly with retry logic def load_startup_and_update_ui(): """ Load model on startup with retry logic (max 3 attempts) and return status with UI updates Uses CPU-first approach (ZeroGPU best practice): - Load model to CPU (no GPU decorator needed, avoids quota issues) - Model will be moved to GPU during first inference """ import time max_retries = 3 base_delay = 5.0 # Start with 5 seconds delay for attempt in range(1, max_retries + 1): try: logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model to CPU...") # Use CPU-first approach (no GPU decorator, avoids quota issues) status_text = load_medical_model_on_startup_cpu() # Check if model is ready and update submit button state is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL) if is_ready: logger.info(f"[STARTUP] āœ… Model loaded successfully on attempt {attempt}") return status_text, gr.update(interactive=is_ready), gr.update(interactive=is_ready) else: # Check if status text indicates quota error if status_text and ("quota" in status_text.lower() or "ZeroGPU" in status_text or "429" in status_text or "runnning out" in status_text.lower() or "running out" in status_text.lower()): if attempt < max_retries: delay = base_delay * attempt logger.warning(f"[STARTUP] Quota error detected in status, retrying in {delay} seconds...") time.sleep(delay) continue else: # Quota exhausted after retries - allow user to proceed, model will load on-demand status_msg = "āš ļø ZeroGPU quota exhausted.\nā³ Model will load automatically when you send a message.\nšŸ’” You can also select a model from the dropdown." logger.info("[STARTUP] Quota exhausted after retries - allowing user to proceed with on-demand loading") return status_msg, gr.update(interactive=True), gr.update(interactive=True) # Model didn't load, but no exception - might be a state issue logger.warning(f"[STARTUP] Model not ready after attempt {attempt}, but no error") if attempt < max_retries: delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s logger.info(f"[STARTUP] Retrying in {delay} seconds...") time.sleep(delay) continue else: # Even if model didn't load, allow user to try selecting another model return status_text + "\nāš ļø Model not loaded. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True) except Exception as e: error_msg = str(e) is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or "quota" in error_msg.lower() or "ZeroGPU" in error_msg or "runnning out" in error_msg.lower() or "running out" in error_msg.lower()) if is_quota_error and attempt < max_retries: delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s logger.warning(f"[STARTUP] ZeroGPU rate limit/quota error on attempt {attempt}/{max_retries}") logger.info(f"[STARTUP] Retrying in {delay} seconds...") time.sleep(delay) continue else: logger.error(f"[STARTUP] Error in load_startup_and_update_ui (attempt {attempt}/{max_retries}): {e}") import traceback logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") if is_quota_error: # If quota exhausted, allow user to proceed - model will load on-demand error_display = "āš ļø ZeroGPU quota exhausted.\nā³ Model will load automatically when you send a message.\nšŸ’” You can also select a model from the dropdown." logger.info("[STARTUP] Quota exhausted - allowing user to proceed with on-demand loading") return error_display, gr.update(interactive=True), gr.update(interactive=True) else: error_display = f"āš ļø Startup error: {str(e)[:100]}" if attempt >= max_retries: logger.error(f"[STARTUP] Failed after {max_retries} attempts") return error_display, gr.update(interactive=False), gr.update(interactive=False) # Should not reach here, but just in case return "āš ļø Startup failed after retries. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True) demo.load( fn=load_startup_and_update_ui, inputs=None, outputs=[model_status, submit_button, message_input] ) # Note: We removed the preload on focus functionality because: # 1. Model loading requires GPU access (device_map="auto" needs GPU in ZeroGPU) # 2. The startup function already loads the model with GPU decorator # 3. Preloading without GPU decorator would fail or cause conflicts # 4. If startup fails, user can select a model from dropdown to trigger loading # Wrap stream_chat - ensure model is loaded before starting (don't load inside stream_chat to save time) def stream_chat_with_model_check( message, history, system_prompt, temperature, max_new_tokens, top_p, top_k, penalty, retriever_k, merge_threshold, use_rag, medical_model_name, use_web_search, enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None ): # Check if model is loaded - if not, show error (don't load here to save stream_chat time) model_loaded = is_model_loaded(medical_model_name) if not model_loaded: loading_state = get_model_loading_state(medical_model_name) # Debug logging to understand why model check fails logger.debug(f"[STREAM_CHAT] Model check: name={medical_model_name}, loaded={model_loaded}, state={loading_state}, in_dict={medical_model_name in config.global_medical_models}, model_exists={config.global_medical_models.get(medical_model_name) is not None if medical_model_name in config.global_medical_models else False}") if loading_state == "loading": error_msg = f"ā³ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages." else: error_msg = f"āš ļø {medical_model_name} is not loaded. Please wait a moment for the model to finish loading, or select a model from the dropdown to load it." updated_history = history + [{"role": "assistant", "content": error_msg}] yield updated_history, "" return # If request is None, create a mock request for compatibility if request is None: class MockRequest: session_hash = "anonymous" request = MockRequest() # Model is loaded, proceed with stream_chat (no model loading here to save time) # Note: We handle "BodyStreamBuffer was aborted" errors by catching stream disconnections # and not attempting to yield after the client has disconnected last_result = None stream_aborted = False try: for result in stream_chat( message, history, system_prompt, temperature, max_new_tokens, top_p, top_k, penalty, retriever_k, merge_threshold, use_rag, medical_model_name, use_web_search, enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request ): last_result = result try: yield result except (GeneratorExit, StopIteration, RuntimeError) as stream_error: # Stream was closed/aborted by client - don't try to yield again error_msg_lower = str(stream_error).lower() if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower: logger.info(f"[UI] Stream was aborted by client, stopping gracefully") stream_aborted = True break raise except (GeneratorExit, StopIteration) as stream_exit: # Stream was closed - this is normal, just log and exit logger.info(f"[UI] Stream closed normally") stream_aborted = True return except Exception as e: # Handle any errors gracefully error_str = str(e) error_msg_lower = error_str.lower() # Check if this is a stream abort error is_stream_abort = ( 'bodystreambuffer' in error_msg_lower or 'stream' in error_msg_lower and 'abort' in error_msg_lower or 'connection' in error_msg_lower and 'abort' in error_msg_lower or isinstance(e, (GeneratorExit, StopIteration, RuntimeError)) and 'abort' in error_msg_lower ) if is_stream_abort: logger.info(f"[UI] Stream was aborted (BodyStreamBuffer or similar): {error_str[:100]}") stream_aborted = True # If we have a result, it was already yielded, so just return return is_gpu_timeout = 'gpu task aborted' in error_msg_lower or 'timeout' in error_msg_lower logger.error(f"Error in stream_chat_with_model_check: {error_str}") import traceback logger.debug(f"Full traceback: {traceback.format_exc()}") # Check if we have a valid answer in the last result has_valid_answer = False if last_result is not None: try: last_history, last_thoughts = last_result # Find the last assistant message in the history if last_history and isinstance(last_history, list): for msg in reversed(last_history): if isinstance(msg, dict) and msg.get("role") == "assistant": assistant_content = msg.get("content", "") # Check if it's a valid answer (not empty, not an error message) if assistant_content and len(assistant_content.strip()) > 0: # Not an error message if not assistant_content.strip().startswith("āš ļø") and not assistant_content.strip().startswith("ā³"): has_valid_answer = True break except Exception as parse_error: logger.debug(f"Error parsing last_result: {parse_error}") # If stream was aborted, don't try to yield - just return if stream_aborted: logger.info(f"[UI] Stream was aborted, not yielding final result") return # If we have a valid answer, use it (don't show error message) if has_valid_answer: logger.info(f"[UI] Error occurred but final answer already generated, displaying it without error message") try: yield last_result except (GeneratorExit, StopIteration, RuntimeError) as yield_error: error_msg_lower = str(yield_error).lower() if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower: logger.info(f"[UI] Stream aborted while yielding final result, ignoring") else: raise return # For GPU timeouts, try to use last result even if it's partial if is_gpu_timeout and last_result is not None: logger.info(f"[UI] GPU timeout occurred, using last available result") try: yield last_result except (GeneratorExit, StopIteration, RuntimeError) as yield_error: error_msg_lower = str(yield_error).lower() if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower: logger.info(f"[UI] Stream aborted while yielding timeout result, ignoring") else: raise return # Only show error for non-timeout errors when we have no valid answer # For GPU timeouts with no result, show empty message (not error) if is_gpu_timeout: logger.info(f"[UI] GPU timeout with no result, showing empty assistant message") updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": ""}] try: yield updated_history, "" except (GeneratorExit, StopIteration, RuntimeError) as yield_error: error_msg_lower = str(yield_error).lower() if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower: logger.info(f"[UI] Stream aborted while yielding empty message, ignoring") else: raise else: # For other errors, show minimal error message only if no result error_display = f"āš ļø An error occurred: {error_str[:200]}" updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_display}] try: yield updated_history, "" except (GeneratorExit, StopIteration, RuntimeError) as yield_error: error_msg_lower = str(yield_error).lower() if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower: logger.info(f"[UI] Stream aborted while yielding error message, ignoring") else: raise submit_button.click( fn=stream_chat_with_model_check, inputs=[ message_input, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k, penalty, retriever_k, merge_threshold, use_rag, medical_model, use_web_search, enable_clinical_intake, disable_agentic_reasoning, show_thoughts_state ], outputs=[chatbot, agentic_thoughts_box] ) message_input.submit( fn=stream_chat_with_model_check, inputs=[ message_input, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k, penalty, retriever_k, merge_threshold, use_rag, medical_model, use_web_search, enable_clinical_intake, disable_agentic_reasoning, show_thoughts_state ], outputs=[chatbot, agentic_thoughts_box] ) return demo