Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,203 Bytes
ec4d4b3 84f64fc ec4d4b3 84f64fc ec4d4b3 f5fd40b ec4d4b3 84f64fc ec4d4b3 84f64fc ec4d4b3 9c11064 84f64fc ec4d4b3 84f64fc 882865c 84f64fc 9b1b152 84f64fc f5fd40b 84f64fc f5fd40b 84f64fc f5fd40b 628bff1 84f64fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
"""
Model inference functions that require GPU.
These functions are tagged with @spaces.GPU(max_duration=120) to ensure
they only run on GPU and don't waste GPU time on CPU operations.
"""
import os
import torch
import logging
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList,
)
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import spaces
import threading
logger = logging.getLogger(__name__)
# Model configurations
MEDSWIN_MODELS = {
"MedSwin SFT": "MedSwin/MedSwin-7B-SFT",
"MedSwin KD": "MedSwin/MedSwin-7B-KD",
"MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7"
}
DEFAULT_MEDICAL_MODEL = "MedSwin TA"
EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1"
HF_TOKEN = os.environ.get("HF_TOKEN")
# Global model storage (shared with app.py)
# These will be initialized in app.py and accessed here
global_medical_models = {}
global_medical_tokenizers = {}
def initialize_medical_model(model_name: str):
"""Initialize medical model (MedSwin) - download on demand"""
global global_medical_models, global_medical_tokenizers
if model_name not in global_medical_models or global_medical_models[model_name] is None:
logger.info(f"Initializing medical model: {model_name}...")
model_path = MEDSWIN_MODELS[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
trust_remote_code=True,
token=HF_TOKEN,
torch_dtype=torch.float16
)
global_medical_models[model_name] = model
global_medical_tokenizers[model_name] = tokenizer
logger.info(f"Medical model {model_name} initialized successfully")
return global_medical_models[model_name], global_medical_tokenizers[model_name]
@spaces.GPU(max_duration=120)
def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
"""Get LLM for RAG indexing (uses medical model) - GPU only"""
# Use medical model for RAG indexing instead of translation model
medical_model_obj, medical_tokenizer = initialize_medical_model(DEFAULT_MEDICAL_MODEL)
return HuggingFaceLLM(
context_window=4096,
max_new_tokens=max_new_tokens,
tokenizer=medical_tokenizer,
model=medical_model_obj,
generate_kwargs={
"do_sample": True,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p
}
)
@spaces.GPU(max_duration=120)
def get_embedding_model():
"""Get embedding model for RAG - GPU only"""
return HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
@spaces.GPU(max_duration=120)
def generate_with_medswin(
medical_model_obj,
medical_tokenizer,
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
penalty: float,
eos_token_id: int,
pad_token_id: int,
stop_event: threading.Event,
streamer: TextIteratorStreamer,
stopping_criteria: StoppingCriteriaList
):
"""
Generate text with MedSwin model - GPU only
This function only performs the actual model inference on GPU.
All other operations (prompt preparation, post-processing) should be done outside.
"""
# Tokenize prompt (this is a CPU operation but happens here for simplicity)
# The actual GPU work is in model.generate()
inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
# Prepare generation kwargs
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=penalty,
do_sample=True,
stopping_criteria=stopping_criteria,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id
)
# Run generation on GPU - this is the only GPU operation
medical_model_obj.generate(**generation_kwargs) |