Update modeling_nort5.py
Browse files- modeling_nort5.py +2 -2
modeling_nort5.py
CHANGED
|
@@ -221,7 +221,7 @@ class Attention(nn.Module):
|
|
| 221 |
- torch.arange(512, dtype=torch.long).unsqueeze(0)
|
| 222 |
position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, 512)
|
| 223 |
position_indices = config.position_bucket_size - 1 + position_indices
|
| 224 |
-
self.register_buffer("position_indices", position_indices, persistent=
|
| 225 |
|
| 226 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 227 |
self.scale = 1.0 / math.sqrt(3 * self.head_size)
|
|
@@ -271,7 +271,7 @@ class Attention(nn.Module):
|
|
| 271 |
- torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(0)
|
| 272 |
position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
|
| 273 |
position_indices = self.config.position_bucket_size - 1 + position_indices
|
| 274 |
-
self.register_buffer("position_indices", position_indices.to(q.device), persistent=
|
| 275 |
|
| 276 |
q = self.pre_layer_norm(q)
|
| 277 |
query = self.in_proj_q(q) # shape: [T, B, D]
|
|
|
|
| 221 |
- torch.arange(512, dtype=torch.long).unsqueeze(0)
|
| 222 |
position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, 512)
|
| 223 |
position_indices = config.position_bucket_size - 1 + position_indices
|
| 224 |
+
self.register_buffer("position_indices", position_indices, persistent=False)
|
| 225 |
|
| 226 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 227 |
self.scale = 1.0 / math.sqrt(3 * self.head_size)
|
|
|
|
| 271 |
- torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(0)
|
| 272 |
position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
|
| 273 |
position_indices = self.config.position_bucket_size - 1 + position_indices
|
| 274 |
+
self.register_buffer("position_indices", position_indices.to(q.device), persistent=False)
|
| 275 |
|
| 276 |
q = self.pre_layer_norm(q)
|
| 277 |
query = self.in_proj_q(q) # shape: [T, B, D]
|