瀏覽代碼

First commit

AshwinRJ 6 年之前
當前提交
7ad2e4503b
共有 2 個文件被更改,包括 24 次插入0 次删除
  1. 23 0
      NN_Arch.py
  2. 1 0
      README.md

+ 23 - 0
NN_Arch.py

@@ -0,0 +1,23 @@
+import torch
+import torch.nn as nn
+import torch.nn.Functional as F
+# MLP Arch with 1 Hidden layer
+
+
+class MLP(nn.Module):
+    def __init__(self, input_dim, hidden, out_dim):
+
+        super(MLP, self).__init__()
+        self.linear1 = nn.Linear(input_dim, hidden)
+        self.linear2 = nn.Linear(hidden, out_dim)
+        self.relu = nn.ReLU()
+        self.dropout = nn.Dropout()
+        self.softmax = nn.Softmax(dim=1)
+
+    def forward(self, x):
+        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
+        x = self.linear1(x)
+        x = self.dropout(x)
+        x = self.relu(x)
+        x = self.linear2(x)
+        return self.softmax(x)

+ 1 - 0
README.md

@@ -0,0 +1 @@
+# Federated-Learning