|
@@ -0,0 +1,110 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+# Python version: 3.6
|
|
|
+
|
|
|
+import torch
|
|
|
+from torch import nn
|
|
|
+from torch.utils.data import DataLoader, Dataset
|
|
|
+
|
|
|
+
|
|
|
+class DatasetSplit(Dataset):
|
|
|
+ """An abstract Dataset class wrapped around Pytorch Dataset class.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, dataset, idxs):
|
|
|
+ self.dataset = dataset
|
|
|
+ self.idxs = [int(i) for i in idxs]
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.idxs)
|
|
|
+
|
|
|
+ def __getitem__(self, item):
|
|
|
+ image, label = self.dataset[self.idxs[item]]
|
|
|
+ return torch.tensor(image), torch.tensor(label)
|
|
|
+
|
|
|
+
|
|
|
+class LocalUpdate(object):
|
|
|
+ def __init__(self, args, dataset, idxs, logger):
|
|
|
+ self.args = args
|
|
|
+ self.logger = logger
|
|
|
+ self.trainloader, self.validloader, self.testloader = self.train_val_test(
|
|
|
+ dataset, list(idxs))
|
|
|
+ self.device = 'cuda' if args.gpu else 'cpu'
|
|
|
+ # Default criterion set to NLL loss function
|
|
|
+ self.criterion = nn.NLLLoss().to(self.device)
|
|
|
+
|
|
|
+ def train_val_test(self, dataset, idxs):
|
|
|
+ """
|
|
|
+ Returns train, validation and test dataloaders for a given dataset
|
|
|
+ and user indexes.
|
|
|
+ """
|
|
|
+ # split indexes for train, validation, and test (80, 10, 10)
|
|
|
+ idxs_train = idxs[:(0.8*len(idxs))]
|
|
|
+ idxs_val = idxs[(0.8*len(idxs)):(0.9*len(idxs))]
|
|
|
+ idxs_test = idxs[(0.9*len(idxs)):]
|
|
|
+
|
|
|
+ trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
|
|
|
+ batch_size=self.args.local_bs, shuffle=True)
|
|
|
+ validloader = DataLoader(DatasetSplit(dataset, idxs_val),
|
|
|
+ batch_size=int(len(idxs_val)/10), shuffle=True)
|
|
|
+ testloader = DataLoader(DatasetSplit(dataset, idxs_test),
|
|
|
+ batch_size=int(len(idxs_test)/10), shuffle=True)
|
|
|
+ return trainloader, validloader, testloader
|
|
|
+
|
|
|
+ def update_weights(self, model):
|
|
|
+ # Set mode to train model
|
|
|
+ model.train()
|
|
|
+ epoch_loss = []
|
|
|
+
|
|
|
+ # Set optimizer for the local updates
|
|
|
+ if self.args.optimizer == 'sgd':
|
|
|
+ optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
|
|
|
+ momentum=0.5)
|
|
|
+ elif self.args.optimizer == 'adam':
|
|
|
+ optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
|
|
|
+ weight_decay=1e-4)
|
|
|
+
|
|
|
+ for iter in range(self.args.local_ep):
|
|
|
+ batch_loss = []
|
|
|
+ for batch_idx, (images, labels) in enumerate(self.trainloader):
|
|
|
+ images, labels = images.to(self.device), labels.to(self.device)
|
|
|
+
|
|
|
+ model.zero_grad()
|
|
|
+ log_probs = model(images)
|
|
|
+ loss = self.criterion(log_probs, labels)
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ if self.args.verbose and batch_idx % 10 == 0:
|
|
|
+ print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
|
|
+ iter, batch_idx * len(images), len(self.trainloader.dataset),
|
|
|
+ 100. * batch_idx / len(self.trainloader), loss.item()))
|
|
|
+ self.logger.add_scalar('loss', loss.item())
|
|
|
+ batch_loss.append(loss.item())
|
|
|
+ epoch_loss.append(sum(batch_loss)/len(batch_loss))
|
|
|
+
|
|
|
+ return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
|
|
|
+
|
|
|
+ def inference(self, model):
|
|
|
+ """ Returns the inference accuracy and loss.
|
|
|
+ """
|
|
|
+
|
|
|
+ model.eval()
|
|
|
+ loss, total, correct = 0.0, 0.0, 0.0
|
|
|
+
|
|
|
+ for batch_idx, (images, labels) in enumerate(self.testloader):
|
|
|
+ images, labels = images.to(self.device), labels.to(self.device)
|
|
|
+
|
|
|
+ # Inference
|
|
|
+ outputs = model(images)
|
|
|
+ batch_loss = self.criterion(outputs, labels)
|
|
|
+ loss += batch_loss.item()
|
|
|
+
|
|
|
+ # Prediction
|
|
|
+ _, pred_labels = torch.max(outputs, 1)
|
|
|
+ pred_labels = pred_labels.view(-1)
|
|
|
+ correct += torch.sum(torch.eq(pred_labels, labels)).item()
|
|
|
+ total += len(labels)
|
|
|
+
|
|
|
+ accuracy = correct/total
|
|
|
+ return accuracy, loss.item()
|