utils.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. from torchvision import datasets, transforms
  5. from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
  6. from sampling import cifar_iid, cifar_noniid
  7. def get_dataset(args):
  8. """ Returns train and test datasets and a user group which is a dict where
  9. the keys are the user index and the values are the corresponding data for
  10. each of those users.
  11. """
  12. if args.dataset == 'cifar':
  13. data_dir = '../data/cifar/'
  14. apply_transform = transforms.Compose(
  15. [transforms.ToTensor(),
  16. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  17. train_dataset = datasets.MNIST(data_dir, train=True, download=True,
  18. transform=apply_transform)
  19. test_dataset = datasets.MNIST(data_dir, train=False, download=True,
  20. transform=apply_transform)
  21. # sample training data amongst users
  22. if args.iid:
  23. # Sample IID user data from Mnist
  24. user_groups = cifar_iid(train_dataset, args.num_users)
  25. else:
  26. # Sample Non-IID user data from Mnist
  27. if args.unequal:
  28. # Chose uneuqal splits for every user
  29. raise NotImplementedError()
  30. else:
  31. # Chose euqal splits for every user
  32. user_groups = cifar_noniid(train_dataset, args.num_users)
  33. elif args.dataset == 'mnist' or 'fmnist':
  34. if args.dataset == 'mnist':
  35. data_dir = '../data/mnist/'
  36. else:
  37. data_dir = '../data/fmnist/'
  38. apply_transform = transforms.Compose([
  39. transforms.ToTensor(),
  40. transforms.Normalize((0.1307,), (0.3081,))])
  41. train_dataset = datasets.MNIST(data_dir, train=True, download=True,
  42. transform=apply_transform)
  43. test_dataset = datasets.MNIST(data_dir, train=False, download=True,
  44. transform=apply_transform)
  45. # sample training data amongst users
  46. if args.iid:
  47. # Sample IID user data from Mnist
  48. user_groups = mnist_iid(train_dataset, args.num_users)
  49. else:
  50. # Sample Non-IID user data from Mnist
  51. if args.unequal:
  52. # Chose uneuqal splits for every user
  53. user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
  54. else:
  55. # Chose euqal splits for every user
  56. user_groups = mnist_noniid(train_dataset, args.num_users)
  57. return train_dataset, test_dataset, user_groups