update.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import torch
  5. from torch import nn
  6. from torch.utils.data import DataLoader, Dataset
  7. class DatasetSplit(Dataset):
  8. """An abstract Dataset class wrapped around Pytorch Dataset class.
  9. """
  10. def __init__(self, dataset, idxs):
  11. self.dataset = dataset
  12. self.idxs = [int(i) for i in idxs]
  13. def __len__(self):
  14. return len(self.idxs)
  15. def __getitem__(self, item):
  16. image, label = self.dataset[self.idxs[item]]
  17. return torch.tensor(image), torch.tensor(label)
  18. class LocalUpdate(object):
  19. def __init__(self, args, dataset, idxs, logger):
  20. self.args = args
  21. self.logger = logger
  22. self.trainloader, self.validloader, self.testloader = self.train_val_test(
  23. dataset, list(idxs))
  24. self.device = 'cuda' if args.gpu else 'cpu'
  25. # Default criterion set to NLL loss function
  26. self.criterion = nn.NLLLoss().to(self.device)
  27. def train_val_test(self, dataset, idxs):
  28. """
  29. Returns train, validation and test dataloaders for a given dataset
  30. and user indexes.
  31. """
  32. # split indexes for train, validation, and test (80, 10, 10)
  33. idxs_train = idxs[:(0.8*len(idxs))]
  34. idxs_val = idxs[(0.8*len(idxs)):(0.9*len(idxs))]
  35. idxs_test = idxs[(0.9*len(idxs)):]
  36. trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
  37. batch_size=self.args.local_bs, shuffle=True)
  38. validloader = DataLoader(DatasetSplit(dataset, idxs_val),
  39. batch_size=int(len(idxs_val)/10), shuffle=True)
  40. testloader = DataLoader(DatasetSplit(dataset, idxs_test),
  41. batch_size=int(len(idxs_test)/10), shuffle=True)
  42. return trainloader, validloader, testloader
  43. def update_weights(self, model):
  44. # Set mode to train model
  45. model.train()
  46. epoch_loss = []
  47. # Set optimizer for the local updates
  48. if self.args.optimizer == 'sgd':
  49. optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
  50. momentum=0.5)
  51. elif self.args.optimizer == 'adam':
  52. optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
  53. weight_decay=1e-4)
  54. for iter in range(self.args.local_ep):
  55. batch_loss = []
  56. for batch_idx, (images, labels) in enumerate(self.trainloader):
  57. images, labels = images.to(self.device), labels.to(self.device)
  58. model.zero_grad()
  59. log_probs = model(images)
  60. loss = self.criterion(log_probs, labels)
  61. loss.backward()
  62. optimizer.step()
  63. if self.args.verbose and batch_idx % 10 == 0:
  64. print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  65. iter, batch_idx * len(images), len(self.trainloader.dataset),
  66. 100. * batch_idx / len(self.trainloader), loss.item()))
  67. self.logger.add_scalar('loss', loss.item())
  68. batch_loss.append(loss.item())
  69. epoch_loss.append(sum(batch_loss)/len(batch_loss))
  70. return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
  71. def inference(self, model):
  72. """ Returns the inference accuracy and loss.
  73. """
  74. model.eval()
  75. loss, total, correct = 0.0, 0.0, 0.0
  76. for batch_idx, (images, labels) in enumerate(self.testloader):
  77. images, labels = images.to(self.device), labels.to(self.device)
  78. # Inference
  79. outputs = model(images)
  80. batch_loss = self.criterion(outputs, labels)
  81. loss += batch_loss.item()
  82. # Prediction
  83. _, pred_labels = torch.max(outputs, 1)
  84. pred_labels = pred_labels.view(-1)
  85. correct += torch.sum(torch.eq(pred_labels, labels)).item()
  86. total += len(labels)
  87. accuracy = correct/total
  88. return accuracy, loss.item()