Преглед на файлове

Update and rename Update.py to update.py

Ashwin R Jadhav преди 4 години
родител
ревизия
4cf63daad4
променени са 2 файла, в които са добавени 110 реда и са изтрити 87 реда
  1. 0 87
      src/Update.py
  2. 110 0
      src/update.py

+ 0 - 87
src/Update.py

@@ -1,87 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Python version: 3.6
-
-import torch
-from torch import nn, autograd
-from torch.utils.data import DataLoader, Dataset
-import numpy as np
-from sklearn import metrics
-
-
-class DatasetSplit(Dataset):
-    def __init__(self, dataset, idxs):
-        self.dataset = dataset
-        self.idxs = [int(i) for i in idxs]
-
-    def __len__(self):
-        return len(self.idxs)
-
-    def __getitem__(self, item):
-        image, label = self.dataset[self.idxs[item]]
-        return image, label
-
-
-class LocalUpdate(object):
-    def __init__(self, args, dataset, idxs, tb):
-        self.args = args
-        self.loss_func = nn.NLLLoss()
-        self.ldr_train, self.ldr_val, self.ldr_test = self.train_val_test(dataset, list(idxs))
-        self.tb = tb
-
-    def train_val_test(self, dataset, idxs):
-        # split train, validation, and test
-        idxs_train = idxs[:420]
-        idxs_val = idxs[420:480]
-        idxs_test = idxs[480:]
-        train = DataLoader(DatasetSplit(dataset, idxs_train),
-                           batch_size=self.args.local_bs, shuffle=True)
-        val = DataLoader(DatasetSplit(dataset, idxs_val),
-                         batch_size=int(len(idxs_val)/10), shuffle=True)
-        test = DataLoader(DatasetSplit(dataset, idxs_test),
-                          batch_size=int(len(idxs_test)/10), shuffle=True)
-        return train, val, test
-
-    def update_weights(self, net):
-        net.train()
-        # train and update
-        # Add support for other optimizers
-        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5)
-
-        epoch_loss = []
-        for iter in range(self.args.local_ep):
-            batch_loss = []
-            for batch_idx, (images, labels) in enumerate(self.ldr_train):
-                if self.args.gpu != -1:
-                    images, labels = images.cuda(), labels.cuda()
-                images, labels = autograd.Variable(images), autograd.Variable(labels)
-                net.zero_grad()
-                log_probs = net(images)
-                loss = self.loss_func(log_probs, labels)
-                loss.backward()
-                optimizer.step()
-                if self.args.gpu != -1:
-                    loss = loss.cpu()
-                if self.args.verbose and batch_idx % 10 == 0:
-                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
-                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
-                        100. * batch_idx / len(self.ldr_train), loss.data[0]))
-                self.tb.add_scalar('loss', loss.data[0])
-                batch_loss.append(loss.data[0])
-            epoch_loss.append(sum(batch_loss)/len(batch_loss))
-        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
-
-    def test(self, net):
-        for batch_idx, (images, labels) in enumerate(self.ldr_test):
-            if self.args.gpu != -1:
-                images, labels = images.cuda(), labels.cuda()
-            images, labels = autograd.Variable(images), autograd.Variable(labels)
-            log_probs = net(images)
-            loss = self.loss_func(log_probs, labels)
-        if self.args.gpu != -1:
-            loss = loss.cpu()
-            log_probs = log_probs.cpu()
-            labels = labels.cpu()
-        y_pred = np.argmax(log_probs.data, axis=1)
-        acc = metrics.accuracy_score(y_true=labels.data, y_pred=y_pred)
-        return acc, loss.data[0]

+ 110 - 0
src/update.py

@@ -0,0 +1,110 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+import torch
+from torch import nn
+from torch.utils.data import DataLoader, Dataset
+
+
+class DatasetSplit(Dataset):
+    """An abstract Dataset class wrapped around Pytorch Dataset class.
+    """
+
+    def __init__(self, dataset, idxs):
+        self.dataset = dataset
+        self.idxs = [int(i) for i in idxs]
+
+    def __len__(self):
+        return len(self.idxs)
+
+    def __getitem__(self, item):
+        image, label = self.dataset[self.idxs[item]]
+        return torch.tensor(image), torch.tensor(label)
+
+
+class LocalUpdate(object):
+    def __init__(self, args, dataset, idxs, logger):
+        self.args = args
+        self.logger = logger
+        self.trainloader, self.validloader, self.testloader = self.train_val_test(
+            dataset, list(idxs))
+        self.device = 'cuda' if args.gpu else 'cpu'
+        # Default criterion set to NLL loss function
+        self.criterion = nn.NLLLoss().to(self.device)
+
+    def train_val_test(self, dataset, idxs):
+        """
+        Returns train, validation and test dataloaders for a given dataset
+        and user indexes.
+        """
+        # split indexes for train, validation, and test (80, 10, 10)
+        idxs_train = idxs[:(0.8*len(idxs))]
+        idxs_val = idxs[(0.8*len(idxs)):(0.9*len(idxs))]
+        idxs_test = idxs[(0.9*len(idxs)):]
+
+        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
+                                 batch_size=self.args.local_bs, shuffle=True)
+        validloader = DataLoader(DatasetSplit(dataset, idxs_val),
+                                 batch_size=int(len(idxs_val)/10), shuffle=True)
+        testloader = DataLoader(DatasetSplit(dataset, idxs_test),
+                                batch_size=int(len(idxs_test)/10), shuffle=True)
+        return trainloader, validloader, testloader
+
+    def update_weights(self, model):
+        # Set mode to train model
+        model.train()
+        epoch_loss = []
+
+        # Set optimizer for the local updates
+        if self.args.optimizer == 'sgd':
+            optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
+                                        momentum=0.5)
+        elif self.args.optimizer == 'adam':
+            optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
+                                         weight_decay=1e-4)
+
+        for iter in range(self.args.local_ep):
+            batch_loss = []
+            for batch_idx, (images, labels) in enumerate(self.trainloader):
+                images, labels = images.to(self.device), labels.to(self.device)
+
+                model.zero_grad()
+                log_probs = model(images)
+                loss = self.criterion(log_probs, labels)
+                loss.backward()
+                optimizer.step()
+
+                if self.args.verbose and batch_idx % 10 == 0:
+                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+                        iter, batch_idx * len(images), len(self.trainloader.dataset),
+                        100. * batch_idx / len(self.trainloader), loss.item()))
+                self.logger.add_scalar('loss', loss.item())
+                batch_loss.append(loss.item())
+            epoch_loss.append(sum(batch_loss)/len(batch_loss))
+
+        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
+
+    def inference(self, model):
+        """ Returns the inference accuracy and loss.
+        """
+
+        model.eval()
+        loss, total, correct = 0.0, 0.0, 0.0
+
+        for batch_idx, (images, labels) in enumerate(self.testloader):
+            images, labels = images.to(self.device), labels.to(self.device)
+
+            # Inference
+            outputs = model(images)
+            batch_loss = self.criterion(outputs, labels)
+            loss += batch_loss.item()
+
+            # Prediction
+            _, pred_labels = torch.max(outputs, 1)
+            pred_labels = pred_labels.view(-1)
+            correct += torch.sum(torch.eq(pred_labels, labels)).item()
+            total += len(labels)
+
+        accuracy = correct/total
+        return accuracy, loss.item()