123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # Python version: 3.6
- import matplotlib
- import matplotlib.pyplot as plt
- # matplotlib.use('Agg')
- import os
- import copy
- import numpy as np
- from torchvision import datasets, transforms
- from tqdm import tqdm
- import torch
- import torch.nn.functional as F
- from torch import autograd
- from tensorboardX import SummaryWriter
- from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, mnist_noniid_unequal
- from options import args_parser
- from Update import LocalUpdate
- from FedNets import MLP, CNNMnist, CNNCifar
- from averaging import average_weights
- import pickle
- if __name__ == '__main__':
- # parse args
- args = args_parser()
- # define paths
- path_project = os.path.abspath('..')
- summary = SummaryWriter('local')
- # load dataset and split users
- if args.dataset == 'mnist':
- dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ]))
- # sample users
- if args.iid:
- dict_users = mnist_iid(dataset_train, args.num_users)
- elif args.unequal:
- dict_users = mnist_noniid_unequal(dataset_train, args.num_users)
- else:
- dict_users = mnist_noniid(dataset_train, args.num_users)
- elif args.dataset == 'cifar':
- transform = transforms.Compose(
- [transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
- dataset_train = datasets.CIFAR10(
- '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
- if args.iid:
- dict_users = cifar_iid(dataset_train, args.num_users)
- else:
- dict_users = cifar_noniid(dataset_train, args.num_users)
- else:
- exit('Error: unrecognized dataset')
- img_size = dataset_train[0][0].shape
- # BUILD MODEL
- if args.model == 'cnn' and args.dataset == 'mnist':
- if args.gpu != -1:
- torch.cuda.set_device(args.gpu)
- net_glob = CNNMnist(args=args).cuda()
- else:
- net_glob = CNNMnist(args=args)
- elif args.model == 'cnn' and args.dataset == 'cifar':
- if args.gpu != -1:
- torch.cuda.set_device(args.gpu)
- net_glob = CNNCifar(args=args).cuda()
- else:
- net_glob = CNNCifar(args=args)
- elif args.model == 'mlp':
- len_in = 1
- for x in img_size:
- len_in *= x
- if args.gpu != -1:
- torch.cuda.set_device(args.gpu)
- net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
- else:
- net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
- else:
- exit('Error: unrecognized model')
- print(net_glob)
- net_glob.train()
- # copy weights
- w_glob = net_glob.state_dict()
- # training
- train_loss = []
- train_accuracy = []
- cv_loss, cv_acc = [], []
- val_loss_pre, counter = 0, 0
- net_best = None
- val_acc_list, net_list = [], []
- for iter in tqdm(range(args.epochs)):
- w_locals, loss_locals = [], []
- m = max(int(args.frac * args.num_users), 1)
- idxs_users = np.random.choice(range(args.num_users), m, replace=False)
- for idx in idxs_users:
- local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
- w, loss = local.update_weights(net=copy.deepcopy(net_glob))
- w_locals.append(copy.deepcopy(w))
- loss_locals.append(copy.deepcopy(loss))
- # update global weights
- w_glob = average_weights(w_locals)
- # copy weight to net_glob
- net_glob.load_state_dict(w_glob)
- # print loss after every 'i' rounds
- print_every = 5
- loss_avg = sum(loss_locals) / len(loss_locals)
- if iter % print_every == 0:
- print('\nTrain loss:', loss_avg)
- train_loss.append(loss_avg)
- # Calculate avg accuracy over all users at every epoch
- list_acc, list_loss = [], []
- net_glob.eval()
- for c in range(args.num_users):
- net_local = LocalUpdate(args=args, dataset=dataset_train,
- idxs=dict_users[c], tb=summary)
- acc, loss = net_local.test(net=net_glob)
- list_acc.append(acc)
- list_loss.append(loss)
- train_accuracy.append(sum(list_acc)/len(list_acc))
- # Plot Loss curve
- plt.figure()
- plt.title('Training Loss vs Communication rounds')
- plt.plot(range(len(train_loss)), train_loss, color='r')
- plt.ylabel('Training loss')
- plt.xlabel('Communication Rounds')
- plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(args.dataset,
- args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
- # Plot Average Accuracy vs Communication rounds
- plt.figure()
- plt.title('Average Accuracy vs Communication rounds')
- plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
- plt.ylabel('Average Accuracy')
- plt.xlabel('Communication Rounds')
- plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(args.dataset,
- args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
- print("Final Average Accuracy after {} epochs: {:.2f}%".format(
- args.epochs, 100.*train_accuracy[-1]))
- # Saving the objects train_loss and train_accuracy:
- file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(args.dataset,
- args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)
- with open(file_name, 'wb') as f:
- pickle.dump([train_loss, train_accuracy], f)
|