Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| from . import utils | |
| from torch import nn | |
| class CustomModule(nn.Module): | |
| """A simple two layer type I MLP structure. | |
| """ | |
| def __init__(self, w1_weight=None, w2_bias=None, w2_weight=None, act='gelu'): | |
| super().__init__() | |
| self.linear1 = nn.Linear(w1_weight.shape[1], w1_weight.shape[0]) | |
| self.linear2 = nn.Linear(w1_weight.shape[0], w1_weight.shape[1]) | |
| self.act = utils.load_activation(act) | |
| self.linear1.weight = nn.Parameter(w1_weight.float()) | |
| self.linear1.bias = nn.Parameter(w2_bias.float()) | |
| self.linear2.weight = nn.Parameter(w2_weight.T.float()) | |
| self.linear2.bias = nn.Parameter(torch.zeros_like(self.linear2.bias)) | |
| def forward(self, x): | |
| return self.linear2(self.act(self.linear1(x))) | |
| class CustomNormModule(nn.Module): | |
| """A simple two layer type I MLP structure. | |
| """ | |
| def __init__(self, | |
| w1_weight=None, | |
| w1_bias = None, | |
| w2_weight=None, | |
| centroid=None, | |
| norm_weight=None, | |
| norm_bias=None, | |
| add_norm = True, | |
| return_w1 = False, | |
| act='relu' | |
| ): | |
| super().__init__() | |
| self.linear1 = nn.Linear(w1_weight.shape[1], w1_weight.shape[0]) | |
| self.linear2 = nn.Linear(w1_weight.shape[0], w1_weight.shape[1]) | |
| self.act = utils.load_activation(act) | |
| self.centroid = centroid | |
| self.norm_weight = norm_weight | |
| self.norm_bias = norm_bias | |
| if self.norm_bias is None: self.norm_bias = 0 | |
| self.add_norm = add_norm | |
| self.return_w1 = return_w1 | |
| self.linear1.weight = nn.Parameter(w1_weight) | |
| if w1_bias is not None: self.linear1.bias = nn.Parameter(w1_bias) | |
| self.linear2.weight = nn.Parameter(w2_weight.T) | |
| self.linear2.bias = nn.Parameter(torch.zeros_like(self.linear2.bias).to(w1_weight.dtype).cuda()) | |
| def forward(self, x): | |
| # normalisation (part I) | |
| x = (x - self.norm_bias) / self.norm_weight / np.sqrt(self.centroid.shape[0]) | |
| x = x - self.centroid | |
| if self.add_norm: | |
| x = x / torch.norm(x, dim=-1)[:,:,None] | |
| w1_output = self.act(self.linear1(x)) | |
| if self.return_w1: | |
| return w1_output | |
| w2_output = self.linear2(w1_output) | |
| return w2_output | |
| class ModifiedMLP(nn.Module): | |
| """Modifed MLP structure | |
| """ | |
| def __init__(self, original_mlp, custom_module): | |
| super(ModifiedMLP, self).__init__() | |
| self.original_mlp = original_mlp | |
| self.custom_module = custom_module | |
| def forward(self, x): | |
| # Get the output from the original MLP | |
| o = self.original_mlp(x) | |
| # Pass the output through the CustomModule | |
| return o + self.custom_module(x) | |
| class ModifieMambadMLP(nn.Module): | |
| """Modifed MLP structure | |
| """ | |
| def __init__(self, original_mlp, custom_module): | |
| super(ModifieMambadMLP, self).__init__() | |
| self.original_mlp = original_mlp | |
| self.custom_module = custom_module | |
| def forward(self, x, cache_params=None): | |
| # Get the output from the original MLP | |
| o = self.original_mlp(x, cache_params=cache_params) | |
| # Pass the output through the CustomModule | |
| return o + self.custom_module(x) | |