File size: 4,827 Bytes
52b4ed7
 
03d8100
52b4ed7
 
 
 
 
daa4c4c
52b4ed7
 
 
 
 
 
 
 
03d8100
 
 
 
 
63e92ef
 
 
 
03d8100
 
 
 
 
63e92ef
 
 
 
03d8100
 
 
 
 
 
 
 
 
 
52b4ed7
 
 
03d8100
52b4ed7
03d8100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Model initialization and management"""
import torch
import threading
from transformers import AutoModelForCausalLM, AutoTokenizer
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from logger import logger
import config
import spaces

try:
    from TTS.api import TTS
    TTS_AVAILABLE = True
except ImportError:
    TTS_AVAILABLE = False
    TTS = None

# Model loading state tracking
_model_loading_states = {}
_model_loading_lock = threading.Lock()

def set_model_loading_state(model_name: str, state: str):
    """
    Set model loading state: 'loading', 'loaded', 'error'
    Note: No GPU decorator needed - this just sets a dictionary value, no GPU access required.
    """
    with _model_loading_lock:
        _model_loading_states[model_name] = state
        logger.debug(f"Model {model_name} state set to: {state}")

def get_model_loading_state(model_name: str) -> str:
    """
    Get model loading state: 'loading', 'loaded', 'error', or 'unknown'
    Note: No GPU decorator needed - this just reads a dictionary value, no GPU access required.
    """
    with _model_loading_lock:
        return _model_loading_states.get(model_name, "unknown")

def is_model_loaded(model_name: str) -> bool:
    """Check if model is loaded and ready"""
    with _model_loading_lock:
        return (model_name in config.global_medical_models and 
                config.global_medical_models[model_name] is not None and
                _model_loading_states.get(model_name) == "loaded")

def initialize_medical_model(model_name: str):
    """Initialize medical model (MedSwin) - download on demand"""
    if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
        set_model_loading_state(model_name, "loading")
        logger.info(f"Initializing medical model: {model_name}...")
        try:
            model_path = config.MEDSWIN_MODELS[model_name]
            tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                device_map="auto",
                trust_remote_code=True,
                token=config.HF_TOKEN,
                torch_dtype=torch.float16
            )
            config.global_medical_models[model_name] = model
            config.global_medical_tokenizers[model_name] = tokenizer
            set_model_loading_state(model_name, "loaded")
            logger.info(f"Medical model {model_name} initialized successfully")
        except Exception as e:
            set_model_loading_state(model_name, "error")
            logger.error(f"Failed to initialize medical model {model_name}: {e}")
            raise
    else:
        # Model already loaded, ensure state is set
        if get_model_loading_state(model_name) != "loaded":
            set_model_loading_state(model_name, "loaded")
    return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]

def initialize_tts_model():
    """Initialize TTS model for text-to-speech"""
    if not TTS_AVAILABLE:
        logger.warning("TTS library not installed. TTS features will be disabled.")
        return None
    if config.global_tts_model is None:
        try:
            logger.info("Initializing TTS model for voice generation...")
            config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False)
            logger.info("TTS model initialized successfully")
        except Exception as e:
            logger.warning(f"TTS model initialization failed: {e}")
            logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts")
            config.global_tts_model = None
    return config.global_tts_model

def get_or_create_embed_model():
    """Reuse embedding model to avoid reloading weights each request"""
    if config.global_embed_model is None:
        logger.info("Initializing shared embedding model for RAG retrieval...")
        config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN)
    return config.global_embed_model

def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
    """Get LLM for RAG indexing (uses medical model)"""
    medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL)
    
    return HuggingFaceLLM(
        context_window=4096,
        max_new_tokens=max_new_tokens,
        tokenizer=medical_tokenizer,
        model=medical_model_obj,
        generate_kwargs={
            "do_sample": True,
            "temperature": temperature,
            "top_k": top_k,
            "top_p": top_p
        }
    )