main_fedavg.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import os
  5. import copy
  6. import numpy as np
  7. from torchvision import datasets, transforms
  8. from tqdm import tqdm
  9. import torch
  10. import torch.nn.functional as F
  11. from torch import autograd
  12. from tensorboardX import SummaryWriter
  13. from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid
  14. from options import args_parser
  15. from Update import LocalUpdate
  16. from FedNets import MLP, CNNMnist, CNNCifar
  17. from averaging import average_weights
  18. import matplotlib
  19. import matplotlib.pyplot as plt
  20. matplotlib.use('Agg')
  21. if __name__ == '__main__':
  22. # parse args
  23. args = args_parser()
  24. # define paths
  25. path_project = os.path.abspath('..')
  26. summary = SummaryWriter('local')
  27. # load dataset and split users
  28. if args.dataset == 'mnist':
  29. dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
  30. transform=transforms.Compose([
  31. transforms.ToTensor(),
  32. transforms.Normalize((0.1307,), (0.3081,))
  33. ]))
  34. # sample users
  35. if args.iid:
  36. dict_users = mnist_iid(dataset_train, args.num_users)
  37. else:
  38. dict_users = mnist_noniid(dataset_train, args.num_users)
  39. elif args.dataset == 'cifar':
  40. transform = transforms.Compose(
  41. [transforms.ToTensor(),
  42. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  43. dataset_train = datasets.CIFAR10(
  44. '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
  45. if args.iid:
  46. dict_users = cifar_iid(dataset_train, args.num_users)
  47. else:
  48. dict_users = cifar_noniid(dataset_train, args.num_users)
  49. else:
  50. exit('Error: unrecognized dataset')
  51. img_size = dataset_train[0][0].shape
  52. # BUILD MODEL
  53. if args.model == 'cnn' and args.dataset == 'mnist':
  54. if args.gpu != -1:
  55. torch.cuda.set_device(args.gpu)
  56. net_glob = CNNMnist(args=args).cuda()
  57. else:
  58. net_glob = CNNMnist(args=args)
  59. elif args.model == 'cnn' and args.dataset == 'cifar':
  60. if args.gpu != -1:
  61. torch.cuda.set_device(args.gpu)
  62. net_glob = CNNCifar(args=args).cuda()
  63. else:
  64. net_glob = CNNCifar(args=args)
  65. elif args.model == 'mlp':
  66. len_in = 1
  67. for x in img_size:
  68. len_in *= x
  69. if args.gpu != -1:
  70. torch.cuda.set_device(args.gpu)
  71. net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
  72. else:
  73. net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
  74. else:
  75. exit('Error: unrecognized model')
  76. print(net_glob)
  77. net_glob.train()
  78. # copy weights
  79. w_glob = net_glob.state_dict()
  80. # training
  81. loss_train = []
  82. train_accuracy = []
  83. cv_loss, cv_acc = [], []
  84. val_loss_pre, counter = 0, 0
  85. net_best = None
  86. val_acc_list, net_list = [], []
  87. for iter in tqdm(range(args.epochs)):
  88. w_locals, loss_locals = [], []
  89. m = max(int(args.frac * args.num_users), 1)
  90. idxs_users = np.random.choice(range(args.num_users), m, replace=False)
  91. for idx in idxs_users:
  92. local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
  93. w, loss = local.update_weights(net=copy.deepcopy(net_glob))
  94. w_locals.append(copy.deepcopy(w))
  95. loss_locals.append(copy.deepcopy(loss))
  96. # update global weights
  97. w_glob = average_weights(w_locals)
  98. # copy weight to net_glob
  99. net_glob.load_state_dict(w_glob)
  100. # print loss after every 'i' rounds
  101. print_every = 5
  102. loss_avg = sum(loss_locals) / len(loss_locals)
  103. if iter % print_every == 0:
  104. print('\nTrain loss:', loss_avg)
  105. loss_train.append(loss_avg)
  106. # Calculate avg accuracy over all users at every epoch
  107. list_acc, list_loss = [], []
  108. net_glob.eval()
  109. for c in range(args.num_users):
  110. net_local = LocalUpdate(args=args, dataset=dataset_train,
  111. idxs=dict_users[c], tb=summary)
  112. acc, loss = net_local.test(net=net_glob)
  113. list_acc.append(acc)
  114. list_loss.append(loss)
  115. train_accuracy.append(sum(list_acc)/len(list_acc))
  116. # Plot Loss curve
  117. plt.figure()
  118. plt.title('Training Loss vs Communication rounds')
  119. plt.plot(range(len(loss_train)), loss_train, color='r')
  120. plt.ylabel('Training loss')
  121. plt.xlabel('Communication Rounds')
  122. plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_E{}_B{}_loss.png'.format(args.dataset,
  123. args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
  124. # Plot Average Accuracy vs Communication rounds
  125. plt.figure()
  126. plt.title('Average Accuracy vs Communication rounds')
  127. plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
  128. plt.ylabel('Average Accuracy')
  129. plt.xlabel('Communication Rounds')
  130. plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_E{}_B{}_acc.png'.format(args.dataset,
  131. args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
  132. print("Final Average Accuracy after {} epochs: {:.2f}%".format(
  133. args.epochs, 100.*train_accuracy[-1]))