Przeglądaj źródła

fix error in federated_main.py->idxs=user_groups[c], and implement first hierarchical structure with 2 clusters

wesleyjtann 4 lat temu
rodzic
commit
83c5219a4c

Plik diff jest za duży
+ 0 - 1673
src/.ipynb_checkpoints/federated_main-hierarchical-checkpoint.ipynb


Plik diff jest za duży
+ 2178 - 0
src/.ipynb_checkpoints/federated_main-hierarchical_v0-checkpoint.ipynb


Plik diff jest za duży
+ 2573 - 0
src/.ipynb_checkpoints/federated_main-hierarchical_v1-checkpoint.ipynb


BIN
src/__pycache__/update.cpython-37.pyc


Plik diff jest za duży
+ 0 - 1673
src/federated_main-hierarchical.ipynb


Plik diff jest za duży
+ 2178 - 0
src/federated_main-hierarchical_v0.ipynb


Plik diff jest za duży
+ 2573 - 0
src/federated_main-hierarchical_v1.ipynb


+ 5 - 0
src/federated_main.py

@@ -81,6 +81,7 @@ if __name__ == '__main__':
 
         More details: It sets the mode to train (see source code). You can call either model.eval() or model.train(mode=False) to tell that you are testing. It is somewhat intuitive to expect train function to train model but it does not do that. It just sets the mode.
         """
+        # ============== TRAIN ============== 
         global_model.train()
         m = max(int(args.frac * args.num_users), 1) # C = args.frac. Setting number of clients m for training
         idxs_users = np.random.choice(range(args.num_users), m, replace=False) # args.num_users=100 total clients. Choosing a random array of indices. Subset of clients.
@@ -102,11 +103,15 @@ if __name__ == '__main__':
         loss_avg = sum(local_losses) / len(local_losses)
         train_loss.append(loss_avg) # Performance measure
 
+        # ============== EVAL ============== 
         # Calculate avg training accuracy over all users at every epoch
         list_acc, list_loss = [], []
         global_model.eval() # must set your model into evaluation mode when computing model output values if dropout or bach norm used for training.
 
         for c in range(args.num_users): # 0 to 99
+            local_model = LocalUpdate(args=args, dataset=train_dataset,
+                                      # idxs=user_groups[idx], logger=logger)
+            # Fix error idxs=user_groups[idx] to idxs=user_groups[c]                                      
             local_model = LocalUpdate(args=args, dataset=train_dataset,
                                       idxs=user_groups[idx], logger=logger)
             acc, loss = local_model.inference(model=global_model)

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików