"""Model initialization and management""" import os import torch import threading from transformers import AutoModelForCausalLM, AutoTokenizer from llama_index.llms.huggingface import HuggingFaceLLM from llama_index.embeddings.huggingface import HuggingFaceEmbedding from logger import logger import config import spaces try: from TTS.api import TTS TTS_AVAILABLE = True except ImportError: TTS_AVAILABLE = False TTS = None try: from transformers import WhisperProcessor, WhisperForConditionalGeneration try: import torchaudio except ImportError: torchaudio = None WHISPER_AVAILABLE = True except ImportError: WHISPER_AVAILABLE = False WhisperProcessor = None WhisperForConditionalGeneration = None torchaudio = None # Model loading state tracking _model_loading_states = {} _model_loading_lock = threading.Lock() def set_model_loading_state(model_name: str, state: str): """ Set model loading state: 'loading', 'loaded', 'error' Note: No GPU decorator needed - this just sets a dictionary value, no GPU access required. """ with _model_loading_lock: _model_loading_states[model_name] = state logger.debug(f"Model {model_name} state set to: {state}") def get_model_loading_state(model_name: str) -> str: """ Get model loading state: 'loading', 'loaded', 'error', or 'unknown' Note: No GPU decorator needed - this just reads a dictionary value, no GPU access required. """ with _model_loading_lock: return _model_loading_states.get(model_name, "unknown") def is_model_loaded(model_name: str) -> bool: """Check if model is loaded and ready""" with _model_loading_lock: return (model_name in config.global_medical_models and config.global_medical_models[model_name] is not None and _model_loading_states.get(model_name) == "loaded") def initialize_medical_model(model_name: str): """Initialize medical model (MedSwin) - download on demand""" if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None: set_model_loading_state(model_name, "loading") logger.info(f"Initializing medical model: {model_name}...") try: model_path = config.MEDSWIN_MODELS[model_name] tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN) model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", trust_remote_code=True, token=config.HF_TOKEN, torch_dtype=torch.float16 ) config.global_medical_models[model_name] = model config.global_medical_tokenizers[model_name] = tokenizer set_model_loading_state(model_name, "loaded") logger.info(f"Medical model {model_name} initialized successfully") except Exception as e: set_model_loading_state(model_name, "error") logger.error(f"Failed to initialize medical model {model_name}: {e}") raise else: # Model already loaded, ensure state is set if get_model_loading_state(model_name) != "loaded": set_model_loading_state(model_name, "loaded") return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name] def initialize_tts_model(): """Initialize TTS model for text-to-speech""" if not TTS_AVAILABLE: logger.warning("TTS library not installed. TTS features will be disabled.") return None if config.global_tts_model is None: try: logger.info("Initializing TTS model for voice generation...") config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False) logger.info("TTS model initialized successfully") except Exception as e: logger.warning(f"TTS model initialization failed: {e}") logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts") config.global_tts_model = None return config.global_tts_model def initialize_whisper_model(): """Initialize Whisper model for speech-to-text (ASR) from Hugging Face""" if not WHISPER_AVAILABLE: logger.warning("Whisper transformers not installed. ASR features will be disabled.") return None if config.global_whisper_model is None: try: logger.info("Initializing Whisper model (openai/whisper-large-v3-turbo) from Hugging Face...") model_id = "openai/whisper-large-v3-turbo" processor = WhisperProcessor.from_pretrained(model_id, token=config.HF_TOKEN) model = WhisperForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16, token=config.HF_TOKEN ) # Store both processor and model config.global_whisper_model = {"processor": processor, "model": model} logger.info(f"Whisper model ({model_id}) initialized successfully") except Exception as e: logger.warning(f"Whisper model initialization failed: {e}") logger.warning("ASR features will be disabled. Install with: pip install transformers torchaudio") config.global_whisper_model = None return config.global_whisper_model def get_or_create_embed_model(): """Reuse embedding model to avoid reloading weights each request""" if config.global_embed_model is None: logger.info("Initializing shared embedding model for RAG retrieval...") config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN) return config.global_embed_model def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): """Get LLM for RAG indexing (uses medical model)""" medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL) return HuggingFaceLLM( context_window=4096, max_new_tokens=max_new_tokens, tokenizer=medical_tokenizer, model=medical_model_obj, generate_kwargs={ "do_sample": True, "temperature": temperature, "top_k": top_k, "top_p": top_p } )