LiamKhoaLe commited on
Commit
9c11064
·
1 Parent(s): dcc293a

Simplify MedSwin #9

Browse files
Files changed (2) hide show
  1. app.py +20 -4
  2. model.py +258 -31
app.py CHANGED
@@ -725,6 +725,7 @@ def format_prompt_manually(messages: list, tokenizer) -> str:
725
  - Simple Question/Answer format
726
  - System prompt as instruction context
727
  - Clean formatting without extra special tokens
 
728
  """
729
  # Combine system and user messages into a single instruction
730
  system_content = ""
@@ -744,12 +745,17 @@ def format_prompt_manually(messages: list, tokenizer) -> str:
744
 
745
  # Format for MedAlpaca/LLaMA-based medical models
746
  # Common format: Instruction + Input -> Response
747
- # Following the exact example pattern
 
748
  if system_content:
 
749
  prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:"
750
  else:
751
  prompt = f"Question: {user_content}\n\nAnswer:"
752
 
 
 
 
753
  return prompt
754
 
755
  def detect_language(text: str) -> str:
@@ -1801,8 +1807,15 @@ def stream_chat(
1801
  prompt = format_prompt_manually(messages, medical_tokenizer)
1802
 
1803
  # Calculate prompt length for stopping criteria
1804
- # Tokenize to get length - use same tokenization as model.py (simple, no extra params)
1805
- inputs = medical_tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
1806
  prompt_length = inputs['input_ids'].shape[1]
1807
  logger.debug(f"Prompt length: {prompt_length} tokens")
1808
 
@@ -1844,10 +1857,13 @@ def stream_chat(
1844
  MedicalStoppingCriteria(eos_token_id, prompt_length, min_new_tokens=100)
1845
  ])
1846
 
 
 
1847
  streamer = TextIteratorStreamer(
1848
  medical_tokenizer,
1849
  skip_prompt=True,
1850
- skip_special_tokens=True
 
1851
  )
1852
 
1853
  temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.7
 
725
  - Simple Question/Answer format
726
  - System prompt as instruction context
727
  - Clean formatting without extra special tokens
728
+ - Ensure no double special tokens are added
729
  """
730
  # Combine system and user messages into a single instruction
731
  system_content = ""
 
745
 
746
  # Format for MedAlpaca/LLaMA-based medical models
747
  # Common format: Instruction + Input -> Response
748
+ # Following the exact example pattern - keep it simple and clean
749
+ # The tokenizer will add BOS token automatically, so we don't add it here
750
  if system_content:
751
+ # Clean format: system instruction, then question, then answer prompt
752
  prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:"
753
  else:
754
  prompt = f"Question: {user_content}\n\nAnswer:"
755
 
756
+ # Ensure prompt is clean (no extra whitespace or special characters)
757
+ prompt = prompt.strip()
758
+
759
  return prompt
760
 
761
  def detect_language(text: str) -> str:
 
1807
  prompt = format_prompt_manually(messages, medical_tokenizer)
1808
 
1809
  # Calculate prompt length for stopping criteria
1810
+ # Tokenize to get length - use EXACT same tokenization as model.py
1811
+ # This ensures consistency and prevents tokenization mismatches
1812
+ inputs = medical_tokenizer(
1813
+ prompt,
1814
+ return_tensors="pt",
1815
+ add_special_tokens=True, # Match model.py tokenization
1816
+ padding=False,
1817
+ truncation=False
1818
+ )
1819
  prompt_length = inputs['input_ids'].shape[1]
1820
  logger.debug(f"Prompt length: {prompt_length} tokens")
1821
 
 
1857
  MedicalStoppingCriteria(eos_token_id, prompt_length, min_new_tokens=100)
1858
  ])
1859
 
1860
+ # Create streamer with correct settings for LLaMA-based models
1861
+ # skip_special_tokens=True ensures clean text output without special token artifacts
1862
  streamer = TextIteratorStreamer(
1863
  medical_tokenizer,
1864
  skip_prompt=True,
1865
+ skip_special_tokens=True, # Skip special tokens in output for clean text
1866
+ timeout=None # Don't timeout on long generations
1867
  )
1868
 
1869
  temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.7
model.py CHANGED
@@ -45,6 +45,7 @@ def initialize_medical_model(model_name: str):
45
  - Model loading with device_map="auto" for ZeroGPU Spaces
