chore: update gradient checkpointing
Browse files
model.py
CHANGED
|
@@ -311,8 +311,8 @@ class StripedHyena(nn.Module):
|
|
| 311 |
self.embedding_layer = VocabParallelEmbedding(config)
|
| 312 |
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
|
| 313 |
self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)
|
| 314 |
-
self.
|
| 315 |
-
|
| 316 |
if config.get("use_flashfft", "False"):
|
| 317 |
raise NotImplementedError("Please use standalone SH code for other custom kernels")
|
| 318 |
else:
|
|
@@ -349,8 +349,18 @@ class StripedHyena(nn.Module):
|
|
| 349 |
if type(padding_mask) == torch.Tensor:
|
| 350 |
x = x * padding_mask[..., None]
|
| 351 |
|
| 352 |
-
for
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
return x, None
|
| 355 |
|
| 356 |
def initialize_inference_params(self):
|
|
|
|
| 311 |
self.embedding_layer = VocabParallelEmbedding(config)
|
| 312 |
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
|
| 313 |
self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)
|
| 314 |
+
self.gradient_checkpointing = False
|
| 315 |
+
|
| 316 |
if config.get("use_flashfft", "False"):
|
| 317 |
raise NotImplementedError("Please use standalone SH code for other custom kernels")
|
| 318 |
else:
|
|
|
|
| 349 |
if type(padding_mask) == torch.Tensor:
|
| 350 |
x = x * padding_mask[..., None]
|
| 351 |
|
| 352 |
+
for block_idx, block in enumerate(self.blocks):
|
| 353 |
+
if self.gradient_checkpointing and self.training:
|
| 354 |
+
def create_custom_forward(module):
|
| 355 |
+
def custom_forward(*inputs):
|
| 356 |
+
# None for past_key_value
|
| 357 |
+
return module(*inputs, inference_params=None, padding_mask=padding_mask)
|
| 358 |
+
|
| 359 |
+
return custom_forward
|
| 360 |
+
|
| 361 |
+
x, _ = checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
| 362 |
+
else:
|
| 363 |
+
x, _ = block(x, inference_params=None, padding_mask=padding_mask)
|
| 364 |
return x, None
|
| 365 |
|
| 366 |
def initialize_inference_params(self):
|