Selaa lähdekoodia

Remove duplicated directory

AshwinRJ 4 vuotta sitten
vanhempi
commit
53de533bd7

+ 0 - 86
Federated_Avg/FedNets.py

@@ -1,86 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Python version: 3.6
-
-from torch import nn
-import torch.nn.functional as F
-
-
-class MLP(nn.Module):
-    def __init__(self, dim_in, dim_hidden, dim_out):
-        super(MLP, self).__init__()
-        self.layer_input = nn.Linear(dim_in, dim_hidden)
-        self.relu = nn.ReLU()
-        self.dropout = nn.Dropout()
-        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
-        self.softmax = nn.Softmax(dim=1)
-
-    def forward(self, x):
-        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
-        x = self.layer_input(x)
-        x = self.dropout(x)
-        x = self.relu(x)
-        x = self.layer_hidden(x)
-        return self.softmax(x)
-
-
-class CNNMnist(nn.Module):
-    def __init__(self, args):
-        super(CNNMnist, self).__init__()
-        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
-        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
-        self.conv2_drop = nn.Dropout2d()
-        self.fc1 = nn.Linear(320, 50)
-        self.fc2 = nn.Linear(50, args.num_classes)
-
-    def forward(self, x):
-        x = F.relu(F.max_pool2d(self.conv1(x), 2))
-        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
-        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
-        x = F.relu(self.fc1(x))
-        x = F.dropout(x, training=self.training)
-        x = self.fc2(x)
-        return F.log_softmax(x, dim=1)
-
-
-class CNNFashion_Mnist(nn.Module):
-    def __init__(self, args):
-        super(CNNFashion_Mnist, self).__init__()
-        self.layer1 = nn.Sequential(
-            nn.Conv2d(1, 16, kernel_size=5, padding=2),
-            nn.BatchNorm2d(16),
-            nn.ReLU(),
-            nn.MaxPool2d(2))
-        self.layer2 = nn.Sequential(
-            nn.Conv2d(16, 32, kernel_size=5, padding=2),
-            nn.BatchNorm2d(32),
-            nn.ReLU(),
-            nn.MaxPool2d(2))
-        self.fc = nn.Linear(7*7*32, 10)
-
-    def forward(self, x):
-        out = self.layer1(x)
-        out = self.layer2(out)
-        out = out.view(out.size(0), -1)
-        out = self.fc(out)
-        return out
-
-
-class CNNCifar(nn.Module):
-    def __init__(self, args):
-        super(CNNCifar, self).__init__()
-        self.conv1 = nn.Conv2d(3, 6, 5)
-        self.pool = nn.MaxPool2d(2, 2)
-        self.conv2 = nn.Conv2d(6, 16, 5)
-        self.fc1 = nn.Linear(16 * 5 * 5, 120)
-        self.fc2 = nn.Linear(120, 84)
-        self.fc3 = nn.Linear(84, args.num_classes)
-
-    def forward(self, x):
-        x = self.pool(F.relu(self.conv1(x)))
-        x = self.pool(F.relu(self.conv2(x)))
-        x = x.view(-1, 16 * 5 * 5)
-        x = F.relu(self.fc1(x))
-        x = F.relu(self.fc2(x))
-        x = self.fc3(x)
-        return F.log_softmax(x, dim=1)

+ 0 - 87
Federated_Avg/Update.py

