|
@@ -10,11 +10,11 @@ def generate_noise(max_norm, parameter, sigma, noise_type, device):
|
|
|
|
|
|
scale = torch.full(size=parameter.shape, fill_value=scale_scalar, dtype=torch.float32, device=device)
|
|
scale = torch.full(size=parameter.shape, fill_value=scale_scalar, dtype=torch.float32, device=device)
|
|
|
|
|
|
- if noise_type == "gaussian":
|
|
|
|
|
|
+ if noise_type.lower() in ["normal", "gauss", "gaussian"]:
|
|
dist = torch.distributions.normal.Normal(mean, scale)
|
|
dist = torch.distributions.normal.Normal(mean, scale)
|
|
- elif noise_type == "laplacian":
|
|
|
|
|
|
+ elif noise_type.lower() in ["laplace", "laplacian"]:
|
|
dist = torch.distributions.laplace.Laplace(mean, scale)
|
|
dist = torch.distributions.laplace.Laplace(mean, scale)
|
|
- elif noise_type == "exponential":
|
|
|
|
|
|
+ elif noise_type.lower() in ["exponential"]:
|
|
rate = 1 / scale
|
|
rate = 1 / scale
|
|
dist = torch.distributions.exponential.Exponential(rate)
|
|
dist = torch.distributions.exponential.Exponential(rate)
|
|
else:
|
|
else:
|