Procházet zdrojové kódy

Major restructuring

AshwinRJ před 4 roky
rodič
revize
5f6673f73a
2 změnil soubory, kde provedl 89 přidání a 95 odebrání
  1. 87 94
      src/main_fedavg.py
  2. 2 1
      src/options.py

+ 87 - 94
src/main_fedavg.py

@@ -2,137 +2,139 @@
 # -*- coding: utf-8 -*-
 # Python version: 3.6
 
-import matplotlib
-import matplotlib.pyplot as plt
-# matplotlib.use('Agg')
 
 import os
 import copy
+import time
+import pickle
 import numpy as np
-from torchvision import datasets, transforms
 from tqdm import tqdm
-import torch
-import torch.nn.functional as F
-from torch import autograd
 
-from tensorboardX import SummaryWriter
+import torch
+from tensorboardX import SummaryWrepoch
 
-from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, mnist_noniid_unequal
 from options import args_parser
 from Update import LocalUpdate
-from FedNets import MLP, CNNMnist, CNNCifar
+from FedNets import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
 from averaging import average_weights
-import pickle
+from utils import get_dataset
 
 
 if __name__ == '__main__':
-    # parse args
-    args = args_parser()
+    start_time = time.time()
 
     # define paths
     path_project = os.path.abspath('..')
+    summary = SummaryWrepoch('local')
 
