Y Phung Nguyen commited on
Commit
03d8100
·
1 Parent(s): e52570b

Upd model efficiency and GPU task assignment

Browse files

ZeroGPU tagging: Each MedSwin task has @spaces.GPU(max_duration=120) decorator
No batching: Tasks execute individually (respects token limits)
Retry logic: Automatic retry with exponential backoff for GPU errors
Sequential delays: Small delays between GPU requests to prevent conflicts
Model status tracking: Real-time status updates
UI protection: Prevents submission while model is loading
Auto-loading: Models load automatically when selected

Files changed (5) hide show
  1. config.py +9 -0
  2. models.py +49 -12
  3. pipeline.py +6 -0
  4. supervisor.py +43 -3
  5. ui.py +97 -2
config.py CHANGED
@@ -131,6 +131,15 @@ CSS = """
131
  background: #f3e5f5;
132
  color: #7b1fa2;
133
  }
 
 
 
 
 
 
 
 
 
134
  @media (min-width: 768px) {
135
  .main-container {
136
  display: flex;
 
131
  background: #f3e5f5;
132
  color: #7b1fa2;
133
  }
134
+ .model-status {
135
+ margin-top: 5px;
136
+ padding: 8px;
137
+ border-radius: 5px;
138
+ font-size: 13px;
139
+ font-weight: 500;
140
+ background-color: #f5f5f5;
141
+ border: 1px solid #e0e0e0;
142
+ }
143
  @media (min-width: 768px) {
144
  .main-container {
145
  display: flex;
models.py CHANGED
@@ -1,5 +1,6 @@
1
  """Model initialization and management"""
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from llama_index.llms.huggingface import HuggingFaceLLM
5
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
@@ -13,23 +14,59 @@ except ImportError:
13
  TTS_AVAILABLE = False
14
  TTS = None
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def initialize_medical_model(model_name: str):
18
  """Initialize medical model (MedSwin) - download on demand"""
19
  if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
 
20
  logger.info(f"Initializing medical model: {model_name}...")
21
- model_path = config.MEDSWIN_MODELS[model_name]
22
- tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
23
- model = AutoModelForCausalLM.from_pretrained(
24
- model_path,
25
- device_map="auto",
26
- trust_remote_code=True,
27
- token=config.HF_TOKEN,
28
- torch_dtype=torch.float16
29
- )
30
- config.global_medical_models[model_name] = model
31
- config.global_medical_tokenizers[model_name] = tokenizer
32
- logger.info(f"Medical model {model_name} initialized successfully")
 
 
 
 
 
 
 
 
 
 
33
  return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
34
 
35
 
 
1
  """Model initialization and management"""
2
  import torch
3
+ import threading
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from llama_index.llms.huggingface import HuggingFaceLLM
6
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
 
14
  TTS_AVAILABLE = False
15
  TTS = None
16
 
17
+ # Model loading state tracking
18
+ _model_loading_states = {}
19
+ _model_loading_lock = threading.Lock()
20
+
21
+
22
+ def set_model_loading_state(model_name: str, state: str):
23
+ """Set model loading state: 'loading', 'loaded', 'error'"""
24
+ with _model_loading_lock:
25
+ _model_loading_states[model_name] = state
26
+ logger.debug(f"Model {model_name} state set to: {state}")
27
+
28
+
29
+ def get_model_loading_state(model_name: str) -> str:
30
+ """Get model loading state: 'loading', 'loaded', 'error', or 'unknown'"""
31
+ with _model_loading_lock:
32
+ return _model_loading_states.get(model_name, "unknown")
33
+
34
+
35
+ def is_model_loaded(model_name: str) -> bool:
36
+ """Check if model is loaded and ready"""
37
+ with _model_loading_lock:
38
+ return (model_name in config.global_medical_models and
39
+ config.global_medical_models[model_name] is not None and
40
+ _model_loading_states.get(model_name) == "loaded")
41
+
42
 
43
  def initialize_medical_model(model_name: str):
44
  """Initialize medical model (MedSwin) - download on demand"""
45
  if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
46
+ set_model_loading_state(model_name, "loading")
47
  logger.info(f"Initializing medical model: {model_name}...")
48
+ try:
49
+ model_path = config.MEDSWIN_MODELS[model_name]
50
+ tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ model_path,
53
+ device_map="auto",
54
+ trust_remote_code=True,
55
+ token=config.HF_TOKEN,
56
+ torch_dtype=torch.float16
57
+ )
58
+ config.global_medical_models[model_name] = model
59
+ config.global_medical_tokenizers[model_name] = tokenizer
60
+ set_model_loading_state(model_name, "loaded")
61
+ logger.info(f"Medical model {model_name} initialized successfully")
62
+ except Exception as e:
63
+ set_model_loading_state(model_name, "error")
64
+ logger.error(f"Failed to initialize medical model {model_name}: {e}")
65
+ raise
66
+ else:
67
+ # Model already loaded, ensure state is set
68
+ if get_model_loading_state(model_name) != "loaded":
69
+ set_model_loading_state(model_name, "loaded")
70
  return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
71
 
72
 
pipeline.py CHANGED
@@ -571,6 +571,12 @@ def stream_chat(
571
  if len(rag_contexts) > 1 and idx <= len(rag_contexts):
572
  task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context
573
 
 
 
 
 
 
 
574
  try:
575
  task_answer = execute_medswin_task(
576
  medical_model_obj=medical_model_obj,
 
571
  if len(rag_contexts) > 1 and idx <= len(rag_contexts):
572
  task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context
573
 
574
+ # Add small delay between GPU requests to prevent ZeroGPU scheduler conflicts
575
+ if idx > 1:
576
+ delay = 0.5 # 500ms delay between sequential GPU requests
577
+ logger.debug(f"[MEDSWIN] Waiting {delay}s before next GPU request to avoid scheduler conflicts...")
578
+ time.sleep(delay)
579
+
580
  try:
581
  task_answer = execute_medswin_task(
582
  medical_model_obj=medical_model_obj,
supervisor.py CHANGED
@@ -559,8 +559,7 @@ def gemini_supervisor_rag_brainstorm(query: str, retrieved_docs: str, time_elaps
559
  }
560
 
561
 
562
- @spaces.GPU(max_duration=120)
563
- def execute_medswin_task(
564
  medical_model_obj,
565
  medical_tokenizer,
566
  task_instruction: str,
@@ -572,7 +571,7 @@ def execute_medswin_task(
572
  top_k: int,
573
  penalty: float
574
  ) -> str:
575
- """MedSwin Specialist: Execute a single task assigned by Gemini Supervisor"""
576
  if context:
577
  full_prompt = f"{system_prompt_base}\n\nContext:\n{context}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):"
578
  else:
@@ -622,6 +621,47 @@ def execute_medswin_task(
622
  return response
623
 
624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  async def gemini_supervisor_synthesize_async(query: str, medswin_answers: list, rag_contexts: list, search_contexts: list, breakdown: dict) -> str:
626
  """Gemini Supervisor: Synthesize final answer from all MedSwin responses"""
627
  context_summary = ""
 
559
  }
560
 
561
 
562
+ def _execute_medswin_core(
 
563
  medical_model_obj,
564
  medical_tokenizer,
565
  task_instruction: str,
 
571
  top_k: int,
572
  penalty: float
573
  ) -> str:
574
+ """Core MedSwin execution logic (without GPU decorator for retry logic)"""
575
  if context:
576
  full_prompt = f"{system_prompt_base}\n\nContext:\n{context}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):"
577
  else:
 
621
  return response
622
 
623
 
624
+ @spaces.GPU(max_duration=120)
625
+ def execute_medswin_task(
626
+ medical_model_obj,
627
+ medical_tokenizer,
628
+ task_instruction: str,
629
+ context: str,
630
+ system_prompt_base: str,
631
+ temperature: float,
632
+ max_new_tokens: int,
633
+ top_p: float,
634
+ top_k: int,
635
+ penalty: float
636
+ ) -> str:
637
+ """
638
+ MedSwin Specialist: Execute a single task assigned by Gemini Supervisor (with ZeroGPU tag)
639
+ Includes retry logic with exponential backoff to handle GPU task aborted errors
640
+ """
641
+ import time
642
+ max_retries = 3
643
+ base_delay = 1.0 # Base delay in seconds
644
+
645
+ for attempt in range(max_retries):
646
+ try:
647
+ return _execute_medswin_core(
648
+ medical_model_obj, medical_tokenizer, task_instruction, context,
649
+ system_prompt_base, temperature, max_new_tokens, top_p, top_k, penalty
650
+ )
651
+ except Exception as e:
652
+ error_msg = str(e).lower()
653
+ is_gpu_error = 'gpu task aborted' in error_msg or 'gpu' in error_msg or 'zerogpu' in error_msg
654
+
655
+ if is_gpu_error and attempt < max_retries - 1:
656
+ delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s
657
+ logger.warning(f"[MEDSWIN] GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
658
+ time.sleep(delay)
659
+ continue
660
+ else:
661
+ logger.error(f"[MEDSWIN] Task failed after {attempt + 1} attempts: {e}")
662
+ raise
663
+
664
+
665
  async def gemini_supervisor_synthesize_async(query: str, medswin_answers: list, rag_contexts: list, search_contexts: list, breakdown: dict) -> str:
666
  """Gemini Supervisor: Synthesize final answer from all MedSwin responses"""
667
  context_summary = ""
ui.py CHANGED
@@ -5,6 +5,7 @@ from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODE
5
  from indexing import create_or_update_index
6
  from pipeline import stream_chat
7
  from voice import transcribe_audio, generate_speech
 
8
 
9
 
10
  def create_demo():
@@ -176,6 +177,13 @@ def create_demo():
176
  label="Medical Model",
177
  info="MedSwin TA (default), others download on first use"
178
  )
 
 
 
 
 
 
 
179
 
180
  system_prompt = gr.Textbox(
181
  value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
@@ -250,8 +258,95 @@ def create_demo():
250
  outputs=[agentic_thoughts_box, show_thoughts_state]
251
  )
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  submit_button.click(
254
- fn=stream_chat,
255
  inputs=[
256
  message_input,
257
  chatbot,
@@ -274,7 +369,7 @@ def create_demo():
274
  )
275
 
276
  message_input.submit(
277
- fn=stream_chat,
278
  inputs=[
279
  message_input,
280
  chatbot,
 
5
  from indexing import create_or_update_index
6
  from pipeline import stream_chat
7
  from voice import transcribe_audio, generate_speech
8
+ from models import initialize_medical_model, is_model_loaded, get_model_loading_state, set_model_loading_state
9
 
10
 
11
  def create_demo():
 
177
  label="Medical Model",
178
  info="MedSwin TA (default), others download on first use"
179
  )
180
+ model_status = gr.Textbox(
181
+ value="Checking model status...",
182
+ label="Model Status",
183
+ interactive=False,
184
+ visible=True,
185
+ elem_classes="model-status"
186
+ )
187
 
188
  system_prompt = gr.Textbox(
189
  value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
 
258
  outputs=[agentic_thoughts_box, show_thoughts_state]
259
  )
260
 
261
+ def load_model_and_update_status(model_name):
262
+ """Load model and update status, return status text and whether model is ready"""
263
+ try:
264
+ if is_model_loaded(model_name):
265
+ return "✅ The model has been loaded successfully", True
266
+
267
+ state = get_model_loading_state(model_name)
268
+ if state == "loading":
269
+ return "⏳ The model is being loaded, please wait...", False
270
+ elif state == "error":
271
+ return "❌ Error loading model. Please try again.", False
272
+
273
+ # Start loading
274
+ set_model_loading_state(model_name, "loading")
275
+ try:
276
+ initialize_medical_model(model_name)
277
+ return "✅ The model has been loaded successfully", True
278
+ except Exception as e:
279
+ set_model_loading_state(model_name, "error")
280
+ return f"❌ Error loading model: {str(e)[:100]}", False
281
+ except Exception as e:
282
+ return f"❌ Error: {str(e)[:100]}", False
283
+
284
+ def check_model_status(model_name):
285
+ """Check current model status without loading"""
286
+ if is_model_loaded(model_name):
287
+ return "✅ The model has been loaded successfully", True
288
+ state = get_model_loading_state(model_name)
289
+ if state == "loading":
290
+ return "⏳ The model is being loaded, please wait...", False
291
+ elif state == "error":
292
+ return "❌ Error loading model. Please try again.", False
293
+ else:
294
+ return "⚠️ Model not loaded. Click to load or it will load on first use.", False
295
+
296
+ # Initialize status on load
297
+ def init_model_status():
298
+ status_text, is_ready = check_model_status(DEFAULT_MEDICAL_MODEL)
299
+ return status_text
300
+
301
+ # Handle model selection change
302
+ def on_model_change(model_name):
303
+ status_text, is_ready = load_model_and_update_status(model_name)
304
+ submit_enabled = is_ready
305
+ return (
306
+ status_text,
307
+ gr.update(interactive=submit_enabled),
308
+ gr.update(interactive=submit_enabled)
309
+ )
310
+
311
+ medical_model.change(
312
+ fn=on_model_change,
313
+ inputs=[medical_model],
314
+ outputs=[model_status, submit_button, message_input]
315
+ )
316
+
317
+ # Initialize status
318
+ demo.load(
319
+ fn=init_model_status,
320
+ outputs=[model_status]
321
+ )
322
+
323
+ # Wrap stream_chat to check model status before execution
324
+ def stream_chat_with_model_check(
325
+ message, history, system_prompt, temperature, max_new_tokens,
326
+ top_p, top_k, penalty, retriever_k, merge_threshold,
327
+ use_rag, medical_model_name, use_web_search,
328
+ enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request
329
+ ):
330
+ # Check if model is loaded
331
+ if not is_model_loaded(medical_model_name):
332
+ # Try to load it
333
+ status_text, is_ready = load_model_and_update_status(medical_model_name)
334
+ if not is_ready:
335
+ error_msg = "⚠️ Model is not ready. Please wait for the model to finish loading before sending messages."
336
+ yield history + [{"role": "assistant", "content": error_msg}], ""
337
+ return
338
+
339
+ # Model is ready, proceed with chat
340
+ for result in stream_chat(
341
+ message, history, system_prompt, temperature, max_new_tokens,
342
+ top_p, top_k, penalty, retriever_k, merge_threshold,
343
+ use_rag, medical_model_name, use_web_search,
344
+ enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
345
+ ):
346
+ yield result
347
+
348
  submit_button.click(
349
+ fn=stream_chat_with_model_check,
350
  inputs=[
351
  message_input,
352
  chatbot,
 
369
  )
370
 
371
  message_input.submit(
372
+ fn=stream_chat_with_model_check,
373
  inputs=[
374
  message_input,
375
  chatbot,