federated-hierarchical_main.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import os
  5. import copy
  6. import time
  7. import pickle
  8. import numpy as np
  9. from tqdm import tqdm
  10. import torch
  11. from tensorboardX import SummaryWriter
  12. from options import args_parser
  13. from update import LocalUpdate, test_inference
  14. from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
  15. from utils import get_dataset, average_weights, exp_details
  16. # BUILD MODEL
  17. def build_model(args, train_dataset):
  18. if args.model == 'cnn':
  19. # Convolutional neural network
  20. if args.dataset == 'mnist':
  21. global_model = CNNMnist(args=args)
  22. elif args.dataset == 'fmnist':
  23. global_model = CNNFashion_Mnist(args=args)
  24. elif args.dataset == 'cifar':
  25. global_model = CNNCifar(args=args)
  26. elif args.model == 'mlp':
  27. # Multi-layer preceptron
  28. img_size = train_dataset[0][0].shape
  29. len_in = 1
  30. for x in img_size:
  31. len_in *= x
  32. global_model = MLP(dim_in=len_in, dim_hidden=200,
  33. dim_out=args.num_classes)
  34. else:
  35. exit('Error: unrecognized model')
  36. return global_model
  37. # Defining the training function
  38. def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs):
  39. cluster_train_loss, cluster_train_accuracy = [], []
  40. cluster_val_acc_list, cluster_net_list = [], []
  41. cluster_cv_loss, cluster_cv_acc = [], []
  42. print_every = 2
  43. cluster_val_loss_pre, counter = 0, 0
  44. for epoch in tqdm(range(epochs)):
  45. cluster_local_weights, cluster_local_losses = [], []
  46. print(f'\n | Cluster Training Round : {epoch+1} |\n')
  47. cluster_global_model.train()
  48. m = max(int(args.frac * len(cluster)), 1)
  49. idxs_users = np.random.choice(cluster, m, replace=False)
  50. for idx in idxs_users:
  51. cluster_local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[idx], logger=logger)
  52. cluster_w, cluster_loss = cluster_local_model.update_weights(model=copy.deepcopy(cluster_global_model), global_round=epoch)
  53. cluster_local_weights.append(copy.deepcopy(cluster_w))
  54. cluster_local_losses.append(copy.deepcopy(cluster_loss))
  55. # averaging global weights
  56. cluster_global_weights = average_weights(cluster_local_weights)
  57. # update global weights
  58. cluster_global_model.load_state_dict(cluster_global_weights)
  59. cluster_loss_avg = sum(cluster_local_losses) / len(cluster_local_losses)
  60. cluster_train_loss.append(cluster_loss_avg)
  61. return cluster_global_weights, cluster_loss_avg
  62. if __name__ == '__main__':
  63. start_time = time.time()
  64. # define paths
  65. path_project = os.path.abspath('..')
  66. logger = SummaryWriter('../logs')
  67. args = args_parser()
  68. exp_details(args)
  69. if args.gpu:
  70. torch.cuda.set_device(args.gpu)
  71. device = 'cuda' if args.gpu else 'cpu'
  72. # load dataset and user groups
  73. train_dataset, test_dataset, user_groups = get_dataset(args)
  74. # ======= Splitting into clusters. FL groups =======
  75. cluster_size = args.num_users / args.num_clusters
  76. print("Each cluster size: ", cluster_size)
  77. # Cluster 1
  78. A1 = np.arange(cluster_size, dtype=int)
  79. user_groupsA = {k:user_groups[k] for k in A1 if k in user_groups}
  80. print("Size of cluster 1: ", len(user_groupsA))
  81. # Cluster 2
  82. B1 = np.arange(cluster_size, cluster_size+cluster_size, dtype=int)
  83. user_groupsB = {k:user_groups[k] for k in B1 if k in user_groups}
  84. print("Size of cluster 2: ", len(user_groupsB))
  85. # MODEL PARAM SUMMARY
  86. global_model = build_model(args, train_dataset)
  87. pytorch_total_params = sum(p.numel() for p in global_model.parameters())
  88. print(pytorch_total_params)
  89. from torchsummary import summary
  90. summary(global_model, (1, 28, 28))
  91. global_model.parameters()
  92. # Set the model to train and send it to device.
  93. global_model.to(device)
  94. global_model.train()
  95. print(global_model)
  96. # copy weights
  97. global_weights = global_model.state_dict()
  98. # ======= Set the cluster models to train and send it to device. =======
  99. cluster_modelA = build_model(args, train_dataset)
  100. cluster_modelA.to(device)
  101. cluster_modelA.train()
  102. # copy weights
  103. cluster_modelA_weights = cluster_modelA.state_dict()
  104. # Set the cluster models to train and send it to device.
  105. cluster_modelB = build_model(args, train_dataset)
  106. cluster_modelB.to(device)
  107. cluster_modelB.train()
  108. # copy weights
  109. cluster_modelB_weights = cluster_modelA.state_dict()
  110. train_loss, train_accuracy = [], []
  111. val_acc_list, net_list = [], []
  112. cv_loss, cv_acc = [], []
  113. print_every = 1
  114. val_loss_pre, counter = 0, 0
  115. for epoch in tqdm(range(args.epochs)):
  116. local_weights, local_losses, local_accuracies= [], [], []
  117. print(f'\n | Global Training Round : {epoch+1} |\n')
  118. # ============== TRAIN ==============
  119. global_model.train()
  120. # Cluster A
  121. A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, 2)
  122. local_weights.append(copy.deepcopy(A_weights))
  123. local_losses.append(copy.deepcopy(A_losses))
  124. # Cluster B
  125. B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, 2)
  126. local_weights.append(copy.deepcopy(B_weights))
  127. local_losses.append(copy.deepcopy(B_losses))
  128. # averaging global weights
  129. global_weights = average_weights(local_weights)
  130. # update global weights
  131. global_model.load_state_dict(global_weights)
  132. loss_avg = sum(local_losses) / len(local_losses)
  133. train_loss.append(loss_avg)
  134. # ============== EVAL ==============
  135. # Calculate avg training accuracy over all users at every epoch
  136. list_acc, list_loss = [], []
  137. global_model.eval()
  138. for c in range(args.num_users):
  139. local_model = LocalUpdate(args=args, dataset=train_dataset,
  140. idxs=user_groups[c], logger=logger)
  141. acc, loss = local_model.inference(model=global_model)
  142. list_acc.append(acc)
  143. list_loss.append(loss)
  144. train_accuracy.append(sum(list_acc)/len(list_acc))
  145. # print global training loss after every 'i' rounds
  146. if (epoch+1) % print_every == 0:
  147. print(f' \nAvg Training Stats after {epoch+1} global rounds:')
  148. print(f'Training Loss : {np.mean(np.array(train_loss))}')
  149. print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
  150. # print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1][0]))
  151. print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
  152. # Test inference after completion of training
  153. test_acc, test_loss = test_inference(args, global_model, test_dataset)
  154. print(f' \n Results after {args.epochs} global rounds of training:')
  155. print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
  156. print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))