fix(phi-1): Checks length of `attention_mask`if it is passed as direct tensor.
Browse files
modeling_mixformer_sequential.py
CHANGED
|
@@ -35,7 +35,7 @@ from __future__ import annotations
|
|
| 35 |
|
| 36 |
import math
|
| 37 |
import copy
|
| 38 |
-
from typing import Any, Dict, Optional, Tuple
|
| 39 |
from dataclasses import dataclass, field
|
| 40 |
|
| 41 |
import torch
|
|
@@ -541,8 +541,8 @@ class MHA(nn.Module):
|
|
| 541 |
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
|
| 542 |
|
| 543 |
if attention_mask is not None:
|
| 544 |
-
attention_mask
|
| 545 |
-
attention_mask = attention_mask.to(qkv.device)
|
| 546 |
|
| 547 |
attention_kwargs = {"attention_mask": attention_mask}
|
| 548 |
|
|
|
|
| 35 |
|
| 36 |
import math
|
| 37 |
import copy
|
| 38 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 39 |
from dataclasses import dataclass, field
|
| 40 |
|
| 41 |
import torch
|
|
|
|
| 541 |
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
|
| 542 |
|
| 543 |
if attention_mask is not None:
|
| 544 |
+
attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
|
| 545 |
+
attention_mask = attention_mask.bool().to(qkv.device)
|
| 546 |
|
| 547 |
attention_kwargs = {"attention_mask": attention_mask}
|
| 548 |
|