AshwinRJ преди 6 години
родител
ревизия
e00b27389e
променени са 4 файла, в които са добавени 383 реда и са изтрити 0 реда
  1. 87 0
      Update.py
  2. 16 0
      averaging.py
  3. 177 0
      main_fedavg.py
  4. 103 0
      sampling.py

+ 87 - 0
Update.py

@@ -0,0 +1,87 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+import torch
+from torch import nn, autograd
+from torch.utils.data import DataLoader, Dataset
+import numpy as np
+from sklearn import metrics
+
+
+class DatasetSplit(Dataset):
+    def __init__(self, dataset, idxs):
+        self.dataset = dataset
+        self.idxs = list(idxs)
+
+    def __len__(self):
+        return len(self.idxs)
+
+    def __getitem__(self, item):
+        image, label = self.dataset[self.idxs[item]]
+        return image, label
+
+
+class LocalUpdate(object):
+    def __init__(self, args, dataset, idxs, tb):
+        self.args = args
+        self.loss_func = nn.NLLLoss()
+        self.ldr_train, self.ldr_val, self.ldr_test = self.train_val_test(dataset, list(idxs))
+        self.tb = tb
+
+    def train_val_test(self, dataset, idxs):
+        # split train, validation, and test
+        idxs_train = idxs[:420]
+        idxs_val = idxs[420:480]
+        idxs_test = idxs[480:]
+        train = DataLoader(DatasetSplit(dataset, idxs_train),
+                           batch_size=self.args.local_bs, shuffle=True)
+        val = DataLoader(DatasetSplit(dataset, idxs_val),
+                         batch_size=int(len(idxs_val)/10), shuffle=True)
+        test = DataLoader(DatasetSplit(dataset, idxs_test),
+                          batch_size=int(len(idxs_test)/10), shuffle=True)
+        return train, val, test
+
+    def update_weights(self, net):
+        net.train()
+        # train and update
+        # Add support for other optimizers
+        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5)
+
+        epoch_loss = []
+        for iter in range(self.args.local_ep):
+            batch_loss = []
+            for batch_idx, (images, labels) in enumerate(self.ldr_train):
+                if self.args.gpu != -1:
+                    images, labels = images.cuda(), labels.cuda()
+                images, labels = autograd.Variable(images), autograd.Variable(labels)
+                net.zero_grad()
+                log_probs = net(images)
+                loss = self.loss_func(log_probs, labels)
+                loss.backward()
+                optimizer.step()
+                if self.args.gpu != -1:
+                    loss = loss.cpu()
+                if self.args.verbose and batch_idx % 10 == 0:
+                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
+                        100. * batch_idx / len(self.ldr_train), loss.data[0]))
+                self.tb.add_scalar('loss', loss.data[0])
+                batch_loss.append(loss.data[0])
+            epoch_loss.append(sum(batch_loss)/len(batch_loss))
+        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
+
+    def test(self, net):
+        for batch_idx, (images, labels) in enumerate(self.ldr_test):
+            if self.args.gpu != -1:
+                images, labels = images.cuda(), labels.cuda()
+            images, labels = autograd.Variable(images), autograd.Variable(labels)
+            log_probs = net(images)
+            loss = self.loss_func(log_probs, labels)
+        if self.args.gpu != -1:
+            loss = loss.cpu()
+            log_probs = log_probs.cpu()
+            labels = labels.cpu()
+        y_pred = np.argmax(log_probs.data, axis=1)
+        acc = metrics.accuracy_score(y_true=labels.data, y_pred=y_pred)
+        return acc, loss.data[0]

+ 16 - 0
averaging.py

@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+import copy
+import torch
+from torch import nn
+
+
+def average_weights(w):
+    w_avg = copy.deepcopy(w[0])
+    for k in w_avg.keys():
+        for i in range(1, len(w)):
+            w_avg[k] += w[i][k]
+        w_avg[k] = torch.div(w_avg[k], len(w))
+    return w_avg

+ 177 - 0
main_fedavg.py

