AshwinRJ 5 éve
szülő
commit
23c590a2a8
2 módosított fájl, 27 hozzáadás és 26 törlés
  1. 16 16
      Federated_Avg/main_fedavg.py
  2. 11 10
      Federated_Avg/sampling.py

+ 16 - 16
Federated_Avg/main_fedavg.py

@@ -134,22 +134,22 @@ if __name__ == '__main__':
         train_accuracy.append(sum(list_acc)/len(list_acc))
 
     # Plot Loss curve
-    plt.figure()
-    plt.title('Training Loss vs Communication rounds')
-    plt.plot(range(len(train_loss)), train_loss, color='r')
-    plt.ylabel('Training loss')
-    plt.xlabel('Communication Rounds')
-    plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(args.dataset,
-                                                                                 args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
-
-    # Plot Average Accuracy vs Communication rounds
-    plt.figure()
-    plt.title('Average Accuracy vs Communication rounds')
-    plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
-    plt.ylabel('Average Accuracy')
-    plt.xlabel('Communication Rounds')
-    plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(args.dataset,
-                                                                                args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
+    # plt.figure()
+    # plt.title('Training Loss vs Communication rounds')
+    # plt.plot(range(len(train_loss)), train_loss, color='r')
+    # plt.ylabel('Training loss')
+    # plt.xlabel('Communication Rounds')
+    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(args.dataset,
+    #                                                                              args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
+    #
+    # # Plot Average Accuracy vs Communication rounds
+    # plt.figure()
+    # plt.title('Average Accuracy vs Communication rounds')
+    # plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
+    # plt.ylabel('Average Accuracy')
+    # plt.xlabel('Communication Rounds')
+    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(args.dataset,
+    #                                                                             args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
 
     print("Final Average Accuracy after {} epochs: {:.2f}%".format(
         args.epochs, 100.*train_accuracy[-1]))

+ 11 - 10
Federated_Avg/sampling.py

@@ -98,7 +98,7 @@ def mnist_noniid_unequal(dataset, num_users):
 
         # Next, randomly assign the remaining shards
         for i in range(num_users):
-            if len(idx_shard == 0):
+            if len(idx_shard) == 0:
                 continue
             shard_size = random_shard_size[i]
             if shard_size > len(idx_shard):
@@ -118,15 +118,16 @@ def mnist_noniid_unequal(dataset, num_users):
                 dict_users[i] = np.concatenate(
                     (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
 
-        # Add the leftover shards to the client with minimum images:
-        shard_size = len(idx_shard)
-        # Add the remaining shard to the client with lowest data
-        k = min(dict_users, key=lambda x: len(dict_users.get(x)))
-        rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
-        idx_shard = list(set(idx_shard) - rand_set)
-        for rand in rand_set:
-            dict_users[k] = np.concatenate(
-                (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+        if len(idx_shard) > 0:
+            # Add the leftover shards to the client with minimum images:
+            shard_size = len(idx_shard)
+            # Add the remaining shard to the client with lowest data
+            k = min(dict_users, key=lambda x: len(dict_users.get(x)))
+            rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
+            idx_shard = list(set(idx_shard) - rand_set)
+            for rand in rand_set:
+                dict_users[k] = np.concatenate(
+                    (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
 
     return dict_users