|
@@ -0,0 +1,70 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+# Python version: 3.6
|
|
|
+
|
|
|
+from torchvision import datasets, transforms
|
|
|
+from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
|
|
|
+from sampling import cifar_iid, cifar_noniid
|
|
|
+
|
|
|
+
|
|
|
+def get_dataset(args):
|
|
|
+ """ Returns train and test datasets and a user group which is a dict where
|
|
|
+ the keys are the user index and the values are the corresponding data for
|
|
|
+ each of those users.
|
|
|
+ """
|
|
|
+
|
|
|
+ if args.dataset == 'cifar':
|
|
|
+ data_dir = '../data/cifar/'
|
|
|
+ apply_transform = transforms.Compose(
|
|
|
+ [transforms.ToTensor(),
|
|
|
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
|
+
|
|
|
+ train_dataset = datasets.MNIST(data_dir, train=True, download=True,
|
|
|
+ transform=apply_transform)
|
|
|
+
|
|
|
+ test_dataset = datasets.MNIST(data_dir, train=False, download=True,
|
|
|
+ transform=apply_transform)
|
|
|
+
|
|
|
+ # sample training data amongst users
|
|
|
+ if args.iid:
|
|
|
+ # Sample IID user data from Mnist
|
|
|
+ user_groups = cifar_iid(train_dataset, args.num_users)
|
|
|
+ else:
|
|
|
+ # Sample Non-IID user data from Mnist
|
|
|
+ if args.unequal:
|
|
|
+ # Chose uneuqal splits for every user
|
|
|
+ raise NotImplementedError()
|
|
|
+ else:
|
|
|
+ # Chose euqal splits for every user
|
|
|
+ user_groups = cifar_noniid(train_dataset, args.num_users)
|
|
|
+
|
|
|
+ elif args.dataset == 'mnist' or 'fmnist':
|
|
|
+ if args.dataset == 'mnist':
|
|
|
+ data_dir = '../data/mnist/'
|
|
|
+ else:
|
|
|
+ data_dir = '../data/fmnist/'
|
|
|
+
|
|
|
+ apply_transform = transforms.Compose([
|
|
|
+ transforms.ToTensor(),
|
|
|
+ transforms.Normalize((0.1307,), (0.3081,))])
|
|
|
+
|
|
|
+ train_dataset = datasets.MNIST(data_dir, train=True, download=True,
|
|
|
+ transform=apply_transform)
|
|
|
+
|
|
|
+ test_dataset = datasets.MNIST(data_dir, train=False, download=True,
|
|
|
+ transform=apply_transform)
|
|
|
+
|
|
|
+ # sample training data amongst users
|
|
|
+ if args.iid:
|
|
|
+ # Sample IID user data from Mnist
|
|
|
+ user_groups = mnist_iid(train_dataset, args.num_users)
|
|
|
+ else:
|
|
|
+ # Sample Non-IID user data from Mnist
|
|
|
+ if args.unequal:
|
|
|
+ # Chose uneuqal splits for every user
|
|
|
+ user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
|
|
|
+ else:
|
|
|
+ # Chose euqal splits for every user
|
|
|
+ user_groups = mnist_noniid(train_dataset, args.num_users)
|
|
|
+
|
|
|
+ return train_dataset, test_dataset, user_groups
|