|
@@ -163,6 +163,14 @@ if __name__ == '__main__':
|
|
plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_loss.png'.format(args.dataset,
|
|
plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_loss.png'.format(args.dataset,
|
|
args.model, args.epochs, args.frac, args.iid))
|
|
args.model, args.epochs, args.frac, args.iid))
|
|
|
|
|
|
|
|
+ # Plot Average Accuracy vs Communication rounds
|
|
|
|
+ plt.figure()
|
|
|
|
+ plt.title('Average Accuracy vs Communication rounds')
|
|
|
|
+ plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
|
|
|
|
+ plt.ylabel('Average Accuracy')
|
|
|
|
+ plt.xlabel('Communication Rounds')
|
|
|
|
+ plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_acc.png'.format(args.dataset,
|
|
|
|
+ args.model, args.epochs, args.frac, args.iid))
|
|
# testing (original)
|
|
# testing (original)
|
|
list_acc, list_loss = [], []
|
|
list_acc, list_loss = [], []
|
|
net_glob.eval()
|
|
net_glob.eval()
|