@@ -0,0 +1,177 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+
+import os
+import copy
+import numpy as np
+from torchvision import datasets, transforms
+from tqdm import tqdm
+import torch
+import torch.nn.functional as F
+from torch import autograd
+from tensorboardX import SummaryWriter
+
+from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid
+from options import args_parser
+from Update import LocalUpdate
+from FedNets import MLP, CNNMnist, CNNCifar
+from averaging import average_weights
+
+import matplotlib
+import matplotlib.pyplot as plt
+matplotlib.use('Agg')
+
+
+# def test(net_g, data_loader, args):
+#     # testing
+#     test_loss = 0
+#     correct = 0
+#     # Test for the below line
+#     l = len(data_loader)
+#
+#     for idx, (data, target) in enumerate(data_loader):
+#         if args.gpu != -1:
+#             data, target = data.cuda(), target.cuda()
+#         data, target = autograd.Variable(data), autograd.Variable(target)
+#         log_probs = net_g(data)
+#         test_loss += F.nll_loss(log_probs, target, size_average=False).data[0]  # sum up batch loss
+#         y_pred = log_probs.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
+#         correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
+#
+#     test_loss /= len(data_loader.dataset)
+#     print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
+#         test_loss, correct, len(data_loader.dataset),
+#         100. * correct / len(data_loader.dataset)))
+#     return correct, test_loss
+
+
+if __name__ == '__main__':
+    # parse args
+    args = args_parser()
+
+    # define paths
+    path_project = os.path.abspath('..')
+
+    summary = SummaryWriter('local')
+
+    # load dataset and split users
+    if args.dataset == 'mnist':
+        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
+                                       transform=transforms.Compose([
+                                           transforms.ToTensor(),
+                                           transforms.Normalize((0.1307,), (0.3081,))
+                                       ]))
+        # sample users
+        if args.iid:
+            dict_users = mnist_iid(dataset_train, args.num_users)
+        else:
+            dict_users = mnist_noniid(dataset_train, args.num_users)
+
+    elif args.dataset == 'cifar':
+        transform = transforms.Compose(
+            [transforms.ToTensor(),
+             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+        dataset_train = datasets.CIFAR10(
+            '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
+        if args.iid:
+            dict_users = cifar_iid(dataset_train, args.num_users)
+        else:
+            dict_users = cifar_noniid(dataset_train, args.num_users)
+    else:
+        exit('Error: unrecognized dataset')
+    img_size = dataset_train[0][0].shape
+
+    # BUILD MODEL
+    # Using same models for MNIST and FashionMNIST
+    if args.model == 'cnn' and args.dataset == 'mnist':
+        if args.gpu != -1:
+            torch.cuda.set_device(args.gpu)
+            net_glob = CNNMnist(args=args).cuda()
+        else:
+            net_glob = CNNMnist(args=args)
+
+    elif args.model == 'cnn' and args.dataset == 'cifar':
+        if args.gpu != -1:
+            torch.cuda.set_device(args.gpu)
+            net_glob = CNNCifar(args=args).cuda()
+        else:
+            net_glob = CNNCifar(args=args)
+    elif args.model == 'mlp':
+        len_in = 1
+        for x in img_size:
+            len_in *= x
+        if args.gpu != -1:
+            torch.cuda.set_device(args.gpu)
+            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
+        else:
+            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
+    else:
+        exit('Error: unrecognized model')
+    print(net_glob)
+    net_glob.train()
+
+    # copy weights
+    w_glob = net_glob.state_dict()
+
+    # training
+    loss_train = []
+    train_accuracy = []
+    cv_loss, cv_acc = [], []
+    val_loss_pre, counter = 0, 0
+    net_best = None
+    val_acc_list, net_list = [], []
+    for iter in tqdm(range(args.epochs)):
+        w_locals, loss_locals = [], []
+        m = max(int(args.frac * args.num_users), 1)
+        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
+        for idx in idxs_users:
+            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
+            w, loss = local.update_weights(net=copy.deepcopy(net_glob))
+            w_locals.append(copy.deepcopy(w))
+            loss_locals.append(copy.deepcopy(loss))
+        # update global weights
+        w_glob = average_weights(w_locals)
+
+        # copy weight to net_glob
+        net_glob.load_state_dict(w_glob)
+
+        # print loss
+        loss_avg = sum(loss_locals) / len(loss_locals)
+        if args.epochs % 10 == 0:
+            print('\nTrain loss:', loss_avg)
+        loss_train.append(loss_avg)
+
+        # Calculate avg accuracy over all users at every epoch
+        list_acc, list_loss = [], []
+        net_glob.eval()
+        for c in tqdm(range(args.num_users)):
+            net_local = LocalUpdate(args=args, dataset=dataset_train,
+                                    idxs=dict_users[c], tb=summary)
+            acc, loss = net_local.test(net=net_glob)
+            list_acc.append(acc)
+            list_loss.append(loss)
+        train_accuracy.append(sum(list_acc)/len(list_acc))
+
+    # Plot Loss curve
+    plt.figure()
+    plt.title('Training Loss vs Communication rounds')
+    plt.plot(range(len(loss_train)), loss_train, color='r')
+    plt.ylabel('Training loss')
+    plt.xlabel('Communication Rounds')
+    plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}_loss.png'.format(args.dataset,
+                                                                 args.model, args.epochs, args.frac, args.iid))
+
+    # testing (original)
+    list_acc, list_loss = [], []
+    net_glob.eval()
+    for c in tqdm(range(args.num_users)):
+        net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary)
+        acc, loss = net_local.test(net=net_glob)
+        list_acc.append(acc)
+        list_loss.append(loss)
+    print("Final Average Accuracy after {} epochs: {:.2f}%".format(
+        args.epochs, (100.*sum(list_acc)/len(list_acc))))
+
+    print("Final Average Accuracy after {} epochs: {:.2f}%".format(args.epochs, train_accuracy[-1])

+ 103 - 0
sampling.py

@@ -0,0 +1,103 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+
+import numpy as np
+from torchvision import datasets, transforms
+
+
+def mnist_iid(dataset, num_users):
+    """
+    Sample I.I.D. client data from MNIST dataset
+    :param dataset:
+    :param num_users:
+    :return: dict of image index
+    """
+    num_items = int(len(dataset)/num_users)
+    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
+    for i in range(num_users):
+        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
+        all_idxs = list(set(all_idxs) - dict_users[i])
+    return dict_users
+
+
+def mnist_noniid(dataset, num_users):
+    """
+    Sample non-I.I.D client data from MNIST dataset
+    :param dataset:
+    :param num_users:
+    :return:
+    """
+    num_shards, num_imgs = 200, 300
+    idx_shard = [i for i in range(num_shards)]
+    dict_users = {i: np.array([]) for i in range(num_users)}
+    idxs = np.arange(num_shards*num_imgs)
+    labels = dataset.train_labels.numpy()
+
+    # sort labels
+    idxs_labels = np.vstack((idxs, labels))
+    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
+    idxs = idxs_labels[0, :]
+
+    # divide and assign
+    for i in range(num_users):
+        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
+        idx_shard = list(set(idx_shard) - rand_set)
+        for rand in rand_set:
+            dict_users[i] = np.concatenate(
+                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+    return dict_users
+
+
+def cifar_iid(dataset, num_users):
+    """
+    Sample I.I.D. client data from CIFAR10 dataset
+    :param dataset:
+    :param num_users:
+    :return: dict of image index
+    """
+    num_items = int(len(dataset)/num_users)
+    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
+    for i in range(num_users):
+        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
+        all_idxs = list(set(all_idxs) - dict_users[i])
+    return dict_users
+
+
+def cifar_noniid(dataset, num_users):
+    """
+    Sample non-I.I.D client data from CIFAR10 dataset
+    :param dataset:
+    :param num_users:
+    :return:
+    """
+    num_shards, num_imgs = 200, 250
+    idx_shard = [i for i in range(num_shards)]
+    dict_users = {i: np.array([]) for i in range(num_users)}
+    idxs = np.arange(num_shards*num_imgs)
+    labels = dataset.train_labels.numpy()
+
+    # sort labels
+    idxs_labels = np.vstack((idxs, labels))
+    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
+    idxs = idxs_labels[0, :]
+
+    # divide and assign
+    for i in range(num_users):
+        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
+        idx_shard = list(set(idx_shard) - rand_set)
+        for rand in rand_set:
+            dict_users[i] = np.concatenate(
+                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+    return dict_users
+
+
+if __name__ == '__main__':
+    dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
+                                   transform=transforms.Compose([
+                                       transforms.ToTensor(),
+                                       transforms.Normalize((0.1307,), (0.3081,))
+                                   ]))
+    num = 100
+    d = mnist_noniid(dataset_train, num)