|
@@ -18,6 +18,7 @@ from update import LocalUpdate, test_inference
|
|
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
|
|
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
|
|
from utils import get_dataset, average_weights, exp_details
|
|
from utils import get_dataset, average_weights, exp_details
|
|
import math
|
|
import math
|
|
|
|
+import random
|
|
|
|
|
|
|
|
|
|
# BUILD MODEL
|
|
# BUILD MODEL
|
|
@@ -48,7 +49,7 @@ def build_model(args, train_dataset):
|
|
# Defining the training function
|
|
# Defining the training function
|
|
def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs):
|
|
def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs):
|
|
|
|
|
|
- cluster_train_loss, cluster_train_accuracy = [], []
|
|
|
|
|
|
+ cluster_train_loss, cluster_train_acc = [], []
|
|
cluster_val_acc_list, cluster_net_list = [], []
|
|
cluster_val_acc_list, cluster_net_list = [], []
|
|
cluster_cv_loss, cluster_cv_acc = [], []
|
|
cluster_cv_loss, cluster_cv_acc = [], []
|
|
# print_every = 1
|
|
# print_every = 1
|
|
@@ -85,14 +86,16 @@ def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs
|
|
list_acc, list_loss = [], []
|
|
list_acc, list_loss = [], []
|
|
cluster_global_model.eval()
|
|
cluster_global_model.eval()
|
|
for c in range(len(cluster)):
|
|
for c in range(len(cluster)):
|
|
|
|
+ # local_model = LocalUpdate(args=args, dataset=train_dataset,
|
|
|
|
+ # idxs=user_groups[c], logger=logger)
|
|
local_model = LocalUpdate(args=args, dataset=train_dataset,
|
|
local_model = LocalUpdate(args=args, dataset=train_dataset,
|
|
- idxs=user_groups[c], logger=logger)
|
|
|
|
|
|
+ idxs=user_groups[idx], logger=logger)
|
|
acc, loss = local_model.inference(model=global_model)
|
|
acc, loss = local_model.inference(model=global_model)
|
|
list_acc.append(acc)
|
|
list_acc.append(acc)
|
|
list_loss.append(loss)
|
|
list_loss.append(loss)
|
|
- train_accuracy.append(sum(list_acc)/len(list_acc))
|
|
|
|
|
|
+ cluster_train_acc.append(sum(list_acc)/len(list_acc))
|
|
# Add
|
|
# Add
|
|
- print("Cluster accuracy: ", 100*train_accuracy[-1])
|
|
|
|
|
|
+ print("Cluster accuracy: ", 100*cluster_train_acc[-1])
|
|
|
|
|
|
return cluster_global_model, cluster_global_weights, cluster_loss_avg
|
|
return cluster_global_model, cluster_global_weights, cluster_loss_avg
|
|
|
|
|
|
@@ -115,18 +118,31 @@ if __name__ == '__main__':
|
|
device = 'cuda' if args.gpu else 'cpu'
|
|
device = 'cuda' if args.gpu else 'cpu'
|
|
|
|
|
|
# load dataset and user groups
|
|
# load dataset and user groups
|
|
- train_dataset, test_dataset, user_groups = get_dataset(args)
|
|
|
|
-
|
|
|
|
|
|
+ train_dataset, test_dataset, user_groupsold = get_dataset(args)
|
|
|
|
+
|
|
|
|
+ # ======= Shuffle dataset =======
|
|
|
|
+ keys = list(user_groupsold.keys())
|
|
|
|
+ random.shuffle(keys)
|
|
|
|
+ user_groups = dict()
|
|
|
|
+ for key in keys:
|
|
|
|
+ user_groups.update({key:user_groupsold[key]})
|
|
|
|
+ # print(user_groups.keys())
|
|
|
|
+ keylist = list(user_groups.keys())
|
|
|
|
+ print("keylist: ", keylist)
|
|
# ======= Splitting into clusters. FL groups =======
|
|
# ======= Splitting into clusters. FL groups =======
|
|
- cluster_size = args.num_users / args.num_clusters
|
|
|
|
|
|
+ cluster_size = int(args.num_users / args.num_clusters)
|
|
print("Each cluster size: ", cluster_size)
|
|
print("Each cluster size: ", cluster_size)
|
|
|
|
|
|
# Cluster 1
|
|
# Cluster 1
|
|
- A1 = np.arange(cluster_size, dtype=int)
|
|
|
|
|
|
+ # A1 = np.arange(cluster_size, dtaype=int)
|
|
|
|
+ A1 = keylist[:cluster_size]
|
|
|
|
+ print("A1: ", A1)
|
|
user_groupsA = {k:user_groups[k] for k in A1 if k in user_groups}
|
|
user_groupsA = {k:user_groups[k] for k in A1 if k in user_groups}
|
|
print("Size of cluster 1: ", len(user_groupsA))
|
|
print("Size of cluster 1: ", len(user_groupsA))
|
|
# Cluster 2
|
|
# Cluster 2
|
|
- B1 = np.arange(cluster_size, cluster_size+cluster_size, dtype=int)
|
|
|
|
|
|
+ # B1 = np.arange(cluster_size, cluster_size+cluster_size, dtype=int)
|
|
|
|
+ B1 = keylist[cluster_size:2*cluster_size]
|
|
|
|
+ print("B1: ", B1)
|
|
user_groupsB = {k:user_groups[k] for k in B1 if k in user_groups}
|
|
user_groupsB = {k:user_groups[k] for k in B1 if k in user_groups}
|
|
print("Size of cluster 2: ", len(user_groupsB))
|
|
print("Size of cluster 2: ", len(user_groupsB))
|
|
# # Cluster 3
|
|
# # Cluster 3
|
|
@@ -169,6 +185,7 @@ if __name__ == '__main__':
|
|
cluster_modelB.train()
|
|
cluster_modelB.train()
|
|
# copy weights
|
|
# copy weights
|
|
cluster_modelB_weights = cluster_modelB.state_dict()
|
|
cluster_modelB_weights = cluster_modelB.state_dict()
|
|
|
|
+
|
|
# # Cluster C
|
|
# # Cluster C
|
|
# cluster_modelC = build_model(args, train_dataset)
|
|
# cluster_modelC = build_model(args, train_dataset)
|
|
# cluster_modelC.to(device)
|
|
# cluster_modelC.to(device)
|
|
@@ -188,10 +205,10 @@ if __name__ == '__main__':
|
|
cv_loss, cv_acc = [], []
|
|
cv_loss, cv_acc = [], []
|
|
print_every = 1
|
|
print_every = 1
|
|
val_loss_pre, counter = 0, 0
|
|
val_loss_pre, counter = 0, 0
|
|
- testacc_check, epoch = 0, 0
|
|
|
|
|
|
+ testacc_check, epoch, idx = 0, 0, 0
|
|
|
|
|
|
- # for epoch in tqdm(range(args.epochs)):
|
|
|
|
- while testacc_check < args.test_acc:
|
|
|
|
|
|
+ for epoch in tqdm(range(args.epochs)):
|
|
|
|
+ # while testacc_check < args.test_acc:
|
|
local_weights, local_losses, local_accuracies= [], [], []
|
|
local_weights, local_losses, local_accuracies= [], [], []
|
|
print(f'\n | Global Training Round : {epoch+1} |\n')
|
|
print(f'\n | Global Training Round : {epoch+1} |\n')
|
|
|
|
|
|
@@ -208,6 +225,7 @@ if __name__ == '__main__':
|
|
local_weights.append(copy.deepcopy(B_weights))
|
|
local_weights.append(copy.deepcopy(B_weights))
|
|
local_losses.append(copy.deepcopy(B_losses))
|
|
local_losses.append(copy.deepcopy(B_losses))
|
|
cluster_modelB = B_model
|
|
cluster_modelB = B_model
|
|
|
|
+
|
|
# # Cluster C
|
|
# # Cluster C
|
|
# C_weights, C_losses = fl_train(args, train_dataset, cluster_modelC, C1, user_groupsC, args.Cepochs)
|
|
# C_weights, C_losses = fl_train(args, train_dataset, cluster_modelC, C1, user_groupsC, args.Cepochs)
|
|
# local_weights.append(copy.deepcopy(C_weights))
|
|
# local_weights.append(copy.deepcopy(C_weights))
|
|
@@ -231,9 +249,10 @@ if __name__ == '__main__':
|
|
# Calculate avg training accuracy over all users at every epoch
|
|
# Calculate avg training accuracy over all users at every epoch
|
|
list_acc, list_loss = [], []
|
|
list_acc, list_loss = [], []
|
|
global_model.eval()
|
|
global_model.eval()
|
|
|
|
+ # print("========== idx ========== ", idx)
|
|
for c in range(args.num_users):
|
|
for c in range(args.num_users):
|
|
local_model = LocalUpdate(args=args, dataset=train_dataset,
|
|
local_model = LocalUpdate(args=args, dataset=train_dataset,
|
|
- idxs=user_groups[c], logger=logger)
|
|
|
|
|
|
+ idxs=user_groups[idx], logger=logger)
|
|
acc, loss = local_model.inference(model=global_model)
|
|
acc, loss = local_model.inference(model=global_model)
|
|
list_acc.append(acc)
|
|
list_acc.append(acc)
|
|
list_loss.append(loss)
|
|
list_loss.append(loss)
|