瀏覽代碼

fix serverside noise

add max_norm to apply_noise
Jens Keim 3 年之前
父節點
當前提交
0b24012e06
共有 1 個文件被更改,包括 5 次插入2 次删除
  1. 5 2
      privacy_engine_xl.py

+ 5 - 2
privacy_engine_xl.py

@@ -47,7 +47,7 @@ def generate_noise(max_norm, parameter, noise_multiplier, noise_type, device):
     return 0.0
 
 # Server side Noise
-def apply_noise(weights, batch_size, noise_multiplier, noise_type, device, loss_reduction="mean"):
+def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, device, loss_reduction="mean"):
     """
     A function for applying noise to weights on the (intermediate) server side that utilizes the generate_noise function above.
 
@@ -55,6 +55,9 @@ def apply_noise(weights, batch_size, noise_multiplier, noise_type, device, loss_
         The weights to which to apply the noise.
     @param batch_size
         Batch size used for averaging.
+    @param max_norm
+        The maximum norm of the per-sample gradients. Any gradient with norm
+        higher than this will be clipped to this value.
     @param noise_multiplier
         The ratio of the standard deviation of the Gaussian noise to
         the L2-sensitivity of the function to which the noise is added
@@ -68,7 +71,7 @@ def apply_noise(weights, batch_size, noise_multiplier, noise_type, device, loss_
         currently supported: mean
     """
     for p in weights.values():
-        noise = generate_noise(0, p, noise_multiplier, noise_type, device)
+        noise = generate_noise(max_norm, p, noise_multiplier, noise_type, device)
         if loss_reduction == "mean":
             noise /= batch_size
         p += noise