Skip to content

Commit

Permalink
docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Jan 4, 2025
1 parent 159e87b commit 99440c8
Showing 1 changed file with 118 additions and 70 deletions.
188 changes: 118 additions & 70 deletions pytorch_forecasting/data/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,59 +226,93 @@ class TimeSeriesDataSet(Dataset):
entries can be also lists which are then encoded together
(e.g. useful for product categories)
static_reals (List[str]): list of continuous variables that do not change over time
time_varying_known_categoricals (List[str]): list of categorical variables that change over
time and are known in the future, entries can be also lists which are then encoded together
static_reals : list of str, optional, default=None
list of continuous variables that do not change over time
time_varying_known_categoricals : list of str, optional, default=None
list of categorical variables that change over time and are known in the future,
entries can be also lists which are then encoded together
(e.g. useful for special days or promotion categories)
time_varying_known_reals (List[str]): list of continuous variables that change over
time and are known in the future (e.g. price of a product, but not demand of a product)
time_varying_unknown_categoricals (List[str]): list of categorical variables that change over
time and are not known in the future, entries can be also lists which are then encoded together
(e.g. useful for weather categories). You might want to include your target here.
time_varying_unknown_reals (List[str]): list of continuous variables that change over
time and are not known in the future. You might want to include your target here.
variable_groups (Dict[str, List[str]]): dictionary mapping a name to a list of columns in the data.
time_varying_known_reals : list of str, optional, default=None
list of continuous variables that change over time and are known in the future
(e.g. price of a product, but not demand of a product)
time_varying_unknown_categoricals : list of str, optional, default=None
list of categorical variables that are not known in the future
and change over time.
entries can be also lists which are then encoded together
(e.g. useful for weather categories).
Target variables should be included here, if categorical.
time_varying_unknown_reals : list of str, optional, default=None
list of continuous variables that are not known in the future
and change over time.
Target variables should be included here, if real.
variable_groups : Dict[str, List[str]], optional, default=None
dictionary mapping a name to a list of columns in the data.
The name should be present
in a categorical or real class argument, to be able to encode or scale the columns by group.
This will effectively combine categorical variables is particularly useful if a categorical variable can
have multiple values at the same time. An example are holidays which can be overlapping.
constant_fill_strategy (Dict[str, Union[str, float, int, bool]]): dictionary of column names with
constants to fill in missing values if there are
gaps in the sequence (by default forward fill strategy is used). The values will be only used if
``allow_missing_timesteps=True``. A common use case is to denote that demand was 0 if the sample
is not in the dataset.
allow_missing_timesteps (bool): if to allow missing timesteps that are automatically filled up. Missing
values
refer to gaps in the ``time_idx``, e.g. if a specific timeseries has only samples for
1, 2, 4, 5, the sample for 3 will be generated on-the-fly.
Allow missings does not deal with ``NA`` values. You should fill NA values before
passing the dataframe to the TimeSeriesDataSet.
lags (Dict[str, List[int]]): dictionary of variable names mapped to list of time steps by
which the variable should be lagged.
Lags can be useful to indicate seasonality to the models. If you know the seasonalit(ies) of your data,
add at least the target variables with the corresponding lags to improve performance.
Lags must be at not larger than the shortest time series as all time series will be cut by the largest
lag value to prevent NA values. A lagged variable has to appear in the time-varying variables. If you
only want the lagged but not the current value, lag it manually in your input data using
``data[lagged_variable_name] = data.sort_values(time_idx).groupby(group_ids, observed=True).shift(lag)``
.
Defaults to no lags.
add_relative_time_idx (bool): if to add a relative time index as feature (i.e. for each sampled sequence,
the index will range from -encoder_length to prediction_length)
add_target_scales (bool): if to add scales for target to static real features (i.e. add the center and scale
of the unnormalized timeseries as features)
add_encoder_length (bool): if to add encoder length to list of static real variables.
Defaults to "auto", i.e. ``True`` if ``min_encoder_length != max_encoder_length``.
target_normalizer (Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer, str, list, tuple]):
transformer that take group_ids, target and time_idx to normalize targets.
in a categorical or real class argument, to be able to encode or scale the
columns by group.
This will effectively combine categorical variables is particularly useful
if a categorical variable can have multiple values at the same time.
An example are holidays which can be overlapping.
constant_fill_strategy : dict, optional, default=None
Keys must be str, values can be str, float, int or bool.
Dictionary of column names with constants to fill in missing values if there
are gaps in the sequence (by default forward fill strategy is used).
The values will be only used if ``allow_missing_timesteps=True``.
A common use case is to denote that demand was 0 if the sample is not in the
dataset.
allow_missing_timesteps : bool, optional, default=False
whether to allow missing timesteps that are automatically filled up.
Missing values refer to gaps in the ``time_idx``, e.g. if a specific
timeseries has only samples for 1, 2, 4, 5, the sample for 3 will be
generated on-the-fly.
Allow missings does not deal with ``NA`` values. You should fill NA values
before passing the dataframe to the TimeSeriesDataSet.
lags : Dict[str, List[int]], optional, default=None
dictionary of variable names mapped to list of time steps by which the
variable should be lagged.
Lags can be useful to indicate seasonality to the models.
Useful to add if seasonalit(ies) of the data are known.,
In this case, it is recommended to add the target variables
with the corresponding lags to improve performance.
Lags must be at not larger than the shortest time series as all time series
will be cut by the largest lag value to prevent NA values.
A lagged variable has to appear in the time-varying variables.
If you only want the lagged but not the current value, lag it manually in
your input data using
``data[lagged_varname] = ``
``data.sort_values(time_idx).groupby(group_ids, observed=True).shift(lag)``.
add_relative_time_idx : bool, optional, default=False
whether to add a relative time index as feature, i.e.,
for each sampled sequence, the index will range from -encoder_length to
prediction_length.
add_target_scales : bool, optional, default=False
whether to add scales for target to static real features, i.e., add the
center and scale of the unnormalized timeseries as features.
add_encoder_length : Union[bool, str], optional, default="auto"
whether to add encoder length to list of static real variables.
Defaults to "auto", iwhich is same as
``True`` iff ``min_encoder_length != max_encoder_length``.
target_normalizer : torch transformer, str, list, tuple, optional, default="auto"
Transformer that takes group_ids, target and time_idx to normalize targets.
You can choose from
:py:class:`~pytorch_forecasting.data.encoders.TorchNormalizer`,
:py:class:`~pytorch_forecasting.data.encoders.GroupNormalizer`,
:py:class:`~pytorch_forecasting.data.encoders.NaNLabelEncoder`,
:py:class:`~pytorch_forecasting.data.encoders.EncoderNormalizer`
(on which overfitting tests will fail)
or `None` for using no normalizer. For multiple targets, use a
or ``None`` for using no normalizer. For multiple targets, use a
:py:class`~pytorch_forecasting.data.encoders.MultiNormalizer`.
By default an appropriate normalizer is chosen automatically.
Expand Down Expand Up @@ -984,9 +1018,11 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
), "__time_idx__ is a protected column and must not be present in data"
data["__time_idx__"] = data[self.time_idx] # save unscaled
for target in self.target_names:
assert (
f"__target__{target}" not in data.columns
), f"__target__{target} is a protected column and must not be present in data"
msg = (
f"__target__{target} is a protected column "
"and must not be present in data"
)
assert f"__target__{target}" not in data.columns, msg
data[f"__target__{target}"] = data[target]
if self.weight is not None:
data["__weight__"] = data[self.weight]
Expand Down Expand Up @@ -1042,12 +1078,14 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
data[target] = transformed[idx]