@@ -1,87 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Python version: 3.6
-
-import torch
-from torch import nn, autograd
-from torch.utils.data import DataLoader, Dataset
-import numpy as np
-from sklearn import metrics
-
-
-class DatasetSplit(Dataset):
-    def __init__(self, dataset, idxs):
-        self.dataset = dataset
-        self.idxs = [int(i) for i in idxs]
-
-    def __len__(self):
-        return len(self.idxs)
-
-    def __getitem__(self, item):
-        image, label = self.dataset[self.idxs[item]]
-        return image, label
-
-
-class LocalUpdate(object):
-    def __init__(self, args, dataset, idxs, tb):
-        self.args = args
-        self.loss_func = nn.NLLLoss()
-        self.ldr_train, self.ldr_val, self.ldr_test = self.train_val_test(dataset, list(idxs))
-        self.tb = tb
-
-    def train_val_test(self, dataset, idxs):
-        # split train, validation, and test
-        idxs_train = idxs[:420]
-        idxs_val = idxs[420:480]
-        idxs_test = idxs[480:]
-        train = DataLoader(DatasetSplit(dataset, idxs_train),
-                           batch_size=self.args.local_bs, shuffle=True)
-        val = DataLoader(DatasetSplit(dataset, idxs_val),
-                         batch_size=int(len(idxs_val)/10), shuffle=True)
-        test = DataLoader(DatasetSplit(dataset, idxs_test),
-                          batch_size=int(len(idxs_test)/10), shuffle=True)
-        return train, val, test
-
-    def update_weights(self, net):
-        net.train()
-        # train and update
-        # Add support for other optimizers
-        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5)
-
-        epoch_loss = []
-        for iter in range(self.args.local_ep):
-            batch_loss = []
-            for batch_idx, (images, labels) in enumerate(self.ldr_train):
-                if self.args.gpu != -1:
-                    images, labels = images.cuda(), labels.cuda()
-                images, labels = autograd.Variable(images), autograd.Variable(labels)
-                net.zero_grad()
-                log_probs = net(images)
-                loss = self.loss_func(log_probs, labels)
-                loss.backward()
-                optimizer.step()
-                if self.args.gpu != -1:
-                    loss = loss.cpu()
-                if self.args.verbose and batch_idx % 10 == 0:
-                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
-                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
-                        100. * batch_idx / len(self.ldr_train), loss.data[0]))
-                self.tb.add_scalar('loss', loss.data[0])
-                batch_loss.append(loss.data[0])
-            epoch_loss.append(sum(batch_loss)/len(batch_loss))
-        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
-
-    def test(self, net):
-        for batch_idx, (images, labels) in enumerate(self.ldr_test):
-            if self.args.gpu != -1:
-                images, labels = images.cuda(), labels.cuda()
-            images, labels = autograd.Variable(images), autograd.Variable(labels)
-            log_probs = net(images)
-            loss = self.loss_func(log_probs, labels)
-        if self.args.gpu != -1:
-            loss = loss.cpu()
-            log_probs = log_probs.cpu()
-            labels = labels.cpu()
-        y_pred = np.argmax(log_probs.data, axis=1)
-        acc = metrics.accuracy_score(y_true=labels.data, y_pred=y_pred)
-        return acc, loss.data[0]

+ 0 - 16
Federated_Avg/averaging.py

@@ -1,16 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Python version: 3.6
-
-import copy
-import torch
-from torch import nn
-
-
-def average_weights(w):
-    w_avg = copy.deepcopy(w[0])
-    for k in w_avg.keys():
-        for i in range(1, len(w)):
-            w_avg[k] += w[i][k]
-        w_avg[k] = torch.div(w_avg[k], len(w))
-    return w_avg

+ 0 - 161
Federated_Avg/main_fedavg.py

