FedNets.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.Functional as F
  4. # MLP Arch with 1 Hidden layer
  5. class MLP(nn.Module):
  6. def __init__(self, input_dim, hidden, out_dim):
  7. super(MLP, self).__init__()
  8. self.linear1 = nn.Linear(input_dim, hidden)
  9. self.linear2 = nn.Linear(hidden, out_dim)
  10. self.relu = nn.ReLU()
  11. self.dropout = nn.Dropout()
  12. self.softmax = nn.Softmax(dim=1)
  13. def forward(self, x):
  14. x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
  15. x = self.linear1(x)
  16. x = self.dropout(x)
  17. x = self.relu(x)
  18. x = self.linear2(x)
  19. return self.softmax(x)
  20. # CNN Arch for MNIST
  21. class CNN_Mnist(nn.Module):
  22. def __init__(self, args):
  23. super(CNN_Mnist, self).__init__()
  24. self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
  25. self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
  26. self.dropout_2d = nn.Dropout2d()
  27. self.fc1 = nn.Linear(320, 50)
  28. self.fc2 = nn.Linear(50, args.num_classes)
  29. def forward(self, x):
  30. x = F.max_pool2d(self.conv1(x), 2)
  31. x = F.relu(x)
  32. x = F.max_pool2d(nn.Dropout2d(self.conv2(x)), 2)
  33. x = F.relu(x)
  34. x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
  35. x = F.relu(self.fc1(x))
  36. x = F.dropout(x, training=self.training)
  37. x = self.fc2(x)
  38. return F.log_softmax(x, dim=1)
  39. # CNN Arch -- CIFAR
  40. class CNN_Cifar(nn.Module):
  41. def __init__(self, args):
  42. super(CNN_Cifar, self).__init__()
  43. self.conv1 = nn.Conv2d(3, 6, 5)
  44. self.pool = nn.MaxPool2d(2, 2)
  45. self.conv2 = nn.Conv2d(6, 16, 5)
  46. self.fc1 = nn.Linear(16*5*5, 120)
  47. self.fc2 = nn.Linear(120, 84)
  48. self.fc3 = nn.Linear(84, args.num_classes)
  49. def forward(self, x):
  50. x = F.relu(self.conv1(x))
  51. x = self.pool(x)
  52. x = F.relu(self.conv2(x))
  53. x = self.pool(x)
  54. x = x.view(-1, 16*5*5) # Dim of fc1
  55. x = F.relu(self.fc1(x))
  56. x = F.relu(self.fc2(x))
  57. x = self.fc3(x)
  58. return F.log_softmax(x, dim=1)