Spaces:
Running
on
Zero
Running
on
Zero
Y Phung Nguyen
commited on
Commit
·
4fb2874
1
Parent(s):
67da541
Resolve continous aborted GPU tasks
Browse files
models.py
CHANGED
|
@@ -43,7 +43,6 @@ def is_model_loaded(model_name: str) -> bool:
|
|
| 43 |
config.global_medical_models[model_name] is not None and
|
| 44 |
_model_loading_states.get(model_name) == "loaded")
|
| 45 |
|
| 46 |
-
@spaces.GPU(max_duration=120)
|
| 47 |
def initialize_medical_model(model_name: str):
|
| 48 |
"""Initialize medical model (MedSwin) - download on demand"""
|
| 49 |
if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
|
|
@@ -73,7 +72,6 @@ def initialize_medical_model(model_name: str):
|
|
| 73 |
set_model_loading_state(model_name, "loaded")
|
| 74 |
return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
|
| 75 |
|
| 76 |
-
@spaces.GPU(max_duration=120)
|
| 77 |
def initialize_tts_model():
|
| 78 |
"""Initialize TTS model for text-to-speech"""
|
| 79 |
if not TTS_AVAILABLE:
|
|
@@ -90,7 +88,6 @@ def initialize_tts_model():
|
|
| 90 |
config.global_tts_model = None
|
| 91 |
return config.global_tts_model
|
| 92 |
|
| 93 |
-
@spaces.GPU(max_duration=120)
|
| 94 |
def get_or_create_embed_model():
|
| 95 |
"""Reuse embedding model to avoid reloading weights each request"""
|
| 96 |
if config.global_embed_model is None:
|
|
@@ -98,7 +95,6 @@ def get_or_create_embed_model():
|
|
| 98 |
config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN)
|
| 99 |
return config.global_embed_model
|
| 100 |
|
| 101 |
-
@spaces.GPU(max_duration=120)
|
| 102 |
def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
|
| 103 |
"""Get LLM for RAG indexing (uses medical model)"""
|
| 104 |
medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL)
|
|
|
|
| 43 |
config.global_medical_models[model_name] is not None and
|
| 44 |
_model_loading_states.get(model_name) == "loaded")
|
| 45 |
|
|
|
|
| 46 |
def initialize_medical_model(model_name: str):
|
| 47 |
"""Initialize medical model (MedSwin) - download on demand"""
|
| 48 |
if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
|
|
|
|
| 72 |
set_model_loading_state(model_name, "loaded")
|
| 73 |
return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
|
| 74 |
|
|
|
|
| 75 |
def initialize_tts_model():
|
| 76 |
"""Initialize TTS model for text-to-speech"""
|
| 77 |
if not TTS_AVAILABLE:
|
|
|
|
| 88 |
config.global_tts_model = None
|
| 89 |
return config.global_tts_model
|
| 90 |
|
|
|
|
| 91 |
def get_or_create_embed_model():
|
| 92 |
"""Reuse embedding model to avoid reloading weights each request"""
|
| 93 |
if config.global_embed_model is None:
|
|
|
|
| 95 |
config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN)
|
| 96 |
return config.global_embed_model
|
| 97 |
|
|
|
|
| 98 |
def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
|
| 99 |
"""Get LLM for RAG indexing (uses medical model)"""
|
| 100 |
medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL)
|