123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # Python version: 3.6
- import torch
- from torch import nn, autograd
- from torch.utils.data import DataLoader, Dataset
- import numpy as np
- from sklearn import metrics
- class DatasetSplit(Dataset):
- def __init__(self, dataset, idxs):
- self.dataset = dataset
- self.idxs = list(idxs)
- def __len__(self):
- return len(self.idxs)
- def __getitem__(self, item):
- image, label = self.dataset[self.idxs[item]]
- return image, label
- class LocalUpdate(object):
- def __init__(self, args, dataset, idxs, tb):
- self.args = args
- self.loss_func = nn.NLLLoss()
- self.ldr_train, self.ldr_val, self.ldr_test = self.train_val_test(dataset, list(idxs))
- self.tb = tb
- def train_val_test(self, dataset, idxs):
- # split train, validation, and test
- idxs_train = idxs[:420]
- idxs_val = idxs[420:480]
- idxs_test = idxs[480:]
- train = DataLoader(DatasetSplit(dataset, idxs_train),
- batch_size=self.args.local_bs, shuffle=True)
- val = DataLoader(DatasetSplit(dataset, idxs_val),
- batch_size=int(len(idxs_val)/10), shuffle=True)
- test = DataLoader(DatasetSplit(dataset, idxs_test),
- batch_size=int(len(idxs_test)/10), shuffle=True)
- return train, val, test
- def update_weights(self, net):
- net.train()
- # train and update
- # Add support for other optimizers
- optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5)
- epoch_loss = []
- for iter in range(self.args.local_ep):
- batch_loss = []
- for batch_idx, (images, labels) in enumerate(self.ldr_train):
- if self.args.gpu != -1:
- images, labels = images.cuda(), labels.cuda()
- images, labels = autograd.Variable(images), autograd.Variable(labels)
- net.zero_grad()
- log_probs = net(images)
- loss = self.loss_func(log_probs, labels)
- loss.backward()
- optimizer.step()
- if self.args.gpu != -1:
- loss = loss.cpu()
- if self.args.verbose and batch_idx % 10 == 0:
- print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
- iter, batch_idx * len(images), len(self.ldr_train.dataset),
- 100. * batch_idx / len(self.ldr_train), loss.data[0]))
- self.tb.add_scalar('loss', loss.data[0])
- batch_loss.append(loss.data[0])
- epoch_loss.append(sum(batch_loss)/len(batch_loss))
- return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
- def test(self, net):
- for batch_idx, (images, labels) in enumerate(self.ldr_test):
- if self.args.gpu != -1:
- images, labels = images.cuda(), labels.cuda()
- images, labels = autograd.Variable(images), autograd.Variable(labels)
- log_probs = net(images)
- loss = self.loss_func(log_probs, labels)
- if self.args.gpu != -1:
- loss = loss.cpu()
- log_probs = log_probs.cpu()
- labels = labels.cpu()
- y_pred = np.argmax(log_probs.data, axis=1)
- acc = metrics.accuracy_score(y_true=labels.data, y_pred=y_pred)
- return acc, loss.data[0]
|