@@ -21,7 +21,7 @@ if __name__ == '__main__':
torch.cuda.set_device(args.gpu)
device = 'cuda' if args.gpu else 'cpu'
- # load dataset and user groups
+ # load datasets
train_dataset, test_dataset, _ = get_dataset(args)
# BUILD MODEL