utils.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import copy
  5. import torch
  6. from torchvision import datasets, transforms
  7. from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
  8. from sampling import cifar_iid, cifar_noniid
  9. def get_dataset(args):
  10. """ Returns train and test datasets and a user group which is a dict where
  11. the keys are the user index and the values are the corresponding data for
  12. each of those users.
  13. """
  14. if args.dataset == 'cifar':
  15. data_dir = '../data/cifar/'
  16. apply_transform = transforms.Compose(
  17. [transforms.ToTensor(),
  18. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  19. train_dataset = datasets.MNIST(data_dir, train=True, download=True,
  20. transform=apply_transform)
  21. test_dataset = datasets.MNIST(data_dir, train=False, download=True,
  22. transform=apply_transform)
  23. # sample training data amongst users
  24. if args.iid:
  25. # Sample IID user data from Mnist
  26. user_groups = cifar_iid(train_dataset, args.num_users)
  27. else:
  28. # Sample Non-IID user data from Mnist
  29. if args.unequal:
  30. # Chose uneuqal splits for every user
  31. raise NotImplementedError()
  32. else:
  33. # Chose euqal splits for every user
  34. user_groups = cifar_noniid(train_dataset, args.num_users)
  35. elif args.dataset == 'mnist' or 'fmnist':
  36. if args.dataset == 'mnist':
  37. data_dir = '../data/mnist/'
  38. else:
  39. data_dir = '../data/fmnist/'
  40. apply_transform = transforms.Compose([
  41. transforms.ToTensor(),
  42. transforms.Normalize((0.1307,), (0.3081,))])
  43. train_dataset = datasets.MNIST(data_dir, train=True, download=True,
  44. transform=apply_transform)
  45. test_dataset = datasets.MNIST(data_dir, train=False, download=True,
  46. transform=apply_transform)
  47. # sample training data amongst users
  48. if args.iid:
  49. # Sample IID user data from Mnist
  50. user_groups = mnist_iid(train_dataset, args.num_users)
  51. else:
  52. # Sample Non-IID user data from Mnist
  53. if args.unequal:
  54. # Chose uneuqal splits for every user
  55. user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
  56. else:
  57. # Chose euqal splits for every user
  58. user_groups = mnist_noniid(train_dataset, args.num_users)
  59. return train_dataset, test_dataset, user_groups
  60. def average_weights(w):
  61. """
  62. Returns the average of the weights.
  63. """
  64. w_avg = copy.deepcopy(w[0])
  65. for key in w_avg.keys():
  66. for i in range(1, len(w)):
  67. w_avg[key] += w[i][key]
  68. w_avg[key] = torch.div(w_avg[key], len(w))
  69. return w_avg