rokati commited on
Commit
2a9b828
verified
1 Parent(s): a32706d

Upload model_architecture.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_architecture.py +34 -0
model_architecture.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class MLP(nn.Module):
5
+ def __init__(self, input_dim, hidden_dims=[128, 64, 32], dropout_rate=0.3):
6
+ """
7
+ Multi-Layer Perceptron for xG prediction
8
+
9
+ Args:
10
+ input_dim: Number of input features
11
+ hidden_dims: List of hidden layer dimensions
12
+ dropout_rate: Dropout probability
13
+ """
14
+ super(MLP, self).__init__()
15
+
16
+ layers = []
17
+ prev_dim = input_dim
18
+
19
+ # Build hidden layers
20
+ for hidden_dim in hidden_dims:
21
+ layers.append(nn.Linear(prev_dim, hidden_dim))
22
+ layers.append(nn.ReLU())
23
+ layers.append(nn.BatchNorm1d(hidden_dim))
24
+ layers.append(nn.Dropout(dropout_rate))
25
+ prev_dim = hidden_dim
26
+
27
+ # Output layer
28
+ layers.append(nn.Linear(prev_dim, 1))
29
+ layers.append(nn.Sigmoid())
30
+
31
+ self.network = nn.Sequential(*layers)
32
+
33
+ def forward(self, x):
34
+ return self.network(x)