Spaces:
Running
on
Zero
Running
on
Zero
| """Model initialization and management""" | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from llama_index.llms.huggingface import HuggingFaceLLM | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from logger import logger | |
| import config | |
| try: | |
| from TTS.api import TTS | |
| TTS_AVAILABLE = True | |
| except ImportError: | |
| TTS_AVAILABLE = False | |
| TTS = None | |
| def initialize_medical_model(model_name: str): | |
| """Initialize medical model (MedSwin) - download on demand""" | |
| if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None: | |
| logger.info(f"Initializing medical model: {model_name}...") | |
| model_path = config.MEDSWIN_MODELS[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| token=config.HF_TOKEN, | |
| torch_dtype=torch.float16 | |
| ) | |
| config.global_medical_models[model_name] = model | |
| config.global_medical_tokenizers[model_name] = tokenizer | |
| logger.info(f"Medical model {model_name} initialized successfully") | |
| return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name] | |
| def initialize_tts_model(): | |
| """Initialize TTS model for text-to-speech""" | |
| if not TTS_AVAILABLE: | |
| logger.warning("TTS library not installed. TTS features will be disabled.") | |
| return None | |
| if config.global_tts_model is None: | |
| try: | |
| logger.info("Initializing TTS model for voice generation...") | |
| config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False) | |
| logger.info("TTS model initialized successfully") | |
| except Exception as e: | |
| logger.warning(f"TTS model initialization failed: {e}") | |
| logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts") | |
| config.global_tts_model = None | |
| return config.global_tts_model | |
| def get_or_create_embed_model(): | |
| """Reuse embedding model to avoid reloading weights each request""" | |
| if config.global_embed_model is None: | |
| logger.info("Initializing shared embedding model for RAG retrieval...") | |
| config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN) | |
| return config.global_embed_model | |
| def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): | |
| """Get LLM for RAG indexing (uses medical model)""" | |
| medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL) | |
| return HuggingFaceLLM( | |
| context_window=4096, | |
| max_new_tokens=max_new_tokens, | |
| tokenizer=medical_tokenizer, | |
| model=medical_model_obj, | |
| generate_kwargs={ | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": top_k, | |
| "top_p": top_p | |
| } | |
| ) | |