Adds support for flash-attn rotary embedding and fused dense layers.
Browse files- modeling_mixformer_sequential.py +59 -19
modeling_mixformer_sequential.py
CHANGED
|
@@ -32,6 +32,7 @@
|
|
| 32 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 33 |
|
| 34 |
from __future__ import annotations
|
|
|
|
| 35 |
|
| 36 |
import math
|
| 37 |
from typing import Any, Dict, Optional, Tuple, Union
|
|
@@ -48,6 +49,18 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
| 48 |
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
@dataclass
|
| 52 |
class InferenceParams:
|
| 53 |
"""Inference parameters passed to model to efficiently calculate
|
|
@@ -213,6 +226,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 213 |
dim: int,
|
| 214 |
base: int = 10000,
|
| 215 |
scale_base: Optional[float] = None,
|
|
|
|
| 216 |
device: Optional[str] = None,
|
| 217 |
**kwargs,
|
| 218 |
) -> None:
|
|
@@ -221,15 +235,17 @@ class RotaryEmbedding(nn.Module):
|
|
| 221 |
if scale_base is not None:
|
| 222 |
raise NotImplementedError
|
| 223 |
|
| 224 |
-
# Generate and save the inverse frequency buffer (non-trainable)
|
| 225 |
self.dim = dim
|
| 226 |
-
self.base = base
|
| 227 |
self.scale_base = scale_base
|
|
|
|
| 228 |
self.device = device
|
| 229 |
|
| 230 |
-
|
|
|
|
| 231 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 232 |
|
|
|
|
| 233 |
scale = (
|
| 234 |
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
| 235 |
if scale_base is not None
|
|
@@ -243,23 +259,37 @@ class RotaryEmbedding(nn.Module):
|
|
| 243 |
self._cos_k_cached = None
|
| 244 |
self._sin_k_cached = None
|
| 245 |
|
|
|
|
|
|
|
|
|
|
| 246 |
def _update_cos_sin_cache(
|
| 247 |
self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
|
| 248 |
) -> None:
|
| 249 |
-
#
|
| 250 |
-
#
|
| 251 |
-
if
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
| 257 |
self._seq_len_cached = seqlen
|
| 258 |
-
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 259 |
|
| 260 |
-
#
|
| 261 |
-
#
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
if self.scale is None:
|
| 264 |
self._cos_cached = torch.cos(freqs).to(dtype)
|
| 265 |
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
@@ -269,7 +299,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 269 |
) / self.scale_base
|
| 270 |
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| 271 |
|
| 272 |
-
#
|
| 273 |
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 274 |
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 275 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
|
@@ -520,6 +550,8 @@ class MHA(nn.Module):
|
|
| 520 |
causal: bool = True,
|
| 521 |
softmax_scale: Optional[float] = None,
|
| 522 |
dropout: float = 0.0,
|
|
|
|
|
|
|
| 523 |
layer_idx: Optional[int] = None,
|
| 524 |
return_residual: bool = False,
|
| 525 |
checkpointing: bool = False,
|
|
@@ -532,15 +564,23 @@ class MHA(nn.Module):
|
|
| 532 |
rotary_kwargs = {"device": device}
|
| 533 |
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
| 534 |
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
|
| 537 |
# MLP
|
| 538 |
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
| 539 |
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
| 540 |
hidden_size = config.n_embd
|
| 541 |
|
| 542 |
-
|
| 543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
|
| 545 |
# Attention
|
| 546 |
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
|
|
|
| 32 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 33 |
|
| 34 |
from __future__ import annotations
|
| 35 |
+
import importlib
|
| 36 |
|
| 37 |
import math
|
| 38 |
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
|
| 49 |
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
| 50 |
|
| 51 |
|
| 52 |
+
def _is_flash_attn_available() -> bool:
|
| 53 |
+
return importlib.util.find_spec("flash_attn") is not None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if _is_flash_attn_available():
|
| 57 |
+
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 58 |
+
from flash_attn.ops.fused_dense import FusedDense
|
| 59 |
+
else:
|
| 60 |
+
FlashRotaryEmbedding = None
|
| 61 |
+
FusedDense = None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
@dataclass
|
| 65 |
class InferenceParams:
|
| 66 |
"""Inference parameters passed to model to efficiently calculate
|
|
|
|
| 226 |
dim: int,
|
| 227 |
base: int = 10000,
|
| 228 |
scale_base: Optional[float] = None,
|
| 229 |
+
pos_idx_in_fp32: bool = True,
|
| 230 |
device: Optional[str] = None,
|
| 231 |
**kwargs,
|
| 232 |
) -> None:
|
|
|
|
| 235 |
if scale_base is not None:
|
| 236 |
raise NotImplementedError
|
| 237 |
|
|
|
|
| 238 |
self.dim = dim
|
| 239 |
+
self.base = float(base)
|
| 240 |
self.scale_base = scale_base
|
| 241 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
| 242 |
self.device = device
|
| 243 |
|
| 244 |
+
# Generate and save the inverse frequency buffer (non-trainable)
|
| 245 |
+
inv_freq = self._compute_inv_freq(device)
|
| 246 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 247 |
|
| 248 |
+
# Generate and save the scale buffer (non-trainable)
|
| 249 |
scale = (
|
| 250 |
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
| 251 |
if scale_base is not None
|
|
|
|
| 259 |
self._cos_k_cached = None
|
| 260 |
self._sin_k_cached = None
|
| 261 |
|
| 262 |
+
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
| 263 |
+
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
| 264 |
+
|
| 265 |
def _update_cos_sin_cache(
|
| 266 |
self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
|
| 267 |
) -> None:
|
| 268 |
+
# Reset the tables if sequence length has been chaned, if we are on a
|
| 269 |
+
# new device or if we are switching from inference mode to training
|
| 270 |
+
if (
|
| 271 |
+
seqlen > self._seq_len_cached
|
| 272 |
+
or self._cos_cached is None
|
| 273 |
+
or self._cos_cached.device != device
|
| 274 |
+
or self._cos_cached.dtype != dtype
|
| 275 |
+
or (self.training and self._cos_cached.is_inference())
|
| 276 |
+
):
|
| 277 |
self._seq_len_cached = seqlen
|
|
|
|
| 278 |
|
| 279 |
+
# fp32 is preferred since the output of `torch.arange` can be quite large
|
| 280 |
+
# and bf16 would lose a lot of precision
|
| 281 |
+
if self.pos_idx_in_fp32:
|
| 282 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 283 |
+
if self.inv_freq.dtype != torch.float32:
|
| 284 |
+
inv_freq = self._compute_inv_freq(device=device)
|
| 285 |
+
else:
|
| 286 |
+
inv_freq = self.inv_freq
|
| 287 |
+
else:
|
| 288 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 289 |
+
inv_freq = self.inv_freq
|
| 290 |
+
|
| 291 |
+
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
| 292 |
+
freqs = torch.outer(t, inv_freq)
|
| 293 |
if self.scale is None:
|
| 294 |
self._cos_cached = torch.cos(freqs).to(dtype)
|
| 295 |
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
|
|
| 299 |
) / self.scale_base
|
| 300 |
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| 301 |
|
| 302 |
+
# Force the scale multiplication to happen in fp32
|
| 303 |
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 304 |
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 305 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
|
|
|
| 550 |
causal: bool = True,
|
| 551 |
softmax_scale: Optional[float] = None,
|
| 552 |
dropout: float = 0.0,
|
| 553 |
+
flash_rotary: bool = True,
|
| 554 |
+
fused_dense: bool = True,
|
| 555 |
layer_idx: Optional[int] = None,
|
| 556 |
return_residual: bool = False,
|
| 557 |
checkpointing: bool = False,
|
|
|
|
| 564 |
rotary_kwargs = {"device": device}
|
| 565 |
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
| 566 |
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
| 567 |
+
|
| 568 |
+
rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
|
| 569 |
+
if rotary_cls is None:
|
| 570 |
+
rotary_cls = RotaryEmbedding
|
| 571 |
+
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
| 572 |
|
| 573 |
# MLP
|
| 574 |
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
| 575 |
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
| 576 |
hidden_size = config.n_embd
|
| 577 |
|
| 578 |
+
linear_cls = FusedDense if fused_dense else nn.Linear
|
| 579 |
+
if linear_cls is None:
|
| 580 |
+
linear_cls = nn.Linear
|
| 581 |
+
|
| 582 |
+
self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
|
| 583 |
+
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
| 584 |
|
| 585 |
# Attention
|
| 586 |
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|