"""Model initialization and management""" import os import torch import threading from transformers import AutoModelForCausalLM, AutoTokenizer from llama_index.llms.huggingface import HuggingFaceLLM from llama_index.embeddings.huggingface import HuggingFaceEmbedding from logger import logger import config import spaces try: from snac import SNAC SNAC_AVAILABLE = True except ImportError: SNAC_AVAILABLE = False SNAC = None # For backward compatibility, check TTS library too (but we use Maya1 directly) try: from TTS.api import TTS TTS_AVAILABLE = True except ImportError: TTS_AVAILABLE = False TTS = None try: from transformers import WhisperProcessor, WhisperForConditionalGeneration try: import torchaudio except ImportError: torchaudio = None WHISPER_AVAILABLE = True except ImportError: WHISPER_AVAILABLE = False WhisperProcessor = None WhisperForConditionalGeneration = None torchaudio = None # Model loading state tracking _model_loading_states = {} _model_loading_lock = threading.Lock() def set_model_loading_state(model_name: str, state: str): """ Set model loading state: 'loading', 'loaded', 'error' Note: No GPU decorator needed - this just sets a dictionary value, no GPU access required. """ with _model_loading_lock: _model_loading_states[model_name] = state logger.debug(f"Model {model_name} state set to: {state}") def get_model_loading_state(model_name: str) -> str: """ Get model loading state: 'loading', 'loaded', 'error', or 'unknown' Note: No GPU decorator needed - this just reads a dictionary value, no GPU access required. """ with _model_loading_lock: return _model_loading_states.get(model_name, "unknown") def is_model_loaded(model_name: str) -> bool: """Check if model is loaded and ready""" with _model_loading_lock: return (model_name in config.global_medical_models and config.global_medical_models[model_name] is not None and _model_loading_states.get(model_name) == "loaded") def initialize_medical_model(model_name: str, load_to_gpu: bool = True): """ Initialize medical model (MedSwin) - download on demand According to ZeroGPU best practices: - If load_to_gpu=True: Load directly to GPU using device_map="auto" (must be called within @spaces.GPU decorated function) - If load_to_gpu=False: Load to CPU first, then move to GPU in inference function Args: model_name: Name of the model to load load_to_gpu: If True, load directly to GPU. If False, load to CPU (for ZeroGPU best practices) """ if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None: set_model_loading_state(model_name, "loading") logger.info(f"Initializing medical model: {model_name}... (load_to_gpu={load_to_gpu})") try: model_path = config.MEDSWIN_MODELS[model_name] tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN) if load_to_gpu: # Load directly to GPU (must be within @spaces.GPU decorated function) # Clear GPU cache before loading to prevent memory issues if torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("Cleared GPU cache before model loading") model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", # Automatically places model on GPU trust_remote_code=True, token=config.HF_TOKEN, torch_dtype=torch.float16 ) # Clear cache after loading if torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("Cleared GPU cache after model loading") else: # Load to CPU first (ZeroGPU best practice - no GPU decorator needed) logger.info(f"Loading {model_name} to CPU (will move to GPU during inference)...") model = AutoModelForCausalLM.from_pretrained( model_path, device_map="cpu", # Load to CPU trust_remote_code=True, token=config.HF_TOKEN, torch_dtype=torch.float16 ) logger.info(f"Model {model_name} loaded to CPU successfully") # Set models in config BEFORE setting state to "loaded" config.global_medical_models[model_name] = model config.global_medical_tokenizers[model_name] = tokenizer # Set state to "loaded" AFTER models are stored set_model_loading_state(model_name, "loaded") logger.info(f"Medical model {model_name} initialized successfully") # Verify the state was set correctly if not is_model_loaded(model_name): logger.warning(f"Model {model_name} initialized but is_model_loaded() returns False. State: {get_model_loading_state(model_name)}, in dict: {model_name in config.global_medical_models}") except Exception as e: set_model_loading_state(model_name, "error") logger.error(f"Failed to initialize medical model {model_name}: {e}") # Clear cache on error if torch.cuda.is_available(): torch.cuda.empty_cache() raise else: # Model already loaded, ensure state is set if get_model_loading_state(model_name) != "loaded": logger.info(f"Model {model_name} exists in config but state not set to 'loaded'. Setting state now.") set_model_loading_state(model_name, "loaded") return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name] def move_model_to_gpu(model_name: str): """ Move a model from CPU to GPU (for ZeroGPU best practices) Must be called within a @spaces.GPU decorated function According to ZeroGPU best practices: - Models should be loaded to CPU first (no GPU quota used) - Models are moved to GPU only during inference (within @spaces.GPU decorated function) For models loaded with device_map="cpu", we reload with device_map="auto" to avoid meta tensor issues when moving to GPU. """ if model_name not in config.global_medical_models: raise ValueError(f"Model {model_name} not found in config") model = config.global_medical_models[model_name] if model is None: raise ValueError(f"Model {model_name} is None") # Check if model is already on GPU try: # For models with device_map, check the actual device if hasattr(model, 'device'): device_str = str(model.device) if 'cuda' in device_str.lower(): logger.debug(f"Model {model_name} is already on GPU ({device_str})") return model # Check device_map if available if hasattr(model, 'hf_device_map'): device_map = model.hf_device_map if isinstance(device_map, dict): # Check if any device is GPU if any('cuda' in str(dev).lower() for dev in device_map.values()): logger.debug(f"Model {model_name} is already on GPU (device_map)") return model except Exception as e: logger.debug(f"Could not check model device: {e}") # For models loaded with device_map="cpu", we need to reload with device_map="auto" # because models with meta tensors cannot be moved with .to() logger.info(f"Moving model {model_name} from CPU to GPU...") if torch.cuda.is_available(): torch.cuda.empty_cache() # Get model path for reloading if model_name not in config.MEDSWIN_MODELS: raise ValueError(f"Model path for {model_name} not found in config.MEDSWIN_MODELS") model_path = config.MEDSWIN_MODELS[model_name] try: # Reload model with device_map="auto" to place it on GPU # This avoids meta tensor issues when moving from CPU to GPU logger.info(f"Reloading model {model_name} with device_map='auto' for GPU placement...") # Delete the old model to free memory del model if torch.cuda.is_available(): torch.cuda.empty_cache() # Reload with GPU device_map model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", # Automatically places model on GPU trust_remote_code=True, token=config.HF_TOKEN, torch_dtype=torch.float16 ) config.global_medical_models[model_name] = model logger.info(f"Model {model_name} reloaded to GPU successfully") except Exception as e: logger.error(f"Failed to reload model {model_name} to GPU: {e}") # Try fallback with accelerate dispatch if reload fails try: logger.info(f"Trying accelerate dispatch as fallback...") from accelerate import dispatch_model from accelerate.utils import get_balanced_memory, infer_auto_device_map # Reload model first (in case deletion happened) if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None: logger.info(f"Reloading model {model_name} to CPU for accelerate dispatch...") model = AutoModelForCausalLM.from_pretrained( model_path, device_map="cpu", trust_remote_code=True, token=config.HF_TOKEN, torch_dtype=torch.float16 ) config.global_medical_models[model_name] = model else: model = config.global_medical_models[model_name] # Get device map for GPU max_memory = get_balanced_memory(model, max_memory={0: "20GiB"}) device_map = infer_auto_device_map(model, max_memory=max_memory) model = dispatch_model(model, device_map=device_map) config.global_medical_models[model_name] = model logger.info(f"Model {model_name} moved to GPU successfully using accelerate dispatch") except Exception as e2: logger.error(f"Failed to move model {model_name} to GPU with all methods: {e2}") raise if torch.cuda.is_available(): torch.cuda.empty_cache() return model def initialize_tts_model(): """Initialize Maya1 TTS model for text-to-speech using transformers and SNAC""" if not SNAC_AVAILABLE: logger.warning("SNAC library not installed. Maya1 TTS features will be disabled.") logger.warning("Install with: pip install snac") return None if config.global_tts_model is None: try: # Clear GPU cache before loading if torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("Cleared GPU cache before TTS model loading") logger.info("Initializing Maya1 TTS model with Transformers...") # Load Maya1 model and tokenizer model = AutoModelForCausalLM.from_pretrained( config.TTS_MODEL, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, token=config.HF_TOKEN ) tokenizer = AutoTokenizer.from_pretrained( config.TTS_MODEL, trust_remote_code=True, token=config.HF_TOKEN ) logger.info("Loading SNAC decoder...") snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() if torch.cuda.is_available(): snac_model = snac_model.to("cuda") # Store as a dictionary with model, tokenizer, and snac_model config.global_tts_model = { "model": model, "tokenizer": tokenizer, "snac_model": snac_model } logger.info("Maya1 TTS model initialized successfully") # Clear cache after loading if torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("Cleared GPU cache after TTS model loading") except Exception as e: logger.warning(f"Maya1 TTS model initialization failed: {e}") import traceback logger.warning(f"TTS initialization traceback: {traceback.format_exc()}") logger.warning("TTS features will be disabled. Install dependencies: pip install snac transformers") config.global_tts_model = None # Clear cache on error if torch.cuda.is_available(): torch.cuda.empty_cache() return config.global_tts_model def initialize_whisper_model(): """Initialize Whisper model for speech-to-text (ASR) from Hugging Face""" if not WHISPER_AVAILABLE: logger.warning("Whisper transformers not installed. ASR features will be disabled.") return None if config.global_whisper_model is None: try: # Clear GPU cache before loading if torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("Cleared GPU cache before Whisper model loading") logger.info("Initializing Whisper model (openai/whisper-large-v3-turbo) from Hugging Face...") model_id = "openai/whisper-large-v3-turbo" processor = WhisperProcessor.from_pretrained(model_id, token=config.HF_TOKEN) model = WhisperForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16, token=config.HF_TOKEN ) # Store both processor and model config.global_whisper_model = {"processor": processor, "model": model} logger.info(f"Whisper model ({model_id}) initialized successfully") # Clear cache after loading if torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("Cleared GPU cache after Whisper model loading") except Exception as e: logger.warning(f"Whisper model initialization failed: {e}") logger.warning("ASR features will be disabled. Install with: pip install transformers torchaudio") config.global_whisper_model = None # Clear cache on error if torch.cuda.is_available(): torch.cuda.empty_cache() return config.global_whisper_model def get_or_create_embed_model(): """Reuse embedding model to avoid reloading weights each request""" if config.global_embed_model is None: logger.info("Initializing shared embedding model for RAG retrieval...") config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN) return config.global_embed_model def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): """Get LLM for RAG indexing (uses medical model)""" medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL) return HuggingFaceLLM( context_window=4096, max_new_tokens=max_new_tokens, tokenizer=medical_tokenizer, model=medical_model_obj, generate_kwargs={ "do_sample": True, "temperature": temperature, "top_k": top_k, "top_p": top_p } )