|
@@ -81,6 +81,7 @@ if __name__ == '__main__':
|
|
|
|
|
|
More details: It sets the mode to train (see source code). You can call either model.eval() or model.train(mode=False) to tell that you are testing. It is somewhat intuitive to expect train function to train model but it does not do that. It just sets the mode.
|
|
|
"""
|
|
|
+ # ============== TRAIN ==============
|
|
|
global_model.train()
|
|
|
m = max(int(args.frac * args.num_users), 1) # C = args.frac. Setting number of clients m for training
|
|
|
idxs_users = np.random.choice(range(args.num_users), m, replace=False) # args.num_users=100 total clients. Choosing a random array of indices. Subset of clients.
|
|
@@ -102,11 +103,15 @@ if __name__ == '__main__':
|
|
|
loss_avg = sum(local_losses) / len(local_losses)
|
|
|
train_loss.append(loss_avg) # Performance measure
|
|
|
|
|
|
+ # ============== EVAL ==============
|
|
|
# Calculate avg training accuracy over all users at every epoch
|
|
|
list_acc, list_loss = [], []
|
|
|
global_model.eval() # must set your model into evaluation mode when computing model output values if dropout or bach norm used for training.
|
|
|
|
|
|
for c in range(args.num_users): # 0 to 99
|
|
|
+ local_model = LocalUpdate(args=args, dataset=train_dataset,
|
|
|
+ # idxs=user_groups[idx], logger=logger)
|
|
|
+ # Fix error idxs=user_groups[idx] to idxs=user_groups[c]
|
|
|
local_model = LocalUpdate(args=args, dataset=train_dataset,
|
|
|
idxs=user_groups[idx], logger=logger)
|
|
|
acc, loss = local_model.inference(model=global_model)
|