Skip to content

Commit

Permalink
Add more tests for temporal fusion transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Beitner committed Aug 16, 2020
1 parent 0be859b commit bfe6a5b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 35 deletions.
67 changes: 39 additions & 28 deletions tests/test_models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from data import get_stallion_data, generate_ar_data
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder, EncoderNormalizer
from pytorch_forecasting.data import GroupNormalizer, NaNLabelEncoder, EncoderNormalizer


@pytest.fixture
Expand Down Expand Up @@ -35,8 +35,42 @@ def data_with_covariates():
return data


@pytest.fixture
def dataloaders_with_coveratiates(data_with_covariates):
@pytest.fixture(
params=[
dict(),
dict(
static_categoricals=["agency", "sku"],
static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
time_varying_known_categoricals=["special_days", "month"],
variable_groups=dict(
special_days=[
"easter_day",
"good_friday",
"new_year",
"christmas",
"labor_day",
"independence_day",
"revolution_day_memorial",
"regional_games",
"fifa_u_17_world_cup",
"football_gold_cup",
"beer_capital",
"music_fest",
]
),
time_varying_known_reals=["time_idx", "price_regular", "price_actual", "discount", "discount_in_percent"],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=["volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp"],
constant_fill_strategy={"volume": 0},
dropout_categoricals=["sku"],
),
dict(static_categoricals=["agency", "sku"]),
dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2),
dict(target_normalizer=GroupNormalizer(log_scale=True)),
dict(target_normalizer=GroupNormalizer(groups=["agency", "sku"], coerce_positive=1.0)),
]
)
def dataloaders_with_coveratiates(data_with_covariates, request):
training_cutoff = "2016-09-01"
max_encoder_length = 36
max_prediction_length = 6
Expand All @@ -49,30 +83,7 @@ def dataloaders_with_coveratiates(data_with_covariates):
group_ids=["agency", "sku"],
max_encoder_length=max_encoder_length,
max_prediction_length=max_prediction_length,
static_categoricals=["agency", "sku"],
static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
time_varying_known_categoricals=["special_days", "month"],
variable_groups=dict(
special_days=[
"easter_day",
"good_friday",
"new_year",
"christmas",
"labor_day",
"independence_day",
"revolution_day_memorial",
"regional_games",
"fifa_u_17_world_cup",
"football_gold_cup",
"beer_capital",
"music_fest",
]
),
time_varying_known_reals=["time_idx", "price_regular", "price_actual", "discount", "discount_in_percent"],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=["volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp"],
constant_fill_strategy={"volume": 0},
dropout_categoricals=["sku"],
**request.param # fixture parametrization
)

validation = TimeSeriesDataSet.from_dataset(
Expand All @@ -85,7 +96,7 @@ def dataloaders_with_coveratiates(data_with_covariates):
return dict(train=train_dataloader, val=val_dataloader)


@pytest.fixture
@pytest.fixture()
def dataloaders_fixed_window_without_coveratiates(data_with_covariates):
data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=10)
data["static"] = "2"
Expand Down
15 changes: 8 additions & 7 deletions tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pytorch_forecasting.data import TimeSeriesDataSet
import pytest
import shutil
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
Expand All @@ -7,8 +9,6 @@


# todo: run with multiple normalizers
# todo: run with muliple datasets and normalizers: ...
# todo: monotonicity
# todo: test different parameters
def test_integration(dataloaders_with_coveratiates, tmp_path):
train_dataloader = dataloaders_with_coveratiates["train"]
Expand All @@ -28,7 +28,11 @@ def test_integration(dataloaders_with_coveratiates, tmp_path):
fast_dev_run=True,
logger=logger,
)

# test monotone constraints automatically
if "discount_in_percent" in dataloaders_with_coveratiates["train"].dataset.reals:
monotone_constaints = {"discount_in_percent": +1}
else:
monotone_constaints = {}
net = TemporalFusionTransformer.from_dataset(
train_dataloader.dataset,
learning_rate=0.15,
Expand All @@ -40,7 +44,7 @@ def test_integration(dataloaders_with_coveratiates, tmp_path):
log_interval=5,
log_val_interval=1,
log_gradient_flow=True,
monotone_constaints={"discount_in_percent": +1},
monotone_constaints=monotone_constaints,
)
net.size()
try:
Expand All @@ -57,6 +61,3 @@ def test_integration(dataloaders_with_coveratiates, tmp_path):
finally:
shutil.rmtree(tmp_path, ignore_errors=True)


def test_monotinicity():
pass

0 comments on commit bfe6a5b

Please sign in to comment.