Pārlūkot izejas kodu

add functions for server side noise

Jens Keim 4 gadi atpakaļ
vecāks
revīzija
a3ab6c0c1a
1 mainītis faili ar 29 papildinājumiem un 0 dzēšanām
  1. 29 0
      src/privacy_engine_xl.py

+ 29 - 0
src/privacy_engine_xl.py

@@ -3,6 +3,35 @@ import opacus
 from typing import List, Union
 import os
 
+def generate_noise(max_norm, parameter, sigma, noise_type, device):
+    if sigma > 0:
+        mean = 0
+        scale_scalar = sigma * max_norm
+
+        scale = torch.full(size=parameter.shape, fill_value=scale_scalar, dtype=torch.float32, device=device)
+
+        if noise_type == "gaussian":
+            dist = torch.distributions.normal.Normal(mean, scale)
+        elif noise_type == "laplacian":
+            dist = torch.distributions.laplace.Laplace(mean, scale)
+        elif noise_type == "exponential":
+            rate = 1 / scale
+            dist = torch.distributions.exponential.Exponential(rate)
+        else:
+            dist = torch.distributions.normal.Normal(mean, scale)
+
+        noise = dist.sample()
+
+        return noise
+    return 0.0
+
+def apply_noise(weights, batch_size, sigma, noise_type, device, loss_reduction="mean"):
+    for p in weights.values():
+        noise = generate_noise(0, p, sigma, noise_type, device)
+        if loss_reduction == "mean":
+            noise /= batch_size
+        p += noise
+
 class PrivacyEngineXL(opacus.PrivacyEngine):
 
     def __init__(