main_nn.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. from tqdm import tqdm
  5. import torch
  6. import torch.nn.functional as F
  7. from torch.utils.data import DataLoader
  8. from torch import autograd
  9. import torch.optim as optim
  10. from torchvision import datasets, transforms
  11. from options import args_parser
  12. from FedNets import MLP, CNNMnist, CNNCifar
  13. import matplotlib
  14. import matplotlib.pyplot as plt
  15. matplotlib.use('Agg')
  16. def test(net_g, data_loader):
  17. # testing
  18. net_g.eval()
  19. test_loss = 0
  20. correct = 0
  21. l = len(data_loader)
  22. for idx, (data, target) in enumerate(data_loader):
  23. if args.gpu != -1:
  24. data, target = data.cuda(), target.cuda()
  25. data, target = autograd.Variable(data, volatile=True), autograd.Variable(target)
  26. log_probs = net_g(data)
  27. test_loss += F.nll_loss(log_probs, target, size_average=False).data[0]
  28. y_pred = log_probs.data.max(1, keepdim=True)[1]
  29. correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
  30. test_loss /= len(data_loader.dataset)
  31. print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
  32. test_loss, correct, len(data_loader.dataset),
  33. 100. * correct / len(data_loader.dataset)))
  34. return correct, test_loss
  35. if __name__ == '__main__':
  36. # parse args
  37. args = args_parser()
  38. torch.manual_seed(args.seed)
  39. # load dataset and split users
  40. if args.dataset == 'mnist':
  41. dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
  42. transform=transforms.Compose([
  43. transforms.ToTensor(),
  44. transforms.Normalize((0.1307,), (0.3081,))
  45. ]))
  46. img_size = dataset_train[0][0].shape
  47. elif args.dataset == 'cifar':
  48. transform = transforms.Compose(
  49. [transforms.ToTensor(),
  50. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  51. dataset_train = datasets.CIFAR10(
  52. '../data/cifar', train=True, transform=transform, target_transform=None, download=True)
  53. img_size = dataset_train[0][0].shape
  54. else:
  55. exit('Error: unrecognized dataset')
  56. # build model
  57. if args.model == 'cnn' and args.dataset == 'cifar':
  58. if args.gpu != -1:
  59. torch.cuda.set_device(args.gpu)
  60. net_glob = CNNCifar(args=args).cuda()
  61. else:
  62. net_glob = CNNCifar(args=args)
  63. elif args.model == 'cnn' and args.dataset == 'mnist':
  64. if args.gpu != -1:
  65. torch.cuda.set_device(args.gpu)
  66. net_glob = CNNMnist(args=args).cuda()
  67. else:
  68. net_glob = CNNMnist(args=args)
  69. elif args.model == 'mlp':
  70. len_in = 1
  71. for x in img_size:
  72. len_in *= x
  73. if args.gpu != -1:
  74. torch.cuda.set_device(args.gpu)
  75. net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
  76. else:
  77. net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
  78. else:
  79. exit('Error: unrecognized model')
  80. print(net_glob)
  81. # training
  82. optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
  83. train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
  84. list_loss = []
  85. net_glob.train()
  86. for epoch in tqdm(range(args.epochs)):
  87. batch_loss = []
  88. for batch_idx, (data, target) in enumerate(train_loader):
  89. if args.gpu != -1:
  90. data, target = data.cuda(), target.cuda()
  91. data, target = autograd.Variable(data), autograd.Variable(target)
  92. optimizer.zero_grad()
  93. output = net_glob(data)
  94. loss = F.nll_loss(output, target)
  95. loss.backward()
  96. optimizer.step()
  97. if batch_idx % 50 == 0:
  98. print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  99. epoch, batch_idx * len(data), len(train_loader.dataset),
  100. 100. * batch_idx / len(train_loader), loss.data[0]))
  101. batch_loss.append(loss.data[0])
  102. loss_avg = sum(batch_loss)/len(batch_loss)
  103. print('\nTrain loss:', loss_avg)
  104. list_loss.append(loss_avg)
  105. # plot loss
  106. plt.figure()
  107. plt.plot(range(len(list_loss)), list_loss)
  108. plt.xlabel('epochs')
  109. plt.ylabel('train loss')
  110. plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))
  111. # testing
  112. if args.dataset == 'mnist':
  113. dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True,
  114. transform=transforms.Compose([
  115. transforms.ToTensor(),
  116. transforms.Normalize((0.1307,), (0.3081,))
  117. ]))
  118. test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
  119. elif args.dataset == 'cifar':
  120. transform = transforms.Compose(
  121. [transforms.ToTensor(),
  122. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  123. dataset_test = datasets.CIFAR10('../data/cifar', train=False,
  124. transform=transform, target_transform=None, download=True)
  125. test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
  126. else:
  127. exit('Error: unrecognized dataset')
  128. print('Test on', len(dataset_test), 'samples')
  129. test_acc, test_loss = test(net_glob, test_loader)