Pārlūkot izejas kodu

catch more wordings for noise type

use .lower() to compare
Jens Keim 3 gadi atpakaļ
vecāks
revīzija
7cbe19d755
1 mainītis faili ar 3 papildinājumiem un 3 dzēšanām
  1. 3 3
      privacy_engine_xl.py

+ 3 - 3
privacy_engine_xl.py

@@ -10,11 +10,11 @@ def generate_noise(max_norm, parameter, sigma, noise_type, device):
 
         scale = torch.full(size=parameter.shape, fill_value=scale_scalar, dtype=torch.float32, device=device)
 
-        if noise_type == "gaussian":
+        if noise_type.lower() in ["normal", "gauss", "gaussian"]:
             dist = torch.distributions.normal.Normal(mean, scale)
-        elif noise_type == "laplacian":
+        elif noise_type.lower() in ["laplace", "laplacian"]:
             dist = torch.distributions.laplace.Laplace(mean, scale)
-        elif noise_type == "exponential":
+        elif noise_type.lower() in ["exponential"]:
             rate = 1 / scale
             dist = torch.distributions.exponential.Exponential(rate)
         else: