Spaces:
Running
on
Zero
Running
on
Zero
Y Phung Nguyen
commited on
Commit
·
c11b620
1
Parent(s):
4a5418d
Fix model preloader
Browse files
models.py
CHANGED
|
@@ -77,11 +77,17 @@ def initialize_medical_model(model_name: str):
|
|
| 77 |
token=config.HF_TOKEN,
|
| 78 |
torch_dtype=torch.float16
|
| 79 |
)
|
|
|
|
| 80 |
config.global_medical_models[model_name] = model
|
| 81 |
config.global_medical_tokenizers[model_name] = tokenizer
|
|
|
|
| 82 |
set_model_loading_state(model_name, "loaded")
|
| 83 |
logger.info(f"Medical model {model_name} initialized successfully")
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# Clear cache after loading to free up temporary memory
|
| 86 |
if torch.cuda.is_available():
|
| 87 |
torch.cuda.empty_cache()
|
|
@@ -96,6 +102,7 @@ def initialize_medical_model(model_name: str):
|
|
| 96 |
else:
|
| 97 |
# Model already loaded, ensure state is set
|
| 98 |
if get_model_loading_state(model_name) != "loaded":
|
|
|
|
| 99 |
set_model_loading_state(model_name, "loaded")
|
| 100 |
return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
|
| 101 |
|
|
|
|
| 77 |
token=config.HF_TOKEN,
|
| 78 |
torch_dtype=torch.float16
|
| 79 |
)
|
| 80 |
+
# Set models in config BEFORE setting state to "loaded"
|
| 81 |
config.global_medical_models[model_name] = model
|
| 82 |
config.global_medical_tokenizers[model_name] = tokenizer
|
| 83 |
+
# Set state to "loaded" AFTER models are stored
|
| 84 |
set_model_loading_state(model_name, "loaded")
|
| 85 |
logger.info(f"Medical model {model_name} initialized successfully")
|
| 86 |
|
| 87 |
+
# Verify the state was set correctly
|
| 88 |
+
if not is_model_loaded(model_name):
|
| 89 |
+
logger.warning(f"Model {model_name} initialized but is_model_loaded() returns False. State: {get_model_loading_state(model_name)}, in dict: {model_name in config.global_medical_models}")
|
| 90 |
+
|
| 91 |
# Clear cache after loading to free up temporary memory
|
| 92 |
if torch.cuda.is_available():
|
| 93 |
torch.cuda.empty_cache()
|
|
|
|
| 102 |
else:
|
| 103 |
# Model already loaded, ensure state is set
|
| 104 |
if get_model_loading_state(model_name) != "loaded":
|
| 105 |
+
logger.info(f"Model {model_name} exists in config but state not set to 'loaded'. Setting state now.")
|
| 106 |
set_model_loading_state(model_name, "loaded")
|
| 107 |
return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
|
| 108 |
|
ui.py
CHANGED
|
@@ -696,8 +696,7 @@ def create_demo():
|
|
| 696 |
preload_model_on_input_focus()
|
| 697 |
except Exception as e:
|
| 698 |
logger.debug(f"[PRELOAD] Pre-load trigger error (non-critical): {e}")
|
| 699 |
-
#
|
| 700 |
-
return ""
|
| 701 |
|
| 702 |
# Trigger model pre-loading when user focuses on message input
|
| 703 |
message_input.focus(
|
|
@@ -714,8 +713,12 @@ def create_demo():
|
|
| 714 |
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
|
| 715 |
):
|
| 716 |
# Check if model is loaded - if not, show error (don't load here to save stream_chat time)
|
| 717 |
-
|
|
|
|
| 718 |
loading_state = get_model_loading_state(medical_model_name)
|
|
|
|
|
|
|
|
|
|
| 719 |
if loading_state == "loading":
|
| 720 |
error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
|
| 721 |
else:
|
|
|
|
| 696 |
preload_model_on_input_focus()
|
| 697 |
except Exception as e:
|
| 698 |
logger.debug(f"[PRELOAD] Pre-load trigger error (non-critical): {e}")
|
| 699 |
+
# Don't return anything - outputs=None means no return value expected
|
|
|
|
| 700 |
|
| 701 |
# Trigger model pre-loading when user focuses on message input
|
| 702 |
message_input.focus(
|
|
|
|
| 713 |
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
|
| 714 |
):
|
| 715 |
# Check if model is loaded - if not, show error (don't load here to save stream_chat time)
|
| 716 |
+
model_loaded = is_model_loaded(medical_model_name)
|
| 717 |
+
if not model_loaded:
|
| 718 |
loading_state = get_model_loading_state(medical_model_name)
|
| 719 |
+
# Debug logging to understand why model check fails
|
| 720 |
+
logger.debug(f"[STREAM_CHAT] Model check: name={medical_model_name}, loaded={model_loaded}, state={loading_state}, in_dict={medical_model_name in config.global_medical_models}, model_exists={config.global_medical_models.get(medical_model_name) is not None if medical_model_name in config.global_medical_models else False}")
|
| 721 |
+
|
| 722 |
if loading_state == "loading":
|
| 723 |
error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
|
| 724 |
else:
|