diff --git a/pytorch_forecasting/data.py b/pytorch_forecasting/data.py index 0d9034cd..3f1012d8 100644 --- a/pytorch_forecasting/data.py +++ b/pytorch_forecasting/data.py @@ -1193,21 +1193,20 @@ def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> D Will shuffle and drop last batch if True. Defaults to True. batch_size (int): batch size for training model. Defaults to 64. **kwargs: additional arguments to ``DataLoader()`` - - + + Examples: - + To samples for training: - + .. code-block:: python - + from torch.utils.data import WeightedRandomSampler - + # length of probabilties for sampler have to be equal to the length of the index probabilities = np.sqrt(1 + data.loc[dataset.index, "target"]) sampler = WeightedRandomSampler(probabilities, len(probabilities)) dataset.to_dataloader(train=True, sampler=sampler, shuffle=False) - Returns: DataLoader: dataloader that returns Tuple. diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index a2e6064a..a0442708 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -60,4 +60,3 @@ def test_integration(dataloaders_with_coveratiates, tmp_path): net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) finally: shutil.rmtree(tmp_path, ignore_errors=True) -