瀏覽代碼

conform to pep8

AshwinRJ 5 年之前
父節點
當前提交
ff15a0fc6b
共有 2 個文件被更改,包括 32 次插入17 次删除
  1. 8 5
      src/federated_main.py
  2. 24 12
      src/sampling.py

+ 8 - 5
src/federated_main.py

@@ -51,7 +51,8 @@ if __name__ == '__main__':
         len_in = 1
         for x in img_size:
             len_in *= x
-            global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
+            global_model = MLP(dim_in=len_in, dim_hidden=64,
+                               dim_out=args.num_classes)
     else:
         exit('Error: unrecognized model')
 
@@ -142,8 +143,9 @@ if __name__ == '__main__':
     # plt.plot(range(len(train_loss)), train_loss, color='r')
     # plt.ylabel('Training loss')
     # plt.xlabel('Communication Rounds')
-    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(args.dataset,
-    #                                                                              args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
+    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
+    #             format(args.dataset, args.model, args.epochs, args.frac,
+    #                    args.iid, args.local_ep, args.local_bs))
     #
     # # Plot Average Accuracy vs Communication rounds
     # plt.figure()
@@ -151,5 +153,6 @@ if __name__ == '__main__':
     # plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
     # plt.ylabel('Average Accuracy')
     # plt.xlabel('Communication Rounds')
-    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(args.dataset,
-    #                                                                             args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
+    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
+    #             format(args.dataset, args.model, args.epochs, args.frac,
+    #                    args.iid, args.local_ep, args.local_bs))

+ 24 - 12
src/sampling.py

@@ -17,7 +17,8 @@ def mnist_iid(dataset, num_users):
     num_items = int(len(dataset)/num_users)
     dict_users, all_idxs = {}, [i for i in range(len(dataset))]
     for i in range(num_users):
-        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
+        dict_users[i] = set(np.random.choice(all_idxs, num_items,
+                                             replace=False))
         all_idxs = list(set(all_idxs) - dict_users[i])
     return dict_users
 
@@ -78,8 +79,10 @@ def mnist_noniid_unequal(dataset, num_users):
 
     # Divide the shards into random chunks for every client
     # s.t the sum of these chunks = num_shards
-    random_shard_size = np.random.randint(min_shard, max_shard+1, size=num_users)
-    random_shard_size = np.around(random_shard_size/sum(random_shard_size) * num_shards)
+    random_shard_size = np.random.randint(min_shard, max_shard+1,
+                                          size=num_users)
+    random_shard_size = np.around(random_shard_size /
+                                  sum(random_shard_size) * num_shards)
     random_shard_size = random_shard_size.astype(int)
 
     # Assign the shards randomly to each client
@@ -92,7 +95,8 @@ def mnist_noniid_unequal(dataset, num_users):
             idx_shard = list(set(idx_shard) - rand_set)
             for rand in rand_set:
                 dict_users[i] = np.concatenate(
-                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
+                    axis=0)
 
         random_shard_size = random_shard_size-1
 
@@ -103,31 +107,37 @@ def mnist_noniid_unequal(dataset, num_users):
             shard_size = random_shard_size[i]
             if shard_size > len(idx_shard):
                 shard_size = len(idx_shard)
-            rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
+            rand_set = set(np.random.choice(idx_shard, shard_size,
+                                            replace=False))
             idx_shard = list(set(idx_shard) - rand_set)
             for rand in rand_set:
                 dict_users[i] = np.concatenate(
-                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
+                    axis=0)
     else:
 
         for i in range(num_users):
             shard_size = random_shard_size[i]
-            rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
+            rand_set = set(np.random.choice(idx_shard, shard_size,
+                                            replace=False))
             idx_shard = list(set(idx_shard) - rand_set)
             for rand in rand_set:
                 dict_users[i] = np.concatenate(
-                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
+                    axis=0)
 
         if len(idx_shard) > 0:
             # Add the leftover shards to the client with minimum images:
             shard_size = len(idx_shard)
             # Add the remaining shard to the client with lowest data
             k = min(dict_users, key=lambda x: len(dict_users.get(x)))
-            rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
+            rand_set = set(np.random.choice(idx_shard, shard_size,
+                                            replace=False))
             idx_shard = list(set(idx_shard) - rand_set)
             for rand in rand_set:
                 dict_users[k] = np.concatenate(
-                    (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+                    (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]),
+                    axis=0)
 
     return dict_users
 
@@ -142,7 +152,8 @@ def cifar_iid(dataset, num_users):
     num_items = int(len(dataset)/num_users)
     dict_users, all_idxs = {}, [i for i in range(len(dataset))]
     for i in range(num_users):
-        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
+        dict_users[i] = set(np.random.choice(all_idxs, num_items,
+                                             replace=False))
         all_idxs = list(set(all_idxs) - dict_users[i])
     return dict_users
 
@@ -180,7 +191,8 @@ if __name__ == '__main__':
     dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
-                                       transforms.Normalize((0.1307,), (0.3081,))
+                                       transforms.Normalize((0.1307,),
+                                                            (0.3081,))
                                    ]))
     num = 100
     d = mnist_noniid(dataset_train, num)