Skip to content

Commit

Permalink
Merge pull request #69 from jdb78/feature/fix_encoder_length_naming
Browse files Browse the repository at this point in the history
Feature should be called encoder_length and not decoder_length
  • Loading branch information
jdb78 authored Sep 29, 2020
2 parents 21dc7c9 + 410d4fe commit b0b9068
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
16 changes: 8 additions & 8 deletions pytorch_forecasting/data/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ def __init__(
# add decoder length to static real variables
if self.add_encoder_length:
assert (
"decoder_length" not in data.columns
), "decoder_length is a protected column and must not be present in data"
if "decoder_length" not in self.time_varying_known_reals and "decoder_length" not in self.reals:
self.static_reals.append("decoder_length")
data["decoder_length"] = 0.0 # dummy - real value will be set dynamiclly in __getitem__()
"encoder_length" not in data.columns
), "encoder_length is a protected column and must not be present in data"
if "encoder_length" not in self.time_varying_known_reals and "encoder_length" not in self.reals:
self.static_reals.append("encoder_length")
data["encoder_length"] = 0.0 # dummy - real value will be set dynamiclly in __getitem__()

# validate
self._validate_data(data)
Expand Down Expand Up @@ -916,9 +916,9 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
)

if self.add_encoder_length:
data_cont[:, self.reals.index("decoder_length")] = (
decoder_length - 0.5 * self.max_encoder_length
) / self.max_encoder_length
data_cont[:, self.reals.index("encoder_length")] = (
(encoder_length - 0.5 * self.max_encoder_length) / self.max_encoder_length * 2.0
)

# rescale target
if self.target_normalizer is not None and isinstance(self.target_normalizer, EncoderNormalizer):
Expand Down
30 changes: 30 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,36 @@ def test_from_dataset(test_dataset, test_data):
check_dataloader_output(dataset, next(iter(dataset.to_dataloader(num_workers=0))))


def test_from_dataset_equivalence(test_data):
training = TimeSeriesDataSet(
test_data[lambda x: x.time_idx < x.time_idx.max() - 1],
time_idx="time_idx",
target="volume",
time_varying_known_reals=["price_regular", "time_idx"],
group_ids=["agency", "sku"],
static_categoricals=["agency"],
max_encoder_length=3,
max_prediction_length=2,
min_prediction_length=1,
min_encoder_length=0,
randomize_length=None,
add_encoder_length=True,
add_relative_time_idx=True,
add_target_scales=True,
)
validation1 = TimeSeriesDataSet.from_dataset(training, test_data, predict=True)
validation2 = TimeSeriesDataSet.from_dataset(
training,
test_data[lambda x: x.time_idx > x.time_idx.min() + 2],
predict=True,
)
# ensure validation1 and validation2 datasets are exactly the same despite different data inputs
for v1, v2 in zip(iter(validation1.to_dataloader(train=False)), iter(validation2.to_dataloader(train=False))):
for k in v1[0].keys():
assert torch.isclose(v1[0][k], v2[0][k]).all()
assert torch.isclose(v1[1], v2[1]).all()


def test_dataset_index(test_dataset):
index = []
for x, _ in iter(test_dataset.to_dataloader()):
Expand Down

0 comments on commit b0b9068

Please sign in to comment.