Y Phung Nguyen commited on
Commit
a5fe328
·
1 Parent(s): c5ac360

Upd models loader #2

Browse files
Files changed (2) hide show
  1. models.py +39 -0
  2. ui.py +86 -65
models.py CHANGED
@@ -63,6 +63,11 @@ def initialize_medical_model(model_name: str):
63
  set_model_loading_state(model_name, "loading")
64
  logger.info(f"Initializing medical model: {model_name}...")
65
  try:
 
 
 
 
 
66
  model_path = config.MEDSWIN_MODELS[model_name]
67
  tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
68
  model = AutoModelForCausalLM.from_pretrained(
@@ -76,9 +81,17 @@ def initialize_medical_model(model_name: str):
76
  config.global_medical_tokenizers[model_name] = tokenizer
77
  set_model_loading_state(model_name, "loaded")
78
  logger.info(f"Medical model {model_name} initialized successfully")
 
 
 
 
 
79
  except Exception as e:
80
  set_model_loading_state(model_name, "error")
81
  logger.error(f"Failed to initialize medical model {model_name}: {e}")
 
 
 
82
  raise
83
  else:
84
  # Model already loaded, ensure state is set
@@ -93,13 +106,26 @@ def initialize_tts_model():
93
  return None
94
  if config.global_tts_model is None:
95
  try:
 
 
 
 
 
96
  logger.info("Initializing TTS model for voice generation...")
97
  config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False)
98
  logger.info("TTS model initialized successfully")
 
 
 
 
 
99
  except Exception as e:
100
  logger.warning(f"TTS model initialization failed: {e}")
101
  logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts")
102
  config.global_tts_model = None
 
 
 
103
  return config.global_tts_model
104
 
105
  def initialize_whisper_model():
@@ -109,6 +135,11 @@ def initialize_whisper_model():
109
  return None
110
  if config.global_whisper_model is None:
111
  try:
 
 
 
 
 
112
  logger.info("Initializing Whisper model (openai/whisper-large-v3-turbo) from Hugging Face...")
113
  model_id = "openai/whisper-large-v3-turbo"
114
  processor = WhisperProcessor.from_pretrained(model_id, token=config.HF_TOKEN)
@@ -121,10 +152,18 @@ def initialize_whisper_model():
121
  # Store both processor and model
122
  config.global_whisper_model = {"processor": processor, "model": model}
123
  logger.info(f"Whisper model ({model_id}) initialized successfully")
 
 
 
 
 
124
  except Exception as e:
125
  logger.warning(f"Whisper model initialization failed: {e}")
126
  logger.warning("ASR features will be disabled. Install with: pip install transformers torchaudio")
127
  config.global_whisper_model = None
 
 
 
128
  return config.global_whisper_model
129
 
130
  def get_or_create_embed_model():
 
63
  set_model_loading_state(model_name, "loading")
64
  logger.info(f"Initializing medical model: {model_name}...")
65
  try:
66
+ # Clear GPU cache before loading to prevent memory issues
67
+ if torch.cuda.is_available():
68
+ torch.cuda.empty_cache()
69
+ logger.debug("Cleared GPU cache before model loading")
70
+
71
  model_path = config.MEDSWIN_MODELS[model_name]
72
  tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
73
  model = AutoModelForCausalLM.from_pretrained(
 
81
  config.global_medical_tokenizers[model_name] = tokenizer
82
  set_model_loading_state(model_name, "loaded")
83
  logger.info(f"Medical model {model_name} initialized successfully")
84
+
85
+ # Clear cache after loading to free up temporary memory
86
+ if torch.cuda.is_available():
87
+ torch.cuda.empty_cache()
88
+ logger.debug("Cleared GPU cache after model loading")
89
  except Exception as e:
90
  set_model_loading_state(model_name, "error")
91
  logger.error(f"Failed to initialize medical model {model_name}: {e}")
92
+ # Clear cache on error
93
+ if torch.cuda.is_available():
94
+ torch.cuda.empty_cache()
95
  raise
96
  else:
97
  # Model already loaded, ensure state is set
 
106
  return None
107
  if config.global_tts_model is None:
108
  try:
109
+ # Clear GPU cache before loading
110
+ if torch.cuda.is_available():
111
+ torch.cuda.empty_cache()
112
+ logger.debug("Cleared GPU cache before TTS model loading")
113
+
114
  logger.info("Initializing TTS model for voice generation...")
115
  config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False)
