#!/usr/bin/env python # -*- coding: utf-8 -*- # Python version: 3.6 from tqdm import tqdm import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torch import autograd import torch.optim as optim from torchvision import datasets, transforms from options import args_parser from FedNets import MLP, CNNMnist, CNNCifar import matplotlib import matplotlib.pyplot as plt matplotlib.use('Agg') def test(net_g, data_loader): # testing net_g.eval() test_loss = 0 correct = 0 l = len(data_loader) for idx, (data, target) in enumerate(data_loader): if args.gpu != -1: data, target = data.cuda(), target.cuda() data, target = autograd.Variable(data, volatile=True), autograd.Variable(target) log_probs = net_g(data) test_loss += F.nll_loss(log_probs, target, size_average=False).data[0] y_pred = log_probs.data.max(1, keepdim=True)[1] correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() test_loss /= len(data_loader.dataset) print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( test_loss, correct, len(data_loader.dataset), 100. * correct / len(data_loader.dataset))) return correct, test_loss if __name__ == '__main__': # parse args args = args_parser() torch.manual_seed(args.seed) # load dataset and split users if args.dataset == 'mnist': dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) img_size = dataset_train[0][0].shape elif args.dataset == 'cifar': transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) dataset_train = datasets.CIFAR10( '../data/cifar', train=True, transform=transform, target_transform=None, download=True) img_size = dataset_train[0][0].shape else: exit('Error: unrecognized dataset') # build model if args.model == 'cnn' and args.dataset == 'cifar': if args.gpu != -1: torch.cuda.set_device(args.gpu) net_glob = CNNCifar(args=args).cuda() else: net_glob = CNNCifar(args=args) elif args.model == 'cnn' and args.dataset == 'mnist': if args.gpu != -1: torch.cuda.set_device(args.gpu) net_glob = CNNMnist(args=args).cuda() else: net_glob = CNNMnist(args=args) elif args.model == 'mlp': len_in = 1 for x in img_size: len_in *= x if args.gpu != -1: torch.cuda.set_device(args.gpu) net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda() else: net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes) else: exit('Error: unrecognized model') print(net_glob) # training optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum) train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True) list_loss = [] net_glob.train() for epoch in tqdm(range(args.epochs)): batch_loss = [] for batch_idx, (data, target) in enumerate(train_loader): if args.gpu != -1: data, target = data.cuda(), target.cuda() data, target = autograd.Variable(data), autograd.Variable(target) optimizer.zero_grad() output = net_glob(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % 50 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.data[0])) batch_loss.append(loss.data[0]) loss_avg = sum(batch_loss)/len(batch_loss) print('\nTrain loss:', loss_avg) list_loss.append(loss_avg) # plot loss plt.figure() plt.plot(range(len(list_loss)), list_loss) plt.xlabel('epochs') plt.ylabel('train loss') plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs)) # testing if args.dataset == 'mnist': dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) elif args.dataset == 'cifar': transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) dataset_test = datasets.CIFAR10('../data/cifar', train=False, transform=transform, target_transform=None, download=True) test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) else: exit('Error: unrecognized dataset') print('Test on', len(dataset_test), 'samples') test_acc, test_loss = test(net_glob, test_loader)