Spaces:
Running
on
Zero
Running
on
Zero
| """Model initialization and management""" | |
| import torch | |
| import threading | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from llama_index.llms.huggingface import HuggingFaceLLM | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from logger import logger | |
| import config | |
| try: | |
| from TTS.api import TTS | |
| TTS_AVAILABLE = True | |
| except ImportError: | |
| TTS_AVAILABLE = False | |
| TTS = None | |
| # Model loading state tracking | |
| _model_loading_states = {} | |
| _model_loading_lock = threading.Lock() | |
| def set_model_loading_state(model_name: str, state: str): | |
| """Set model loading state: 'loading', 'loaded', 'error'""" | |
| with _model_loading_lock: | |
| _model_loading_states[model_name] = state | |
| logger.debug(f"Model {model_name} state set to: {state}") | |
| def get_model_loading_state(model_name: str) -> str: | |
| """Get model loading state: 'loading', 'loaded', 'error', or 'unknown'""" | |
| with _model_loading_lock: | |
| return _model_loading_states.get(model_name, "unknown") | |
| def is_model_loaded(model_name: str) -> bool: | |
| """Check if model is loaded and ready""" | |
| with _model_loading_lock: | |
| return (model_name in config.global_medical_models and | |
| config.global_medical_models[model_name] is not None and | |
| _model_loading_states.get(model_name) == "loaded") | |
| def initialize_medical_model(model_name: str): | |
| """Initialize medical model (MedSwin) - download on demand""" | |
| if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None: | |
| set_model_loading_state(model_name, "loading") | |
| logger.info(f"Initializing medical model: {model_name}...") | |
| try: | |
| model_path = config.MEDSWIN_MODELS[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| token=config.HF_TOKEN, | |
| torch_dtype=torch.float16 | |
| ) | |
| config.global_medical_models[model_name] = model | |
| config.global_medical_tokenizers[model_name] = tokenizer | |
| set_model_loading_state(model_name, "loaded") | |
| logger.info(f"Medical model {model_name} initialized successfully") | |
| except Exception as e: | |
| set_model_loading_state(model_name, "error") | |
| logger.error(f"Failed to initialize medical model {model_name}: {e}") | |
| raise | |
| else: | |
| # Model already loaded, ensure state is set | |
| if get_model_loading_state(model_name) != "loaded": | |
| set_model_loading_state(model_name, "loaded") | |
| return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name] | |
| def initialize_tts_model(): | |
| """Initialize TTS model for text-to-speech""" | |
| if not TTS_AVAILABLE: | |
| logger.warning("TTS library not installed. TTS features will be disabled.") | |
| return None | |
| if config.global_tts_model is None: | |
| try: | |
| logger.info("Initializing TTS model for voice generation...") | |
| config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False) | |
| logger.info("TTS model initialized successfully") | |
| except Exception as e: | |
| logger.warning(f"TTS model initialization failed: {e}") | |
| logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts") | |
| config.global_tts_model = None | |
| return config.global_tts_model | |
| def get_or_create_embed_model(): | |
| """Reuse embedding model to avoid reloading weights each request""" | |
| if config.global_embed_model is None: | |
| logger.info("Initializing shared embedding model for RAG retrieval...") | |
| config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN) | |
| return config.global_embed_model | |
| def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): | |
| """Get LLM for RAG indexing (uses medical model)""" | |
| medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL) | |
| return HuggingFaceLLM( | |
| context_window=4096, | |
| max_new_tokens=max_new_tokens, | |
| tokenizer=medical_tokenizer, | |
| model=medical_model_obj, | |
| generate_kwargs={ | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": top_k, | |
| "top_p": top_p | |
| } | |
| ) | |