AshwinRJ 4 年 前
コミット
eadd078244
1 ファイル変更70 行追加0 行削除
  1. 70 0
      src/utils.py

+ 70 - 0
src/utils.py

@@ -0,0 +1,70 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+from torchvision import datasets, transforms
+from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
+from sampling import cifar_iid, cifar_noniid
+
+
+def get_dataset(args):
+    """ Returns train and test datasets and a user group which is a dict where
+    the keys are the user index and the values are the corresponding data for
+    each of those users.
+    """
+
+    if args.dataset == 'cifar':
+        data_dir = '../data/cifar/'
+        apply_transform = transforms.Compose(
+            [transforms.ToTensor(),
+             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+
+        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
+                                       transform=apply_transform)
+
+        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
+                                      transform=apply_transform)
+
+        # sample training data amongst users
+        if args.iid:
+            # Sample IID user data from Mnist
+            user_groups = cifar_iid(train_dataset, args.num_users)
+        else:
+            # Sample Non-IID user data from Mnist
+            if args.unequal:
+                # Chose uneuqal splits for every user
+                raise NotImplementedError()
+            else:
+                # Chose euqal splits for every user
+                user_groups = cifar_noniid(train_dataset, args.num_users)
+
+    elif args.dataset == 'mnist' or 'fmnist':
+        if args.dataset == 'mnist':
+            data_dir = '../data/mnist/'
+        else:
+            data_dir = '../data/fmnist/'
+
+        apply_transform = transforms.Compose([
+            transforms.ToTensor(),
+            transforms.Normalize((0.1307,), (0.3081,))])
+
+        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
+                                       transform=apply_transform)
+
+        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
+                                      transform=apply_transform)
+
+        # sample training data amongst users
+        if args.iid:
+            # Sample IID user data from Mnist
+            user_groups = mnist_iid(train_dataset, args.num_users)
+        else:
+            # Sample Non-IID user data from Mnist
+            if args.unequal:
+                # Chose uneuqal splits for every user
+                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
+            else:
+                # Chose euqal splits for every user
+                user_groups = mnist_noniid(train_dataset, args.num_users)
+
+    return train_dataset, test_dataset, user_groups