BAPULM / model /model.py
Moreza009's picture
Upload folder using huggingface_hub
15c5ffb verified
import torch
import torch.nn as nn
class BAPULM(nn.Module):
def __init__(self):
super(BAPULM, self).__init__()
self.prot_linear = nn.Linear(1024, 512)
self.mol_linear = nn.Linear(768, 512)
self.norm = nn.BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True)
self.dropout = nn.Dropout(p=0.1)
self.linear1 = nn.Linear(1024, 768)
self.linear2 = nn.Linear(768, 512)
self.linear3 = nn.Linear(512, 32)
self.final_linear = nn.Linear(32, 1)
def forward(self, prot, mol):
prot_output = torch.relu(self.prot_linear(prot))
mol_output = torch.relu(self.mol_linear(mol))
combined_output = torch.cat((prot_output, mol_output), dim=1)
combined_output = self.norm(combined_output)
combined_output = self.dropout(combined_output)
x = torch.relu(self.linear1(combined_output))
x = torch.relu(self.linear2(x))
x = self.dropout(x)
x = torch.relu(self.linear3(x))
output = self.final_linear(x)
return output