|
@@ -24,29 +24,6 @@ import matplotlib.pyplot as plt
|
|
|
matplotlib.use('Agg')
|
|
|
|
|
|
|
|
|
-# def test(net_g, data_loader, args):
|
|
|
-# # testing
|
|
|
-# test_loss = 0
|
|
|
-# correct = 0
|
|
|
-# # Test for the below line
|
|
|
-# l = len(data_loader)
|
|
|
-#
|
|
|
-# for idx, (data, target) in enumerate(data_loader):
|
|
|
-# if args.gpu != -1:
|
|
|
-# data, target = data.cuda(), target.cuda()
|
|
|
-# data, target = autograd.Variable(data), autograd.Variable(target)
|
|
|
-# log_probs = net_g(data)
|
|
|
-# test_loss += F.nll_loss(log_probs, target, size_average=False).data[0] # sum up batch loss
|
|
|
-# y_pred = log_probs.data.max(1, keepdim=True)[1] # get the index of the max log-probability
|
|
|
-# correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
|
|
|
-#
|
|
|
-# test_loss /= len(data_loader.dataset)
|
|
|
-# print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
|
|
|
-# test_loss, correct, len(data_loader.dataset),
|
|
|
-# 100. * correct / len(data_loader.dataset)))
|
|
|
-# return correct, test_loss
|
|
|
-
|
|
|
-
|
|
|
if __name__ == '__main__':
|
|
|
# parse args
|
|
|
args = args_parser()
|
|
@@ -84,14 +61,12 @@ if __name__ == '__main__':
|
|
|
img_size = dataset_train[0][0].shape
|
|
|
|
|
|
# BUILD MODEL
|
|
|
- # Using same models for MNIST and FashionMNIST
|
|
|
if args.model == 'cnn' and args.dataset == 'mnist':
|
|
|
if args.gpu != -1:
|
|
|
torch.cuda.set_device(args.gpu)
|
|
|
net_glob = CNNMnist(args=args).cuda()
|
|
|
else:
|
|
|
net_glob = CNNMnist(args=args)
|
|
|
-
|
|
|
elif args.model == 'cnn' and args.dataset == 'cifar':
|
|
|
if args.gpu != -1:
|
|
|
torch.cuda.set_device(args.gpu)
|
|
@@ -137,16 +112,16 @@ if __name__ == '__main__':
|
|
|
# copy weight to net_glob
|
|
|
net_glob.load_state_dict(w_glob)
|
|
|
|
|
|
- # print loss
|
|
|
+ # print loss after every round
|
|
|
loss_avg = sum(loss_locals) / len(loss_locals)
|
|
|
- if args.epochs % 10 == 0:
|
|
|
+ if iter % 1 == 0:
|
|
|
print('\nTrain loss:', loss_avg)
|
|
|
loss_train.append(loss_avg)
|
|
|
|
|
|
# Calculate avg accuracy over all users at every epoch
|
|
|
list_acc, list_loss = [], []
|
|
|
net_glob.eval()
|
|
|
- for c in tqdm(range(args.num_users)):
|
|
|
+ for c in range(args.num_users):
|
|
|
net_local = LocalUpdate(args=args, dataset=dataset_train,
|
|
|
idxs=dict_users[c], tb=summary)
|
|
|
acc, loss = net_local.test(net=net_glob)
|
|
@@ -171,15 +146,6 @@ if __name__ == '__main__':
|
|
|
plt.xlabel('Communication Rounds')
|
|
|
plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_acc.png'.format(args.dataset,
|
|
|
args.model, args.epochs, args.frac, args.iid))
|
|
|
- # testing (original)
|
|
|
- list_acc, list_loss = [], []
|
|
|
- net_glob.eval()
|
|
|
- for c in tqdm(range(args.num_users)):
|
|
|
- net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary)
|
|
|
- acc, loss = net_local.test(net=net_glob)
|
|
|
- list_acc.append(acc)
|
|
|
- list_loss.append(loss)
|
|
|
- print("Final Average Accuracy after {} epochs: {:.2f}%".format(
|
|
|
- args.epochs, (100.*sum(list_acc)/len(list_acc))))
|
|
|
|
|
|
- print("Final Average Accuracy after {} epochs: {:.2f}%".format(args.epochs, train_accuracy[-1])
|
|
|
+ print("Final Average Accuracy after {} epochs: {:.2f}%".format(
|
|
|
+ args.epochs, 100.*train_accuracy[-1]))
|