baseline_main.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. from tqdm import tqdm
  5. import matplotlib.pyplot as plt
  6. import torch
  7. from torch.utils.data import DataLoader
  8. from utils import get_dataset
  9. from options import args_parser
  10. from update import test_inference
  11. from FedNets import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
  12. if __name__ == '__main__':
  13. args = args_parser()
  14. if args.gpu:
  15. torch.cuda.set_device(args.gpu)
  16. device = 'cuda' if args.gpu else 'cpu'
  17. # load dataset and user groups
  18. train_dataset, test_dataset, _ = get_dataset(args)
  19. # BUILD MODEL
  20. if args.model == 'cnn':
  21. # Convolutional neural netork
  22. if args.dataset == 'mnist':
  23. global_model = CNNMnist(args=args)
  24. elif args.dataset == 'fmnist':
  25. global_model = CNNFashion_Mnist(args=args)
  26. elif args.dataset == 'cifar':
  27. global_model = CNNCifar(args=args)
  28. elif args.model == 'mlp':
  29. # Multi-layer preceptron
  30. img_size = train_dataset[0][0].shape
  31. len_in = 1
  32. for x in img_size:
  33. len_in *= x
  34. global_model = MLP(dim_in=len_in, dim_hidden=64,
  35. dim_out=args.num_classes)
  36. else:
  37. exit('Error: unrecognized model')
  38. # Set the model to train and send it to device.
  39. global_model.to(device)
  40. global_model.train()
  41. print(global_model)
  42. # Training
  43. # Set optimizer and criterion
  44. if args.optimizer == 'sgd':
  45. optimizer = torch.optim.SGD(global_model.parameters(), lr=args.lr,
  46. momentum=0.5)
  47. elif args.optimizer == 'adam':
  48. optimizer = torch.optim.Adam(global_model.parameters(), lr=args.lr,
  49. weight_decay=1e-4)
  50. trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  51. criterion = torch.nn.NLLLoss().to(device)
  52. epoch_loss = []
  53. for epoch in tqdm(range(args.epochs)):
  54. batch_loss = []
  55. for batch_idx, (images, labels) in enumerate(trainloader):
  56. images, labels = images.to(device), labels.to(device)
  57. optimizer.zero_grad()
  58. outputs = global_model(images)
  59. loss = criterion(outputs, labels)
  60. loss.backward()
  61. optimizer.step()
  62. if batch_idx % 50 == 0:
  63. print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  64. epoch+1, batch_idx * len(images), len(trainloader.dataset),
  65. 100. * batch_idx / len(trainloader), loss.item()))
  66. batch_loss.append(loss.item())
  67. loss_avg = sum(batch_loss)/len(batch_loss)
  68. print('\nTrain loss:', loss_avg)
  69. epoch_loss.append(loss_avg)
  70. # Plot loss
  71. plt.figure()
  72. plt.plot(range(len(epoch_loss)), epoch_loss)
  73. plt.xlabel('epochs')
  74. plt.ylabel('Train loss')
  75. plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model,
  76. args.epochs))
  77. # testing
  78. test_acc, test_loss = test_inference(args, global_model, test_dataset)
  79. print('Test on', len(test_dataset), 'samples')
  80. print("Test Accuracy: {:.2f}%".format(100*test_acc))