Spaces:
Running
on
Zero
Running
on
Zero
| """Model initialization and management""" | |
| import os | |
| import torch | |
| import threading | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from llama_index.llms.huggingface import HuggingFaceLLM | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from logger import logger | |
| import config | |
| import spaces | |
| try: | |
| from TTS.api import TTS | |
| TTS_AVAILABLE = True | |
| except ImportError: | |
| TTS_AVAILABLE = False | |
| TTS = None | |
| try: | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| try: | |
| import torchaudio | |
| except ImportError: | |
| torchaudio = None | |
| WHISPER_AVAILABLE = True | |
| except ImportError: | |
| WHISPER_AVAILABLE = False | |
| WhisperProcessor = None | |
| WhisperForConditionalGeneration = None | |
| torchaudio = None | |
| # Model loading state tracking | |
| _model_loading_states = {} | |
| _model_loading_lock = threading.Lock() | |
| def set_model_loading_state(model_name: str, state: str): | |
| """ | |
| Set model loading state: 'loading', 'loaded', 'error' | |
| Note: No GPU decorator needed - this just sets a dictionary value, no GPU access required. | |
| """ | |
| with _model_loading_lock: | |
| _model_loading_states[model_name] = state | |
| logger.debug(f"Model {model_name} state set to: {state}") | |
| def get_model_loading_state(model_name: str) -> str: | |
| """ | |
| Get model loading state: 'loading', 'loaded', 'error', or 'unknown' | |
| Note: No GPU decorator needed - this just reads a dictionary value, no GPU access required. | |
| """ | |
| with _model_loading_lock: | |
| return _model_loading_states.get(model_name, "unknown") | |
| def is_model_loaded(model_name: str) -> bool: | |
| """Check if model is loaded and ready""" | |
| with _model_loading_lock: | |
| return (model_name in config.global_medical_models and | |
| config.global_medical_models[model_name] is not None and | |
| _model_loading_states.get(model_name) == "loaded") | |
| def initialize_medical_model(model_name: str): | |
| """Initialize medical model (MedSwin) - download on demand""" | |
| if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None: | |
| set_model_loading_state(model_name, "loading") | |
| logger.info(f"Initializing medical model: {model_name}...") | |
| try: | |
| # Clear GPU cache before loading to prevent memory issues | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache before model loading") | |
| model_path = config.MEDSWIN_MODELS[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| token=config.HF_TOKEN, | |
| torch_dtype=torch.float16 | |
| ) | |
| # Set models in config BEFORE setting state to "loaded" | |
| config.global_medical_models[model_name] = model | |
| config.global_medical_tokenizers[model_name] = tokenizer | |
| # Set state to "loaded" AFTER models are stored | |
| set_model_loading_state(model_name, "loaded") | |
| logger.info(f"Medical model {model_name} initialized successfully") | |
| # Verify the state was set correctly | |
| if not is_model_loaded(model_name): | |
| logger.warning(f"Model {model_name} initialized but is_model_loaded() returns False. State: {get_model_loading_state(model_name)}, in dict: {model_name in config.global_medical_models}") | |
| # Clear cache after loading to free up temporary memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache after model loading") | |
| except Exception as e: | |
| set_model_loading_state(model_name, "error") | |
| logger.error(f"Failed to initialize medical model {model_name}: {e}") | |
| # Clear cache on error | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| raise | |
| else: | |
| # Model already loaded, ensure state is set | |
| if get_model_loading_state(model_name) != "loaded": | |
| logger.info(f"Model {model_name} exists in config but state not set to 'loaded'. Setting state now.") | |
| set_model_loading_state(model_name, "loaded") | |
| return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name] | |
| def initialize_tts_model(): | |
| """Initialize TTS model for text-to-speech""" | |
| if not TTS_AVAILABLE: | |
| logger.warning("TTS library not installed. TTS features will be disabled.") | |
| return None | |
| if config.global_tts_model is None: | |
| try: | |
| # Clear GPU cache before loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache before TTS model loading") | |
| logger.info("Initializing TTS model for voice generation...") | |
| config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False) | |
| logger.info("TTS model initialized successfully") | |
| # Clear cache after loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache after TTS model loading") | |
| except Exception as e: | |
| logger.warning(f"TTS model initialization failed: {e}") | |
| logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts") | |
| config.global_tts_model = None | |
| # Clear cache on error | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return config.global_tts_model | |
| def initialize_whisper_model(): | |
| """Initialize Whisper model for speech-to-text (ASR) from Hugging Face""" | |
| if not WHISPER_AVAILABLE: | |
| logger.warning("Whisper transformers not installed. ASR features will be disabled.") | |
| return None | |
| if config.global_whisper_model is None: | |
| try: | |
| # Clear GPU cache before loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache before Whisper model loading") | |
| logger.info("Initializing Whisper model (openai/whisper-large-v3-turbo) from Hugging Face...") | |
| model_id = "openai/whisper-large-v3-turbo" | |
| processor = WhisperProcessor.from_pretrained(model_id, token=config.HF_TOKEN) | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| token=config.HF_TOKEN | |
| ) | |
| # Store both processor and model | |
| config.global_whisper_model = {"processor": processor, "model": model} | |
| logger.info(f"Whisper model ({model_id}) initialized successfully") | |
| # Clear cache after loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache after Whisper model loading") | |
| except Exception as e: | |
| logger.warning(f"Whisper model initialization failed: {e}") | |
| logger.warning("ASR features will be disabled. Install with: pip install transformers torchaudio") | |
| config.global_whisper_model = None | |
| # Clear cache on error | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return config.global_whisper_model | |
| def get_or_create_embed_model(): | |
| """Reuse embedding model to avoid reloading weights each request""" | |
| if config.global_embed_model is None: | |
| logger.info("Initializing shared embedding model for RAG retrieval...") | |
| config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN) | |
| return config.global_embed_model | |
| def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): | |
| """Get LLM for RAG indexing (uses medical model)""" | |
| medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL) | |
| return HuggingFaceLLM( | |
| context_window=4096, | |
| max_new_tokens=max_new_tokens, | |
| tokenizer=medical_tokenizer, | |
| model=medical_model_obj, | |
| generate_kwargs={ | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": top_k, | |
| "top_p": top_p | |
| } | |
| ) | |