|
@@ -70,7 +70,10 @@ def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, dev
|
|
The method of loss reduction.
|
|
The method of loss reduction.
|
|
currently supported: mean
|
|
currently supported: mean
|
|
"""
|
|
"""
|
|
- for p in weights.values():
|
|
|
|
|
|
+ if isinstance(weights, dict):
|
|
|
|
+ weights = weights.values()
|
|
|
|
+
|
|
|
|
+ for p in weights:
|
|
noise = generate_noise(max_norm, p, noise_multiplier, noise_type, device)
|
|
noise = generate_noise(max_norm, p, noise_multiplier, noise_type, device)
|
|
if loss_reduction == "mean":
|
|
if loss_reduction == "mean":
|
|
noise /= batch_size
|
|
noise /= batch_size
|