46
  - Proper pad_token setup for LLaMA-based models
47
  - Float16 for memory efficiency
 
48
  """
49
  global global_medical_models, global_medical_tokenizers
50
 
@@ -53,13 +54,34 @@ def initialize_medical_model(model_name: str):
53
  model_path = MEDSWIN_MODELS[model_name]
54
 
55
  # Load tokenizer - simple and clean, following example pattern
56
- tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # LLaMA models don't have pad_token by default, set it to eos_token
59
  if tokenizer.pad_token is None:
60
  tokenizer.pad_token = tokenizer.eos_token
61
  tokenizer.pad_token_id = tokenizer.eos_token_id
62
 
 
 
 
 
 
 
 
63
  # Load model - use device_map="auto" for ZeroGPU Spaces
64
  model = AutoModelForCausalLM.from_pretrained(
65
  model_path,
@@ -79,6 +101,7 @@ def initialize_medical_model(model_name: str):
79
  logger.info(f"Tokenizer vocab size: {len(tokenizer)}")
80
  logger.info(f"EOS token: {tokenizer.eos_token} (id: {tokenizer.eos_token_id})")
81
  logger.info(f"PAD token: {tokenizer.pad_token} (id: {tokenizer.pad_token_id})")
 
82
 
83
  return global_medical_models[model_name], global_medical_tokenizers[model_name]
84
 
@@ -106,7 +129,166 @@ def get_embedding_model():
106
  """Get embedding model for RAG - GPU only"""
107
  return HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  @spaces.GPU(max_duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def generate_with_medswin(
111
  medical_model_obj,
112
  medical_tokenizer,
@@ -123,51 +305,96 @@ def generate_with_medswin(
123
  stopping_criteria: StoppingCriteriaList
124
  ):
125
  """
126
- Generate text with MedSwin model - following standard MedAlpaca/LLaMA inference pattern
 
 
127
 
128
- Key points for proper generation:
129
- - Simple tokenization without over-complication
130
- - Correct device placement for ZeroGPU
131
- - Standard generation kwargs for LLaMA-based models
132
- - Proper handling of special tokens
133
  """
134
- # Ensure model is in evaluation mode
135
- medical_model_obj.eval()
 
 
 
 
136
 
137
- # Get device - handle device_map="auto" case
138
- device = next(medical_model_obj.parameters()).device
139
-
140
- # Tokenize prompt - simple and clean, following example pattern
141
- # For LLaMA-based models, tokenizer handles special tokens automatically
142
- inputs = medical_tokenizer(prompt, return_tensors="pt").to(device)
143
 
144
- # Log tokenization info for debugging
 
 
 
 
 
 
 
145
  prompt_length = inputs['input_ids'].shape[1]
146
- logger.info(f"Tokenized prompt: {prompt_length} tokens on device {device}")
147
 
148
- # Prepare generation kwargs - following standard MedAlpaca/LLaMA pattern
149
- generation_kwargs = dict(
150
- inputs,
151
- streamer=streamer,
 
152
  max_new_tokens=max_new_tokens,
153
  temperature=temperature,
154
  top_p=top_p,
155
  top_k=top_k,
156
- repetition_penalty=penalty,
157
- do_sample=True,
158
- stopping_criteria=stopping_criteria,
159
  eos_token_id=eos_token_id,
160
- pad_token_id=pad_token_id
 
 
161
  )
162
 
163
- # Run generation on GPU with torch.no_grad() for efficiency
164
- with torch.no_grad():
 
 
165
  try:
166
- logger.debug(f"Starting generation with max_new_tokens={max_new_tokens}, temperature={temperature}, top_p={top_p}, top_k={top_k}")
167
- medical_model_obj.generate(**generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  except Exception as e:
169
- logger.error(f"Error during generation: {e}")
170
  import traceback
171
  logger.error(traceback.format_exc())
172
- raise
 
 
 
 
 
 
173
 
 
45
  - Model loading with device_map="auto" for ZeroGPU Spaces
46
  - Proper pad_token setup for LLaMA-based models
47
  - Float16 for memory efficiency
48
+ - Ensure tokenizer padding side is set correctly
49
  """
50
  global global_medical_models, global_medical_tokenizers
51
 
 
54
  model_path = MEDSWIN_MODELS[model_name]
55
 
56
  # Load tokenizer - simple and clean, following example pattern
