Update modeling_mimo.py
#7
by
chengfeng17
- opened
- modeling_mimo.py +6 -6
modeling_mimo.py
CHANGED
|
@@ -27,10 +27,10 @@ class MiMoMTPLayers(nn.Module):
|
|
| 27 |
hidden_states,
|
| 28 |
attention_mask,
|
| 29 |
position_ids,
|
| 30 |
-
|
| 31 |
output_attentions: Optional[bool]=False,
|
| 32 |
use_cache: Optional[bool]=False,
|
| 33 |
-
|
| 34 |
cache_position=None,
|
| 35 |
**kwargs):
|
| 36 |
input_embeds = self.token_layernorm(input_embeds)
|
|
@@ -38,15 +38,15 @@ class MiMoMTPLayers(nn.Module):
|
|
| 38 |
hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
|
| 39 |
residual = hidden_states
|
| 40 |
hidden_states = self.input_layernorm(hidden_states)
|
| 41 |
-
hidden_states
|
| 42 |
attention_mask=attention_mask,
|
| 43 |
position_ids=position_ids,
|
| 44 |
-
|
| 45 |
output_attentions=output_attentions,
|
| 46 |
use_cache=use_cache,
|
| 47 |
cache_position=cache_position,
|
| 48 |
-
|
| 49 |
-
**kwargs)
|
| 50 |
hidden_states = residual + hidden_states
|
| 51 |
residual = hidden_states
|
| 52 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
|
| 27 |
hidden_states,
|
| 28 |
attention_mask,
|
| 29 |
position_ids,
|
| 30 |
+
past_key_value: Optional[Cache]=None,
|
| 31 |
output_attentions: Optional[bool]=False,
|
| 32 |
use_cache: Optional[bool]=False,
|
| 33 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 34 |
cache_position=None,
|
| 35 |
**kwargs):
|
| 36 |
input_embeds = self.token_layernorm(input_embeds)
|
|
|
|
| 38 |
hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
|
| 39 |
residual = hidden_states
|
| 40 |
hidden_states = self.input_layernorm(hidden_states)
|
| 41 |
+
hidden_states = self.self_attn(hidden_states,
|
| 42 |
attention_mask=attention_mask,
|
| 43 |
position_ids=position_ids,
|
| 44 |
+
past_key_value=past_key_value,
|
| 45 |
output_attentions=output_attentions,
|
| 46 |
use_cache=use_cache,
|
| 47 |
cache_position=cache_position,
|
| 48 |
+
position_embeddings=position_embeddings,
|
| 49 |
+
**kwargs)[0]
|
| 50 |
hidden_states = residual + hidden_states
|
| 51 |
residual = hidden_states
|
| 52 |
hidden_states = self.post_attention_layernorm(hidden_states)
|