Upload 2 files
Browse files- configuration_nemotron_h.py +1 -1
- modeling_nemotron_h.py +45 -48
configuration_nemotron_h.py
CHANGED
|
@@ -239,4 +239,4 @@ class NemotronHConfig(PretrainedConfig):
|
|
| 239 |
return [
|
| 240 |
"mamba" if self.hybrid_override_pattern[i] == "M" else
|
| 241 |
"attention" if self.hybrid_override_pattern[i] == "*" else "mlp"
|
| 242 |
-
for i in range(self.num_hidden_layers)]
|
|
|
|
| 239 |
return [
|
| 240 |
"mamba" if self.hybrid_override_pattern[i] == "M" else
|
| 241 |
"attention" if self.hybrid_override_pattern[i] == "*" else "mlp"
|
| 242 |
+
for i in range(self.num_hidden_layers)]
|
modeling_nemotron_h.py
CHANGED
|
@@ -469,14 +469,12 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
| 469 |
self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
|
| 470 |
)
|
| 471 |
else:
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
activation=self.activation,
|
| 479 |
-
).transpose(1, 2)
|
| 480 |
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
| 481 |
hidden_states, B, C = torch.split(
|
| 482 |
hidden_states_B_C,
|
|
@@ -485,23 +483,21 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
| 485 |
)
|
| 486 |
|
| 487 |
# 3. SSM transformation
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
**dt_limit_kwargs,
|
| 504 |
-
)
|
| 505 |
|
| 506 |
# Init cache
|
| 507 |
if ssm_state is not None and cache_params is not None:
|
|
@@ -768,30 +764,31 @@ class NemotronHBlock(nn.Module):
|
|
| 768 |
cache_position: Optional[torch.LongTensor] = None,
|
| 769 |
attention_mask: Optional[torch.Tensor] = None,
|
| 770 |
):
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
hidden_states
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
|
|
|
| 792 |
|
| 793 |
-
|
| 794 |
-
|
| 795 |
|
| 796 |
|
| 797 |
# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
|
|
|
|
| 469 |
self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
|
| 470 |
)
|
| 471 |
else:
|
| 472 |
+
hidden_states_B_C = causal_conv1d_fn(
|
| 473 |
+
x=hidden_states_B_C.transpose(1, 2),
|
| 474 |
+
weight=self.conv1d.weight.squeeze(1),
|
| 475 |
+
bias=self.conv1d.bias,
|
| 476 |
+
activation=self.activation,
|
| 477 |
+
).transpose(1, 2)
|
|
|
|
|
|
|
| 478 |
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
| 479 |
hidden_states, B, C = torch.split(
|
| 480 |
hidden_states_B_C,
|
|
|
|
| 483 |
)
|
| 484 |
|
| 485 |
# 3. SSM transformation
|
| 486 |
+
scan_output, ssm_state = mamba_chunk_scan_combined(
|
| 487 |
+
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
|
| 488 |
+
dt,
|
| 489 |
+
A,
|
| 490 |
+
B.view(batch_size, seq_len, self.n_groups, -1),
|
| 491 |
+
C.view(batch_size, seq_len, self.n_groups, -1),
|
| 492 |
+
chunk_size=self.chunk_size,
|
| 493 |
+
D=self.D,
|
| 494 |
+
z=None,
|
| 495 |
+
seq_idx=None,
|
| 496 |
+
return_final_states=True,
|
| 497 |
+
dt_bias=self.dt_bias,
|
| 498 |
+
dt_softplus=True,
|
| 499 |
+
**dt_limit_kwargs,
|
| 500 |
+
)
|
|
|
|
|
|
|
| 501 |
|
| 502 |
# Init cache
|
| 503 |
if ssm_state is not None and cache_params is not None:
|
|
|
|
| 764 |
cache_position: Optional[torch.LongTensor] = None,
|
| 765 |
attention_mask: Optional[torch.Tensor] = None,
|
| 766 |
):
|
| 767 |
+
with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
|
| 768 |
+
# * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
|
| 769 |
+
residual = hidden_states
|
| 770 |
+
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
| 771 |
+
if self.residual_in_fp32:
|
| 772 |
+
residual = residual.to(torch.float32)
|
| 773 |
+
|
| 774 |
+
if self.block_type == "mamba":
|
| 775 |
+
hidden_states = self.mixer(
|
| 776 |
+
hidden_states, cache_params=cache_params, cache_position=cache_position
|
| 777 |
+
)
|
| 778 |
+
elif self.block_type == "attention":
|
| 779 |
+
hidden_states = self.mixer(
|
| 780 |
+
hidden_states, cache_position=cache_position
|
| 781 |
+
)
|
| 782 |
+
hidden_states = hidden_states[0]
|
| 783 |
+
elif self.block_type == "mlp":
|
| 784 |
+
hidden_states = self.mixer(
|
| 785 |
+
hidden_states
|
| 786 |
+
)
|
| 787 |
+
else:
|
| 788 |
+
raise ValueError(f"Invalid block_type: {self.block_type}")
|
| 789 |
|
| 790 |
+
hidden_states = residual + hidden_states
|
| 791 |
+
return hidden_states
|
| 792 |
|
| 793 |
|
| 794 |
# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
|