utils.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import copy
  5. import torch
  6. import numpy as np
  7. from sys import exit
  8. from torchvision import datasets, transforms
  9. from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
  10. from sampling import cifar_iid, cifar_noniid
  11. from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
  12. import update
  13. #from update import LocalUpdate, test_inference
  14. def get_dataset(args):
  15. """ Returns train and test datasets and a user group which is a dict where
  16. the keys are the user index and the values are the corresponding data for
  17. each of those users.
  18. """
  19. if args.dataset == 'cifar':
  20. data_dir = '../data/cifar/'
  21. apply_transform = transforms.Compose(
  22. [transforms.ToTensor(),
  23. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  24. # train_dataset = datasets.MNIST(data_dir, train=True, download=True,
  25. # transform=apply_transform)
  26. train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
  27. transform=apply_transform)
  28. # test_dataset = datasets.MNIST(data_dir, train=False, download=True,
  29. # transform=apply_transform)
  30. test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
  31. transform=apply_transform)
  32. # sample training data amongst users
  33. if args.iid:
  34. # Sample IID user data from Mnist
  35. print("Dataset: CIFAR10 IID")
  36. user_groups = cifar_iid(train_dataset, args.num_users)
  37. else:
  38. # Sample Non-IID user data from Mnist
  39. if args.unequal:
  40. # Chose uneuqal splits for every user
  41. raise NotImplementedError()
  42. else:
  43. # Chose euqal splits for every user
  44. print("Dataset: CIFAR10 equal Non-IID")
  45. user_groups = cifar_noniid(train_dataset, args.num_users)
  46. elif args.dataset == 'mnist':
  47. data_dir = '../data/mnist/'
  48. apply_transform = transforms.Compose([
  49. transforms.ToTensor(),
  50. transforms.Normalize((0.1307,), (0.3081,))])
  51. train_dataset = datasets.MNIST(data_dir, train=True, download=True,
  52. transform=apply_transform)
  53. test_dataset = datasets.MNIST(data_dir, train=False, download=True,
  54. transform=apply_transform)
  55. # sample training data amongst users
  56. if args.iid:
  57. # Sample IID user data from Mnist
  58. print("Dataset: MNIST IID")
  59. user_groups = mnist_iid(train_dataset, args.num_users)
  60. else:
  61. # Sample Non-IID user data from Mnist
  62. if args.unequal:
  63. print("Dataset: MNIST unequal Non-IID")
  64. # Chose uneuqal splits for every user
  65. user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
  66. else:
  67. # Chose equal splits for every user
  68. print("Dataset: MNIST equal Non-IID")
  69. user_groups = mnist_noniid(train_dataset, args.num_users)
  70. else:
  71. exit("No such dataset: " + args.dataset)
  72. return train_dataset, test_dataset, user_groups
  73. def average_weights(w):
  74. """
  75. Returns the average of the weights.
  76. """
  77. w_avg = copy.deepcopy(w[0])
  78. for key in w_avg.keys():
  79. for i in range(1, len(w)):
  80. w_avg[key] += w[i][key]
  81. w_avg[key] = torch.div(w_avg[key], len(w))
  82. return w_avg
  83. def exp_details(args):
  84. print('\nExperimental details:')
  85. print(f' Model : {args.model}')
  86. print(f' Optimizer : {args.optimizer}')
  87. print(f' Learning : {args.lr}')
  88. print(f' Global Rounds : {args.epochs}\n')
  89. print(' Federated parameters:')
  90. if args.iid:
  91. print(' IID')
  92. else:
  93. print(' Non-IID')
  94. print(f' Fraction of users : {args.frac}')
  95. print(f' Local Batch size : {args.local_bs}')
  96. print(f' Local Epochs : {args.local_ep}\n')
  97. return
  98. def set_device(args):
  99. # Select CPU or GPU
  100. if not args.gpu or not torch.cuda.is_available():
  101. device=torch.device('cpu')
  102. else:
  103. # Check that GPU is indeed available
  104. device = torch.device(args.gpu_id)
  105. return device
  106. def build_model(args, train_dataset):
  107. if args.model == 'cnn':
  108. # Convolutional neural network
  109. if args.dataset == 'mnist':
  110. model = CNNMnist(args=args)
  111. elif args.dataset == 'fmnist':
  112. model = CNNFashion_Mnist(args=args)
  113. elif args.dataset == 'cifar':
  114. model = CNNCifar(args=args)
  115. elif args.model == 'mlp':
  116. # Multi-layer preceptron
  117. img_size = train_dataset[0][0].shape
  118. len_in = 1
  119. for x in img_size:
  120. len_in *= x
  121. model = MLP(dim_in=len_in, dim_hidden=args.mlpdim,
  122. dim_out=args.num_classes)
  123. else:
  124. exit('Error- unrecognized model: ' + args.model)
  125. return model
  126. def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs, logger, cluster_dtype=torch.float32):
  127. """
  128. Defining the training function.
  129. """
  130. cluster_train_loss, cluster_train_acc = [], []
  131. cluster_val_acc_list, cluster_net_list = [], []
  132. cluster_cv_loss, cluster_cv_acc = [], []
  133. # print_every = 1
  134. cluster_val_loss_pre, counter = 0, 0
  135. for epoch in range(epochs):
  136. cluster_local_weights, cluster_local_losses = [], []
  137. # print(f'\n | Cluster Training Round : {epoch+1} |\n')
  138. cluster_global_model.train()
  139. # m = max(int(args.frac * len(cluster)), 1)
  140. # m = max(int(math.ceil(args.frac * len(cluster))), 1)
  141. m = min(int(len(cluster)), 10)
  142. # print("=== m ==== ", m)
  143. # m = 10
  144. idxs_users = np.random.choice(cluster, m, replace=False)
  145. for idx in idxs_users:
  146. cluster_local_model = update.LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[idx], logger=logger)
  147. cluster_w, cluster_loss = cluster_local_model.update_weights(model=copy.deepcopy(cluster_global_model), global_round=epoch, dtype=cluster_dtype)
  148. cluster_local_weights.append(copy.deepcopy(cluster_w))
  149. cluster_local_losses.append(copy.deepcopy(cluster_loss))
  150. # print('| Global Round : {} | User : {} | \tLoss: {:.6f}'.format(epoch, idx, cluster_loss))
  151. # averaging global weights
  152. cluster_global_weights = average_weights(cluster_local_weights)
  153. # update global weights
  154. cluster_global_model.load_state_dict(cluster_global_weights)
  155. cluster_loss_avg = sum(cluster_local_losses) / len(cluster_local_losses)
  156. cluster_train_loss.append(cluster_loss_avg)
  157. # ============== EVAL ==============
  158. # Calculate avg training accuracy over all users at every epoch
  159. list_acc, list_loss = [], []
  160. cluster_global_model.eval()
  161. # C = np.random.choice(cluster, m, replace=False) # random set of clients
  162. # print("C: ", C)
  163. # for c in C:
  164. # for c in range(len(cluster)):
  165. for c in idxs_users:
  166. cluster_local_model = update.LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[c], logger=logger)
  167. # local_model = LocalUpdate(args=args, dataset=train_dataset,idxs=user_groups[idx], logger=logger)
  168. acc, loss = cluster_local_model.inference(model=cluster_global_model, dtype=cluster_dtype)
  169. list_acc.append(acc)
  170. list_loss.append(loss)
  171. # cluster_train_acc.append(sum(list_acc)/len(list_acc))
  172. # Add
  173. # print("Cluster accuracy: ", 100*cluster_train_acc[-1])
  174. print("Cluster accuracy: ", 100*sum(list_acc)/len(list_acc))
  175. return cluster_global_model, cluster_global_weights, cluster_loss_avg