update modeling_qwen.py
Browse files- modeling_qwen.py +4 -3
modeling_qwen.py
CHANGED
|
@@ -175,6 +175,7 @@ class FlashSelfAttention(torch.nn.Module):
|
|
| 175 |
assert all((i.is_cuda for i in (q, k, v)))
|
| 176 |
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 177 |
seqlen_k = k.shape[1]
|
|
|
|
| 178 |
|
| 179 |
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
| 180 |
cu_seqlens_q = torch.arange(
|
|
@@ -187,11 +188,11 @@ class FlashSelfAttention(torch.nn.Module):
|
|
| 187 |
|
| 188 |
if attention_mask is not None:
|
| 189 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
| 190 |
-
|
| 191 |
-
if self.training or q.size(0) == k.size(0):
|
| 192 |
q = q[indices_k]
|
| 193 |
cu_seqlens_q = cu_seqlens_k
|
| 194 |
seqlen_q = seqlen_k
|
|
|
|
| 195 |
else:
|
| 196 |
cu_seqlens_k = torch.arange(
|
| 197 |
0,
|
|
@@ -222,7 +223,7 @@ class FlashSelfAttention(torch.nn.Module):
|
|
| 222 |
causal=is_causal,
|
| 223 |
)
|
| 224 |
if attention_mask is not None and seqlen_q == seqlen_k:
|
| 225 |
-
output = self.pad_input(output, indices_k, batch_size,
|
| 226 |
else:
|
| 227 |
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
|
| 228 |
output = output.view(new_shape)
|
|
|
|
| 175 |
assert all((i.is_cuda for i in (q, k, v)))
|
| 176 |
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 177 |
seqlen_k = k.shape[1]
|
| 178 |
+
seqlen_out = seqlen_q
|
| 179 |
|
| 180 |
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
| 181 |
cu_seqlens_q = torch.arange(
|
|
|
|
| 188 |
|
| 189 |
if attention_mask is not None:
|
| 190 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
| 191 |
+
if q.size(0) == v.size(0):
|
|
|
|
| 192 |
q = q[indices_k]
|
| 193 |
cu_seqlens_q = cu_seqlens_k
|
| 194 |
seqlen_q = seqlen_k
|
| 195 |
+
v = v[indices_k]
|
| 196 |
else:
|
| 197 |
cu_seqlens_k = torch.arange(
|
| 198 |
0,
|
|
|
|
| 223 |
causal=is_causal,
|
| 224 |
)
|
| 225 |
if attention_mask is not None and seqlen_q == seqlen_k:
|
| 226 |
+
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
|
| 227 |
else:
|
| 228 |
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
|
| 229 |
output = output.view(new_shape)
|