-    summary = SummaryWriter('local')
-
-    # load dataset and split users
-    if args.dataset == 'mnist':
-        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
-                                       transform=transforms.Compose([
-                                           transforms.ToTensor(),
-                                           transforms.Normalize((0.1307,), (0.3081,))
-                                       ]))
-        # sample users
-        if args.iid:
-            dict_users = mnist_iid(dataset_train, args.num_users)
-        elif args.unequal:
-            dict_users = mnist_noniid_unequal(dataset_train, args.num_users)
-        else:
-            dict_users = mnist_noniid(dataset_train, args.num_users)
-
-    elif args.dataset == 'cifar':
-        transform = transforms.Compose(
-            [transforms.ToTensor(),
-             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
-        dataset_train = datasets.CIFAR10(
-            '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
-        if args.iid:
-            dict_users = cifar_iid(dataset_train, args.num_users)
-        else:
-            dict_users = cifar_noniid(dataset_train, args.num_users)
-    else:
-        exit('Error: unrecognized dataset')
-    img_size = dataset_train[0][0].shape
+    args = args_parser()
+    if args.gpu:
+        torch.cuda.set_device(args.gpu)
+        device = 'cuda' if args.gpu else 'cpu'
+
+    # load dataset and user groups
+    train_dataset, test_dataset, user_groups = get_dataset(args)
 
     # BUILD MODEL
-    if args.model == 'cnn' and args.dataset == 'mnist':
-        if args.gpu != -1:
-            torch.cuda.set_device(args.gpu)
-            net_glob = CNNMnist(args=args).cuda()
-        else:
-            net_glob = CNNMnist(args=args)
-    elif args.model == 'cnn' and args.dataset == 'cifar':
-        if args.gpu != -1:
-            torch.cuda.set_device(args.gpu)
-            net_glob = CNNCifar(args=args).cuda()
-        else:
-            net_glob = CNNCifar(args=args)
+    if args.model == 'cnn':
+        # Convolutional neural netork
+        if args.dataset == 'mnist':
+            global_model = CNNMnist(args=args)
+        elif args.dataset == 'fmnist':
+            global_model = CNNFashion_Mnist(args=args)
+        elif args.dataset == 'cifar':
+            global_model = CNNCifar(args=args)
+
     elif args.model == 'mlp':
+        # Multi-layer preceptron
+        img_size = train_dataset[0][0].shape
         len_in = 1
         for x in img_size:
             len_in *= x
-        if args.gpu != -1:
-            torch.cuda.set_device(args.gpu)
-            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
-        else:
-            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
+            global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
     else:
         exit('Error: unrecognized model')
-    print(net_glob)
-    net_glob.train()
+
+    # Set the model to train and send it to device.
+    global_model.to(device)
+    global_model.train()
+    print(global_model)
 
     # copy weights
-    w_glob = net_glob.state_dict()
+    global_weights = global_model.state_dict()
 
     # training
-    train_loss = []
-    train_accuracy = []
+    train_loss, train_accuracy = [], []
+    val_acc_list, net_list = [], []
     cv_loss, cv_acc = [], []
+    print_every = 20
     val_loss_pre, counter = 0, 0
-    net_best = None
-    val_acc_list, net_list = [], []
-    for iter in tqdm(range(args.epochs)):
-        w_locals, loss_locals = [], []
+
+    for epoch in tqdm(range(args.epochs)):
+        global_model.train()
+        local_weights, local_losses = [], []
+
         m = max(int(args.frac * args.num_users), 1)
         idxs_users = np.random.choice(range(args.num_users), m, replace=False)
+
         for idx in idxs_users:
-            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
-            w, loss = local.update_weights(net=copy.deepcopy(net_glob))
-            w_locals.append(copy.deepcopy(w))
-            loss_locals.append(copy.deepcopy(loss))
+            local_model = LocalUpdate(args=args, dataset=train_dataset,
+                                      idxs=user_groups[idx], logger=summary)
+            w, loss = local_model.update_weights(net=copy.deepcopy(global_model))
+            local_weights.append(copy.deepcopy(w))
+            local_losses.append(copy.deepcopy(loss))
+
         # update global weights
-        w_glob = average_weights(w_locals)
+        global_weights = average_weights(local_weights)
 
-        # copy weight to net_glob
-        net_glob.load_state_dict(w_glob)
+        # copy weight to global model
+        global_model.load_state_dict(global_weights)
 
-        # print loss after every 'i' rounds
-        print_every = 5
-        loss_avg = sum(loss_locals) / len(loss_locals)
-        if iter % print_every == 0:
+        # print loss after every 20 rounds
+        loss_avg = sum(local_losses) / len(local_losses)
+        if (epoch+1) % print_every == 0:
             print('\nTrain loss:', loss_avg)
         train_loss.append(loss_avg)
 
-        # Calculate avg accuracy over all users at every epoch
+        # Calculate avg training accuracy over all users at every epoch
         list_acc, list_loss = [], []
-        net_glob.eval()
+        global_model.eval()
         for c in range(args.num_users):
-            net_local = LocalUpdate(args=args, dataset=dataset_train,
-                                    idxs=dict_users[c], tb=summary)
-            acc, loss = net_local.test(net=net_glob)
+            local_model = LocalUpdate(args=args, dataset=train_dataset,
+                                      idxs=user_groups[idx], logger=summary)
+            acc, loss = local_model.inference(net=global_model)
             list_acc.append(acc)
             list_loss.append(loss)
         train_accuracy.append(sum(list_acc)/len(list_acc))
 
+    # Test inference after completion of training
+    test_acc, test_loss = [], []
+    for c in tqdm(range(args.num_users)):
+        local_model = LocalUpdate(args=args, dataset=test_dataset,
+                                  idxs=user_groups[idx], logger=summary)
+        acc, loss = local_model.test(net=global_model)
+        test_acc.append(acc)
+        test_loss.append(loss)
+
+    print("Final Average Train Accuracy after {} epochs: {:.2f}%".format(
+        args.epochs, 100.*train_accuracy[-1]))
+
+    print("Final Average Test Accuracy after {} epochs: {:.2f}%".format(
+        args.epochs, (100.*sum(test_acc)/len(test_acc))))
+
+    # # Saving the objects train_loss and train_accuracy:
+    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(args.dataset,
+                                                                                args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)
+    with open(file_name, 'wb') as f:
+        pickle.dump([train_loss, train_accuracy], f)
+
+    print('Total Time: {0:0.4f}'.format(time.time()-start_time))
+
+    # PLOTTING (optional)
+    # import matplotlib
+    # import matplotlib.pyplot as plt
+    # matplotlib.use('Agg')
+
     # Plot Loss curve
     # plt.figure()
     # plt.title('Training Loss vs Communication rounds')
@@ -150,12 +152,3 @@ if __name__ == '__main__':
     # plt.xlabel('Communication Rounds')
     # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(args.dataset,
     #                                                                             args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
-
-    print("Final Average Accuracy after {} epochs: {:.2f}%".format(
-        args.epochs, 100.*train_accuracy[-1]))
-
-# Saving the objects train_loss and train_accuracy:
-file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(args.dataset,
-                                                                            args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)
-with open(file_name, 'wb') as f:
-    pickle.dump([train_loss, train_accuracy], f)

+ 2 - 1
src/options.py

@@ -41,7 +41,8 @@ def args_parser():
     # other arguments
     parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
     parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
-    parser.add_argument('--gpu', type=int, default=1, help="GPU ID")
+    parser.add_argument('--gpu', default=None, help="To use cuda, set \
+                        to a specific GPU ID. Default set to use CPU.")
     parser.add_argument('--iid', type=int, default=0,
                         help='whether i.i.d or not: 1 for iid, 0 for non-iid')
     parser.add_argument('--unequal', type=int, default=0,