jnjj commited on
Commit
e58f514
·
verified ·
1 Parent(s): 7dcc63f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -107
app.py CHANGED
@@ -231,7 +231,7 @@ def get_stopping_criteria(req: GenerateRequest, initial_ids: torch.Tensor, token
231
  if req.max_length is not None and req.max_length > 0:
232
  max_len_from_req = req.max_length
233
  if max_len_from_req <= initial_len:
234
- logger.warning(f"Requested max_length ({req.max_length}) is less than or equal to prompt length ({initial_len}). Generation will stop immediately.")
235
  elif req.max_new_tokens is not None and req.max_new_tokens > 0:
236
  max_len_from_req = initial_len + req.max_new_tokens
237
  if model_max_len is not None:
@@ -264,111 +264,141 @@ def get_stopping_criteria(req: GenerateRequest, initial_ids: torch.Tensor, token
264
  logger.error(f"Failed to create StopSequenceCriteria: {e}", exc_info=True)
265
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create StopSequenceCriteria: {e}")
266
  return criteria
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  async def stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> AsyncGenerator[str, None]:
268
- initial_len = initial_ids.shape[-1]
269
- full_sequence: List[int] = initial_ids.tolist()[0]
270
  generated_tokens_count = 0
271
- start_time = time.time()
272
- finish_reason = "unknown"
273
- eos_token_id = gen_cfg.eos_token_id
274
- pad_token_id = gen_cfg.pad_token_id
275
  stop_token_ids = set()
276
  if eos_token_id is not None:
277
  stop_token_ids.add(eos_token_id)
278
  if pad_token_id is not None and pad_token_id != eos_token_id:
279
  stop_token_ids.add(pad_token_id)
280
- stopping_criteria_list = get_stopping_criteria(req, initial_ids.to('cpu'), global_tokenizer)
281
- stop_sequence_criteria = None
282
- for crit in stopping_criteria_list:
283
- if isinstance(crit, StopSequenceCriteria):
284
- stop_sequence_criteria = crit
285
- break
286
- gen_cfg.use_cache = True
287
- gen_cfg.num_beams = 1
288
- gen_cfg.num_return_sequences = 1
289
- gen_cfg.num_beam_groups = 1
 
290
  model_total_capacity = getattr(global_model.config, 'max_position_embeddings', None)
291
  if model_total_capacity is None:
292
  model_total_capacity = MAX_CONTEXT_TOKENS + MAX_GENERATION_TOKENS
293
  effective_max_total_length = req.max_length if req.max_length is not None else initial_len + req.max_new_tokens
294
  if effective_max_total_length > model_total_capacity:
295
  effective_max_total_length = model_total_capacity
296
- logger.info(f"Starting stream generation (using HF stream): max_new_tokens={req.max_new_tokens}, max_length={req.max_length}, max_time={req.max_time}, initial_len={initial_len}, effective_max_total_length={effective_max_total_length}")
 
 
 
 
297
  try:
298
- last_sequence_len = initial_len
299
- async for generation_output in global_model.generate(
300
- input_ids=initial_ids.to(global_model.device),
301
- generation_config=gen_cfg,
302
- temperature=req.temperature,
303
- top_k=req.top_k,
304
- top_p=req.top_p,
305
- repetition_penalty=req.repetition_penalty,
306
- frequency_penalty=req.frequency_penalty,
307
- presence_penalty=req.presence_penalty,
308
- do_sample=req.do_sample,
309
- forced_bos_token_id=req.forced_bos_token_id,
310
- forced_eos_token_id=req.forced_eos_token_id,
311
- encoder_no_repeat_ngram_size=req.encoder_no_repeat_ngram_size,
312
- exponential_decay_length_penalty=req.exponential_decay_length_penalty,
313
- typical_p=req.typical_p,
314
- encoder_repetition_penalty=req.encoder_repetition_penalty,
315
- diversity_penalty=req.diversity_penalty,
316
- length_normalization_factor=req.length_normalization_factor,
317
- min_new_tokens=req.min_new_tokens,
318
- do_normalize_logits=req.do_normalize_logits,
319
- stream=True,
320
- stopping_criteria=stopping_criteria_list if stopping_criteria_list else None,
321
- ):
322
- current_sequence: List[int] = generation_output.sequences[0].tolist()
323
- new_tokens_ids = current_sequence[last_sequence_len:]
324
- if not new_tokens_ids:
325
- continue
326
- text = global_tokenizer.decode(new_tokens_ids, skip_special_tokens=True)
327
  text = filter_unwanted_json_fragments(text)
328
- full_sequence.extend(new_tokens_ids)
329
- generated_tokens_count += len(new_tokens_ids)
330
- last_sequence_len = len(current_sequence)
331
  chunk_payload: Dict[str, Any] = {
332
  "type": "token",
333
  "text": text,
334
- "token_ids": new_tokens_ids,
335
- "generated_tokens_count": generated_tokens_count,
 
 
336
  }
337
  yield json.dumps(chunk_payload) + "\n"
338
- if req.max_time is not None and (time.time() - start_time) > req.max_time:
339
- finish_reason = "time"
340
- break
341
- if generated_tokens_count >= req.max_new_tokens:
342
- finish_reason = "max_new_tokens"
343
  break
344
- if stop_sequence_criteria and stop_sequence_criteria(torch.tensor([full_sequence], device='cpu'), None):
345
- finish_reason = "stop_sequence"
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  break
347
- final_sequence_ids = full_sequence
348
- if finish_reason == "unknown":
349
- if generated_tokens_count > 0:
350
- last_token = final_sequence_ids[-1]
351
- if eos_token_id is not None and last_token == eos_token_id:
352
- finish_reason = "eos_token"
353
- elif pad_token_id is not None and last_token == pad_token_id:
354
- finish_reason = "pad_token"
355
- elif req.max_new_tokens is not None and generated_tokens_count >= req.max_new_tokens:
356
- finish_reason = "max_new_tokens"
357
- elif gen_cfg.max_length is not None and len(final_sequence_ids) >= gen_cfg.max_length:
358
- finish_reason = "max_length"
359
- elif finish_reason == "unknown":
360
- finish_reason = "completed"
361
- final_text_raw = global_tokenizer.decode(final_sequence_ids[initial_len:], skip_special_tokens=True)
362
  final_text_raw = filter_unwanted_json_fragments(final_text_raw)
363
  final_payload: Dict[str, Any] = {
364
  "type": "done",
365
  "total_prompt_tokens": initial_len,
366
  "total_generated_tokens": generated_tokens_count,
367
- "total_sequence_tokens": len(final_sequence_ids),
368
  "final_text": final_text_raw,
369
  "finish_reason": finish_reason,
 
370
  }
371
- logger.info(f"Stream generation finished. Reason: {finish_reason}. Total tokens: {len(final_sequence_ids)}.")
372
  yield json.dumps(final_payload) + "\n"
373
  except Exception as e:
374
  logger.error("Stream generation error", exc_info=True)
@@ -384,28 +414,42 @@ async def generate_full_response(req: GenerateRequest, initial_ids: torch.Tensor
384
  accumulated_text = ""
385
  finish_reason = "unknown"
386
  total_generated_count = 0
387
- try:
388
- async for chunk_json in stream_generation_logic(req, initial_ids, gen_cfg, device):
389
- try:
390
- data = json.loads(chunk_json)
391
- if data.get("type") == "token":
392
- token_ids_chunk = data.get("token_ids", [])
393
- text_chunk = data.get("text", "")
394
- accumulated_tokens.extend(token_ids_chunk)
395
- accumulated_text += text_chunk
396
- total_generated_count = data["generated_tokens_count"]
397
- elif data.get("type") == "done":
398
- finish_reason = data.get("finish_reason", "done")
399
- final_text_part = data.get("final_text", "")
400
- accumulated_text = accumulated_text + final_text_part
401
- break
402
- elif data.get("type") == "error":
403
- raise RuntimeError(f"Error during streaming generation: {data.get('message', 'Unknown error')}")
404
- except json.JSONDecodeError:
405
- logger.warning(f"Failed to decode JSON chunk: {chunk_json.strip()}")
406
- except Exception as e:
407
- logger.error("Error during full response generation from stream", exc_info=True)
408
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  full_sequence_ids_list = initial_ids.tolist()[0] + accumulated_tokens
410
  final_payload: Dict[str, Any] = {
411
  "prompt_tokens": initial_ids.shape[-1],
@@ -417,6 +461,7 @@ async def generate_full_response(req: GenerateRequest, initial_ids: torch.Tensor
417
  "full_sequence_token_ids": full_sequence_ids_list
418
  }],
419
  "total_tokens": initial_ids.shape[-1] + total_generated_count,
 
420
  }
421
  logger.info(f"Full response generation finished. Reason: {finish_reason}. Total tokens: {final_payload['total_tokens']}.")
422
  return final_payload
@@ -447,12 +492,10 @@ async def load_model():
447
  device = "cpu"
448
  if torch.cuda.is_available():
449
  device = "cuda"
450
- logger.info(f"Using device: {device}")
451
  elif torch.backends.mps.is_available():
452
  device = "mps"
453
- logger.info(f"Using device: {device}")
454
  else:
455
- logger.info(f"Using device: {device}")
456
  current_model_name = MODEL_NAME
457
  current_trust_remote_code = TRUST_REMOTE_CODE
458
  try:
@@ -467,28 +510,23 @@ async def load_model():
467
  global_model = AutoModelForCausalLM.from_pretrained(current_model_name, **model_kwargs)
468
  if 'device_map' not in model_kwargs or model_kwargs['device_map'] is None:
469
  global_model.to(device)
470
- logger.info(f"Manually moved model to device: {device}")
471
  else:
472
  model_device = next(global_model.parameters()).device
473
  global_model.eval()
474
  global_tokens["eos_token_id"] = global_tokenizer.eos_token_id
475
  global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
476
  global_tokens["bos_token_id"] = global_tokenizer.bos_token_id
477
- logger.info(f"Tokenizer IDs: EOS={global_tokens['eos_token_id']}, PAD={global_tokens['pad_token_id']}, BOS={global_tokens['bos_token_id']}")
478
  if global_model.config.pad_token_id is None and global_tokens["pad_token_id"] is None:
479
  if global_tokens["eos_token_id"] is not None:
480
  global_tokenizer.pad_token_id = global_tokens["eos_token_id"]
481
  global_model.config.pad_token_id = global_tokens["eos_token_id"]
482
  global_tokens["pad_token_id"] = global_tokens["eos_token_id"]
483
- logger.warning(f"Model/Tokenizer pad_token_id not set. Using eos_token_id: {global_tokens['pad_token_id']}")
484
  else:
485
- logger.warning("Neither EOS nor PAD tokens are available for this tokenizer/model. Padding for batching/beam search might not work.")
486
  elif global_model.config.pad_token_id is None and global_tokens["pad_token_id"] is not None:
487
  global_model.config.pad_token_id = global_tokens["pad_token_id"]
488
- logger.info(f"Model config pad_token_id was None, set from tokenizer: {global_tokens['pad_token_id']}")
489
  elif global_model.config.pad_token_id is not None and global_tokens["pad_token_id"] is None:
490
  global_tokens["pad_token_id"] = global_model.config.pad_token_id
491
- logger.warning(f"Tokenizer pad_token_id was None, set from model config: {global_tokens['pad_token_id']}")
492
  logger.info("Model and tokenizer loaded successfully.")
493
  logger.info(f"Model device: {next(global_model.parameters()).device}")
494
  except Exception as e:
 
231
  if req.max_length is not None and req.max_length > 0:
232
  max_len_from_req = req.max_length
233
  if max_len_from_req <= initial_len:
234
+ pass
235
  elif req.max_new_tokens is not None and req.max_new_tokens > 0:
236
  max_len_from_req = initial_len + req.max_new_tokens
237
  if model_max_len is not None:
 
264
  logger.error(f"Failed to create StopSequenceCriteria: {e}", exc_info=True)
265
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create StopSequenceCriteria: {e}")
266
  return criteria
267
+ def generate_next_token_sync(
268
+ input_ids: torch.Tensor,
269
+ past_key_values: Optional[Tuple],
270
+ gen_cfg: GenerationConfig,
271
+ device: str
272
+ ) -> Tuple[torch.Tensor, Any, torch.Tensor]:
273
+ model_input_ids = input_ids.to(global_model.device)
274
+ model_past_key_values = past_key_values
275
+ with torch.no_grad():
276
+ outputs = global_model(
277
+ input_ids=model_input_ids,
278
+ past_key_values=model_past_key_values,
279
+ use_cache=gen_cfg.use_cache,
280
+ return_dict=True
281
+ )
282
+ logits = outputs.logits[:, -1, :]
283
+ past = outputs.past_key_values
284
+ if gen_cfg.do_sample:
285
+ if gen_cfg.temperature > 1e-8:
286
+ logits = logits / gen_cfg.temperature
287
+ if gen_cfg.top_k and gen_cfg.top_k > 0:
288
+ topk_values, topk_indices = torch.topk(logits, gen_cfg.top_k)
289
+ logits[logits < topk_values[:, -1]] = -float('Inf')
290
+ if gen_cfg.top_p < 1.0 - 1e-8:
291
+ sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
292
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
293
+ sorted_indices_to_remove = cumulative_probs > gen_cfg.top_p
294
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
295
+ sorted_indices_to_remove[..., 0] = False
296
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
297
+ logits[:, indices_to_remove] = -float('Inf')
298
+ token = torch.multinomial(torch.softmax(logits, dim=-1), 1)
299
+ else:
300
+ token = torch.argmax(logits, dim=-1, keepdim=True)
301
+ return token.to('cpu'), past, logits.to('cpu')
302
  async def stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> AsyncGenerator[str, None]:
303
+ past = None
 
304
  generated_tokens_count = 0
305
+ eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id")
306
+ pad_token_id = req.pad_token_id_override if req.pad_token_id_override is not None else global_tokens.get("pad_token_id")
 
 
307
  stop_token_ids = set()
308
  if eos_token_id is not None:
309
  stop_token_ids.add(eos_token_id)
310
  if pad_token_id is not None and pad_token_id != eos_token_id:
311
  stop_token_ids.add(pad_token_id)
312
+ current_ids = initial_ids.to(device)
313
+ initial_len = initial_ids.shape[-1]
314
+ total_ids_list = initial_ids.tolist()[0]
315
+ start_time = time.time()
316
+ finish_reason = "unknown"
317
+ all_stopping_criteria = get_stopping_criteria(req, initial_ids.to('cpu'), global_tokenizer)
318
+ stream_stopping_criteria = StoppingCriteriaList([
319
+ crit for crit in all_stopping_criteria
320
+ if isinstance(crit, (MaxLengthCriteria, StopSequenceCriteria))
321
+ ])
322
+ last_step_logits = None
323
  model_total_capacity = getattr(global_model.config, 'max_position_embeddings', None)
324
  if model_total_capacity is None:
325
  model_total_capacity = MAX_CONTEXT_TOKENS + MAX_GENERATION_TOKENS
326
  effective_max_total_length = req.max_length if req.max_length is not None else initial_len + req.max_new_tokens
327
  if effective_max_total_length > model_total_capacity:
328
  effective_max_total_length = model_total_capacity
329
+ current_segment_start_ids = initial_ids.clone()
330
+ current_segment_initial_len = initial_len
331
+ current_segment_generated_count = 0
332
+ yielded_segments = 0
333
+ logger.info(f"Starting stream generation: max_new_tokens (soft limit)={req.max_new_tokens}, max_length (effective total)={effective_max_total_length}, max_time={req.max_time}, initial_len={initial_len}")
334
  try:
335
+ while True:
336
+ if req.max_time is not None and (time.time() - start_time) > req.max_time:
337
+ finish_reason = "time"
338
+ logger.info(f"Stopping stream generation: {finish_reason} reached (>{req.max_time} seconds).")
339
+ break
340
+ current_total_len = len(total_ids_list)
341
+ if current_total_len >= effective_max_total_length:
342
+ finish_reason = "max_length_reached"
343
+ logger.info(f"Stopping stream generation: {finish_reason} ({current_total_len} tokens).")
344
+ break
345
+ input_ids_sync = current_ids if past is None else token.to(device)
346
+ token, past, step_logits = await asyncio.to_thread(
347
+ generate_next_token_sync,
348
+ input_ids_sync,
349
+ past,
350
+ gen_cfg,
351
+ device
352
+ )
353
+ last_step_logits = step_logits
354
+ generated_token_id = token[0].item()
355
+ total_ids_list.append(generated_token_id)
356
+ text = global_tokenizer.decode([generated_token_id], skip_special_tokens=True)
 
 
 
 
 
 
 
357
  text = filter_unwanted_json_fragments(text)
 
 
 
358
  chunk_payload: Dict[str, Any] = {
359
  "type": "token",
360
  "text": text,
361
+ "token_id": generated_token_id,
362
+ "generated_tokens_count": generated_tokens_count + 1,
363
+ "segment": yielded_segments,
364
+ "segment_token_count": current_segment_generated_count + 1,
365
  }
366
  yield json.dumps(chunk_payload) + "\n"
367
+ generated_tokens_count += 1
368
+ current_segment_generated_count += 1
369
+ if generated_token_id in stop_token_ids:
370
+ finish_reason = "eos_token" if generated_token_id == eos_token_id else "pad_token"
371
+ logger.info(f"Stopping stream generation: Stop token {generated_token_id} ({finish_reason}) generated.")
372
  break
373
+ current_full_ids_tensor = torch.tensor([total_ids_list], device='cpu')
374
+ if stream_stopping_criteria(current_full_ids_tensor, last_step_logits):
375
+ criteria_finish = "stopping_criteria"
376
+ if any(isinstance(c, MaxLengthCriteria) for c in stream_stopping_criteria):
377
+ max_len_crit_met = any(len(total_ids_list) >= c.max_length_seq for c in stream_stopping_criteria if isinstance(c, MaxLengthCriteria))
378
+ if max_len_crit_met:
379
+ criteria_finish = "max_length"
380
+ stop_seq_crit_met = any(isinstance(c, StopSequenceCriteria) for c in stream_stopping_criteria) and req.stop_sequences
381
+ if stop_seq_crit_met:
382
+ generated_text_so_far = global_tokenizer.decode(total_ids_list[initial_len:], skip_special_tokens=True)
383
+ generated_text_so_far = filter_unwanted_json_fragments(generated_text_so_far)
384
+ if any(seq and seq in generated_text_so_far for seq in req.stop_sequences):
385
+ criteria_finish = "stop_sequence"
386
+ finish_reason = criteria_finish
387
+ logger.info(f"Stopping stream generation: {finish_reason} criteria met.")
388
  break
389
+ current_ids = token.to(device)
390
+ final_text_raw = global_tokenizer.decode(total_ids_list[initial_len:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  final_text_raw = filter_unwanted_json_fragments(final_text_raw)
392
  final_payload: Dict[str, Any] = {
393
  "type": "done",
394
  "total_prompt_tokens": initial_len,
395
  "total_generated_tokens": generated_tokens_count,
396
+ "total_sequence_tokens": len(total_ids_list),
397
  "final_text": final_text_raw,
398
  "finish_reason": finish_reason,
399
+ "segment": yielded_segments,
400
  }
401
+ logger.info(f"Stream generation finished. Reason: {finish_reason}. Total tokens: {len(total_ids_list)}. Segment: {yielded_segments}")
402
  yield json.dumps(final_payload) + "\n"
403
  except Exception as e:
404
  logger.error("Stream generation error", exc_info=True)
 
414
  accumulated_text = ""
415
  finish_reason = "unknown"
416
  total_generated_count = 0
417
+ segments_data = []
418
+ current_segment_tokens = []
419
+ current_segment_text = ""
420
+ current_segment_generated_count = 0
421
+ segment_index = 0
422
+ async for chunk_json in stream_generation_logic(req, initial_ids, gen_cfg, device):
423
+ try:
424
+ data = json.loads(chunk_json)
425
+ if data.get("type") == "token":
426
+ token_id = data.get("token_id")
427
+ text = data.get("text", "")
428
+ if token_id is not None:
429
+ accumulated_tokens.append(token_id)
430
+ current_segment_tokens.append(token_id)
431
+ accumulated_text += text
432
+ current_segment_text += text
433
+ total_generated_count = data.get("generated_tokens_count", total_generated_count + 1)
434
+ current_segment_generated_count = data.get("segment_token_count", current_segment_generated_count + 1)
435
+ elif data.get("type") == "done":
436
+ finish_reason = data.get("finish_reason", "done")
437
+ final_segment_text = data.get("final_text", "")
438
+ final_segment_text = filter_unwanted_json_fragments(final_segment_text)
439
+ accumulated_text = filter_unwanted_json_fragments(accumulated_text) + final_segment_text
440
+ current_segment_text = filter_unwanted_json_fragments(current_segment_text) + final_segment_text
441
+ segments_data.append({
442
+ "segment": segment_index,
443
+ "text": current_segment_text,
444
+ "token_ids": current_segment_tokens,
445
+ "generated_tokens_count": current_segment_generated_count,
446
+ "finish_reason": finish_reason if finish_reason != "max_new_tokens_segment" else "completed_segment"
447
+ })
448
+ break
449
+ elif data.get("type") == "error":
450
+ raise RuntimeError(f"Error during streaming generation: {data.get('message', 'Unknown error')}")
451
+ except json.JSONDecodeError:
452
+ logger.warning(f"Failed to decode JSON chunk: {chunk_json.strip()}")
453
  full_sequence_ids_list = initial_ids.tolist()[0] + accumulated_tokens
454
  final_payload: Dict[str, Any] = {
455
  "prompt_tokens": initial_ids.shape[-1],
 
461
  "full_sequence_token_ids": full_sequence_ids_list
462
  }],
463
  "total_tokens": initial_ids.shape[-1] + total_generated_count,
464
+ "segments": segments_data if segments_data else None
465
  }
466
  logger.info(f"Full response generation finished. Reason: {finish_reason}. Total tokens: {final_payload['total_tokens']}.")
467
  return final_payload
 
492
  device = "cpu"
493
  if torch.cuda.is_available():
494
  device = "cuda"
 
495
  elif torch.backends.mps.is_available():
496
  device = "mps"
 
497
  else:
498
+ device = "cpu"
499
  current_model_name = MODEL_NAME
500
  current_trust_remote_code = TRUST_REMOTE_CODE
501
  try:
 
510
  global_model = AutoModelForCausalLM.from_pretrained(current_model_name, **model_kwargs)
511
  if 'device_map' not in model_kwargs or model_kwargs['device_map'] is None:
512
  global_model.to(device)
 
513
  else:
514
  model_device = next(global_model.parameters()).device
515
  global_model.eval()
516
  global_tokens["eos_token_id"] = global_tokenizer.eos_token_id
517
  global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
518
  global_tokens["bos_token_id"] = global_tokenizer.bos_token_id
 
519
  if global_model.config.pad_token_id is None and global_tokens["pad_token_id"] is None:
520
  if global_tokens["eos_token_id"] is not None:
521
  global_tokenizer.pad_token_id = global_tokens["eos_token_id"]
522
  global_model.config.pad_token_id = global_tokens["eos_token_id"]
523
  global_tokens["pad_token_id"] = global_tokens["eos_token_id"]
 
524
  else:
525
+ pass
526
  elif global_model.config.pad_token_id is None and global_tokens["pad_token_id"] is not None:
527
  global_model.config.pad_token_id = global_tokens["pad_token_id"]
 
528
  elif global_model.config.pad_token_id is not None and global_tokens["pad_token_id"] is None:
529
  global_tokens["pad_token_id"] = global_model.config.pad_token_id
 
530
  logger.info("Model and tokenizer loaded successfully.")
531
  logger.info(f"Model device: {next(global_model.parameters()).device}")
532
  except Exception as e: