Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Model inference functions that require GPU. | |
| These functions are tagged with @spaces.GPU(max_duration=120) to ensure | |
| they only run on GPU and don't waste GPU time on CPU operations. | |
| """ | |
| import os | |
| import torch | |
| import logging | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TextIteratorStreamer, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| ) | |
| from llama_index.llms.huggingface import HuggingFaceLLM | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| import spaces | |
| import threading | |
| logger = logging.getLogger(__name__) | |
| # Model configurations | |
| MEDSWIN_MODELS = { | |
| "MedSwin SFT": "MedSwin/MedSwin-7B-SFT", | |
| "MedSwin KD": "MedSwin/MedSwin-7B-KD", | |
| "MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7" | |
| } | |
| DEFAULT_MEDICAL_MODEL = "MedSwin TA" | |
| EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # Global model storage (shared with app.py) | |
| # These will be initialized in app.py and accessed here | |
| global_medical_models = {} | |
| global_medical_tokenizers = {} | |
| def initialize_medical_model(model_name: str): | |
| """Initialize medical model (MedSwin) - download on demand""" | |
| global global_medical_models, global_medical_tokenizers | |
| if model_name not in global_medical_models or global_medical_models[model_name] is None: | |
| logger.info(f"Initializing medical model: {model_name}...") | |
| model_path = MEDSWIN_MODELS[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| token=HF_TOKEN, | |
| torch_dtype=torch.float16 | |
| ) | |
| global_medical_models[model_name] = model | |
| global_medical_tokenizers[model_name] = tokenizer | |
| logger.info(f"Medical model {model_name} initialized successfully") | |
| return global_medical_models[model_name], global_medical_tokenizers[model_name] | |
| def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): | |
| """Get LLM for RAG indexing (uses medical model) - GPU only""" | |
| # Use medical model for RAG indexing instead of translation model | |
| medical_model_obj, medical_tokenizer = initialize_medical_model(DEFAULT_MEDICAL_MODEL) | |
| return HuggingFaceLLM( | |
| context_window=4096, | |
| max_new_tokens=max_new_tokens, | |
| tokenizer=medical_tokenizer, | |
| model=medical_model_obj, | |
| generate_kwargs={ | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": top_k, | |
| "top_p": top_p | |
| } | |
| ) | |
| def get_embedding_model(): | |
| """Get embedding model for RAG - GPU only""" | |
| return HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN) | |
| def generate_with_medswin( | |
| medical_model_obj, | |
| medical_tokenizer, | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| penalty: float, | |
| eos_token_id: int, | |
| pad_token_id: int, | |
| stop_event: threading.Event, | |
| streamer: TextIteratorStreamer, | |
| stopping_criteria: StoppingCriteriaList | |
| ): | |
| """ | |
| Generate text with MedSwin model - GPU only | |
| This function only performs the actual model inference on GPU. | |
| All other operations (prompt preparation, post-processing) should be done outside. | |
| """ | |
| # Tokenize prompt (this is a CPU operation but happens here for simplicity) | |
| # The actual GPU work is in model.generate() | |
| inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device) | |
| # Prepare generation kwargs | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=penalty, | |
| do_sample=True, | |
| stopping_criteria=stopping_criteria, | |
| eos_token_id=eos_token_id, | |
| pad_token_id=pad_token_id | |
| ) | |
| # Run generation on GPU - this is the only GPU operation | |
| medical_model_obj.generate(**generation_kwargs) |