Bladeren bron

Added test inference

AshwinRJ 4 jaren geleden
bovenliggende
commit
2b46351813
1 gewijzigde bestanden met toevoegingen van 41 en 10 verwijderingen
  1. 41 10
      src/update.py

+ 41 - 10
src/update.py

@@ -39,19 +39,19 @@ class LocalUpdate(object):
         and user indexes.
         """
         # split indexes for train, validation, and test (80, 10, 10)
-        idxs_train = idxs[:(0.8*len(idxs))]
-        idxs_val = idxs[(0.8*len(idxs)):(0.9*len(idxs))]
-        idxs_test = idxs[(0.9*len(idxs)):]
+        idxs_train = idxs[:int(0.8*len(idxs))]
+        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
+        idxs_test = idxs[int(0.9*len(idxs)):]
 
         trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                  batch_size=self.args.local_bs, shuffle=True)
         validloader = DataLoader(DatasetSplit(dataset, idxs_val),
-                                 batch_size=int(len(idxs_val)/10), shuffle=True)
+                                 batch_size=int(len(idxs_val)/10), shuffle=False)
         testloader = DataLoader(DatasetSplit(dataset, idxs_test),
-                                batch_size=int(len(idxs_test)/10), shuffle=True)
+                                batch_size=int(len(idxs_test)/10), shuffle=False)
         return trainloader, validloader, testloader
 
-    def update_weights(self, model):
+    def update_weights(self, model, global_round):
         # Set mode to train model
         model.train()
         epoch_loss = []
@@ -75,9 +75,10 @@ class LocalUpdate(object):
                 loss.backward()
                 optimizer.step()
 
-                if self.args.verbose and batch_idx % 10 == 0:
-                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
-                        iter, batch_idx * len(images), len(self.trainloader.dataset),
+                if self.args.verbose and (batch_idx % 10 == 0):
+                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+                        global_round, iter, batch_idx * len(images),
+                        len(self.trainloader.dataset),
                         100. * batch_idx / len(self.trainloader), loss.item()))
                 self.logger.add_scalar('loss', loss.item())
                 batch_loss.append(loss.item())
@@ -107,4 +108,34 @@ class LocalUpdate(object):
             total += len(labels)
 
         accuracy = correct/total
-        return accuracy, loss.item()
+        return accuracy, loss
+
+
+def test_inference(args, model, test_dataset):
+    """ Returns the test accuracy and loss.
+    """
+
+    model.eval()
+    loss, total, correct = 0.0, 0.0, 0.0
+
+    device = 'cuda' if args.gpu else 'cpu'
+    criterion = nn.NLLLoss().to(device)
+    testloader = DataLoader(test_dataset, batch_size=128,
+                            shuffle=False)
+
+    for batch_idx, (images, labels) in enumerate(testloader):
+        images, labels = images.to(device), labels.to(device)
+
+        # Inference
+        outputs = model(images)
+        batch_loss = criterion(outputs, labels)
+        loss += batch_loss.item()
+
+        # Prediction
+        _, pred_labels = torch.max(outputs, 1)
+        pred_labels = pred_labels.view(-1)
+        correct += torch.sum(torch.eq(pred_labels, labels)).item()
+        total += len(labels)
+
+    accuracy = correct/total
+    return accuracy, loss