57
+ # Use fast tokenizer if available (default), fallback to slow if needed
58
+ try:
59
+ tokenizer = AutoTokenizer.from_pretrained(
60
+ model_path,
61
+ token=HF_TOKEN,
62
+ trust_remote_code=True
63
+ )
64
+ except Exception as e:
65
+ logger.warning(f"Failed to load fast tokenizer, trying slow tokenizer: {e}")
66
+ tokenizer = AutoTokenizer.from_pretrained(
67
+ model_path,
68
+ token=HF_TOKEN,
69
+ use_fast=False,
70
+ trust_remote_code=True
71
+ )
72
 
73
  # LLaMA models don't have pad_token by default, set it to eos_token
74
  if tokenizer.pad_token is None:
75
  tokenizer.pad_token = tokenizer.eos_token
76
  tokenizer.pad_token_id = tokenizer.eos_token_id
77
 
78
+ # Set padding side to left for generation (LLaMA models expect this)
79
+ tokenizer.padding_side = "left"
80
+
81
+ # Ensure tokenizer is properly configured
82
+ if not hasattr(tokenizer, 'model_max_length') or tokenizer.model_max_length is None:
83
+ tokenizer.model_max_length = 4096
84
+
85
  # Load model - use device_map="auto" for ZeroGPU Spaces