if isinstance(self.target_normalizer[idx], NaNLabelEncoder):
# overwrite target because it requires encoding (continuous targets should not be normalized)
# overwrite target because it requires encoding
# (continuous targets should not be normalized)
data[f"__target__{target}"] = data[target]

elif isinstance(self.target_normalizer, NaNLabelEncoder):
data[self.target] = self.target_normalizer.transform(data[self.target])
# overwrite target because it requires encoding (continuous targets should not be normalized)
# overwrite target because it requires encoding
# (continuous targets should not be normalized)
data[f"__target__{self.target}"] = data[self.target]
scales = None

Expand Down Expand Up @@ -1124,24 +1162,29 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
name, np.array([value]), data=data, inverse=False
)[0]

# shorten data by maximum of lagged sequences to avoid NA values - shorten only after encoding
# shorten data by maximum of lagged sequences to avoid NA values -
# shorten only after encoding
if self.max_lag > 0:
# negative tail implementation as .groupby().tail(-self.max_lag) is not implemented in pandas
# negative tail implementation as .groupby().tail(-self.max_lag)
# is not implemented in pandas
g = data.groupby(self._group_ids, observed=True)
data = g._selected_obj[g.cumcount() >= self.max_lag]
return data

def get_transformer(self, name: str, group_id: bool = False):
"""
Get transformer for variable.
"""Get transformer for variable.
Args:
name (str): variable name
group_id (bool, optional): If the passed name refers to a group id (different encoders are used for these).
Defaults to False.
Parameters
----------
name : str
variable name
group_id : bool, optional, default=False
Whether the passed name refers to a group id,
different encoders are used for these.
Returns:
transformer
Returns
-------
transformer
"""
if group_id:
name = self._group_ids_mapping[name]
Expand Down Expand Up @@ -1188,7 +1231,7 @@ def transform_values(
values : Union[pd.Series, torch.Tensor, np.ndarray]
values to encode/scale
data : pd.DataFrame, optional, default=None
extra data used for scaling (e.g. dataframe with groups columns), by default None
extra data used for scaling (e.g. dataframe with groups columns)
inverse : bool, optional, default=False
whether transform is plain (True), or inverse (False)
group_id : bool, optional, default=False
Expand Down Expand Up @@ -1232,15 +1275,18 @@ def transform_values(
return values

def _data_to_tensors(self, data: pd.DataFrame) -> Dict[str, torch.Tensor]:
"""
Convert data to tensors for faster access with :py:meth:`~__getitem__`.
"""Convert data to tensors for faster access with :py:meth:`~__getitem__`.
Args:
data (pd.DataFrame): preprocessed data
Parameters
----------
data : pd.DataFrame
preprocessed data
Returns:
Dict[str, torch.Tensor]: dictionary of tensors for continous, categorical data, groups, target and
time index
Returns
-------
Dict[str, torch.Tensor]
dictionary of tensors for continous, categorical data, groups, target and
time index
"""

index = check_for_nonfinite(
Expand Down Expand Up @@ -1366,8 +1412,10 @@ def variable_to_group_mapping(self) -> Dict[str, str]:
"""
Mapping from categorical variables to variables in input data.
Returns:
Dict[str, str]: dictionary mapping from :py:meth:`~categorical` to :py:meth:`~flat_categoricals`.
Returns
-------
Dict[str, str]
dictionary, maps :py:meth:`~categorical` to :py:meth:`~flat_categoricals`.
"""
groups = {}
for group_name, sublist in self._variable_groups.items():
Expand Down

0 comments on commit 99440c8

Please sign in to comment.