Skip to content

Commit

Permalink
Improve docs for to_dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Beitner committed Aug 16, 2020
1 parent 855e1a5 commit 0be859b
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions pytorch_forecasting/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,23 @@ def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> D
Args:
train (bool, optional): if dataloader is used for training or prediction
Will shuffle and drop last batch if True. Defaults to True.
batch_size (int): batch size for training model. Defaults to 64.
**kwargs: additional arguments to ``DataLoader()``
Examples:
To samples for training:
.. code-block:: python
from torch.utils.data import WeightedRandomSampler
# length of probabilties for sampler have to be equal to the length of the index
probabilities = np.sqrt(1 + data.loc[dataset.index, "target"])
sampler = WeightedRandomSampler(probabilities, len(probabilities))
dataset.to_dataloader(train=True, sampler=sampler, shuffle=False)
Returns:
DataLoader: dataloader that returns Tuple.
Expand All @@ -1208,15 +1225,16 @@ def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> D
Second entry is target
)
"""
return DataLoader(
self,
default_kwargs = dict(
shuffle=train,
drop_last=train and len(self) > batch_size,
collate_fn=self._collate_fn,
batch_size=batch_size,
**kwargs,
)

default_kwargs.update(kwargs)
return DataLoader(self, **default_kwargs,)

def get_index(self) -> pd.DataFrame:
"""
Data index / order in which items are returned in train=False mode by dataloader.
Expand Down

0 comments on commit 0be859b

Please sign in to comment.