Upload modeling_nemotron_h.py
Browse files- modeling_nemotron_h.py +26 -23
modeling_nemotron_h.py
CHANGED
|
@@ -335,7 +335,7 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
| 335 |
A = torch.arange(1, self.num_heads + 1)
|
| 336 |
self.A_log = nn.Parameter(torch.log(A))
|
| 337 |
self.A_log._no_weight_decay = True
|
| 338 |
-
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size)
|
| 339 |
self.D = nn.Parameter(torch.ones(self.num_heads))
|
| 340 |
self.D._no_weight_decay = True
|
| 341 |
|
|
@@ -469,13 +469,14 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
| 469 |
self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
|
| 470 |
)
|
| 471 |
else:
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
|
|
|
| 479 |
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
| 480 |
hidden_states, B, C = torch.split(
|
| 481 |
hidden_states_B_C,
|
|
@@ -484,21 +485,23 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
| 484 |
)
|
| 485 |
|
| 486 |
# 3. SSM transformation
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
|
|
|
|
|
|
| 502 |
|
| 503 |
# Init cache
|
| 504 |
if ssm_state is not None and cache_params is not None:
|
|
|
|
| 335 |
A = torch.arange(1, self.num_heads + 1)
|
| 336 |
self.A_log = nn.Parameter(torch.log(A))
|
| 337 |
self.A_log._no_weight_decay = True
|
| 338 |
+
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size // self.n_groups)
|
| 339 |
self.D = nn.Parameter(torch.ones(self.num_heads))
|
| 340 |
self.D._no_weight_decay = True
|
| 341 |
|
|
|
|
| 469 |
self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
|
| 470 |
)
|
| 471 |
else:
|
| 472 |
+
# * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
|
| 473 |
+
with torch.cuda.stream(torch.cuda.default_stream(hidden_states_B_C.device)):
|
| 474 |
+
hidden_states_B_C = causal_conv1d_fn(
|
| 475 |
+
x=hidden_states_B_C.transpose(1, 2),
|
| 476 |
+
weight=self.conv1d.weight.squeeze(1),
|
| 477 |
+
bias=self.conv1d.bias,
|
| 478 |
+
activation=self.activation,
|
| 479 |
+
).transpose(1, 2)
|
| 480 |
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
| 481 |
hidden_states, B, C = torch.split(
|
| 482 |
hidden_states_B_C,
|
|
|
|
| 485 |
)
|
| 486 |
|
| 487 |
# 3. SSM transformation
|
| 488 |
+
# * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
|
| 489 |
+
with torch.cuda.stream(torch.cuda.default_stream(hidden_states_B_C.device)):
|
| 490 |
+
scan_output, ssm_state = mamba_chunk_scan_combined(
|
| 491 |
+
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
|
| 492 |
+
dt,
|
| 493 |
+
A,
|
| 494 |
+
B.view(batch_size, seq_len, self.n_groups, -1),
|
| 495 |
+
C.view(batch_size, seq_len, self.n_groups, -1),
|
| 496 |
+
chunk_size=self.chunk_size,
|
| 497 |
+
D=self.D,
|
| 498 |
+
z=None,
|
| 499 |
+
seq_idx=None,
|
| 500 |
+
return_final_states=True,
|
| 501 |
+
dt_bias=self.dt_bias,
|
| 502 |
+
dt_softplus=True,
|
| 503 |
+
**dt_limit_kwargs,
|
| 504 |
+
)
|
| 505 |
|
| 506 |
# Init cache
|
| 507 |
if ssm_state is not None and cache_params is not None:
|