Browse Source

add clipping

Jens Keim 2 years ago
parent
commit
59fd146c03
1 changed files with 9 additions and 1 deletions
  1. 9 1
      privacy_engine_xl.py

+ 9 - 1
privacy_engine_xl.py

@@ -47,7 +47,7 @@ def generate_noise(max_norm, parameter, noise_multiplier, noise_type, device):
     return 0.0
 
 # Server side Noise
-def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, device, loss_reduction="mean"):
+def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, device, loss_reduction="mean", clipping=False):
     """
     A function for applying noise to weights on the (intermediate) server side that utilizes the generate_noise function above.
 
@@ -73,7 +73,15 @@ def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, dev
     if isinstance(weights, dict):
         weights = weights.values()
 
+    if max_norm == None:
+        max_norm = 1.0
+
     for p in weights:
+        if clipping:
+            norm = torch.norm(p, p=2)
+            div_norm = max(1, norm/max_norm)
+            p /= div_norm
+
         noise = generate_noise(max_norm, p, noise_multiplier, noise_type, device)
         if loss_reduction == "mean":
             noise /= batch_size