|
@@ -39,19 +39,19 @@ class LocalUpdate(object):
|
|
|
and user indexes.
|
|
|
"""
|
|
|
# split indexes for train, validation, and test (80, 10, 10)
|
|
|
- idxs_train = idxs[:(0.8*len(idxs))]
|
|
|
- idxs_val = idxs[(0.8*len(idxs)):(0.9*len(idxs))]
|
|
|
- idxs_test = idxs[(0.9*len(idxs)):]
|
|
|
+ idxs_train = idxs[:int(0.8*len(idxs))]
|
|
|
+ idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
|
|
|
+ idxs_test = idxs[int(0.9*len(idxs)):]
|
|
|
|
|
|
trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
|
|
|
batch_size=self.args.local_bs, shuffle=True)
|
|
|
validloader = DataLoader(DatasetSplit(dataset, idxs_val),
|
|
|
- batch_size=int(len(idxs_val)/10), shuffle=True)
|
|
|
+ batch_size=int(len(idxs_val)/10), shuffle=False)
|
|
|
testloader = DataLoader(DatasetSplit(dataset, idxs_test),
|
|
|
- batch_size=int(len(idxs_test)/10), shuffle=True)
|
|
|
+ batch_size=int(len(idxs_test)/10), shuffle=False)
|
|
|
return trainloader, validloader, testloader
|
|
|
|
|
|
- def update_weights(self, model):
|
|
|
+ def update_weights(self, model, global_round):
|
|
|
# Set mode to train model
|
|
|
model.train()
|
|
|
epoch_loss = []
|
|
@@ -75,9 +75,10 @@ class LocalUpdate(object):
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
- if self.args.verbose and batch_idx % 10 == 0:
|
|
|
- print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
|
|
- iter, batch_idx * len(images), len(self.trainloader.dataset),
|
|
|
+ if self.args.verbose and (batch_idx % 10 == 0):
|
|
|
+ print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
|
|
+ global_round, iter, batch_idx * len(images),
|
|
|
+ len(self.trainloader.dataset),
|
|
|
100. * batch_idx / len(self.trainloader), loss.item()))
|
|
|
self.logger.add_scalar('loss', loss.item())
|
|
|
batch_loss.append(loss.item())
|
|
@@ -107,4 +108,34 @@ class LocalUpdate(object):
|
|
|
total += len(labels)
|
|
|
|
|
|
accuracy = correct/total
|
|
|
- return accuracy, loss.item()
|
|
|
+ return accuracy, loss
|
|
|
+
|
|
|
+
|
|
|
+def test_inference(args, model, test_dataset):
|
|
|
+ """ Returns the test accuracy and loss.
|
|
|
+ """
|
|
|
+
|
|
|
+ model.eval()
|
|
|
+ loss, total, correct = 0.0, 0.0, 0.0
|
|
|
+
|
|
|
+ device = 'cuda' if args.gpu else 'cpu'
|
|
|
+ criterion = nn.NLLLoss().to(device)
|
|
|
+ testloader = DataLoader(test_dataset, batch_size=128,
|
|
|
+ shuffle=False)
|
|
|
+
|
|
|
+ for batch_idx, (images, labels) in enumerate(testloader):
|
|
|
+ images, labels = images.to(device), labels.to(device)
|
|
|
+
|
|
|
+ # Inference
|
|
|
+ outputs = model(images)
|
|
|
+ batch_loss = criterion(outputs, labels)
|
|
|
+ loss += batch_loss.item()
|
|
|
+
|
|
|
+ # Prediction
|
|
|
+ _, pred_labels = torch.max(outputs, 1)
|
|
|
+ pred_labels = pred_labels.view(-1)
|
|
|
+ correct += torch.sum(torch.eq(pred_labels, labels)).item()
|
|
|
+ total += len(labels)
|
|
|
+
|
|
|
+ accuracy = correct/total
|
|
|
+ return accuracy, loss
|