MedLLM-Agent / ui.py
Y Phung Nguyen
Upd MedSwin loader with GPU
5c0f22e
raw
history blame
21.7 kB
"""Gradio UI setup"""
import time
import gradio as gr
import spaces
from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL
from indexing import create_or_update_index
from pipeline import stream_chat
from voice import transcribe_audio, generate_speech
from models import initialize_medical_model, is_model_loaded, get_model_loading_state, set_model_loading_state
from logger import logger
def create_demo():
"""Create and return Gradio demo interface"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
with gr.Row(elem_classes="main-container"):
with gr.Column(elem_classes="upload-section"):
file_upload = gr.File(
file_count="multiple",
label="Drag and Drop Files Here",
file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"],
elem_id="file-upload"
)
upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
status_output = gr.Textbox(
label="Status",
placeholder="Upload files to start...",
interactive=False
)
file_info_output = gr.HTML(
label="File Information",
elem_classes="processing-info"
)
upload_button.click(
fn=create_or_update_index,
inputs=[file_upload],
outputs=[status_output, file_info_output]
)
with gr.Column(elem_classes="chatbot-container"):
chatbot = gr.Chatbot(
height=500,
placeholder="Chat with MedSwin... Type your question below.",
show_label=False,
type="messages"
)
with gr.Row(elem_classes="input-row"):
message_input = gr.Textbox(
placeholder="Type your medical question here...",
show_label=False,
container=False,
lines=1,
scale=10
)
mic_button = gr.Audio(
sources=["microphone"],
type="filepath",
label="",
show_label=False,
container=False,
scale=1
)
submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
recording_timer = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
container=False,
elem_classes="recording-timer"
)
recording_start_time = [None]
def handle_recording_start():
"""Called when recording starts"""
recording_start_time[0] = time.time()
return gr.update(visible=True, value="Recording... 0s")
def handle_recording_stop(audio):
"""Called when recording stops"""
recording_start_time[0] = None
if audio is None:
return gr.update(visible=False, value=""), ""
transcribed = transcribe_audio(audio)
return gr.update(visible=False, value=""), transcribed
mic_button.start_recording(
fn=handle_recording_start,
outputs=[recording_timer]
)
mic_button.stop_recording(
fn=handle_recording_stop,
inputs=[mic_button],
outputs=[recording_timer, message_input]
)
with gr.Row(visible=False) as tts_row:
tts_text = gr.Textbox(visible=False)
tts_audio = gr.Audio(label="Generated Speech", visible=False)
def generate_speech_from_chat(history):
"""Extract last assistant message and generate speech"""
if not history or len(history) == 0:
return None
last_msg = history[-1]
if last_msg.get("role") == "assistant":
text = last_msg.get("content", "").replace(" 🔊", "").strip()
if text:
audio_path = generate_speech(text)
return audio_path
return None
tts_button = gr.Button("🔊 Play Response", visible=False, size="sm")
def update_tts_button(history):
if history and len(history) > 0 and history[-1].get("role") == "assistant":
return gr.update(visible=True)
return gr.update(visible=False)
chatbot.change(
fn=update_tts_button,
inputs=[chatbot],
outputs=[tts_button]
)
tts_button.click(
fn=generate_speech_from_chat,
inputs=[chatbot],
outputs=[tts_audio]
)
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
disable_agentic_reasoning = gr.Checkbox(
value=False,
label="Disable agentic reasoning",
info="Use MedSwin model alone without agentic reasoning, RAG, or web search"
)
show_agentic_thought = gr.Button(
"Show agentic thought",
size="sm"
)
enable_clinical_intake = gr.Checkbox(
value=True,
label="Enable clinical intake (max 5 Q&A)",
info="Ask focused follow-up questions before breaking down the case"
)
agentic_thoughts_box = gr.Textbox(
label="Agentic Thoughts",
placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
lines=8,
max_lines=15,
interactive=False,
visible=False,
elem_classes="agentic-thoughts"
)
with gr.Row():
use_rag = gr.Checkbox(
value=False,
label="Enable Document RAG",
info="Answer based on uploaded documents (upload required)"
)
use_web_search = gr.Checkbox(
value=False,
label="Enable Web Search (MCP)",
info="Fetch knowledge from online medical resources"
)
medical_model = gr.Radio(
choices=list(MEDSWIN_MODELS.keys()),
value=DEFAULT_MEDICAL_MODEL,
label="Medical Model",
info="MedSwin DT (default), others download on selection"
)
model_status = gr.Textbox(
value="Checking model status...",
label="Model Status",
interactive=False,
visible=True,
elem_classes="model-status"
)
system_prompt = gr.Textbox(
value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
label="System Prompt",
lines=3
)
with gr.Tab("Generation Parameters"):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=512,
maximum=4096,
step=128,
value=2048,
label="Max New Tokens",
info="Increased for medical models to prevent early stopping"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
label="Top P"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top K"
)
penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition Penalty"
)
with gr.Tab("Retrieval Parameters"):
retriever_k = gr.Slider(
minimum=5,
maximum=30,
step=1,
value=15,
label="Initial Retrieval Size (Top K)"
)
merge_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.1,
value=0.5,
label="Merge Threshold (lower = more merging)"
)
show_thoughts_state = gr.State(value=False)
def toggle_thoughts_box(current_state):
"""Toggle visibility of agentic thoughts box"""
new_state = not current_state
return gr.update(visible=new_state), new_state
show_agentic_thought.click(
fn=toggle_thoughts_box,
inputs=[show_thoughts_state],
outputs=[agentic_thoughts_box, show_thoughts_state]
)
# GPU-decorated function to load any model (for user selection)
@spaces.GPU(max_duration=120)
def load_model_with_gpu(model_name):
"""Load medical model (GPU-decorated for ZeroGPU compatibility)"""
try:
if not is_model_loaded(model_name):
logger.info(f"Loading medical model: {model_name}...")
set_model_loading_state(model_name, "loading")
try:
initialize_medical_model(model_name)
logger.info(f"✅ Medical model {model_name} loaded successfully!")
return "✅ The model has been loaded successfully", True
except Exception as e:
logger.error(f"Failed to load medical model {model_name}: {e}")
set_model_loading_state(model_name, "error")
return f"❌ Error loading model: {str(e)[:100]}", False
else:
logger.info(f"Medical model {model_name} is already loaded")
return "✅ The model has been loaded successfully", True
except Exception as e:
logger.error(f"Error loading model {model_name}: {e}")
return f"❌ Error: {str(e)[:100]}", False
def load_model_and_update_status(model_name):
"""Load model and update status, return status text and whether model is ready"""
try:
if is_model_loaded(model_name):
return "✅ The model has been loaded successfully", True
state = get_model_loading_state(model_name)
if state == "loading":
return "⏳ The model is being loaded, please wait...", False
elif state == "error":
return "❌ Error loading model. Please try again.", False
# Use GPU-decorated function to load the model
status_text, is_ready = load_model_with_gpu(model_name)
return status_text, is_ready
except Exception as e:
return f"❌ Error: {str(e)[:100]}", False
def check_model_status(model_name):
"""Check current model status without loading"""
if is_model_loaded(model_name):
return "✅ The model has been loaded successfully", True
state = get_model_loading_state(model_name)
if state == "loading":
return "⏳ The model is being loaded, please wait...", False
elif state == "error":
return "❌ Error loading model. Please try again.", False
else:
return "⚠️ Model not loaded. Click to load or it will load on first use.", False
# GPU-decorated function to load model on startup
@spaces.GPU(max_duration=120)
def load_default_model_on_startup():
"""Load default medical model on startup (GPU-decorated for ZeroGPU compatibility)"""
try:
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
logger.info(f"Loading default medical model on startup: {DEFAULT_MEDICAL_MODEL}...")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
try:
initialize_medical_model(DEFAULT_MEDICAL_MODEL)
logger.info(f"✅ Default medical model {DEFAULT_MEDICAL_MODEL} loaded successfully on startup!")
return f"✅ {DEFAULT_MEDICAL_MODEL} loaded successfully"
except Exception as e:
logger.error(f"Failed to load default medical model on startup: {e}")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
return f"❌ Error loading model: {str(e)[:100]}"
else:
logger.info(f"Default medical model {DEFAULT_MEDICAL_MODEL} is already loaded")
return f"✅ {DEFAULT_MEDICAL_MODEL} is ready"
except Exception as e:
logger.error(f"Error in model loading startup: {e}")
return f"⚠️ Startup loading error: {str(e)[:100]}"
# Initialize status on load
def init_model_status():
status_text, is_ready = check_model_status(DEFAULT_MEDICAL_MODEL)
return status_text
# Handle model selection change
def on_model_change(model_name):
status_text, is_ready = load_model_and_update_status(model_name)
submit_enabled = is_ready
return (
status_text,
gr.update(interactive=submit_enabled),
gr.update(interactive=submit_enabled)
)
medical_model.change(
fn=on_model_change,
inputs=[medical_model],
outputs=[model_status, submit_button, message_input]
)
# Load default model on startup (GPU-decorated function)
demo.load(
fn=load_default_model_on_startup,
outputs=[model_status]
)
# Wrap stream_chat to check model status before execution
def stream_chat_with_model_check(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
):
try:
# Check if model is loaded
if not is_model_loaded(medical_model_name):
# Try to load it
status_text, is_ready = load_model_and_update_status(medical_model_name)
if not is_ready:
error_msg = "⚠️ Model is not ready. Please wait for the model to finish loading before sending messages."
updated_history = history + [{"role": "assistant", "content": error_msg}]
yield updated_history, ""
return
# Model is ready, proceed with chat
if request is None:
# If request is None, create a mock request for compatibility
class MockRequest:
session_hash = "anonymous"
request = MockRequest()
for result in stream_chat(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
):
yield result
except Exception as e:
# Handle any errors gracefully
error_msg = f"⚠️ An error occurred: {str(e)[:200]}"
updated_history = history + [{"role": "assistant", "content": error_msg}]
yield updated_history, ""
submit_button.click(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
message_input.submit(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
return demo