|
@@ -47,7 +47,7 @@ def generate_noise(max_norm, parameter, noise_multiplier, noise_type, device):
|
|
|
return 0.0
|
|
|
|
|
|
# Server side Noise
|
|
|
-def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, device, loss_reduction="mean"):
|
|
|
+def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, device, loss_reduction="mean", clipping=False):
|
|
|
"""
|
|
|
A function for applying noise to weights on the (intermediate) server side that utilizes the generate_noise function above.
|
|
|
|
|
@@ -73,7 +73,15 @@ def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, dev
|
|
|
if isinstance(weights, dict):
|
|
|
weights = weights.values()
|
|
|
|
|
|
+ if max_norm == None:
|
|
|
+ max_norm = 1.0
|
|
|
+
|
|
|
for p in weights:
|
|
|
+ if clipping:
|
|
|
+ norm = torch.norm(p, p=2)
|
|
|
+ div_norm = max(1, norm/max_norm)
|
|
|
+ p /= div_norm
|
|
|
+
|
|
|
noise = generate_noise(max_norm, p, noise_multiplier, noise_type, device)
|
|
|
if loss_reduction == "mean":
|
|
|
noise /= batch_size
|