ActPRM / nets.py
vermouthdky's picture
Upload 3 files
8c38a63 verified
# Copyright 2024 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deep networks."""
from copy import deepcopy
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
def init_weights(m):
@torch.no_grad()
def truncated_normal_init(t, mean=0.0, std=0.01):
# torch.nn.init.normal_(t, mean=mean, std=std)
t.data.normal_(mean, std)
while True:
cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std)
if not torch.sum(cond):
break
w = torch.empty(t.shape, device=t.device, dtype=t.dtype)
# torch.nn.init.normal_(w, mean=mean, std=std)
w.data.normal_(mean, std)
t = torch.where(cond, w, t)
return t
if type(m) is nn.Linear or isinstance(m, EnsembleFC):
truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(m.in_features)))
if m.bias is not None:
m.bias.data.fill_(0.0)
def init_weights_uniform(m):
input_dim = m.in_features
torch.nn.init.uniform(m.weight, -1 / np.sqrt(input_dim), 1 / np.sqrt(input_dim))
if m.bias is not None:
m.bias.data.fill_(0.0)
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
x = x * F.sigmoid(x)
return x
class MLPModel(nn.Module):
def __init__(self, encoding_dim, hidden_dim=128, activation="relu") -> None:
super(MLPModel, self).__init__()
self.hidden_size = hidden_dim
self.output_dim = 1
self.nn1 = nn.Linear(encoding_dim, hidden_dim)
self.nn2 = nn.Linear(hidden_dim, hidden_dim)
self.nn_out = nn.Linear(hidden_dim, self.output_dim)
self.apply(init_weights)
if activation == "swish":
self.activation = Swish()
elif activation == "relu":
self.activation = nn.ReLU()
else:
raise ValueError(f"Unknown activation {activation}")
def get_params(self) -> torch.Tensor:
params = []
for pp in list(self.parameters()):
params.append(pp.view(-1))
return torch.cat(params)
def forward(self, encoding: torch.Tensor) -> torch.Tensor:
x = self.activation(self.nn1(encoding))
x = self.activation(self.nn2(x))
score = self.nn_out(x)
return score
def init(self):
self.init_params = self.get_params().data.clone()
if torch.cuda.is_available():
self.init_params = self.init_params.cuda()
def regularization(self):
"""Prior towards independent initialization."""
return ((self.get_params() - self.init_params) ** 2).mean()
class EnsembleFC(nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
ensemble_size: int
weight: torch.Tensor
def __init__(
self,
in_features: int,
out_features: int,
ensemble_size: int,
bias: bool = True,
dtype=torch.float32,
) -> None:
super(EnsembleFC, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size
# init immediately to avoid error
self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features, dtype=dtype))
if bias:
self.bias = nn.Parameter(torch.empty(ensemble_size, out_features, dtype=dtype))
else:
self.register_parameter("bias", None)
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(self.weight.dtype)
wx = torch.einsum("eblh,ehm->eblm", input, self.weight)
return torch.add(wx, self.bias[:, None, None, :]) # w times x + b
def get_params(model):
return torch.cat([p.view(-1) for p in model.parameters()])
class _EnsembleModel(nn.Module):
def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None:
# super().__init__(encoding_dim, hidden_dim, activation)
super(_EnsembleModel, self).__init__()
self.num_ensemble = num_ensemble
self.hidden_dim = hidden_dim
self.output_dim = 1
self.nn1 = EnsembleFC(encoding_dim, hidden_dim, num_ensemble, dtype=dtype)
self.nn2 = EnsembleFC(hidden_dim, hidden_dim, num_ensemble, dtype=dtype)
self.nn_out = EnsembleFC(hidden_dim, self.output_dim, num_ensemble, dtype=dtype)
self.apply(init_weights)
if activation == "swish":
self.activation = Swish()
elif activation == "relu":
self.activation = nn.ReLU()
else:
raise ValueError(f"Unknown activation {activation}")
def forward(self, encoding: torch.Tensor) -> torch.Tensor:
x = self.activation(self.nn1(encoding))
x = self.activation(self.nn2(x))
score = self.nn_out(x)
return score
def regularization(self):
"""Prior towards independent initialization."""
return ((self.get_params() - self.init_params) ** 2).mean()
class EnsembleModel(nn.Module):
def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None:
super(EnsembleModel, self).__init__()
self.encoding_dim = encoding_dim
self.num_ensemble = num_ensemble
self.hidden_dim = hidden_dim
self.model = _EnsembleModel(encoding_dim, num_ensemble, hidden_dim, activation, dtype)
self.reg_model = deepcopy(self.model) # only used for regularization
# freeze the reg model
for param in self.reg_model.parameters():
param.requires_grad = False
def forward(self, encoding: torch.Tensor) -> torch.Tensor:
return self.model(encoding)
def regularization(self):
"""Prior towards independent initialization."""
model_params = get_params(self.model)
reg_params = get_params(self.reg_model).detach()
return ((model_params - reg_params) ** 2).mean()