|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Deep networks.""" |
|
|
|
|
|
from copy import deepcopy |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
def init_weights(m): |
|
|
@torch.no_grad() |
|
|
def truncated_normal_init(t, mean=0.0, std=0.01): |
|
|
|
|
|
t.data.normal_(mean, std) |
|
|
while True: |
|
|
cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std) |
|
|
if not torch.sum(cond): |
|
|
break |
|
|
w = torch.empty(t.shape, device=t.device, dtype=t.dtype) |
|
|
|
|
|
w.data.normal_(mean, std) |
|
|
t = torch.where(cond, w, t) |
|
|
return t |
|
|
|
|
|
if type(m) is nn.Linear or isinstance(m, EnsembleFC): |
|
|
truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(m.in_features))) |
|
|
if m.bias is not None: |
|
|
m.bias.data.fill_(0.0) |
|
|
|
|
|
|
|
|
def init_weights_uniform(m): |
|
|
input_dim = m.in_features |
|
|
torch.nn.init.uniform(m.weight, -1 / np.sqrt(input_dim), 1 / np.sqrt(input_dim)) |
|
|
if m.bias is not None: |
|
|
m.bias.data.fill_(0.0) |
|
|
|
|
|
|
|
|
class Swish(nn.Module): |
|
|
def __init__(self): |
|
|
super(Swish, self).__init__() |
|
|
|
|
|
def forward(self, x): |
|
|
x = x * F.sigmoid(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class MLPModel(nn.Module): |
|
|
def __init__(self, encoding_dim, hidden_dim=128, activation="relu") -> None: |
|
|
super(MLPModel, self).__init__() |
|
|
self.hidden_size = hidden_dim |
|
|
self.output_dim = 1 |
|
|
|
|
|
self.nn1 = nn.Linear(encoding_dim, hidden_dim) |
|
|
self.nn2 = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.nn_out = nn.Linear(hidden_dim, self.output_dim) |
|
|
|
|
|
self.apply(init_weights) |
|
|
|
|
|
if activation == "swish": |
|
|
self.activation = Swish() |
|
|
elif activation == "relu": |
|
|
self.activation = nn.ReLU() |
|
|
else: |
|
|
raise ValueError(f"Unknown activation {activation}") |
|
|
|
|
|
def get_params(self) -> torch.Tensor: |
|
|
params = [] |
|
|
for pp in list(self.parameters()): |
|
|
params.append(pp.view(-1)) |
|
|
return torch.cat(params) |
|
|
|
|
|
def forward(self, encoding: torch.Tensor) -> torch.Tensor: |
|
|
x = self.activation(self.nn1(encoding)) |
|
|
x = self.activation(self.nn2(x)) |
|
|
score = self.nn_out(x) |
|
|
return score |
|
|
|
|
|
def init(self): |
|
|
self.init_params = self.get_params().data.clone() |
|
|
if torch.cuda.is_available(): |
|
|
self.init_params = self.init_params.cuda() |
|
|
|
|
|
def regularization(self): |
|
|
"""Prior towards independent initialization.""" |
|
|
return ((self.get_params() - self.init_params) ** 2).mean() |
|
|
|
|
|
|
|
|
class EnsembleFC(nn.Module): |
|
|
__constants__ = ["in_features", "out_features"] |
|
|
in_features: int |
|
|
out_features: int |
|
|
ensemble_size: int |
|
|
weight: torch.Tensor |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
out_features: int, |
|
|
ensemble_size: int, |
|
|
bias: bool = True, |
|
|
dtype=torch.float32, |
|
|
) -> None: |
|
|
super(EnsembleFC, self).__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
self.ensemble_size = ensemble_size |
|
|
|
|
|
self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features, dtype=dtype)) |
|
|
if bias: |
|
|
self.bias = nn.Parameter(torch.empty(ensemble_size, out_features, dtype=dtype)) |
|
|
else: |
|
|
self.register_parameter("bias", None) |
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
|
input = input.to(self.weight.dtype) |
|
|
wx = torch.einsum("eblh,ehm->eblm", input, self.weight) |
|
|
|
|
|
return torch.add(wx, self.bias[:, None, None, :]) |
|
|
|
|
|
|
|
|
def get_params(model): |
|
|
return torch.cat([p.view(-1) for p in model.parameters()]) |
|
|
|
|
|
|
|
|
class _EnsembleModel(nn.Module): |
|
|
def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None: |
|
|
|
|
|
super(_EnsembleModel, self).__init__() |
|
|
self.num_ensemble = num_ensemble |
|
|
self.hidden_dim = hidden_dim |
|
|
self.output_dim = 1 |
|
|
|
|
|
self.nn1 = EnsembleFC(encoding_dim, hidden_dim, num_ensemble, dtype=dtype) |
|
|
self.nn2 = EnsembleFC(hidden_dim, hidden_dim, num_ensemble, dtype=dtype) |
|
|
self.nn_out = EnsembleFC(hidden_dim, self.output_dim, num_ensemble, dtype=dtype) |
|
|
|
|
|
self.apply(init_weights) |
|
|
|
|
|
if activation == "swish": |
|
|
self.activation = Swish() |
|
|
elif activation == "relu": |
|
|
self.activation = nn.ReLU() |
|
|
else: |
|
|
raise ValueError(f"Unknown activation {activation}") |
|
|
|
|
|
def forward(self, encoding: torch.Tensor) -> torch.Tensor: |
|
|
x = self.activation(self.nn1(encoding)) |
|
|
x = self.activation(self.nn2(x)) |
|
|
score = self.nn_out(x) |
|
|
return score |
|
|
|
|
|
def regularization(self): |
|
|
"""Prior towards independent initialization.""" |
|
|
return ((self.get_params() - self.init_params) ** 2).mean() |
|
|
|
|
|
|
|
|
class EnsembleModel(nn.Module): |
|
|
def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None: |
|
|
super(EnsembleModel, self).__init__() |
|
|
self.encoding_dim = encoding_dim |
|
|
self.num_ensemble = num_ensemble |
|
|
self.hidden_dim = hidden_dim |
|
|
self.model = _EnsembleModel(encoding_dim, num_ensemble, hidden_dim, activation, dtype) |
|
|
self.reg_model = deepcopy(self.model) |
|
|
|
|
|
for param in self.reg_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, encoding: torch.Tensor) -> torch.Tensor: |
|
|
return self.model(encoding) |
|
|
|
|
|
def regularization(self): |
|
|
"""Prior towards independent initialization.""" |
|
|
model_params = get_params(self.model) |
|
|
reg_params = get_params(self.reg_model).detach() |
|
|
return ((model_params - reg_params) ** 2).mean() |
|
|
|