Browse Source

Added CNN Arch for CIFAR

AshwinRJ 6 years ago
parent
commit
e9f57c2877
1 changed files with 29 additions and 0 deletions
  1. 29 0
      NN_Arch.py

+ 29 - 0
NN_Arch.py

@@ -5,6 +5,7 @@ import torch.nn.Functional as F
 
 # MLP Arch with 1 Hidden layer
 
+
 class MLP(nn.Module):
     def __init__(self, input_dim, hidden, out_dim):
 
@@ -23,6 +24,7 @@ class MLP(nn.Module):
         x = self.linear2(x)
         return self.softmax(x)
 
+
 # CNN Arch for MNIST
 
 
@@ -46,3 +48,30 @@ class CNN_Mnist(nn.Module):
         x = F.dropout(x, training=self.training)
         x = self.fc2(x)
         return F.log_softmax(x, dim=1)
+
+
+# CNN Arch -- CIFAR
+
+
+class CNN_Cifar(nn.Module):
+
+    def __init__(self, args):
+
+        super(CNN_Cifar, self).__init__()
+        self.conv1 = nn.Conv2d(3, 6, 5)
+        self.pool = nn.MaxPool2d(2, 2)
+        self.conv2 = nn.Conv2d(6, 16, 5)
+        self.fc1 = nn.Linear(16*5*5, 120)
+        self.fc2 = nn.Linear(120, 84)
+        self.fc3 = nn.Linear(84, args.num_classes)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        x = self.pool(x)
+        x = F.relu(self.conv2(x))
+        x = self.pool(x)
+        x = x.view(-1, 16*5*5)  # Dim of fc1
+        x = F.relu(self.fc1(x))
+        x = F.relu(self.fc2(x))
+        x = self.fc3(x)
+        return F.log_softmax(x, dim=1)