|
@@ -112,9 +112,10 @@ if __name__ == '__main__':
|
|
|
# copy weight to net_glob
|
|
|
net_glob.load_state_dict(w_glob)
|
|
|
|
|
|
- # print loss after every round
|
|
|
+ # print loss after every 'i' rounds
|
|
|
+ print_every = 5
|
|
|
loss_avg = sum(loss_locals) / len(loss_locals)
|
|
|
- if iter % 1 == 0:
|
|
|
+ if iter % print_every == 0:
|
|
|
print('\nTrain loss:', loss_avg)
|
|
|
loss_train.append(loss_avg)
|
|
|
|
|
@@ -135,8 +136,8 @@ if __name__ == '__main__':
|
|
|
plt.plot(range(len(loss_train)), loss_train, color='r')
|
|
|
plt.ylabel('Training loss')
|
|
|
plt.xlabel('Communication Rounds')
|
|
|
- plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_loss.png'.format(args.dataset,
|
|
|
- args.model, args.epochs, args.frac, args.iid))
|
|
|
+ plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_E{}_B{}_loss.png'.format(args.dataset,
|
|
|
+ args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
|
|
|
|
|
|
# Plot Average Accuracy vs Communication rounds
|
|
|
plt.figure()
|
|
@@ -144,8 +145,8 @@ if __name__ == '__main__':
|
|
|
plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
|
|
|
plt.ylabel('Average Accuracy')
|
|
|
plt.xlabel('Communication Rounds')
|
|
|
- plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_acc.png'.format(args.dataset,
|
|
|
- args.model, args.epochs, args.frac, args.iid))
|
|
|
+ plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_E{}_B{}_acc.png'.format(args.dataset,
|
|
|
+ args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
|
|
|
|
|
|
print("Final Average Accuracy after {} epochs: {:.2f}%".format(
|
|
|
args.epochs, 100.*train_accuracy[-1]))
|