Browse Source

Unequal data split support for non-iid

AshwinRJ 6 years ago
parent
commit
00c0211e4c

BIN
.DS_Store


BIN
Federated_Avg/.DS_Store


+ 21 - 12
Federated_Avg/main_fedavg.py

@@ -2,6 +2,9 @@
 # -*- coding: utf-8 -*-
 # Python version: 3.6
 
+import matplotlib
+import matplotlib.pyplot as plt
+# matplotlib.use('Agg')
 
 import os
 import copy
@@ -11,17 +14,15 @@ from tqdm import tqdm
 import torch
 import torch.nn.functional as F
 from torch import autograd
+
 from tensorboardX import SummaryWriter
 
-from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid
+from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, mnist_noniid_unequal
 from options import args_parser
 from Update import LocalUpdate
 from FedNets import MLP, CNNMnist, CNNCifar
 from averaging import average_weights
-
-import matplotlib
-import matplotlib.pyplot as plt
-matplotlib.use('Agg')
+import pickle
 
 
 if __name__ == '__main__':
@@ -43,6 +44,8 @@ if __name__ == '__main__':
         # sample users
         if args.iid:
             dict_users = mnist_iid(dataset_train, args.num_users)
+        elif args.unequal:
+            dict_users = mnist_noniid_unequal(dataset_train, args.num_users)
         else:
             dict_users = mnist_noniid(dataset_train, args.num_users)
 
@@ -91,7 +94,7 @@ if __name__ == '__main__':
     w_glob = net_glob.state_dict()
 
     # training
-    loss_train = []
+    train_loss = []
     train_accuracy = []
     cv_loss, cv_acc = [], []
     val_loss_pre, counter = 0, 0
@@ -117,7 +120,7 @@ if __name__ == '__main__':
         loss_avg = sum(loss_locals) / len(loss_locals)
         if iter % print_every == 0:
             print('\nTrain loss:', loss_avg)
-        loss_train.append(loss_avg)
+        train_loss.append(loss_avg)
 
         # Calculate avg accuracy over all users at every epoch
         list_acc, list_loss = [], []
@@ -133,11 +136,11 @@ if __name__ == '__main__':
     # Plot Loss curve
     plt.figure()
     plt.title('Training Loss vs Communication rounds')
-    plt.plot(range(len(loss_train)), loss_train, color='r')
+    plt.plot(range(len(train_loss)), train_loss, color='r')
     plt.ylabel('Training loss')
     plt.xlabel('Communication Rounds')