116
  logger.info("TTS model initialized successfully")
117
+
118
+ # Clear cache after loading
119
+ if torch.cuda.is_available():
120
+ torch.cuda.empty_cache()
121
+ logger.debug("Cleared GPU cache after TTS model loading")
122
  except Exception as e:
123
  logger.warning(f"TTS model initialization failed: {e}")
124
  logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts")
125
  config.global_tts_model = None
126
+ # Clear cache on error
127
+ if torch.cuda.is_available():
128
+ torch.cuda.empty_cache()
129
  return config.global_tts_model
130
 
131
  def initialize_whisper_model():
 
135
  return None
136
  if config.global_whisper_model is None:
137
  try:
138
+ # Clear GPU cache before loading
139
+ if torch.cuda.is_available():
140
+ torch.cuda.empty_cache()
141
+ logger.debug("Cleared GPU cache before Whisper model loading")
142
+
143
  logger.info("Initializing Whisper model (openai/whisper-large-v3-turbo) from Hugging Face...")
144
  model_id = "openai/whisper-large-v3-turbo"
145
  processor = WhisperProcessor.from_pretrained(model_id, token=config.HF_TOKEN)
 
152
  # Store both processor and model
153
  config.global_whisper_model = {"processor": processor, "model": model}
154
  logger.info(f"Whisper model ({model_id}) initialized successfully")
155
+
156
+ # Clear cache after loading
157
+ if torch.cuda.is_available():
158
+ torch.cuda.empty_cache()
159
+ logger.debug("Cleared GPU cache after Whisper model loading")
160
  except Exception as e:
161
  logger.warning(f"Whisper model initialization failed: {e}")
162
  logger.warning("ASR features will be disabled. Install with: pip install transformers torchaudio")
163
  config.global_whisper_model = None
164
+ # Clear cache on error
165
+ if torch.cuda.is_available():
166
+ torch.cuda.empty_cache()
167
  return config.global_whisper_model
168
 
169
  def get_or_create_embed_model():
ui.py CHANGED
@@ -290,7 +290,7 @@ def create_demo():
290
  )
291
 
292
  # GPU-decorated function to load any model (for user selection)
293
- @spaces.GPU(max_duration=120)
294
  def load_model_with_gpu(model_name):
295
  """Load medical model (GPU-decorated for ZeroGPU compatibility)"""
296
  try:
@@ -404,59 +404,103 @@ def create_demo():
404
  is_ready = is_model_loaded(model_name)
405
  return status_text, is_ready
406
 
407
- # GPU-decorated function to load model on startup
408
- @spaces.GPU(max_duration=120)
409
- def load_default_model_on_startup():
410
- """Load default medical model on startup (GPU-decorated for ZeroGPU compatibility)"""
 
 
 
 
 
411
  try:
 
 
 
 
 
 
412
  if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
413
- logger.info(f"Loading default medical model on startup: {DEFAULT_MEDICAL_MODEL}...")
414
  set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
415
  try:
416
  initialize_medical_model(DEFAULT_MEDICAL_MODEL)
417
- logger.info(f"✅ Default medical model {DEFAULT_MEDICAL_MODEL} loaded successfully on startup!")
418
- return f"✅ {DEFAULT_MEDICAL_MODEL} loaded successfully"
419
  except Exception as e:
420
- logger.error(f"Failed to load default medical model on startup: {e}")
 
421
  set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
422
- return f"❌ Error loading model: {str(e)[:100]}"
423
  else:
