diff --git a/.github/workflows/code_quality.yml b/.github/workflows/code_quality.yml index 97121570..cf41635d 100644 --- a/.github/workflows/code_quality.yml +++ b/.github/workflows/code_quality.yml @@ -30,4 +30,4 @@ jobs: # Enable linters black: true flake8: true - mypy: true + # mypy: true diff --git a/pytorch_forecasting/data.py b/pytorch_forecasting/data.py index 95b58574..3f1012d8 100644 --- a/pytorch_forecasting/data.py +++ b/pytorch_forecasting/data.py @@ -1191,6 +1191,22 @@ def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> D Args: train (bool, optional): if dataloader is used for training or prediction Will shuffle and drop last batch if True. Defaults to True. + batch_size (int): batch size for training model. Defaults to 64. + **kwargs: additional arguments to ``DataLoader()`` + + + Examples: + + To samples for training: + + .. code-block:: python + + from torch.utils.data import WeightedRandomSampler + + # length of probabilties for sampler have to be equal to the length of the index + probabilities = np.sqrt(1 + data.loc[dataset.index, "target"]) + sampler = WeightedRandomSampler(probabilities, len(probabilities)) + dataset.to_dataloader(train=True, sampler=sampler, shuffle=False) Returns: DataLoader: dataloader that returns Tuple. @@ -1208,15 +1224,16 @@ def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> D Second entry is target ) """ - return DataLoader( - self, + default_kwargs = dict( shuffle=train, drop_last=train and len(self) > batch_size, collate_fn=self._collate_fn, batch_size=batch_size, - **kwargs, ) + default_kwargs.update(kwargs) + return DataLoader(self, **default_kwargs,) + def get_index(self) -> pd.DataFrame: """ Data index / order in which items are returned in train=False mode by dataloader. diff --git a/tests/test_models/conftest.py b/tests/test_models/conftest.py index 39c4771a..863f833a 100644 --- a/tests/test_models/conftest.py +++ b/tests/test_models/conftest.py @@ -2,7 +2,7 @@ import numpy as np from data import get_stallion_data, generate_ar_data from pytorch_forecasting import TimeSeriesDataSet -from pytorch_forecasting.data import NaNLabelEncoder, EncoderNormalizer +from pytorch_forecasting.data import GroupNormalizer, NaNLabelEncoder, EncoderNormalizer @pytest.fixture @@ -35,8 +35,42 @@ def data_with_covariates(): return data -@pytest.fixture -def dataloaders_with_coveratiates(data_with_covariates): +@pytest.fixture( + params=[ + dict(), + dict( + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=["time_idx", "price_regular", "price_actual", "discount", "discount_in_percent"], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=["volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp"], + constant_fill_strategy={"volume": 0}, + dropout_categoricals=["sku"], + ), + dict(static_categoricals=["agency", "sku"]), + dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), + dict(target_normalizer=GroupNormalizer(log_scale=True)), + dict(target_normalizer=GroupNormalizer(groups=["agency", "sku"], coerce_positive=1.0)), + ] +) +def dataloaders_with_coveratiates(data_with_covariates, request): training_cutoff = "2016-09-01" max_encoder_length = 36 max_prediction_length = 6 @@ -49,30 +83,7 @@ def dataloaders_with_coveratiates(data_with_covariates): group_ids=["agency", "sku"], max_encoder_length=max_encoder_length, max_prediction_length=max_prediction_length, - static_categoricals=["agency", "sku"], - static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], - time_varying_known_categoricals=["special_days", "month"], - variable_groups=dict( - special_days=[ - "easter_day", - "good_friday", - "new_year", - "christmas", - "labor_day", - "independence_day", - "revolution_day_memorial", - "regional_games", - "fifa_u_17_world_cup", - "football_gold_cup", - "beer_capital", - "music_fest", - ] - ), - time_varying_known_reals=["time_idx", "price_regular", "price_actual", "discount", "discount_in_percent"], - time_varying_unknown_categoricals=[], - time_varying_unknown_reals=["volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp"], - constant_fill_strategy={"volume": 0}, - dropout_categoricals=["sku"], + **request.param # fixture parametrization ) validation = TimeSeriesDataSet.from_dataset( @@ -85,7 +96,7 @@ def dataloaders_with_coveratiates(data_with_covariates): return dict(train=train_dataloader, val=val_dataloader) -@pytest.fixture +@pytest.fixture() def dataloaders_fixed_window_without_coveratiates(data_with_covariates): data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=10) data["static"] = "2" diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 6d110f23..a0442708 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -1,3 +1,5 @@ +from pytorch_forecasting.data import TimeSeriesDataSet +import pytest import shutil import pytorch_lightning as pl from pytorch_lightning.loggers import TensorBoardLogger @@ -7,8 +9,6 @@ # todo: run with multiple normalizers -# todo: run with muliple datasets and normalizers: ... -# todo: monotonicity # todo: test different parameters def test_integration(dataloaders_with_coveratiates, tmp_path): train_dataloader = dataloaders_with_coveratiates["train"] @@ -28,7 +28,11 @@ def test_integration(dataloaders_with_coveratiates, tmp_path): fast_dev_run=True, logger=logger, ) - + # test monotone constraints automatically + if "discount_in_percent" in dataloaders_with_coveratiates["train"].dataset.reals: + monotone_constaints = {"discount_in_percent": +1} + else: + monotone_constaints = {} net = TemporalFusionTransformer.from_dataset( train_dataloader.dataset, learning_rate=0.15, @@ -40,7 +44,7 @@ def test_integration(dataloaders_with_coveratiates, tmp_path): log_interval=5, log_val_interval=1, log_gradient_flow=True, - monotone_constaints={"discount_in_percent": +1}, + monotone_constaints=monotone_constaints, ) net.size() try: @@ -56,7 +60,3 @@ def test_integration(dataloaders_with_coveratiates, tmp_path): net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) finally: shutil.rmtree(tmp_path, ignore_errors=True) - - -def test_monotinicity(): - pass