federated-hierarchical_main.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import os
  5. import copy
  6. import time
  7. import pickle
  8. import numpy as np
  9. from tqdm import tqdm
  10. import torch
  11. from tensorboardX import SummaryWriter
  12. from options import args_parser
  13. from update import LocalUpdate, test_inference
  14. from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
  15. from utils import get_dataset, average_weights, exp_details
  16. import math
  17. # BUILD MODEL
  18. def build_model(args, train_dataset):
  19. if args.model == 'cnn':
  20. # Convolutional neural network
  21. if args.dataset == 'mnist':
  22. global_model = CNNMnist(args=args)
  23. elif args.dataset == 'fmnist':
  24. global_model = CNNFashion_Mnist(args=args)
  25. elif args.dataset == 'cifar':
  26. global_model = CNNCifar(args=args)
  27. elif args.model == 'mlp':
  28. # Multi-layer preceptron
  29. img_size = train_dataset[0][0].shape
  30. len_in = 1
  31. for x in img_size:
  32. len_in *= x
  33. global_model = MLP(dim_in=len_in, dim_hidden=64,
  34. dim_out=args.num_classes)
  35. else:
  36. exit('Error: unrecognized model')
  37. return global_model
  38. # Defining the training function
  39. def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs):
  40. cluster_train_loss, cluster_train_accuracy = [], []
  41. cluster_val_acc_list, cluster_net_list = [], []
  42. cluster_cv_loss, cluster_cv_acc = [], []
  43. # print_every = 1
  44. cluster_val_loss_pre, counter = 0, 0
  45. for epoch in range(epochs):
  46. cluster_local_weights, cluster_local_losses = [], []
  47. # print(f'\n | Cluster Training Round : {epoch+1} |\n')
  48. cluster_global_model.train()
  49. # m = max(int(args.frac * len(cluster)), 1)
  50. m = max(int(math.ceil(args.frac * len(cluster))), 1)
  51. idxs_users = np.random.choice(cluster, m, replace=False)
  52. for idx in idxs_users:
  53. cluster_local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[idx], logger=logger)
  54. cluster_w, cluster_loss = cluster_local_model.update_weights(model=copy.deepcopy(cluster_global_model), global_round=epoch)
  55. cluster_local_weights.append(copy.deepcopy(cluster_w))
  56. cluster_local_losses.append(copy.deepcopy(cluster_loss))
  57. # print('| Global Round : {} | User : {} | \tLoss: {:.6f}'.format(epoch, idx, cluster_loss))
  58. # averaging global weights
  59. cluster_global_weights = average_weights(cluster_local_weights)
  60. # update global weights
  61. cluster_global_model.load_state_dict(cluster_global_weights)
  62. cluster_loss_avg = sum(cluster_local_losses) / len(cluster_local_losses)
  63. cluster_train_loss.append(cluster_loss_avg)
  64. # ============== EVAL ==============
  65. # Calculate avg training accuracy over all users at every epoch
  66. list_acc, list_loss = [], []
  67. cluster_global_model.eval()
  68. for c in range(len(cluster)):
  69. local_model = LocalUpdate(args=args, dataset=train_dataset,
  70. idxs=user_groups[c], logger=logger)
  71. acc, loss = local_model.inference(model=global_model)
  72. list_acc.append(acc)
  73. list_loss.append(loss)
  74. train_accuracy.append(sum(list_acc)/len(list_acc))
  75. # Add
  76. print("Cluster accuracy: ", 100*train_accuracy[-1])
  77. return cluster_global_model, cluster_global_weights, cluster_loss_avg
  78. if __name__ == '__main__':
  79. start_time = time.time()
  80. # define paths
  81. path_project = os.path.abspath('..')
  82. logger = SummaryWriter('../logs')
  83. args = args_parser()
  84. exp_details(args)
  85. if args.gpu:
  86. torch.cuda.set_device(args.gpu)
  87. device = 'cuda' if args.gpu else 'cpu'
  88. # load dataset and user groups
  89. train_dataset, test_dataset, user_groups = get_dataset(args)
  90. # ======= Splitting into clusters. FL groups =======
  91. cluster_size = args.num_users / args.num_clusters
  92. print("Each cluster size: ", cluster_size)
  93. # Cluster 1
  94. A1 = np.arange(cluster_size, dtype=int)
  95. user_groupsA = {k:user_groups[k] for k in A1 if k in user_groups}
  96. print("Size of cluster 1: ", len(user_groupsA))
  97. # Cluster 2
  98. B1 = np.arange(cluster_size, cluster_size+cluster_size, dtype=int)
  99. user_groupsB = {k:user_groups[k] for k in B1 if k in user_groups}
  100. print("Size of cluster 2: ", len(user_groupsB))
  101. # # Cluster 3
  102. # C1 = np.arange(2*cluster_size, 3*cluster_size, dtype=int)
  103. # user_groupsC = {k:user_groups[k] for k in C1 if k in user_groups}
  104. # print("Size of cluster 3: ", len(user_groupsC))
  105. # # Cluster 4
  106. # D1 = np.arange(3*cluster_size, 4*cluster_size, dtype=int)
  107. # user_groupsD = {k:user_groups[k] for k in D1 if k in user_groups}
  108. # print("Size of cluster 4: ", len(user_groupsD))
  109. # MODEL PARAM SUMMARY
  110. global_model = build_model(args, train_dataset)
  111. pytorch_total_params = sum(p.numel() for p in global_model.parameters())
  112. print("Model total number of parameters: ", pytorch_total_params)
  113. # from torchsummary import summary
  114. # summary(global_model, (1, 28, 28))
  115. # global_model.parameters()
  116. # Set the model to train and send it to device.
  117. global_model.to(device)
  118. global_model.train()
  119. print(global_model)
  120. # copy weights
  121. global_weights = global_model.state_dict()
  122. # ======= Set the cluster models to train and send it to device. =======
  123. # Cluster A
  124. cluster_modelA = build_model(args, train_dataset)
  125. cluster_modelA.to(device)
  126. cluster_modelA.train()
  127. # copy weights
  128. cluster_modelA_weights = cluster_modelA.state_dict()
  129. # Cluster B
  130. cluster_modelB = build_model(args, train_dataset)
  131. cluster_modelB.to(device)
  132. cluster_modelB.train()
  133. # copy weights
  134. cluster_modelB_weights = cluster_modelB.state_dict()
  135. # # Cluster C
  136. # cluster_modelC = build_model(args, train_dataset)
  137. # cluster_modelC.to(device)
  138. # cluster_modelC.train()
  139. # # copy weights
  140. # cluster_modelC_weights = cluster_modelC.state_dict()
  141. # # Cluster D
  142. # cluster_modelD = build_model(args, train_dataset)
  143. # cluster_modelD.to(device)
  144. # cluster_modelD.train()
  145. # # copy weights
  146. # cluster_modelD_weights = cluster_modelD.state_dict()
  147. train_loss, train_accuracy = [], []
  148. val_acc_list, net_list = [], []
  149. cv_loss, cv_acc = [], []
  150. print_every = 1
  151. val_loss_pre, counter = 0, 0
  152. testacc_check, epoch = 0, 0
  153. # for epoch in tqdm(range(args.epochs)):
  154. while testacc_check < args.test_acc:
  155. local_weights, local_losses, local_accuracies= [], [], []
  156. print(f'\n | Global Training Round : {epoch+1} |\n')
  157. # ============== TRAIN ==============
  158. global_model.train()
  159. # Cluster A
  160. A_model, A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, args.Cepochs)
  161. local_weights.append(copy.deepcopy(A_weights))
  162. local_losses.append(copy.deepcopy(A_losses))
  163. cluster_modelA = A_model
  164. # Cluster B
  165. B_model, B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, args.Cepochs)
  166. local_weights.append(copy.deepcopy(B_weights))
  167. local_losses.append(copy.deepcopy(B_losses))
  168. cluster_modelB = B_model
  169. # # Cluster C
  170. # C_weights, C_losses = fl_train(args, train_dataset, cluster_modelC, C1, user_groupsC, args.Cepochs)
  171. # local_weights.append(copy.deepcopy(C_weights))
  172. # local_losses.append(copy.deepcopy(C_losses))
  173. # # Cluster D
  174. # D_weights, D_losses = fl_train(args, train_dataset, cluster_modelD, D1, user_groupsD, args.Cepochs)
  175. # local_weights.append(copy.deepcopy(D_weights))
  176. # local_losses.append(copy.deepcopy(D_losses))
  177. # averaging global weights
  178. global_weights = average_weights(local_weights)
  179. # update global weights
  180. global_model.load_state_dict(global_weights)
  181. loss_avg = sum(local_losses) / len(local_losses)
  182. train_loss.append(loss_avg)
  183. # ============== EVAL ==============
  184. # Calculate avg training accuracy over all users at every epoch
  185. list_acc, list_loss = [], []
  186. global_model.eval()
  187. for c in range(args.num_users):
  188. local_model = LocalUpdate(args=args, dataset=train_dataset,
  189. idxs=user_groups[c], logger=logger)
  190. acc, loss = local_model.inference(model=global_model)
  191. list_acc.append(acc)
  192. list_loss.append(loss)
  193. train_accuracy.append(sum(list_acc)/len(list_acc))
  194. # Add
  195. testacc_check = 100*train_accuracy[-1]
  196. epoch = epoch + 1
  197. # print global training loss after every 'i' rounds
  198. if (epoch+1) % print_every == 0:
  199. print(f' \nAvg Training Stats after {epoch+1} global rounds:')
  200. print(f'Training Loss : {np.mean(np.array(train_loss))}')
  201. print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
  202. print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
  203. # Test inference after completion of training
  204. test_acc, test_loss = test_inference(args, global_model, test_dataset)
  205. # print(f' \n Results after {args.epochs} global rounds of training:')
  206. print(f"\nAvg Training Stats after {epoch} global rounds:")
  207. print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
  208. print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
  209. # Saving the objects train_loss and train_accuracy:
  210. file_name = '../save/objects/HFL_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
  211. format(args.dataset, args.model, epoch, args.frac, args.iid,
  212. args.local_ep, args.local_bs)
  213. with open(file_name, 'wb') as f:
  214. pickle.dump([train_loss, train_accuracy], f)