|
@@ -17,6 +17,7 @@ from options import args_parser
|
|
|
from update import LocalUpdate, test_inference
|
|
|
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
|
|
|
from utils import get_dataset, average_weights, exp_details
|
|
|
+import math
|
|
|
|
|
|
|
|
|
# BUILD MODEL
|
|
@@ -58,7 +59,8 @@ def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs
|
|
|
# print(f'\n | Cluster Training Round : {epoch+1} |\n')
|
|
|
|
|
|
cluster_global_model.train()
|
|
|
- m = max(int(args.frac * len(cluster)), 1)
|
|
|
+ # m = max(int(args.frac * len(cluster)), 1)
|
|
|
+ m = max(int(math.ceil(args.frac * len(cluster))), 1)
|
|
|
idxs_users = np.random.choice(cluster, m, replace=False)
|
|
|
|
|
|
|
|
@@ -78,7 +80,21 @@ def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs
|
|
|
cluster_loss_avg = sum(cluster_local_losses) / len(cluster_local_losses)
|
|
|
cluster_train_loss.append(cluster_loss_avg)
|
|
|
|
|
|
- return cluster_global_weights, cluster_loss_avg
|
|
|
+ # ============== EVAL ==============
|
|
|
+ # Calculate avg training accuracy over all users at every epoch
|
|
|
+ list_acc, list_loss = [], []
|
|
|
+ cluster_global_model.eval()
|
|
|
+ for c in range(len(cluster)):
|
|
|
+ local_model = LocalUpdate(args=args, dataset=train_dataset,
|
|
|
+ idxs=user_groups[c], logger=logger)
|
|
|
+ acc, loss = local_model.inference(model=global_model)
|
|
|
+ list_acc.append(acc)
|
|
|
+ list_loss.append(loss)
|
|
|
+ train_accuracy.append(sum(list_acc)/len(list_acc))
|
|
|
+ # Add
|
|
|
+ print("Cluster accuracy: ", 100*train_accuracy[-1])
|
|
|
+
|
|
|
+ return cluster_global_model, cluster_global_weights, cluster_loss_avg
|
|
|
|
|
|
|
|
|
|
|
@@ -113,14 +129,14 @@ if __name__ == '__main__':
|
|
|
B1 = np.arange(cluster_size, cluster_size+cluster_size, dtype=int)
|
|
|
user_groupsB = {k:user_groups[k] for k in B1 if k in user_groups}
|
|
|
print("Size of cluster 2: ", len(user_groupsB))
|
|
|
- # Cluster 3
|
|
|
- C1 = np.arange(2*cluster_size, 3*cluster_size, dtype=int)
|
|
|
- user_groupsC = {k:user_groups[k] for k in C1 if k in user_groups}
|
|
|
- print("Size of cluster 3: ", len(user_groupsC))
|
|
|
- # Cluster 4
|
|
|
- D1 = np.arange(3*cluster_size, 4*cluster_size, dtype=int)
|
|
|
- user_groupsD = {k:user_groups[k] for k in D1 if k in user_groups}
|
|
|
- print("Size of cluster 4: ", len(user_groupsD))
|
|
|
+ # # Cluster 3
|
|
|
+ # C1 = np.arange(2*cluster_size, 3*cluster_size, dtype=int)
|
|
|
+ # user_groupsC = {k:user_groups[k] for k in C1 if k in user_groups}
|
|
|
+ # print("Size of cluster 3: ", len(user_groupsC))
|
|
|
+ # # Cluster 4
|
|
|
+ # D1 = np.arange(3*cluster_size, 4*cluster_size, dtype=int)
|
|
|
+ # user_groupsD = {k:user_groups[k] for k in D1 if k in user_groups}
|
|
|
+ # print("Size of cluster 4: ", len(user_groupsD))
|
|
|
|
|
|
# MODEL PARAM SUMMARY
|
|
|
global_model = build_model(args, train_dataset)
|
|
@@ -153,18 +169,18 @@ if __name__ == '__main__':
|
|
|
cluster_modelB.train()
|
|
|
# copy weights
|
|
|
cluster_modelB_weights = cluster_modelB.state_dict()
|
|
|
- # Cluster C
|
|
|
- cluster_modelC = build_model(args, train_dataset)
|
|
|
- cluster_modelC.to(device)
|
|
|
- cluster_modelC.train()
|
|
|
- # copy weights
|
|
|
- cluster_modelC_weights = cluster_modelC.state_dict()
|
|
|
- # Cluster D
|
|
|
- cluster_modelD = build_model(args, train_dataset)
|
|
|
- cluster_modelD.to(device)
|
|
|
- cluster_modelD.train()
|
|
|
- # copy weights
|
|
|
- cluster_modelD_weights = cluster_modelD.state_dict()
|
|
|
+ # # Cluster C
|
|
|
+ # cluster_modelC = build_model(args, train_dataset)
|
|
|
+ # cluster_modelC.to(device)
|
|
|
+ # cluster_modelC.train()
|
|
|
+ # # copy weights
|
|
|
+ # cluster_modelC_weights = cluster_modelC.state_dict()
|
|
|
+ # # Cluster D
|
|
|
+ # cluster_modelD = build_model(args, train_dataset)
|
|
|
+ # cluster_modelD.to(device)
|
|
|
+ # cluster_modelD.train()
|
|
|
+ # # copy weights
|
|
|
+ # cluster_modelD_weights = cluster_modelD.state_dict()
|
|
|
|
|
|
|
|
|
train_loss, train_accuracy = [], []
|
|
@@ -183,21 +199,23 @@ if __name__ == '__main__':
|
|
|
global_model.train()
|
|
|
|
|
|
# Cluster A
|
|
|
- A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, args.Cepochs)
|
|
|
+ A_model, A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, args.Cepochs)
|
|
|
local_weights.append(copy.deepcopy(A_weights))
|
|
|
- local_losses.append(copy.deepcopy(A_losses))
|
|
|
+ local_losses.append(copy.deepcopy(A_losses))
|
|
|
+ cluster_modelA = A_model
|
|
|
# Cluster B
|
|
|
- B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, args.Cepochs)
|
|
|
+ B_model, B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, args.Cepochs)
|
|
|
local_weights.append(copy.deepcopy(B_weights))
|
|
|
local_losses.append(copy.deepcopy(B_losses))
|
|
|
- # Cluster C
|
|
|
- C_weights, C_losses = fl_train(args, train_dataset, cluster_modelC, C1, user_groupsC, args.Cepochs)
|
|
|
- local_weights.append(copy.deepcopy(C_weights))
|
|
|
- local_losses.append(copy.deepcopy(C_losses))
|
|
|
- # Cluster D
|
|
|
- D_weights, D_losses = fl_train(args, train_dataset, cluster_modelD, D1, user_groupsD, args.Cepochs)
|
|
|
- local_weights.append(copy.deepcopy(D_weights))
|
|
|
- local_losses.append(copy.deepcopy(D_losses))
|
|
|
+ cluster_modelB = B_model
|
|
|
+ # # Cluster C
|
|
|
+ # C_weights, C_losses = fl_train(args, train_dataset, cluster_modelC, C1, user_groupsC, args.Cepochs)
|
|
|
+ # local_weights.append(copy.deepcopy(C_weights))
|
|
|
+ # local_losses.append(copy.deepcopy(C_losses))
|
|
|
+ # # Cluster D
|
|
|
+ # D_weights, D_losses = fl_train(args, train_dataset, cluster_modelD, D1, user_groupsD, args.Cepochs)
|
|
|
+ # local_weights.append(copy.deepcopy(D_weights))
|
|
|
+ # local_losses.append(copy.deepcopy(D_losses))
|
|
|
|
|
|
|
|
|
# averaging global weights
|
|
@@ -242,7 +260,7 @@ if __name__ == '__main__':
|
|
|
print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
|
|
|
|
|
|
# Saving the objects train_loss and train_accuracy:
|
|
|
- file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
|
|
|
+ file_name = '../save/objects/HFL_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
|
|
|
format(args.dataset, args.model, epoch, args.frac, args.iid,
|
|
|
args.local_ep, args.local_bs)
|
|
|
|