소스 검색

Added averaging function

AshwinRJ 4 년 전
부모
커밋
a7ad5d04a9
1개의 변경된 파일14개의 추가작업 그리고 0개의 파일을 삭제
  1. 14 0
      src/utils.py

+ 14 - 0
src/utils.py

@@ -2,6 +2,8 @@
 # -*- coding: utf-8 -*-
 # Python version: 3.6
 
+import copy
+import torch
 from torchvision import datasets, transforms
 from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
 from sampling import cifar_iid, cifar_noniid
@@ -68,3 +70,15 @@ def get_dataset(args):
                 user_groups = mnist_noniid(train_dataset, args.num_users)
 
     return train_dataset, test_dataset, user_groups
+
+
+def average_weights(w):
+    """
+    Returns the average of the weights.
+    """
+    w_avg = copy.deepcopy(w[0])
+    for key in w_avg.keys():
+        for i in range(1, len(w)):
+            w_avg[key] += w[i][key]
+        w_avg[key] = torch.div(w_avg[key], len(w))
+    return w_avg