Spaces:
Running
on
Zero
Running
on
Zero
File size: 55,293 Bytes
52b4ed7 5c0f22e 52b4ed7 af9efda 52b4ed7 2506ee7 af9efda 2506ee7 af9efda 2506ee7 5c0f22e 52b4ed7 2fffb9d 52b4ed7 30bc748 83a4de1 30bc748 52b4ed7 5487be8 52b4ed7 03d8100 faa95c5 03d8100 52b4ed7 22b7790 ab36fa0 52b4ed7 30bc748 52b4ed7 5c0f22e 09d7494 5c0f22e 03d8100 af9efda 03d8100 af9efda c5ac360 af9efda 03d8100 faa95c5 af9efda faa95c5 af9efda 39c0a27 c5ac360 39c0a27 4f99918 af9efda 39c0a27 c5ac360 03d8100 af9efda 03d8100 af9efda 0c5218c ab36fa0 0c5218c ab36fa0 af9efda 0c5218c af9efda faa95c5 af9efda faa95c5 af9efda 4f99918 af9efda 4f99918 af9efda 03d8100 8056774 f7415cc a5fe328 5c0f22e a5fe328 8056774 5c0f22e f7415cc 5c0f22e f7415cc b94bf16 f7415cc b94bf16 5c0f22e b94bf16 a5fe328 b94bf16 5c0f22e a5fe328 8056774 af9efda 8056774 af9efda a5fe328 8056774 a5fe328 8056774 a5fe328 8056774 a5fe328 2506ee7 ab36fa0 acc39fd ab36fa0 acc39fd ab36fa0 acc39fd a5fe328 af9efda a5fe328 2506ee7 03d8100 c5ac360 03d8100 af9efda c5ac360 5d5697b c5ac360 af9efda 03d8100 c5ac360 5d5697b c5ac360 03d8100 af9efda 03d8100 4f99918 09d7494 4f99918 8056774 ab36fa0 8056774 f7415cc ab36fa0 f7415cc ab36fa0 acc39fd ab36fa0 acc39fd ab36fa0 acc39fd ab36fa0 acc39fd ab36fa0 acc39fd ab36fa0 acc39fd 8056774 af9efda 8056774 af9efda 8056774 af9efda 03d8100 6698c3b 4a5418d 03d8100 2574b82 03d8100 4a5418d c11b620 4a5418d c11b620 4a5418d 6ae14bf 4a5418d 6698c3b 4a5418d d8e18ef 8515412 d8e18ef 4a5418d 8515412 d8e18ef 4a5418d 8515412 d8e18ef 8515412 4a5418d 8515412 d8e18ef 8515412 d8e18ef 8515412 d8e18ef 8515412 d8e18ef 8515412 d8e18ef 03d8100 52b4ed7 03d8100 52b4ed7 83a4de1 52b4ed7 30bc748 52b4ed7 03d8100 52b4ed7 83a4de1 52b4ed7 30bc748 52b4ed7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 |
"""Gradio UI setup"""
import time
import gradio as gr
import spaces
from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL
import config
from indexing import create_or_update_index
from pipeline import stream_chat
from voice import transcribe_audio, generate_speech
from models import (
initialize_medical_model,
is_model_loaded,
get_model_loading_state,
set_model_loading_state,
initialize_tts_model,
initialize_whisper_model,
TTS_AVAILABLE,
WHISPER_AVAILABLE,
)
from logger import logger
MAX_DURATION = 120
def create_demo():
"""Create and return Gradio demo interface"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
with gr.Row(elem_classes="main-container"):
with gr.Column(elem_classes="upload-section"):
file_upload = gr.File(
file_count="multiple",
label="Drag and Drop Files Here",
file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"],
elem_id="file-upload"
)
upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
status_output = gr.Textbox(
label="Status",
placeholder="Upload files to start...",
interactive=False
)
file_info_output = gr.HTML(
label="File Information",
elem_classes="processing-info"
)
upload_button.click(
fn=create_or_update_index,
inputs=[file_upload],
outputs=[status_output, file_info_output]
)
with gr.Column(elem_classes="chatbot-container"):
chatbot = gr.Chatbot(
height=500,
placeholder="Chat with MedSwin... Type your question below.",
show_label=False,
type="messages"
)
with gr.Row(elem_classes="input-row"):
message_input = gr.Textbox(
placeholder="Type your medical question here...",
show_label=False,
container=False,
lines=1,
scale=10
)
mic_button = gr.Audio(
sources=["microphone"],
type="filepath",
label="",
show_label=False,
container=False,
scale=1
)
submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
recording_timer = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
container=False,
elem_classes="recording-timer"
)
recording_start_time = [None]
def handle_recording_start():
"""Called when recording starts"""
recording_start_time[0] = time.time()
return gr.update(visible=True, value="Recording... 0s")
def handle_recording_stop(audio):
"""Called when recording stops"""
recording_start_time[0] = None
if audio is None:
return gr.update(visible=False, value=""), ""
transcribed = transcribe_audio(audio)
return gr.update(visible=False, value=""), transcribed
mic_button.start_recording(
fn=handle_recording_start,
outputs=[recording_timer]
)
mic_button.stop_recording(
fn=handle_recording_stop,
inputs=[mic_button],
outputs=[recording_timer, message_input]
)
with gr.Row(visible=False) as tts_row:
tts_text = gr.Textbox(visible=False)
tts_audio = gr.Audio(label="Generated Speech", visible=False)
def generate_speech_from_chat(history):
"""Extract last assistant message and generate speech"""
if not history or len(history) == 0:
return None
last_msg = history[-1]
if last_msg.get("role") == "assistant":
text = last_msg.get("content", "").replace(" 🔊", "").strip()
if text:
audio_path = generate_speech(text)
return audio_path
return None
tts_button = gr.Button("🔊 Play Response", visible=False, size="sm")
def update_tts_button(history):
if history and len(history) > 0 and history[-1].get("role") == "assistant":
return gr.update(visible=True)
return gr.update(visible=False)
chatbot.change(
fn=update_tts_button,
inputs=[chatbot],
outputs=[tts_button]
)
tts_button.click(
fn=generate_speech_from_chat,
inputs=[chatbot],
outputs=[tts_audio]
)
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
disable_agentic_reasoning = gr.Checkbox(
value=False,
label="Disable agentic reasoning",
info="Use MedSwin model alone without agentic reasoning, RAG, or web search"
)
show_agentic_thought = gr.Button(
"Show agentic thought",
size="sm"
)
enable_clinical_intake = gr.Checkbox(
value=True,
label="Enable clinical intake (max 5 Q&A)",
info="Ask focused follow-up questions before breaking down the case"
)
agentic_thoughts_box = gr.Textbox(
label="Agentic Thoughts",
placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
lines=8,
max_lines=15,
interactive=False,
visible=False,
elem_classes="agentic-thoughts"
)
with gr.Row():
use_rag = gr.Checkbox(
value=False,
label="Enable Document RAG",
info="Answer based on uploaded documents (upload required)"
)
use_web_search = gr.Checkbox(
value=False,
label="Enable Web Search (MCP)",
info="Fetch knowledge from online medical resources"
)
medical_model = gr.Radio(
choices=list(MEDSWIN_MODELS.keys()),
value=DEFAULT_MEDICAL_MODEL,
label="Medical Model",
info="MedSwin DT (default), others download on selection"
)
model_status = gr.Textbox(
value="Checking model status...",
label="Model Status",
interactive=False,
visible=True,
lines=3,
max_lines=3,
elem_classes="model-status"
)
system_prompt = gr.Textbox(
value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
label="System Prompt",
lines=3
)
with gr.Tab("Generation Parameters"):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=512,
maximum=4096,
step=128,
value=2048,
label="Max New Tokens",
info="Increased for medical models to prevent early stopping"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
label="Top P"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top K"
)
penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition Penalty"
)
with gr.Tab("Retrieval Parameters"):
retriever_k = gr.Slider(
minimum=5,
maximum=30,
step=1,
value=15,
label="Initial Retrieval Size (Top K)"
)
merge_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.1,
value=0.5,
label="Merge Threshold (lower = more merging)"
)
# MedSwin Model Links
gr.Markdown(
"""
<div style="margin-top: 20px; padding: 15px; background-color: #f5f5f5; border-radius: 8px;">
<h4 style="margin-top: 0; margin-bottom: 10px;">🔗 MedSwin Models on Hugging Face</h4>
<div style="display: flex; flex-wrap: wrap; gap: 10px;">
<a href="https://huggingface.co/MedSwin/MedSwin-Merged-DaRE-TIES-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin DT</a>
<a href="https://huggingface.co/MedSwin/MedSwin-Merged-NuSLERP-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin Nsl</a>
<a href="https://huggingface.co/MedSwin/MedSwin-Merged-DaRE-Linear-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin DL</a>
<a href="https://huggingface.co/MedSwin/MedSwin-Merged-TIES-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin Ti</a>
<a href="https://huggingface.co/MedSwin/MedSwin-Merged-TA-SFT-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin TA</a>
<a href="https://huggingface.co/MedSwin/MedSwin-7B-SFT" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin SFT</a>
<a href="https://huggingface.co/MedSwin/MedSwin-7B-KD" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin KD</a>
</div>
<p style="margin-top: 10px; margin-bottom: 0; font-size: 11px; color: #666;">Click any model name to view details on Hugging Face</p>
</div>
"""
)
show_thoughts_state = gr.State(value=False)
def toggle_thoughts_box(current_state):
"""Toggle visibility of agentic thoughts box"""
new_state = not current_state
return gr.update(visible=new_state), new_state
show_agentic_thought.click(
fn=toggle_thoughts_box,
inputs=[show_thoughts_state],
outputs=[agentic_thoughts_box, show_thoughts_state]
)
# GPU-decorated function to load any model (for user selection)
# @spaces.GPU(max_duration=MAX_DURATION)
def load_model_with_gpu(model_name):
"""Load medical model (GPU-decorated for ZeroGPU compatibility)"""
try:
if not is_model_loaded(model_name):
logger.info(f"Loading medical model: {model_name}...")
set_model_loading_state(model_name, "loading")
try:
initialize_medical_model(model_name)
logger.info(f"✅ Medical model {model_name} loaded successfully!")
return "✅ The model has been loaded successfully", True
except Exception as e:
logger.error(f"Failed to load medical model {model_name}: {e}")
set_model_loading_state(model_name, "error")
return f"❌ Error loading model: {str(e)[:100]}", False
else:
logger.info(f"Medical model {model_name} is already loaded")
return "✅ The model has been loaded successfully", True
except Exception as e:
logger.error(f"Error loading model {model_name}: {e}")
return f"❌ Error: {str(e)[:100]}", False
def load_model_and_update_status(model_name):
"""Load model and update status, return status text and whether model is ready"""
try:
status_lines = []
# Medical model status
if is_model_loaded(model_name):
status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready")
else:
state = get_model_loading_state(model_name)
if state == "loading":
status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
elif state == "error":
status_lines.append(f"❌ MedSwin ({model_name}): error loading")
else:
# Use GPU-decorated function to load the model
try:
result = load_model_with_gpu(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
if is_ready:
status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready")
else:
status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
else:
status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
except Exception as e:
logger.error(f"Error calling load_model_with_gpu: {e}")
status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
# TTS model status (only show if available or if there's an issue)
if TTS_AVAILABLE:
if config.global_tts_model is not None:
status_lines.append("✅ TTS (maya1): loaded and ready")
else:
# TTS available but not loaded - optional feature
pass # Don't show if not loaded, it's optional
# Don't show TTS status if library not available (it's optional)
# ASR (Whisper) model status
if WHISPER_AVAILABLE:
if config.global_whisper_model is not None:
status_lines.append("✅ ASR (Whisper): loaded and ready")
else:
status_lines.append("⏳ ASR (Whisper): will load on first use")
else:
status_lines.append("❌ ASR: library not available")
status_text = "\n".join(status_lines)
is_ready = is_model_loaded(model_name)
return status_text, is_ready
except Exception as e:
return f"❌ Error: {str(e)[:100]}", False
def check_model_status(model_name):
"""Check current model status without loading"""
status_lines = []
# Medical model status
if is_model_loaded(model_name):
status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready")
else:
state = get_model_loading_state(model_name)
if state == "loading":
status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
elif state == "error":
status_lines.append(f"❌ MedSwin ({model_name}): error loading")
else:
status_lines.append(f"⚠️ MedSwin ({model_name}): not loaded")
# TTS model status (only show if available and loaded)
if TTS_AVAILABLE:
if config.global_tts_model is not None:
status_lines.append("✅ TTS (maya1): loaded and ready")
# Don't show if TTS library available but model not loaded (optional feature)
# Don't show TTS status if library not available (it's optional)
# ASR (Whisper) model status
if WHISPER_AVAILABLE:
if config.global_whisper_model is not None:
status_lines.append("✅ ASR (Whisper): loaded and ready")
else:
status_lines.append("⏳ ASR (Whisper): will load on first use")
else:
status_lines.append("❌ ASR: library not available")
status_text = "\n".join(status_lines)
is_ready = is_model_loaded(model_name)
return status_text, is_ready
# GPU-decorated function to load ONLY medical model on startup
# According to ZeroGPU best practices:
# 1. Load models to CPU in global scope (no GPU decorator needed)
# 2. Move models to GPU only in inference functions (with @spaces.GPU decorator)
# However, for large models, loading to CPU then moving to GPU uses more memory
# So we use a hybrid approach: load to GPU directly but within GPU-decorated function
def load_medical_model_on_startup_cpu():
"""
Load model to CPU on startup (ZeroGPU best practice - no GPU decorator needed)
Model will be moved to GPU during first inference
"""
status_messages = []
try:
# Load only medical model (MedSwin) to CPU - TTS and Whisper load on-demand
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
logger.info(f"[STARTUP] Loading medical model to CPU: {DEFAULT_MEDICAL_MODEL}...")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
try:
# Load to CPU (no GPU decorator needed)
initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=False)
# Verify model is actually loaded
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to CPU")
logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to CPU successfully!")
else:
status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
except Exception as e:
status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}")
logger.error(f"[STARTUP] Failed to load medical model: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
else:
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
# Add ASR status (will load on first use)
if WHISPER_AVAILABLE:
status_messages.append("⏳ ASR (Whisper): will load on first use")
else:
status_messages.append("❌ ASR: library not available")
# Return status
status_text = "\n".join(status_messages)
logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}")
return status_text
except Exception as e:
error_msg = str(e)
logger.error(f"[STARTUP] Error loading model to CPU: {error_msg}")
return f"⚠️ Error loading model: {error_msg[:100]}"
# Alternative: Load directly to GPU (requires GPU decorator)
# @spaces.GPU(max_duration=MAX_DURATION)
def load_medical_model_on_startup_gpu():
"""
Load model directly to GPU on startup (alternative approach)
Uses GPU quota but model is immediately ready for inference
"""
import torch
status_messages = []
try:
# Clear GPU cache at start
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("[STARTUP] Cleared GPU cache before model loading")
# Load only medical model (MedSwin) - TTS and Whisper load on-demand
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
logger.info(f"[STARTUP] Loading medical model to GPU: {DEFAULT_MEDICAL_MODEL}...")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
try:
# Load directly to GPU (within GPU-decorated function)
initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=True)
# Verify model is actually loaded
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to GPU")
logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to GPU successfully!")
else:
status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
except Exception as e:
status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}")
logger.error(f"[STARTUP] Failed to load medical model: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
else:
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
# Add ASR status (will load on first use)
if WHISPER_AVAILABLE:
status_messages.append("⏳ ASR (Whisper): will load on first use")
else:
status_messages.append("❌ ASR: library not available")
# Clear cache after loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("[STARTUP] Cleared GPU cache after model loading")
# Return status
status_text = "\n".join(status_messages)
logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}")
return status_text
except Exception as e:
error_msg = str(e)
# Check if it's a ZeroGPU quota/rate limit error
is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or
"quota" in error_msg.lower() or "ZeroGPU" in error_msg or
"runnning out" in error_msg.lower() or "running out" in error_msg.lower())
if is_quota_error:
logger.warning(f"[STARTUP] ZeroGPU quota/rate limit error detected: {error_msg[:100]}")
# Return status message indicating quota error (will be handled by retry logic)
status_messages.append("⚠️ ZeroGPU quota error - will retry")
status_text = "\n".join(status_messages)
# Also add ASR status
if WHISPER_AVAILABLE:
status_text += "\n⏳ ASR (Whisper): will load on first use"
return status_text # Return status instead of raising, let wrapper handle retry
logger.error(f"[STARTUP] ❌ Error in model loading startup: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
# Clear cache on error
if torch.cuda.is_available():
torch.cuda.empty_cache()
return f"⚠️ Startup loading error: {str(e)[:100]}"
# Initialize status on load
def init_model_status():
try:
result = check_model_status(DEFAULT_MEDICAL_MODEL)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
return status_text
else:
return "⚠️ Unable to check model status"
except Exception as e:
logger.error(f"Error in init_model_status: {e}")
return f"⚠️ Error: {str(e)[:100]}"
# Update status when model selection changes
def update_model_status_on_change(model_name):
try:
result = check_model_status(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
return status_text
else:
return "⚠️ Unable to check model status"
except Exception as e:
logger.error(f"Error in update_model_status_on_change: {e}")
return f"⚠️ Error: {str(e)[:100]}"
# Handle model selection change
def on_model_change(model_name):
try:
result = load_model_and_update_status(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
submit_enabled = is_ready
return (
status_text,
gr.update(interactive=submit_enabled),
gr.update(interactive=submit_enabled)
)
else:
error_msg = "⚠️ Unable to load model status"
return (
error_msg,
gr.update(interactive=False),
gr.update(interactive=False)
)
except Exception as e:
logger.error(f"Error in on_model_change: {e}")
error_msg = f"⚠️ Error: {str(e)[:100]}"
return (
error_msg,
gr.update(interactive=False),
gr.update(interactive=False)
)
# Update status display periodically or on model status changes
def refresh_model_status(model_name):
return update_model_status_on_change(model_name)
medical_model.change(
fn=on_model_change,
inputs=[medical_model],
outputs=[model_status, submit_button, message_input]
)
# GPU-decorated function to load Whisper ASR model on-demand
# @spaces.GPU(max_duration=MAX_DURATION)
def load_whisper_model_on_demand():
"""Load Whisper ASR model when needed"""
try:
if WHISPER_AVAILABLE and config.global_whisper_model is None:
logger.info("[ASR] Loading Whisper model on-demand...")
initialize_whisper_model()
if config.global_whisper_model is not None:
logger.info("[ASR] ✅ Whisper model loaded successfully!")
return "✅ ASR (Whisper): loaded"
else:
logger.warning("[ASR] ⚠️ Whisper model failed to load")
return "⚠️ ASR (Whisper): failed to load"
elif config.global_whisper_model is not None:
return "✅ ASR (Whisper): already loaded"
else:
return "❌ ASR: library not available"
except Exception as e:
logger.error(f"[ASR] Error loading Whisper model: {e}")
return f"❌ ASR: error - {str(e)[:100]}"
# Load medical model on startup and update status
# Use a wrapper to handle GPU context properly with retry logic
def load_startup_and_update_ui():
"""
Load model on startup with retry logic (max 3 attempts) and return status with UI updates
Uses CPU-first approach (ZeroGPU best practice):
- Load model to CPU (no GPU decorator needed, avoids quota issues)
- Model will be moved to GPU during first inference
"""
import time
max_retries = 3
base_delay = 5.0 # Start with 5 seconds delay
for attempt in range(1, max_retries + 1):
try:
logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model to CPU...")
# Use CPU-first approach (no GPU decorator, avoids quota issues)
status_text = load_medical_model_on_startup_cpu()
# Check if model is ready and update submit button state
is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL)
if is_ready:
logger.info(f"[STARTUP] ✅ Model loaded successfully on attempt {attempt}")
return status_text, gr.update(interactive=is_ready), gr.update(interactive=is_ready)
else:
# Check if status text indicates quota error
if status_text and ("quota" in status_text.lower() or "ZeroGPU" in status_text or
"429" in status_text or "runnning out" in status_text.lower() or
"running out" in status_text.lower()):
if attempt < max_retries:
delay = base_delay * attempt
logger.warning(f"[STARTUP] Quota error detected in status, retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
# Quota exhausted after retries - allow user to proceed, model will load on-demand
status_msg = "⚠️ ZeroGPU quota exhausted.\n⏳ Model will load automatically when you send a message.\n💡 You can also select a model from the dropdown."
logger.info("[STARTUP] Quota exhausted after retries - allowing user to proceed with on-demand loading")
return status_msg, gr.update(interactive=True), gr.update(interactive=True)
# Model didn't load, but no exception - might be a state issue
logger.warning(f"[STARTUP] Model not ready after attempt {attempt}, but no error")
if attempt < max_retries:
delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s
logger.info(f"[STARTUP] Retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
# Even if model didn't load, allow user to try selecting another model
return status_text + "\n⚠️ Model not loaded. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True)
except Exception as e:
error_msg = str(e)
is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or
"quota" in error_msg.lower() or "ZeroGPU" in error_msg or
"runnning out" in error_msg.lower() or "running out" in error_msg.lower())
if is_quota_error and attempt < max_retries:
delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s
logger.warning(f"[STARTUP] ZeroGPU rate limit/quota error on attempt {attempt}/{max_retries}")
logger.info(f"[STARTUP] Retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
logger.error(f"[STARTUP] Error in load_startup_and_update_ui (attempt {attempt}/{max_retries}): {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
if is_quota_error:
# If quota exhausted, allow user to proceed - model will load on-demand
error_display = "⚠️ ZeroGPU quota exhausted.\n⏳ Model will load automatically when you send a message.\n💡 You can also select a model from the dropdown."
logger.info("[STARTUP] Quota exhausted - allowing user to proceed with on-demand loading")
return error_display, gr.update(interactive=True), gr.update(interactive=True)
else:
error_display = f"⚠️ Startup error: {str(e)[:100]}"
if attempt >= max_retries:
logger.error(f"[STARTUP] Failed after {max_retries} attempts")
return error_display, gr.update(interactive=False), gr.update(interactive=False)
# Should not reach here, but just in case
return "⚠️ Startup failed after retries. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True)
demo.load(
fn=load_startup_and_update_ui,
inputs=None,
outputs=[model_status, submit_button, message_input]
)
# Note: We removed the preload on focus functionality because:
# 1. Model loading requires GPU access (device_map="auto" needs GPU in ZeroGPU)
# 2. The startup function already loads the model with GPU decorator
# 3. Preloading without GPU decorator would fail or cause conflicts
# 4. If startup fails, user can select a model from dropdown to trigger loading
# Wrap stream_chat - ensure model is loaded before starting (don't load inside stream_chat to save time)
def stream_chat_with_model_check(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
):
# Check if model is loaded - if not, show error (don't load here to save stream_chat time)
model_loaded = is_model_loaded(medical_model_name)
if not model_loaded:
loading_state = get_model_loading_state(medical_model_name)
# Debug logging to understand why model check fails
logger.debug(f"[STREAM_CHAT] Model check: name={medical_model_name}, loaded={model_loaded}, state={loading_state}, in_dict={medical_model_name in config.global_medical_models}, model_exists={config.global_medical_models.get(medical_model_name) is not None if medical_model_name in config.global_medical_models else False}")
if loading_state == "loading":
error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
else:
error_msg = f"⚠️ {medical_model_name} is not loaded. Please wait a moment for the model to finish loading, or select a model from the dropdown to load it."
updated_history = history + [{"role": "assistant", "content": error_msg}]
yield updated_history, ""
return
# If request is None, create a mock request for compatibility
if request is None:
class MockRequest:
session_hash = "anonymous"
request = MockRequest()
# Model is loaded, proceed with stream_chat (no model loading here to save time)
# Note: We handle "BodyStreamBuffer was aborted" errors by catching stream disconnections
# and not attempting to yield after the client has disconnected
last_result = None
stream_aborted = False
try:
for result in stream_chat(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
):
last_result = result
try:
yield result
except (GeneratorExit, StopIteration, RuntimeError) as stream_error:
# Stream was closed/aborted by client - don't try to yield again
error_msg_lower = str(stream_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream was aborted by client, stopping gracefully")
stream_aborted = True
break
raise
except (GeneratorExit, StopIteration) as stream_exit:
# Stream was closed - this is normal, just log and exit
logger.info(f"[UI] Stream closed normally")
stream_aborted = True
return
except Exception as e:
# Handle any errors gracefully
error_str = str(e)
error_msg_lower = error_str.lower()
# Check if this is a stream abort error
is_stream_abort = (
'bodystreambuffer' in error_msg_lower or
'stream' in error_msg_lower and 'abort' in error_msg_lower or
'connection' in error_msg_lower and 'abort' in error_msg_lower or
isinstance(e, (GeneratorExit, StopIteration, RuntimeError)) and 'abort' in error_msg_lower
)
if is_stream_abort:
logger.info(f"[UI] Stream was aborted (BodyStreamBuffer or similar): {error_str[:100]}")
stream_aborted = True
# If we have a result, it was already yielded, so just return
return
is_gpu_timeout = 'gpu task aborted' in error_msg_lower or 'timeout' in error_msg_lower
logger.error(f"Error in stream_chat_with_model_check: {error_str}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
# Check if we have a valid answer in the last result
has_valid_answer = False
if last_result is not None:
try:
last_history, last_thoughts = last_result
# Find the last assistant message in the history
if last_history and isinstance(last_history, list):
for msg in reversed(last_history):
if isinstance(msg, dict) and msg.get("role") == "assistant":
assistant_content = msg.get("content", "")
# Check if it's a valid answer (not empty, not an error message)
if assistant_content and len(assistant_content.strip()) > 0:
# Not an error message
if not assistant_content.strip().startswith("⚠️") and not assistant_content.strip().startswith("⏳"):
has_valid_answer = True
break
except Exception as parse_error:
logger.debug(f"Error parsing last_result: {parse_error}")
# If stream was aborted, don't try to yield - just return
if stream_aborted:
logger.info(f"[UI] Stream was aborted, not yielding final result")
return
# If we have a valid answer, use it (don't show error message)
if has_valid_answer:
logger.info(f"[UI] Error occurred but final answer already generated, displaying it without error message")
try:
yield last_result
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding final result, ignoring")
else:
raise
return
# For GPU timeouts, try to use last result even if it's partial
if is_gpu_timeout and last_result is not None:
logger.info(f"[UI] GPU timeout occurred, using last available result")
try:
yield last_result
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding timeout result, ignoring")
else:
raise
return
# Only show error for non-timeout errors when we have no valid answer
# For GPU timeouts with no result, show empty message (not error)
if is_gpu_timeout:
logger.info(f"[UI] GPU timeout with no result, showing empty assistant message")
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": ""}]
try:
yield updated_history, ""
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding empty message, ignoring")
else:
raise
else:
# For other errors, show minimal error message only if no result
error_display = f"⚠️ An error occurred: {error_str[:200]}"
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_display}]
try:
yield updated_history, ""
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding error message, ignoring")
else:
raise
submit_button.click(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
message_input.submit(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
return demo
|