Jelajahi Sumber

Added Regular NN Implementation

AshwinRJ 5 tahun lalu
induk
melakukan
3c93e8265e
4 mengubah file dengan 192 tambahan dan 0 penghapusan
  1. TEMPAT SAMPAH
      .DS_Store
  2. 0 0
      FedNets.py
  3. 149 0
      main_nn.py
  4. 43 0
      options.py

TEMPAT SAMPAH
.DS_Store


+ 0 - 0
NN_Arch.py → FedNets.py


+ 149 - 0
main_nn.py

@@ -0,0 +1,149 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+
+from tqdm import tqdm
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch import autograd
+import torch.optim as optim
+from torchvision import datasets, transforms
+
+from options import args_parser
+from FedNets import MLP, CNNMnist, CNNCifar
+
+import matplotlib
+import matplotlib.pyplot as plt
+matplotlib.use('Agg')
+
+
+def test(net_g, data_loader):
+    # testing
+    net_g.eval()
+    test_loss = 0
+    correct = 0
+    l = len(data_loader)
+    for idx, (data, target) in enumerate(data_loader):
+        if args.gpu != -1:
+            data, target = data.cuda(), target.cuda()
+        data, target = autograd.Variable(data, volatile=True), autograd.Variable(target)
+        log_probs = net_g(data)
+        test_loss += F.nll_loss(log_probs, target, size_average=False).data[0]
+        y_pred = log_probs.data.max(1, keepdim=True)[1]
+        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
+
+    test_loss /= len(data_loader.dataset)
+    print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
+        test_loss, correct, len(data_loader.dataset),
+        100. * correct / len(data_loader.dataset)))
+
+    return correct, test_loss
+
+
+if __name__ == '__main__':
+    # parse args
+    args = args_parser()
+
+    torch.manual_seed(args.seed)
+
+    # load dataset and split users
+    if args.dataset == 'mnist':
+        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
+                                       transform=transforms.Compose([
+                                           transforms.ToTensor(),
+                                           transforms.Normalize((0.1307,), (0.3081,))
+                                       ]))
+        img_size = dataset_train[0][0].shape
+
+    elif args.dataset == 'cifar':
+        transform = transforms.Compose(
+            [transforms.ToTensor(),
+             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+        dataset_train = datasets.CIFAR10(
+            '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
+        img_size = dataset_train[0][0].shape
+    else:
+        exit('Error: unrecognized dataset')
+
+    # build model
+    if args.model == 'cnn' and args.dataset == 'cifar':
+        if args.gpu != -1:
+            torch.cuda.set_device(args.gpu)
+            net_glob = CNNCifar(args=args).cuda()
+        else:
+            net_glob = CNNCifar(args=args)
+    elif args.model == 'cnn' and args.dataset == 'mnist':
+        if args.gpu != -1:
+            torch.cuda.set_device(args.gpu)
+            net_glob = CNNMnist(args=args).cuda()
+        else:
+            net_glob = CNNMnist(args=args)
+    elif args.model == 'mlp':
+        len_in = 1
+        for x in img_size:
+            len_in *= x
+        if args.gpu != -1:
+            torch.cuda.set_device(args.gpu)
+            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
+        else:
+            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
+    else:
+        exit('Error: unrecognized model')
+    print(net_glob)
+
+    # training
+    optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
+    train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
+
+    list_loss = []
+    net_glob.train()
+    for epoch in tqdm(range(args.epochs)):
+        batch_loss = []
+        for batch_idx, (data, target) in enumerate(train_loader):
+            if args.gpu != -1:
+                data, target = data.cuda(), target.cuda()
+            data, target = autograd.Variable(data), autograd.Variable(target)
+            optimizer.zero_grad()
+            output = net_glob(data)
+            loss = F.nll_loss(output, target)
+            loss.backward()
+            optimizer.step()
+            if batch_idx % 50 == 0:
+                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+                    epoch, batch_idx * len(data), len(train_loader.dataset),
+                    100. * batch_idx / len(train_loader), loss.data[0]))
+            batch_loss.append(loss.data[0])
+        loss_avg = sum(batch_loss)/len(batch_loss)
+        print('\nTrain loss:', loss_avg)
+        list_loss.append(loss_avg)
+
+    # plot loss
+    plt.figure()
+    plt.plot(range(len(list_loss)), list_loss)
+    plt.xlabel('epochs')
+    plt.ylabel('train loss')
+    plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))
+
+    # testing
+    if args.dataset == 'mnist':
+        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True,
+                                      transform=transforms.Compose([
+                                          transforms.ToTensor(),
+                                          transforms.Normalize((0.1307,), (0.3081,))
+                                      ]))
+        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
+
+    elif args.dataset == 'cifar':
+        transform = transforms.Compose(
+            [transforms.ToTensor(),
+             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+        dataset_test = datasets.CIFAR10('../data/cifar', train=False,
+                                        transform=transform, target_transform=None, download=True)
+        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
+    else:
+        exit('Error: unrecognized dataset')
+
+    print('Test on', len(dataset_test), 'samples')
+    test_acc, test_loss = test(net_glob, test_loader)

+ 43 - 0
options.py

@@ -0,0 +1,43 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+import argparse
+
+
+def args_parser():
+    parser = argparse.ArgumentParser()
+    # federated arguments
+    parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
+    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
+    parser.add_argument('--frac', type=float, default=0.1, help='the fraction of clients: C')
+    parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
+    parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
+    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
+    parser.add_argument('--momentum', type=float, default=0.5, help='SGD momentum (default: 0.5)')
+
+    # model arguments
+    parser.add_argument('--model', type=str, default='mlp', help='model name')
+    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
+    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
+                        help='comma-separated kernel size to use for convolution')
+    parser.add_argument('--norm', type=str, default='batch_norm',
+                        help="batch_norm, layer_norm, or None")
+    parser.add_argument('--num_filters', type=int, default=32,
+                        help="number of filters for conv nets -- 32 for miniimagenet, 64 for omiglot.")
+    parser.add_argument('--max_pool', type=str, default='True',
+                        help="Whether use max pooling rather than strided convolutions")
+
+    # other arguments
+    parser.add_argument('--dataset', type=str, default='cifar', help="name of dataset")
+    parser.add_argument('--iid', type=int, default=0,
+                        help='whether i.i.d or not, 1 for iid, 0 for non-iid')
+    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
+    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
+    parser.add_argument('--gpu', type=int, default=1, help="GPU ID")
+    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
+    parser.add_argument('--verbose', type=int, default=1,
+                        help='verbose print, 1 for True, 0 for False')
+    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
+    args = parser.parse_args()
+    return args