privacy_engine_xl.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import torch
  2. import opacus
  3. from typing import List, Union
  4. import os
  5. def generate_noise(max_norm, parameter, noise_multiplier, noise_type, device):
  6. if noise_multiplier > 0:
  7. mean = 0
  8. scale_scalar = noise_multiplier * max_norm
  9. scale = torch.full(size=parameter.shape, fill_value=scale_scalar, dtype=torch.float32, device=device)
  10. if noise_type.lower() in ["normal", "gauss", "gaussian"]:
  11. dist = torch.distributions.normal.Normal(mean, scale)
  12. elif noise_type.lower() in ["laplace", "laplacian"]:
  13. dist = torch.distributions.laplace.Laplace(mean, scale)
  14. elif noise_type.lower() in ["exponential"]:
  15. rate = 1 / scale
  16. dist = torch.distributions.exponential.Exponential(rate)
  17. else:
  18. dist = torch.distributions.normal.Normal(mean, scale)
  19. noise = dist.sample()
  20. return noise
  21. return 0.0
  22. def apply_noise(weights, batch_size, noise_multiplier, noise_type, device, loss_reduction="mean"):
  23. for p in weights.values():
  24. noise = generate_noise(0, p, noise_multiplier, noise_type, device)
  25. if loss_reduction == "mean":
  26. noise /= batch_size
  27. p += noise
  28. class PrivacyEngineXL(opacus.PrivacyEngine):
  29. def __init__(
  30. self,
  31. module: torch.nn.Module,
  32. batch_size: int,
  33. sample_size: int,
  34. alphas: List[float],
  35. noise_multiplier: float,
  36. max_grad_norm: Union[float, List[float]],
  37. secure_rng: bool = False,
  38. grad_norm_type: int = 2,
  39. batch_first: bool = True,
  40. target_delta: float = 1e-6,
  41. loss_reduction: str = "mean",
  42. noise_type: str="gaussian",
  43. **misc_settings
  44. ):
  45. import warnings
  46. if secure_rng:
  47. warnings.warn(
  48. "Secure RNG was turned on. However it is not yet implemented for the noise distributions of privacy_engine_xl."
  49. )
  50. opacus.PrivacyEngine.__init__(
  51. self,
  52. module,
  53. batch_size,
  54. sample_size,
  55. alphas,
  56. noise_multiplier,
  57. max_grad_norm,
  58. secure_rng,
  59. grad_norm_type,
  60. batch_first,
  61. target_delta,
  62. loss_reduction,
  63. **misc_settings)
  64. self.noise_type = noise_type
  65. def _generate_noise(self, max_norm, parameter):
  66. return generate_noise(max_norm, parameter, self.noise_multiplier, self.noise_type, self.device)