|
@@ -16,93 +16,93 @@ from tensorboardX import SummaryWriter
|
|
|
from options import args_parser
|
|
|
from update import LocalUpdate, test_inference
|
|
|
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
|
|
|
-from utils import get_dataset, average_weights, exp_details
|
|
|
+from utils import get_dataset, average_weights, exp_details, set_device, build_model, fl_train
|
|
|
import math
|
|
|
import random
|
|
|
|
|
|
|
|
|
-# BUILD MODEL
|
|
|
-def build_model(args, train_dataset):
|
|
|
- if args.model == 'cnn':
|
|
|
- # Convolutional neural network
|
|
|
- if args.dataset == 'mnist':
|
|
|
- global_model = CNNMnist(args=args)
|
|
|
- elif args.dataset == 'fmnist':
|
|
|
- global_model = CNNFashion_Mnist(args=args)
|
|
|
- elif args.dataset == 'cifar':
|
|
|
- global_model = CNNCifar(args=args)
|
|
|
-
|
|
|
- elif args.model == 'mlp':
|
|
|
- # Multi-layer preceptron
|
|
|
- img_size = train_dataset[0][0].shape
|
|
|
- len_in = 1
|
|
|
- for x in img_size:
|
|
|
- len_in *= x
|
|
|
- global_model = MLP(dim_in=len_in, dim_hidden=args.mlpdim,
|
|
|
- dim_out=args.num_classes)
|
|
|
- else:
|
|
|
- exit('Error: unrecognized model')
|
|
|
+# # BUILD MODEL
|
|
|
+# def build_model(args, train_dataset):
|
|
|
+# if args.model == 'cnn':
|
|
|
+# # Convolutional neural network
|
|
|
+# if args.dataset == 'mnist':
|
|
|
+# global_model = CNNMnist(args=args)
|
|
|
+# elif args.dataset == 'fmnist':
|
|
|
+# global_model = CNNFashion_Mnist(args=args)
|
|
|
+# elif args.dataset == 'cifar':
|
|
|
+# global_model = CNNCifar(args=args)
|
|
|
+
|
|
|
+# elif args.model == 'mlp':
|
|
|
+# # Multi-layer preceptron
|
|
|
+# img_size = train_dataset[0][0].shape
|
|
|
+# len_in = 1
|
|
|
+# for x in img_size:
|
|
|
+# len_in *= x
|
|
|
+# global_model = MLP(dim_in=len_in, dim_hidden=args.mlpdim,
|
|
|
+# dim_out=args.num_classes)
|
|
|
+# else:
|
|
|
+# exit('Error: unrecognized model')
|
|
|
|
|
|
- return global_model
|
|
|
+# return global_model
|
|
|
|
|
|
|
|
|
-# Defining the training function
|
|
|
-def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs):
|
|
|
+# # Defining the training function
|
|
|
+# def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs):
|
|
|
|
|
|
- cluster_train_loss, cluster_train_acc = [], []
|
|
|
- cluster_val_acc_list, cluster_net_list = [], []
|
|
|
- cluster_cv_loss, cluster_cv_acc = [], []
|
|
|
- # print_every = 1
|
|
|
- cluster_val_loss_pre, counter = 0, 0
|
|
|
-
|
|
|
- for epoch in range(epochs):
|
|
|
- cluster_local_weights, cluster_local_losses = [], []
|
|
|
- # print(f'\n | Cluster Training Round : {epoch+1} |\n')
|
|
|
-
|
|
|
- cluster_global_model.train()
|
|
|
- # m = max(int(args.frac * len(cluster)), 1)
|
|
|
- # m = max(int(math.ceil(args.frac * len(cluster))), 1)
|
|
|
- m = min(int(len(cluster)), 10)
|
|
|
- # print("=== m ==== ", m)
|
|
|
- idxs_users = np.random.choice(cluster, m, replace=False)
|
|
|
-
|
|
|
-
|
|
|
- for idx in idxs_users:
|
|
|
- cluster_local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[idx], logger=logger)
|
|
|
- cluster_w, cluster_loss = cluster_local_model.update_weights(model=copy.deepcopy(cluster_global_model), global_round=epoch)
|
|
|
- cluster_local_weights.append(copy.deepcopy(cluster_w))
|
|
|
- cluster_local_losses.append(copy.deepcopy(cluster_loss))
|
|
|
- # print('| Global Round : {} | User : {} | \tLoss: {:.6f}'.format(epoch, idx, cluster_loss))
|
|
|
-
|
|
|
- # averaging global weights
|
|
|
- cluster_global_weights = average_weights(cluster_local_weights)
|
|
|
-
|
|
|
- # update global weights
|
|
|
- cluster_global_model.load_state_dict(cluster_global_weights)
|
|
|
-
|
|
|
- cluster_loss_avg = sum(cluster_local_losses) / len(cluster_local_losses)
|
|
|
- cluster_train_loss.append(cluster_loss_avg)
|
|
|
-
|
|
|
- # ============== EVAL ==============
|
|
|
- # Calculate avg training accuracy over all users at every epoch
|
|
|
- list_acc, list_loss = [], []
|
|
|
- cluster_global_model.eval()
|
|
|
- # C = np.random.choice(cluster, m, replace=False) # random set of clients
|
|
|
- # print("C: ", C)
|
|
|
- # for c in C:
|
|
|
- # for c in range(len(cluster)):
|
|
|
- for c in idxs_users:
|
|
|
- cluster_local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[c], logger=logger)
|
|
|
- # local_model = LocalUpdate(args=args, dataset=train_dataset,idxs=user_groups[idx], logger=logger)
|
|
|
- acc, loss = cluster_local_model.inference(model=global_model)
|
|
|
- list_acc.append(acc)
|
|
|
- list_loss.append(loss)
|
|
|
- # cluster_train_acc.append(sum(list_acc)/len(list_acc))
|
|
|
- # Add
|
|
|
- # print("Cluster accuracy: ", 100*cluster_train_acc[-1])
|
|
|
- print("Cluster accuracy: ", 100*sum(list_acc)/len(list_acc))
|
|
|
-
|
|
|
- return cluster_global_model, cluster_global_weights, cluster_loss_avg
|
|
|
+# cluster_train_loss, cluster_train_acc = [], []
|
|
|
+# cluster_val_acc_list, cluster_net_list = [], []
|
|
|
+# cluster_cv_loss, cluster_cv_acc = [], []
|
|
|
+# # print_every = 1
|
|
|
+# cluster_val_loss_pre, counter = 0, 0
|
|
|
+
|
|
|
+# for epoch in range(epochs):
|
|
|
+# cluster_local_weights, cluster_local_losses = [], []
|
|
|
+# # print(f'\n | Cluster Training Round : {epoch+1} |\n')
|
|
|
+
|
|
|
+# cluster_global_model.train()
|
|
|
+# # m = max(int(args.frac * len(cluster)), 1)
|
|
|
+# # m = max(int(math.ceil(args.frac * len(cluster))), 1)
|
|
|
+# m = min(int(len(cluster)), 10)
|
|
|
+# # print("=== m ==== ", m)
|
|
|
+# idxs_users = np.random.choice(cluster, m, replace=False)
|
|
|
+
|
|
|
+
|
|
|
+# for idx in idxs_users:
|
|
|
+# cluster_local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[idx], logger=logger)
|
|
|
+# cluster_w, cluster_loss = cluster_local_model.update_weights(model=copy.deepcopy(cluster_global_model), global_round=epoch)
|
|
|
+# cluster_local_weights.append(copy.deepcopy(cluster_w))
|
|
|
+# cluster_local_losses.append(copy.deepcopy(cluster_loss))
|
|
|
+# # print('| Global Round : {} | User : {} | \tLoss: {:.6f}'.format(epoch, idx, cluster_loss))
|
|
|
+
|
|
|
+# # averaging global weights
|
|
|
+# cluster_global_weights = average_weights(cluster_local_weights)
|
|
|
+
|
|
|
+# # update global weights
|
|
|
+# cluster_global_model.load_state_dict(cluster_global_weights)
|
|
|
+
|
|
|
+# cluster_loss_avg = sum(cluster_local_losses) / len(cluster_local_losses)
|
|
|
+# cluster_train_loss.append(cluster_loss_avg)
|
|
|
+
|
|
|
+# # ============== EVAL ==============
|
|
|
+# # Calculate avg training accuracy over all users at every epoch
|
|
|
+# list_acc, list_loss = [], []
|
|
|
+# cluster_global_model.eval()
|
|
|
+# # C = np.random.choice(cluster, m, replace=False) # random set of clients
|
|
|
+# # print("C: ", C)
|
|
|
+# # for c in C:
|
|
|
+# # for c in range(len(cluster)):
|
|
|
+# for c in idxs_users:
|
|
|
+# cluster_local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[c], logger=logger)
|
|
|
+# # local_model = LocalUpdate(args=args, dataset=train_dataset,idxs=user_groups[idx], logger=logger)
|
|
|
+# acc, loss = cluster_local_model.inference(model=global_model)
|
|
|
+# list_acc.append(acc)
|
|
|
+# list_loss.append(loss)
|
|
|
+# # cluster_train_acc.append(sum(list_acc)/len(list_acc))
|
|
|
+# # Add
|
|
|
+# # print("Cluster accuracy: ", 100*cluster_train_acc[-1])
|
|
|
+# print("Cluster accuracy: ", 100*sum(list_acc)/len(list_acc))
|
|
|
+
|
|
|
+# return cluster_global_model, cluster_global_weights, cluster_loss_avg
|
|
|
|
|
|
|
|
|
|
|
@@ -118,9 +118,8 @@ if __name__ == '__main__':
|
|
|
args = args_parser()
|
|
|
exp_details(args)
|
|
|
|
|
|
- if args.gpu:
|
|
|
- torch.cuda.set_device(args.gpu)
|
|
|
- device = 'cuda' if args.gpu else 'cpu'
|
|
|
+ # Select CPU or GPU
|
|
|
+ device = set_device(args)
|
|
|
|
|
|
# load dataset and user groups
|
|
|
train_dataset, test_dataset, user_groupsold = get_dataset(args)
|
|
@@ -268,42 +267,42 @@ if __name__ == '__main__':
|
|
|
global_model.train()
|
|
|
|
|
|
# ===== Cluster A =====
|
|
|
- A_model, A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, args.Cepochs)
|
|
|
+ A_model, A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, args.Cepochs, logger)
|
|
|
local_weights.append(copy.deepcopy(A_weights))
|
|
|
local_losses.append(copy.deepcopy(A_losses))
|
|
|
cluster_modelA = global_model# = A_model
|
|
|
# ===== Cluster B =====
|
|
|
- B_model, B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, args.Cepochs)
|
|
|
+ B_model, B_weights, B_losses = fl_train(args, train_dataset, cluster_modelB, B1, user_groupsB, args.Cepochs, logger)
|
|
|
local_weights.append(copy.deepcopy(B_weights))
|
|
|
local_losses.append(copy.deepcopy(B_losses))
|
|
|
cluster_modelB = global_model# = B_model
|
|
|
# ===== Cluster C =====
|
|
|
- C_model, C_weights, C_losses = fl_train(args, train_dataset, cluster_modelC, C1, user_groupsC, args.Cepochs)
|
|
|
+ C_model, C_weights, C_losses = fl_train(args, train_dataset, cluster_modelC, C1, user_groupsC, args.Cepochs, logger)
|
|
|
local_weights.append(copy.deepcopy(C_weights))
|
|
|
local_losses.append(copy.deepcopy(C_losses))
|
|
|
cluster_modelC = global_model# = C_model
|
|
|
# ===== Cluster D =====
|
|
|
- D_model, D_weights, D_losses = fl_train(args, train_dataset, cluster_modelD, D1, user_groupsD, args.Cepochs)
|
|
|
+ D_model, D_weights, D_losses = fl_train(args, train_dataset, cluster_modelD, D1, user_groupsD, args.Cepochs, logger)
|
|
|
local_weights.append(copy.deepcopy(D_weights))
|
|
|
local_losses.append(copy.deepcopy(D_losses))
|
|
|
cluster_modelD = global_model# = D_model
|
|
|
# ===== Cluster E =====
|
|
|
- E_model, E_weights, E_losses = fl_train(args, train_dataset, cluster_modelE, E1, user_groupsE, args.Cepochs)
|
|
|
+ E_model, E_weights, E_losses = fl_train(args, train_dataset, cluster_modelE, E1, user_groupsE, args.Cepochs, logger)
|
|
|
local_weights.append(copy.deepcopy(E_weights))
|
|
|
local_losses.append(copy.deepcopy(E_losses))
|
|
|
cluster_modelE = global_model# = E_model
|
|
|
# ===== Cluster F =====
|
|
|
- F_model, F_weights, F_losses = fl_train(args, train_dataset, cluster_modelF, F1, user_groupsF, args.Cepochs)
|
|
|
+ F_model, F_weights, F_losses = fl_train(args, train_dataset, cluster_modelF, F1, user_groupsF, args.Cepochs, logger)
|
|
|
local_weights.append(copy.deepcopy(F_weights))
|
|
|
local_losses.append(copy.deepcopy(F_losses))
|
|
|
cluster_modelF = global_model# = F_model
|
|
|
# ===== Cluster G =====
|
|
|
- G_model, G_weights, G_losses = fl_train(args, train_dataset, cluster_modelG, G1, user_groupsG, args.Cepochs)
|
|
|
+ G_model, G_weights, G_losses = fl_train(args, train_dataset, cluster_modelG, G1, user_groupsG, args.Cepochs, logger)
|
|
|
local_weights.append(copy.deepcopy(G_weights))
|
|
|
local_losses.append(copy.deepcopy(G_losses))
|
|
|
cluster_modelG = global_model# = G_model
|
|
|
# ===== Cluster H =====
|
|
|
- H_model, H_weights, H_losses = fl_train(args, train_dataset, cluster_modelH, H1, user_groupsH, args.Cepochs)
|
|
|
+ H_model, H_weights, H_losses = fl_train(args, train_dataset, cluster_modelH, H1, user_groupsH, args.Cepochs, logger)
|
|
|
local_weights.append(copy.deepcopy(H_weights))
|
|
|
local_losses.append(copy.deepcopy(H_losses))
|
|
|
cluster_modelH = global_model# = H_model
|