MedLLM-Agent / ui.py
Y Phung Nguyen
Fix model preloader
f7415cc
raw
history blame
47.5 kB
"""Gradio UI setup"""
import time
import gradio as gr
import spaces
from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL
import config
from indexing import create_or_update_index
from pipeline import stream_chat
from voice import transcribe_audio, generate_speech
from models import (
initialize_medical_model,
is_model_loaded,
get_model_loading_state,
set_model_loading_state,
initialize_tts_model,
initialize_whisper_model,
TTS_AVAILABLE,
WHISPER_AVAILABLE,
)
from logger import logger
MAX_DURATION = 120
def create_demo():
"""Create and return Gradio demo interface"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
with gr.Row(elem_classes="main-container"):
with gr.Column(elem_classes="upload-section"):
file_upload = gr.File(
file_count="multiple",
label="Drag and Drop Files Here",
file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"],
elem_id="file-upload"
)
upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
status_output = gr.Textbox(
label="Status",
placeholder="Upload files to start...",
interactive=False
)
file_info_output = gr.HTML(
label="File Information",
elem_classes="processing-info"
)
upload_button.click(
fn=create_or_update_index,
inputs=[file_upload],
outputs=[status_output, file_info_output]
)
with gr.Column(elem_classes="chatbot-container"):
chatbot = gr.Chatbot(
height=500,
placeholder="Chat with MedSwin... Type your question below.",
show_label=False,
type="messages"
)
with gr.Row(elem_classes="input-row"):
message_input = gr.Textbox(
placeholder="Type your medical question here...",
show_label=False,
container=False,
lines=1,
scale=10
)
mic_button = gr.Audio(
sources=["microphone"],
type="filepath",
label="",
show_label=False,
container=False,
scale=1
)
submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
recording_timer = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
container=False,
elem_classes="recording-timer"
)
recording_start_time = [None]
def handle_recording_start():
"""Called when recording starts"""
recording_start_time[0] = time.time()
return gr.update(visible=True, value="Recording... 0s")
def handle_recording_stop(audio):
"""Called when recording stops"""
recording_start_time[0] = None
if audio is None:
return gr.update(visible=False, value=""), ""
transcribed = transcribe_audio(audio)
return gr.update(visible=False, value=""), transcribed
mic_button.start_recording(
fn=handle_recording_start,
outputs=[recording_timer]
)
mic_button.stop_recording(
fn=handle_recording_stop,
inputs=[mic_button],
outputs=[recording_timer, message_input]
)
with gr.Row(visible=False) as tts_row:
tts_text = gr.Textbox(visible=False)
tts_audio = gr.Audio(label="Generated Speech", visible=False)
def generate_speech_from_chat(history):
"""Extract last assistant message and generate speech"""
if not history or len(history) == 0:
return None
last_msg = history[-1]
if last_msg.get("role") == "assistant":
text = last_msg.get("content", "").replace(" 🔊", "").strip()
if text:
audio_path = generate_speech(text)
return audio_path
return None
tts_button = gr.Button("🔊 Play Response", visible=False, size="sm")
def update_tts_button(history):
if history and len(history) > 0 and history[-1].get("role") == "assistant":
return gr.update(visible=True)
return gr.update(visible=False)
chatbot.change(
fn=update_tts_button,
inputs=[chatbot],
outputs=[tts_button]
)
tts_button.click(
fn=generate_speech_from_chat,
inputs=[chatbot],
outputs=[tts_audio]
)
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
disable_agentic_reasoning = gr.Checkbox(
value=False,
label="Disable agentic reasoning",
info="Use MedSwin model alone without agentic reasoning, RAG, or web search"
)
show_agentic_thought = gr.Button(
"Show agentic thought",
size="sm"
)
enable_clinical_intake = gr.Checkbox(
value=True,
label="Enable clinical intake (max 5 Q&A)",
info="Ask focused follow-up questions before breaking down the case"
)
agentic_thoughts_box = gr.Textbox(
label="Agentic Thoughts",
placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
lines=8,
max_lines=15,
interactive=False,
visible=False,
elem_classes="agentic-thoughts"
)
with gr.Row():
use_rag = gr.Checkbox(
value=False,
label="Enable Document RAG",
info="Answer based on uploaded documents (upload required)"
)
use_web_search = gr.Checkbox(
value=False,
label="Enable Web Search (MCP)",
info="Fetch knowledge from online medical resources"
)
medical_model = gr.Radio(
choices=list(MEDSWIN_MODELS.keys()),
value=DEFAULT_MEDICAL_MODEL,
label="Medical Model",
info="MedSwin DT (default), others download on selection"
)
model_status = gr.Textbox(
value="Checking model status...",
label="Model Status",
interactive=False,
visible=True,
lines=3,
max_lines=3,
elem_classes="model-status"
)
system_prompt = gr.Textbox(
value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
label="System Prompt",
lines=3
)
with gr.Tab("Generation Parameters"):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=512,
maximum=4096,
step=128,
value=2048,
label="Max New Tokens",
info="Increased for medical models to prevent early stopping"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
label="Top P"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top K"
)
penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition Penalty"
)
with gr.Tab("Retrieval Parameters"):
retriever_k = gr.Slider(
minimum=5,
maximum=30,
step=1,
value=15,
label="Initial Retrieval Size (Top K)"
)
merge_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.1,
value=0.5,
label="Merge Threshold (lower = more merging)"
)
# MedSwin Model Links
gr.Markdown(
"""
<div style="margin-top: 20px; padding: 15px; background-color: #f5f5f5; border-radius: 8px;">
<h4 style="margin-top: 0; margin-bottom: 10px;">🔗 MedSwin Models on Hugging Face</h4>
<div style="display: flex; flex-wrap: wrap; gap: 10px;">
<a href="https://huggingface.co/MedSwin/MedSwin-Merged-DaRE-TIES-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin DT</a>
<a href="https://huggingface.co/MedSwin/MedSwin-Merged-NuSLERP-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin Nsl</a>
<a href="https://huggingface.co/MedSwin/MedSwin-Merged-DaRE-Linear-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin DL</a>
<a href="https://huggingface.co/MedSwin/MedSwin-Merged-TIES-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin Ti</a>
<a href="https://huggingface.co/MedSwin/MedSwin-Merged-TA-SFT-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin TA</a>
<a href="https://huggingface.co/MedSwin/MedSwin-7B-SFT" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin SFT</a>
<a href="https://huggingface.co/MedSwin/MedSwin-7B-KD" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin KD</a>
</div>
<p style="margin-top: 10px; margin-bottom: 0; font-size: 11px; color: #666;">Click any model name to view details on Hugging Face</p>
</div>
"""
)
show_thoughts_state = gr.State(value=False)
def toggle_thoughts_box(current_state):
"""Toggle visibility of agentic thoughts box"""
new_state = not current_state
return gr.update(visible=new_state), new_state
show_agentic_thought.click(
fn=toggle_thoughts_box,
inputs=[show_thoughts_state],
outputs=[agentic_thoughts_box, show_thoughts_state]
)
# GPU-decorated function to load any model (for user selection)
# @spaces.GPU(max_duration=MAX_DURATION)
def load_model_with_gpu(model_name):
"""Load medical model (GPU-decorated for ZeroGPU compatibility)"""
try:
if not is_model_loaded(model_name):
logger.info(f"Loading medical model: {model_name}...")
set_model_loading_state(model_name, "loading")
try:
initialize_medical_model(model_name)
logger.info(f"✅ Medical model {model_name} loaded successfully!")
return "✅ The model has been loaded successfully", True
except Exception as e:
logger.error(f"Failed to load medical model {model_name}: {e}")
set_model_loading_state(model_name, "error")
return f"❌ Error loading model: {str(e)[:100]}", False
else:
logger.info(f"Medical model {model_name} is already loaded")
return "✅ The model has been loaded successfully", True
except Exception as e:
logger.error(f"Error loading model {model_name}: {e}")
return f"❌ Error: {str(e)[:100]}", False
def load_model_and_update_status(model_name):
"""Load model and update status, return status text and whether model is ready"""
try:
status_lines = []
# Medical model status
if is_model_loaded(model_name):
status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready")
else:
state = get_model_loading_state(model_name)
if state == "loading":
status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
elif state == "error":
status_lines.append(f"❌ MedSwin ({model_name}): error loading")
else:
# Use GPU-decorated function to load the model
try:
result = load_model_with_gpu(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
if is_ready:
status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready")
else:
status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
else:
status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
except Exception as e:
logger.error(f"Error calling load_model_with_gpu: {e}")
status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
# TTS model status (only show if available or if there's an issue)
if TTS_AVAILABLE:
if config.global_tts_model is not None:
status_lines.append("✅ TTS (maya1): loaded and ready")
else:
# TTS available but not loaded - optional feature
pass # Don't show if not loaded, it's optional
# Don't show TTS status if library not available (it's optional)
# ASR (Whisper) model status
if WHISPER_AVAILABLE:
if config.global_whisper_model is not None:
status_lines.append("✅ ASR (Whisper): loaded and ready")
else:
status_lines.append("⏳ ASR (Whisper): will load on first use")
else:
status_lines.append("❌ ASR: library not available")
status_text = "\n".join(status_lines)
is_ready = is_model_loaded(model_name)
return status_text, is_ready
except Exception as e:
return f"❌ Error: {str(e)[:100]}", False
def check_model_status(model_name):
"""Check current model status without loading"""
status_lines = []
# Medical model status
if is_model_loaded(model_name):
status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready")
else:
state = get_model_loading_state(model_name)
if state == "loading":
status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
elif state == "error":
status_lines.append(f"❌ MedSwin ({model_name}): error loading")
else:
status_lines.append(f"⚠️ MedSwin ({model_name}): not loaded")
# TTS model status (only show if available and loaded)
if TTS_AVAILABLE:
if config.global_tts_model is not None:
status_lines.append("✅ TTS (maya1): loaded and ready")
# Don't show if TTS library available but model not loaded (optional feature)
# Don't show TTS status if library not available (it's optional)
# ASR (Whisper) model status
if WHISPER_AVAILABLE:
if config.global_whisper_model is not None:
status_lines.append("✅ ASR (Whisper): loaded and ready")
else:
status_lines.append("⏳ ASR (Whisper): will load on first use")
else:
status_lines.append("❌ ASR: library not available")
status_text = "\n".join(status_lines)
is_ready = is_model_loaded(model_name)
return status_text, is_ready
# GPU-decorated function to load ONLY medical model on startup
# According to ZeroGPU best practices:
# 1. Load models to CPU in global scope (no GPU decorator needed)
# 2. Move models to GPU only in inference functions (with @spaces.GPU decorator)
# However, for large models, loading to CPU then moving to GPU uses more memory
# So we use a hybrid approach: load to GPU directly but within GPU-decorated function
def load_medical_model_on_startup_cpu():
"""
Load model to CPU on startup (ZeroGPU best practice - no GPU decorator needed)
Model will be moved to GPU during first inference
"""
status_messages = []
try:
# Load only medical model (MedSwin) to CPU - TTS and Whisper load on-demand
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
logger.info(f"[STARTUP] Loading medical model to CPU: {DEFAULT_MEDICAL_MODEL}...")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
try:
# Load to CPU (no GPU decorator needed)
initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=False)
# Verify model is actually loaded
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to CPU")
logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to CPU successfully!")
else:
status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
except Exception as e:
status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}")
logger.error(f"[STARTUP] Failed to load medical model: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
else:
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
# Add ASR status (will load on first use)
if WHISPER_AVAILABLE:
status_messages.append("⏳ ASR (Whisper): will load on first use")
else:
status_messages.append("❌ ASR: library not available")
# Return status
status_text = "\n".join(status_messages)
logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}")
return status_text
except Exception as e:
error_msg = str(e)
logger.error(f"[STARTUP] Error loading model to CPU: {error_msg}")
return f"⚠️ Error loading model: {error_msg[:100]}"
# Alternative: Load directly to GPU (requires GPU decorator)
# @spaces.GPU(max_duration=MAX_DURATION)
def load_medical_model_on_startup_gpu():
"""
Load model directly to GPU on startup (alternative approach)
Uses GPU quota but model is immediately ready for inference
"""
import torch
status_messages = []
try:
# Clear GPU cache at start
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("[STARTUP] Cleared GPU cache before model loading")
# Load only medical model (MedSwin) - TTS and Whisper load on-demand
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
logger.info(f"[STARTUP] Loading medical model to GPU: {DEFAULT_MEDICAL_MODEL}...")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
try:
# Load directly to GPU (within GPU-decorated function)
initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=True)
# Verify model is actually loaded
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to GPU")
logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to GPU successfully!")
else:
status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
except Exception as e:
status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}")
logger.error(f"[STARTUP] Failed to load medical model: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
else:
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
# Add ASR status (will load on first use)
if WHISPER_AVAILABLE:
status_messages.append("⏳ ASR (Whisper): will load on first use")
else:
status_messages.append("❌ ASR: library not available")
# Clear cache after loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("[STARTUP] Cleared GPU cache after model loading")
# Return status
status_text = "\n".join(status_messages)
logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}")
return status_text
except Exception as e:
error_msg = str(e)
# Check if it's a ZeroGPU quota/rate limit error
is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or
"quota" in error_msg.lower() or "ZeroGPU" in error_msg or
"runnning out" in error_msg.lower() or "running out" in error_msg.lower())
if is_quota_error:
logger.warning(f"[STARTUP] ZeroGPU quota/rate limit error detected: {error_msg[:100]}")
# Return status message indicating quota error (will be handled by retry logic)
status_messages.append("⚠️ ZeroGPU quota error - will retry")
status_text = "\n".join(status_messages)
# Also add ASR status
if WHISPER_AVAILABLE:
status_text += "\n⏳ ASR (Whisper): will load on first use"
return status_text # Return status instead of raising, let wrapper handle retry
logger.error(f"[STARTUP] ❌ Error in model loading startup: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
# Clear cache on error
if torch.cuda.is_available():
torch.cuda.empty_cache()
return f"⚠️ Startup loading error: {str(e)[:100]}"
# Initialize status on load
def init_model_status():
try:
result = check_model_status(DEFAULT_MEDICAL_MODEL)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
return status_text
else:
return "⚠️ Unable to check model status"
except Exception as e:
logger.error(f"Error in init_model_status: {e}")
return f"⚠️ Error: {str(e)[:100]}"
# Update status when model selection changes
def update_model_status_on_change(model_name):
try:
result = check_model_status(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
return status_text
else:
return "⚠️ Unable to check model status"
except Exception as e:
logger.error(f"Error in update_model_status_on_change: {e}")
return f"⚠️ Error: {str(e)[:100]}"
# Handle model selection change
def on_model_change(model_name):
try:
result = load_model_and_update_status(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
submit_enabled = is_ready
return (
status_text,
gr.update(interactive=submit_enabled),
gr.update(interactive=submit_enabled)
)
else:
error_msg = "⚠️ Unable to load model status"
return (
error_msg,
gr.update(interactive=False),
gr.update(interactive=False)
)
except Exception as e:
logger.error(f"Error in on_model_change: {e}")
error_msg = f"⚠️ Error: {str(e)[:100]}"
return (
error_msg,
gr.update(interactive=False),
gr.update(interactive=False)
)
# Update status display periodically or on model status changes
def refresh_model_status(model_name):
return update_model_status_on_change(model_name)
medical_model.change(
fn=on_model_change,
inputs=[medical_model],
outputs=[model_status, submit_button, message_input]
)
# GPU-decorated function to load Whisper ASR model on-demand
# @spaces.GPU(max_duration=MAX_DURATION)
def load_whisper_model_on_demand():
"""Load Whisper ASR model when needed"""
try:
if WHISPER_AVAILABLE and config.global_whisper_model is None:
logger.info("[ASR] Loading Whisper model on-demand...")
initialize_whisper_model()
if config.global_whisper_model is not None:
logger.info("[ASR] ✅ Whisper model loaded successfully!")
return "✅ ASR (Whisper): loaded"
else:
logger.warning("[ASR] ⚠️ Whisper model failed to load")
return "⚠️ ASR (Whisper): failed to load"
elif config.global_whisper_model is not None:
return "✅ ASR (Whisper): already loaded"
else:
return "❌ ASR: library not available"
except Exception as e:
logger.error(f"[ASR] Error loading Whisper model: {e}")
return f"❌ ASR: error - {str(e)[:100]}"
# Load medical model on startup and update status
# Use a wrapper to handle GPU context properly with retry logic
def load_startup_and_update_ui():
"""
Load model on startup with retry logic (max 3 attempts) and return status with UI updates
Uses CPU-first approach (ZeroGPU best practice):
- Load model to CPU (no GPU decorator needed, avoids quota issues)
- Model will be moved to GPU during first inference
"""
import time
max_retries = 3
base_delay = 5.0 # Start with 5 seconds delay
for attempt in range(1, max_retries + 1):
try:
logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model to CPU...")
# Use CPU-first approach (no GPU decorator, avoids quota issues)
status_text = load_medical_model_on_startup_cpu()
# Check if model is ready and update submit button state
is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL)
if is_ready:
logger.info(f"[STARTUP] ✅ Model loaded successfully on attempt {attempt}")
return status_text, gr.update(interactive=is_ready), gr.update(interactive=is_ready)
else:
# Check if status text indicates quota error
if status_text and ("quota" in status_text.lower() or "ZeroGPU" in status_text or
"429" in status_text or "runnning out" in status_text.lower() or
"running out" in status_text.lower()):
if attempt < max_retries:
delay = base_delay * attempt
logger.warning(f"[STARTUP] Quota error detected in status, retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
# Quota exhausted after retries - allow user to proceed, model will load on-demand
status_msg = "⚠️ ZeroGPU quota exhausted.\n⏳ Model will load automatically when you send a message.\n💡 You can also select a model from the dropdown."
logger.info("[STARTUP] Quota exhausted after retries - allowing user to proceed with on-demand loading")
return status_msg, gr.update(interactive=True), gr.update(interactive=True)
# Model didn't load, but no exception - might be a state issue
logger.warning(f"[STARTUP] Model not ready after attempt {attempt}, but no error")
if attempt < max_retries:
delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s
logger.info(f"[STARTUP] Retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
# Even if model didn't load, allow user to try selecting another model
return status_text + "\n⚠️ Model not loaded. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True)
except Exception as e:
error_msg = str(e)
is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or
"quota" in error_msg.lower() or "ZeroGPU" in error_msg or
"runnning out" in error_msg.lower() or "running out" in error_msg.lower())
if is_quota_error and attempt < max_retries:
delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s
logger.warning(f"[STARTUP] ZeroGPU rate limit/quota error on attempt {attempt}/{max_retries}")
logger.info(f"[STARTUP] Retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
logger.error(f"[STARTUP] Error in load_startup_and_update_ui (attempt {attempt}/{max_retries}): {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
if is_quota_error:
# If quota exhausted, allow user to proceed - model will load on-demand
error_display = "⚠️ ZeroGPU quota exhausted.\n⏳ Model will load automatically when you send a message.\n💡 You can also select a model from the dropdown."
logger.info("[STARTUP] Quota exhausted - allowing user to proceed with on-demand loading")
return error_display, gr.update(interactive=True), gr.update(interactive=True)
else:
error_display = f"⚠️ Startup error: {str(e)[:100]}"
if attempt >= max_retries:
logger.error(f"[STARTUP] Failed after {max_retries} attempts")
return error_display, gr.update(interactive=False), gr.update(interactive=False)
# Should not reach here, but just in case
return "⚠️ Startup failed after retries. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True)
demo.load(
fn=load_startup_and_update_ui,
inputs=None,
outputs=[model_status, submit_button, message_input]
)
# Note: We removed the preload on focus functionality because:
# 1. Model loading requires GPU access (device_map="auto" needs GPU in ZeroGPU)
# 2. The startup function already loads the model with GPU decorator
# 3. Preloading without GPU decorator would fail or cause conflicts
# 4. If startup fails, user can select a model from dropdown to trigger loading
# Wrap stream_chat - ensure model is loaded before starting (don't load inside stream_chat to save time)
def stream_chat_with_model_check(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
):
# Check if model is loaded - if not, show error (don't load here to save stream_chat time)
model_loaded = is_model_loaded(medical_model_name)
if not model_loaded:
loading_state = get_model_loading_state(medical_model_name)
# Debug logging to understand why model check fails
logger.debug(f"[STREAM_CHAT] Model check: name={medical_model_name}, loaded={model_loaded}, state={loading_state}, in_dict={medical_model_name in config.global_medical_models}, model_exists={config.global_medical_models.get(medical_model_name) is not None if medical_model_name in config.global_medical_models else False}")
if loading_state == "loading":
error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
else:
error_msg = f"⚠️ {medical_model_name} is not loaded. Please wait a moment for the model to finish loading, or select a model from the dropdown to load it."
updated_history = history + [{"role": "assistant", "content": error_msg}]
yield updated_history, ""
return
# If request is None, create a mock request for compatibility
if request is None:
class MockRequest:
session_hash = "anonymous"
request = MockRequest()
# Model is loaded, proceed with stream_chat (no model loading here to save time)
try:
for result in stream_chat(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
):
yield result
except Exception as e:
# Handle any errors gracefully
logger.error(f"Error in stream_chat_with_model_check: {e}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
error_msg = f"⚠️ An error occurred: {str(e)[:200]}"
updated_history = history + [{"role": "assistant", "content": error_msg}]
yield updated_history, ""
submit_button.click(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
message_input.submit(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
return demo