|
@@ -2,137 +2,139 @@
|
|
|
# -*- coding: utf-8 -*-
|
|
|
# Python version: 3.6
|
|
|
|
|
|
-import matplotlib
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-# matplotlib.use('Agg')
|
|
|
|
|
|
import os
|
|
|
import copy
|
|
|
+import time
|
|
|
+import pickle
|
|
|
import numpy as np
|
|
|
-from torchvision import datasets, transforms
|
|
|
from tqdm import tqdm
|
|
|
-import torch
|
|
|
-import torch.nn.functional as F
|
|
|
-from torch import autograd
|
|
|
|
|
|
-from tensorboardX import SummaryWriter
|
|
|
+import torch
|
|
|
+from tensorboardX import SummaryWrepoch
|
|
|
|
|
|
-from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, mnist_noniid_unequal
|
|
|
from options import args_parser
|
|
|
from Update import LocalUpdate
|
|
|
-from FedNets import MLP, CNNMnist, CNNCifar
|
|
|
+from FedNets import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
|
|
|
from averaging import average_weights
|
|
|
-import pickle
|
|
|
+from utils import get_dataset
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- # parse args
|
|
|
- args = args_parser()
|
|
|
+ start_time = time.time()
|
|
|
|
|
|
# define paths
|
|
|
path_project = os.path.abspath('..')
|
|
|
+ summary = SummaryWrepoch('local')
|
|
|
|
|
|
- summary = SummaryWriter('local')
|
|
|
-
|
|
|
- # load dataset and split users
|
|
|
- if args.dataset == 'mnist':
|
|
|
- dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
|
|
|
- transform=transforms.Compose([
|
|
|
- transforms.ToTensor(),
|
|
|
- transforms.Normalize((0.1307,), (0.3081,))
|
|
|
- ]))
|
|
|
- # sample users
|
|
|
- if args.iid:
|
|
|
- dict_users = mnist_iid(dataset_train, args.num_users)
|
|
|
- elif args.unequal:
|
|
|
- dict_users = mnist_noniid_unequal(dataset_train, args.num_users)
|
|
|
- else:
|
|
|
- dict_users = mnist_noniid(dataset_train, args.num_users)
|
|
|
-
|
|
|
- elif args.dataset == 'cifar':
|
|
|
- transform = transforms.Compose(
|
|
|
- [transforms.ToTensor(),
|
|
|
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
|
- dataset_train = datasets.CIFAR10(
|
|
|
- '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
|
|
|
- if args.iid:
|
|
|
- dict_users = cifar_iid(dataset_train, args.num_users)
|
|
|
- else:
|
|
|
- dict_users = cifar_noniid(dataset_train, args.num_users)
|
|
|
- else:
|
|
|
- exit('Error: unrecognized dataset')
|
|
|
- img_size = dataset_train[0][0].shape
|
|
|
+ args = args_parser()
|
|
|
+ if args.gpu:
|
|
|
+ torch.cuda.set_device(args.gpu)
|
|
|
+ device = 'cuda' if args.gpu else 'cpu'
|
|
|
+
|
|
|
+ # load dataset and user groups
|
|
|
+ train_dataset, test_dataset, user_groups = get_dataset(args)
|
|
|
|
|
|
# BUILD MODEL
|
|
|
- if args.model == 'cnn' and args.dataset == 'mnist':
|
|
|
- if args.gpu != -1:
|
|
|
- torch.cuda.set_device(args.gpu)
|
|
|
- net_glob = CNNMnist(args=args).cuda()
|
|
|
- else:
|
|
|
- net_glob = CNNMnist(args=args)
|
|
|
- elif args.model == 'cnn' and args.dataset == 'cifar':
|
|
|
- if args.gpu != -1:
|
|
|
- torch.cuda.set_device(args.gpu)
|
|
|
- net_glob = CNNCifar(args=args).cuda()
|
|
|
- else:
|
|
|
- net_glob = CNNCifar(args=args)
|
|
|
+ if args.model == 'cnn':
|
|
|
+ # Convolutional neural netork
|
|
|
+ if args.dataset == 'mnist':
|
|
|
+ global_model = CNNMnist(args=args)
|
|
|
+ elif args.dataset == 'fmnist':
|
|
|
+ global_model = CNNFashion_Mnist(args=args)
|
|
|
+ elif args.dataset == 'cifar':
|
|
|
+ global_model = CNNCifar(args=args)
|
|
|
+
|
|
|
elif args.model == 'mlp':
|
|
|
+ # Multi-layer preceptron
|
|
|
+ img_size = train_dataset[0][0].shape
|
|
|
len_in = 1
|
|
|
for x in img_size:
|
|
|
len_in *= x
|
|
|
- if args.gpu != -1:
|
|
|
- torch.cuda.set_device(args.gpu)
|
|
|
- net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
|
|
|
- else:
|
|
|
- net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
|
|
|
+ global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
|
|
|
else:
|
|
|
exit('Error: unrecognized model')
|
|
|
- print(net_glob)
|
|
|
- net_glob.train()
|
|
|
+
|
|
|
+ # Set the model to train and send it to device.
|
|
|
+ global_model.to(device)
|
|
|
+ global_model.train()
|
|
|
+ print(global_model)
|
|
|
|
|
|
# copy weights
|
|
|
- w_glob = net_glob.state_dict()
|
|
|
+ global_weights = global_model.state_dict()
|
|
|
|
|
|
# training
|
|
|
- train_loss = []
|
|
|
- train_accuracy = []
|
|
|
+ train_loss, train_accuracy = [], []
|
|
|
+ val_acc_list, net_list = [], []
|
|
|
cv_loss, cv_acc = [], []
|
|
|
+ print_every = 20
|
|
|
val_loss_pre, counter = 0, 0
|
|
|
- net_best = None
|
|
|
- val_acc_list, net_list = [], []
|
|
|
- for iter in tqdm(range(args.epochs)):
|
|
|
- w_locals, loss_locals = [], []
|
|
|
+
|
|
|
+ for epoch in tqdm(range(args.epochs)):
|
|
|
+ global_model.train()
|
|
|
+ local_weights, local_losses = [], []
|
|
|
+
|
|
|
m = max(int(args.frac * args.num_users), 1)
|
|
|
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
|
|
|
+
|
|
|
for idx in idxs_users:
|
|
|
- local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
|
|
|
- w, loss = local.update_weights(net=copy.deepcopy(net_glob))
|
|
|
- w_locals.append(copy.deepcopy(w))
|
|
|
- loss_locals.append(copy.deepcopy(loss))
|
|
|
+ local_model = LocalUpdate(args=args, dataset=train_dataset,
|
|
|
+ idxs=user_groups[idx], logger=summary)
|
|
|
+ w, loss = local_model.update_weights(net=copy.deepcopy(global_model))
|
|
|
+ local_weights.append(copy.deepcopy(w))
|
|
|
+ local_losses.append(copy.deepcopy(loss))
|
|
|
+
|
|
|
# update global weights
|
|
|
- w_glob = average_weights(w_locals)
|
|
|
+ global_weights = average_weights(local_weights)
|
|
|
|
|
|
- # copy weight to net_glob
|
|
|
- net_glob.load_state_dict(w_glob)
|
|
|
+ # copy weight to global model
|
|
|
+ global_model.load_state_dict(global_weights)
|
|
|
|
|
|
- # print loss after every 'i' rounds
|
|
|
- print_every = 5
|
|
|
- loss_avg = sum(loss_locals) / len(loss_locals)
|
|
|
- if iter % print_every == 0:
|
|
|
+ # print loss after every 20 rounds
|
|
|
+ loss_avg = sum(local_losses) / len(local_losses)
|
|
|
+ if (epoch+1) % print_every == 0:
|
|
|
print('\nTrain loss:', loss_avg)
|
|
|
train_loss.append(loss_avg)
|
|
|
|
|
|
- # Calculate avg accuracy over all users at every epoch
|
|
|
+ # Calculate avg training accuracy over all users at every epoch
|
|
|
list_acc, list_loss = [], []
|
|
|
- net_glob.eval()
|
|
|
+ global_model.eval()
|
|
|
for c in range(args.num_users):
|
|
|
- net_local = LocalUpdate(args=args, dataset=dataset_train,
|
|
|
- idxs=dict_users[c], tb=summary)
|
|
|
- acc, loss = net_local.test(net=net_glob)
|
|
|
+ local_model = LocalUpdate(args=args, dataset=train_dataset,
|
|
|
+ idxs=user_groups[idx], logger=summary)
|
|
|
+ acc, loss = local_model.inference(net=global_model)
|
|
|
list_acc.append(acc)
|
|
|
list_loss.append(loss)
|
|
|
train_accuracy.append(sum(list_acc)/len(list_acc))
|
|
|
|
|
|
+ # Test inference after completion of training
|
|
|
+ test_acc, test_loss = [], []
|
|
|
+ for c in tqdm(range(args.num_users)):
|
|
|
+ local_model = LocalUpdate(args=args, dataset=test_dataset,
|
|
|
+ idxs=user_groups[idx], logger=summary)
|
|
|
+ acc, loss = local_model.test(net=global_model)
|
|
|
+ test_acc.append(acc)
|
|
|
+ test_loss.append(loss)
|
|
|
+
|
|
|
+ print("Final Average Train Accuracy after {} epochs: {:.2f}%".format(
|
|
|
+ args.epochs, 100.*train_accuracy[-1]))
|
|
|
+
|
|
|
+ print("Final Average Test Accuracy after {} epochs: {:.2f}%".format(
|
|
|
+ args.epochs, (100.*sum(test_acc)/len(test_acc))))
|
|
|
+
|
|
|
+ # # Saving the objects train_loss and train_accuracy:
|
|
|
+ file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(args.dataset,
|
|
|
+ args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)
|
|
|
+ with open(file_name, 'wb') as f:
|
|
|
+ pickle.dump([train_loss, train_accuracy], f)
|
|
|
+
|
|
|
+ print('Total Time: {0:0.4f}'.format(time.time()-start_time))
|
|
|
+
|
|
|
+ # PLOTTING (optional)
|
|
|
+ # import matplotlib
|
|
|
+ # import matplotlib.pyplot as plt
|
|
|
+ # matplotlib.use('Agg')
|
|
|
+
|
|
|
# Plot Loss curve
|
|
|
# plt.figure()
|
|
|
# plt.title('Training Loss vs Communication rounds')
|
|
@@ -150,12 +152,3 @@ if __name__ == '__main__':
|
|
|
# plt.xlabel('Communication Rounds')
|
|
|
# plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(args.dataset,
|
|
|
# args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
|
|
|
-
|
|
|
- print("Final Average Accuracy after {} epochs: {:.2f}%".format(
|
|
|
- args.epochs, 100.*train_accuracy[-1]))
|
|
|
-
|
|
|
-# Saving the objects train_loss and train_accuracy:
|
|
|
-file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(args.dataset,
|
|
|
- args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)
|
|
|
-with open(file_name, 'wb') as f:
|
|
|
- pickle.dump([train_loss, train_accuracy], f)
|