privacy_engine_xl.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import torch
  2. import opacus
  3. from typing import List, Union
  4. import os
  5. def generate_noise(max_norm, parameter, sigma, noise_type, device):
  6. if sigma > 0:
  7. mean = 0
  8. scale_scalar = sigma * max_norm
  9. scale = torch.full(size=parameter.shape, fill_value=scale_scalar, dtype=torch.float32, device=device)
  10. if noise_type == "gaussian":
  11. dist = torch.distributions.normal.Normal(mean, scale)
  12. elif noise_type == "laplacian":
  13. dist = torch.distributions.laplace.Laplace(mean, scale)
  14. elif noise_type == "exponential":
  15. rate = 1 / scale
  16. dist = torch.distributions.exponential.Exponential(rate)
  17. else:
  18. dist = torch.distributions.normal.Normal(mean, scale)
  19. noise = dist.sample()
  20. return noise
  21. return 0.0
  22. def apply_noise(weights, batch_size, sigma, noise_type, device, loss_reduction="mean"):
  23. for p in weights.values():
  24. noise = generate_noise(0, p, sigma, noise_type, device)
  25. if loss_reduction == "mean":
  26. noise /= batch_size
  27. p += noise
  28. class PrivacyEngineXL(opacus.PrivacyEngine):
  29. def __init__(
  30. self,
  31. module: torch.nn.Module,
  32. batch_size: int,
  33. sample_size: int,
  34. alphas: List[float],
  35. noise_multiplier: float,
  36. max_grad_norm: Union[float, List[float]],
  37. secure_rng: bool = False,
  38. grad_norm_type: int = 2,
  39. batch_first: bool = True,
  40. target_delta: float = 1e-6,
  41. loss_reduction: str = "mean",
  42. noise_type: str="gaussian",
  43. **misc_settings
  44. ):
  45. import warnings
  46. if secure_rng:
  47. warnings.warn(
  48. "Secure RNG was turned on. However it is not yet implemented for the noise distributions of privacy_engine_xl."
  49. )
  50. opacus.PrivacyEngine.__init__(
  51. self,
  52. module,
  53. batch_size,
  54. sample_size,
  55. alphas,
  56. noise_multiplier,
  57. max_grad_norm,
  58. secure_rng,
  59. grad_norm_type,
  60. batch_first,
  61. target_delta,
  62. loss_reduction,
  63. **misc_settings)
  64. self.noise_type = noise_type
  65. def _generate_noise(self, max_norm, parameter):
  66. if self.noise_multiplier > 0:
  67. mean = 0
  68. scale_scalar = self.noise_multiplier * max_norm
  69. scale = torch.full(size=parameter.grad.shape, fill_value=scale_scalar, dtype=torch.float32, device=self.device)
  70. if self.noise_type == "gaussian":
  71. dist = torch.distributions.normal.Normal(mean, scale)
  72. elif self.noise_type == "laplacian":
  73. dist = torch.distributions.laplace.Laplace(mean, scale)
  74. elif self.noise_type == "exponential":
  75. rate = 1 / scale
  76. dist = torch.distributions.exponential.Exponential(rate)
  77. else:
  78. dist = torch.distributions.normal.Normal(mean, scale)
  79. noise = dist.sample()
  80. return noise
  81. return 0.0