@@ -12,7 +12,7 @@ from sklearn import metrics
class DatasetSplit(Dataset):
def __init__(self, dataset, idxs):
self.dataset = dataset
- self.idxs = list(idxs)
+ self.idxs = [int(i) for i in idxs]
def __len__(self):
return len(self.idxs)