"""Gradio UI setup"""
import os
import time
import gradio as gr
import spaces
from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL
import config
from indexing import create_or_update_index
from pipeline import stream_chat
from voice import transcribe_audio, generate_speech
from models import (
initialize_medical_model,
is_model_loaded,
get_model_loading_state,
set_model_loading_state,
initialize_tts_model,
initialize_whisper_model,
TTS_AVAILABLE,
WHISPER_AVAILABLE,
)
from logger import logger
MAX_DURATION = 120
def create_demo():
"""Create and return Gradio demo interface"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
with gr.Row(elem_classes="main-container"):
with gr.Column(elem_classes="upload-section"):
file_upload = gr.File(
file_count="multiple",
label="Drag and Drop Files Here",
file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"],
elem_id="file-upload"
)
upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
status_output = gr.Textbox(
label="Status",
placeholder="Upload files to start...",
interactive=False
)
file_info_output = gr.HTML(
label="File Information",
elem_classes="processing-info"
)
upload_button.click(
fn=create_or_update_index,
inputs=[file_upload],
outputs=[status_output, file_info_output]
)
with gr.Column(elem_classes="chatbot-container"):
chatbot = gr.Chatbot(
height=500,
placeholder="Chat with MedSwin... Type your question below.",
show_label=False,
type="messages"
)
with gr.Row(elem_classes="input-row"):
message_input = gr.Textbox(
placeholder="Type your medical question here...",
show_label=False,
container=False,
lines=1,
scale=10
)
mic_button = gr.Audio(
sources=["microphone"],
type="filepath",
label="",
show_label=False,
container=False,
scale=1
)
submit_button = gr.Button("ā¤", elem_classes="submit-btn", scale=1)
recording_timer = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
container=False,
elem_classes="recording-timer"
)
recording_start_time = [None]
def handle_recording_start():
"""Called when recording starts"""
recording_start_time[0] = time.time()
return gr.update(visible=True, value="Recording... 0s")
def handle_recording_stop(audio):
"""Called when recording stops"""
recording_start_time[0] = None
if audio is None:
return gr.update(visible=False, value=""), ""
transcribed = transcribe_audio(audio)
return gr.update(visible=False, value=""), transcribed
mic_button.start_recording(
fn=handle_recording_start,
outputs=[recording_timer]
)
mic_button.stop_recording(
fn=handle_recording_stop,
inputs=[mic_button],
outputs=[recording_timer, message_input]
)
with gr.Row():
tts_button = gr.Button("š Play Response", visible=False, size="sm")
tts_audio = gr.Audio(label="", visible=True, autoplay=True, show_label=False, container=False)
def generate_speech_from_chat(history):
"""Extract last assistant message and generate speech"""
if not history or len(history) == 0:
logger.warning("[TTS] No history available")
return None
last_msg = history[-1]
if last_msg.get("role") == "assistant":
text = last_msg.get("content", "").replace(" š", "").strip()
if text:
logger.info(f"[TTS] Generating speech for text: {text[:100]}...")
try:
audio_path = generate_speech(text)
if audio_path and os.path.exists(audio_path):
logger.info(f"[TTS] ā
Generated audio successfully: {audio_path}")
return audio_path
else:
logger.warning(f"[TTS] ā Failed to generate audio or file doesn't exist: {audio_path}")
return None
except Exception as e:
logger.error(f"[TTS] Error generating speech: {e}")
import traceback
logger.debug(f"[TTS] Traceback: {traceback.format_exc()}")
return None
else:
logger.warning("[TTS] Empty text extracted from assistant message")
else:
logger.warning(f"[TTS] Last message is not from assistant: {last_msg.get('role')}")
return None
def update_tts_button(history):
if history and len(history) > 0 and history[-1].get("role") == "assistant":
return gr.update(visible=True)
return gr.update(visible=False)
chatbot.change(
fn=update_tts_button,
inputs=[chatbot],
outputs=[tts_button]
)
tts_button.click(
fn=generate_speech_from_chat,
inputs=[chatbot],
outputs=[tts_audio]
)
with gr.Accordion("āļø Advanced Settings", open=False):
with gr.Row():
disable_agentic_reasoning = gr.Checkbox(
value=False,
label="Disable agentic reasoning",
info="Use MedSwin model alone without agentic reasoning, RAG, or web search"
)
show_agentic_thought = gr.Button(
"Show agentic thought",
size="sm"
)
enable_clinical_intake = gr.Checkbox(
value=True,
label="Enable clinical intake (max 5 Q&A)",
info="Ask focused follow-up questions before breaking down the case"
)
agentic_thoughts_box = gr.Textbox(
label="Agentic Thoughts",
placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
lines=8,
max_lines=15,
interactive=False,
visible=False,
elem_classes="agentic-thoughts"
)
with gr.Row():
use_rag = gr.Checkbox(
value=False,
label="Enable Document RAG",
info="Answer based on uploaded documents (upload required)"
)
use_web_search = gr.Checkbox(
value=False,
label="Enable Web Search (MCP)",
info="Fetch knowledge from online medical resources"
)
medical_model = gr.Radio(
choices=list(MEDSWIN_MODELS.keys()),
value=DEFAULT_MEDICAL_MODEL,
label="Medical Model",
info="MedSwin DT (default), others download on selection"
)
model_status = gr.Textbox(
value="Checking model status...",
label="Model Status",
interactive=False,
visible=True,
lines=3,
max_lines=3,
elem_classes="model-status"
)
system_prompt = gr.Textbox(
value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
label="System Prompt",
lines=3
)
with gr.Tab("Generation Parameters"):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=512,
maximum=4096,
step=128,
value=2048,
label="Max New Tokens",
info="Increased for medical models to prevent early stopping"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
label="Top P"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top K"
)
penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition Penalty"
)
with gr.Tab("Retrieval Parameters"):
retriever_k = gr.Slider(
minimum=5,
maximum=30,
step=1,
value=15,
label="Initial Retrieval Size (Top K)"
)
merge_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.1,
value=0.5,
label="Merge Threshold (lower = more merging)"
)
# MedSwin Model Links
gr.Markdown(
"""
š MedSwin Models on Hugging Face
Click any model name to view details on Hugging Face
"""
)
show_thoughts_state = gr.State(value=False)
def toggle_thoughts_box(current_state):
"""Toggle visibility of agentic thoughts box"""
new_state = not current_state
return gr.update(visible=new_state), new_state
show_agentic_thought.click(
fn=toggle_thoughts_box,
inputs=[show_thoughts_state],
outputs=[agentic_thoughts_box, show_thoughts_state]
)
# GPU-decorated function to load any model (for user selection)
# @spaces.GPU(max_duration=MAX_DURATION)
def load_model_with_gpu(model_name):
"""Load medical model (GPU-decorated for ZeroGPU compatibility)"""
try:
if not is_model_loaded(model_name):
logger.info(f"Loading medical model: {model_name}...")
set_model_loading_state(model_name, "loading")
try:
initialize_medical_model(model_name)
logger.info(f"ā
Medical model {model_name} loaded successfully!")
return "ā
The model has been loaded successfully", True
except Exception as e:
logger.error(f"Failed to load medical model {model_name}: {e}")
set_model_loading_state(model_name, "error")
return f"ā Error loading model: {str(e)[:100]}", False
else:
logger.info(f"Medical model {model_name} is already loaded")
return "ā
The model has been loaded successfully", True
except Exception as e:
logger.error(f"Error loading model {model_name}: {e}")
return f"ā Error: {str(e)[:100]}", False
def load_model_and_update_status(model_name):
"""Load model and update status, return status text and whether model is ready"""
try:
status_lines = []
# Medical model status
if is_model_loaded(model_name):
status_lines.append(f"ā
MedSwin ({model_name}): loaded and ready")
else:
state = get_model_loading_state(model_name)
if state == "loading":
status_lines.append(f"ā³ MedSwin ({model_name}): loading...")
elif state == "error":
status_lines.append(f"ā MedSwin ({model_name}): error loading")
else:
# Use GPU-decorated function to load the model
try:
result = load_model_with_gpu(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
if is_ready:
status_lines.append(f"ā
MedSwin ({model_name}): loaded and ready")
else:
status_lines.append(f"ā³ MedSwin ({model_name}): loading...")
else:
status_lines.append(f"ā³ MedSwin ({model_name}): loading...")
except Exception as e:
logger.error(f"Error calling load_model_with_gpu: {e}")
status_lines.append(f"ā³ MedSwin ({model_name}): loading...")
# TTS model status (only show if available or if there's an issue)
if TTS_AVAILABLE:
if config.global_tts_model is not None:
status_lines.append("ā
TTS (maya1): loaded and ready")
else:
# TTS available but not loaded - optional feature
pass # Don't show if not loaded, it's optional
# Don't show TTS status if library not available (it's optional)
# ASR (Whisper) model status
if WHISPER_AVAILABLE:
if config.global_whisper_model is not None:
status_lines.append("ā
ASR (Whisper): loaded and ready")
else:
status_lines.append("ā³ ASR (Whisper): will load on first use")
else:
status_lines.append("ā ASR: library not available")
status_text = "\n".join(status_lines)
is_ready = is_model_loaded(model_name)
return status_text, is_ready
except Exception as e:
return f"ā Error: {str(e)[:100]}", False
def check_model_status(model_name):
"""Check current model status without loading"""
status_lines = []
# Medical model status
if is_model_loaded(model_name):
status_lines.append(f"ā
MedSwin ({model_name}): loaded and ready")
else:
state = get_model_loading_state(model_name)
if state == "loading":
status_lines.append(f"ā³ MedSwin ({model_name}): loading...")
elif state == "error":
status_lines.append(f"ā MedSwin ({model_name}): error loading")
else:
status_lines.append(f"ā ļø MedSwin ({model_name}): not loaded")
# TTS model status (only show if available and loaded)
if TTS_AVAILABLE:
if config.global_tts_model is not None:
status_lines.append("ā
TTS (maya1): loaded and ready")
# Don't show if TTS library available but model not loaded (optional feature)
# Don't show TTS status if library not available (it's optional)
# ASR (Whisper) model status
if WHISPER_AVAILABLE:
if config.global_whisper_model is not None:
status_lines.append("ā
ASR (Whisper): loaded and ready")
else:
status_lines.append("ā³ ASR (Whisper): will load on first use")
else:
status_lines.append("ā ASR: library not available")
status_text = "\n".join(status_lines)
is_ready = is_model_loaded(model_name)
return status_text, is_ready
# GPU-decorated function to load ONLY medical model on startup
# According to ZeroGPU best practices:
# 1. Load models to CPU in global scope (no GPU decorator needed)
# 2. Move models to GPU only in inference functions (with @spaces.GPU decorator)
# However, for large models, loading to CPU then moving to GPU uses more memory
# So we use a hybrid approach: load to GPU directly but within GPU-decorated function
def load_medical_model_on_startup_cpu():
"""
Load model to CPU on startup (ZeroGPU best practice - no GPU decorator needed)
Model will be moved to GPU during first inference
"""
status_messages = []
try:
# Load only medical model (MedSwin) to CPU - TTS and Whisper load on-demand
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
logger.info(f"[STARTUP] Loading medical model to CPU: {DEFAULT_MEDICAL_MODEL}...")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
try:
# Load to CPU (no GPU decorator needed)
initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=False)
# Verify model is actually loaded
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
status_messages.append(f"ā
MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to CPU")
logger.info(f"[STARTUP] ā
Medical model {DEFAULT_MEDICAL_MODEL} loaded to CPU successfully!")
else:
status_messages.append(f"ā ļø MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
except Exception as e:
status_messages.append(f"ā MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}")
logger.error(f"[STARTUP] Failed to load medical model: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
else:
status_messages.append(f"ā
MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
# Add ASR status (will load on first use)
if WHISPER_AVAILABLE:
status_messages.append("ā³ ASR (Whisper): will load on first use")
else:
status_messages.append("ā ASR: library not available")
# Return status
status_text = "\n".join(status_messages)
logger.info(f"[STARTUP] ā
Model loading complete. Status:\n{status_text}")
return status_text
except Exception as e:
error_msg = str(e)
logger.error(f"[STARTUP] Error loading model to CPU: {error_msg}")
return f"ā ļø Error loading model: {error_msg[:100]}"
# Alternative: Load directly to GPU (requires GPU decorator)
# @spaces.GPU(max_duration=MAX_DURATION)
def load_medical_model_on_startup_gpu():
"""
Load model directly to GPU on startup (alternative approach)
Uses GPU quota but model is immediately ready for inference
"""
import torch
status_messages = []
try:
# Clear GPU cache at start
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("[STARTUP] Cleared GPU cache before model loading")
# Load only medical model (MedSwin) - TTS and Whisper load on-demand
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
logger.info(f"[STARTUP] Loading medical model to GPU: {DEFAULT_MEDICAL_MODEL}...")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
try:
# Load directly to GPU (within GPU-decorated function)
initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=True)
# Verify model is actually loaded
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
status_messages.append(f"ā
MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to GPU")
logger.info(f"[STARTUP] ā
Medical model {DEFAULT_MEDICAL_MODEL} loaded to GPU successfully!")
else:
status_messages.append(f"ā ļø MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
except Exception as e:
status_messages.append(f"ā MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}")
logger.error(f"[STARTUP] Failed to load medical model: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
else:
status_messages.append(f"ā
MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
# Add ASR status (will load on first use)
if WHISPER_AVAILABLE:
status_messages.append("ā³ ASR (Whisper): will load on first use")
else:
status_messages.append("ā ASR: library not available")
# Clear cache after loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("[STARTUP] Cleared GPU cache after model loading")
# Return status
status_text = "\n".join(status_messages)
logger.info(f"[STARTUP] ā
Model loading complete. Status:\n{status_text}")
return status_text
except Exception as e:
error_msg = str(e)
# Check if it's a ZeroGPU quota/rate limit error
is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or
"quota" in error_msg.lower() or "ZeroGPU" in error_msg or
"runnning out" in error_msg.lower() or "running out" in error_msg.lower())
if is_quota_error:
logger.warning(f"[STARTUP] ZeroGPU quota/rate limit error detected: {error_msg[:100]}")
# Return status message indicating quota error (will be handled by retry logic)
status_messages.append("ā ļø ZeroGPU quota error - will retry")
status_text = "\n".join(status_messages)
# Also add ASR status
if WHISPER_AVAILABLE:
status_text += "\nā³ ASR (Whisper): will load on first use"
return status_text # Return status instead of raising, let wrapper handle retry
logger.error(f"[STARTUP] ā Error in model loading startup: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
# Clear cache on error
if torch.cuda.is_available():
torch.cuda.empty_cache()
return f"ā ļø Startup loading error: {str(e)[:100]}"
# Initialize status on load
def init_model_status():
try:
result = check_model_status(DEFAULT_MEDICAL_MODEL)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
return status_text
else:
return "ā ļø Unable to check model status"
except Exception as e:
logger.error(f"Error in init_model_status: {e}")
return f"ā ļø Error: {str(e)[:100]}"
# Update status when model selection changes
def update_model_status_on_change(model_name):
try:
result = check_model_status(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
return status_text
else:
return "ā ļø Unable to check model status"
except Exception as e:
logger.error(f"Error in update_model_status_on_change: {e}")
return f"ā ļø Error: {str(e)[:100]}"
# Handle model selection change
def on_model_change(model_name):
try:
result = load_model_and_update_status(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
submit_enabled = is_ready
return (
status_text,
gr.update(interactive=submit_enabled),
gr.update(interactive=submit_enabled)
)
else:
error_msg = "ā ļø Unable to load model status"
return (
error_msg,
gr.update(interactive=False),
gr.update(interactive=False)
)
except Exception as e:
logger.error(f"Error in on_model_change: {e}")
error_msg = f"ā ļø Error: {str(e)[:100]}"
return (
error_msg,
gr.update(interactive=False),
gr.update(interactive=False)
)
# Update status display periodically or on model status changes
def refresh_model_status(model_name):
return update_model_status_on_change(model_name)
medical_model.change(
fn=on_model_change,
inputs=[medical_model],
outputs=[model_status, submit_button, message_input]
)
# GPU-decorated function to load Whisper ASR model on-demand
# @spaces.GPU(max_duration=MAX_DURATION)
def load_whisper_model_on_demand():
"""Load Whisper ASR model when needed"""
try:
if WHISPER_AVAILABLE and config.global_whisper_model is None:
logger.info("[ASR] Loading Whisper model on-demand...")
initialize_whisper_model()
if config.global_whisper_model is not None:
logger.info("[ASR] ā
Whisper model loaded successfully!")
return "ā
ASR (Whisper): loaded"
else:
logger.warning("[ASR] ā ļø Whisper model failed to load")
return "ā ļø ASR (Whisper): failed to load"
elif config.global_whisper_model is not None:
return "ā
ASR (Whisper): already loaded"
else:
return "ā ASR: library not available"
except Exception as e:
logger.error(f"[ASR] Error loading Whisper model: {e}")
return f"ā ASR: error - {str(e)[:100]}"
# Load medical model on startup and update status
# Use a wrapper to handle GPU context properly with retry logic
def load_startup_and_update_ui():
"""
Load model on startup with retry logic (max 3 attempts) and return status with UI updates
Uses CPU-first approach (ZeroGPU best practice):
- Load model to CPU (no GPU decorator needed, avoids quota issues)
- Model will be moved to GPU during first inference
"""
import time
max_retries = 3
base_delay = 5.0 # Start with 5 seconds delay
for attempt in range(1, max_retries + 1):
try:
logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model to CPU...")
# Use CPU-first approach (no GPU decorator, avoids quota issues)
status_text = load_medical_model_on_startup_cpu()
# Check if model is ready and update submit button state
is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL)
if is_ready:
logger.info(f"[STARTUP] ā
Model loaded successfully on attempt {attempt}")
return status_text, gr.update(interactive=is_ready), gr.update(interactive=is_ready)
else:
# Check if status text indicates quota error
if status_text and ("quota" in status_text.lower() or "ZeroGPU" in status_text or
"429" in status_text or "runnning out" in status_text.lower() or
"running out" in status_text.lower()):
if attempt < max_retries:
delay = base_delay * attempt
logger.warning(f"[STARTUP] Quota error detected in status, retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
# Quota exhausted after retries - allow user to proceed, model will load on-demand
status_msg = "ā ļø ZeroGPU quota exhausted.\nā³ Model will load automatically when you send a message.\nš” You can also select a model from the dropdown."
logger.info("[STARTUP] Quota exhausted after retries - allowing user to proceed with on-demand loading")
return status_msg, gr.update(interactive=True), gr.update(interactive=True)
# Model didn't load, but no exception - might be a state issue
logger.warning(f"[STARTUP] Model not ready after attempt {attempt}, but no error")
if attempt < max_retries:
delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s
logger.info(f"[STARTUP] Retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
# Even if model didn't load, allow user to try selecting another model
return status_text + "\nā ļø Model not loaded. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True)
except Exception as e:
error_msg = str(e)
is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or
"quota" in error_msg.lower() or "ZeroGPU" in error_msg or
"runnning out" in error_msg.lower() or "running out" in error_msg.lower())
if is_quota_error and attempt < max_retries:
delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s
logger.warning(f"[STARTUP] ZeroGPU rate limit/quota error on attempt {attempt}/{max_retries}")
logger.info(f"[STARTUP] Retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
logger.error(f"[STARTUP] Error in load_startup_and_update_ui (attempt {attempt}/{max_retries}): {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
if is_quota_error:
# If quota exhausted, allow user to proceed - model will load on-demand
error_display = "ā ļø ZeroGPU quota exhausted.\nā³ Model will load automatically when you send a message.\nš” You can also select a model from the dropdown."
logger.info("[STARTUP] Quota exhausted - allowing user to proceed with on-demand loading")
return error_display, gr.update(interactive=True), gr.update(interactive=True)
else:
error_display = f"ā ļø Startup error: {str(e)[:100]}"
if attempt >= max_retries:
logger.error(f"[STARTUP] Failed after {max_retries} attempts")
return error_display, gr.update(interactive=False), gr.update(interactive=False)
# Should not reach here, but just in case
return "ā ļø Startup failed after retries. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True)
demo.load(
fn=load_startup_and_update_ui,
inputs=None,
outputs=[model_status, submit_button, message_input]
)
# Note: We removed the preload on focus functionality because:
# 1. Model loading requires GPU access (device_map="auto" needs GPU in ZeroGPU)
# 2. The startup function already loads the model with GPU decorator
# 3. Preloading without GPU decorator would fail or cause conflicts
# 4. If startup fails, user can select a model from dropdown to trigger loading
# Wrap stream_chat - ensure model is loaded before starting (don't load inside stream_chat to save time)
def stream_chat_with_model_check(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
):
# Check if model is loaded - if not, show error (don't load here to save stream_chat time)
model_loaded = is_model_loaded(medical_model_name)
if not model_loaded:
loading_state = get_model_loading_state(medical_model_name)
# Debug logging to understand why model check fails
logger.debug(f"[STREAM_CHAT] Model check: name={medical_model_name}, loaded={model_loaded}, state={loading_state}, in_dict={medical_model_name in config.global_medical_models}, model_exists={config.global_medical_models.get(medical_model_name) is not None if medical_model_name in config.global_medical_models else False}")
if loading_state == "loading":
error_msg = f"ā³ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
else:
error_msg = f"ā ļø {medical_model_name} is not loaded. Please wait a moment for the model to finish loading, or select a model from the dropdown to load it."
updated_history = history + [{"role": "assistant", "content": error_msg}]
yield updated_history, ""
return
# If request is None, create a mock request for compatibility
if request is None:
class MockRequest:
session_hash = "anonymous"
request = MockRequest()
# Model is loaded, proceed with stream_chat (no model loading here to save time)
# Note: We handle "BodyStreamBuffer was aborted" errors by catching stream disconnections
# and not attempting to yield after the client has disconnected
last_result = None
stream_aborted = False
try:
for result in stream_chat(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
):
last_result = result
try:
yield result
except (GeneratorExit, StopIteration, RuntimeError) as stream_error:
# Stream was closed/aborted by client - don't try to yield again
error_msg_lower = str(stream_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream was aborted by client, stopping gracefully")
stream_aborted = True
break
raise
except (GeneratorExit, StopIteration) as stream_exit:
# Stream was closed - this is normal, just log and exit
logger.info(f"[UI] Stream closed normally")
stream_aborted = True
return
except Exception as e:
# Handle any errors gracefully
error_str = str(e)
error_msg_lower = error_str.lower()
# Check if this is a stream abort error
is_stream_abort = (
'bodystreambuffer' in error_msg_lower or
'stream' in error_msg_lower and 'abort' in error_msg_lower or
'connection' in error_msg_lower and 'abort' in error_msg_lower or
isinstance(e, (GeneratorExit, StopIteration, RuntimeError)) and 'abort' in error_msg_lower
)
if is_stream_abort:
logger.info(f"[UI] Stream was aborted (BodyStreamBuffer or similar): {error_str[:100]}")
stream_aborted = True
# If we have a result, it was already yielded, so just return
return
is_gpu_timeout = 'gpu task aborted' in error_msg_lower or 'timeout' in error_msg_lower
logger.error(f"Error in stream_chat_with_model_check: {error_str}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
# Check if we have a valid answer in the last result
has_valid_answer = False
if last_result is not None:
try:
last_history, last_thoughts = last_result
# Find the last assistant message in the history
if last_history and isinstance(last_history, list):
for msg in reversed(last_history):
if isinstance(msg, dict) and msg.get("role") == "assistant":
assistant_content = msg.get("content", "")
# Check if it's a valid answer (not empty, not an error message)
if assistant_content and len(assistant_content.strip()) > 0:
# Not an error message
if not assistant_content.strip().startswith("ā ļø") and not assistant_content.strip().startswith("ā³"):
has_valid_answer = True
break
except Exception as parse_error:
logger.debug(f"Error parsing last_result: {parse_error}")
# If stream was aborted, don't try to yield - just return
if stream_aborted:
logger.info(f"[UI] Stream was aborted, not yielding final result")
return
# If we have a valid answer, use it (don't show error message)
if has_valid_answer:
logger.info(f"[UI] Error occurred but final answer already generated, displaying it without error message")
try:
yield last_result
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding final result, ignoring")
else:
raise
return
# For GPU timeouts, try to use last result even if it's partial
if is_gpu_timeout and last_result is not None:
logger.info(f"[UI] GPU timeout occurred, using last available result")
try:
yield last_result
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding timeout result, ignoring")
else:
raise
return
# Only show error for non-timeout errors when we have no valid answer
# For GPU timeouts with no result, show empty message (not error)
if is_gpu_timeout:
logger.info(f"[UI] GPU timeout with no result, showing empty assistant message")
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": ""}]
try:
yield updated_history, ""
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding empty message, ignoring")
else:
raise
else:
# For other errors, show minimal error message only if no result
error_display = f"ā ļø An error occurred: {error_str[:200]}"
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_display}]
try:
yield updated_history, ""
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding error message, ignoring")
else:
raise
submit_button.click(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
message_input.submit(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
return demo