LiamKhoaLe commited on
Commit
9b1b152
·
1 Parent(s): 9c11064

Simplify MedSwin #10

Browse files
Files changed (1) hide show
  1. model.py +68 -163
model.py CHANGED
@@ -130,7 +130,8 @@ def get_embedding_model():
130
  return HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
131
 
132
  def _generate_with_medswin_internal(
133
- model_name: str,
 
134
  prompt: str,
135
  max_new_tokens: int,
136
  temperature: float,
@@ -140,21 +141,15 @@ def _generate_with_medswin_internal(
140
  eos_token_id: int,
141
  pad_token_id: int,
142
  prompt_length: int,
143
- min_new_tokens: int = 100
 
 
144
  ):
145
  """
146
- Internal GPU function that only takes picklable arguments.
147
- This function is decorated with @spaces.GPU and creates streamer/stopping criteria internally.
148
-
149
- Returns: TextIteratorStreamer that can be consumed by the caller
150
  """
151
- # Get model and tokenizer from global storage (already loaded)
152
- medical_model_obj = global_medical_models.get(model_name)
153
- medical_tokenizer = global_medical_tokenizers.get(model_name)
154
-
155
- if medical_model_obj is None or medical_tokenizer is None:
156
- raise ValueError(f"Model {model_name} not initialized. Call initialize_medical_model first.")
157
-
158
  # Ensure model is in evaluation mode
159
  medical_model_obj.eval()
160
 
@@ -175,37 +170,38 @@ def _generate_with_medswin_internal(
175
  actual_prompt_length = inputs['input_ids'].shape[1]
176
  logger.info(f"Tokenized prompt: {actual_prompt_length} tokens on device {device}")
177
 
178
- # Create streamer inside GPU function (can't be pickled, so create here)
179
- streamer = TextIteratorStreamer(
180
- medical_tokenizer,
181
- skip_prompt=True,
182
- skip_special_tokens=True,
183
- timeout=None
184
- )
 
185
 
186
- # Create stopping criteria inside GPU function (can't be pickled)
187
- # Use a simple flag-based stopping instead of threading.Event
188
- class SimpleStoppingCriteria(StoppingCriteria):
189
- def __init__(self, eos_token_id, prompt_length, min_new_tokens=100):
190
- super().__init__()
191
- self.eos_token_id = eos_token_id
192
- self.prompt_length = prompt_length
193
- self.min_new_tokens = min_new_tokens
194
 
195
- def __call__(self, input_ids, scores, **kwargs):
196
- current_length = input_ids.shape[1]
197
- new_tokens = current_length - self.prompt_length
198
- last_token = input_ids[0, -1].item()
199
-
200
- # Don't stop on EOS if we haven't generated enough new tokens
201
- if new_tokens < self.min_new_tokens:
202
- return False
203
- # Allow EOS after minimum new tokens have been generated
204
- return last_token == self.eos_token_id
205
-
206
- stopping_criteria = StoppingCriteriaList([
207
- SimpleStoppingCriteria(eos_token_id, actual_prompt_length, min_new_tokens)
208
- ])
209
 
210
  # Prepare generation kwargs - following standard MedAlpaca/LLaMA pattern
211
  # Ensure all parameters are valid and within expected ranges
@@ -235,58 +231,17 @@ def _generate_with_medswin_internal(
235
  generation_kwargs["pad_token_id"] = pad_token_id
236
 
237
  # Run generation on GPU with torch.no_grad() for efficiency
238
- # Start generation in a separate thread so we can return the streamer immediately
239
- def run_generation():
240
- with torch.no_grad():
241
- try:
242
- logger.debug(f"Starting generation with max_new_tokens={max_new_tokens}, temperature={generation_kwargs['temperature']}, top_p={generation_kwargs['top_p']}, top_k={generation_kwargs['top_k']}")
243
- logger.debug(f"EOS token ID: {eos_token_id}, PAD token ID: {pad_token_id}")
244
- medical_model_obj.generate(**generation_kwargs)
245
- except Exception as e:
246
- logger.error(f"Error during generation: {e}")
247
- import traceback
248
- logger.error(traceback.format_exc())
249
- raise
250
-
251
- # Start generation in background thread
252
- gen_thread = threading.Thread(target=run_generation, daemon=True)
253
- gen_thread.start()
254
-
255
- # Return streamer so caller can consume it
256
- return streamer
257
-
258
-
259
- @spaces.GPU(max_duration=120)
260
- def generate_with_medswin_gpu(
261
- model_name: str,
262
- prompt: str,
263
- max_new_tokens: int,
264
- temperature: float,
265
- top_p: float,
266
- top_k: int,
267
- penalty: float,
268
- eos_token_id: int,
269
- pad_token_id: int,
270
- prompt_length: int,
271
- min_new_tokens: int = 100
272
- ):
273
- """
274
- GPU-decorated wrapper that only takes picklable arguments.
275
- This function is called by generate_with_medswin which handles unpicklable objects.
276
- """
277
- return _generate_with_medswin_internal(
278
- model_name=model_name,
279
- prompt=prompt,
280
- max_new_tokens=max_new_tokens,
281
- temperature=temperature,
282
- top_p=top_p,
283
- top_k=top_k,
284
- penalty=penalty,
285
- eos_token_id=eos_token_id,
286
- pad_token_id=pad_token_id,
287
- prompt_length=prompt_length,
288
- min_new_tokens=min_new_tokens
289
- )
290
 
291
 
292
  def generate_with_medswin(
@@ -305,24 +260,18 @@ def generate_with_medswin(
305
  stopping_criteria: StoppingCriteriaList
306
  ):
307
  """
308
- Public API function that maintains backward compatibility.
309
- This function is NOT decorated with @spaces.GPU to avoid pickling issues.
310
- It calls the GPU-decorated function internally.
311
-
312
- Note: stop_event and the original streamer/stopping_criteria are kept for API compatibility
313
- but the actual generation uses new objects created inside the GPU function.
314
- """
315
- # Get model name from global storage (find which model this is)
316
- model_name = None
317
- for name, model in global_medical_models.items():
318
- if model is medical_model_obj:
319
- model_name = name
320
- break
321
 
322
- if model_name is None:
323
- raise ValueError("Model not found in global storage. Ensure model is initialized via initialize_medical_model.")
 
 
324
 
325
- # Calculate prompt length for stopping criteria
 
 
 
 
326
  inputs = medical_tokenizer(
327
  prompt,
328
  return_tensors="pt",
@@ -332,10 +281,11 @@ def generate_with_medswin(
332
  )
333
  prompt_length = inputs['input_ids'].shape[1]
334
 
335
- # Call GPU function with only picklable arguments
336
- # The GPU function will create its own streamer and stopping criteria
337
- gpu_streamer = generate_with_medswin_gpu(
338
- model_name=model_name,
 
339
  prompt=prompt,
340
  max_new_tokens=max_new_tokens,
341
  temperature=temperature,
@@ -345,56 +295,11 @@ def generate_with_medswin(
345
  eos_token_id=eos_token_id,
346
  pad_token_id=pad_token_id,
347
  prompt_length=prompt_length,
348
- min_new_tokens=100
 
 
349
  )
350
 
351
- # Copy tokens from GPU streamer to the original streamer
352
- # TextIteratorStreamer uses a queue internally (usually named 'queue' or '_queue')
353
- # We need to read from GPU streamer and write to the original streamer's queue
354
- def copy_stream():
355
- try:
356
- # Find the queue in the original streamer
357
- streamer_queue = None
358
- if hasattr(streamer, 'queue'):
359
- streamer_queue = streamer.queue
360
- elif hasattr(streamer, '_queue'):
361
- streamer_queue = streamer._queue
362
- else:
363
- # Try to get queue from tokenizer's queue if available
364
- logger.warning("Could not find streamer queue attribute, trying alternative method")
365
- # TextIteratorStreamer might store queue differently - check all attributes
366
- for attr in dir(streamer):
367
- if 'queue' in attr.lower() and not attr.startswith('__'):
368
- try:
369
- streamer_queue = getattr(streamer, attr)
370
- if hasattr(streamer_queue, 'put'):
371
- break
372
- except:
373
- pass
374
-
375
- if streamer_queue is None:
376
- logger.error("Could not access streamer queue - tokens will be lost!")
377
- return
378
-
379
- # Read tokens from GPU streamer and put them into original streamer's queue
380
- for token in gpu_streamer:
381
- streamer_queue.put(token)
382
-
383
- # Signal end of stream (TextIteratorStreamer uses None or StopIteration)
384
- try:
385
- streamer_queue.put(None)
386
- except:
387
- pass
388
-
389
- except Exception as e:
390
- logger.error(f"Error copying stream: {e}")
391
- import traceback
392
- logger.error(traceback.format_exc())
393
-
394
- # Start copying in background
395
- copy_thread = threading.Thread(target=copy_stream, daemon=True)
396
- copy_thread.start()
397
-
398
- # Return immediately - caller will consume from original streamer
399
  return
400
 
 
130
  return HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
131
 
132
  def _generate_with_medswin_internal(
133
+ medical_model_obj,
134
+ medical_tokenizer,
135
  prompt: str,
136
  max_new_tokens: int,
137
  temperature: float,
 
141
  eos_token_id: int,
142
  pad_token_id: int,
143
  prompt_length: int,
144
+ min_new_tokens: int = 100,
145
+ streamer: TextIteratorStreamer = None,
146
+ stopping_criteria: StoppingCriteriaList = None
147
  ):
148
  """
149
+ Internal generation function that runs directly on GPU.
150
+ Model is already on GPU via device_map="auto", so no @spaces.GPU decorator needed.
151
+ This avoids pickling issues with streamer and stopping_criteria.
 
152
  """
 
 
 
 
 
 
 
153
  # Ensure model is in evaluation mode
154
  medical_model_obj.eval()
155
 
 
170
  actual_prompt_length = inputs['input_ids'].shape[1]
171
  logger.info(f"Tokenized prompt: {actual_prompt_length} tokens on device {device}")
172
 
173
+ # Use provided streamer and stopping_criteria (created in caller to avoid pickling)
174
+ if streamer is None:
175
+ streamer = TextIteratorStreamer(
176
+ medical_tokenizer,
177
+ skip_prompt=True,
178
+ skip_special_tokens=True,
179
+ timeout=None
180
+ )
181
 
182
+ if stopping_criteria is None:
183
+ # Create simple stopping criteria if not provided
184
+ class SimpleStoppingCriteria(StoppingCriteria):
185
+ def __init__(self, eos_token_id, prompt_length, min_new_tokens=100):
186
+ super().__init__()
187
+ self.eos_token_id = eos_token_id
188
+ self.prompt_length = prompt_length
189
+ self.min_new_tokens = min_new_tokens
190
 
191
+ def __call__(self, input_ids, scores, **kwargs):
192
+ current_length = input_ids.shape[1]
193
+ new_tokens = current_length - self.prompt_length
194
+ last_token = input_ids[0, -1].item()
195
+
196
+ # Don't stop on EOS if we haven't generated enough new tokens
197
+ if new_tokens < self.min_new_tokens:
198
+ return False
199
+ # Allow EOS after minimum new tokens have been generated
200
+ return last_token == self.eos_token_id
201
+
202
+ stopping_criteria = StoppingCriteriaList([
203
+ SimpleStoppingCriteria(eos_token_id, actual_prompt_length, min_new_tokens)
204
+ ])
205
 
206
  # Prepare generation kwargs - following standard MedAlpaca/LLaMA pattern
207
  # Ensure all parameters are valid and within expected ranges
 
231
  generation_kwargs["pad_token_id"] = pad_token_id
232
 
233
  # Run generation on GPU with torch.no_grad() for efficiency
234
+ # Model is already on GPU, so this will run on GPU automatically
235
+ with torch.no_grad():
236
+ try:
237
+ logger.debug(f"Starting generation with max_new_tokens={max_new_tokens}, temperature={generation_kwargs['temperature']}, top_p={generation_kwargs['top_p']}, top_k={generation_kwargs['top_k']}")
238
+ logger.debug(f"EOS token ID: {eos_token_id}, PAD token ID: {pad_token_id}")
239
+ medical_model_obj.generate(**generation_kwargs)
240
+ except Exception as e:
241
+ logger.error(f"Error during generation: {e}")
242
+ import traceback
243
+ logger.error(traceback.format_exc())
244
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
 
247
  def generate_with_medswin(
 
260
  stopping_criteria: StoppingCriteriaList
261
  ):
262
  """
263
+ Public API function for model generation.
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
+ This function is NOT decorated with @spaces.GPU because:
266
+ 1. The model is already on GPU via device_map="auto" during initialization
267
+ 2. Generation will automatically run on GPU where the model is located
268
+ 3. This avoids pickling issues with streamer, stop_event, and stopping_criteria
269
 
270
+ The @spaces.GPU decorator is only needed for model loading, which is handled
271
+ separately in initialize_medical_model (though that also doesn't need it since
272
+ device_map="auto" handles GPU placement).
273
+ """
274
+ # Calculate prompt length for stopping criteria (if not already calculated)
275
  inputs = medical_tokenizer(
276
  prompt,
277
  return_tensors="pt",
 
281
  )
282
  prompt_length = inputs['input_ids'].shape[1]
283
 
284
+ # Call internal generation function directly
285
+ # Model is already on GPU, so generation will happen on GPU automatically
286
+ _generate_with_medswin_internal(
287
+ medical_model_obj=medical_model_obj,
288
+ medical_tokenizer=medical_tokenizer,
289
  prompt=prompt,
290
  max_new_tokens=max_new_tokens,
291
  temperature=temperature,
 
295
  eos_token_id=eos_token_id,
296
  pad_token_id=pad_token_id,
297
  prompt_length=prompt_length,
298
+ min_new_tokens=100,
299
+ streamer=streamer, # Use the provided streamer (created in caller)
300
+ stopping_criteria=stopping_criteria # Use the provided stopping criteria
301
  )
302
 
303
+ # Function returns immediately - generation happens in background via streamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  return
305