|
@@ -2,6 +2,8 @@
|
|
# -*- coding: utf-8 -*-
|
|
# -*- coding: utf-8 -*-
|
|
# Python version: 3.6
|
|
# Python version: 3.6
|
|
|
|
|
|
|
|
+import copy
|
|
|
|
+import torch
|
|
from torchvision import datasets, transforms
|
|
from torchvision import datasets, transforms
|
|
from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
|
|
from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
|
|
from sampling import cifar_iid, cifar_noniid
|
|
from sampling import cifar_iid, cifar_noniid
|
|
@@ -68,3 +70,15 @@ def get_dataset(args):
|
|
user_groups = mnist_noniid(train_dataset, args.num_users)
|
|
user_groups = mnist_noniid(train_dataset, args.num_users)
|
|
|
|
|
|
return train_dataset, test_dataset, user_groups
|
|
return train_dataset, test_dataset, user_groups
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def average_weights(w):
|
|
|
|
+ """
|
|
|
|
+ Returns the average of the weights.
|
|
|
|
+ """
|
|
|
|
+ w_avg = copy.deepcopy(w[0])
|
|
|
|
+ for key in w_avg.keys():
|
|
|
|
+ for i in range(1, len(w)):
|
|
|
|
+ w_avg[key] += w[i][key]
|
|
|
|
+ w_avg[key] = torch.div(w_avg[key], len(w))
|
|
|
|
+ return w_avg
|