federated-hierarchical8_main_fp16.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import os
  5. import copy
  6. import time
  7. import pickle
  8. import numpy as np
  9. from tqdm import tqdm
  10. import torch
  11. from tensorboardX import SummaryWriter
  12. from options import args_parser
  13. from update import LocalUpdate, test_inference
  14. from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
  15. from utils import get_dataset, average_weights, exp_details, set_device, build_model, fl_train
  16. import math
  17. import random
  18. if __name__ == '__main__':
  19. start_time = time.time()
  20. # define paths
  21. path_project = os.path.abspath('..')
  22. logger = SummaryWriter('../logs')
  23. args = args_parser()
  24. exp_details(args)
  25. # Select CPU or GPU
  26. device = set_device(args)
  27. # load dataset and user groups
  28. train_dataset, test_dataset, user_groupsold = get_dataset(args)
  29. # user_groups = user_groupsold
  30. # keylist = list(user_groups.keys())
  31. # ======= Shuffle dataset =======
  32. keys = list(user_groupsold.keys())
  33. random.shuffle(keys)
  34. user_groups = dict()
  35. for key in keys:
  36. user_groups.update({key:user_groupsold[key]})
  37. # print(user_groups.keys())
  38. keylist = list(user_groups.keys())
  39. print("keylist: ", keylist)
  40. # ======= Splitting into clusters. FL groups =======
  41. if args.num_clusters != 8:
  42. exit("Confirm that the number of clusters is 8?")
  43. cluster_size = int(args.num_users / args.num_clusters)
  44. print("Each cluster size: ", cluster_size)
  45. # Cluster 1
  46. A1 = keylist[:cluster_size]
  47. # A1 = np.random.choice(keylist, cluster_size, replace=False)
  48. print("A1: ", A1)
  49. user_groupsA = {k:user_groups[k] for k in A1 if k in user_groups}
  50. print("Size of cluster 1: ", len(user_groupsA))
  51. # Cluster 2
  52. B1 = keylist[cluster_size:2*cluster_size]
  53. # B1 = np.random.choice(keylist, cluster_size, replace=False)
  54. print("B1: ", B1)
  55. user_groupsB = {k:user_groups[k] for k in B1 if k in user_groups}
  56. print("Size of cluster 2: ", len(user_groupsB))
  57. # Cluster 3
  58. C1 = keylist[2*cluster_size:3*cluster_size]
  59. # C1 = np.random.choice(keylist, cluster_size, replace=False)
  60. print("C1: ", C1)
  61. user_groupsC = {k:user_groups[k] for k in C1 if k in user_groups}
  62. print("Size of cluster 3: ", len(user_groupsC))
  63. # Cluster 4
  64. D1 = keylist[3*cluster_size:4*cluster_size]
  65. # D1 = np.random.choice(keylist, cluster_size, replace=False)
  66. print("D1: ", D1)
  67. user_groupsD = {k:user_groups[k] for k in D1 if k in user_groups}
  68. print("Size of cluster 4: ", len(user_groupsD))
  69. # Cluster 5
  70. E1 = keylist[4*cluster_size:5*cluster_size] #np.random.choice(keylist, cluster_size, replace=False)
  71. print("E1: ", E1)
  72. user_groupsE = {k:user_groups[k] for k in E1 if k in user_groups}
  73. print("Size of cluster 5: ", len(user_groupsE))
  74. # Cluster 6
  75. F1 = keylist[5*cluster_size:6*cluster_size] #np.random.choice(keylist, cluster_size, replace=False)
  76. print("F1: ", F1)
  77. user_groupsF = {k:user_groups[k] for k in F1 if k in user_groups}
  78. print("Size of cluster 6: ", len(user_groupsF))
  79. # Cluster 7
  80. G1 = keylist[6*cluster_size:7*cluster_size] #np.random.choice(keylist, cluster_size, replace=False)
  81. print("G1: ", G1)
  82. user_groupsG = {k:user_groups[k] for k in G1 if k in user_groups}
  83. print("Size of cluster 7: ", len(user_groupsC))
  84. # Cluster 8
  85. H1 = keylist[7*cluster_size:] #np.random.choice(keylist, cluster_size, replace=False)
  86. print("H1: ", H1)
  87. user_groupsH = {k:user_groups[k] for k in H1 if k in user_groups}
  88. print("Size of cluster 8: ", len(user_groupsH))
  89. # MODEL PARAM SUMMARY
  90. global_model = build_model(args, train_dataset)
  91. pytorch_total_params = sum(p.numel() for p in global_model.parameters())
  92. print("Model total number of parameters: ", pytorch_total_params)
  93. # from torchsummary import summary
  94. # summary(global_model, (1, 28, 28))
  95. # global_model.parameters()
  96. # Set the model to train and send it to device.
  97. global_model.to(device)
  98. # Set model to use Floating Point 16
  99. global_model.to(dtype=torch.float16) ##########################
  100. global_model.train()
  101. print(global_model)
  102. # copy weights
  103. global_weights = global_model.state_dict()
  104. # ======= Set the cluster models to train and send it to device. =======
  105. # Cluster A
  106. cluster_modelA = build_model(args, train_dataset)
  107. cluster_modelA.to(device)
  108. cluster_modelA.to(dtype=torch.float16)
  109. cluster_modelA.train()
  110. # copy weights
  111. cluster_modelA_weights = cluster_modelA.state_dict()
  112. # Cluster B
  113. cluster_modelB = build_model(args, train_dataset)
  114. cluster_modelB.to(device)
  115. cluster_modelB.to(dtype=torch.float16)
  116. cluster_modelB.train()
  117. cluster_modelB_weights = cluster_modelB.state_dict()
  118. # Cluster C
  119. cluster_modelC = build_model(args, train_dataset)
  120. cluster_modelC.to(device)
  121. cluster_modelC.to(dtype=torch.float16)
  122. cluster_modelC.train()
  123. cluster_modelC_weights = cluster_modelC.state_dict()
  124. # Cluster D
  125. cluster_modelD = build_model(args, train_dataset)
  126. cluster_modelD.to(device)
  127. cluster_modelD.to(dtype=torch.float16)
  128. cluster_modelD.train()
  129. cluster_modelD_weights = cluster_modelD.state_dict()
  130. # Cluster E
  131. cluster_modelE = build_model(args, train_dataset)
  132. cluster_modelE.to(device)
  133. cluster_modelE.to(dtype=torch.float16)
  134. cluster_modelE.train()
  135. cluster_modelE_weights = cluster_modelE.state_dict()
  136. # Cluster F
  137. cluster_modelF = build_model(args, train_dataset)
  138. cluster_modelF.to(device)
  139. cluster_modelF.to(dtype=torch.float16)
  140. cluster_modelF.train()
  141. cluster_modelF_weights = cluster_modelF.state_dict()
  142. # Cluster G
  143. cluster_modelG = build_model(args, train_dataset)
  144. cluster_modelG.to(device)
  145. cluster_modelG.to(dtype=torch.float16)
  146. cluster_modelG.train()
  147. cluster_modelG_weights = cluster_modelG.state_dict()
  148. # Cluster H
  149. cluster_modelH = build_model(args, train_dataset)
  150. cluster_modelH.to(device)
  151. cluster_modelH.to(dtype=torch.float16)
  152. cluster_modelH.train()
  153. cluster_modelH_weights = cluster_modelH.state_dict()
  154. train_loss, train_accuracy = [], []
  155. val_acc_list, net_list = [], []
  156. cv_loss, cv_acc = [], []
  157. print_every = 1
  158. val_loss_pre, counter = 0, 0
  159. testacc_check, epoch = 0, 0
  160. idx = np.random.randint(0,99)
  161. # for epoch in tqdm(range(args.epochs)):
  162. for epoch in range(args.epochs):
  163. # while testacc_check < args.test_acc or epoch < args.epochs:
  164. # while epoch < args.epochs:
  165. local_weights, local_losses, local_accuracies= [], [], []
  166. print(f'\n | Global Training Round : {epoch+1} |\n')
  167. # ============== TRAIN ==============
  168. global_model.train()
  169. # ===== Cluster A =====
  170. A_model, A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, args.Cepochs, logger, cluster_dtype=torch.float16)
  171. local_weights.append(copy.deepcopy(A_weights))
  172. local_losses.append(copy.deepcopy(A_losses))
  173. cluster_modelA = global_model# = A_model
  174. # ===== Cluster B =====
  175. B_model, B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, args.Cepochs, logger, cluster_dtype=torch.float16)
  176. local_weights.append(copy.deepcopy(B_weights))
  177. local_losses.append(copy.deepcopy(B_losses))
  178. cluster_modelB = global_model# = B_model
  179. # ===== Cluster C =====
  180. C_model, C_weights, C_losses = fl_train(args, train_dataset, cluster_modelC, C1, user_groupsC, args.Cepochs, logger, cluster_dtype=torch.float16)
  181. local_weights.append(copy.deepcopy(C_weights))
  182. local_losses.append(copy.deepcopy(C_losses))
  183. cluster_modelC = global_model# = C_model
  184. # ===== Cluster D =====
  185. D_model, D_weights, D_losses = fl_train(args, train_dataset, cluster_modelD, D1, user_groupsD, args.Cepochs, logger, cluster_dtype=torch.float16)
  186. local_weights.append(copy.deepcopy(D_weights))
  187. local_losses.append(copy.deepcopy(D_losses))
  188. cluster_modelD = global_model# = D_model
  189. # ===== Cluster E =====
  190. E_model, E_weights, E_losses = fl_train(args, train_dataset, cluster_modelE, E1, user_groupsE, args.Cepochs, logger, cluster_dtype=torch.float16)
  191. local_weights.append(copy.deepcopy(E_weights))
  192. local_losses.append(copy.deepcopy(E_losses))
  193. cluster_modelE = global_model# = E_model
  194. # ===== Cluster F =====
  195. F_model, F_weights, F_losses = fl_train(args, train_dataset, cluster_modelF, F1, user_groupsF, args.Cepochs, logger, cluster_dtype=torch.float16)
  196. local_weights.append(copy.deepcopy(F_weights))
  197. local_losses.append(copy.deepcopy(F_losses))
  198. cluster_modelF = global_model# = F_model
  199. # ===== Cluster G =====
  200. G_model, G_weights, G_losses = fl_train(args, train_dataset, cluster_modelG, G1, user_groupsG, args.Cepochs, logger, cluster_dtype=torch.float16)
  201. local_weights.append(copy.deepcopy(G_weights))
  202. local_losses.append(copy.deepcopy(G_losses))
  203. cluster_modelG = global_model# = G_model
  204. # ===== Cluster H =====
  205. H_model, H_weights, H_losses = fl_train(args, train_dataset, cluster_modelH, H1, user_groupsH, args.Cepochs, logger, cluster_dtype=torch.float16)
  206. local_weights.append(copy.deepcopy(H_weights))
  207. local_losses.append(copy.deepcopy(H_losses))
  208. cluster_modelH = global_model# = H_model
  209. # averaging global weights
  210. global_weights = average_weights(local_weights)
  211. # update global weights
  212. global_model.load_state_dict(global_weights)
  213. loss_avg = sum(local_losses) / len(local_losses)
  214. train_loss.append(loss_avg)
  215. # ============== EVAL ==============
  216. # Calculate avg training accuracy over all users at every epoch
  217. list_acc, list_loss = [], []
  218. global_model.eval()
  219. # print("========== idx ========== ", idx)
  220. for c in range(args.num_users):
  221. # for c in range(cluster_size):
  222. # C = np.random.choice(keylist, int(args.frac * args.num_users), replace=False) # random set of clients
  223. # print("C: ", C)
  224. # for c in C:
  225. local_model = LocalUpdate(args=args, dataset=train_dataset,
  226. idxs=user_groups[c], logger=logger)
  227. acc, loss = local_model.inference(model=global_model, dtype=torch.float16)
  228. list_acc.append(acc)
  229. list_loss.append(loss)
  230. train_accuracy.append(sum(list_acc)/len(list_acc))
  231. # Add
  232. testacc_check = 100*train_accuracy[-1]
  233. epoch = epoch + 1
  234. # print global training loss after every 'i' rounds
  235. if (epoch+1) % print_every == 0:
  236. print(f' \nAvg Training Stats after {epoch+1} global rounds:')
  237. print(f'Training Loss : {np.mean(np.array(train_loss))}')
  238. print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
  239. print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
  240. # Test inference after completion of training
  241. test_acc, test_loss = test_inference(args, global_model, test_dataset, dtype=torch.float16)
  242. # print(f' \n Results after {args.epochs} global rounds of training:')
  243. print(f"\nAvg Training Stats after {epoch} global rounds:")
  244. print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
  245. print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
  246. # Saving the objects train_loss and train_accuracy:
  247. file_name = '../save/objects_fp16/HFL8_{}_{}_{}_lr[{}]_C[{}]_iid[{}]_E[{}]_B[{}]_FP16.pkl'.\
  248. format(args.dataset, args.model, epoch, args.lr, args.frac, args.iid,
  249. args.local_ep, args.local_bs)
  250. with open(file_name, 'wb') as f:
  251. pickle.dump([train_loss, train_accuracy], f)
  252. print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))