86
  model = AutoModelForCausalLM.from_pretrained(
87
  model_path,
 
101
  logger.info(f"Tokenizer vocab size: {len(tokenizer)}")
102
  logger.info(f"EOS token: {tokenizer.eos_token} (id: {tokenizer.eos_token_id})")
103
  logger.info(f"PAD token: {tokenizer.pad_token} (id: {tokenizer.pad_token_id})")
104
+ logger.info(f"Tokenizer padding side: {tokenizer.padding_side}")
105
 
106
  return global_medical_models[model_name], global_medical_tokenizers[model_name]
107
 
 
129
  """Get embedding model for RAG - GPU only"""
130
  return HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
131
 
132
+ def _generate_with_medswin_internal(
133
+ model_name: str,
134
+ prompt: str,
135
+ max_new_tokens: int,
136
+ temperature: float,
137
+ top_p: float,
138
+ top_k: int,
139
+ penalty: float,
140
+ eos_token_id: int,
141
+ pad_token_id: int,
142
+ prompt_length: int,
143
+ min_new_tokens: int = 100
144
+ ):
145
+ """
146
+ Internal GPU function that only takes picklable arguments.
147
+ This function is decorated with @spaces.GPU and creates streamer/stopping criteria internally.
148
+
149
+ Returns: TextIteratorStreamer that can be consumed by the caller
150
+ """
151
+ # Get model and tokenizer from global storage (already loaded)
152
+ medical_model_obj = global_medical_models.get(model_name)
153
+ medical_tokenizer = global_medical_tokenizers.get(model_name)
154
+
155
+ if medical_model_obj is None or medical_tokenizer is None:
156
+ raise ValueError(f"Model {model_name} not initialized. Call initialize_medical_model first.")
157
+
158
+ # Ensure model is in evaluation mode
159
+ medical_model_obj.eval()
160
+
161
+ # Get device - handle device_map="auto" case
162
+ device = next(medical_model_obj.parameters()).device
163
+
164
+ # Tokenize prompt - CRITICAL: use consistent tokenization settings
165
+ # For LLaMA-based models, the tokenizer automatically adds BOS token
166
+ inputs = medical_tokenizer(
167
+ prompt,
168
+ return_tensors="pt",
169
+ add_special_tokens=True, # Let tokenizer add BOS/EOS as needed
170
+ padding=False, # No padding for single sequence generation
171
+ truncation=False # Don't truncate - let model handle length
172
+ ).to(device)
173
+
174
+ # Log tokenization info for debugging
175
+ actual_prompt_length = inputs['input_ids'].shape[1]
176
+ logger.info(f"Tokenized prompt: {actual_prompt_length} tokens on device {device}")
177
+
178
+ # Create streamer inside GPU function (can't be pickled, so create here)
179
+ streamer = TextIteratorStreamer(
180
+ medical_tokenizer,
181
+ skip_prompt=True,
182
+ skip_special_tokens=True,
183
+ timeout=None
184
+ )
185
+
186
+ # Create stopping criteria inside GPU function (can't be pickled)
187
+ # Use a simple flag-based stopping instead of threading.Event
188
+ class SimpleStoppingCriteria(StoppingCriteria):
189
+ def __init__(self, eos_token_id, prompt_length, min_new_tokens=100):
190
+ super().__init__()
191
+ self.eos_token_id = eos_token_id
192
+ self.prompt_length = prompt_length
193
+ self.min_new_tokens = min_new_tokens
194
+
195
+ def __call__(self, input_ids, scores, **kwargs):
196
+ current_length = input_ids.shape[1]
197
+ new_tokens = current_length - self.prompt_length
198
+ last_token = input_ids[0, -1].item()
199
+
200
+ # Don't stop on EOS if we haven't generated enough new tokens
201
+ if new_tokens < self.min_new_tokens:
202
+ return False
203
+ # Allow EOS after minimum new tokens have been generated
204
+ return last_token == self.eos_token_id
205
+
206
+ stopping_criteria = StoppingCriteriaList([
207
+ SimpleStoppingCriteria(eos_token_id, actual_prompt_length, min_new_tokens)
208
+ ])
209
+
210
+ # Prepare generation kwargs - following standard MedAlpaca/LLaMA pattern
211
+ # Ensure all parameters are valid and within expected ranges
212
+ generation_kwargs = {
213
+ **inputs, # Unpack input_ids and attention_mask
214
+ "streamer": streamer,
215
+ "max_new_tokens": max_new_tokens,
216
+ "temperature": max(0.01, min(temperature, 2.0)), # Clamp temperature to valid range
217
+ "top_p": max(0.0, min(top_p, 1.0)), # Clamp top_p to valid range
218
+ "top_k": max(1, int(top_k)), # Ensure top_k is at least 1
219
+ "repetition_penalty": max(1.0, min(penalty, 2.0)), # Clamp repetition_penalty
220
+ "do_sample": True,
221
+ "stopping_criteria": stopping_criteria,
222
+ "eos_token_id": eos_token_id,
223
+ "pad_token_id": pad_token_id
224
+ }
225
+
226
+ # Validate token IDs are valid
227
+ if eos_token_id is None or eos_token_id < 0:
228
+ logger.warning(f"Invalid EOS token ID: {eos_token_id}, using tokenizer default")
229
+ eos_token_id = medical_tokenizer.eos_token_id or medical_tokenizer.pad_token_id
230
+ generation_kwargs["eos_token_id"] = eos_token_id
231
+
232
+ if pad_token_id is None or pad_token_id < 0:
233
+ logger.warning(f"Invalid PAD token ID: {pad_token_id}, using EOS token")
234
+ pad_token_id = eos_token_id
235
+ generation_kwargs["pad_token_id"] = pad_token_id
236
+
237
+ # Run generation on GPU with torch.no_grad() for efficiency
238
+ # Start generation in a separate thread so we can return the streamer immediately
239
+ def run_generation():
240
+ with torch.no_grad():
241
+ try:
242
+ logger.debug(f"Starting generation with max_new_tokens={max_new_tokens}, temperature={generation_kwargs['temperature']}, top_p={generation_kwargs['top_p']}, top_k={generation_kwargs['top_k']}")
243
+ logger.debug(f"EOS token ID: {eos_token_id}, PAD token ID: {pad_token_id}")
244
+ medical_model_obj.generate(**generation_kwargs)
245
+ except Exception as e:
246
+ logger.error(f"Error during generation: {e}")
247
+ import traceback
248
+ logger.error(traceback.format_exc())
249
+ raise
250
+
251
+ # Start generation in background thread
252
+ gen_thread = threading.Thread(target=run_generation, daemon=True)
253
+ gen_thread.start()
254
+
255
+ # Return streamer so caller can consume it
256
+ return streamer
257
+
258
+
259
  @spaces.GPU(max_duration=120)
260
+ def generate_with_medswin_gpu(
261
+ model_name: str,
262
+ prompt: str,
263
+ max_new_tokens: int,
264
+ temperature: float,
265
+ top_p: float,
266
+ top_k: int,
267
+ penalty: float,
268
+ eos_token_id: int,
269
+ pad_token_id: int,
270
+ prompt_length: int,
271
+ min_new_tokens: int = 100
272
+ ):
273
+ """
274
+ GPU-decorated wrapper that only takes picklable arguments.
275
+ This function is called by generate_with_medswin which handles unpicklable objects.
276
+ """
277
+ return _generate_with_medswin_internal(
278
+ model_name=model_name,
279
+ prompt=prompt,
280
+ max_new_tokens=max_new_tokens,
281
+ temperature=temperature,
282
+ top_p=top_p,
283
+ top_k=top_k,
284
+ penalty=penalty,
285
+ eos_token_id=eos_token_id,
286
+ pad_token_id=pad_token_id,
287
+ prompt_length=prompt_length,
288
+ min_new_tokens=min_new_tokens
289
+ )
290
+
291
+
292
  def generate_with_medswin(
293
  medical_model_obj,
294
  medical_tokenizer,
 
305
  stopping_criteria: StoppingCriteriaList
306
  ):
307
  """
308
+ Public API function that maintains backward compatibility.
309
+ This function is NOT decorated with @spaces.GPU to avoid pickling issues.
310
+ It calls the GPU-decorated function internally.
311
 
312
+ Note: stop_event and the original streamer/stopping_criteria are kept for API compatibility
313
+ but the actual generation uses new objects created inside the GPU function.
 
 
 
314
  """
315
+ # Get model name from global storage (find which model this is)
316
+ model_name = None
317
+ for name, model in global_medical_models.items():
318
+ if model is medical_model_obj:
319
+ model_name = name
320
+ break
321
 
322
+ if model_name is None:
323
+ raise ValueError("Model not found in global storage. Ensure model is initialized via initialize_medical_model.")
 
 
 
 
324
 
325
+ # Calculate prompt length for stopping criteria
326
+ inputs = medical_tokenizer(
327
+ prompt,
328
+ return_tensors="pt",
329
+ add_special_tokens=True,
330
+ padding=False,
331
+ truncation=False
332
+ )
333
  prompt_length = inputs['input_ids'].shape[1]
 
334
 
335
+ # Call GPU function with only picklable arguments
336
+ # The GPU function will create its own streamer and stopping criteria
337
+ gpu_streamer = generate_with_medswin_gpu(
338
+ model_name=model_name,
339
+ prompt=prompt,
340
  max_new_tokens=max_new_tokens,
341
  temperature=temperature,
342
  top_p=top_p,
343
  top_k=top_k,
344
+ penalty=penalty,
 
 
345
  eos_token_id=eos_token_id,
346
+ pad_token_id=pad_token_id,
347
+ prompt_length=prompt_length,
348
+ min_new_tokens=100
349
  )
350
 
351
+ # Copy tokens from GPU streamer to the original streamer
352
+ # TextIteratorStreamer uses a queue internally (usually named 'queue' or '_queue')
353
+ # We need to read from GPU streamer and write to the original streamer's queue
354
+ def copy_stream():
355
  try:
356
+ # Find the queue in the original streamer
357
+ streamer_queue = None
358
+ if hasattr(streamer, 'queue'):
359
+ streamer_queue = streamer.queue
360
+ elif hasattr(streamer, '_queue'):
361
+ streamer_queue = streamer._queue
362
+ else:
363
+ # Try to get queue from tokenizer's queue if available
364
+ logger.warning("Could not find streamer queue attribute, trying alternative method")
365
+ # TextIteratorStreamer might store queue differently - check all attributes
366
+ for attr in dir(streamer):
367
+ if 'queue' in attr.lower() and not attr.startswith('__'):
368
+ try:
369
+ streamer_queue = getattr(streamer, attr)
370
+ if hasattr(streamer_queue, 'put'):
371
+ break
372
+ except:
373
+ pass
374
+
375
+ if streamer_queue is None:
376
+ logger.error("Could not access streamer queue - tokens will be lost!")
377
+ return
378
+
379
+ # Read tokens from GPU streamer and put them into original streamer's queue
380
+ for token in gpu_streamer:
381
+ streamer_queue.put(token)
382
+
383
+ # Signal end of stream (TextIteratorStreamer uses None or StopIteration)
384
+ try:
385
+ streamer_queue.put(None)
386
+ except:
387
+ pass
388
+
389
  except Exception as e:
390
+ logger.error(f"Error copying stream: {e}")
391
  import traceback
392
  logger.error(traceback.format_exc())
393
+
394
+ # Start copying in background
395
+ copy_thread = threading.Thread(target=copy_stream, daemon=True)
396
+ copy_thread.start()
397
+
398
+ # Return immediately - caller will consume from original streamer
399
+ return
400