diff --git a/src/defense/evasion_defense.py b/src/defense/evasion_defense.py index 52a9a0f..7604edc 100644 --- a/src/defense/evasion_defense.py +++ b/src/defense/evasion_defense.py @@ -94,6 +94,7 @@ def __init__( num_levels: int = 32 ): super().__init__() + assert num_levels > 1 self.num_levels = num_levels def pre_batch( @@ -112,9 +113,12 @@ def quantize( ): x_min = x.min() x_max = x.max() - x_normalized = (x - x_min) / (x_max - x_min) - x_quantized = torch.round(x_normalized * (self.num_levels - 1)) / (self.num_levels - 1) - x_quantized = x_quantized * (x_max - x_min) + x_min + if x_min != x_max: + x_normalized = (x - x_min) / (x_max - x_min) + x_quantized = torch.round(x_normalized * (self.num_levels - 1)) / (self.num_levels - 1) + x_quantized = x_quantized * (x_max - x_min) + x_min + else: + x_quantized = x return x_quantized