Bläddra i källkod

Improved Logging and Readability

AshwinRJ 5 år sedan
förälder
incheckning
676dcdbd30
1 ändrade filer med 34 tillägg och 32 borttagningar
  1. 34 32
      src/main_fedavg.py

+ 34 - 32
src/main_fedavg.py

@@ -11,10 +11,10 @@ import numpy as np
 from tqdm import tqdm
 
 import torch
-from tensorboardX import SummaryWrepoch
+from tensorboardX import SummaryWriter
 
 from options import args_parser
-from update import LocalUpdate
+from update import LocalUpdate, test_inference
 from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
 from averaging import average_weights
 from utils import get_dataset
@@ -25,12 +25,13 @@ if __name__ == '__main__':
 
     # define paths
     path_project = os.path.abspath('..')
-    summary = SummaryWrepoch('local')
+    logger = SummaryWriter('../logs')
 
     args = args_parser()
+
     if args.gpu:
         torch.cuda.set_device(args.gpu)
-        device = 'cuda' if args.gpu else 'cpu'
+    device = 'cuda' if args.gpu else 'cpu'
 
     # load dataset and user groups
     train_dataset, test_dataset, user_groups = get_dataset(args)
@@ -63,24 +64,26 @@ if __name__ == '__main__':
     # copy weights
     global_weights = global_model.state_dict()
 
-    # training
+    # Training
     train_loss, train_accuracy = [], []
     val_acc_list, net_list = [], []
     cv_loss, cv_acc = [], []
-    print_every = 20
+    print_every = 2
     val_loss_pre, counter = 0, 0
 
     for epoch in tqdm(range(args.epochs)):
-        global_model.train()
         local_weights, local_losses = [], []
+        print(f'\n | Global Training Round : {epoch+1} |\n')
 
+        global_model.train()
         m = max(int(args.frac * args.num_users), 1)
         idxs_users = np.random.choice(range(args.num_users), m, replace=False)
 
         for idx in idxs_users:
             local_model = LocalUpdate(args=args, dataset=train_dataset,
-                                      idxs=user_groups[idx], logger=summary)
-            w, loss = local_model.update_weights(net=copy.deepcopy(global_model))
+                                      idxs=user_groups[idx], logger=logger)
+            w, loss = local_model.update_weights(
+                model=copy.deepcopy(global_model), global_round=epoch)
             local_weights.append(copy.deepcopy(w))
             local_losses.append(copy.deepcopy(loss))
 
@@ -90,10 +93,7 @@ if __name__ == '__main__':
         # copy weight to global model
         global_model.load_state_dict(global_weights)
 
-        # print loss after every 20 rounds
         loss_avg = sum(local_losses) / len(local_losses)
-        if (epoch+1) % print_every == 0:
-            print('\nTrain loss:', loss_avg)
         train_loss.append(loss_avg)
 
         # Calculate avg training accuracy over all users at every epoch
@@ -101,34 +101,36 @@ if __name__ == '__main__':
         global_model.eval()
         for c in range(args.num_users):
             local_model = LocalUpdate(args=args, dataset=train_dataset,
-                                      idxs=user_groups[idx], logger=summary)
-            acc, loss = local_model.inference(net=global_model)
+                                      idxs=user_groups[idx], logger=logger)
+            acc, loss = local_model.inference(model=global_model)
             list_acc.append(acc)
             list_loss.append(loss)
         train_accuracy.append(sum(list_acc)/len(list_acc))
 
+        # print global training loss after every 'i' rounds
+        if (epoch+1) % print_every == 0:
+            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
+            print(f'Training Loss : {np.mean(np.array(train_loss))}')
+            print('Train Accuracy: {:.2f}% \n'.format(
+                100.*(np.mean(np.array(train_accuracy)))))
+
     # Test inference after completion of training
-    test_acc, test_loss = [], []
-    for c in tqdm(range(args.num_users)):
-        local_model = LocalUpdate(args=args, dataset=test_dataset,
-                                  idxs=user_groups[idx], logger=summary)
-        acc, loss = local_model.test(net=global_model)
-        test_acc.append(acc)
-        test_loss.append(loss)
-
-    print("Final Average Train Accuracy after {} epochs: {:.2f}%".format(
-        args.epochs, 100.*train_accuracy[-1]))
-
-    print("Final Average Test Accuracy after {} epochs: {:.2f}%".format(
-        args.epochs, (100.*sum(test_acc)/len(test_acc))))
-
-    # # Saving the objects train_loss and train_accuracy:
-    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(args.dataset,
-                                                                                args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)
+    test_acc, test_loss = test_inference(args, global_model, test_dataset)
+
+    print(f' \n Results after {args.epochs} global rounds of training:')
+    print("|---- Avg Train Accuracy: {:.2f}%".format(
+        100.*(np.mean(np.array(train_accuracy)))))
+    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
+
+    # Saving the objects train_loss and train_accuracy:
+    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
+        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
+               args.local_ep, args.local_bs)
+
     with open(file_name, 'wb') as f:
         pickle.dump([train_loss, train_accuracy], f)
 
-    print('Total Time: {0:0.4f}'.format(time.time()-start_time))
+    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
 
     # PLOTTING (optional)
     # import matplotlib