-    plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_E{}_B{}_loss.png'.format(args.dataset,
-                                                                         args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
+    plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(args.dataset,
+                                                                                 args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
 
     # Plot Average Accuracy vs Communication rounds
     plt.figure()
@@ -145,8 +148,14 @@ if __name__ == '__main__':
     plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
     plt.ylabel('Average Accuracy')
     plt.xlabel('Communication Rounds')
-    plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_E{}_B{}_acc.png'.format(args.dataset,
-                                                                        args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
+    plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(args.dataset,
+                                                                                args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
 
     print("Final Average Accuracy after {} epochs: {:.2f}%".format(
         args.epochs, 100.*train_accuracy[-1]))
+
+# Saving the objects train_loss and train_accuracy:
+file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(args.dataset,
+                                                                            args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)
+with open(file_name, 'wb') as f:
+    pickle.dump([train_loss, train_accuracy], f)

+ 4 - 2
Federated_Avg/options.py

@@ -29,11 +29,13 @@ def args_parser():
                         help="Whether use max pooling rather than strided convolutions")
 
     # other arguments
-    parser.add_argument('--dataset', type=str, default='cifar', help="name of dataset")
+    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
     parser.add_argument('--iid', type=int, default=0,
                         help='whether i.i.d or not, 1 for iid, 0 for non-iid')
+    parser.add_argument('--unequal', type=int, default=0,
+                        help='in non-i.i.d, whether data split among clients is equal or not, 1 for unequal split')
     parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
-    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
+    parser.add_argument('--num_channels', type=int, default=1, help="number of channels of imgs")
     parser.add_argument('--gpu', type=int, default=1, help="GPU ID")
     parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
     parser.add_argument('--verbose', type=int, default=1,

+ 84 - 2
Federated_Avg/sampling.py

@@ -29,6 +29,7 @@ def mnist_noniid(dataset, num_users):
     :param num_users:
     :return:
     """
+    # 60,000 training imgs -->  200 imgs/shard X 300 shards
     num_shards, num_imgs = 200, 300
     idx_shard = [i for i in range(num_shards)]
     dict_users = {i: np.array([]) for i in range(num_users)}
@@ -40,7 +41,7 @@ def mnist_noniid(dataset, num_users):
     idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
     idxs = idxs_labels[0, :]
 
-    # divide and assign
+    # divide and assign 2 shards/client
     for i in range(num_users):
         rand_set = set(np.random.choice(idx_shard, 2, replace=False))
         idx_shard = list(set(idx_shard) - rand_set)
@@ -50,6 +51,86 @@ def mnist_noniid(dataset, num_users):
     return dict_users
 
 
+def mnist_noniid_unequal(dataset, num_users):
+    """
+    Sample non-I.I.D client data from MNIST dataset s.t clients
+    have unequal amount of data
+    :param dataset:
+    :param num_users:
+    :returns a dict of clients with each clients assigned certain
+    number of training imgs
+    """
+    # 60,000 training imgs --> 50 imgs/shard X 1200 shards
+    num_shards, num_imgs = 1200, 50
+    idx_shard = [i for i in range(num_shards)]
+    dict_users = {i: np.array([]) for i in range(num_users)}
+    idxs = np.arange(num_shards*num_imgs)
+    labels = dataset.train_labels.numpy()
+
+    # sort labels
+    idxs_labels = np.vstack((idxs, labels))
+    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
+    idxs = idxs_labels[0, :]
+
+    # Minimum and maximum shards assigned per client:
+    min_shard = 1
+    max_shard = 30
+
+    # Divide the shards into random chunks for every client
+    # s.t the sum of these chunks = num_shards
+    random_shard_size = np.random.randint(min_shard, max_shard+1, size=num_users)
+    random_shard_size = np.around(random_shard_size/sum(random_shard_size) * num_shards)
+    random_shard_size = random_shard_size.astype(int)
+
+    # Assign the shards randomly to each client
+    if sum(random_shard_size) > num_shards:
+
+        for i in range(num_users):
+            # First assign each client 1 shard to ensure every client has
+            # atleast one shard of data
+            rand_set = set(np.random.choice(idx_shard, 1, replace=False))
+            idx_shard = list(set(idx_shard) - rand_set)
+            for rand in rand_set:
+                dict_users[i] = np.concatenate(
+                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+
+        random_shard_size = random_shard_size-1
+
+        # Next, randomly assign the remaining shards
+        for i in range(num_users):
+            if len(idx_shard == 0):
+                continue
+            shard_size = random_shard_size[i]
+            if shard_size > len(idx_shard):
+                shard_size = len(idx_shard)
+            rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
+            idx_shard = list(set(idx_shard) - rand_set)
+            for rand in rand_set:
+                dict_users[i] = np.concatenate(
+                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+    else:
+
+        for i in range(num_users):
+            shard_size = random_shard_size[i]
+            rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
+            idx_shard = list(set(idx_shard) - rand_set)
+            for rand in rand_set:
+                dict_users[i] = np.concatenate(
+                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+
+        # Add the leftover shards to the client with minimum images:
+        shard_size = len(idx_shard)
+        # Add the remaining shard to the client with lowest data
+        k = min(dict_users, key=lambda x: len(dict_users.get(x)))
+        rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
+        idx_shard = list(set(idx_shard) - rand_set)
+        for rand in rand_set:
+            dict_users[k] = np.concatenate(
+                (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+
+    return dict_users
+
+
 def cifar_iid(dataset, num_users):
     """
     Sample I.I.D. client data from CIFAR10 dataset
@@ -76,7 +157,8 @@ def cifar_noniid(dataset, num_users):
     idx_shard = [i for i in range(num_shards)]
     dict_users = {i: np.array([]) for i in range(num_users)}
     idxs = np.arange(num_shards*num_imgs)
-    labels = dataset.train_labels.numpy()
+    # labels = dataset.train_labels.numpy()
+    labels = np.array(dataset.train_labels)
 
     # sort labels
     idxs_labels = np.vstack((idxs, labels))

+ 0 - 0
Federated_Avg/README.md → README.md


BIN
data/.DS_Store


+ 0 - 0
data/cifar/.gitkeep


+ 0 - 0
data/mnist/.gitkeep


BIN
save/.DS_Store


BIN
save/fed_mnist_cnn_2_C0.1_iid1_acc.png


BIN
save/fed_mnist_cnn_2_C0.1_iid1_loss.png


BIN
save/objects/.DS_Store