diff --git a/src/timediffusion/utils.py b/src/timediffusion/utils.py index f8650a9..116e8a6 100644 --- a/src/timediffusion/utils.py +++ b/src/timediffusion/utils.py @@ -81,4 +81,4 @@ def fit_transform(self, data): return self.transform(data) def inverse_transform(self, data): - return data * self.std + self.mu + return data * (self.std + self.eps) + self.mu diff --git a/tests/test_utils.py b/tests/test_utils.py index 268115c..a298111 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,9 +5,25 @@ import torch from torch import nn -from timediffusion import count_params, kl_div +from timediffusion import count_params, kl_div, DimUniversalStandardScaler + +@pytest.mark.parametrize( + "arr", + [ + np.sin(np.arange(10)) * 10, + torch.arange(10).float(), + ] +) +def test_duscaler(arr): + scaler = DimUniversalStandardScaler() + tarr = scaler.fit_transform(arr) + tarr1 = scaler.transform(arr) + rarr = scaler.inverse_transform(tarr) + assert abs((tarr - tarr1).mean()) < scaler.eps + assert abs((rarr - arr).mean()) < scaler.eps + @pytest.mark.parametrize( "x,y", [