federated-hierarchical_main.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import os
  5. import copy
  6. import time
  7. import pickle
  8. import numpy as np
  9. from tqdm import tqdm
  10. import torch
  11. from tensorboardX import SummaryWriter
  12. from options import args_parser
  13. from update import LocalUpdate, test_inference
  14. from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
  15. from utils import get_dataset, average_weights, exp_details
  16. # BUILD MODEL
  17. def build_model(args, train_dataset):
  18. if args.model == 'cnn':
  19. # Convolutional neural network
  20. if args.dataset == 'mnist':
  21. global_model = CNNMnist(args=args)
  22. elif args.dataset == 'fmnist':
  23. global_model = CNNFashion_Mnist(args=args)
  24. elif args.dataset == 'cifar':
  25. global_model = CNNCifar(args=args)
  26. elif args.model == 'mlp':
  27. # Multi-layer preceptron
  28. img_size = train_dataset[0][0].shape
  29. len_in = 1
  30. for x in img_size:
  31. len_in *= x
  32. global_model = MLP(dim_in=len_in, dim_hidden=200,
  33. dim_out=args.num_classes)
  34. else:
  35. exit('Error: unrecognized model')
  36. return global_model
  37. # Defining the training function
  38. def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs):
  39. cluster_train_loss, cluster_train_accuracy = [], []
  40. cluster_val_acc_list, cluster_net_list = [], []
  41. cluster_cv_loss, cluster_cv_acc = [], []
  42. # print_every = 1
  43. cluster_val_loss_pre, counter = 0, 0
  44. for epoch in range(epochs):
  45. cluster_local_weights, cluster_local_losses = [], []
  46. # print(f'\n | Cluster Training Round : {epoch+1} |\n')
  47. cluster_global_model.train()
  48. m = max(int(args.frac * len(cluster)), 1)
  49. idxs_users = np.random.choice(cluster, m, replace=False)
  50. for idx in idxs_users:
  51. cluster_local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[idx], logger=logger)
  52. cluster_w, cluster_loss = cluster_local_model.update_weights(model=copy.deepcopy(cluster_global_model), global_round=epoch)
  53. cluster_local_weights.append(copy.deepcopy(cluster_w))
  54. cluster_local_losses.append(copy.deepcopy(cluster_loss))
  55. print('| Global Round : {} | User : {} | \tLoss: {:.6f}'.format(epoch, idx, cluster_loss))
  56. # averaging global weights
  57. cluster_global_weights = average_weights(cluster_local_weights)
  58. # update global weights
  59. cluster_global_model.load_state_dict(cluster_global_weights)
  60. cluster_loss_avg = sum(cluster_local_losses) / len(cluster_local_losses)
  61. cluster_train_loss.append(cluster_loss_avg)
  62. return cluster_global_weights, cluster_loss_avg
  63. if __name__ == '__main__':
  64. start_time = time.time()
  65. # define paths
  66. path_project = os.path.abspath('..')
  67. logger = SummaryWriter('../logs')
  68. args = args_parser()
  69. exp_details(args)
  70. if args.gpu:
  71. torch.cuda.set_device(args.gpu)
  72. device = 'cuda' if args.gpu else 'cpu'
  73. # load dataset and user groups
  74. train_dataset, test_dataset, user_groups = get_dataset(args)
  75. # ======= Splitting into clusters. FL groups =======
  76. cluster_size = args.num_users / args.num_clusters
  77. print("Each cluster size: ", cluster_size)
  78. # Cluster 1
  79. A1 = np.arange(cluster_size, dtype=int)
  80. user_groupsA = {k:user_groups[k] for k in A1 if k in user_groups}
  81. print("Size of cluster 1: ", len(user_groupsA))
  82. # Cluster 2
  83. B1 = np.arange(cluster_size, cluster_size+cluster_size, dtype=int)
  84. user_groupsB = {k:user_groups[k] for k in B1 if k in user_groups}
  85. print("Size of cluster 2: ", len(user_groupsB))
  86. # MODEL PARAM SUMMARY
  87. global_model = build_model(args, train_dataset)
  88. pytorch_total_params = sum(p.numel() for p in global_model.parameters())
  89. print(pytorch_total_params)
  90. from torchsummary import summary
  91. summary(global_model, (1, 28, 28))
  92. global_model.parameters()
  93. # Set the model to train and send it to device.
  94. global_model.to(device)
  95. global_model.train()
  96. print(global_model)
  97. # copy weights
  98. global_weights = global_model.state_dict()
  99. # ======= Set the cluster models to train and send it to device. =======
  100. cluster_modelA = build_model(args, train_dataset)
  101. cluster_modelA.to(device)
  102. cluster_modelA.train()
  103. # copy weights
  104. cluster_modelA_weights = cluster_modelA.state_dict()
  105. # Set the cluster models to train and send it to device.
  106. cluster_modelB = build_model(args, train_dataset)
  107. cluster_modelB.to(device)
  108. cluster_modelB.train()
  109. # copy weights
  110. cluster_modelB_weights = cluster_modelA.state_dict()
  111. train_loss, train_accuracy = [], []
  112. val_acc_list, net_list = [], []
  113. cv_loss, cv_acc = [], []
  114. print_every = 1
  115. val_loss_pre, counter = 0, 0
  116. testacc_check, epoch = 0, 0
  117. # for epoch in tqdm(range(args.epochs)):
  118. while testacc_check < args.test_acc:
  119. local_weights, local_losses, local_accuracies= [], [], []
  120. print(f'\n | Global Training Round : {epoch+1} |\n')
  121. # ============== TRAIN ==============
  122. global_model.train()
  123. # Cluster A
  124. A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, args.epochs)
  125. local_weights.append(copy.deepcopy(A_weights))
  126. local_losses.append(copy.deepcopy(A_losses))
  127. # Cluster B
  128. B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, args.epochs)
  129. local_weights.append(copy.deepcopy(B_weights))
  130. local_losses.append(copy.deepcopy(B_losses))
  131. # averaging global weights
  132. global_weights = average_weights(local_weights)
  133. # update global weights
  134. global_model.load_state_dict(global_weights)
  135. loss_avg = sum(local_losses) / len(local_losses)
  136. train_loss.append(loss_avg)
  137. # ============== EVAL ==============
  138. # Calculate avg training accuracy over all users at every epoch
  139. list_acc, list_loss = [], []
  140. global_model.eval()
  141. for c in range(args.num_users):
  142. local_model = LocalUpdate(args=args, dataset=train_dataset,
  143. idxs=user_groups[c], logger=logger)
  144. acc, loss = local_model.inference(model=global_model)
  145. list_acc.append(acc)
  146. list_loss.append(loss)
  147. train_accuracy.append(sum(list_acc)/len(list_acc))
  148. # Add
  149. testacc_check = 100*train_accuracy[-1]
  150. epoch = epoch + 1
  151. # print global training loss after every 'i' rounds
  152. if (epoch+1) % print_every == 0:
  153. print(f' \nAvg Training Stats after {epoch+1} global rounds:')
  154. print(f'Training Loss : {np.mean(np.array(train_loss))}')
  155. print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
  156. print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
  157. # Test inference after completion of training
  158. test_acc, test_loss = test_inference(args, global_model, test_dataset)
  159. # print(f' \n Results after {args.epochs} global rounds of training:')
  160. print(f"\nAvg Training Stats after {epoch} global rounds:")
  161. print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
  162. print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
  163. # Saving the objects train_loss and train_accuracy:
  164. file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
  165. format(args.dataset, args.model, epoch, args.frac, args.iid,
  166. args.local_ep, args.local_bs)
  167. with open(file_name, 'wb') as f:
  168. pickle.dump([train_loss, train_accuracy], f)