|
@@ -2,6 +2,9 @@
|
|
# -*- coding: utf-8 -*-
|
|
# -*- coding: utf-8 -*-
|
|
# Python version: 3.6
|
|
# Python version: 3.6
|
|
|
|
|
|
|
|
+import matplotlib
|
|
|
|
+import matplotlib.pyplot as plt
|
|
|
|
+# matplotlib.use('Agg')
|
|
|
|
|
|
import os
|
|
import os
|
|
import copy
|
|
import copy
|
|
@@ -11,17 +14,15 @@ from tqdm import tqdm
|
|
import torch
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
from torch import autograd
|
|
from torch import autograd
|
|
|
|
+
|
|
from tensorboardX import SummaryWriter
|
|
from tensorboardX import SummaryWriter
|
|
|
|
|
|
-from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid
|
|
|
|
|
|
+from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, mnist_noniid_unequal
|
|
from options import args_parser
|
|
from options import args_parser
|
|
from Update import LocalUpdate
|
|
from Update import LocalUpdate
|
|
from FedNets import MLP, CNNMnist, CNNCifar
|
|
from FedNets import MLP, CNNMnist, CNNCifar
|
|
from averaging import average_weights
|
|
from averaging import average_weights
|
|
-
|
|
|
|
-import matplotlib
|
|
|
|
-import matplotlib.pyplot as plt
|
|
|
|
-matplotlib.use('Agg')
|
|
|
|
|
|
+import pickle
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
@@ -43,6 +44,8 @@ if __name__ == '__main__':
|
|
# sample users
|
|
# sample users
|
|
if args.iid:
|
|
if args.iid:
|
|
dict_users = mnist_iid(dataset_train, args.num_users)
|
|
dict_users = mnist_iid(dataset_train, args.num_users)
|
|
|
|
+ elif args.unequal:
|
|
|
|
+ dict_users = mnist_noniid_unequal(dataset_train, args.num_users)
|
|
else:
|
|
else:
|
|
dict_users = mnist_noniid(dataset_train, args.num_users)
|
|
dict_users = mnist_noniid(dataset_train, args.num_users)
|
|
|
|
|
|
@@ -91,7 +94,7 @@ if __name__ == '__main__':
|
|
w_glob = net_glob.state_dict()
|
|
w_glob = net_glob.state_dict()
|
|
|
|
|
|
# training
|
|
# training
|
|
- loss_train = []
|
|
|
|
|
|
+ train_loss = []
|
|
train_accuracy = []
|
|
train_accuracy = []
|
|
cv_loss, cv_acc = [], []
|
|
cv_loss, cv_acc = [], []
|
|
val_loss_pre, counter = 0, 0
|
|
val_loss_pre, counter = 0, 0
|
|
@@ -117,7 +120,7 @@ if __name__ == '__main__':
|
|
loss_avg = sum(loss_locals) / len(loss_locals)
|
|
loss_avg = sum(loss_locals) / len(loss_locals)
|
|
if iter % print_every == 0:
|
|
if iter % print_every == 0:
|
|
print('\nTrain loss:', loss_avg)
|
|
print('\nTrain loss:', loss_avg)
|
|
- loss_train.append(loss_avg)
|
|
|
|
|
|
+ train_loss.append(loss_avg)
|
|
|
|
|
|
# Calculate avg accuracy over all users at every epoch
|
|
# Calculate avg accuracy over all users at every epoch
|
|
list_acc, list_loss = [], []
|
|
list_acc, list_loss = [], []
|
|
@@ -133,11 +136,11 @@ if __name__ == '__main__':
|
|
# Plot Loss curve
|
|
# Plot Loss curve
|
|
plt.figure()
|
|
plt.figure()
|
|
plt.title('Training Loss vs Communication rounds')
|
|
plt.title('Training Loss vs Communication rounds')
|
|
- plt.plot(range(len(loss_train)), loss_train, color='r')
|
|
|
|
|
|
+ plt.plot(range(len(train_loss)), train_loss, color='r')
|
|
plt.ylabel('Training loss')
|
|
plt.ylabel('Training loss')
|
|
plt.xlabel('Communication Rounds')
|
|
plt.xlabel('Communication Rounds')
|
|
- plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_E{}_B{}_loss.png'.format(args.dataset,
|
|
|
|
- args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
|
|
|
|
|
|
+ plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(args.dataset,
|
|
|
|
+ args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
|
|
|
|
|
|
# Plot Average Accuracy vs Communication rounds
|
|
# Plot Average Accuracy vs Communication rounds
|
|
plt.figure()
|
|
plt.figure()
|
|
@@ -145,8 +148,14 @@ if __name__ == '__main__':
|
|
plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
|
|
plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
|
|
plt.ylabel('Average Accuracy')
|
|
plt.ylabel('Average Accuracy')
|
|
plt.xlabel('Communication Rounds')
|
|
plt.xlabel('Communication Rounds')
|
|
- plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_E{}_B{}_acc.png'.format(args.dataset,
|
|
|
|
- args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
|
|
|
|
|
|
+ plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(args.dataset,
|
|
|
|
+ args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
|
|
|
|
|
|
print("Final Average Accuracy after {} epochs: {:.2f}%".format(
|
|
print("Final Average Accuracy after {} epochs: {:.2f}%".format(
|
|
args.epochs, 100.*train_accuracy[-1]))
|
|
args.epochs, 100.*train_accuracy[-1]))
|
|
|
|
+
|
|
|
|
+# Saving the objects train_loss and train_accuracy:
|
|
|
|
+file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(args.dataset,
|
|
|
|
+ args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)
|
|
|
|
+with open(file_name, 'wb') as f:
|
|
|
|
+ pickle.dump([train_loss, train_accuracy], f)
|