瀏覽代碼

Added averaging function

AshwinRJ 5 年之前
父節點
當前提交
a7ad5d04a9
共有 1 個文件被更改,包括 14 次插入0 次删除
  1. 14 0
      src/utils.py

+ 14 - 0
src/utils.py

@@ -2,6 +2,8 @@
 # -*- coding: utf-8 -*-
 # Python version: 3.6
 
+import copy
+import torch
 from torchvision import datasets, transforms
 from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
 from sampling import cifar_iid, cifar_noniid
@@ -68,3 +70,15 @@ def get_dataset(args):
                 user_groups = mnist_noniid(train_dataset, args.num_users)
 
     return train_dataset, test_dataset, user_groups
+
+
+def average_weights(w):
+    """
+    Returns the average of the weights.
+    """
+    w_avg = copy.deepcopy(w[0])
+    for key in w_avg.keys():
+        for i in range(1, len(w)):
+            w_avg[key] += w[i][key]
+        w_avg[key] = torch.div(w_avg[key], len(w))
+    return w_avg