|
@@ -36,7 +36,7 @@ def build_model(args, train_dataset):
|
|
|
len_in = 1
|
|
|
for x in img_size:
|
|
|
len_in *= x
|
|
|
- global_model = MLP(dim_in=len_in, dim_hidden=200,
|
|
|
+ global_model = MLP(dim_in=len_in, dim_hidden=64,
|
|
|
dim_out=args.num_classes)
|
|
|
else:
|
|
|
exit('Error: unrecognized model')
|
|
@@ -67,7 +67,7 @@ def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs
|
|
|
cluster_w, cluster_loss = cluster_local_model.update_weights(model=copy.deepcopy(cluster_global_model), global_round=epoch)
|
|
|
cluster_local_weights.append(copy.deepcopy(cluster_w))
|
|
|
cluster_local_losses.append(copy.deepcopy(cluster_loss))
|
|
|
- print('| Global Round : {} | User : {} | \tLoss: {:.6f}'.format(epoch, idx, cluster_loss))
|
|
|
+ # print('| Global Round : {} | User : {} | \tLoss: {:.6f}'.format(epoch, idx, cluster_loss))
|
|
|
|
|
|
# averaging global weights
|
|
|
cluster_global_weights = average_weights(cluster_local_weights)
|
|
@@ -113,16 +113,23 @@ if __name__ == '__main__':
|
|
|
B1 = np.arange(cluster_size, cluster_size+cluster_size, dtype=int)
|
|
|
user_groupsB = {k:user_groups[k] for k in B1 if k in user_groups}
|
|
|
print("Size of cluster 2: ", len(user_groupsB))
|
|
|
+ # Cluster 3
|
|
|
+ C1 = np.arange(2*cluster_size, 3*cluster_size, dtype=int)
|
|
|
+ user_groupsC = {k:user_groups[k] for k in C1 if k in user_groups}
|
|
|
+ print("Size of cluster 3: ", len(user_groupsC))
|
|
|
+ # Cluster 4
|
|
|
+ D1 = np.arange(3*cluster_size, 4*cluster_size, dtype=int)
|
|
|
+ user_groupsD = {k:user_groups[k] for k in D1 if k in user_groups}
|
|
|
+ print("Size of cluster 4: ", len(user_groupsD))
|
|
|
|
|
|
# MODEL PARAM SUMMARY
|
|
|
global_model = build_model(args, train_dataset)
|
|
|
pytorch_total_params = sum(p.numel() for p in global_model.parameters())
|
|
|
- print(pytorch_total_params)
|
|
|
+ print("Model total number of parameters: ", pytorch_total_params)
|
|
|
|
|
|
- from torchsummary import summary
|
|
|
-
|
|
|
- summary(global_model, (1, 28, 28))
|
|
|
- global_model.parameters()
|
|
|
+ # from torchsummary import summary
|
|
|
+ # summary(global_model, (1, 28, 28))
|
|
|
+ # global_model.parameters()
|
|
|
|
|
|
# Set the model to train and send it to device.
|
|
|
global_model.to(device)
|
|
@@ -134,18 +141,31 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
# ======= Set the cluster models to train and send it to device. =======
|
|
|
+ # Cluster A
|
|
|
cluster_modelA = build_model(args, train_dataset)
|
|
|
cluster_modelA.to(device)
|
|
|
cluster_modelA.train()
|
|
|
# copy weights
|
|
|
cluster_modelA_weights = cluster_modelA.state_dict()
|
|
|
-
|
|
|
- # Set the cluster models to train and send it to device.
|
|
|
+ # Cluster B
|
|
|
cluster_modelB = build_model(args, train_dataset)
|
|
|
cluster_modelB.to(device)
|
|
|
cluster_modelB.train()
|
|
|
# copy weights
|
|
|
- cluster_modelB_weights = cluster_modelA.state_dict()
|
|
|
+ cluster_modelB_weights = cluster_modelB.state_dict()
|
|
|
+ # Cluster C
|
|
|
+ cluster_modelC = build_model(args, train_dataset)
|
|
|
+ cluster_modelC.to(device)
|
|
|
+ cluster_modelC.train()
|
|
|
+ # copy weights
|
|
|
+ cluster_modelC_weights = cluster_modelC.state_dict()
|
|
|
+ # Cluster D
|
|
|
+ cluster_modelD = build_model(args, train_dataset)
|
|
|
+ cluster_modelD.to(device)
|
|
|
+ cluster_modelD.train()
|
|
|
+ # copy weights
|
|
|
+ cluster_modelD_weights = cluster_modelD.state_dict()
|
|
|
+
|
|
|
|
|
|
train_loss, train_accuracy = [], []
|
|
|
val_acc_list, net_list = [], []
|
|
@@ -163,14 +183,21 @@ if __name__ == '__main__':
|
|
|
global_model.train()
|
|
|
|
|
|
# Cluster A
|
|
|
- A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, args.epochs)
|
|
|
+ A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, args.Cepochs)
|
|
|
local_weights.append(copy.deepcopy(A_weights))
|
|
|
- local_losses.append(copy.deepcopy(A_losses))
|
|
|
-
|
|
|
+ local_losses.append(copy.deepcopy(A_losses))
|
|
|
# Cluster B
|
|
|
- B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, args.epochs)
|
|
|
+ B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, args.Cepochs)
|
|
|
local_weights.append(copy.deepcopy(B_weights))
|
|
|
local_losses.append(copy.deepcopy(B_losses))
|
|
|
+ # Cluster C
|
|
|
+ C_weights, C_losses = fl_train(args, train_dataset, cluster_modelC, C1, user_groupsC, args.Cepochs)
|
|
|
+ local_weights.append(copy.deepcopy(C_weights))
|
|
|
+ local_losses.append(copy.deepcopy(C_losses))
|
|
|
+ # Cluster D
|
|
|
+ D_weights, D_losses = fl_train(args, train_dataset, cluster_modelD, D1, user_groupsD, args.Cepochs)
|
|
|
+ local_weights.append(copy.deepcopy(D_weights))
|
|
|
+ local_losses.append(copy.deepcopy(D_losses))
|
|
|
|
|
|
|
|
|
# averaging global weights
|