12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # Python version: 3.6
- from torchvision import datasets, transforms
- from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
- from sampling import cifar_iid, cifar_noniid
- def get_dataset(args):
- """ Returns train and test datasets and a user group which is a dict where
- the keys are the user index and the values are the corresponding data for
- each of those users.
- """
- if args.dataset == 'cifar':
- data_dir = '../data/cifar/'
- apply_transform = transforms.Compose(
- [transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
- train_dataset = datasets.MNIST(data_dir, train=True, download=True,
- transform=apply_transform)
- test_dataset = datasets.MNIST(data_dir, train=False, download=True,
- transform=apply_transform)
- # sample training data amongst users
- if args.iid:
- # Sample IID user data from Mnist
- user_groups = cifar_iid(train_dataset, args.num_users)
- else:
- # Sample Non-IID user data from Mnist
- if args.unequal:
- # Chose uneuqal splits for every user
- raise NotImplementedError()
- else:
- # Chose euqal splits for every user
- user_groups = cifar_noniid(train_dataset, args.num_users)
- elif args.dataset == 'mnist' or 'fmnist':
- if args.dataset == 'mnist':
- data_dir = '../data/mnist/'
- else:
- data_dir = '../data/fmnist/'
- apply_transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))])
- train_dataset = datasets.MNIST(data_dir, train=True, download=True,
- transform=apply_transform)
- test_dataset = datasets.MNIST(data_dir, train=False, download=True,
- transform=apply_transform)
- # sample training data amongst users
- if args.iid:
- # Sample IID user data from Mnist
- user_groups = mnist_iid(train_dataset, args.num_users)
- else:
- # Sample Non-IID user data from Mnist
- if args.unequal:
- # Chose uneuqal splits for every user
- user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
- else:
- # Chose euqal splits for every user
- user_groups = mnist_noniid(train_dataset, args.num_users)
- return train_dataset, test_dataset, user_groups
|