123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # Python version: 3.6
- import copy
- import torch
- from torchvision import datasets, transforms
- from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
- from sampling import cifar_iid, cifar_noniid
- def get_dataset(args):
- """ Returns train and test datasets and a user group which is a dict where
- the keys are the user index and the values are the corresponding data for
- each of those users.
- """
- if args.dataset == 'cifar':
- data_dir = '../data/cifar/'
- apply_transform = transforms.Compose(
- [transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
- train_dataset = datasets.MNIST(data_dir, train=True, download=True,
- transform=apply_transform)
- test_dataset = datasets.MNIST(data_dir, train=False, download=True,
- transform=apply_transform)
- # sample training data amongst users
- if args.iid:
- # Sample IID user data from Mnist
- user_groups = cifar_iid(train_dataset, args.num_users)
- else:
- # Sample Non-IID user data from Mnist
- if args.unequal:
- # Chose uneuqal splits for every user
- raise NotImplementedError()
- else:
- # Chose euqal splits for every user
- user_groups = cifar_noniid(train_dataset, args.num_users)
- elif args.dataset == 'mnist' or 'fmnist':
- if args.dataset == 'mnist':
- data_dir = '../data/mnist/'
- else:
- data_dir = '../data/fmnist/'
- apply_transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))])
- train_dataset = datasets.MNIST(data_dir, train=True, download=True,
- transform=apply_transform)
- test_dataset = datasets.MNIST(data_dir, train=False, download=True,
- transform=apply_transform)
- # sample training data amongst users
- if args.iid:
- # Sample IID user data from Mnist
- user_groups = mnist_iid(train_dataset, args.num_users)
- else:
- # Sample Non-IID user data from Mnist
- if args.unequal:
- # Chose uneuqal splits for every user
- user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
- else:
- # Chose euqal splits for every user
- user_groups = mnist_noniid(train_dataset, args.num_users)
- return train_dataset, test_dataset, user_groups
- def average_weights(w):
- """
- Returns the average of the weights.
- """
- w_avg = copy.deepcopy(w[0])
- for key in w_avg.keys():
- for i in range(1, len(w)):
- w_avg[key] += w[i][key]
- w_avg[key] = torch.div(w_avg[key], len(w))
- return w_avg
|