Update modeling_rwkv5.py
Browse files- modeling_rwkv5.py +4 -1
modeling_rwkv5.py
CHANGED
|
@@ -789,7 +789,10 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
|
|
| 789 |
# only last token for inputs_ids if the state is passed along.
|
| 790 |
if state is not None:
|
| 791 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 792 |
-
|
|
|
|
|
|
|
|
|
|
| 793 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 794 |
if inputs_embeds is not None and state is None:
|
| 795 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
|
|
| 789 |
# only last token for inputs_ids if the state is passed along.
|
| 790 |
if state is not None:
|
| 791 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 792 |
+
else:
|
| 793 |
+
# add in \n at the beginning
|
| 794 |
+
input_ids = torch.cat([torch.full([1,1],11,device=input_ids.device,dtype=input_ids.dtype), input_ids])
|
| 795 |
+
|
| 796 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 797 |
if inputs_embeds is not None and state is None:
|
| 798 |
model_inputs = {"inputs_embeds": inputs_embeds}
|