sampling.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import numpy as np
  5. from torchvision import datasets, transforms
  6. def mnist_iid(dataset, num_users):
  7. """
  8. Sample I.I.D. client data from MNIST dataset
  9. :param dataset:
  10. :param num_users:
  11. :return: dict of image index
  12. """
  13. num_items = int(len(dataset)/num_users)
  14. dict_users, all_idxs = {}, [i for i in range(len(dataset))]
  15. for i in range(num_users):
  16. dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
  17. all_idxs = list(set(all_idxs) - dict_users[i])
  18. return dict_users
  19. def mnist_noniid(dataset, num_users):
  20. """
  21. Sample non-I.I.D client data from MNIST dataset
  22. :param dataset:
  23. :param num_users:
  24. :return:
  25. """
  26. # 60,000 training imgs --> 200 imgs/shard X 300 shards
  27. num_shards, num_imgs = 200, 300
  28. idx_shard = [i for i in range(num_shards)]
  29. dict_users = {i: np.array([]) for i in range(num_users)}
  30. idxs = np.arange(num_shards*num_imgs)
  31. labels = dataset.train_labels.numpy()
  32. # sort labels
  33. idxs_labels = np.vstack((idxs, labels))
  34. idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
  35. idxs = idxs_labels[0, :]
  36. # divide and assign 2 shards/client
  37. for i in range(num_users):
  38. rand_set = set(np.random.choice(idx_shard, 2, replace=False))
  39. idx_shard = list(set(idx_shard) - rand_set)
  40. for rand in rand_set:
  41. dict_users[i] = np.concatenate(
  42. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  43. return dict_users
  44. def mnist_noniid_unequal(dataset, num_users):
  45. """
  46. Sample non-I.I.D client data from MNIST dataset s.t clients
  47. have unequal amount of data
  48. :param dataset:
  49. :param num_users:
  50. :returns a dict of clients with each clients assigned certain
  51. number of training imgs
  52. """
  53. # 60,000 training imgs --> 50 imgs/shard X 1200 shards
  54. num_shards, num_imgs = 1200, 50
  55. idx_shard = [i for i in range(num_shards)]
  56. dict_users = {i: np.array([]) for i in range(num_users)}
  57. idxs = np.arange(num_shards*num_imgs)
  58. labels = dataset.train_labels.numpy()
  59. # sort labels
  60. idxs_labels = np.vstack((idxs, labels))
  61. idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
  62. idxs = idxs_labels[0, :]
  63. # Minimum and maximum shards assigned per client:
  64. min_shard = 1
  65. max_shard = 30
  66. # Divide the shards into random chunks for every client
  67. # s.t the sum of these chunks = num_shards
  68. random_shard_size = np.random.randint(min_shard, max_shard+1, size=num_users)
  69. random_shard_size = np.around(random_shard_size/sum(random_shard_size) * num_shards)
  70. random_shard_size = random_shard_size.astype(int)
  71. # Assign the shards randomly to each client
  72. if sum(random_shard_size) > num_shards:
  73. for i in range(num_users):
  74. # First assign each client 1 shard to ensure every client has
  75. # atleast one shard of data
  76. rand_set = set(np.random.choice(idx_shard, 1, replace=False))
  77. idx_shard = list(set(idx_shard) - rand_set)
  78. for rand in rand_set:
  79. dict_users[i] = np.concatenate(
  80. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  81. random_shard_size = random_shard_size-1
  82. # Next, randomly assign the remaining shards
  83. for i in range(num_users):
  84. if len(idx_shard) == 0:
  85. continue
  86. shard_size = random_shard_size[i]
  87. if shard_size > len(idx_shard):
  88. shard_size = len(idx_shard)
  89. rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
  90. idx_shard = list(set(idx_shard) - rand_set)
  91. for rand in rand_set:
  92. dict_users[i] = np.concatenate(
  93. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  94. else:
  95. for i in range(num_users):
  96. shard_size = random_shard_size[i]
  97. rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
  98. idx_shard = list(set(idx_shard) - rand_set)
  99. for rand in rand_set:
  100. dict_users[i] = np.concatenate(
  101. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  102. if len(idx_shard) > 0:
  103. # Add the leftover shards to the client with minimum images:
  104. shard_size = len(idx_shard)
  105. # Add the remaining shard to the client with lowest data
  106. k = min(dict_users, key=lambda x: len(dict_users.get(x)))
  107. rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
  108. idx_shard = list(set(idx_shard) - rand_set)
  109. for rand in rand_set:
  110. dict_users[k] = np.concatenate(
  111. (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  112. return dict_users
  113. def cifar_iid(dataset, num_users):
  114. """
  115. Sample I.I.D. client data from CIFAR10 dataset
  116. :param dataset:
  117. :param num_users:
  118. :return: dict of image index
  119. """
  120. num_items = int(len(dataset)/num_users)
  121. dict_users, all_idxs = {}, [i for i in range(len(dataset))]
  122. for i in range(num_users):
  123. dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
  124. all_idxs = list(set(all_idxs) - dict_users[i])
  125. return dict_users
  126. def cifar_noniid(dataset, num_users):
  127. """
  128. Sample non-I.I.D client data from CIFAR10 dataset
  129. :param dataset:
  130. :param num_users:
  131. :return:
  132. """
  133. num_shards, num_imgs = 200, 250
  134. idx_shard = [i for i in range(num_shards)]
  135. dict_users = {i: np.array([]) for i in range(num_users)}
  136. idxs = np.arange(num_shards*num_imgs)
  137. # labels = dataset.train_labels.numpy()
  138. labels = np.array(dataset.train_labels)
  139. # sort labels
  140. idxs_labels = np.vstack((idxs, labels))
  141. idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
  142. idxs = idxs_labels[0, :]
  143. # divide and assign
  144. for i in range(num_users):
  145. rand_set = set(np.random.choice(idx_shard, 2, replace=False))
  146. idx_shard = list(set(idx_shard) - rand_set)
  147. for rand in rand_set:
  148. dict_users[i] = np.concatenate(
  149. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  150. return dict_users
  151. if __name__ == '__main__':
  152. dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
  153. transform=transforms.Compose([
  154. transforms.ToTensor(),
  155. transforms.Normalize((0.1307,), (0.3081,))
  156. ]))
  157. num = 100
  158. d = mnist_noniid(dataset_train, num)