utils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import copy
  5. import torch
  6. from torchvision import datasets, transforms
  7. from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
  8. from sampling import cifar_iid, cifar_noniid
  9. def get_dataset(args):
  10. """ Returns train and test datasets and a user group which is a dict where
  11. the keys are the user index and the values are the corresponding data for
  12. each of those users.
  13. """
  14. if args.dataset == 'cifar':
  15. data_dir = '../data/cifar/'
  16. apply_transform = transforms.Compose(
  17. [transforms.ToTensor(),
  18. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  19. # train_dataset = datasets.MNIST(data_dir, train=True, download=True,
  20. # transform=apply_transform)
  21. train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
  22. transform=apply_transform)
  23. # test_dataset = datasets.MNIST(data_dir, train=False, download=True,
  24. # transform=apply_transform)
  25. test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
  26. transform=apply_transform)
  27. # sample training data amongst users
  28. if args.iid:
  29. # Sample IID user data from Mnist
  30. user_groups = cifar_iid(train_dataset, args.num_users)
  31. else:
  32. # Sample Non-IID user data from Mnist
  33. if args.unequal:
  34. # Chose uneuqal splits for every user
  35. raise NotImplementedError()
  36. else:
  37. # Chose euqal splits for every user
  38. user_groups = cifar_noniid(train_dataset, args.num_users)
  39. elif args.dataset == 'mnist' or 'fmnist':
  40. if args.dataset == 'mnist':
  41. data_dir = '../data/mnist/'
  42. else:
  43. data_dir = '../data/fmnist/'
  44. apply_transform = transforms.Compose([
  45. transforms.ToTensor(),
  46. transforms.Normalize((0.1307,), (0.3081,))])
  47. train_dataset = datasets.MNIST(data_dir, train=True, download=True,
  48. transform=apply_transform)
  49. test_dataset = datasets.MNIST(data_dir, train=False, download=True,
  50. transform=apply_transform)
  51. # sample training data amongst users
  52. if args.iid:
  53. # Sample IID user data from Mnist
  54. user_groups = mnist_iid(train_dataset, args.num_users)
  55. else:
  56. # Sample Non-IID user data from Mnist
  57. if args.unequal:
  58. # Chose uneuqal splits for every user
  59. user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
  60. else:
  61. # Chose euqal splits for every user
  62. user_groups = mnist_noniid(train_dataset, args.num_users)
  63. return train_dataset, test_dataset, user_groups
  64. def average_weights(w):
  65. """
  66. Returns the average of the weights.
  67. """
  68. w_avg = copy.deepcopy(w[0])
  69. for key in w_avg.keys():
  70. for i in range(1, len(w)):
  71. w_avg[key] += w[i][key]
  72. w_avg[key] = torch.div(w_avg[key], len(w))
  73. return w_avg
  74. def exp_details(args):
  75. print('\nExperimental details:')
  76. print(f' Model : {args.model}')
  77. print(f' Optimizer : {args.optimizer}')
  78. print(f' Learning : {args.lr}')
  79. print(f' Global Rounds : {args.epochs}\n')
  80. print(' Federated parameters:')
  81. if args.iid:
  82. print(' IID')
  83. else:
  84. print(' Non-IID')
  85. print(f' Fraction of users : {args.frac}')
  86. print(f' Local Batch size : {args.local_bs}')
  87. print(f' Local Epochs : {args.local_ep}\n')
  88. return