MedLLM-Agent / ui.py
Y Phung Nguyen
Upd model loader to UI appender
2574b82
raw
history blame
18.7 kB
"""Gradio UI setup"""
import time
import gradio as gr
from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL
from indexing import create_or_update_index
from pipeline import stream_chat
from voice import transcribe_audio, generate_speech
from models import initialize_medical_model, is_model_loaded, get_model_loading_state, set_model_loading_state
def create_demo():
"""Create and return Gradio demo interface"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
with gr.Row(elem_classes="main-container"):
with gr.Column(elem_classes="upload-section"):
file_upload = gr.File(
file_count="multiple",
label="Drag and Drop Files Here",
file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"],
elem_id="file-upload"
)
upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
status_output = gr.Textbox(
label="Status",
placeholder="Upload files to start...",
interactive=False
)
file_info_output = gr.HTML(
label="File Information",
elem_classes="processing-info"
)
upload_button.click(
fn=create_or_update_index,
inputs=[file_upload],
outputs=[status_output, file_info_output]
)
with gr.Column(elem_classes="chatbot-container"):
chatbot = gr.Chatbot(
height=500,
placeholder="Chat with MedSwin... Type your question below.",
show_label=False,
type="messages"
)
with gr.Row(elem_classes="input-row"):
message_input = gr.Textbox(
placeholder="Type your medical question here...",
show_label=False,
container=False,
lines=1,
scale=10
)
mic_button = gr.Audio(
sources=["microphone"],
type="filepath",
label="",
show_label=False,
container=False,
scale=1
)
submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
recording_timer = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
container=False,
elem_classes="recording-timer"
)
recording_start_time = [None]
def handle_recording_start():
"""Called when recording starts"""
recording_start_time[0] = time.time()
return gr.update(visible=True, value="Recording... 0s")
def handle_recording_stop(audio):
"""Called when recording stops"""
recording_start_time[0] = None
if audio is None:
return gr.update(visible=False, value=""), ""
transcribed = transcribe_audio(audio)
return gr.update(visible=False, value=""), transcribed
mic_button.start_recording(
fn=handle_recording_start,
outputs=[recording_timer]
)
mic_button.stop_recording(
fn=handle_recording_stop,
inputs=[mic_button],
outputs=[recording_timer, message_input]
)
with gr.Row(visible=False) as tts_row:
tts_text = gr.Textbox(visible=False)
tts_audio = gr.Audio(label="Generated Speech", visible=False)
def generate_speech_from_chat(history):
"""Extract last assistant message and generate speech"""
if not history or len(history) == 0:
return None
last_msg = history[-1]
if last_msg.get("role") == "assistant":
text = last_msg.get("content", "").replace(" 🔊", "").strip()
if text:
audio_path = generate_speech(text)
return audio_path
return None
tts_button = gr.Button("🔊 Play Response", visible=False, size="sm")
def update_tts_button(history):
if history and len(history) > 0 and history[-1].get("role") == "assistant":
return gr.update(visible=True)
return gr.update(visible=False)
chatbot.change(
fn=update_tts_button,
inputs=[chatbot],
outputs=[tts_button]
)
tts_button.click(
fn=generate_speech_from_chat,
inputs=[chatbot],
outputs=[tts_audio]
)
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
disable_agentic_reasoning = gr.Checkbox(
value=False,
label="Disable agentic reasoning",
info="Use MedSwin model alone without agentic reasoning, RAG, or web search"
)
show_agentic_thought = gr.Button(
"Show agentic thought",
size="sm"
)
enable_clinical_intake = gr.Checkbox(
value=True,
label="Enable clinical intake (max 5 Q&A)",
info="Ask focused follow-up questions before breaking down the case"
)
agentic_thoughts_box = gr.Textbox(
label="Agentic Thoughts",
placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
lines=8,
max_lines=15,
interactive=False,
visible=False,
elem_classes="agentic-thoughts"
)
with gr.Row():
use_rag = gr.Checkbox(
value=False,
label="Enable Document RAG",
info="Answer based on uploaded documents (upload required)"
)
use_web_search = gr.Checkbox(
value=False,
label="Enable Web Search (MCP)",
info="Fetch knowledge from online medical resources"
)
medical_model = gr.Radio(
choices=list(MEDSWIN_MODELS.keys()),
value=DEFAULT_MEDICAL_MODEL,
label="Medical Model",
info="MedSwin TA (default), others download on first use"
)
model_status = gr.Textbox(
value="Checking model status...",
label="Model Status",
interactive=False,
visible=True,
elem_classes="model-status"
)
system_prompt = gr.Textbox(
value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
label="System Prompt",
lines=3
)
with gr.Tab("Generation Parameters"):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=512,
maximum=4096,
step=128,
value=2048,
label="Max New Tokens",
info="Increased for medical models to prevent early stopping"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
label="Top P"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top K"
)
penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition Penalty"
)
with gr.Tab("Retrieval Parameters"):
retriever_k = gr.Slider(
minimum=5,
maximum=30,
step=1,
value=15,
label="Initial Retrieval Size (Top K)"
)
merge_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.1,
value=0.5,
label="Merge Threshold (lower = more merging)"
)
show_thoughts_state = gr.State(value=False)
def toggle_thoughts_box(current_state):
"""Toggle visibility of agentic thoughts box"""
new_state = not current_state
return gr.update(visible=new_state), new_state
show_agentic_thought.click(
fn=toggle_thoughts_box,
inputs=[show_thoughts_state],
outputs=[agentic_thoughts_box, show_thoughts_state]
)
def load_model_and_update_status(model_name):
"""Load model and update status, return status text and whether model is ready"""
try:
if is_model_loaded(model_name):
return "✅ The model has been loaded successfully", True
state = get_model_loading_state(model_name)
if state == "loading":
return "⏳ The model is being loaded, please wait...", False
elif state == "error":
return "❌ Error loading model. Please try again.", False
# Start loading
set_model_loading_state(model_name, "loading")
try:
initialize_medical_model(model_name)
return "✅ The model has been loaded successfully", True
except Exception as e:
set_model_loading_state(model_name, "error")
return f"❌ Error loading model: {str(e)[:100]}", False
except Exception as e:
return f"❌ Error: {str(e)[:100]}", False
def check_model_status(model_name):
"""Check current model status without loading"""
if is_model_loaded(model_name):
return "✅ The model has been loaded successfully", True
state = get_model_loading_state(model_name)
if state == "loading":
return "⏳ The model is being loaded, please wait...", False
elif state == "error":
return "❌ Error loading model. Please try again.", False
else:
return "⚠️ Model not loaded. Click to load or it will load on first use.", False
# Initialize status on load
def init_model_status():
status_text, is_ready = check_model_status(DEFAULT_MEDICAL_MODEL)
return status_text
# Handle model selection change
def on_model_change(model_name):
status_text, is_ready = load_model_and_update_status(model_name)
submit_enabled = is_ready
return (
status_text,
gr.update(interactive=submit_enabled),
gr.update(interactive=submit_enabled)
)
medical_model.change(
fn=on_model_change,
inputs=[medical_model],
outputs=[model_status, submit_button, message_input]
)
# Initialize status
demo.load(
fn=init_model_status,
outputs=[model_status]
)
# Wrap stream_chat to check model status before execution
def stream_chat_with_model_check(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
):
try:
# Check if model is loaded
if not is_model_loaded(medical_model_name):
# Try to load it
status_text, is_ready = load_model_and_update_status(medical_model_name)
if not is_ready:
error_msg = "⚠️ Model is not ready. Please wait for the model to finish loading before sending messages."
updated_history = history + [{"role": "assistant", "content": error_msg}]
yield updated_history, ""
return
# Model is ready, proceed with chat
if request is None:
# If request is None, create a mock request for compatibility
class MockRequest:
session_hash = "anonymous"
request = MockRequest()
for result in stream_chat(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
):
yield result
except Exception as e:
# Handle any errors gracefully
error_msg = f"⚠️ An error occurred: {str(e)[:200]}"
updated_history = history + [{"role": "assistant", "content": error_msg}]
yield updated_history, ""
submit_button.click(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
message_input.submit(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
return demo