2 Commits a88815e07c ... a39ff65a6a

Author SHA1 Message Date
  Jens Keim a39ff65a6a allow lists or tuples for weights as well 3 years ago
  Jens Keim 0b24012e06 fix serverside noise 3 years ago
1 changed files with 9 additions and 3 deletions
  1. 9 3
      privacy_engine_xl.py

+ 9 - 3
privacy_engine_xl.py

@@ -47,7 +47,7 @@ def generate_noise(max_norm, parameter, noise_multiplier, noise_type, device):
     return 0.0
 
 # Server side Noise
-def apply_noise(weights, batch_size, noise_multiplier, noise_type, device, loss_reduction="mean"):
+def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, device, loss_reduction="mean"):
     """
     A function for applying noise to weights on the (intermediate) server side that utilizes the generate_noise function above.
 
@@ -55,6 +55,9 @@ def apply_noise(weights, batch_size, noise_multiplier, noise_type, device, loss_
         The weights to which to apply the noise.
     @param batch_size
         Batch size used for averaging.
+    @param max_norm
+        The maximum norm of the per-sample gradients. Any gradient with norm
+        higher than this will be clipped to this value.
     @param noise_multiplier
         The ratio of the standard deviation of the Gaussian noise to
         the L2-sensitivity of the function to which the noise is added
@@ -67,8 +70,11 @@ def apply_noise(weights, batch_size, noise_multiplier, noise_type, device, loss_
         The method of loss reduction.
         currently supported: mean
     """
-    for p in weights.values():
-        noise = generate_noise(0, p, noise_multiplier, noise_type, device)
+    if isinstance(weights, dict):
+        weights = weights.values()
+
+    for p in weights:
+        noise = generate_noise(max_norm, p, noise_multiplier, noise_type, device)
         if loss_reduction == "mean":
             noise /= batch_size
         p += noise