Update.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import torch
  5. from torch import nn, autograd
  6. from torch.utils.data import DataLoader, Dataset
  7. import numpy as np
  8. from sklearn import metrics
  9. class DatasetSplit(Dataset):
  10. def __init__(self, dataset, idxs):
  11. self.dataset = dataset
  12. self.idxs = [int(i) for i in idxs]
  13. def __len__(self):
  14. return len(self.idxs)
  15. def __getitem__(self, item):
  16. image, label = self.dataset[self.idxs[item]]
  17. return image, label
  18. class LocalUpdate(object):
  19. def __init__(self, args, dataset, idxs, tb):
  20. self.args = args
  21. self.loss_func = nn.NLLLoss()
  22. self.ldr_train, self.ldr_val, self.ldr_test = self.train_val_test(dataset, list(idxs))
  23. self.tb = tb
  24. def train_val_test(self, dataset, idxs):
  25. # split train, validation, and test
  26. idxs_train = idxs[:420]
  27. idxs_val = idxs[420:480]
  28. idxs_test = idxs[480:]
  29. train = DataLoader(DatasetSplit(dataset, idxs_train),
  30. batch_size=self.args.local_bs, shuffle=True)
  31. val = DataLoader(DatasetSplit(dataset, idxs_val),
  32. batch_size=int(len(idxs_val)/10), shuffle=True)
  33. test = DataLoader(DatasetSplit(dataset, idxs_test),
  34. batch_size=int(len(idxs_test)/10), shuffle=True)
  35. return train, val, test
  36. def update_weights(self, net):
  37. net.train()
  38. # train and update
  39. # Add support for other optimizers
  40. optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5)
  41. epoch_loss = []
  42. for iter in range(self.args.local_ep):
  43. batch_loss = []
  44. for batch_idx, (images, labels) in enumerate(self.ldr_train):
  45. if self.args.gpu != -1:
  46. images, labels = images.cuda(), labels.cuda()
  47. images, labels = autograd.Variable(images), autograd.Variable(labels)
  48. net.zero_grad()
  49. log_probs = net(images)
  50. loss = self.loss_func(log_probs, labels)
  51. loss.backward()
  52. optimizer.step()
  53. if self.args.gpu != -1:
  54. loss = loss.cpu()
  55. if self.args.verbose and batch_idx % 10 == 0:
  56. print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  57. iter, batch_idx * len(images), len(self.ldr_train.dataset),
  58. 100. * batch_idx / len(self.ldr_train), loss.data[0]))
  59. self.tb.add_scalar('loss', loss.data[0])
  60. batch_loss.append(loss.data[0])
  61. epoch_loss.append(sum(batch_loss)/len(batch_loss))
  62. return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
  63. def test(self, net):
  64. for batch_idx, (images, labels) in enumerate(self.ldr_test):
  65. if self.args.gpu != -1:
  66. images, labels = images.cuda(), labels.cuda()
  67. images, labels = autograd.Variable(images), autograd.Variable(labels)
  68. log_probs = net(images)
  69. loss = self.loss_func(log_probs, labels)
  70. if self.args.gpu != -1:
  71. loss = loss.cpu()
  72. log_probs = log_probs.cpu()
  73. labels = labels.cpu()
  74. y_pred = np.argmax(log_probs.data, axis=1)
  75. acc = metrics.accuracy_score(y_true=labels.data, y_pred=y_pred)
  76. return acc, loss.data[0]