Skip to content

Commit

Permalink
Fix formatting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Beitner committed Aug 16, 2020
1 parent bfe6a5b commit 9164982
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
13 changes: 6 additions & 7 deletions pytorch_forecasting/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,21 +1193,20 @@ def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> D
Will shuffle and drop last batch if True. Defaults to True.
batch_size (int): batch size for training model. Defaults to 64.
**kwargs: additional arguments to ``DataLoader()``
Examples:
To samples for training:
.. code-block:: python
from torch.utils.data import WeightedRandomSampler
# length of probabilties for sampler have to be equal to the length of the index
probabilities = np.sqrt(1 + data.loc[dataset.index, "target"])
sampler = WeightedRandomSampler(probabilities, len(probabilities))
dataset.to_dataloader(train=True, sampler=sampler, shuffle=False)
Returns:
DataLoader: dataloader that returns Tuple.
Expand Down
1 change: 0 additions & 1 deletion tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,3 @@ def test_integration(dataloaders_with_coveratiates, tmp_path):
net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)
finally:
shutil.rmtree(tmp_path, ignore_errors=True)

0 comments on commit 9164982

Please sign in to comment.