@@ -1,161 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Python version: 3.6
-
-import matplotlib
-import matplotlib.pyplot as plt
-# matplotlib.use('Agg')
-
-import os
-import copy
-import numpy as np
-from torchvision import datasets, transforms
-from tqdm import tqdm
-import torch
-import torch.nn.functional as F
-from torch import autograd
-
-from tensorboardX import SummaryWriter
-
-from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, mnist_noniid_unequal
-from options import args_parser
-from Update import LocalUpdate
-from FedNets import MLP, CNNMnist, CNNCifar
-from averaging import average_weights
-import pickle
-
-
-if __name__ == '__main__':
-    # parse args
-    args = args_parser()
-
-    # define paths
-    path_project = os.path.abspath('..')
-
-    summary = SummaryWriter('local')
-
-    # load dataset and split users
-    if args.dataset == 'mnist':
-        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
-                                       transform=transforms.Compose([
-                                           transforms.ToTensor(),
-                                           transforms.Normalize((0.1307,), (0.3081,))
-                                       ]))
-        # sample users
-        if args.iid:
-            dict_users = mnist_iid(dataset_train, args.num_users)
-        elif args.unequal:
-            dict_users = mnist_noniid_unequal(dataset_train, args.num_users)
-        else:
-            dict_users = mnist_noniid(dataset_train, args.num_users)
-
-    elif args.dataset == 'cifar':
-        transform = transforms.Compose(
-            [transforms.ToTensor(),
-             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
-        dataset_train = datasets.CIFAR10(
-            '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
-        if args.iid:
-            dict_users = cifar_iid(dataset_train, args.num_users)
-        else:
-            dict_users = cifar_noniid(dataset_train, args.num_users)
-    else:
-        exit('Error: unrecognized dataset')
-    img_size = dataset_train[0][0].shape
-
-    # BUILD MODEL
-    if args.model == 'cnn' and args.dataset == 'mnist':
-        if args.gpu != -1:
-            torch.cuda.set_device(args.gpu)
-            net_glob = CNNMnist(args=args).cuda()
-        else:
-            net_glob = CNNMnist(args=args)
-    elif args.model == 'cnn' and args.dataset == 'cifar':
-        if args.gpu != -1:
-            torch.cuda.set_device(args.gpu)
-            net_glob = CNNCifar(args=args).cuda()
-        else:
-            net_glob = CNNCifar(args=args)
-    elif args.model == 'mlp':
-        len_in = 1
-        for x in img_size:
-            len_in *= x
-        if args.gpu != -1:
-            torch.cuda.set_device(args.gpu)
-            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
-        else:
-            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
-    else:
-        exit('Error: unrecognized model')
-    print(net_glob)
-    net_glob.train()
-
-    # copy weights
-    w_glob = net_glob.state_dict()
-
-    # training
-    train_loss = []
-    train_accuracy = []
-    cv_loss, cv_acc = [], []
-    val_loss_pre, counter = 0, 0
-    net_best = None
-    val_acc_list, net_list = [], []
-    for iter in tqdm(range(args.epochs)):
-        w_locals, loss_locals = [], []
-        m = max(int(args.frac * args.num_users), 1)
-        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
-        for idx in idxs_users:
-            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
-            w, loss = local.update_weights(net=copy.deepcopy(net_glob))
-            w_locals.append(copy.deepcopy(w))
-            loss_locals.append(copy.deepcopy(loss))
-        # update global weights
-        w_glob = average_weights(w_locals)
-
-        # copy weight to net_glob
-        net_glob.load_state_dict(w_glob)
-
-        # print loss after every 'i' rounds
-        print_every = 5
-        loss_avg = sum(loss_locals) / len(loss_locals)
-        if iter % print_every == 0:
-            print('\nTrain loss:', loss_avg)
-        train_loss.append(loss_avg)
-
-        # Calculate avg accuracy over all users at every epoch
-        list_acc, list_loss = [], []
-        net_glob.eval()
-        for c in range(args.num_users):
-            net_local = LocalUpdate(args=args, dataset=dataset_train,
-                                    idxs=dict_users[c], tb=summary)
-            acc, loss = net_local.test(net=net_glob)
-            list_acc.append(acc)
-            list_loss.append(loss)
-        train_accuracy.append(sum(list_acc)/len(list_acc))
-
-    # Plot Loss curve
-    # plt.figure()
-    # plt.title('Training Loss vs Communication rounds')
-    # plt.plot(range(len(train_loss)), train_loss, color='r')
-    # plt.ylabel('Training loss')
-    # plt.xlabel('Communication Rounds')
-    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(args.dataset,
-    #                                                                              args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
-    #
-    # # Plot Average Accuracy vs Communication rounds
-    # plt.figure()
-    # plt.title('Average Accuracy vs Communication rounds')
-    # plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
-    # plt.ylabel('Average Accuracy')
-    # plt.xlabel('Communication Rounds')
-    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(args.dataset,
-    #                                                                             args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs))
-
-    print("Final Average Accuracy after {} epochs: {:.2f}%".format(
-        args.epochs, 100.*train_accuracy[-1]))
-
-# Saving the objects train_loss and train_accuracy:
-file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(args.dataset,
-                                                                            args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)
-with open(file_name, 'wb') as f:
-    pickle.dump([train_loss, train_accuracy], f)

+ 0 - 149
Federated_Avg/main_nn.py

@@ -1,149 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Python version: 3.6
-
-
-from tqdm import tqdm
-import torch
-import torch.nn.functional as F
-from torch.utils.data import DataLoader
-from torch import autograd
-import torch.optim as optim
-from torchvision import datasets, transforms
-
-from options import args_parser
-from FedNets import MLP, CNNMnist, CNNCifar
-
-import matplotlib
-import matplotlib.pyplot as plt
-matplotlib.use('Agg')
-
-
-def test(net_g, data_loader):
-    # testing
-    net_g.eval()
-    test_loss = 0
-    correct = 0
-    l = len(data_loader)
-    for idx, (data, target) in enumerate(data_loader):
-        if args.gpu != -1:
-            data, target = data.cuda(), target.cuda()
-        data, target = autograd.Variable(data, volatile=True), autograd.Variable(target)
-        log_probs = net_g(data)
-        test_loss += F.nll_loss(log_probs, target, size_average=False).data[0]
-        y_pred = log_probs.data.max(1, keepdim=True)[1]
-        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
-
-    test_loss /= len(data_loader.dataset)
-    print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
-        test_loss, correct, len(data_loader.dataset),
-        100. * correct / len(data_loader.dataset)))
-
-    return correct, test_loss
-
-
-if __name__ == '__main__':
-    # parse args
-    args = args_parser()
-
-    torch.manual_seed(args.seed)
-
-    # load dataset and split users
-    if args.dataset == 'mnist':
-        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
-                                       transform=transforms.Compose([
-                                           transforms.ToTensor(),
-                                           transforms.Normalize((0.1307,), (0.3081,))
-                                       ]))
-        img_size = dataset_train[0][0].shape
-
-    elif args.dataset == 'cifar':
-        transform = transforms.Compose(
-            [transforms.ToTensor(),
-             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
-        dataset_train = datasets.CIFAR10(
-            '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
-        img_size = dataset_train[0][0].shape
-    else:
-        exit('Error: unrecognized dataset')
-
-    # build model
-    if args.model == 'cnn' and args.dataset == 'cifar':
-        if args.gpu != -1:
-            torch.cuda.set_device(args.gpu)
-            net_glob = CNNCifar(args=args).cuda()
-        else:
-            net_glob = CNNCifar(args=args)
-    elif args.model == 'cnn' and args.dataset == 'mnist':
-        if args.gpu != -1:
-            torch.cuda.set_device(args.gpu)
-            net_glob = CNNMnist(args=args).cuda()
-        else:
-            net_glob = CNNMnist(args=args)
-    elif args.model == 'mlp':
-        len_in = 1
-        for x in img_size:
-            len_in *= x
-        if args.gpu != -1:
-            torch.cuda.set_device(args.gpu)
-            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
-        else:
-            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
-    else:
-        exit('Error: unrecognized model')
-    print(net_glob)
-
-    # training
-    optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
-    train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
-
-    list_loss = []
-    net_glob.train()
-    for epoch in tqdm(range(args.epochs)):
-        batch_loss = []
-        for batch_idx, (data, target) in enumerate(train_loader):
-            if args.gpu != -1:
-                data, target = data.cuda(), target.cuda()
-            data, target = autograd.Variable(data), autograd.Variable(target)
-            optimizer.zero_grad()
-            output = net_glob(data)
-            loss = F.nll_loss(output, target)
-            loss.backward()
-            optimizer.step()
-            if batch_idx % 50 == 0:
-                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
-                    epoch, batch_idx * len(data), len(train_loader.dataset),
-                    100. * batch_idx / len(train_loader), loss.data[0]))
-            batch_loss.append(loss.data[0])
-        loss_avg = sum(batch_loss)/len(batch_loss)
-        print('\nTrain loss:', loss_avg)
-        list_loss.append(loss_avg)
-
-    # plot loss
-    plt.figure()
-    plt.plot(range(len(list_loss)), list_loss)
-    plt.xlabel('epochs')
-    plt.ylabel('train loss')
-    plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))
-
-    # testing
-    if args.dataset == 'mnist':
-        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True,
-                                      transform=transforms.Compose([
-                                          transforms.ToTensor(),
-                                          transforms.Normalize((0.1307,), (0.3081,))
-                                      ]))
-        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
-
-    elif args.dataset == 'cifar':
-        transform = transforms.Compose(
-            [transforms.ToTensor(),
-             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
-        dataset_test = datasets.CIFAR10('../data/cifar', train=False,
-                                        transform=transform, target_transform=None, download=True)
-        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
-    else:
-        exit('Error: unrecognized dataset')
-
-    print('Test on', len(dataset_test), 'samples')
-    test_acc, test_loss = test(net_glob, test_loader)

+ 0 - 53
Federated_Avg/options.py

@@ -1,53 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Python version: 3.6
-
-import argparse
-
-
-def args_parser():
-    parser = argparse.ArgumentParser()
-
-    # federated arguments (Notation for the arguments followed from paper)
-    parser.add_argument('--epochs', type=int, default=10,
-                        help="number of rounds of training")
-    parser.add_argument('--num_users', type=int, default=100,
-                        help="number of users: K")
-    parser.add_argument('--frac', type=float, default=0.1,
-                        help='the fraction of clients: C')
-    parser.add_argument('--local_ep', type=int, default=5,
-                        help="the number of local epochs: E")
-    parser.add_argument('--local_bs', type=int, default=10,
-                        help="local batch size: B")
-    parser.add_argument('--lr', type=float, default=0.01,
-                        help='learning rate')
-    parser.add_argument('--momentum', type=float, default=0.5,
-                        help='SGD momentum (default: 0.5)')
-
-    # model arguments
-    parser.add_argument('--model', type=str, default='mlp', help='model name')
-    parser.add_argument('--kernel_num', type=int, default=9,
-                        help='number of each kind of kernel')
-    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
-                        help='comma-separated kernel size to use for convolution')
-    parser.add_argument('--num_channels', type=int, default=1, help="number of channels of imgs")
-    parser.add_argument('--norm', type=str, default='batch_norm',
-                        help="batch_norm, layer_norm, or None")
-    parser.add_argument('--num_filters', type=int, default=32,
-                        help="number of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot.")
-    parser.add_argument('--max_pool', type=str, default='True',
-                        help="Whether use max pooling rather than strided convolutions")
-
-    # other arguments
-    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
-    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
-    parser.add_argument('--gpu', type=int, default=1, help="GPU ID")
-    parser.add_argument('--iid', type=int, default=0,
-                        help='whether i.i.d or not: 1 for iid, 0 for non-iid')
-    parser.add_argument('--unequal', type=int, default=0,
-                        help='whether to use unequal data splits for  non-i.i.d setting (use 0 for equal splits)')
-    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
-    parser.add_argument('--verbose', type=int, default=1, help='verbose')
-    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
-    args = parser.parse_args()
-    return args

+ 0 - 186
Federated_Avg/sampling.py

@@ -1,186 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Python version: 3.6
-
-
-import numpy as np
-from torchvision import datasets, transforms
-
-
-def mnist_iid(dataset, num_users):
-    """
-    Sample I.I.D. client data from MNIST dataset
-    :param dataset:
-    :param num_users:
-    :return: dict of image index
-    """
-    num_items = int(len(dataset)/num_users)
-    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
-    for i in range(num_users):
-        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
-        all_idxs = list(set(all_idxs) - dict_users[i])
-    return dict_users
-
-
-def mnist_noniid(dataset, num_users):
-    """
-    Sample non-I.I.D client data from MNIST dataset
-    :param dataset:
-    :param num_users:
-    :return:
-    """
-    # 60,000 training imgs -->  200 imgs/shard X 300 shards
-    num_shards, num_imgs = 200, 300
-    idx_shard = [i for i in range(num_shards)]
-    dict_users = {i: np.array([]) for i in range(num_users)}
-    idxs = np.arange(num_shards*num_imgs)
-    labels = dataset.train_labels.numpy()
-
-    # sort labels
-    idxs_labels = np.vstack((idxs, labels))
-    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
-    idxs = idxs_labels[0, :]
-
-    # divide and assign 2 shards/client
-    for i in range(num_users):
-        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
-        idx_shard = list(set(idx_shard) - rand_set)
-        for rand in rand_set:
-            dict_users[i] = np.concatenate(
-                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
-    return dict_users
-
-
-def mnist_noniid_unequal(dataset, num_users):
-    """
-    Sample non-I.I.D client data from MNIST dataset s.t clients
-    have unequal amount of data
-    :param dataset:
-    :param num_users:
-    :returns a dict of clients with each clients assigned certain
-    number of training imgs
-    """
-    # 60,000 training imgs --> 50 imgs/shard X 1200 shards
-    num_shards, num_imgs = 1200, 50
-    idx_shard = [i for i in range(num_shards)]
-    dict_users = {i: np.array([]) for i in range(num_users)}
-    idxs = np.arange(num_shards*num_imgs)
-    labels = dataset.train_labels.numpy()
-
-    # sort labels
-    idxs_labels = np.vstack((idxs, labels))
-    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
-    idxs = idxs_labels[0, :]
-
-    # Minimum and maximum shards assigned per client:
-    min_shard = 1
-    max_shard = 30
-
-    # Divide the shards into random chunks for every client
-    # s.t the sum of these chunks = num_shards
-    random_shard_size = np.random.randint(min_shard, max_shard+1, size=num_users)
-    random_shard_size = np.around(random_shard_size/sum(random_shard_size) * num_shards)
-    random_shard_size = random_shard_size.astype(int)
-
-    # Assign the shards randomly to each client
-    if sum(random_shard_size) > num_shards:
-
-        for i in range(num_users):
-            # First assign each client 1 shard to ensure every client has
-            # atleast one shard of data
-            rand_set = set(np.random.choice(idx_shard, 1, replace=False))
-            idx_shard = list(set(idx_shard) - rand_set)
-            for rand in rand_set:
-                dict_users[i] = np.concatenate(
-                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
-
-        random_shard_size = random_shard_size-1
-
-        # Next, randomly assign the remaining shards
-        for i in range(num_users):
-            if len(idx_shard) == 0:
-                continue
-            shard_size = random_shard_size[i]
-            if shard_size > len(idx_shard):
-                shard_size = len(idx_shard)
-            rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
-            idx_shard = list(set(idx_shard) - rand_set)
-            for rand in rand_set:
-                dict_users[i] = np.concatenate(
-                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
-    else:
-
-        for i in range(num_users):
-            shard_size = random_shard_size[i]
-            rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
-            idx_shard = list(set(idx_shard) - rand_set)
-            for rand in rand_set:
-                dict_users[i] = np.concatenate(
-                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
-
-        if len(idx_shard) > 0:
-            # Add the leftover shards to the client with minimum images:
-            shard_size = len(idx_shard)
-            # Add the remaining shard to the client with lowest data
-            k = min(dict_users, key=lambda x: len(dict_users.get(x)))
-            rand_set = set(np.random.choice(idx_shard, shard_size, replace=False))
-            idx_shard = list(set(idx_shard) - rand_set)
-            for rand in rand_set:
-                dict_users[k] = np.concatenate(
-                    (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
-
-    return dict_users
-
-
-def cifar_iid(dataset, num_users):
-    """
-    Sample I.I.D. client data from CIFAR10 dataset
-    :param dataset:
-    :param num_users:
-    :return: dict of image index
-    """
-    num_items = int(len(dataset)/num_users)
-    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
-    for i in range(num_users):
-        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
-        all_idxs = list(set(all_idxs) - dict_users[i])
-    return dict_users
-
-
-def cifar_noniid(dataset, num_users):
-    """
-    Sample non-I.I.D client data from CIFAR10 dataset
-    :param dataset:
-    :param num_users:
-    :return:
-    """
-    num_shards, num_imgs = 200, 250
-    idx_shard = [i for i in range(num_shards)]
-    dict_users = {i: np.array([]) for i in range(num_users)}
-    idxs = np.arange(num_shards*num_imgs)
-    # labels = dataset.train_labels.numpy()
-    labels = np.array(dataset.train_labels)
-
-    # sort labels
-    idxs_labels = np.vstack((idxs, labels))
-    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
-    idxs = idxs_labels[0, :]
-
-    # divide and assign
-    for i in range(num_users):
-        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
-        idx_shard = list(set(idx_shard) - rand_set)
-        for rand in rand_set:
-            dict_users[i] = np.concatenate(
-                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
-    return dict_users
-
-
-if __name__ == '__main__':
-    dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
-                                   transform=transforms.Compose([
-                                       transforms.ToTensor(),
-                                       transforms.Normalize((0.1307,), (0.3081,))
-                                   ]))
-    num = 100
-    d = mnist_noniid(dataset_train, num)