sampling.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import numpy as np
  5. from torchvision import datasets, transforms
  6. def mnist_iid(dataset, num_users):
  7. """
  8. Sample I.I.D. client data from MNIST dataset
  9. :param dataset:
  10. :param num_users:
  11. :return: dict of image index
  12. """
  13. num_items = int(len(dataset)/num_users)
  14. dict_users, all_idxs = {}, [i for i in range(len(dataset))]
  15. for i in range(num_users):
  16. dict_users[i] = set(np.random.choice(all_idxs, num_items,
  17. replace=False))
  18. all_idxs = list(set(all_idxs) - dict_users[i])
  19. return dict_users
  20. def mnist_noniid(dataset, num_users):
  21. """
  22. Sample non-I.I.D client data from MNIST dataset
  23. :param dataset:
  24. :param num_users:
  25. :return:
  26. """
  27. # 60,000 training imgs --> 200 imgs/shard X 300 shards
  28. num_shards, num_imgs = 200, 300
  29. idx_shard = [i for i in range(num_shards)]
  30. dict_users = {i: np.array([]) for i in range(num_users)}
  31. idxs = np.arange(num_shards*num_imgs)
  32. labels = dataset.train_labels.numpy()
  33. # sort labels
  34. idxs_labels = np.vstack((idxs, labels))
  35. idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
  36. idxs = idxs_labels[0, :]
  37. # divide and assign 2 shards/client
  38. for i in range(num_users):
  39. rand_set = set(np.random.choice(idx_shard, 2, replace=False))
  40. idx_shard = list(set(idx_shard) - rand_set)
  41. for rand in rand_set:
  42. dict_users[i] = np.concatenate(
  43. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  44. return dict_users
  45. def mnist_noniid_unequal(dataset, num_users):
  46. """
  47. Sample non-I.I.D client data from MNIST dataset s.t clients
  48. have unequal amount of data
  49. :param dataset:
  50. :param num_users:
  51. :returns a dict of clients with each clients assigned certain
  52. number of training imgs
  53. """
  54. # 60,000 training imgs --> 50 imgs/shard X 1200 shards
  55. num_shards, num_imgs = 1200, 50
  56. idx_shard = [i for i in range(num_shards)]
  57. dict_users = {i: np.array([]) for i in range(num_users)}
  58. idxs = np.arange(num_shards*num_imgs)
  59. labels = dataset.train_labels.numpy()
  60. # sort labels
  61. idxs_labels = np.vstack((idxs, labels))
  62. idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
  63. idxs = idxs_labels[0, :]
  64. # Minimum and maximum shards assigned per client:
  65. min_shard = 1
  66. max_shard = 30
  67. # Divide the shards into random chunks for every client
  68. # s.t the sum of these chunks = num_shards
  69. random_shard_size = np.random.randint(min_shard, max_shard+1,
  70. size=num_users)
  71. random_shard_size = np.around(random_shard_size /
  72. sum(random_shard_size) * num_shards)
  73. random_shard_size = random_shard_size.astype(int)
  74. # Assign the shards randomly to each client
  75. if sum(random_shard_size) > num_shards:
  76. for i in range(num_users):
  77. # First assign each client 1 shard to ensure every client has
  78. # atleast one shard of data
  79. rand_set = set(np.random.choice(idx_shard, 1, replace=False))
  80. idx_shard = list(set(idx_shard) - rand_set)
  81. for rand in rand_set:
  82. dict_users[i] = np.concatenate(
  83. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
  84. axis=0)
  85. random_shard_size = random_shard_size-1
  86. # Next, randomly assign the remaining shards
  87. for i in range(num_users):
  88. if len(idx_shard) == 0:
  89. continue
  90. shard_size = random_shard_size[i]
  91. if shard_size > len(idx_shard):
  92. shard_size = len(idx_shard)
  93. rand_set = set(np.random.choice(idx_shard, shard_size,
  94. replace=False))
  95. idx_shard = list(set(idx_shard) - rand_set)
  96. for rand in rand_set:
  97. dict_users[i] = np.concatenate(
  98. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
  99. axis=0)
  100. else:
  101. for i in range(num_users):
  102. shard_size = random_shard_size[i]
  103. rand_set = set(np.random.choice(idx_shard, shard_size,
  104. replace=False))
  105. idx_shard = list(set(idx_shard) - rand_set)
  106. for rand in rand_set:
  107. dict_users[i] = np.concatenate(
  108. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
  109. axis=0)
  110. if len(idx_shard) > 0:
  111. # Add the leftover shards to the client with minimum images:
  112. shard_size = len(idx_shard)
  113. # Add the remaining shard to the client with lowest data
  114. k = min(dict_users, key=lambda x: len(dict_users.get(x)))
  115. rand_set = set(np.random.choice(idx_shard, shard_size,
  116. replace=False))
  117. idx_shard = list(set(idx_shard) - rand_set)
  118. for rand in rand_set:
  119. dict_users[k] = np.concatenate(
  120. (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]),
  121. axis=0)
  122. return dict_users
  123. def cifar_iid(dataset, num_users):
  124. """
  125. Sample I.I.D. client data from CIFAR10 dataset
  126. :param dataset:
  127. :param num_users:
  128. :return: dict of image index
  129. """
  130. num_items = int(len(dataset)/num_users)
  131. dict_users, all_idxs = {}, [i for i in range(len(dataset))]
  132. for i in range(num_users):
  133. dict_users[i] = set(np.random.choice(all_idxs, num_items,
  134. replace=False))
  135. all_idxs = list(set(all_idxs) - dict_users[i])
  136. return dict_users
  137. def cifar_noniid(dataset, num_users):
  138. """
  139. Sample non-I.I.D client data from CIFAR10 dataset
  140. :param dataset:
  141. :param num_users:
  142. :return:
  143. """
  144. num_shards, num_imgs = 200, 250
  145. idx_shard = [i for i in range(num_shards)]
  146. dict_users = {i: np.array([]) for i in range(num_users)}
  147. idxs = np.arange(num_shards*num_imgs)
  148. # labels = dataset.train_labels.numpy()
  149. labels = np.array(dataset.train_labels)
  150. # sort labels
  151. idxs_labels = np.vstack((idxs, labels))
  152. idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
  153. idxs = idxs_labels[0, :]
  154. # divide and assign
  155. for i in range(num_users):
  156. rand_set = set(np.random.choice(idx_shard, 2, replace=False))
  157. idx_shard = list(set(idx_shard) - rand_set)
  158. for rand in rand_set:
  159. dict_users[i] = np.concatenate(
  160. (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  161. return dict_users
  162. if __name__ == '__main__':
  163. dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
  164. transform=transforms.Compose([
  165. transforms.ToTensor(),
  166. transforms.Normalize((0.1307,),
  167. (0.3081,))
  168. ]))
  169. num = 100
  170. d = mnist_noniid(dataset_train, num)