#!/usr/bin/env python # -*- coding: utf-8 -*- # Python version: 3.6 import torch from torch import nn, autograd from torch.utils.data import DataLoader, Dataset import numpy as np from sklearn import metrics class DatasetSplit(Dataset): def __init__(self, dataset, idxs): self.dataset = dataset self.idxs = [int(i) for i in idxs] def __len__(self): return len(self.idxs) def __getitem__(self, item): image, label = self.dataset[self.idxs[item]] return image, label class LocalUpdate(object): def __init__(self, args, dataset, idxs, tb): self.args = args self.loss_func = nn.NLLLoss() self.ldr_train, self.ldr_val, self.ldr_test = self.train_val_test(dataset, list(idxs)) self.tb = tb def train_val_test(self, dataset, idxs): # split train, validation, and test idxs_train = idxs[:420] idxs_val = idxs[420:480] idxs_test = idxs[480:] train = DataLoader(DatasetSplit(dataset, idxs_train), batch_size=self.args.local_bs, shuffle=True) val = DataLoader(DatasetSplit(dataset, idxs_val), batch_size=int(len(idxs_val)/10), shuffle=True) test = DataLoader(DatasetSplit(dataset, idxs_test), batch_size=int(len(idxs_test)/10), shuffle=True) return train, val, test def update_weights(self, net): net.train() # train and update # Add support for other optimizers optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5) epoch_loss = [] for iter in range(self.args.local_ep): batch_loss = [] for batch_idx, (images, labels) in enumerate(self.ldr_train): if self.args.gpu != -1: images, labels = images.cuda(), labels.cuda() images, labels = autograd.Variable(images), autograd.Variable(labels) net.zero_grad() log_probs = net(images) loss = self.loss_func(log_probs, labels) loss.backward() optimizer.step() if self.args.gpu != -1: loss = loss.cpu() if self.args.verbose and batch_idx % 10 == 0: print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( iter, batch_idx * len(images), len(self.ldr_train.dataset), 100. * batch_idx / len(self.ldr_train), loss.data[0])) self.tb.add_scalar('loss', loss.data[0]) batch_loss.append(loss.data[0]) epoch_loss.append(sum(batch_loss)/len(batch_loss)) return net.state_dict(), sum(epoch_loss) / len(epoch_loss) def test(self, net): for batch_idx, (images, labels) in enumerate(self.ldr_test): if self.args.gpu != -1: images, labels = images.cuda(), labels.cuda() images, labels = autograd.Variable(images), autograd.Variable(labels) log_probs = net(images) loss = self.loss_func(log_probs, labels) if self.args.gpu != -1: loss = loss.cpu() log_probs = log_probs.cpu() labels = labels.cpu() y_pred = np.argmax(log_probs.data, axis=1) acc = metrics.accuracy_score(y_true=labels.data, y_pred=y_pred) return acc, loss.data[0]