Upload fused_norm_gate.py
Browse files- fused_norm_gate.py +896 -0
fused_norm_gate.py
ADDED
|
@@ -0,0 +1,896 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2023, Tri Dao.
|
| 4 |
+
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
|
| 5 |
+
# Implement residual + layer_norm / rms_norm.
|
| 6 |
+
|
| 7 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 8 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
| 9 |
+
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
| 10 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
import functools
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import triton
|
| 22 |
+
import triton.language as tl
|
| 23 |
+
|
| 24 |
+
def contiguous(fn):
|
| 25 |
+
@functools.wraps(fn)
|
| 26 |
+
def wrapper(ctx, *args, **kwargs):
|
| 27 |
+
return fn(ctx,
|
| 28 |
+
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
|
| 29 |
+
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
|
| 30 |
+
return wrapper
|
| 31 |
+
|
| 32 |
+
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
| 33 |
+
dtype = x.dtype
|
| 34 |
+
if upcast:
|
| 35 |
+
weight = weight.float()
|
| 36 |
+
bias = bias.float() if bias is not None else None
|
| 37 |
+
if upcast:
|
| 38 |
+
x = x.float()
|
| 39 |
+
residual = residual.float() if residual is not None else residual
|
| 40 |
+
if residual is not None:
|
| 41 |
+
x = (x + residual).to(x.dtype)
|
| 42 |
+
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
| 43 |
+
dtype
|
| 44 |
+
)
|
| 45 |
+
return out if not prenorm else (out, x)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
| 49 |
+
dtype = x.dtype
|
| 50 |
+
if upcast:
|
| 51 |
+
weight = weight.float()
|
| 52 |
+
bias = bias.float() if bias is not None else None
|
| 53 |
+
if upcast:
|
| 54 |
+
x = x.float()
|
| 55 |
+
residual = residual.float() if residual is not None else residual
|
| 56 |
+
if residual is not None:
|
| 57 |
+
x = (x + residual).to(x.dtype)
|
| 58 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
| 59 |
+
out = (x * rstd * weight) + \
|
| 60 |
+
bias if bias is not None else (x * rstd * weight)
|
| 61 |
+
out = out.to(dtype)
|
| 62 |
+
return out if not prenorm else (out, x)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@triton.autotune(
|
| 66 |
+
configs=[
|
| 67 |
+
triton.Config({}, num_warps=1),
|
| 68 |
+
triton.Config({}, num_warps=2),
|
| 69 |
+
triton.Config({}, num_warps=4),
|
| 70 |
+
triton.Config({}, num_warps=8),
|
| 71 |
+
triton.Config({}, num_warps=16),
|
| 72 |
+
triton.Config({}, num_warps=32),
|
| 73 |
+
],
|
| 74 |
+
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
| 75 |
+
)
|
| 76 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 77 |
+
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
| 78 |
+
@triton.jit
|
| 79 |
+
def _layer_norm_fwd_1pass_kernel(
|
| 80 |
+
X, # pointer to the input
|
| 81 |
+
O, # pointer to the gate
|
| 82 |
+
Y, # pointer to the output
|
| 83 |
+
W, # pointer to the weights
|
| 84 |
+
B, # pointer to the biases
|
| 85 |
+
RESIDUAL, # pointer to the residual
|
| 86 |
+
RESIDUAL_OUT, # pointer to the residual
|
| 87 |
+
Mean, # pointer to the mean
|
| 88 |
+
Rstd, # pointer to the 1/std
|
| 89 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 90 |
+
stride_y_row,
|
| 91 |
+
stride_res_row,
|
| 92 |
+
stride_res_out_row,
|
| 93 |
+
N, # number of columns in X
|
| 94 |
+
eps, # epsilon to avoid division by zero
|
| 95 |
+
IS_RMS_NORM: tl.constexpr,
|
| 96 |
+
BLOCK_N: tl.constexpr,
|
| 97 |
+
HAS_RESIDUAL: tl.constexpr,
|
| 98 |
+
STORE_RESIDUAL_OUT: tl.constexpr,
|
| 99 |
+
HAS_WEIGHT: tl.constexpr,
|
| 100 |
+
HAS_BIAS: tl.constexpr
|
| 101 |
+
):
|
| 102 |
+
# Map the program id to the row of X and Y it should compute.
|
| 103 |
+
row = tl.program_id(0)
|
| 104 |
+
X += row * stride_x_row
|
| 105 |
+
Y += row * stride_y_row
|
| 106 |
+
O += row * stride_x_row
|
| 107 |
+
if HAS_RESIDUAL:
|
| 108 |
+
RESIDUAL += row * stride_res_row
|
| 109 |
+
if STORE_RESIDUAL_OUT:
|
| 110 |
+
RESIDUAL_OUT += row * stride_res_out_row
|
| 111 |
+
# Compute mean and variance
|
| 112 |
+
cols = tl.arange(0, BLOCK_N)
|
| 113 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 114 |
+
if HAS_RESIDUAL:
|
| 115 |
+
residual = tl.load(RESIDUAL + cols, mask=cols <
|
| 116 |
+
N, other=0.0).to(tl.float32)
|
| 117 |
+
x += residual
|
| 118 |
+
if STORE_RESIDUAL_OUT:
|
| 119 |
+
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
| 120 |
+
if not IS_RMS_NORM:
|
| 121 |
+
mean = tl.sum(x, axis=0) / N
|
| 122 |
+
tl.store(Mean + row, mean)
|
| 123 |
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
| 124 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
| 125 |
+
else:
|
| 126 |
+
xbar = tl.where(cols < N, x, 0.0)
|
| 127 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
| 128 |
+
rstd = 1 / tl.sqrt(var + eps)
|
| 129 |
+
tl.store(Rstd + row, rstd)
|
| 130 |
+
# Normalize and apply linear transformation
|
| 131 |
+
mask = cols < N
|
| 132 |
+
if HAS_WEIGHT:
|
| 133 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 134 |
+
if HAS_BIAS:
|
| 135 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
| 136 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 137 |
+
y = x_hat * w if HAS_WEIGHT else x_hat
|
| 138 |
+
if HAS_BIAS:
|
| 139 |
+
y = y + b
|
| 140 |
+
|
| 141 |
+
# Swish output gate
|
| 142 |
+
o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 143 |
+
y = y * o * tl.sigmoid(o)
|
| 144 |
+
|
| 145 |
+
# Write output
|
| 146 |
+
tl.store(Y + cols, y, mask=mask)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _layer_norm_fwd(
|
| 150 |
+
x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
|
| 151 |
+
):
|
| 152 |
+
if residual is not None:
|
| 153 |
+
residual_dtype = residual.dtype
|
| 154 |
+
M, N = x.shape
|
| 155 |
+
assert x.stride(-1) == 1
|
| 156 |
+
if residual is not None:
|
| 157 |
+
assert residual.stride(-1) == 1
|
| 158 |
+
assert residual.shape == (M, N)
|
| 159 |
+
if weight is not None:
|
| 160 |
+
assert weight.shape == (N,)
|
| 161 |
+
assert weight.stride(-1) == 1
|
| 162 |
+
if bias is not None:
|
| 163 |
+
assert bias.stride(-1) == 1
|
| 164 |
+
assert bias.shape == (N,)
|
| 165 |
+
# allocate output
|
| 166 |
+
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
| 167 |
+
assert y.stride(-1) == 1
|
| 168 |
+
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
| 169 |
+
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
|
| 170 |
+
assert residual_out.stride(-1) == 1
|
| 171 |
+
else:
|
| 172 |
+
residual_out = None
|
| 173 |
+
mean = torch.empty((M,), dtype=torch.float32,
|
| 174 |
+
device="cuda") if not is_rms_norm else None
|
| 175 |
+
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
| 176 |
+
# Less than 64KB per feature: enqueue fused kernel
|
| 177 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 178 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 179 |
+
if N > BLOCK_N:
|
| 180 |
+
raise RuntimeError(
|
| 181 |
+
"This layer norm doesn't support feature dim >= 64KB.")
|
| 182 |
+
# heuristics for number of warps
|
| 183 |
+
with torch.cuda.device(x.device.index):
|
| 184 |
+
_layer_norm_fwd_1pass_kernel[(M,)](
|
| 185 |
+
x,
|
| 186 |
+
o,
|
| 187 |
+
y,
|
| 188 |
+
weight,
|
| 189 |
+
bias,
|
| 190 |
+
residual,
|
| 191 |
+
residual_out,
|
| 192 |
+
mean,
|
| 193 |
+
rstd,
|
| 194 |
+
x.stride(0),
|
| 195 |
+
y.stride(0),
|
| 196 |
+
residual.stride(0) if residual is not None else 0,
|
| 197 |
+
residual_out.stride(0) if residual_out is not None else 0,
|
| 198 |
+
N,
|
| 199 |
+
eps,
|
| 200 |
+
is_rms_norm,
|
| 201 |
+
BLOCK_N,
|
| 202 |
+
residual is not None,
|
| 203 |
+
residual_out is not None,
|
| 204 |
+
weight is not None,
|
| 205 |
+
bias is not None,
|
| 206 |
+
)
|
| 207 |
+
# residual_out is None if residual is None and residual_dtype == input_dtype
|
| 208 |
+
return y, mean, rstd, residual_out if residual_out is not None else x
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@triton.autotune(
|
| 212 |
+
configs=[
|
| 213 |
+
triton.Config({}, num_warps=1),
|
| 214 |
+
triton.Config({}, num_warps=2),
|
| 215 |
+
triton.Config({}, num_warps=4),
|
| 216 |
+
triton.Config({}, num_warps=8),
|
| 217 |
+
triton.Config({}, num_warps=16),
|
| 218 |
+
triton.Config({}, num_warps=32),
|
| 219 |
+
],
|
| 220 |
+
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
|
| 221 |
+
)
|
| 222 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 223 |
+
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
| 224 |
+
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
| 225 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
| 226 |
+
@triton.jit
|
| 227 |
+
def _layer_norm_bwd_kernel(
|
| 228 |
+
X, # pointer to the input
|
| 229 |
+
O, # pointer to the gate
|
| 230 |
+
W, # pointer to the weights
|
| 231 |
+
B, # pointer to the biases
|
| 232 |
+
Y, # pointer to the output to be recomputed
|
| 233 |
+
DY, # pointer to the output gradient
|
| 234 |
+
DX, # pointer to the input gradient
|
| 235 |
+
DO, # pointer to the gate gradient
|
| 236 |
+
DW, # pointer to the partial sum of weights gradient
|
| 237 |
+
DB, # pointer to the partial sum of biases gradient
|
| 238 |
+
DRESIDUAL,
|
| 239 |
+
DRESIDUAL_IN,
|
| 240 |
+
Mean, # pointer to the mean
|
| 241 |
+
Rstd, # pointer to the 1/std
|
| 242 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 243 |
+
stride_y_row,
|
| 244 |
+
stride_dy_row,
|
| 245 |
+
stride_dx_row,
|
| 246 |
+
stride_dres_row,
|
| 247 |
+
stride_dres_in_row,
|
| 248 |
+
M, # number of rows in X
|
| 249 |
+
N, # number of columns in X
|
| 250 |
+
eps, # epsilon to avoid division by zero
|
| 251 |
+
rows_per_program,
|
| 252 |
+
IS_RMS_NORM: tl.constexpr,
|
| 253 |
+
BLOCK_N: tl.constexpr,
|
| 254 |
+
HAS_DRESIDUAL: tl.constexpr,
|
| 255 |
+
STORE_DRESIDUAL: tl.constexpr,
|
| 256 |
+
HAS_WEIGHT: tl.constexpr,
|
| 257 |
+
HAS_BIAS: tl.constexpr,
|
| 258 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
| 259 |
+
):
|
| 260 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
| 261 |
+
row_block_id = tl.program_id(0)
|
| 262 |
+
row_start = row_block_id * rows_per_program
|
| 263 |
+
cols = tl.arange(0, BLOCK_N)
|
| 264 |
+
mask = cols < N
|
| 265 |
+
X += row_start * stride_x_row
|
| 266 |
+
O += row_start * stride_x_row
|
| 267 |
+
if HAS_DRESIDUAL:
|
| 268 |
+
DRESIDUAL += row_start * stride_dres_row
|
| 269 |
+
if STORE_DRESIDUAL:
|
| 270 |
+
DRESIDUAL_IN += row_start * stride_dres_in_row
|
| 271 |
+
DY += row_start * stride_dy_row
|
| 272 |
+
DX += row_start * stride_dx_row
|
| 273 |
+
DO += row_start * stride_dx_row
|
| 274 |
+
if RECOMPUTE_OUTPUT:
|
| 275 |
+
Y += row_start * stride_y_row
|
| 276 |
+
if HAS_WEIGHT:
|
| 277 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 278 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 279 |
+
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
| 280 |
+
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
| 281 |
+
if HAS_BIAS:
|
| 282 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 283 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
| 284 |
+
for row in range(row_start, row_end):
|
| 285 |
+
# Load data to SRAM
|
| 286 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
| 287 |
+
o = tl.load(O + cols, mask=mask, other=0).to(tl.float32)
|
| 288 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
| 289 |
+
|
| 290 |
+
if not IS_RMS_NORM:
|
| 291 |
+
mean = tl.load(Mean + row)
|
| 292 |
+
rstd = tl.load(Rstd + row)
|
| 293 |
+
# Compute dx
|
| 294 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 295 |
+
xhat = tl.where(mask, xhat, 0.0)
|
| 296 |
+
|
| 297 |
+
y = xhat * w if HAS_WEIGHT else xhat
|
| 298 |
+
if HAS_BIAS:
|
| 299 |
+
y = y + b
|
| 300 |
+
if RECOMPUTE_OUTPUT:
|
| 301 |
+
tl.store(Y + cols, y, mask=mask)
|
| 302 |
+
|
| 303 |
+
sigmoid_o = tl.sigmoid(o)
|
| 304 |
+
do = dy * y * (sigmoid_o + o * sigmoid_o * (1 - sigmoid_o))
|
| 305 |
+
dy = dy * o * sigmoid_o
|
| 306 |
+
wdy = dy
|
| 307 |
+
if HAS_WEIGHT:
|
| 308 |
+
wdy = dy * w
|
| 309 |
+
dw += dy * xhat
|
| 310 |
+
if HAS_BIAS:
|
| 311 |
+
db += dy
|
| 312 |
+
if not IS_RMS_NORM:
|
| 313 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 314 |
+
c2 = tl.sum(wdy, axis=0) / N
|
| 315 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
| 316 |
+
else:
|
| 317 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 318 |
+
dx = (wdy - xhat * c1) * rstd
|
| 319 |
+
if HAS_DRESIDUAL:
|
| 320 |
+
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
| 321 |
+
dx += dres
|
| 322 |
+
# Write dx
|
| 323 |
+
if STORE_DRESIDUAL:
|
| 324 |
+
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
| 325 |
+
tl.store(DX + cols, dx, mask=mask)
|
| 326 |
+
tl.store(DO + cols, do, mask=mask)
|
| 327 |
+
|
| 328 |
+
X += stride_x_row
|
| 329 |
+
O += stride_x_row
|
| 330 |
+
if HAS_DRESIDUAL:
|
| 331 |
+
DRESIDUAL += stride_dres_row
|
| 332 |
+
if STORE_DRESIDUAL:
|
| 333 |
+
DRESIDUAL_IN += stride_dres_in_row
|
| 334 |
+
if RECOMPUTE_OUTPUT:
|
| 335 |
+
Y += stride_y_row
|
| 336 |
+
DY += stride_dy_row
|
| 337 |
+
DX += stride_dx_row
|
| 338 |
+
DO += stride_dx_row
|
| 339 |
+
if HAS_WEIGHT:
|
| 340 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
| 341 |
+
if HAS_BIAS:
|
| 342 |
+
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def _layer_norm_bwd(
|
| 346 |
+
dy,
|
| 347 |
+
x,
|
| 348 |
+
o,
|
| 349 |
+
weight,
|
| 350 |
+
bias,
|
| 351 |
+
eps,
|
| 352 |
+
mean,
|
| 353 |
+
rstd,
|
| 354 |
+
dresidual=None,
|
| 355 |
+
has_residual=False,
|
| 356 |
+
is_rms_norm=False,
|
| 357 |
+
x_dtype=None,
|
| 358 |
+
recompute_output=False,
|
| 359 |
+
):
|
| 360 |
+
M, N = x.shape
|
| 361 |
+
assert x.stride(-1) == 1
|
| 362 |
+
assert dy.stride(-1) == 1
|
| 363 |
+
assert dy.shape == (M, N)
|
| 364 |
+
if dresidual is not None:
|
| 365 |
+
assert dresidual.stride(-1) == 1
|
| 366 |
+
assert dresidual.shape == (M, N)
|
| 367 |
+
if weight is not None:
|
| 368 |
+
assert weight.shape == (N,)
|
| 369 |
+
assert weight.stride(-1) == 1
|
| 370 |
+
if bias is not None:
|
| 371 |
+
assert bias.stride(-1) == 1
|
| 372 |
+
assert bias.shape == (N,)
|
| 373 |
+
# allocate output
|
| 374 |
+
dx = (
|
| 375 |
+
torch.empty_like(x)
|
| 376 |
+
if x_dtype is None
|
| 377 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
| 378 |
+
)
|
| 379 |
+
do = (
|
| 380 |
+
torch.empty_like(o)
|
| 381 |
+
if x_dtype is None
|
| 382 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
| 383 |
+
)
|
| 384 |
+
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
|
| 385 |
+
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
| 386 |
+
|
| 387 |
+
# Less than 64KB per feature: enqueue fused kernel
|
| 388 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 389 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 390 |
+
if N > BLOCK_N:
|
| 391 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 392 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 393 |
+
_dw = (
|
| 394 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
| 395 |
+
if weight is not None
|
| 396 |
+
else None
|
| 397 |
+
)
|
| 398 |
+
_db = (
|
| 399 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
| 400 |
+
if bias is not None
|
| 401 |
+
else None
|
| 402 |
+
)
|
| 403 |
+
rows_per_program = math.ceil(M / sm_count)
|
| 404 |
+
grid = (sm_count,)
|
| 405 |
+
with torch.cuda.device(x.device.index):
|
| 406 |
+
_layer_norm_bwd_kernel[grid](
|
| 407 |
+
x,
|
| 408 |
+
o,
|
| 409 |
+
weight,
|
| 410 |
+
bias,
|
| 411 |
+
y,
|
| 412 |
+
dy,
|
| 413 |
+
dx,
|
| 414 |
+
do,
|
| 415 |
+
_dw,
|
| 416 |
+
_db,
|
| 417 |
+
dresidual,
|
| 418 |
+
dresidual_in,
|
| 419 |
+
mean,
|
| 420 |
+
rstd,
|
| 421 |
+
x.stride(0),
|
| 422 |
+
0 if not recompute_output else y.stride(0),
|
| 423 |
+
dy.stride(0),
|
| 424 |
+
dx.stride(0),
|
| 425 |
+
dresidual.stride(0) if dresidual is not None else 0,
|
| 426 |
+
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
| 427 |
+
M,
|
| 428 |
+
N,
|
| 429 |
+
eps,
|
| 430 |
+
rows_per_program,
|
| 431 |
+
is_rms_norm,
|
| 432 |
+
BLOCK_N,
|
| 433 |
+
dresidual is not None,
|
| 434 |
+
dresidual_in is not None,
|
| 435 |
+
weight is not None,
|
| 436 |
+
bias is not None,
|
| 437 |
+
)
|
| 438 |
+
dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
|
| 439 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
| 440 |
+
# Don't need to compute dresidual_in separately in this case
|
| 441 |
+
if has_residual and dx.dtype == x.dtype:
|
| 442 |
+
dresidual_in = dx
|
| 443 |
+
return (dx, do, dw, db, dresidual_in) if not recompute_output else (dx, do, dw, db, dresidual_in, y)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class LayerNormSwishGateFn(torch.autograd.Function):
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
@contiguous
|
| 450 |
+
def forward(
|
| 451 |
+
ctx,
|
| 452 |
+
x,
|
| 453 |
+
o,
|
| 454 |
+
weight,
|
| 455 |
+
bias,
|
| 456 |
+
residual=None,
|
| 457 |
+
eps=1e-6,
|
| 458 |
+
prenorm=False,
|
| 459 |
+
residual_in_fp32=False,
|
| 460 |
+
is_rms_norm=False,
|
| 461 |
+
):
|
| 462 |
+
x_shape_og = x.shape
|
| 463 |
+
o_shape_og = o.shape
|
| 464 |
+
# reshape input data into 2D tensor
|
| 465 |
+
x = x.reshape(-1, x.shape[-1])
|
| 466 |
+
o = o.reshape(-1, o.shape[-1])
|
| 467 |
+
if residual is not None:
|
| 468 |
+
assert residual.shape == x_shape_og
|
| 469 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
| 470 |
+
residual_dtype = (
|
| 471 |
+
residual.dtype
|
| 472 |
+
if residual is not None
|
| 473 |
+
else (torch.float32 if residual_in_fp32 else None)
|
| 474 |
+
)
|
| 475 |
+
y, mean, rstd, residual_out = _layer_norm_fwd(
|
| 476 |
+
x, o, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
|
| 477 |
+
)
|
| 478 |
+
ctx.save_for_backward(residual_out, o, weight, bias, mean, rstd)
|
| 479 |
+
ctx.x_shape_og = x_shape_og
|
| 480 |
+
ctx.o_shape_og = o_shape_og
|
| 481 |
+
ctx.eps = eps
|
| 482 |
+
ctx.is_rms_norm = is_rms_norm
|
| 483 |
+
ctx.has_residual = residual is not None
|
| 484 |
+
ctx.prenorm = prenorm
|
| 485 |
+
ctx.x_dtype = x.dtype
|
| 486 |
+
y = y.reshape(x_shape_og)
|
| 487 |
+
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
| 488 |
+
|
| 489 |
+
@staticmethod
|
| 490 |
+
@contiguous
|
| 491 |
+
def backward(ctx, dy, *args):
|
| 492 |
+
x, o, weight, bias, mean, rstd = ctx.saved_tensors
|
| 493 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
| 494 |
+
assert dy.shape == x.shape
|
| 495 |
+
if ctx.prenorm:
|
| 496 |
+
dresidual = args[0]
|
| 497 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 498 |
+
assert dresidual.shape == x.shape
|
| 499 |
+
else:
|
| 500 |
+
dresidual = None
|
| 501 |
+
dx, do, dw, db, dresidual_in = _layer_norm_bwd(
|
| 502 |
+
dy,
|
| 503 |
+
x,
|
| 504 |
+
o,
|
| 505 |
+
weight,
|
| 506 |
+
bias,
|
| 507 |
+
ctx.eps,
|
| 508 |
+
mean,
|
| 509 |
+
rstd,
|
| 510 |
+
dresidual,
|
| 511 |
+
ctx.has_residual,
|
| 512 |
+
ctx.is_rms_norm,
|
| 513 |
+
x_dtype=ctx.x_dtype,
|
| 514 |
+
)
|
| 515 |
+
return (
|
| 516 |
+
dx.reshape(ctx.x_shape_og),
|
| 517 |
+
do.reshape(ctx.o_shape_og),
|
| 518 |
+
dw,
|
| 519 |
+
db,
|
| 520 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 521 |
+
None,
|
| 522 |
+
None,
|
| 523 |
+
None,
|
| 524 |
+
None,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class LayerNormSwishGateLinearFn(torch.autograd.Function):
|
| 529 |
+
|
| 530 |
+
@staticmethod
|
| 531 |
+
@contiguous
|
| 532 |
+
def forward(
|
| 533 |
+
ctx,
|
| 534 |
+
x,
|
| 535 |
+
o,
|
| 536 |
+
norm_weight,
|
| 537 |
+
norm_bias,
|
| 538 |
+
linear_weight,
|
| 539 |
+
linear_bias,
|
| 540 |
+
residual=None,
|
| 541 |
+
eps=1e-6,
|
| 542 |
+
prenorm=False,
|
| 543 |
+
residual_in_fp32=False,
|
| 544 |
+
is_rms_norm=False,
|
| 545 |
+
):
|
| 546 |
+
x_shape_og = x.shape
|
| 547 |
+
o_shape_og = o.shape
|
| 548 |
+
# reshape input data into 2D tensor
|
| 549 |
+
x = x.reshape(-1, x.shape[-1])
|
| 550 |
+
o = o.reshape(-1, o.shape[-1])
|
| 551 |
+
if residual is not None:
|
| 552 |
+
assert residual.shape == x_shape_og
|
| 553 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
| 554 |
+
residual_dtype = (
|
| 555 |
+
residual.dtype
|
| 556 |
+
if residual is not None
|
| 557 |
+
else (torch.float32 if residual_in_fp32 else None)
|
| 558 |
+
)
|
| 559 |
+
y, mean, rstd, residual_out = _layer_norm_fwd(
|
| 560 |
+
x,
|
| 561 |
+
o,
|
| 562 |
+
norm_weight,
|
| 563 |
+
norm_bias,
|
| 564 |
+
eps,
|
| 565 |
+
residual,
|
| 566 |
+
residual_dtype=residual_dtype,
|
| 567 |
+
is_rms_norm=is_rms_norm
|
| 568 |
+
)
|
| 569 |
+
y = y.reshape(x_shape_og)
|
| 570 |
+
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
| 571 |
+
linear_weight = linear_weight.to(dtype)
|
| 572 |
+
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
| 573 |
+
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
| 574 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
| 575 |
+
ctx.save_for_backward(residual_out, o, norm_weight, norm_bias, linear_weight, mean, rstd)
|
| 576 |
+
ctx.x_shape_og = x_shape_og
|
| 577 |
+
ctx.o_shape_og = o_shape_og
|
| 578 |
+
ctx.eps = eps
|
| 579 |
+
ctx.is_rms_norm = is_rms_norm
|
| 580 |
+
ctx.has_residual = residual is not None
|
| 581 |
+
ctx.prenorm = prenorm
|
| 582 |
+
ctx.x_dtype = x.dtype
|
| 583 |
+
ctx.linear_bias_is_none = linear_bias is None
|
| 584 |
+
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
| 585 |
+
|
| 586 |
+
@staticmethod
|
| 587 |
+
@contiguous
|
| 588 |
+
def backward(ctx, dout, *args):
|
| 589 |
+
x, o, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
| 590 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
| 591 |
+
dy = F.linear(dout, linear_weight.t())
|
| 592 |
+
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
| 593 |
+
assert dy.shape == x.shape
|
| 594 |
+
if ctx.prenorm:
|
| 595 |
+
dresidual = args[0]
|
| 596 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 597 |
+
assert dresidual.shape == x.shape
|
| 598 |
+
else:
|
| 599 |
+
dresidual = None
|
| 600 |
+
dx, do, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
|
| 601 |
+
dy,
|
| 602 |
+
x,
|
| 603 |
+
o,
|
| 604 |
+
norm_weight,
|
| 605 |
+
norm_bias,
|
| 606 |
+
ctx.eps,
|
| 607 |
+
mean,
|
| 608 |
+
rstd,
|
| 609 |
+
dresidual=dresidual,
|
| 610 |
+
has_residual=ctx.has_residual,
|
| 611 |
+
is_rms_norm=ctx.is_rms_norm,
|
| 612 |
+
x_dtype=ctx.x_dtype,
|
| 613 |
+
recompute_output=True,
|
| 614 |
+
)
|
| 615 |
+
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
| 616 |
+
return (
|
| 617 |
+
dx.reshape(ctx.x_shape_og),
|
| 618 |
+
do.reshape(ctx.o_shape_og),
|
| 619 |
+
dnorm_weight,
|
| 620 |
+
dnorm_bias,
|
| 621 |
+
dlinear_weight,
|
| 622 |
+
dlinear_bias,
|
| 623 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 624 |
+
None,
|
| 625 |
+
None,
|
| 626 |
+
None,
|
| 627 |
+
None,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def layer_norm_swish_gate_fn(
|
| 632 |
+
x,
|
| 633 |
+
o,
|
| 634 |
+
weight,
|
| 635 |
+
bias,
|
| 636 |
+
residual=None,
|
| 637 |
+
prenorm=False,
|
| 638 |
+
residual_in_fp32=False,
|
| 639 |
+
eps=1e-6
|
| 640 |
+
):
|
| 641 |
+
return LayerNormSwishGateFn.apply(
|
| 642 |
+
x,
|
| 643 |
+
o,
|
| 644 |
+
weight,
|
| 645 |
+
bias,
|
| 646 |
+
residual,
|
| 647 |
+
eps,
|
| 648 |
+
prenorm,
|
| 649 |
+
residual_in_fp32,
|
| 650 |
+
False
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def rms_norm_swish_gate_fn(
|
| 655 |
+
x,
|
| 656 |
+
o,
|
| 657 |
+
weight,
|
| 658 |
+
bias,
|
| 659 |
+
residual=None,
|
| 660 |
+
prenorm=False,
|
| 661 |
+
residual_in_fp32=False,
|
| 662 |
+
eps=1e-6
|
| 663 |
+
):
|
| 664 |
+
return LayerNormSwishGateFn.apply(
|
| 665 |
+
x,
|
| 666 |
+
o,
|
| 667 |
+
weight,
|
| 668 |
+
bias,
|
| 669 |
+
residual,
|
| 670 |
+
eps,
|
| 671 |
+
prenorm,
|
| 672 |
+
residual_in_fp32,
|
| 673 |
+
True
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def layer_norm_swish_gate_linear_fn(
|
| 678 |
+
x,
|
| 679 |
+
o,
|
| 680 |
+
norm_weight,
|
| 681 |
+
norm_bias,
|
| 682 |
+
linear_weight,
|
| 683 |
+
linear_bias,
|
| 684 |
+
residual=None,
|
| 685 |
+
prenorm=False,
|
| 686 |
+
residual_in_fp32=False,
|
| 687 |
+
eps=1e-6
|
| 688 |
+
):
|
| 689 |
+
return LayerNormSwishGateLinearFn.apply(
|
| 690 |
+
x,
|
| 691 |
+
o,
|
| 692 |
+
norm_weight,
|
| 693 |
+
norm_bias,
|
| 694 |
+
linear_weight,
|
| 695 |
+
linear_bias,
|
| 696 |
+
residual,
|
| 697 |
+
eps,
|
| 698 |
+
prenorm,
|
| 699 |
+
residual_in_fp32,
|
| 700 |
+
False
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
def rms_norm_swish_gate_linear_fn(
|
| 705 |
+
x,
|
| 706 |
+
o,
|
| 707 |
+
norm_weight,
|
| 708 |
+
norm_bias,
|
| 709 |
+
linear_weight,
|
| 710 |
+
linear_bias,
|
| 711 |
+
residual=None,
|
| 712 |
+
prenorm=False,
|
| 713 |
+
residual_in_fp32=False,
|
| 714 |
+
eps=1e-6
|
| 715 |
+
):
|
| 716 |
+
return LayerNormSwishGateLinearFn.apply(
|
| 717 |
+
x,
|
| 718 |
+
o,
|
| 719 |
+
norm_weight,
|
| 720 |
+
norm_bias,
|
| 721 |
+
linear_weight,
|
| 722 |
+
linear_bias,
|
| 723 |
+
residual,
|
| 724 |
+
eps,
|
| 725 |
+
prenorm,
|
| 726 |
+
residual_in_fp32,
|
| 727 |
+
True
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
class FusedLayerNormSwishGate(nn.Module):
|
| 732 |
+
|
| 733 |
+
def __init__(
|
| 734 |
+
self,
|
| 735 |
+
hidden_size,
|
| 736 |
+
elementwise_affine: bool = True,
|
| 737 |
+
eps=1e-5
|
| 738 |
+
) -> FusedLayerNormSwishGate:
|
| 739 |
+
super().__init__()
|
| 740 |
+
|
| 741 |
+
self.hidden_size = hidden_size
|
| 742 |
+
self.elementwise_affine = elementwise_affine
|
| 743 |
+
self.eps = eps
|
| 744 |
+
|
| 745 |
+
if elementwise_affine:
|
| 746 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 747 |
+
else:
|
| 748 |
+
self.register_parameter("weight", None)
|
| 749 |
+
self.register_parameter("bias", None)
|
| 750 |
+
|
| 751 |
+
def __repr__(self) -> str:
|
| 752 |
+
s = f"{self.__class__.__name__}({self.hidden_size}"
|
| 753 |
+
if not self.elementwise_affine:
|
| 754 |
+
s += f", elementwise_affine={self.elementwise_affine}"
|
| 755 |
+
s += f", eps={self.eps}"
|
| 756 |
+
s += ")"
|
| 757 |
+
return s
|
| 758 |
+
|
| 759 |
+
def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
|
| 760 |
+
return layer_norm_swish_gate_fn(
|
| 761 |
+
x,
|
| 762 |
+
o,
|
| 763 |
+
self.weight,
|
| 764 |
+
self.bias,
|
| 765 |
+
residual=residual,
|
| 766 |
+
eps=self.eps,
|
| 767 |
+
prenorm=prenorm,
|
| 768 |
+
residual_in_fp32=residual_in_fp32
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
class FusedRMSNormSwishGate(nn.Module):
|
| 773 |
+
|
| 774 |
+
def __init__(
|
| 775 |
+
self,
|
| 776 |
+
hidden_size,
|
| 777 |
+
elementwise_affine: bool = True,
|
| 778 |
+
eps=1e-5
|
| 779 |
+
) -> FusedRMSNormSwishGate:
|
| 780 |
+
super().__init__()
|
| 781 |
+
|
| 782 |
+
self.hidden_size = hidden_size
|
| 783 |
+
self.elementwise_affine = elementwise_affine
|
| 784 |
+
self.eps = eps
|
| 785 |
+
|
| 786 |
+
if elementwise_affine:
|
| 787 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 788 |
+
else:
|
| 789 |
+
self.register_parameter("weight", None)
|
| 790 |
+
self.register_parameter("bias", None)
|
| 791 |
+
|
| 792 |
+
def __repr__(self) -> str:
|
| 793 |
+
s = f"{self.__class__.__name__}({self.hidden_size}"
|
| 794 |
+
if not self.elementwise_affine:
|
| 795 |
+
s += f", elementwise_affine={self.elementwise_affine}"
|
| 796 |
+
s += f", eps={self.eps}"
|
| 797 |
+
s += ")"
|
| 798 |
+
return s
|
| 799 |
+
|
| 800 |
+
def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
|
| 801 |
+
return rms_norm_swish_gate_fn(
|
| 802 |
+
x,
|
| 803 |
+
o,
|
| 804 |
+
self.weight,
|
| 805 |
+
self.bias,
|
| 806 |
+
residual=residual,
|
| 807 |
+
eps=self.eps,
|
| 808 |
+
prenorm=prenorm,
|
| 809 |
+
residual_in_fp32=residual_in_fp32
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
class FusedLayerNormSwishGateLinear(nn.Module):
|
| 814 |
+
|
| 815 |
+
def __init__(
|
| 816 |
+
self,
|
| 817 |
+
hidden_size,
|
| 818 |
+
elementwise_affine: bool = True,
|
| 819 |
+
eps=1e-5
|
| 820 |
+
) -> FusedLayerNormSwishGateLinear:
|
| 821 |
+
super().__init__()
|
| 822 |
+
|
| 823 |
+
self.hidden_size = hidden_size
|
| 824 |
+
self.elementwise_affine = elementwise_affine
|
| 825 |
+
self.eps = eps
|
| 826 |
+
|
| 827 |
+
if elementwise_affine:
|
| 828 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 829 |
+
else:
|
| 830 |
+
self.register_parameter("weight", None)
|
| 831 |
+
self.register_parameter("bias", None)
|
| 832 |
+
|
| 833 |
+
def __repr__(self) -> str:
|
| 834 |
+
s = f"{self.__class__.__name__}({self.hidden_size}"
|
| 835 |
+
if not self.elementwise_affine:
|
| 836 |
+
s += f", elementwise_affine={self.elementwise_affine}"
|
| 837 |
+
s += f", eps={self.eps}"
|
| 838 |
+
s += ")"
|
| 839 |
+
return s
|
| 840 |
+
|
| 841 |
+
def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
|
| 842 |
+
return layer_norm_swish_gate_linear_fn(
|
| 843 |
+
x,
|
| 844 |
+
o,
|
| 845 |
+
self.weight,
|
| 846 |
+
self.bias,
|
| 847 |
+
weight,
|
| 848 |
+
bias,
|
| 849 |
+
residual=residual,
|
| 850 |
+
eps=self.eps,
|
| 851 |
+
prenorm=prenorm,
|
| 852 |
+
residual_in_fp32=residual_in_fp32
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
class FusedRMSNormSwishGateLinear(nn.Module):
|
| 857 |
+
|
| 858 |
+
def __init__(
|
| 859 |
+
self,
|
| 860 |
+
hidden_size,
|
| 861 |
+
elementwise_affine: bool = True,
|
| 862 |
+
eps=1e-5
|
| 863 |
+
) -> FusedRMSNormSwishGateLinear:
|
| 864 |
+
super().__init__()
|
| 865 |
+
|
| 866 |
+
self.hidden_size = hidden_size
|
| 867 |
+
self.elementwise_affine = elementwise_affine
|
| 868 |
+
self.eps = eps
|
| 869 |
+
|
| 870 |
+
if elementwise_affine:
|
| 871 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 872 |
+
else:
|
| 873 |
+
self.register_parameter("weight", None)
|
| 874 |
+
self.register_parameter("bias", None)
|
| 875 |
+
|
| 876 |
+
def __repr__(self) -> str:
|
| 877 |
+
s = f"{self.__class__.__name__}({self.hidden_size}"
|
| 878 |
+
if not self.elementwise_affine:
|
| 879 |
+
s += f", elementwise_affine={self.elementwise_affine}"
|
| 880 |
+
s += f", eps={self.eps}"
|
| 881 |
+
s += ")"
|
| 882 |
+
return s
|
| 883 |
+
|
| 884 |
+
def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
|
| 885 |
+
return rms_norm_swish_gate_linear_fn(
|
| 886 |
+
x,
|
| 887 |
+
o,
|
| 888 |
+
self.weight,
|
| 889 |
+
self.bias,
|
| 890 |
+
weight,
|
| 891 |
+
bias,
|
| 892 |
+
residual=residual,
|
| 893 |
+
eps=self.eps,
|
| 894 |
+
prenorm=prenorm,
|
| 895 |
+
residual_in_fp32=residual_in_fp32
|
| 896 |
+
)
|