Spaces:
Running
on
Zero
Running
on
Zero
Y Phung Nguyen
commited on
Commit
·
a5fe328
1
Parent(s):
c5ac360
Upd models loader #2
Browse files
models.py
CHANGED
|
@@ -63,6 +63,11 @@ def initialize_medical_model(model_name: str):
|
|
| 63 |
set_model_loading_state(model_name, "loading")
|
| 64 |
logger.info(f"Initializing medical model: {model_name}...")
|
| 65 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
model_path = config.MEDSWIN_MODELS[model_name]
|
| 67 |
tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
|
| 68 |
model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -76,9 +81,17 @@ def initialize_medical_model(model_name: str):
|
|
| 76 |
config.global_medical_tokenizers[model_name] = tokenizer
|
| 77 |
set_model_loading_state(model_name, "loaded")
|
| 78 |
logger.info(f"Medical model {model_name} initialized successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
except Exception as e:
|
| 80 |
set_model_loading_state(model_name, "error")
|
| 81 |
logger.error(f"Failed to initialize medical model {model_name}: {e}")
|
|
|
|
|
|
|
|
|
|
| 82 |
raise
|
| 83 |
else:
|
| 84 |
# Model already loaded, ensure state is set
|
|
@@ -93,13 +106,26 @@ def initialize_tts_model():
|
|
| 93 |
return None
|
| 94 |
if config.global_tts_model is None:
|
| 95 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
logger.info("Initializing TTS model for voice generation...")
|
| 97 |
config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False)
|
| 98 |
logger.info("TTS model initialized successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
except Exception as e:
|
| 100 |
logger.warning(f"TTS model initialization failed: {e}")
|
| 101 |
logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts")
|
| 102 |
config.global_tts_model = None
|
|
|
|
|
|
|
|
|
|
| 103 |
return config.global_tts_model
|
| 104 |
|
| 105 |
def initialize_whisper_model():
|
|
@@ -109,6 +135,11 @@ def initialize_whisper_model():
|
|
| 109 |
return None
|
| 110 |
if config.global_whisper_model is None:
|
| 111 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
logger.info("Initializing Whisper model (openai/whisper-large-v3-turbo) from Hugging Face...")
|
| 113 |
model_id = "openai/whisper-large-v3-turbo"
|
| 114 |
processor = WhisperProcessor.from_pretrained(model_id, token=config.HF_TOKEN)
|
|
@@ -121,10 +152,18 @@ def initialize_whisper_model():
|
|
| 121 |
# Store both processor and model
|
| 122 |
config.global_whisper_model = {"processor": processor, "model": model}
|
| 123 |
logger.info(f"Whisper model ({model_id}) initialized successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
except Exception as e:
|
| 125 |
logger.warning(f"Whisper model initialization failed: {e}")
|
| 126 |
logger.warning("ASR features will be disabled. Install with: pip install transformers torchaudio")
|
| 127 |
config.global_whisper_model = None
|
|
|
|
|
|
|
|
|
|
| 128 |
return config.global_whisper_model
|
| 129 |
|
| 130 |
def get_or_create_embed_model():
|
|
|
|
| 63 |
set_model_loading_state(model_name, "loading")
|
| 64 |
logger.info(f"Initializing medical model: {model_name}...")
|
| 65 |
try:
|
| 66 |
+
# Clear GPU cache before loading to prevent memory issues
|
| 67 |
+
if torch.cuda.is_available():
|
| 68 |
+
torch.cuda.empty_cache()
|
| 69 |
+
logger.debug("Cleared GPU cache before model loading")
|
| 70 |
+
|
| 71 |
model_path = config.MEDSWIN_MODELS[model_name]
|
| 72 |
tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
|
| 73 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 81 |
config.global_medical_tokenizers[model_name] = tokenizer
|
| 82 |
set_model_loading_state(model_name, "loaded")
|
| 83 |
logger.info(f"Medical model {model_name} initialized successfully")
|
| 84 |
+
|
| 85 |
+
# Clear cache after loading to free up temporary memory
|
| 86 |
+
if torch.cuda.is_available():
|
| 87 |
+
torch.cuda.empty_cache()
|
| 88 |
+
logger.debug("Cleared GPU cache after model loading")
|
| 89 |
except Exception as e:
|
| 90 |
set_model_loading_state(model_name, "error")
|
| 91 |
logger.error(f"Failed to initialize medical model {model_name}: {e}")
|
| 92 |
+
# Clear cache on error
|
| 93 |
+
if torch.cuda.is_available():
|
| 94 |
+
torch.cuda.empty_cache()
|
| 95 |
raise
|
| 96 |
else:
|
| 97 |
# Model already loaded, ensure state is set
|
|
|
|
| 106 |
return None
|
| 107 |
if config.global_tts_model is None:
|
| 108 |
try:
|
| 109 |
+
# Clear GPU cache before loading
|
| 110 |
+
if torch.cuda.is_available():
|
| 111 |
+
torch.cuda.empty_cache()
|
| 112 |
+
logger.debug("Cleared GPU cache before TTS model loading")
|
| 113 |
+
|
| 114 |
logger.info("Initializing TTS model for voice generation...")
|
| 115 |
config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False)
|
| 116 |
logger.info("TTS model initialized successfully")
|
| 117 |
+
|
| 118 |
+
# Clear cache after loading
|
| 119 |
+
if torch.cuda.is_available():
|
| 120 |
+
torch.cuda.empty_cache()
|
| 121 |
+
logger.debug("Cleared GPU cache after TTS model loading")
|
| 122 |
except Exception as e:
|
| 123 |
logger.warning(f"TTS model initialization failed: {e}")
|
| 124 |
logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts")
|
| 125 |
config.global_tts_model = None
|
| 126 |
+
# Clear cache on error
|
| 127 |
+
if torch.cuda.is_available():
|
| 128 |
+
torch.cuda.empty_cache()
|
| 129 |
return config.global_tts_model
|
| 130 |
|
| 131 |
def initialize_whisper_model():
|
|
|
|
| 135 |
return None
|
| 136 |
if config.global_whisper_model is None:
|
| 137 |
try:
|
| 138 |
+
# Clear GPU cache before loading
|
| 139 |
+
if torch.cuda.is_available():
|
| 140 |
+
torch.cuda.empty_cache()
|
| 141 |
+
logger.debug("Cleared GPU cache before Whisper model loading")
|
| 142 |
+
|
| 143 |
logger.info("Initializing Whisper model (openai/whisper-large-v3-turbo) from Hugging Face...")
|
| 144 |
model_id = "openai/whisper-large-v3-turbo"
|
| 145 |
processor = WhisperProcessor.from_pretrained(model_id, token=config.HF_TOKEN)
|
|
|
|
| 152 |
# Store both processor and model
|
| 153 |
config.global_whisper_model = {"processor": processor, "model": model}
|
| 154 |
logger.info(f"Whisper model ({model_id}) initialized successfully")
|
| 155 |
+
|
| 156 |
+
# Clear cache after loading
|
| 157 |
+
if torch.cuda.is_available():
|
| 158 |
+
torch.cuda.empty_cache()
|
| 159 |
+
logger.debug("Cleared GPU cache after Whisper model loading")
|
| 160 |
except Exception as e:
|
| 161 |
logger.warning(f"Whisper model initialization failed: {e}")
|
| 162 |
logger.warning("ASR features will be disabled. Install with: pip install transformers torchaudio")
|
| 163 |
config.global_whisper_model = None
|
| 164 |
+
# Clear cache on error
|
| 165 |
+
if torch.cuda.is_available():
|
| 166 |
+
torch.cuda.empty_cache()
|
| 167 |
return config.global_whisper_model
|
| 168 |
|
| 169 |
def get_or_create_embed_model():
|
ui.py
CHANGED
|
@@ -290,7 +290,7 @@ def create_demo():
|
|
| 290 |
)
|
| 291 |
|
| 292 |
# GPU-decorated function to load any model (for user selection)
|
| 293 |
-
@spaces.GPU(max_duration=120)
|
| 294 |
def load_model_with_gpu(model_name):
|
| 295 |
"""Load medical model (GPU-decorated for ZeroGPU compatibility)"""
|
| 296 |
try:
|
|
@@ -404,59 +404,103 @@ def create_demo():
|
|
| 404 |
is_ready = is_model_loaded(model_name)
|
| 405 |
return status_text, is_ready
|
| 406 |
|
| 407 |
-
# GPU-decorated function to load
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 413 |
-
logger.info(f"Loading
|
| 414 |
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
|
| 415 |
try:
|
| 416 |
initialize_medical_model(DEFAULT_MEDICAL_MODEL)
|
| 417 |
-
|
| 418 |
-
|
| 419 |
except Exception as e:
|
| 420 |
-
|
|
|
|
| 421 |
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
|
| 422 |
-
return f"❌ Error loading model: {str(e)[:100]}"
|
| 423 |
else:
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
try:
|
| 435 |
-
# Load TTS model
|
| 436 |
if TTS_AVAILABLE:
|
| 437 |
-
logger.info("Loading
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
else:
|
| 444 |
-
|
|
|
|
| 445 |
|
| 446 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
if WHISPER_AVAILABLE:
|
| 448 |
-
logger.info("Loading
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
else:
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
except Exception as e:
|
| 457 |
-
logger.error(f"Error in
|
| 458 |
import traceback
|
| 459 |
-
logger.debug(f"Full traceback: {traceback.format_exc()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
|
| 461 |
# Initialize status on load
|
| 462 |
def init_model_status():
|
|
@@ -522,33 +566,10 @@ def create_demo():
|
|
| 522 |
outputs=[model_status, submit_button, message_input]
|
| 523 |
)
|
| 524 |
|
| 525 |
-
# Load models
|
| 526 |
-
#
|
| 527 |
-
demo.load(
|
| 528 |
-
fn=load_default_model_on_startup,
|
| 529 |
-
inputs=None,
|
| 530 |
-
outputs=[model_status]
|
| 531 |
-
)
|
| 532 |
-
# Then load voice models (TTS and ASR)
|
| 533 |
-
demo.load(
|
| 534 |
-
fn=load_voice_models_on_startup,
|
| 535 |
-
inputs=None,
|
| 536 |
-
outputs=None
|
| 537 |
-
)
|
| 538 |
-
# Finally update status to show all models
|
| 539 |
-
def update_status_after_load():
|
| 540 |
-
try:
|
| 541 |
-
result = check_model_status(DEFAULT_MEDICAL_MODEL)
|
| 542 |
-
if result and isinstance(result, tuple) and len(result) == 2:
|
| 543 |
-
return result[0]
|
| 544 |
-
else:
|
| 545 |
-
return "⚠️ Unable to check model status"
|
| 546 |
-
except Exception as e:
|
| 547 |
-
logger.error(f"Error updating status after load: {e}")
|
| 548 |
-
return f"⚠️ Error: {str(e)[:100]}"
|
| 549 |
-
|
| 550 |
demo.load(
|
| 551 |
-
fn=
|
| 552 |
inputs=None,
|
| 553 |
outputs=[model_status]
|
| 554 |
)
|
|
|
|
| 290 |
)
|
| 291 |
|
| 292 |
# GPU-decorated function to load any model (for user selection)
|
| 293 |
+
# @spaces.GPU(max_duration=120)
|
| 294 |
def load_model_with_gpu(model_name):
|
| 295 |
"""Load medical model (GPU-decorated for ZeroGPU compatibility)"""
|
| 296 |
try:
|
|
|
|
| 404 |
is_ready = is_model_loaded(model_name)
|
| 405 |
return status_text, is_ready
|
| 406 |
|
| 407 |
+
# GPU-decorated function to load ALL models sequentially on startup
|
| 408 |
+
# This prevents ZeroGPU conflicts from multiple simultaneous GPU requests
|
| 409 |
+
# @spaces.GPU(max_duration=180)
|
| 410 |
+
def load_all_models_on_startup():
|
| 411 |
+
"""Load all models sequentially in a single GPU session to avoid ZeroGPU conflicts"""
|
| 412 |
+
import time
|
| 413 |
+
import torch
|
| 414 |
+
status_messages = []
|
| 415 |
+
|
| 416 |
try:
|
| 417 |
+
# Clear GPU cache at start
|
| 418 |
+
if torch.cuda.is_available():
|
| 419 |
+
torch.cuda.empty_cache()
|
| 420 |
+
logger.info("[STARTUP] Cleared GPU cache before model loading")
|
| 421 |
+
|
| 422 |
+
# Step 1: Load medical model (MedSwin)
|
| 423 |
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 424 |
+
logger.info(f"[STARTUP] Step 1/3: Loading medical model: {DEFAULT_MEDICAL_MODEL}...")
|
| 425 |
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
|
| 426 |
try:
|
| 427 |
initialize_medical_model(DEFAULT_MEDICAL_MODEL)
|
| 428 |
+
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded")
|
| 429 |
+
logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded successfully!")
|
| 430 |
except Exception as e:
|
| 431 |
+
status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error")
|
| 432 |
+
logger.error(f"[STARTUP] Failed to load medical model: {e}")
|
| 433 |
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
|
|
|
|
| 434 |
else:
|
| 435 |
+
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
|
| 436 |
+
logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
|
| 437 |
+
|
| 438 |
+
# Small delay to let GPU settle and clear cache
|
| 439 |
+
time.sleep(2)
|
| 440 |
+
if torch.cuda.is_available():
|
| 441 |
+
torch.cuda.empty_cache()
|
| 442 |
+
logger.debug("[STARTUP] Cleared GPU cache after medical model")
|
| 443 |
+
|
| 444 |
+
# Step 2: Load TTS model (maya1)
|
|
|
|
|
|
|
| 445 |
if TTS_AVAILABLE:
|
| 446 |
+
logger.info("[STARTUP] Step 2/3: Loading TTS model (maya1)...")
|
| 447 |
+
try:
|
| 448 |
+
initialize_tts_model()
|
| 449 |
+
if config.global_tts_model is not None:
|
| 450 |
+
status_messages.append("✅ TTS (maya1): loaded")
|
| 451 |
+
logger.info("[STARTUP] ✅ TTS model loaded successfully!")
|
| 452 |
+
else:
|
| 453 |
+
status_messages.append("⚠️ TTS (maya1): failed")
|
| 454 |
+
logger.warning("[STARTUP] ⚠️ TTS model failed to load")
|
| 455 |
+
except Exception as e:
|
| 456 |
+
status_messages.append("❌ TTS (maya1): error")
|
| 457 |
+
logger.error(f"[STARTUP] TTS model loading error: {e}")
|
| 458 |
else:
|
| 459 |
+
status_messages.append("❌ TTS: library not available")
|
| 460 |
+
logger.warning("[STARTUP] TTS library not installed")
|
| 461 |
|
| 462 |
+
# Small delay to let GPU settle and clear cache
|
| 463 |
+
time.sleep(2)
|
| 464 |
+
if torch.cuda.is_available():
|
| 465 |
+
torch.cuda.empty_cache()
|
| 466 |
+
logger.debug("[STARTUP] Cleared GPU cache after TTS model")
|
| 467 |
+
|
| 468 |
+
# Step 3: Load ASR model (Whisper)
|
| 469 |
if WHISPER_AVAILABLE:
|
| 470 |
+
logger.info("[STARTUP] Step 3/3: Loading ASR model (Whisper)...")
|
| 471 |
+
try:
|
| 472 |
+
initialize_whisper_model()
|
| 473 |
+
if config.global_whisper_model is not None:
|
| 474 |
+
status_messages.append("✅ ASR (Whisper): loaded")
|
| 475 |
+
logger.info("[STARTUP] ✅ ASR model loaded successfully!")
|
| 476 |
+
else:
|
| 477 |
+
status_messages.append("⚠️ ASR (Whisper): failed")
|
| 478 |
+
logger.warning("[STARTUP] ⚠️ ASR model failed to load")
|
| 479 |
+
except Exception as e:
|
| 480 |
+
status_messages.append("❌ ASR (Whisper): error")
|
| 481 |
+
logger.error(f"[STARTUP] ASR model loading error: {e}")
|
| 482 |
else:
|
| 483 |
+
status_messages.append("❌ ASR: library not available")
|
| 484 |
+
logger.warning("[STARTUP] Whisper library not installed")
|
| 485 |
+
|
| 486 |
+
# Final cache clear
|
| 487 |
+
if torch.cuda.is_available():
|
| 488 |
+
torch.cuda.empty_cache()
|
| 489 |
+
logger.debug("[STARTUP] Final GPU cache clear")
|
| 490 |
+
|
| 491 |
+
# Return combined status
|
| 492 |
+
status_text = "\n".join(status_messages)
|
| 493 |
+
logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}")
|
| 494 |
+
return status_text
|
| 495 |
+
|
| 496 |
except Exception as e:
|
| 497 |
+
logger.error(f"[STARTUP] ❌ Error in model loading startup: {e}")
|
| 498 |
import traceback
|
| 499 |
+
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
|
| 500 |
+
# Clear cache on error
|
| 501 |
+
if torch.cuda.is_available():
|
| 502 |
+
torch.cuda.empty_cache()
|
| 503 |
+
return f"⚠️ Startup loading error: {str(e)[:100]}"
|
| 504 |
|
| 505 |
# Initialize status on load
|
| 506 |
def init_model_status():
|
|
|
|
| 566 |
outputs=[model_status, submit_button, message_input]
|
| 567 |
)
|
| 568 |
|
| 569 |
+
# Load ALL models sequentially in a SINGLE GPU session to avoid ZeroGPU conflicts
|
| 570 |
+
# This prevents "GPU aborted" errors from multiple simultaneous GPU requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
demo.load(
|
| 572 |
+
fn=load_all_models_on_startup,
|
| 573 |
inputs=None,
|
| 574 |
outputs=[model_status]
|
| 575 |
)
|