|
@@ -17,7 +17,8 @@ def mnist_iid(dataset, num_users):
|
|
|
num_items = int(len(dataset)/num_users)
|
|
|
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
|
|
|
for i in range(num_users):
|
|
|
- dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
|
|
|
+ dict_users[i] = set(np.random.choice(all_idxs, num_items,
|
|
|
+ replace=False))
|
|
|
all_idxs = list(set(all_idxs) - dict_users[i])
|
|
|
return dict_users
|
|
|
|
|
@@ -78,8 +79,10 @@ def mnist_noniid_unequal(dataset, num_users):
|
|
|
|
|
|
# Divide the shards into random chunks for every client
|
|
|
# s.t the sum of these chunks = num_shards
|
|
|
- random_shard_size = np.random.randint(min_shard, max_shard+1, size=num_users)
|
|
|
- random_shard_size = np.around(random_shard_size/sum(random_shard_size) * num_shards)
|
|
|
+ random_shard_size = np.random.randint(min_shard, max_shard+1,
|
|
|
+ size=num_users)
|
|
|
+ random_shard_size = np.around(random_shard_size /
|
|
|
+ sum(random_shard_size) * num_shards)
|
|
|
random_shard_size = random_shard_size.astype(int)
|
|
|
|
|
|
# Assign the shards randomly to each client
|
|
@@ -92,7 +95,8 @@ def mnist_noniid_unequal(dataset, num_users):
|
|
|
idx_shard = list(set(idx_shard) - rand_set)
|
|
|
for rand in rand_set:
|
|
|
dict_users[i] = np.concatenate(
|
|
|
- (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
|
|
|
+ (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
|
|
|
+ axis=0)
|
|
|
|
|
|
random_shard_size = random_shard_size-1
|
|
|
|
|
@@ -103,31 +107,37 @@ def mnist_noniid_unequal(dataset, num_users):
|
|
|
shard_size = random_shard_size[i]
|
|
|
if shard_size > len(idx_shard):
|
|
|
shard_size = len(idx_shard)
|
|
|
- rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
|
|
|
+ rand_set = set(np.random.choice(idx_shard, shard_size,
|
|
|
+ replace=False))
|
|
|
idx_shard = list(set(idx_shard) - rand_set)
|
|
|
for rand in rand_set:
|
|
|
dict_users[i] = np.concatenate(
|
|
|
- (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
|
|
|
+ (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
|
|
|
+ axis=0)
|
|
|
else:
|
|
|
|
|
|
for i in range(num_users):
|
|
|
shard_size = random_shard_size[i]
|
|
|
- rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
|
|
|
+ rand_set = set(np.random.choice(idx_shard, shard_size,
|
|
|
+ replace=False))
|
|
|
idx_shard = list(set(idx_shard) - rand_set)
|
|
|
for rand in rand_set:
|
|
|
dict_users[i] = np.concatenate(
|
|
|
- (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
|
|
|
+ (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
|
|
|
+ axis=0)
|
|
|
|
|
|
if len(idx_shard) > 0:
|
|
|
# Add the leftover shards to the client with minimum images:
|
|
|
shard_size = len(idx_shard)
|
|
|
# Add the remaining shard to the client with lowest data
|
|
|
k = min(dict_users, key=lambda x: len(dict_users.get(x)))
|
|
|
- rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
|
|
|
+ rand_set = set(np.random.choice(idx_shard, shard_size,
|
|
|
+ replace=False))
|
|
|
idx_shard = list(set(idx_shard) - rand_set)
|
|
|
for rand in rand_set:
|
|
|
dict_users[k] = np.concatenate(
|
|
|
- (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
|
|
|
+ (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]),
|
|
|
+ axis=0)
|
|
|
|
|
|
return dict_users
|
|
|
|
|
@@ -142,7 +152,8 @@ def cifar_iid(dataset, num_users):
|
|
|
num_items = int(len(dataset)/num_users)
|
|
|
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
|
|
|
for i in range(num_users):
|
|
|
- dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
|
|
|
+ dict_users[i] = set(np.random.choice(all_idxs, num_items,
|
|
|
+ replace=False))
|
|
|
all_idxs = list(set(all_idxs) - dict_users[i])
|
|
|
return dict_users
|
|
|
|
|
@@ -180,7 +191,8 @@ if __name__ == '__main__':
|
|
|
dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
|
|
|
transform=transforms.Compose([
|
|
|
transforms.ToTensor(),
|
|
|
- transforms.Normalize((0.1307,), (0.3081,))
|
|
|
+ transforms.Normalize((0.1307,),
|
|
|
+ (0.3081,))
|
|
|
]))
|
|
|
num = 100
|
|
|
d = mnist_noniid(dataset_train, num)
|