Fixes flash-attn import with a try/except statement
Browse files
configuration_mixformer_sequential.py
CHANGED
|
@@ -30,6 +30,9 @@ class MixFormerSequentialConfig(PretrainedConfig):
|
|
| 30 |
n_head_kv: Optional[int] = None,
|
| 31 |
rotary_dim: Optional[int] = 32,
|
| 32 |
activation_function: Optional[str] = "gelu_new",
|
|
|
|
|
|
|
|
|
|
| 33 |
embd_pdrop: Optional[float] = 0.0,
|
| 34 |
resid_pdrop: Optional[float] = 0.0,
|
| 35 |
layer_norm_epsilon: Optional[float] = 1e-5,
|
|
@@ -47,6 +50,9 @@ class MixFormerSequentialConfig(PretrainedConfig):
|
|
| 47 |
self.n_head_kv = n_head_kv
|
| 48 |
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
| 49 |
self.activation_function = activation_function
|
|
|
|
|
|
|
|
|
|
| 50 |
self.embd_pdrop = embd_pdrop
|
| 51 |
self.resid_pdrop = resid_pdrop
|
| 52 |
self.layer_norm_epsilon = layer_norm_epsilon
|
|
|
|
| 30 |
n_head_kv: Optional[int] = None,
|
| 31 |
rotary_dim: Optional[int] = 32,
|
| 32 |
activation_function: Optional[str] = "gelu_new",
|
| 33 |
+
flash_rotary: bool = False,
|
| 34 |
+
fused_dense: bool = False,
|
| 35 |
+
attn_pdrop: Optional[float] = 0.0,
|
| 36 |
embd_pdrop: Optional[float] = 0.0,
|
| 37 |
resid_pdrop: Optional[float] = 0.0,
|
| 38 |
layer_norm_epsilon: Optional[float] = 1e-5,
|
|
|
|
| 50 |
self.n_head_kv = n_head_kv
|
| 51 |
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
| 52 |
self.activation_function = activation_function
|
| 53 |
+
self.flash_rotary = flash_rotary
|
| 54 |
+
self.fused_dense = fused_dense
|
| 55 |
+
self.attn_pdrop = attn_pdrop
|
| 56 |
self.embd_pdrop = embd_pdrop
|
| 57 |
self.resid_pdrop = resid_pdrop
|
| 58 |
self.layer_norm_epsilon = layer_norm_epsilon
|
modeling_mixformer_sequential.py
CHANGED
|
@@ -32,7 +32,6 @@
|
|
| 32 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 33 |
|
| 34 |
from __future__ import annotations
|
| 35 |
-
import importlib
|
| 36 |
|
| 37 |
import math
|
| 38 |
from typing import Any, Dict, Optional, Tuple, Union
|
|
@@ -49,14 +48,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
| 49 |
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
| 50 |
|
| 51 |
|
| 52 |
-
|
| 53 |
-
return importlib.util.find_spec("flash_attn") is not None
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
if _is_flash_attn_available():
|
| 57 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 58 |
from flash_attn.ops.fused_dense import FusedDense
|
| 59 |
-
|
| 60 |
FlashRotaryEmbedding = None
|
| 61 |
FusedDense = None
|
| 62 |
|
|
@@ -549,9 +544,6 @@ class MHA(nn.Module):
|
|
| 549 |
bias: bool = True,
|
| 550 |
causal: bool = True,
|
| 551 |
softmax_scale: Optional[float] = None,
|
| 552 |
-
dropout: float = 0.0,
|
| 553 |
-
flash_rotary: bool = True,
|
| 554 |
-
fused_dense: bool = True,
|
| 555 |
layer_idx: Optional[int] = None,
|
| 556 |
return_residual: bool = False,
|
| 557 |
checkpointing: bool = False,
|
|
@@ -565,7 +557,7 @@ class MHA(nn.Module):
|
|
| 565 |
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
| 566 |
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
| 567 |
|
| 568 |
-
rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
|
| 569 |
if rotary_cls is None:
|
| 570 |
rotary_cls = RotaryEmbedding
|
| 571 |
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
|
@@ -575,7 +567,7 @@ class MHA(nn.Module):
|
|
| 575 |
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
| 576 |
hidden_size = config.n_embd
|
| 577 |
|
| 578 |
-
linear_cls = FusedDense if fused_dense else nn.Linear
|
| 579 |
if linear_cls is None:
|
| 580 |
linear_cls = nn.Linear
|
| 581 |
|
|
@@ -583,8 +575,8 @@ class MHA(nn.Module):
|
|
| 583 |
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
| 584 |
|
| 585 |
# Attention
|
| 586 |
-
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=
|
| 587 |
-
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=
|
| 588 |
|
| 589 |
self.layer_idx = layer_idx
|
| 590 |
self.return_residual = return_residual
|
|
|
|
| 32 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 33 |
|
| 34 |
from __future__ import annotations
|
|
|
|
| 35 |
|
| 36 |
import math
|
| 37 |
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
|
| 48 |
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
| 49 |
|
| 50 |
|
| 51 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 53 |
from flash_attn.ops.fused_dense import FusedDense
|
| 54 |
+
except:
|
| 55 |
FlashRotaryEmbedding = None
|
| 56 |
FusedDense = None
|
| 57 |
|
|
|
|
| 544 |
bias: bool = True,
|
| 545 |
causal: bool = True,
|
| 546 |
softmax_scale: Optional[float] = None,
|
|
|
|
|
|
|
|
|
|
| 547 |
layer_idx: Optional[int] = None,
|
| 548 |
return_residual: bool = False,
|
| 549 |
checkpointing: bool = False,
|
|
|
|
| 557 |
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
| 558 |
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
| 559 |
|
| 560 |
+
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
| 561 |
if rotary_cls is None:
|
| 562 |
rotary_cls = RotaryEmbedding
|
| 563 |
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
|
|
|
| 567 |
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
| 568 |
hidden_size = config.n_embd
|
| 569 |
|
| 570 |
+
linear_cls = FusedDense if config.fused_dense else nn.Linear
|
| 571 |
if linear_cls is None:
|
| 572 |
linear_cls = nn.Linear
|
| 573 |
|
|
|
|
| 575 |
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
| 576 |
|
| 577 |
# Attention
|
| 578 |
+
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
|
| 579 |
+
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
|
| 580 |
|
| 581 |
self.layer_idx = layer_idx
|
| 582 |
self.return_residual = return_residual
|