sampling.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import numpy as np
  5. from torchvision import datasets, transforms
  6. def mnist_iid(dataset, num_users):
  7. """
  8. Sample I.I.D. client data from MNIST dataset
  9. :param dataset:
  10. :param num_users:
  11. :return: dict of image index
  12. """
  13. num_items = int(len(dataset)/num_users)
  14. dict_users, all_idxs = {}, [i for i in range(len(dataset))]
  15. for i in range(num_users):
  16. dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
  17. all_idxs = list(set(all_idxs) - dict_users[i])
  18. return dict_users
  19. def mnist_noniid(dataset, num_users):
  20. """
  21. Sample non-I.I.D client data from MNIST dataset
  22. :param dataset:
  23. :param num_users:
  24. :return:
  25. """
  26. num_shards, num_imgs = 200, 300
  27. idx_shard = [i for i in range(num_shards)]
  28. dict_users = {i: np.array([]) for i in range(num_users)}
  29. idxs = np.arange(num_shards*num_imgs)
  30. labels = dataset.train_labels.numpy()
  31. # sort labels
  32. idxs_labels = np.vstack((idxs, labels))
  33. idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
  34. idxs = idxs_labels[0, :]
  35. # divide and assign
  36. for i in range(num_users):
  37. rand_set = set(np.random.choice(idx_shard, 2, replace=False))
  38. idx_shard = list(set(idx_shard) - rand_set)
  39. for rand in rand_set:
  40. dict_users[i] = np.concatenate(
  41. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  42. return dict_users
  43. def cifar_iid(dataset, num_users):
  44. """
  45. Sample I.I.D. client data from CIFAR10 dataset
  46. :param dataset:
  47. :param num_users:
  48. :return: dict of image index
  49. """
  50. num_items = int(len(dataset)/num_users)
  51. dict_users, all_idxs = {}, [i for i in range(len(dataset))]
  52. for i in range(num_users):
  53. dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
  54. all_idxs = list(set(all_idxs) - dict_users[i])
  55. return dict_users
  56. def cifar_noniid(dataset, num_users):
  57. """
  58. Sample non-I.I.D client data from CIFAR10 dataset
  59. :param dataset:
  60. :param num_users:
  61. :return:
  62. """
  63. num_shards, num_imgs = 200, 250
  64. idx_shard = [i for i in range(num_shards)]
  65. dict_users = {i: np.array([]) for i in range(num_users)}
  66. idxs = np.arange(num_shards*num_imgs)
  67. labels = dataset.train_labels.numpy()
  68. # sort labels
  69. idxs_labels = np.vstack((idxs, labels))
  70. idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
  71. idxs = idxs_labels[0, :]
  72. # divide and assign
  73. for i in range(num_users):
  74. rand_set = set(np.random.choice(idx_shard, 2, replace=False))
  75. idx_shard = list(set(idx_shard) - rand_set)
  76. for rand in rand_set:
  77. dict_users[i] = np.concatenate(
  78. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  79. return dict_users
  80. if __name__ == '__main__':
  81. dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
  82. transform=transforms.Compose([
  83. transforms.ToTensor(),
  84. transforms.Normalize((0.1307,), (0.3081,))
  85. ]))
  86. num = 100
  87. d = mnist_noniid(dataset_train, num)