|
@@ -0,0 +1,149 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+# Python version: 3.6
|
|
|
+
|
|
|
+
|
|
|
+from tqdm import tqdm
|
|
|
+import torch
|
|
|
+import torch.nn.functional as F
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+from torch import autograd
|
|
|
+import torch.optim as optim
|
|
|
+from torchvision import datasets, transforms
|
|
|
+
|
|
|
+from options import args_parser
|
|
|
+from FedNets import MLP, CNNMnist, CNNCifar
|
|
|
+
|
|
|
+import matplotlib
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+matplotlib.use('Agg')
|
|
|
+
|
|
|
+
|
|
|
+def test(net_g, data_loader):
|
|
|
+ # testing
|
|
|
+ net_g.eval()
|
|
|
+ test_loss = 0
|
|
|
+ correct = 0
|
|
|
+ l = len(data_loader)
|
|
|
+ for idx, (data, target) in enumerate(data_loader):
|
|
|
+ if args.gpu != -1:
|
|
|
+ data, target = data.cuda(), target.cuda()
|
|
|
+ data, target = autograd.Variable(data, volatile=True), autograd.Variable(target)
|
|
|
+ log_probs = net_g(data)
|
|
|
+ test_loss += F.nll_loss(log_probs, target, size_average=False).data[0]
|
|
|
+ y_pred = log_probs.data.max(1, keepdim=True)[1]
|
|
|
+ correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
|
|
|
+
|
|
|
+ test_loss /= len(data_loader.dataset)
|
|
|
+ print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
|
|
|
+ test_loss, correct, len(data_loader.dataset),
|
|
|
+ 100. * correct / len(data_loader.dataset)))
|
|
|
+
|
|
|
+ return correct, test_loss
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ # parse args
|
|
|
+ args = args_parser()
|
|
|
+
|
|
|
+ torch.manual_seed(args.seed)
|
|
|
+
|
|
|
+ # load dataset and split users
|
|
|
+ if args.dataset == 'mnist':
|
|
|
+ dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
|
|
|
+ transform=transforms.Compose([
|
|
|
+ transforms.ToTensor(),
|
|
|
+ transforms.Normalize((0.1307,), (0.3081,))
|
|
|
+ ]))
|
|
|
+ img_size = dataset_train[0][0].shape
|
|
|
+
|
|
|
+ elif args.dataset == 'cifar':
|
|
|
+ transform = transforms.Compose(
|
|
|
+ [transforms.ToTensor(),
|
|
|
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
|
+ dataset_train = datasets.CIFAR10(
|
|
|
+ '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
|
|
|
+ img_size = dataset_train[0][0].shape
|
|
|
+ else:
|
|
|
+ exit('Error: unrecognized dataset')
|
|
|
+
|
|
|
+ # build model
|
|
|
+ if args.model == 'cnn' and args.dataset == 'cifar':
|
|
|
+ if args.gpu != -1:
|
|
|
+ torch.cuda.set_device(args.gpu)
|
|
|
+ net_glob = CNNCifar(args=args).cuda()
|
|
|
+ else:
|
|
|
+ net_glob = CNNCifar(args=args)
|
|
|
+ elif args.model == 'cnn' and args.dataset == 'mnist':
|
|
|
+ if args.gpu != -1:
|
|
|
+ torch.cuda.set_device(args.gpu)
|
|
|
+ net_glob = CNNMnist(args=args).cuda()
|
|
|
+ else:
|
|
|
+ net_glob = CNNMnist(args=args)
|
|
|
+ elif args.model == 'mlp':
|
|
|
+ len_in = 1
|
|
|
+ for x in img_size:
|
|
|
+ len_in *= x
|
|
|
+ if args.gpu != -1:
|
|
|
+ torch.cuda.set_device(args.gpu)
|
|
|
+ net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
|
|
|
+ else:
|
|
|
+ net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
|
|
|
+ else:
|
|
|
+ exit('Error: unrecognized model')
|
|
|
+ print(net_glob)
|
|
|
+
|
|
|
+ # training
|
|
|
+ optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
|
|
|
+ train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
|
|
|
+
|
|
|
+ list_loss = []
|
|
|
+ net_glob.train()
|
|
|
+ for epoch in tqdm(range(args.epochs)):
|
|
|
+ batch_loss = []
|
|
|
+ for batch_idx, (data, target) in enumerate(train_loader):
|
|
|
+ if args.gpu != -1:
|
|
|
+ data, target = data.cuda(), target.cuda()
|
|
|
+ data, target = autograd.Variable(data), autograd.Variable(target)
|
|
|
+ optimizer.zero_grad()
|
|
|
+ output = net_glob(data)
|
|
|
+ loss = F.nll_loss(output, target)
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+ if batch_idx % 50 == 0:
|
|
|
+ print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
|
|
+ epoch, batch_idx * len(data), len(train_loader.dataset),
|
|
|
+ 100. * batch_idx / len(train_loader), loss.data[0]))
|
|
|
+ batch_loss.append(loss.data[0])
|
|
|
+ loss_avg = sum(batch_loss)/len(batch_loss)
|
|
|
+ print('\nTrain loss:', loss_avg)
|
|
|
+ list_loss.append(loss_avg)
|
|
|
+
|
|
|
+ # plot loss
|
|
|
+ plt.figure()
|
|
|
+ plt.plot(range(len(list_loss)), list_loss)
|
|
|
+ plt.xlabel('epochs')
|
|
|
+ plt.ylabel('train loss')
|
|
|
+ plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))
|
|
|
+
|
|
|
+ # testing
|
|
|
+ if args.dataset == 'mnist':
|
|
|
+ dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True,
|
|
|
+ transform=transforms.Compose([
|
|
|
+ transforms.ToTensor(),
|
|
|
+ transforms.Normalize((0.1307,), (0.3081,))
|
|
|
+ ]))
|
|
|
+ test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
|
|
|
+
|
|
|
+ elif args.dataset == 'cifar':
|
|
|
+ transform = transforms.Compose(
|
|
|
+ [transforms.ToTensor(),
|
|
|
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
|
+ dataset_test = datasets.CIFAR10('../data/cifar', train=False,
|
|
|
+ transform=transform, target_transform=None, download=True)
|
|
|
+ test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
|
|
|
+ else:
|
|
|
+ exit('Error: unrecognized dataset')
|
|
|
+
|
|
|
+ print('Test on', len(dataset_test), 'samples')
|
|
|
+ test_acc, test_loss = test(net_glob, test_loader)
|