NN_Arch.py 646 B

1234567891011121314151617181920212223
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.Functional as F
  4. # MLP Arch with 1 Hidden layer
  5. class MLP(nn.Module):
  6. def __init__(self, input_dim, hidden, out_dim):
  7. super(MLP, self).__init__()
  8. self.linear1 = nn.Linear(input_dim, hidden)
  9. self.linear2 = nn.Linear(hidden, out_dim)
  10. self.relu = nn.ReLU()
  11. self.dropout = nn.Dropout()
  12. self.softmax = nn.Softmax(dim=1)
  13. def forward(self, x):
  14. x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
  15. x = self.linear1(x)
  16. x = self.dropout(x)
  17. x = self.relu(x)
  18. x = self.linear2(x)
  19. return self.softmax(x)