Properly propagate `model_inputs`
#14
by
Alnusjaponica
- opened
- modeling_plamo.py +3 -0
modeling_plamo.py
CHANGED
|
@@ -1663,6 +1663,9 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
|
|
| 1663 |
"position_ids": position_ids,
|
| 1664 |
"past_key_values": past_key_values,
|
| 1665 |
"use_cache": kwargs.get("use_cache"),
|
|
|
|
|
|
|
|
|
|
| 1666 |
"attention_mask": attention_mask,
|
| 1667 |
"image_features": image_features,
|
| 1668 |
}
|
|
|
|
| 1663 |
"position_ids": position_ids,
|
| 1664 |
"past_key_values": past_key_values,
|
| 1665 |
"use_cache": kwargs.get("use_cache"),
|
| 1666 |
+
"output_attentions": kwargs.get("output_attentions"),
|
| 1667 |
+
"output_hidden_states": kwargs.get("output_hidden_states"),
|
| 1668 |
+
"logits_to_keep": kwargs.get("logits_to_keep"),
|
| 1669 |
"attention_mask": attention_mask,
|
| 1670 |
"image_features": image_features,
|
| 1671 |
}
|