Text Generation
Transformers
Safetensors
PyTorch
nvidia
nemotron-h
suhara commited on
Commit
83908dc
·
verified ·
1 Parent(s): d63ce32

Upload modeling_nemotron_h.py

Browse files
Files changed (1) hide show
  1. modeling_nemotron_h.py +26 -23
modeling_nemotron_h.py CHANGED
@@ -335,7 +335,7 @@ class NemotronHMamba2Mixer(nn.Module):
335
  A = torch.arange(1, self.num_heads + 1)
336
  self.A_log = nn.Parameter(torch.log(A))
337
  self.A_log._no_weight_decay = True
338
- self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size)
339
  self.D = nn.Parameter(torch.ones(self.num_heads))
340
  self.D._no_weight_decay = True
341
 
@@ -469,13 +469,14 @@ class NemotronHMamba2Mixer(nn.Module):
469
  self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
470
  )
471
  else:
472
- hidden_states_B_C = causal_conv1d_fn(
473
- x=hidden_states_B_C.transpose(1, 2),
474
- weight=self.conv1d.weight.squeeze(1),
475
- bias=self.conv1d.bias,
476
- activation=self.activation,
477
- ).transpose(1, 2)
478
-
 
479
  hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
480
  hidden_states, B, C = torch.split(
481
  hidden_states_B_C,
@@ -484,21 +485,23 @@ class NemotronHMamba2Mixer(nn.Module):
484
  )
485
 
486
  # 3. SSM transformation
487
- scan_output, ssm_state = mamba_chunk_scan_combined(
488
- hidden_states.view(batch_size, seq_len, -1, self.head_dim),
489
- dt,
490
- A,
491
- B.view(batch_size, seq_len, self.n_groups, -1),
492
- C.view(batch_size, seq_len, self.n_groups, -1),
493
- chunk_size=self.chunk_size,
494
- D=self.D,
495
- z=None,
496
- seq_idx=None,
497
- return_final_states=True,
498
- dt_bias=self.dt_bias,
499
- dt_softplus=True,
500
- **dt_limit_kwargs,
501
- )
 
 
502
 
503
  # Init cache
504
  if ssm_state is not None and cache_params is not None:
 
335
  A = torch.arange(1, self.num_heads + 1)
336
  self.A_log = nn.Parameter(torch.log(A))
337
  self.A_log._no_weight_decay = True
338
+ self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size // self.n_groups)
339
  self.D = nn.Parameter(torch.ones(self.num_heads))
340
  self.D._no_weight_decay = True
341
 
 
469
  self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
470
  )
471
  else:
472
+ # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
473
+ with torch.cuda.stream(torch.cuda.default_stream(hidden_states_B_C.device)):
474
+ hidden_states_B_C = causal_conv1d_fn(
475
+ x=hidden_states_B_C.transpose(1, 2),
476
+ weight=self.conv1d.weight.squeeze(1),
477
+ bias=self.conv1d.bias,
478
+ activation=self.activation,
479
+ ).transpose(1, 2)
480
  hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
481
  hidden_states, B, C = torch.split(
482
  hidden_states_B_C,
 
485
  )
486
 
487
  # 3. SSM transformation
488
+ # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
489
+ with torch.cuda.stream(torch.cuda.default_stream(hidden_states_B_C.device)):
490
+ scan_output, ssm_state = mamba_chunk_scan_combined(
491
+ hidden_states.view(batch_size, seq_len, -1, self.head_dim),
492
+ dt,
493
+ A,
494
+ B.view(batch_size, seq_len, self.n_groups, -1),
495
+ C.view(batch_size, seq_len, self.n_groups, -1),
496
+ chunk_size=self.chunk_size,
497
+ D=self.D,
498
+ z=None,
499
+ seq_idx=None,
500
+ return_final_states=True,
501
+ dt_bias=self.dt_bias,
502
+ dt_softplus=True,
503
+ **dt_limit_kwargs,
504
+ )
505
 
506
  # Init cache
507
  if ssm_state is not None and cache_params is not None: