Ver código fonte

Updated figure save format

AshwinRJ 5 anos atrás
pai
commit
d47dc60fcc

BIN
.DS_Store


BIN
Federated_Avg/__pycache__/FedNets.cpython-36.pyc


BIN
Federated_Avg/__pycache__/Update.cpython-36.pyc


BIN
Federated_Avg/__pycache__/averaging.cpython-36.pyc


BIN
Federated_Avg/__pycache__/options.cpython-36.pyc


BIN
Federated_Avg/__pycache__/sampling.cpython-36.pyc


BIN
Federated_Avg/local/events.out.tfevents.1543212153.Nitros-MacBook-Pro.local


+ 7 - 6
Federated_Avg/main_fedavg.py

@@ -112,9 +112,10 @@ if __name__ == '__main__':
         # copy weight to net_glob
         net_glob.load_state_dict(w_glob)
 
-        # print loss after every round
+        # print loss after every 'i' rounds
+        print_every = 5
         loss_avg = sum(loss_locals) / len(loss_locals)
-        if iter % 1 == 0:
+        if iter % print_every == 0:
             print('\nTrain loss:', loss_avg)
         loss_train.append(loss_avg)
 
@@ -135,8 +136,8 @@ if __name__ == '__main__':
     plt.plot(range(len(loss_train)), loss_train, color='r')
     plt.ylabel('Training loss')
     plt.xlabel('Communication Rounds')
-    plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_loss.png'.format(args.dataset,
-                                                                 args.model, args.epochs, args.frac, args.iid))
+    plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_E{}_B{}_loss.png'.format(args.dataset,
+                                                                         args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
 
     # Plot Average Accuracy vs Communication rounds
     plt.figure()
@@ -144,8 +145,8 @@ if __name__ == '__main__':
     plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
     plt.ylabel('Average Accuracy')
     plt.xlabel('Communication Rounds')
-    plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_acc.png'.format(args.dataset,
-                                                                args.model, args.epochs, args.frac, args.iid))
+    plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_E{}_B{}_acc.png'.format(args.dataset,
+                                                                        args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
 
     print("Final Average Accuracy after {} epochs: {:.2f}%".format(
         args.epochs, 100.*train_accuracy[-1]))

BIN
data/.DS_Store


+ 0 - 0
data/cifar/.gitkeep


+ 0 - 0
data/mnist/.gitkeep