Browse Source

allow lists or tuples for weights as well

Jens Keim 3 năm trước cách đây
mục cha
commit
a39ff65a6a
1 tập tin đã thay đổi với 4 bổ sung1 xóa
  1. 4 1
      privacy_engine_xl.py

+ 4 - 1
privacy_engine_xl.py

@@ -70,7 +70,10 @@ def apply_noise(weights, batch_size, max_norm, noise_multiplier, noise_type, dev
         The method of loss reduction.
         currently supported: mean
     """
-    for p in weights.values():
+    if isinstance(weights, dict):
+        weights = weights.values()
+
+    for p in weights:
         noise = generate_noise(max_norm, p, noise_multiplier, noise_type, device)
         if loss_reduction == "mean":
             noise /= batch_size