privacy_engine_xl.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import torch
  2. import opacus
  3. from typing import List, Union
  4. import os
  5. def generate_noise(max_norm, parameter, noise_multiplier, noise_type, device):
  6. """
  7. A noise generation function that can utilize different distributions for noise generation.
  8. @param max_norm
  9. The maximum norm of the per-sample gradients. Any gradient with norm
  10. higher than this will be clipped to this value.
  11. @param parameter
  12. The parameter, based on which the dimension of the noise tensor
  13. will be determined
  14. @param noise_multiplier
  15. The ratio of the standard deviation of the Gaussian noise to
  16. the L2-sensitivity of the function to which the noise is added
  17. @param noise_type
  18. Sets the distribution for the noise generation.
  19. See generate_noise for supported strings.
  20. @param device
  21. The device used for calculations and needed for tensor definition.
  22. @return
  23. a tensor of noise in the same shape as ``parameter``.
  24. """
  25. if noise_multiplier > 0:
  26. mean = 0
  27. scale_scalar = noise_multiplier * max_norm
  28. scale = torch.full(size=parameter.shape, fill_value=scale_scalar, dtype=torch.float32, device=device)
  29. if noise_type.lower() in ["normal", "gauss", "gaussian"]:
  30. dist = torch.distributions.normal.Normal(mean, scale)
  31. elif noise_type.lower() in ["laplace", "laplacian"]:
  32. dist = torch.distributions.laplace.Laplace(mean, scale)
  33. elif noise_type.lower() in ["exponential"]:
  34. rate = 1 / scale
  35. dist = torch.distributions.exponential.Exponential(rate)
  36. else:
  37. dist = torch.distributions.normal.Normal(mean, scale)
  38. noise = dist.sample()
  39. return noise
  40. return 0.0
  41. # Server side Noise
  42. def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, device, loss_reduction="mean", clipping=False):
  43. """
  44. A function for applying noise to weights on the (intermediate) server side that utilizes the generate_noise function above.
  45. @param weights
  46. The weights to which to apply the noise.
  47. @param batch_size
  48. Batch size used for averaging.
  49. @param max_norm
  50. The maximum norm of the per-sample gradients. Any gradient with norm
  51. higher than this will be clipped to this value.
  52. @param noise_multiplier
  53. The ratio of the standard deviation of the Gaussian noise to
  54. the L2-sensitivity of the function to which the noise is added
  55. @param noise_type
  56. Sets the distribution for the noise generation.
  57. See generate_noise for supported strings.
  58. @param device
  59. The device used for calculations and needed for tensor definition.
  60. @param loss_reduction
  61. The method of loss reduction.
  62. currently supported: mean
  63. """
  64. if isinstance(weights, dict):
  65. weights = weights.values()
  66. if max_norm == None:
  67. max_norm = 1.0
  68. clipped = 0
  69. total = 0
  70. for p in weights:
  71. total += 1
  72. if clipping:
  73. norm = torch.norm(p, p=2)
  74. div_norm = max(1, norm/max_norm)
  75. if div_norm != 1:
  76. clipped += 1
  77. p /= div_norm
  78. noise = generate_noise(max_norm, p, noise_multiplier, noise_type, device)
  79. if loss_reduction == "mean":
  80. noise /= batch_size
  81. p += noise
  82. return clipped, total
  83. # Client side Noise
  84. class PrivacyEngineXL(opacus.PrivacyEngine):
  85. """
  86. A privacy engine that can utilize different distributions for noise generation, based on opacus' privacy engine.
  87. It gets attached to the optimizer just like the privacy engine from opacus.
  88. @param module:
  89. The Pytorch module to which we are attaching the privacy engine
  90. @param batch_size
  91. Training batch size. Used in the privacy accountant.
  92. @param sample_size
  93. The size of the sample (dataset). Used in the privacy accountant.
  94. @param alphas
  95. A list of RDP orders
  96. @param noise_multiplier
  97. The ratio of the standard deviation of the Gaussian noise to
  98. the L2-sensitivity of the function to which the noise is added
  99. @param max_grad_norm
  100. The maximum norm of the per-sample gradients. Any gradient with norm
  101. higher than this will be clipped to this value.
  102. @param secure_rng
  103. If on, it will use ``torchcsprng`` for secure random number generation. Comes with
  104. a significant performance cost, therefore it's recommended that you turn it off when
  105. just experimenting.
  106. @param grad_norm_type
  107. The order of the norm. For instance, 2 represents L-2 norm, while
  108. 1 represents L-1 norm.
  109. @param batch_first
  110. Flag to indicate if the input tensor to the corresponding module
  111. has the first dimension representing the batch. If set to True,
  112. dimensions on input tensor will be ``[batch_size, ..., ...]``.
  113. @param target_delta
  114. The target delta
  115. @param loss_reduction
  116. Indicates if the loss reduction (for aggregating the gradients)
  117. is a sum or a mean operation. Can take values "sum" or "mean"
  118. @param noise_type
  119. Sets the distribution for the noise generation.
  120. See generate_noise for supported strings.
  121. @param **misc_settings
  122. Other arguments to the init
  123. """
  124. def __init__(
  125. self,
  126. module: torch.nn.Module,
  127. batch_size: int,
  128. sample_size: int,
  129. alphas: List[float],
  130. noise_multiplier: float,
  131. max_grad_norm: Union[float, List[float]],
  132. secure_rng: bool = False,
  133. grad_norm_type: int = 2,
  134. batch_first: bool = True,
  135. target_delta: float = 1e-6,
  136. loss_reduction: str = "mean",
  137. noise_type: str="gaussian",
  138. **misc_settings
  139. ):
  140. import warnings
  141. if secure_rng:
  142. warnings.warn(
  143. "Secure RNG was turned on. However it is not yet implemented for the noise distributions of privacy_engine_xl."
  144. )
  145. opacus.PrivacyEngine.__init__(
  146. self,
  147. module,
  148. batch_size,
  149. sample_size,
  150. alphas,
  151. noise_multiplier,
  152. max_grad_norm,
  153. secure_rng,
  154. grad_norm_type,
  155. batch_first,
  156. target_delta,
  157. loss_reduction,
  158. **misc_settings)
  159. self.noise_type = noise_type
  160. def _generate_noise(self, max_norm, parameter):
  161. """
  162. Generates a tensor of noise in the same shape as ``parameter``.
  163. @param max_norm
  164. The maximum norm of the per-sample gradients. Any gradient with norm
  165. higher than this will be clipped to this value.
  166. @param parameter
  167. The parameter, based on which the dimension of the noise tensor
  168. will be determined
  169. @return
  170. a tensor of noise in the same shape as ``parameter``.
  171. """
  172. return generate_noise(max_norm, parameter, self.noise_multiplier, self.noise_type, self.device)