Skip to content

Commit

Permalink
scaler formula fix + tests on scaler
Browse files Browse the repository at this point in the history
  • Loading branch information
timetoai committed Sep 2, 2023
1 parent 0c6f327 commit 6fa07b1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/timediffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def fit_transform(self, data):
return self.transform(data)

def inverse_transform(self, data):
return data * self.std + self.mu
return data * (self.std + self.eps) + self.mu
18 changes: 17 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,25 @@
import torch
from torch import nn

from timediffusion import count_params, kl_div
from timediffusion import count_params, kl_div, DimUniversalStandardScaler



@pytest.mark.parametrize(
"arr",
[
np.sin(np.arange(10)) * 10,
torch.arange(10).float(),
]
)
def test_duscaler(arr):
scaler = DimUniversalStandardScaler()
tarr = scaler.fit_transform(arr)
tarr1 = scaler.transform(arr)
rarr = scaler.inverse_transform(tarr)
assert abs((tarr - tarr1).mean()) < scaler.eps
assert abs((rarr - arr).mean()) < scaler.eps

@pytest.mark.parametrize(
"x,y",
[
Expand Down

0 comments on commit 6fa07b1

Please sign in to comment.