424
- logger.info(f"Default medical model {DEFAULT_MEDICAL_MODEL} is already loaded")
425
- return f" {DEFAULT_MEDICAL_MODEL} is ready"
426
- except Exception as e:
427
- logger.error(f"Error in model loading startup: {e}")
428
- return f"⚠️ Startup loading error: {str(e)[:100]}"
429
-
430
- # GPU-decorated function to load default TTS and ASR models on startup
431
- @spaces.GPU(max_duration=120)
432
- def load_voice_models_on_startup():
433
- """Load default TTS model (maya1) and ASR model (Whisper) on startup"""
434
- try:
435
- # Load TTS model
436
  if TTS_AVAILABLE:
437
- logger.info("Loading default TTS model (maya1) on startup...")
438
- initialize_tts_model()
439
- if config.global_tts_model is not None:
440
- logger.info("✅ Default TTS model (maya1) loaded successfully on startup!")
441
- else:
442
- logger.warning("⚠️ TTS model failed to load on startup")
 
 
 
 
 
 
443
  else:
444
- logger.warning("TTS library not installed; skipping TTS preload.")
 
445
 
446
- # Load ASR (Whisper) model
 
 
 
 
 
 
447
  if WHISPER_AVAILABLE:
448
- logger.info("Loading default ASR model (Whisper large-v3-turbo) on startup...")
449
- initialize_whisper_model()
450
- if config.global_whisper_model is not None:
451
- logger.info("✅ Default ASR model (Whisper large-v3-turbo) loaded successfully on startup!")
452
- else:
453
- logger.warning("⚠️ ASR model failed to load on startup")
 
 
 
 
 
 
454
  else:
455
- logger.warning("Whisper transformers not installed; skipping ASR preload.")
 
 
 
 
 
 
 
 
 
 
 
 
456
  except Exception as e:
457
- logger.error(f"Error in voice models loading startup: {e}")
458
  import traceback
459
- logger.debug(f"Full traceback: {traceback.format_exc()}")
 
 
 
 
460
 
461
  # Initialize status on load
462
  def init_model_status():
@@ -522,33 +566,10 @@ def create_demo():
522
  outputs=[model_status, submit_button, message_input]
523
  )
524
 
525
- # Load models on startup - they will be loaded in separate GPU sessions
526
- # First load medical model
527
- demo.load(
528
- fn=load_default_model_on_startup,
529
- inputs=None,
530
- outputs=[model_status]
531
- )
532
- # Then load voice models (TTS and ASR)
533
- demo.load(
534
- fn=load_voice_models_on_startup,
535
- inputs=None,
536
- outputs=None
537
- )
538
- # Finally update status to show all models
539
- def update_status_after_load():
540
- try:
541
- result = check_model_status(DEFAULT_MEDICAL_MODEL)
542
- if result and isinstance(result, tuple) and len(result) == 2:
543
- return result[0]
544
- else:
545
- return "⚠️ Unable to check model status"
546
- except Exception as e:
547
- logger.error(f"Error updating status after load: {e}")
548
- return f"⚠️ Error: {str(e)[:100]}"
549
-
550
  demo.load(
551
- fn=update_status_after_load,
552
  inputs=None,
553
  outputs=[model_status]
554
  )
 
290
  )
291
 
292
  # GPU-decorated function to load any model (for user selection)
293
+ # @spaces.GPU(max_duration=120)
294
  def load_model_with_gpu(model_name):
295
  """Load medical model (GPU-decorated for ZeroGPU compatibility)"""
296
  try:
 
404
  is_ready = is_model_loaded(model_name)
405
  return status_text, is_ready
406
 
407
+ # GPU-decorated function to load ALL models sequentially on startup
408
+ # This prevents ZeroGPU conflicts from multiple simultaneous GPU requests
409
+ # @spaces.GPU(max_duration=180)
410
+ def load_all_models_on_startup():
411
+ """Load all models sequentially in a single GPU session to avoid ZeroGPU conflicts"""
412
+ import time
413
+ import torch
414
+ status_messages = []
415
+
416
  try:
417
+ # Clear GPU cache at start
418
+ if torch.cuda.is_available():
419
+ torch.cuda.empty_cache()
420
+ logger.info("[STARTUP] Cleared GPU cache before model loading")
421
+
422
+ # Step 1: Load medical model (MedSwin)
423
  if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
424
+ logger.info(f"[STARTUP] Step 1/3: Loading medical model: {DEFAULT_MEDICAL_MODEL}...")
425
  set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
