update.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import torch
  5. from torch import nn
  6. from torch.utils.data import DataLoader, Dataset
  7. import utils
  8. #from utils import set_device
  9. class DatasetSplit(Dataset):
  10. """An abstract Dataset class wrapped around Pytorch Dataset class.
  11. """
  12. def __init__(self, dataset, idxs):
  13. self.dataset = dataset
  14. self.idxs = [int(i) for i in idxs]
  15. def __len__(self):
  16. return len(self.idxs)
  17. def __getitem__(self, item):
  18. image, label = self.dataset[self.idxs[item]]
  19. return torch.tensor(image), torch.tensor(label)
  20. class LocalUpdate(object):
  21. def __init__(self, args, dataset, idxs, logger):
  22. self.args = args
  23. self.logger = logger
  24. self.trainloader, self.validloader, self.testloader = self.train_val_test(
  25. dataset, list(idxs))
  26. # Select CPU or GPU
  27. self.device = utils.set_device(args)
  28. # Default criterion set to NLL loss function
  29. self.criterion = nn.NLLLoss().to(self.device)
  30. def train_val_test(self, dataset, idxs):
  31. """
  32. Returns train, validation and test dataloaders for a given dataset
  33. and user indexes.
  34. """
  35. # split indexes for train, validation, and test (80, 10, 10)
  36. idxs_train = idxs[:int(0.8*len(idxs))]
  37. idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
  38. idxs_test = idxs[int(0.9*len(idxs)):]
  39. trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
  40. batch_size=self.args.local_bs, shuffle=True)
  41. validloader = DataLoader(DatasetSplit(dataset, idxs_val),
  42. batch_size=int(len(idxs_val)/10), shuffle=False)
  43. testloader = DataLoader(DatasetSplit(dataset, idxs_test),
  44. batch_size=int(len(idxs_test)/10), shuffle=False)
  45. return trainloader, validloader, testloader
  46. def update_weights(self, model, global_round, dtype=torch.float32):
  47. # Set mode to train model
  48. model.train()
  49. epoch_loss = []
  50. # Set dtype for criterion
  51. self.criterion.to(dtype)
  52. # Set optimizer for the local updates
  53. if self.args.optimizer == 'sgd':
  54. optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
  55. momentum=0.5)
  56. elif self.args.optimizer == 'adam':
  57. optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
  58. weight_decay=1e-4)
  59. for iter in range(self.args.local_ep):
  60. batch_loss = []
  61. for batch_idx, (images, labels) in enumerate(self.trainloader):
  62. images, labels = images.to(self.device), labels.to(self.device)
  63. images = images.to(dtype)
  64. # labels shouldn't be cast to criterion_dtype, and should remain as dtype long
  65. model.zero_grad()
  66. log_probs = model(images)
  67. loss = self.criterion(log_probs, labels)
  68. loss.backward()
  69. optimizer.step()
  70. # if self.args.verbose and (batch_idx % 10 == 0):
  71. # print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  72. # global_round, iter, batch_idx * len(images),
  73. # len(self.trainloader.dataset),
  74. # 100. * batch_idx / len(self.trainloader), loss.item()))
  75. self.logger.add_scalar('loss', loss.item())
  76. batch_loss.append(loss.item())
  77. epoch_loss.append(sum(batch_loss)/len(batch_loss))
  78. return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
  79. def inference(self, model, dtype=torch.float32):
  80. """ Returns the inference accuracy and loss.
  81. """
  82. model.eval()
  83. loss, total, correct = 0.0, 0.0, 0.0
  84. # Set dtype for criterion
  85. self.criterion.to(dtype)
  86. for batch_idx, (images, labels) in enumerate(self.testloader):
  87. images, labels = images.to(self.device), labels.to(self.device)
  88. images = images.to(dtype)
  89. # Inference
  90. outputs = model(images)
  91. batch_loss = self.criterion(outputs, labels)
  92. loss += batch_loss.item()
  93. # Prediction
  94. _, pred_labels = torch.max(outputs, 1)
  95. pred_labels = pred_labels.view(-1)
  96. correct += torch.sum(torch.eq(pred_labels, labels)).item()
  97. total += len(labels)
  98. accuracy = correct/total
  99. return accuracy, loss
  100. def test_inference(args, model, test_dataset, dtype=torch.float32):
  101. """ Returns the test accuracy and loss.
  102. """
  103. model.eval()
  104. model.to(dtype)
  105. loss, total, correct = 0.0, 0.0, 0.0
  106. # Select CPU or GPU
  107. device = utils.set_device(args)
  108. criterion = nn.NLLLoss().to(device)
  109. # Set dtype for criterion
  110. criterion.to(dtype)
  111. testloader = DataLoader(test_dataset, batch_size=128,
  112. shuffle=False)
  113. for batch_idx, (images, labels) in enumerate(testloader):
  114. images, labels = images.to(device), labels.to(device)
  115. images = images.to(dtype)
  116. # Inference
  117. outputs = model(images)
  118. batch_loss = criterion(outputs, labels)
  119. loss += batch_loss.item()
  120. # Prediction
  121. _, pred_labels = torch.max(outputs, 1)
  122. pred_labels = pred_labels.view(-1)
  123. correct += torch.sum(torch.eq(pred_labels, labels)).item()
  124. total += len(labels)
  125. accuracy = correct/total
  126. return accuracy, loss