LiamKhoaLe commited on
Commit
ec4d4b3
·
1 Parent(s): 3115184

Use GPU dynamically

Browse files
Files changed (2) hide show
  1. app.py +36 -58
  2. model.py +129 -0
app.py CHANGED
@@ -31,6 +31,15 @@ from llama_index.core.retrievers import AutoMergingRetriever
31
  from llama_index.core.storage.docstore import SimpleDocumentStore
32
  from llama_index.llms.huggingface import HuggingFaceLLM
33
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
 
 
 
 
 
 
 
 
 
34
  from tqdm import tqdm
35
  from langdetect import detect, LangDetectException
36
  # MCP imports
@@ -189,9 +198,8 @@ CSS = """
189
  }
190
  """
191
 
192
- # Global model storage
193
- global_medical_models = {}
194
- global_medical_tokenizers = {}
195
  global_file_info = {}
196
  global_tts_model = None
197
 
@@ -454,24 +462,7 @@ async def call_agent(user_prompt: str, system_prompt: str = None, files: list =
454
  logger.debug(traceback.format_exc())
455
  return ""
456
 
457
- def initialize_medical_model(model_name: str):
458
- """Initialize medical model (MedSwin) - download on demand"""
459
- global global_medical_models, global_medical_tokenizers
460
- if model_name not in global_medical_models or global_medical_models[model_name] is None:
461
- logger.info(f"Initializing medical model: {model_name}...")
462
- model_path = MEDSWIN_MODELS[model_name]
463
- tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
464
- model = AutoModelForCausalLM.from_pretrained(
465
- model_path,
466
- device_map="auto",
467
- trust_remote_code=True,
468
- token=HF_TOKEN,
469
- torch_dtype=torch.float16
470
- )
471
- global_medical_models[model_name] = model
472
- global_medical_tokenizers[model_name] = tokenizer
473
- logger.info(f"Medical model {model_name} initialized successfully")
474
- return global_medical_models[model_name], global_medical_tokenizers[model_name]
475
 
476
 
477
  def initialize_tts_model():
@@ -1038,23 +1029,7 @@ def summarize_web_content(content_list: list, query: str) -> str:
1038
  return content_list[0].get('content', '')[:500]
1039
  return ""
1040
 
1041
- def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
1042
- """Get LLM for RAG indexing (uses medical model)"""
1043
- # Use medical model for RAG indexing instead of translation model
1044
- medical_model_obj, medical_tokenizer = initialize_medical_model(DEFAULT_MEDICAL_MODEL)
1045
-
1046
- return HuggingFaceLLM(
1047
- context_window=4096,
1048
- max_new_tokens=max_new_tokens,
1049
- tokenizer=medical_tokenizer,
1050
- model=medical_model_obj,
1051
- generate_kwargs={
1052
- "do_sample": True,
1053
- "temperature": temperature,
1054
- "top_k": top_k,
1055
- "top_p": top_p
1056
- }
1057
- )
1058
 
1059
  async def autonomous_reasoning_gemini(query: str) -> dict:
1060
  """Autonomous reasoning using Gemini MCP"""
@@ -1450,7 +1425,6 @@ def extract_text_from_document(file):
1450
  logger.error(f"Error processing document: {e}")
1451
  return None, 0, ValueError(f"Error processing {file_extension} file: {str(e)}")
1452
 
1453
- @spaces.GPU(max_duration=120)
1454
  def create_or_update_index(files, request: gr.Request):
1455
  global global_file_info
1456
 
@@ -1460,9 +1434,9 @@ def create_or_update_index(files, request: gr.Request):
1460
  start_time = time.time()
1461
  user_id = request.session_hash
1462
  save_dir = f"./{user_id}_index"
1463
- # Initialize LlamaIndex modules
1464
- llm = get_llm_for_rag()
1465
- embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
1466
  Settings.llm = llm
1467
  Settings.embed_model = embed_model
1468
  file_stats = []
@@ -1557,7 +1531,6 @@ def create_or_update_index(files, request: gr.Request):
1557
  output_container += "</div>"
1558
  return f"Successfully indexed {len(files)} files.", output_container
1559
 
1560
- @spaces.GPU(max_duration=120)
1561
  def stream_chat(
1562
  message: str,
1563
  history: list,
@@ -1630,7 +1603,8 @@ def stream_chat(
1630
  rag_context = ""
1631
  source_info = ""
1632
  if final_use_rag and has_rag_index:
1633
- embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
 
1634
  Settings.embed_model = embed_model
1635
  storage_context = StorageContext.from_defaults(persist_dir=index_dir)
1636
  index = load_index_from_storage(storage_context, settings=Settings)
@@ -1775,21 +1749,25 @@ def stream_chat(
1775
  top_k = int(top_k) if isinstance(top_k, (int, float)) else 50
1776
  penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2
1777
 
1778
- generation_kwargs = dict(
1779
- inputs,
1780
- streamer=streamer,
1781
- max_new_tokens=max_new_tokens,
1782
- temperature=temperature,
1783
- top_p=top_p,
1784
- top_k=top_k,
1785
- repetition_penalty=penalty,
1786
- do_sample=True,
1787
- stopping_criteria=stopping_criteria,
1788
- eos_token_id=eos_token_id,
1789
- pad_token_id=medical_tokenizer.pad_token_id or eos_token_id
 
 
 
 
 
 
1790
  )
1791
-
1792
- thread = threading.Thread(target=medical_model_obj.generate, kwargs=generation_kwargs)
1793
  thread.start()
1794
 
1795
  updated_history = history + [
 
31
  from llama_index.core.storage.docstore import SimpleDocumentStore
32
  from llama_index.llms.huggingface import HuggingFaceLLM
33
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
34
+ # Import GPU-tagged model functions
35
+ from model import (
36
+ get_llm_for_rag as get_llm_for_rag_gpu,
37
+ get_embedding_model as get_embedding_model_gpu,
38
+ generate_with_medswin,
39
+ initialize_medical_model,
40
+ global_medical_models,
41
+ global_medical_tokenizers
42
+ )
43
  from tqdm import tqdm
44
  from langdetect import detect, LangDetectException
45
  # MCP imports
 
198
  }
199
  """
200
 
201
+ # Global model storage - models are stored in model.py
202
+ # Import the global model storage from model.py
 
203
  global_file_info = {}
204
  global_tts_model = None
205
 
 
462
  logger.debug(traceback.format_exc())
463
  return ""
464
 
465
+ # initialize_medical_model is now imported from model.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
 
468
  def initialize_tts_model():
 
1029
  return content_list[0].get('content', '')[:500]
1030
  return ""
1031
 
1032
+ # get_llm_for_rag is now imported from model.py as get_llm_for_rag_gpu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1033
 
1034
  async def autonomous_reasoning_gemini(query: str) -> dict:
1035
  """Autonomous reasoning using Gemini MCP"""
 
1425
  logger.error(f"Error processing document: {e}")
1426
  return None, 0, ValueError(f"Error processing {file_extension} file: {str(e)}")
1427
 
 
1428
  def create_or_update_index(files, request: gr.Request):
1429
  global global_file_info
1430
 
 
1434
  start_time = time.time()
1435
  user_id = request.session_hash
1436
  save_dir = f"./{user_id}_index"
1437
+ # Initialize LlamaIndex modules - use GPU functions for model inference only
1438
+ llm = get_llm_for_rag_gpu()
1439
+ embed_model = get_embedding_model_gpu()
1440
  Settings.llm = llm
1441
  Settings.embed_model = embed_model
1442
  file_stats = []
 
1531
  output_container += "</div>"
1532
  return f"Successfully indexed {len(files)} files.", output_container
1533
 
 
1534
  def stream_chat(
1535
  message: str,
1536
  history: list,
 
1603
  rag_context = ""
1604
  source_info = ""
1605
  if final_use_rag and has_rag_index:
1606
+ # Use GPU function for embedding model
1607
+ embed_model = get_embedding_model_gpu()
1608
  Settings.embed_model = embed_model
1609
  storage_context = StorageContext.from_defaults(persist_dir=index_dir)
1610
  index = load_index_from_storage(storage_context, settings=Settings)
 
1749
  top_k = int(top_k) if isinstance(top_k, (int, float)) else 50
1750
  penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2
1751
 
1752
+ # Call GPU function for model inference only
1753
+ thread = threading.Thread(
1754
+ target=generate_with_medswin,
1755
+ kwargs={
1756
+ "medical_model_obj": medical_model_obj,
1757
+ "medical_tokenizer": medical_tokenizer,
1758
+ "prompt": prompt,
1759
+ "max_new_tokens": max_new_tokens,
1760
+ "temperature": temperature,
1761
+ "top_p": top_p,
1762
+ "top_k": top_k,
1763
+ "penalty": penalty,
1764
+ "eos_token_id": eos_token_id,
1765
+ "pad_token_id": medical_tokenizer.pad_token_id or eos_token_id,
1766
+ "stop_event": stop_event,
1767
+ "streamer": streamer,
1768
+ "stopping_criteria": stopping_criteria
1769
+ }
1770
  )
 
 
1771
  thread.start()
1772
 
1773
  updated_history = history + [
model.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model inference functions that require GPU.
3
+ These functions are tagged with @spaces.GPU(max_duration=120) to ensure
4
+ they only run on GPU and don't waste GPU time on CPU operations.
5
+ """
6
+
7
+ import os
8
+ import torch
9
+ import logging
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ TextIteratorStreamer,
14
+ StoppingCriteria,
15
+ StoppingCriteriaList,
16
+ )
17
+ from llama_index.llms.huggingface import HuggingFaceLLM
18
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
19
+ import spaces
20
+ import threading
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Model configurations
25
+ MEDSWIN_MODELS = {
26
+ "MedSwin SFT": "MedSwin/MedSwin-7B-SFT",
27
+ "MedSwin KD": "MedSwin/MedSwin-7B-KD",
28
+ "MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7"
29
+ }
30
+ DEFAULT_MEDICAL_MODEL = "MedSwin TA"
31
+ EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1"
32
+ HF_TOKEN = os.environ.get("HF_TOKEN")
33
+
34
+ # Global model storage (shared with app.py)
35
+ # These will be initialized in app.py and accessed here
36
+ global_medical_models = {}
37
+ global_medical_tokenizers = {}
38
+
39
+
40
+ def initialize_medical_model(model_name: str):
41
+ """Initialize medical model (MedSwin) - download on demand"""
42
+ global global_medical_models, global_medical_tokenizers
43
+ if model_name not in global_medical_models or global_medical_models[model_name] is None:
44
+ logger.info(f"Initializing medical model: {model_name}...")
45
+ model_path = MEDSWIN_MODELS[model_name]
46
+ tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ model_path,
49
+ device_map="auto",
50
+ trust_remote_code=True,
51
+ token=HF_TOKEN,
52
+ torch_dtype=torch.float16
53
+ )
54
+ global_medical_models[model_name] = model
55
+ global_medical_tokenizers[model_name] = tokenizer
56
+ logger.info(f"Medical model {model_name} initialized successfully")
57
+ return global_medical_models[model_name], global_medical_tokenizers[model_name]
58
+
59
+
60
+ @spaces.GPU(max_duration=120)
61
+ def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
62
+ """Get LLM for RAG indexing (uses medical model) - GPU only"""
63
+ # Use medical model for RAG indexing instead of translation model
64
+ medical_model_obj, medical_tokenizer = initialize_medical_model(DEFAULT_MEDICAL_MODEL)
65
+
66
+ return HuggingFaceLLM(
67
+ context_window=4096,
68
+ max_new_tokens=max_new_tokens,
69
+ tokenizer=medical_tokenizer,
70
+ model=medical_model_obj,
71
+ generate_kwargs={
72
+ "do_sample": True,
73
+ "temperature": temperature,
74
+ "top_k": top_k,
75
+ "top_p": top_p
76
+ }
77
+ )
78
+
79
+
80
+ @spaces.GPU(max_duration=120)
81
+ def get_embedding_model():
82
+ """Get embedding model for RAG - GPU only"""
83
+ return HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
84
+
85
+
86
+ @spaces.GPU(max_duration=120)
87
+ def generate_with_medswin(
88
+ medical_model_obj,
89
+ medical_tokenizer,
90
+ prompt: str,
91
+ max_new_tokens: int,
92
+ temperature: float,
93
+ top_p: float,
94
+ top_k: int,
95
+ penalty: float,
96
+ eos_token_id: int,
97
+ pad_token_id: int,
98
+ stop_event: threading.Event,
99
+ streamer: TextIteratorStreamer,
100
+ stopping_criteria: StoppingCriteriaList
101
+ ):
102
+ """
103
+ Generate text with MedSwin model - GPU only
104
+
105
+ This function only performs the actual model inference on GPU.
106
+ All other operations (prompt preparation, post-processing) should be done outside.
107
+ """
108
+ # Tokenize prompt (this is a CPU operation but happens here for simplicity)
109
+ # The actual GPU work is in model.generate()
110
+ inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
111
+
112
+ # Prepare generation kwargs
113
+ generation_kwargs = dict(
114
+ **inputs,
115
+ streamer=streamer,
116
+ max_new_tokens=max_new_tokens,
117
+ temperature=temperature,
118
+ top_p=top_p,
119
+ top_k=top_k,
120
+ repetition_penalty=penalty,
121
+ do_sample=True,
122
+ stopping_criteria=stopping_criteria,
123
+ eos_token_id=eos_token_id,
124
+ pad_token_id=pad_token_id
125
+ )
126
+
127
+ # Run generation on GPU - this is the only GPU operation
128
+ medical_model_obj.generate(**generation_kwargs)
129
+