123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # Python version: 3.6
- import torch
- from torch import nn
- from torch.utils.data import DataLoader, Dataset
- class DatasetSplit(Dataset):
- """An abstract Dataset class wrapped around Pytorch Dataset class.
- """
- def __init__(self, dataset, idxs):
- self.dataset = dataset
- self.idxs = [int(i) for i in idxs]
- def __len__(self):
- return len(self.idxs)
- def __getitem__(self, item):
- image, label = self.dataset[self.idxs[item]]
- return torch.tensor(image), torch.tensor(label)
- class LocalUpdate(object):
- def __init__(self, args, dataset, idxs, logger):
- self.args = args
- self.logger = logger
- self.trainloader, self.validloader, self.testloader = self.train_val_test(
- dataset, list(idxs))
- self.device = 'cuda' if args.gpu else 'cpu'
- # Default criterion set to NLL loss function
- self.criterion = nn.NLLLoss().to(self.device)
- def train_val_test(self, dataset, idxs):
- """
- Returns train, validation and test dataloaders for a given dataset
- and user indexes.
- """
- # split indexes for train, validation, and test (80, 10, 10)
- idxs_train = idxs[:(0.8*len(idxs))]
- idxs_val = idxs[(0.8*len(idxs)):(0.9*len(idxs))]
- idxs_test = idxs[(0.9*len(idxs)):]
- trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
- batch_size=self.args.local_bs, shuffle=True)
- validloader = DataLoader(DatasetSplit(dataset, idxs_val),
- batch_size=int(len(idxs_val)/10), shuffle=True)
- testloader = DataLoader(DatasetSplit(dataset, idxs_test),
- batch_size=int(len(idxs_test)/10), shuffle=True)
- return trainloader, validloader, testloader
- def update_weights(self, model):
- # Set mode to train model
- model.train()
- epoch_loss = []
- # Set optimizer for the local updates
- if self.args.optimizer == 'sgd':
- optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
- momentum=0.5)
- elif self.args.optimizer == 'adam':
- optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
- weight_decay=1e-4)
- for iter in range(self.args.local_ep):
- batch_loss = []
- for batch_idx, (images, labels) in enumerate(self.trainloader):
- images, labels = images.to(self.device), labels.to(self.device)
- model.zero_grad()
- log_probs = model(images)
- loss = self.criterion(log_probs, labels)
- loss.backward()
- optimizer.step()
- if self.args.verbose and batch_idx % 10 == 0:
- print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
- iter, batch_idx * len(images), len(self.trainloader.dataset),
- 100. * batch_idx / len(self.trainloader), loss.item()))
- self.logger.add_scalar('loss', loss.item())
- batch_loss.append(loss.item())
- epoch_loss.append(sum(batch_loss)/len(batch_loss))
- return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
- def inference(self, model):
- """ Returns the inference accuracy and loss.
- """
- model.eval()
- loss, total, correct = 0.0, 0.0, 0.0
- for batch_idx, (images, labels) in enumerate(self.testloader):
- images, labels = images.to(self.device), labels.to(self.device)
- # Inference
- outputs = model(images)
- batch_loss = self.criterion(outputs, labels)
- loss += batch_loss.item()
- # Prediction
- _, pred_labels = torch.max(outputs, 1)
- pred_labels = pred_labels.view(-1)
- correct += torch.sum(torch.eq(pred_labels, labels)).item()
- total += len(labels)
- accuracy = correct/total
- return accuracy, loss.item()
|