AshwinRJ hace 6 años
padre
commit
acb916a8d7
Se han modificado 1 ficheros con 5 adiciones y 39 borrados
  1. 5 39
      main_fedavg.py

+ 5 - 39
main_fedavg.py

@@ -24,29 +24,6 @@ import matplotlib.pyplot as plt
 matplotlib.use('Agg')
 
 
-# def test(net_g, data_loader, args):
-#     # testing
-#     test_loss = 0
-#     correct = 0
-#     # Test for the below line
-#     l = len(data_loader)
-#
-#     for idx, (data, target) in enumerate(data_loader):
-#         if args.gpu != -1:
-#             data, target = data.cuda(), target.cuda()
-#         data, target = autograd.Variable(data), autograd.Variable(target)
-#         log_probs = net_g(data)
-#         test_loss += F.nll_loss(log_probs, target, size_average=False).data[0]  # sum up batch loss
-#         y_pred = log_probs.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
-#         correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
-#
-#     test_loss /= len(data_loader.dataset)
-#     print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
-#         test_loss, correct, len(data_loader.dataset),
-#         100. * correct / len(data_loader.dataset)))
-#     return correct, test_loss
-
-
 if __name__ == '__main__':
     # parse args
     args = args_parser()
@@ -84,14 +61,12 @@ if __name__ == '__main__':
     img_size = dataset_train[0][0].shape
 
     # BUILD MODEL
-    # Using same models for MNIST and FashionMNIST
     if args.model == 'cnn' and args.dataset == 'mnist':
         if args.gpu != -1:
             torch.cuda.set_device(args.gpu)
             net_glob = CNNMnist(args=args).cuda()
         else:
             net_glob = CNNMnist(args=args)
-
     elif args.model == 'cnn' and args.dataset == 'cifar':
         if args.gpu != -1:
             torch.cuda.set_device(args.gpu)
@@ -137,16 +112,16 @@ if __name__ == '__main__':
         # copy weight to net_glob
         net_glob.load_state_dict(w_glob)
 
-        # print loss
+        # print loss after every round
         loss_avg = sum(loss_locals) / len(loss_locals)
-        if args.epochs % 10 == 0:
+        if iter % 1 == 0:
             print('\nTrain loss:', loss_avg)
         loss_train.append(loss_avg)
 
         # Calculate avg accuracy over all users at every epoch
         list_acc, list_loss = [], []
         net_glob.eval()
-        for c in tqdm(range(args.num_users)):
+        for c in range(args.num_users):
             net_local = LocalUpdate(args=args, dataset=dataset_train,
                                     idxs=dict_users[c], tb=summary)
             acc, loss = net_local.test(net=net_glob)
@@ -171,15 +146,6 @@ if __name__ == '__main__':
     plt.xlabel('Communication Rounds')
     plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_acc.png'.format(args.dataset,
                                                                 args.model, args.epochs, args.frac, args.iid))
-    # testing (original)
-    list_acc, list_loss = [], []
-    net_glob.eval()
-    for c in tqdm(range(args.num_users)):
-        net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary)
-        acc, loss = net_local.test(net=net_glob)
-        list_acc.append(acc)
-        list_loss.append(loss)
-    print("Final Average Accuracy after {} epochs: {:.2f}%".format(
-        args.epochs, (100.*sum(list_acc)/len(list_acc))))
 
-    print("Final Average Accuracy after {} epochs: {:.2f}%".format(args.epochs, train_accuracy[-1])
+    print("Final Average Accuracy after {} epochs: {:.2f}%".format(
+        args.epochs, 100.*train_accuracy[-1]))