Y Phung Nguyen commited on
Commit
c11b620
·
1 Parent(s): 4a5418d

Fix model preloader

Browse files
Files changed (2) hide show
  1. models.py +7 -0
  2. ui.py +6 -3
models.py CHANGED
@@ -77,11 +77,17 @@ def initialize_medical_model(model_name: str):
77
  token=config.HF_TOKEN,
78
  torch_dtype=torch.float16
79
  )
 
80
  config.global_medical_models[model_name] = model
81
  config.global_medical_tokenizers[model_name] = tokenizer
 
82
  set_model_loading_state(model_name, "loaded")
83
  logger.info(f"Medical model {model_name} initialized successfully")
84
 
 
 
 
 
85
  # Clear cache after loading to free up temporary memory
86
  if torch.cuda.is_available():
87
  torch.cuda.empty_cache()
@@ -96,6 +102,7 @@ def initialize_medical_model(model_name: str):
96
  else:
97
  # Model already loaded, ensure state is set
98
  if get_model_loading_state(model_name) != "loaded":
 
99
  set_model_loading_state(model_name, "loaded")
100
  return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
101
 
 
77
  token=config.HF_TOKEN,
78
  torch_dtype=torch.float16
79
  )
80
+ # Set models in config BEFORE setting state to "loaded"
81
  config.global_medical_models[model_name] = model
82
  config.global_medical_tokenizers[model_name] = tokenizer
83
+ # Set state to "loaded" AFTER models are stored
84
  set_model_loading_state(model_name, "loaded")
85
  logger.info(f"Medical model {model_name} initialized successfully")
86
 
87
+ # Verify the state was set correctly
88
+ if not is_model_loaded(model_name):
89
+ logger.warning(f"Model {model_name} initialized but is_model_loaded() returns False. State: {get_model_loading_state(model_name)}, in dict: {model_name in config.global_medical_models}")
90
+
91
  # Clear cache after loading to free up temporary memory
92
  if torch.cuda.is_available():
93
  torch.cuda.empty_cache()
 
102
  else:
103
  # Model already loaded, ensure state is set
104
  if get_model_loading_state(model_name) != "loaded":
105
+ logger.info(f"Model {model_name} exists in config but state not set to 'loaded'. Setting state now.")
106
  set_model_loading_state(model_name, "loaded")
107
  return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
108
 
ui.py CHANGED
@@ -696,8 +696,7 @@ def create_demo():
696
  preload_model_on_input_focus()
697
  except Exception as e:
698
  logger.debug(f"[PRELOAD] Pre-load trigger error (non-critical): {e}")
699
- # Return empty string to not update any UI element
700
- return ""
701
 
702
  # Trigger model pre-loading when user focuses on message input
703
  message_input.focus(
@@ -714,8 +713,12 @@ def create_demo():
714
  enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
715
  ):
716
  # Check if model is loaded - if not, show error (don't load here to save stream_chat time)
717
- if not is_model_loaded(medical_model_name):
 
718
  loading_state = get_model_loading_state(medical_model_name)
 
 
 
719
  if loading_state == "loading":
720
  error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
721
  else:
 
696
  preload_model_on_input_focus()
697
  except Exception as e:
698
  logger.debug(f"[PRELOAD] Pre-load trigger error (non-critical): {e}")
699
+ # Don't return anything - outputs=None means no return value expected
 
700
 
701
  # Trigger model pre-loading when user focuses on message input
702
  message_input.focus(
 
713
  enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
714
  ):
715
  # Check if model is loaded - if not, show error (don't load here to save stream_chat time)
716
+ model_loaded = is_model_loaded(medical_model_name)
717
+ if not model_loaded:
718
  loading_state = get_model_loading_state(medical_model_name)
719
+ # Debug logging to understand why model check fails
720
+ logger.debug(f"[STREAM_CHAT] Model check: name={medical_model_name}, loaded={model_loaded}, state={loading_state}, in_dict={medical_model_name in config.global_medical_models}, model_exists={config.global_medical_models.get(medical_model_name) is not None if medical_model_name in config.global_medical_models else False}")
721
+
722
  if loading_state == "loading":
723
  error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
724
  else: