Properly propagate `model_inputs`

#14
Files changed (1) hide show
  1. modeling_plamo.py +3 -0
modeling_plamo.py CHANGED
@@ -1663,6 +1663,9 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1663
  "position_ids": position_ids,
1664
  "past_key_values": past_key_values,
1665
  "use_cache": kwargs.get("use_cache"),
 
 
 
1666
  "attention_mask": attention_mask,
1667
  "image_features": image_features,
1668
  }
 
1663
  "position_ids": position_ids,
1664
  "past_key_values": past_key_values,
1665
  "use_cache": kwargs.get("use_cache"),
1666
+ "output_attentions": kwargs.get("output_attentions"),
1667
+ "output_hidden_states": kwargs.get("output_hidden_states"),
1668
+ "logits_to_keep": kwargs.get("logits_to_keep"),
1669
  "attention_mask": attention_mask,
1670
  "image_features": image_features,
1671
  }