privacy_engine_xl.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import torch
  2. import opacus
  3. from typing import List, Union
  4. import os
  5. class PrivacyEngineXL(opacus.PrivacyEngine):
  6. def __init__(
  7. self,
  8. module: torch.nn.Module,
  9. batch_size: int,
  10. sample_size: int,
  11. alphas: List[float],
  12. noise_multiplier: float,
  13. max_grad_norm: Union[float, List[float]],
  14. secure_rng: bool = False,
  15. grad_norm_type: int = 2,
  16. batch_first: bool = True,
  17. target_delta: float = 1e-6,
  18. loss_reduction: str = "mean",
  19. noise_type: str="gaussian",
  20. **misc_settings
  21. ):
  22. import warnings
  23. if secure_rng:
  24. warnings.warn(
  25. "Secure RNG was turned on. However it is not yet implemented for the noise distributions of privacy_engine_xl."
  26. )
  27. opacus.PrivacyEngine.__init__(
  28. self,
  29. module,
  30. batch_size,
  31. sample_size,
  32. alphas,
  33. noise_multiplier,
  34. max_grad_norm,
  35. secure_rng,
  36. grad_norm_type,
  37. batch_first,
  38. target_delta,
  39. loss_reduction,
  40. **misc_settings)
  41. self.noise_type = noise_type
  42. def _generate_noise(self, max_norm, parameter):
  43. if self.noise_multiplier > 0:
  44. mean = 0
  45. scale_scalar = self.noise_multiplier * max_norm
  46. scale = torch.full(size=parameter.grad.shape, fill_value=scale_scalar, dtype=torch.float32, device=self.device)
  47. if self.noise_type == "gaussian":
  48. dist = torch.distributions.normal.Normal(mean, scale)
  49. elif self.noise_type == "laplacian":
  50. dist = torch.distributions.laplace.Laplace(mean, scale)
  51. elif self.noise_type == "exponential":
  52. rate = 1 / scale
  53. dist = torch.distributions.exponential.Exponential(rate)
  54. else:
  55. dist = torch.distributions.normal.Normal(mean, scale)
  56. noise = dist.sample()
  57. return noise
  58. return 0.0