remove fix-sized causal mask
Browse files- modeling_qwen.py +3 -76
modeling_qwen.py
CHANGED
|
@@ -395,62 +395,6 @@ class QWenAttention(nn.Module):
|
|
| 395 |
|
| 396 |
return attn_output, attn_weights
|
| 397 |
|
| 398 |
-
def _upcast_and_reordered_attn(
|
| 399 |
-
self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
|
| 400 |
-
):
|
| 401 |
-
bsz, num_heads, q_seq_len, dk = query.size()
|
| 402 |
-
_, _, k_seq_len, _ = key.size()
|
| 403 |
-
|
| 404 |
-
attn_weights = torch.empty(
|
| 405 |
-
bsz * num_heads,
|
| 406 |
-
q_seq_len,
|
| 407 |
-
k_seq_len,
|
| 408 |
-
dtype=torch.float32,
|
| 409 |
-
device=query.device,
|
| 410 |
-
)
|
| 411 |
-
|
| 412 |
-
scale_factor = 1.0
|
| 413 |
-
if self.scale_attn_weights:
|
| 414 |
-
scale_factor /= float(value.size(-1)) ** 0.5
|
| 415 |
-
|
| 416 |
-
with autocast(enabled=False):
|
| 417 |
-
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
|
| 418 |
-
-1, dk, k_seq_len
|
| 419 |
-
)
|
| 420 |
-
attn_weights = torch.baddbmm(
|
| 421 |
-
attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
|
| 422 |
-
)
|
| 423 |
-
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
| 424 |
-
|
| 425 |
-
query_length, key_length = query.size(-2), key.size(-2)
|
| 426 |
-
causal_mask = registered_causal_mask[
|
| 427 |
-
:, :, key_length - query_length : key_length, :key_length
|
| 428 |
-
]
|
| 429 |
-
mask_value = torch.finfo(attn_weights.dtype).min
|
| 430 |
-
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
|
| 431 |
-
attn_weights.device
|
| 432 |
-
)
|
| 433 |
-
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
| 434 |
-
|
| 435 |
-
if attention_mask is not None:
|
| 436 |
-
attn_weights = attn_weights + attention_mask
|
| 437 |
-
|
| 438 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 439 |
-
|
| 440 |
-
if attn_weights.dtype != torch.float32:
|
| 441 |
-
raise RuntimeError(
|
| 442 |
-
"Error with upcasting, attn_weights does not have dtype torch.float32"
|
| 443 |
-
)
|
| 444 |
-
attn_weights = attn_weights.type(value.dtype)
|
| 445 |
-
attn_weights = self.attn_dropout(attn_weights)
|
| 446 |
-
|
| 447 |
-
if head_mask is not None:
|
| 448 |
-
attn_weights = attn_weights * head_mask
|
| 449 |
-
|
| 450 |
-
attn_output = torch.matmul(attn_weights, value)
|
| 451 |
-
|
| 452 |
-
return attn_output, attn_weights
|
| 453 |
-
|
| 454 |
def _split_heads(self, tensor, num_heads, attn_head_size):
|
| 455 |
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
| 456 |
tensor = tensor.view(new_shape)
|
|
@@ -465,7 +409,6 @@ class QWenAttention(nn.Module):
|
|
| 465 |
self,
|
| 466 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 467 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
| 468 |
-
registered_causal_mask: Optional[torch.Tensor] = None,
|
| 469 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 470 |
attention_mask: Optional[torch.FloatTensor] = None,
|
| 471 |
head_mask: Optional[torch.FloatTensor] = None,
|
|
@@ -558,6 +501,9 @@ class QWenAttention(nn.Module):
|
|
| 558 |
q, k, v = query, key, value
|
| 559 |
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
| 560 |
else:
|
|
|
|
|
|
|
|
|
|
| 561 |
query = query.permute(0, 2, 1, 3)
|
| 562 |
if not self.use_cache_quantization:
|
| 563 |
key = key.permute(0, 2, 1, 3)
|
|
@@ -650,7 +596,6 @@ class QWenBlock(nn.Module):
|
|
| 650 |
self,
|
| 651 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 652 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
| 653 |
-
registered_causal_mask: Optional[torch.Tensor] = None,
|
| 654 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 655 |
attention_mask: Optional[torch.FloatTensor] = None,
|
| 656 |
head_mask: Optional[torch.FloatTensor] = None,
|
|
@@ -664,7 +609,6 @@ class QWenBlock(nn.Module):
|
|
| 664 |
attn_outputs = self.attn(
|
| 665 |
layernorm_output,
|
| 666 |
rotary_pos_emb_list,
|
| 667 |
-
registered_causal_mask=registered_causal_mask,
|
| 668 |
layer_past=layer_past,
|
| 669 |
attention_mask=attention_mask,
|
| 670 |
head_mask=head_mask,
|
|
@@ -764,21 +708,6 @@ class QWenModel(QWenPreTrainedModel):
|
|
| 764 |
|
| 765 |
self.use_flash_attn = config.use_flash_attn
|
| 766 |
self.is_fp32 = not (config.bf16 or config.fp16)
|
| 767 |
-
if (
|
| 768 |
-
self.use_flash_attn
|
| 769 |
-
and flash_attn_unpadded_func is not None
|
| 770 |
-
and not self.is_fp32
|
| 771 |
-
):
|
| 772 |
-
self.registered_causal_mask = None
|
| 773 |
-
else:
|
| 774 |
-
max_positions = config.max_position_embeddings
|
| 775 |
-
self.register_buffer(
|
| 776 |
-
"registered_causal_mask",
|
| 777 |
-
torch.tril(
|
| 778 |
-
torch.ones((max_positions, max_positions), dtype=torch.bool)
|
| 779 |
-
).view(1, 1, max_positions, max_positions),
|
| 780 |
-
persistent=False,
|
| 781 |
-
)
|
| 782 |
|
| 783 |
self.h = nn.ModuleList(
|
| 784 |
[
|
|
@@ -950,7 +879,6 @@ class QWenModel(QWenPreTrainedModel):
|
|
| 950 |
create_custom_forward(block),
|
| 951 |
hidden_states,
|
| 952 |
rotary_pos_emb_list,
|
| 953 |
-
self.registered_causal_mask,
|
| 954 |
None,
|
| 955 |
attention_mask,
|
| 956 |
head_mask[i],
|
|
@@ -962,7 +890,6 @@ class QWenModel(QWenPreTrainedModel):
|
|
| 962 |
hidden_states,
|
| 963 |
layer_past=layer_past,
|
| 964 |
rotary_pos_emb_list=rotary_pos_emb_list,
|
| 965 |
-
registered_causal_mask=self.registered_causal_mask,
|
| 966 |
attention_mask=attention_mask,
|
| 967 |
head_mask=head_mask[i],
|
| 968 |
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
| 395 |
|
| 396 |
return attn_output, attn_weights
|
| 397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
def _split_heads(self, tensor, num_heads, attn_head_size):
|
| 399 |
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
| 400 |
tensor = tensor.view(new_shape)
|
|
|
|
| 409 |
self,
|
| 410 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 411 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
|
|
|
| 412 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 413 |
attention_mask: Optional[torch.FloatTensor] = None,
|
| 414 |
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
| 501 |
q, k, v = query, key, value
|
| 502 |
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
| 503 |
else:
|
| 504 |
+
registered_causal_mask = torch.tril(
|
| 505 |
+
torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
|
| 506 |
+
).view(1, 1, key.size(1), key.size(1))
|
| 507 |
query = query.permute(0, 2, 1, 3)
|
| 508 |
if not self.use_cache_quantization:
|
| 509 |
key = key.permute(0, 2, 1, 3)
|
|
|
|
| 596 |
self,
|
| 597 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 598 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
|
|
|
| 599 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 600 |
attention_mask: Optional[torch.FloatTensor] = None,
|
| 601 |
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
| 609 |
attn_outputs = self.attn(
|
| 610 |
layernorm_output,
|
| 611 |
rotary_pos_emb_list,
|
|
|
|
| 612 |
layer_past=layer_past,
|
| 613 |
attention_mask=attention_mask,
|
| 614 |
head_mask=head_mask,
|
|
|
|
| 708 |
|
| 709 |
self.use_flash_attn = config.use_flash_attn
|
| 710 |
self.is_fp32 = not (config.bf16 or config.fp16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
|
| 712 |
self.h = nn.ModuleList(
|
| 713 |
[
|
|
|
|
| 879 |
create_custom_forward(block),
|
| 880 |
hidden_states,
|
| 881 |
rotary_pos_emb_list,
|
|
|
|
| 882 |
None,
|
| 883 |
attention_mask,
|
| 884 |
head_mask[i],
|
|
|
|
| 890 |
hidden_states,
|
| 891 |
layer_past=layer_past,
|
| 892 |
rotary_pos_emb_list=rotary_pos_emb_list,
|
|
|
|
| 893 |
attention_mask=attention_mask,
|
| 894 |
head_mask=head_mask[i],
|
| 895 |
encoder_hidden_states=encoder_hidden_states,
|