federated-hierarchical8_main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import os
  5. import copy
  6. import time
  7. import pickle
  8. import numpy as np
  9. from tqdm import tqdm
  10. import torch
  11. from tensorboardX import SummaryWriter
  12. from options import args_parser
  13. from update import LocalUpdate, test_inference
  14. from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
  15. from utils import get_dataset, average_weights, exp_details
  16. import math
  17. import random
  18. # BUILD MODEL
  19. def build_model(args, train_dataset):
  20. if args.model == 'cnn':
  21. # Convolutional neural network
  22. if args.dataset == 'mnist':
  23. global_model = CNNMnist(args=args)
  24. elif args.dataset == 'fmnist':
  25. global_model = CNNFashion_Mnist(args=args)
  26. elif args.dataset == 'cifar':
  27. global_model = CNNCifar(args=args)
  28. elif args.model == 'mlp':
  29. # Multi-layer preceptron
  30. img_size = train_dataset[0][0].shape
  31. len_in = 1
  32. for x in img_size:
  33. len_in *= x
  34. global_model = MLP(dim_in=len_in, dim_hidden=args.mlpdim,
  35. dim_out=args.num_classes)
  36. else:
  37. exit('Error: unrecognized model')
  38. return global_model
  39. # Defining the training function
  40. def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs):
  41. cluster_train_loss, cluster_train_acc = [], []
  42. cluster_val_acc_list, cluster_net_list = [], []
  43. cluster_cv_loss, cluster_cv_acc = [], []
  44. # print_every = 1
  45. cluster_val_loss_pre, counter = 0, 0
  46. for epoch in range(epochs):
  47. cluster_local_weights, cluster_local_losses = [], []
  48. # print(f'\n | Cluster Training Round : {epoch+1} |\n')
  49. cluster_global_model.train()
  50. # m = max(int(args.frac * len(cluster)), 1)
  51. # m = max(int(math.ceil(args.frac * len(cluster))), 1)
  52. m = min(int(len(cluster)), 10)
  53. # print("=== m ==== ", m)
  54. idxs_users = np.random.choice(cluster, m, replace=False)
  55. for idx in idxs_users:
  56. cluster_local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[idx], logger=logger)
  57. cluster_w, cluster_loss = cluster_local_model.update_weights(model=copy.deepcopy(cluster_global_model), global_round=epoch)
  58. cluster_local_weights.append(copy.deepcopy(cluster_w))
  59. cluster_local_losses.append(copy.deepcopy(cluster_loss))
  60. # print('| Global Round : {} | User : {} | \tLoss: {:.6f}'.format(epoch, idx, cluster_loss))
  61. # averaging global weights
  62. cluster_global_weights = average_weights(cluster_local_weights)
  63. # update global weights
  64. cluster_global_model.load_state_dict(cluster_global_weights)
  65. cluster_loss_avg = sum(cluster_local_losses) / len(cluster_local_losses)
  66. cluster_train_loss.append(cluster_loss_avg)
  67. # ============== EVAL ==============
  68. # Calculate avg training accuracy over all users at every epoch
  69. list_acc, list_loss = [], []
  70. cluster_global_model.eval()
  71. # C = np.random.choice(cluster, m, replace=False) # random set of clients
  72. # print("C: ", C)
  73. # for c in C:
  74. # for c in range(len(cluster)):
  75. for c in idxs_users:
  76. cluster_local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[c], logger=logger)
  77. # local_model = LocalUpdate(args=args, dataset=train_dataset,idxs=user_groups[idx], logger=logger)
  78. acc, loss = cluster_local_model.inference(model=global_model)
  79. list_acc.append(acc)
  80. list_loss.append(loss)
  81. # cluster_train_acc.append(sum(list_acc)/len(list_acc))
  82. # Add
  83. # print("Cluster accuracy: ", 100*cluster_train_acc[-1])
  84. print("Cluster accuracy: ", 100*sum(list_acc)/len(list_acc))
  85. return cluster_global_model, cluster_global_weights, cluster_loss_avg
  86. if __name__ == '__main__':
  87. start_time = time.time()
  88. # define paths
  89. path_project = os.path.abspath('..')
  90. logger = SummaryWriter('../logs')
  91. args = args_parser()
  92. exp_details(args)
  93. if args.gpu:
  94. torch.cuda.set_device(args.gpu)
  95. device = 'cuda' if args.gpu else 'cpu'
  96. # load dataset and user groups
  97. train_dataset, test_dataset, user_groupsold = get_dataset(args)
  98. # user_groups = user_groupsold
  99. # keylist = list(user_groups.keys())
  100. # ======= Shuffle dataset =======
  101. keys = list(user_groupsold.keys())
  102. random.shuffle(keys)
  103. user_groups = dict()
  104. for key in keys:
  105. user_groups.update({key:user_groupsold[key]})
  106. # print(user_groups.keys())
  107. keylist = list(user_groups.keys())
  108. print("keylist: ", keylist)
  109. # ======= Splitting into clusters. FL groups =======
  110. cluster_size = int(args.num_users / args.num_clusters)
  111. # cluster_size = 50
  112. print("Each cluster size: ", cluster_size)
  113. # Cluster 1
  114. A1 = keylist[:cluster_size]
  115. # A1 = np.random.choice(keylist, cluster_size, replace=False)
  116. print("A1: ", A1)
  117. user_groupsA = {k:user_groups[k] for k in A1 if k in user_groups}
  118. print("Size of cluster 1: ", len(user_groupsA))
  119. # Cluster 2
  120. B1 = keylist[cluster_size:2*cluster_size]
  121. # B1 = np.random.choice(keylist, cluster_size, replace=False)
  122. print("B1: ", B1)
  123. user_groupsB = {k:user_groups[k] for k in B1 if k in user_groups}
  124. print("Size of cluster 2: ", len(user_groupsB))
  125. # Cluster 3
  126. C1 = keylist[2*cluster_size:3*cluster_size]
  127. # C1 = np.random.choice(keylist, cluster_size, replace=False)
  128. print("C1: ", C1)
  129. user_groupsC = {k:user_groups[k] for k in C1 if k in user_groups}
  130. print("Size of cluster 3: ", len(user_groupsC))
  131. # Cluster 4
  132. D1 = keylist[3*cluster_size:4*cluster_size]
  133. # D1 = np.random.choice(keylist, cluster_size, replace=False)
  134. print("D1: ", D1)
  135. user_groupsD = {k:user_groups[k] for k in D1 if k in user_groups}
  136. print("Size of cluster 4: ", len(user_groupsD))
  137. # Cluster 5
  138. E1 = keylist[4*cluster_size:5*cluster_size] #np.random.choice(keylist, cluster_size, replace=False)
  139. print("E1: ", E1)
  140. user_groupsE = {k:user_groups[k] for k in E1 if k in user_groups}
  141. print("Size of cluster 5: ", len(user_groupsE))
  142. # Cluster 6
  143. F1 = keylist[5*cluster_size:6*cluster_size] #np.random.choice(keylist, cluster_size, replace=False)
  144. print("F1: ", F1)
  145. user_groupsF = {k:user_groups[k] for k in F1 if k in user_groups}
  146. print("Size of cluster 6: ", len(user_groupsF))
  147. # Cluster 7
  148. G1 = keylist[6*cluster_size:7*cluster_size] #np.random.choice(keylist, cluster_size, replace=False)
  149. print("G1: ", G1)
  150. user_groupsG = {k:user_groups[k] for k in G1 if k in user_groups}
  151. print("Size of cluster 7: ", len(user_groupsC))
  152. # Cluster 8
  153. H1 = keylist[7*cluster_size:] #np.random.choice(keylist, cluster_size, replace=False)
  154. print("H1: ", H1)
  155. user_groupsH = {k:user_groups[k] for k in H1 if k in user_groups}
  156. print("Size of cluster 8: ", len(user_groupsH))
  157. # MODEL PARAM SUMMARY
  158. global_model = build_model(args, train_dataset)
  159. pytorch_total_params = sum(p.numel() for p in global_model.parameters())
  160. print("Model total number of parameters: ", pytorch_total_params)
  161. # from torchsummary import summary
  162. # summary(global_model, (1, 28, 28))
  163. # global_model.parameters()
  164. # Set the model to train and send it to device.
  165. global_model.to(device)
  166. global_model.train()
  167. print(global_model)
  168. # copy weights
  169. global_weights = global_model.state_dict()
  170. # ======= Set the cluster models to train and send it to device. =======
  171. # Cluster A
  172. cluster_modelA = build_model(args, train_dataset)
  173. cluster_modelA.to(device)
  174. cluster_modelA.train()
  175. # copy weights
  176. cluster_modelA_weights = cluster_modelA.state_dict()
  177. # Cluster B
  178. cluster_modelB = build_model(args, train_dataset)
  179. cluster_modelB.to(device)
  180. cluster_modelB.train()
  181. cluster_modelB_weights = cluster_modelB.state_dict()
  182. # Cluster C
  183. cluster_modelC = build_model(args, train_dataset)
  184. cluster_modelC.to(device)
  185. cluster_modelC.train()
  186. cluster_modelC_weights = cluster_modelC.state_dict()
  187. # Cluster D
  188. cluster_modelD = build_model(args, train_dataset)
  189. cluster_modelD.to(device)
  190. cluster_modelD.train()
  191. cluster_modelD_weights = cluster_modelD.state_dict()
  192. # Cluster E
  193. cluster_modelE = build_model(args, train_dataset)
  194. cluster_modelE.to(device)
  195. cluster_modelE.train()
  196. cluster_modelE_weights = cluster_modelE.state_dict()
  197. # Cluster F
  198. cluster_modelF = build_model(args, train_dataset)
  199. cluster_modelF.to(device)
  200. cluster_modelF.train()
  201. cluster_modelF_weights = cluster_modelF.state_dict()
  202. # Cluster G
  203. cluster_modelG = build_model(args, train_dataset)
  204. cluster_modelG.to(device)
  205. cluster_modelG.train()
  206. cluster_modelG_weights = cluster_modelG.state_dict()
  207. # Cluster H
  208. cluster_modelH = build_model(args, train_dataset)
  209. cluster_modelH.to(device)
  210. cluster_modelH.train()
  211. # copy weights
  212. cluster_modelH_weights = cluster_modelH.state_dict()
  213. train_loss, train_accuracy = [], []
  214. val_acc_list, net_list = [], []
  215. cv_loss, cv_acc = [], []
  216. print_every = 1
  217. val_loss_pre, counter = 0, 0
  218. testacc_check, epoch = 0, 0
  219. idx = np.random.randint(0,99)
  220. # for epoch in tqdm(range(args.epochs)):
  221. for epoch in range(args.epochs):
  222. # while testacc_check < args.test_acc or epoch < args.epochs:
  223. # while epoch < args.epochs:
  224. local_weights, local_losses, local_accuracies= [], [], []
  225. print(f'\n | Global Training Round : {epoch+1} |\n')
  226. # ============== TRAIN ==============
  227. global_model.train()
  228. # ===== Cluster A =====
  229. A_model, A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, args.Cepochs)
  230. local_weights.append(copy.deepcopy(A_weights))
  231. local_losses.append(copy.deepcopy(A_losses))
  232. cluster_modelA = global_model# = A_model
  233. # ===== Cluster B =====
  234. B_model, B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, args.Cepochs)
  235. local_weights.append(copy.deepcopy(B_weights))
  236. local_losses.append(copy.deepcopy(B_losses))
  237. cluster_modelB = global_model# = B_model
  238. # ===== Cluster C =====
  239. C_model, C_weights, C_losses = fl_train(args, train_dataset, cluster_modelC, C1, user_groupsC, args.Cepochs)
  240. local_weights.append(copy.deepcopy(C_weights))
  241. local_losses.append(copy.deepcopy(C_losses))
  242. cluster_modelC = global_model# = C_model
  243. # ===== Cluster D =====
  244. D_model, D_weights, D_losses = fl_train(args, train_dataset, cluster_modelD, D1, user_groupsD, args.Cepochs)
  245. local_weights.append(copy.deepcopy(D_weights))
  246. local_losses.append(copy.deepcopy(D_losses))
  247. cluster_modelD = global_model# = D_model
  248. # ===== Cluster E =====
  249. E_model, E_weights, E_losses = fl_train(args, train_dataset, cluster_modelE, E1, user_groupsE, args.Cepochs)
  250. local_weights.append(copy.deepcopy(E_weights))
  251. local_losses.append(copy.deepcopy(E_losses))
  252. cluster_modelE = global_model# = E_model
  253. # ===== Cluster F =====
  254. F_model, F_weights, F_losses = fl_train(args, train_dataset, cluster_modelF, F1, user_groupsF, args.Cepochs)
  255. local_weights.append(copy.deepcopy(F_weights))
  256. local_losses.append(copy.deepcopy(F_losses))
  257. cluster_modelF = global_model# = F_model
  258. # ===== Cluster G =====
  259. G_model, G_weights, G_losses = fl_train(args, train_dataset, cluster_modelG, G1, user_groupsG, args.Cepochs)
  260. local_weights.append(copy.deepcopy(G_weights))
  261. local_losses.append(copy.deepcopy(G_losses))
  262. cluster_modelG = global_model# = G_model
  263. # ===== Cluster H =====
  264. H_model, H_weights, H_losses = fl_train(args, train_dataset, cluster_modelH, H1, user_groupsH, args.Cepochs)
  265. local_weights.append(copy.deepcopy(H_weights))
  266. local_losses.append(copy.deepcopy(H_losses))
  267. cluster_modelH = global_model# = H_model
  268. # averaging global weights
  269. global_weights = average_weights(local_weights)
  270. # update global weights
  271. global_model.load_state_dict(global_weights)
  272. loss_avg = sum(local_losses) / len(local_losses)
  273. train_loss.append(loss_avg)
  274. # ============== EVAL ==============
  275. # Calculate avg training accuracy over all users at every epoch
  276. list_acc, list_loss = [], []
  277. global_model.eval()
  278. # print("========== idx ========== ", idx)
  279. for c in range(args.num_users):
  280. # for c in range(cluster_size):
  281. # C = np.random.choice(keylist, int(args.frac * args.num_users), replace=False) # random set of clients
  282. # print("C: ", C)
  283. # for c in C:
  284. local_model = LocalUpdate(args=args, dataset=train_dataset,
  285. idxs=user_groups[c], logger=logger)
  286. acc, loss = local_model.inference(model=global_model)
  287. list_acc.append(acc)
  288. list_loss.append(loss)
  289. train_accuracy.append(sum(list_acc)/len(list_acc))
  290. # Add
  291. testacc_check = 100*train_accuracy[-1]
  292. epoch = epoch + 1
  293. # print global training loss after every 'i' rounds
  294. if (epoch+1) % print_every == 0:
  295. print(f' \nAvg Training Stats after {epoch+1} global rounds:')
  296. print(f'Training Loss : {np.mean(np.array(train_loss))}')
  297. print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
  298. print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
  299. # Test inference after completion of training
  300. test_acc, test_loss = test_inference(args, global_model, test_dataset)
  301. # print(f' \n Results after {args.epochs} global rounds of training:')
  302. print(f"\nAvg Training Stats after {epoch} global rounds:")
  303. print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
  304. print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
  305. # Saving the objects train_loss and train_accuracy:
  306. file_name = '../save/objects/HFL8_{}_{}_{}_lr[{}]_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
  307. format(args.dataset, args.model, epoch, args.lr, args.frac, args.iid,
  308. args.local_ep, args.local_bs)
  309. with open(file_name, 'wb') as f:
  310. pickle.dump([train_loss, train_accuracy], f)