main_fedavg.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import matplotlib
  5. import matplotlib.pyplot as plt
  6. # matplotlib.use('Agg')
  7. import os
  8. import copy
  9. import numpy as np
  10. from torchvision import datasets, transforms
  11. from tqdm import tqdm
  12. import torch
  13. import torch.nn.functional as F
  14. from torch import autograd
  15. from tensorboardX import SummaryWriter
  16. from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, mnist_noniid_unequal
  17. from options import args_parser
  18. from Update import LocalUpdate
  19. from FedNets import MLP, CNNMnist, CNNCifar
  20. from averaging import average_weights
  21. import pickle
  22. if __name__ == '__main__':
  23. # parse args
  24. args = args_parser()
  25. # define paths
  26. path_project = os.path.abspath('..')
  27. summary = SummaryWriter('local')
  28. # load dataset and split users
  29. if args.dataset == 'mnist':
  30. dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
  31. transform=transforms.Compose([
  32. transforms.ToTensor(),
  33. transforms.Normalize((0.1307,), (0.3081,))
  34. ]))
  35. # sample users
  36. if args.iid:
  37. dict_users = mnist_iid(dataset_train, args.num_users)
  38. elif args.unequal:
  39. dict_users = mnist_noniid_unequal(dataset_train, args.num_users)
  40. else:
  41. dict_users = mnist_noniid(dataset_train, args.num_users)
  42. elif args.dataset == 'cifar':
  43. transform = transforms.Compose(
  44. [transforms.ToTensor(),
  45. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  46. dataset_train = datasets.CIFAR10(
  47. '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
  48. if args.iid:
  49. dict_users = cifar_iid(dataset_train, args.num_users)
  50. else:
  51. dict_users = cifar_noniid(dataset_train, args.num_users)
  52. else:
  53. exit('Error: unrecognized dataset')
  54. img_size = dataset_train[0][0].shape
  55. # BUILD MODEL
  56. if args.model == 'cnn' and args.dataset == 'mnist':
  57. if args.gpu != -1:
  58. torch.cuda.set_device(args.gpu)
  59. net_glob = CNNMnist(args=args).cuda()
  60. else:
  61. net_glob = CNNMnist(args=args)
  62. elif args.model == 'cnn' and args.dataset == 'cifar':
  63. if args.gpu != -1:
  64. torch.cuda.set_device(args.gpu)
  65. net_glob = CNNCifar(args=args).cuda()
  66. else:
  67. net_glob = CNNCifar(args=args)
  68. elif args.model == 'mlp':
  69. len_in = 1
  70. for x in img_size:
  71. len_in *= x
  72. if args.gpu != -1:
  73. torch.cuda.set_device(args.gpu)
  74. net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
  75. else:
  76. net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
  77. else:
  78. exit('Error: unrecognized model')
  79. print(net_glob)
  80. net_glob.train()
  81. # copy weights
  82. w_glob = net_glob.state_dict()
  83. # training
  84. train_loss = []
  85. train_accuracy = []
  86. cv_loss, cv_acc = [], []
  87. val_loss_pre, counter = 0, 0
  88. net_best = None
  89. val_acc_list, net_list = [], []
  90. for iter in tqdm(range(args.epochs)):
  91. w_locals, loss_locals = [], []
  92. m = max(int(args.frac * args.num_users), 1)
  93. idxs_users = np.random.choice(range(args.num_users), m, replace=False)
  94. for idx in idxs_users:
  95. local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
  96. w, loss = local.update_weights(net=copy.deepcopy(net_glob))
  97. w_locals.append(copy.deepcopy(w))
  98. loss_locals.append(copy.deepcopy(loss))
  99. # update global weights
  100. w_glob = average_weights(w_locals)
  101. # copy weight to net_glob
  102. net_glob.load_state_dict(w_glob)
  103. # print loss after every 'i' rounds
  104. print_every = 5
  105. loss_avg = sum(loss_locals) / len(loss_locals)
  106. if iter % print_every == 0:
  107. print('\nTrain loss:', loss_avg)
  108. train_loss.append(loss_avg)
  109. # Calculate avg accuracy over all users at every epoch
  110. list_acc, list_loss = [], []
  111. net_glob.eval()
  112. for c in range(args.num_users):
  113. net_local = LocalUpdate(args=args, dataset=dataset_train,
  114. idxs=dict_users[c], tb=summary)
  115. acc, loss = net_local.test(net=net_glob)
  116. list_acc.append(acc)
  117. list_loss.append(loss)
  118. train_accuracy.append(sum(list_acc)/len(list_acc))
  119. # Plot Loss curve
  120. # plt.figure()
  121. # plt.title('Training Loss vs Communication rounds')
  122. # plt.plot(range(len(train_loss)), train_loss, color='r')
  123. # plt.ylabel('Training loss')
  124. # plt.xlabel('Communication Rounds')
  125. # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(args.dataset,
  126. # args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
  127. #
  128. # # Plot Average Accuracy vs Communication rounds
  129. # plt.figure()
  130. # plt.title('Average Accuracy vs Communication rounds')
  131. # plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
  132. # plt.ylabel('Average Accuracy')
  133. # plt.xlabel('Communication Rounds')
  134. # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(args.dataset,
  135. # args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
  136. print("Final Average Accuracy after {} epochs: {:.2f}%".format(
  137. args.epochs, 100.*train_accuracy[-1]))
  138. # Saving the objects train_loss and train_accuracy:
  139. file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(args.dataset,
  140. args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)
  141. with open(file_name, 'wb') as f:
  142. pickle.dump([train_loss, train_accuracy], f)