baseline_main_fp16.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. from tqdm import tqdm
  5. import matplotlib.pyplot as plt
  6. import torch
  7. from torch.utils.data import DataLoader
  8. from utils import get_dataset, set_device, build_model
  9. from options import args_parser
  10. from update import test_inference
  11. from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
  12. import pickle
  13. import time
  14. from sys import exit
  15. from torchsummary import summary
  16. if __name__ == '__main__':
  17. start_time = time.time()
  18. args = args_parser()
  19. # Select CPU or GPU
  20. device = set_device(args)
  21. # load datasets
  22. train_dataset, test_dataset, _ = get_dataset(args)
  23. # BUILD MODEL
  24. global_model = build_model(args, train_dataset)
  25. # Set the model to train and send it to device.
  26. global_model.to(device)
  27. # Set model to use Floating Point 16
  28. global_model.to(dtype=torch.float16) ##########
  29. global_model.train()
  30. print(global_model)
  31. #img_size = train_dataset[0][0].shape
  32. #summary(global_model, img_size) ####
  33. #print(global_model.parameters())
  34. # Training
  35. # Set optimizer and criterion
  36. if args.optimizer == 'sgd':
  37. optimizer = torch.optim.SGD(global_model.parameters(), lr=args.lr,
  38. momentum=0.5)
  39. elif args.optimizer == 'adam':
  40. optimizer = torch.optim.Adam(global_model.parameters(), lr=args.lr,
  41. weight_decay=1e-4)
  42. elif args.optimizer == 'adagrad':
  43. optimizer = torch.optim.Adagrad(global_model.parameters(), lr=args.lr,
  44. weight_decay=1e-4)
  45. elif args.optimizer == 'adamax':
  46. optimizer = torch.optim.Adamax(global_model.parameters(), lr=args.lr,
  47. weight_decay=1e-4)
  48. elif args.optimizer == 'rmsprop':
  49. optimizer = torch.optim.RMSprop(global_model.parameters(), lr=args.lr,
  50. weight_decay=1e-4)
  51. else:
  52. exit('Error- unrecognized optimizer: ' + args.optimizer)
  53. # look under optim for more info on scheduler
  54. #scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
  55. trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  56. criterion = torch.nn.NLLLoss().to(device)
  57. criterion.to(dtype = torch.float16) ################
  58. epoch_loss = []
  59. for epoch in tqdm(range(args.epochs)):
  60. batch_loss = []
  61. for batch_idx, (images, labels) in enumerate(trainloader):
  62. images=images.to(dtype=torch.float16) #################
  63. images, labels = images.to(device), labels.to(device)
  64. optimizer.zero_grad()
  65. outputs = global_model(images)
  66. loss = criterion(outputs, labels)
  67. loss.backward()
  68. optimizer.step()
  69. if batch_idx % 50 == 0:
  70. print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  71. epoch+1, batch_idx * len(images), len(trainloader.dataset),
  72. 100. * batch_idx / len(trainloader), loss.item()))
  73. batch_loss.append(loss.item())
  74. loss_avg = sum(batch_loss)/len(batch_loss)
  75. print('\nTrain loss:', loss_avg)
  76. epoch_loss.append(loss_avg)
  77. # testing
  78. test_acc, test_loss = test_inference(args, global_model, test_dataset, dtype=torch.float16) ############
  79. print('Test on', len(test_dataset), 'samples')
  80. print("Test Accuracy: {:.2f}%".format(100*test_acc))
  81. # Saving the objects train_loss, test_acc, test_loss:
  82. file_name = '../save/objects_fp16/BaseSGD_{}_{}_epoch[{}]_lr[{}]_iid[{}]_FP16.pkl'.\
  83. format(args.dataset, args.model, epoch, args.lr, args.iid)
  84. with open(file_name, 'wb') as f:
  85. pickle.dump([epoch_loss, test_acc, test_loss], f)
  86. print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
  87. # # Plot loss
  88. # plt.figure()
  89. # plt.plot(range(len(epoch_loss)), epoch_loss)
  90. # plt.xlabel('epochs')
  91. # plt.ylabel('Train loss')
  92. # plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model,
  93. # args.epochs))