426
  try:
427
  initialize_medical_model(DEFAULT_MEDICAL_MODEL)
428
+ status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded")
429
+ logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} loaded successfully!")
430
  except Exception as e:
431
+ status_messages.append(f" MedSwin ({DEFAULT_MEDICAL_MODEL}): error")
432
+ logger.error(f"[STARTUP] Failed to load medical model: {e}")
433
  set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
 
434
  else:
435
+ status_messages.append(f" MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
436
+ logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
437
+
438
+ # Small delay to let GPU settle and clear cache
439
+ time.sleep(2)
440
+ if torch.cuda.is_available():
441
+ torch.cuda.empty_cache()
442
+ logger.debug("[STARTUP] Cleared GPU cache after medical model")
443
+
444
+ # Step 2: Load TTS model (maya1)
 
 
445
  if TTS_AVAILABLE:
446
+ logger.info("[STARTUP] Step 2/3: Loading TTS model (maya1)...")
447
+ try:
448
+ initialize_tts_model()
449
+ if config.global_tts_model is not None:
450
+ status_messages.append("✅ TTS (maya1): loaded")
451
+ logger.info("[STARTUP] TTS model loaded successfully!")
452
+ else:
453
+ status_messages.append("⚠️ TTS (maya1): failed")
454
+ logger.warning("[STARTUP] ⚠️ TTS model failed to load")
455
+ except Exception as e:
456
+ status_messages.append("❌ TTS (maya1): error")
457
+ logger.error(f"[STARTUP] TTS model loading error: {e}")
458
  else:
459
+ status_messages.append("TTS: library not available")
460
+ logger.warning("[STARTUP] TTS library not installed")
461
 
462
+ # Small delay to let GPU settle and clear cache
463
+ time.sleep(2)
464
+ if torch.cuda.is_available():
465
+ torch.cuda.empty_cache()
466
+ logger.debug("[STARTUP] Cleared GPU cache after TTS model")
467
+
468
+ # Step 3: Load ASR model (Whisper)
469
  if WHISPER_AVAILABLE:
470
+ logger.info("[STARTUP] Step 3/3: Loading ASR model (Whisper)...")
471
+ try:
472
+ initialize_whisper_model()
473
+ if config.global_whisper_model is not None:
474
+ status_messages.append("✅ ASR (Whisper): loaded")
475
+ logger.info("[STARTUP] ASR model loaded successfully!")
476
+ else:
477
+ status_messages.append("⚠️ ASR (Whisper): failed")
478
+ logger.warning("[STARTUP] ⚠️ ASR model failed to load")
479
+ except Exception as e:
480
+ status_messages.append("❌ ASR (Whisper): error")
481
+ logger.error(f"[STARTUP] ASR model loading error: {e}")
482
  else:
483
+ status_messages.append(" ASR: library not available")
484
+ logger.warning("[STARTUP] Whisper library not installed")
485
+
486
+ # Final cache clear
487
+ if torch.cuda.is_available():
488
+ torch.cuda.empty_cache()
489
+ logger.debug("[STARTUP] Final GPU cache clear")
490
+
491
+ # Return combined status
492
+ status_text = "\n".join(status_messages)
493
+ logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}")
494
+ return status_text
495
+
496
  except Exception as e:
497
+ logger.error(f"[STARTUP] ❌ Error in model loading startup: {e}")
498
  import traceback
499
+ logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
500
+ # Clear cache on error
501
+ if torch.cuda.is_available():
502
+ torch.cuda.empty_cache()
503
+ return f"⚠️ Startup loading error: {str(e)[:100]}"
504
 
505
  # Initialize status on load
506
  def init_model_status():
 
566
  outputs=[model_status, submit_button, message_input]
567
  )
568
 
569
+ # Load ALL models sequentially in a SINGLE GPU session to avoid ZeroGPU conflicts
570
+ # This prevents "GPU aborted" errors from multiple simultaneous GPU requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  demo.load(
572
+ fn=load_all_models_on_startup,
573
  inputs=None,
574
  outputs=[model_status]
575
  )