123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # Python version: 3.6
- import torch
- from torch import nn
- from torch.utils.data import DataLoader, Dataset
- class DatasetSplit(Dataset):
- """An abstract Dataset class wrapped around Pytorch Dataset class.
- """
- def __init__(self, dataset, idxs):
- self.dataset = dataset
- self.idxs = [int(i) for i in idxs]
- def __len__(self):
- return len(self.idxs)
- def __getitem__(self, item):
- image, label = self.dataset[self.idxs[item]]
- return torch.tensor(image), torch.tensor(label)
- class LocalUpdate(object):
- def __init__(self, args, dataset, idxs, logger):
- self.args = args
- self.logger = logger
- self.trainloader, self.validloader, self.testloader = self.train_val_test(
- dataset, list(idxs))
- self.device = 'cuda' if args.gpu else 'cpu'
- # Default criterion set to NLL loss function
- self.criterion = nn.NLLLoss().to(self.device)
- def train_val_test(self, dataset, idxs):
- """
- Returns train, validation and test dataloaders for a given dataset
- and user indexes.
- """
- # split indexes for train, validation, and test (80, 10, 10)
- idxs_train = idxs[:int(0.8*len(idxs))]
- idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
- idxs_test = idxs[int(0.9*len(idxs)):]
- trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
- batch_size=self.args.local_bs, shuffle=True)
- validloader = DataLoader(DatasetSplit(dataset, idxs_val),
- batch_size=int(len(idxs_val)/10), shuffle=False)
- testloader = DataLoader(DatasetSplit(dataset, idxs_test),
- batch_size=int(len(idxs_test)/10), shuffle=False)
- return trainloader, validloader, testloader
- def update_weights(self, model, global_round):
- # Set mode to train model
- model.train()
- epoch_loss = []
- # Set optimizer for the local updates
- if self.args.optimizer == 'sgd':
- optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
- momentum=0.5)
- elif self.args.optimizer == 'adam':
- optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
- weight_decay=1e-4)
- for iter in range(self.args.local_ep):
- batch_loss = []
- for batch_idx, (images, labels) in enumerate(self.trainloader):
- images, labels = images.to(self.device), labels.to(self.device)
- model.zero_grad()
- log_probs = model(images)
- loss = self.criterion(log_probs, labels)
- loss.backward()
- optimizer.step()
- # if self.args.verbose and (batch_idx % 10 == 0):
- # print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
- # global_round, iter, batch_idx * len(images),
- # len(self.trainloader.dataset),
- # 100. * batch_idx / len(self.trainloader), loss.item()))
- self.logger.add_scalar('loss', loss.item())
- batch_loss.append(loss.item())
- epoch_loss.append(sum(batch_loss)/len(batch_loss))
- return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
- def inference(self, model):
- """ Returns the inference accuracy and loss.
- """
- model.eval()
- loss, total, correct = 0.0, 0.0, 0.0
- for batch_idx, (images, labels) in enumerate(self.testloader):
- images, labels = images.to(self.device), labels.to(self.device)
- # Inference
- outputs = model(images)
- batch_loss = self.criterion(outputs, labels)
- loss += batch_loss.item()
- # Prediction
- _, pred_labels = torch.max(outputs, 1)
- pred_labels = pred_labels.view(-1)
- correct += torch.sum(torch.eq(pred_labels, labels)).item()
- total += len(labels)
- accuracy = correct/total
- return accuracy, loss
- def test_inference(args, model, test_dataset):
- """ Returns the test accuracy and loss.
- """
- model.eval()
- loss, total, correct = 0.0, 0.0, 0.0
- device = 'cuda' if args.gpu else 'cpu'
- criterion = nn.NLLLoss().to(device)
- testloader = DataLoader(test_dataset, batch_size=128,
- shuffle=False)
- for batch_idx, (images, labels) in enumerate(testloader):
- images, labels = images.to(device), labels.to(device)
- # Inference
- outputs = model(images)
- batch_loss = criterion(outputs, labels)
- loss += batch_loss.item()
- # Prediction
- _, pred_labels = torch.max(outputs, 1)
- pred_labels = pred_labels.view(-1)
- correct += torch.sum(torch.eq(pred_labels, labels)).item()
- total += len(labels)
- accuracy = correct/total
- return accuracy, loss
|