Skip to content

Commit

Permalink
Merge pull request #30 from jdb78/feature/dependency
Browse files Browse the repository at this point in the history
Add calculation of partial dependencies
  • Loading branch information
jdb78 authored Sep 2, 2020
2 parents 41c5f3b + aa1ba56 commit 6843748
Show file tree
Hide file tree
Showing 9 changed files with 765 additions and 133 deletions.
543 changes: 468 additions & 75 deletions docs/source/tutorials/stallion.ipynb

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions pytorch_forecasting/data/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,14 +691,18 @@ def __len__(self) -> int:
"""
return self.index.shape[0]

def set_overwrite_values(self, values: Union[float, torch.Tensor], variable: str, target: str = "decoder") -> None:
def set_overwrite_values(
self, values: Union[float, torch.Tensor], variable: str, target: Union[str, slice] = "decoder"
) -> None:
"""
Convenience method to quickly overwrite values in decoder or encoder (or both) for a specific variable.
Args:
values (Union[float, torch.Tensor]): values to use for overwrite.
variable (str): variable whose values should be overwritten.
target (str, optional): positions to overwrite. One of "decoder", "encoder" or "all". Defaults to "decoder".
target (Union[str, slice], optional): positions to overwrite. One of "decoder", "encoder" or "all" or
a slice object which is directly used to overwrite indices, e.g. ``slice(-5, None)`` will overwrite
the last 5 values. Defaults to "decoder".
"""
values = torch.tensor(self.transform_values(variable, np.asarray(values).reshape(-1), inverse=False)).squeeze()
assert target in [
Expand All @@ -707,6 +711,9 @@ def set_overwrite_values(self, values: Union[float, torch.Tensor], variable: str
"encoder",
], f"target has be one of 'all', 'decoder' or 'encoder' but target={target} instead"

if variable in self.static_categoricals or variable in self.static_categoricals:
target = "all"

if variable == self.target:
raise NotImplementedError("Target variable is not supported")
if self.weight is not None and self.weight == variable:
Expand Down Expand Up @@ -856,7 +863,9 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:

# overwrite values
if self._overwrite_values is not None:
if self._overwrite_values["target"] == "all":
if isinstance(self._overwrite_values["target"], slice):
positions = self._overwrite_values["target"]
elif self._overwrite_values["target"] == "all":
positions = slice(None)
elif self._overwrite_values["target"] == "encoder":
positions = slice(None, encoder_length)
Expand Down
7 changes: 5 additions & 2 deletions pytorch_forecasting/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,19 @@ def forward(self, y_pred: Dict[str, torch.Tensor], target: Union[torch.Tensor, r
mask = torch.arange(target.size(1), device=target.device).unsqueeze(0) >= lengths.unsqueeze(-1)
if losses.ndim > 2:
mask = mask.unsqueeze(-1)
dim_normalizer = losses.size(-1)
else:
dim_normalizer = 1.0
# reduce to one number
if self.reduction == "none":
loss = losses.masked_fill(mask, float("nan"))
else:
if self.reduction == "mean":
losses = losses.masked_fill(mask, 0.0)
loss = losses.sum() / lengths.sum()
loss = losses.sum() / lengths.sum() / dim_normalizer
elif self.reduction == "sqrt-mean":
losses = losses.masked_fill(mask, 0.0)
loss = losses.sum() / lengths.sum()
loss = losses.sum() / lengths.sum() / dim_normalizer
loss = loss.sqrt()
assert not torch.isnan(loss), (
"Loss should not be nan - i.e. something went wrong "
Expand Down
200 changes: 173 additions & 27 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from copy import deepcopy
import inspect
from pytorch_forecasting.data.encoders import GroupNormalizer
from torch import unsqueeze
from torch import optim
import cloudpickle
Expand All @@ -11,7 +12,7 @@
from tqdm.notebook import tqdm

from pytorch_forecasting.metrics import SMAPE
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from pytorch_lightning import LightningModule
from pytorch_lightning.metrics.metric import TensorMetric
from pytorch_forecasting.optim import Ranger
Expand Down Expand Up @@ -50,13 +51,6 @@ def forward(self, x):
encoding_target = x["encoder_target"]
return dict(prediction=..., target_scale=x["target_scale"])
# implement lightning steps
def training_step(self, batch, batch_idx):
x, y = batch
return {"loss": self.loss(self(x), y)}
# implement further steps
"""

def __init__(
Expand Down Expand Up @@ -516,7 +510,7 @@ def predict(
batch_size: batch size for dataloader - only used if data is not a dataloader is passed
num_workers: number of workers for dataloader - only used if data is not a dataloader is passed
fast_dev_run: if to only return results of first batch
show_progress_bar: if to show progress bar. Defaults to True
show_progress_bar: if to show progress bar. Defaults to False.
return_x: if to return network inputs
Returns:
Expand Down Expand Up @@ -608,6 +602,118 @@ def predict(
output.append(torch.cat(decode_lenghts, dim=0))
return output

def predict_dependency(
self,
data: Union[DataLoader, pd.DataFrame, TimeSeriesDataSet],
variable: str,
values: Iterable,
mode: str = "dataframe",
target="decoder",
show_progress_bar: bool = False,
**kwargs,
) -> Union[np.ndarray, torch.Tensor, pd.Series, pd.DataFrame]:
"""
Predict partial dependency.
Args:
data (Union[DataLoader, pd.DataFrame, TimeSeriesDataSet]): data
variable (str): variable which to modify
values (Iterable): array of values to probe
mode (str, optional): Output mode. Defaults to "dataframe". Either
* "series": values are average prediction and index are probed values
* "dataframe": columns are as obtained by the `dataset.get_index()` method,
prediction (which is the mean prediction over the time horizon),
normalized_prediction (which are predictions devided by the prediction for the first probed value)
the variable name for the probed values
* "raw": outputs a tensor of shape len(values) x prediction_shape
target: Defines which values are overwritten for making a prediction.
Same as in :py:meth:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.set_overwrite_values`.
Defaults to "decoder".
show_progress_bar: if to show progress bar. Defaults to False.
**kwargs: additional kwargs to :py:meth:`~predict` method
Returns:
Union[np.ndarray, torch.Tensor, pd.Series, pd.DataFrame]: output
"""
values = np.asarray(values)
if isinstance(data, pd.DataFrame): # convert to dataframe
data = TimeSeriesDataSet.from_parameters(self.dataset_parameters, data, predict=True)
elif isinstance(data, DataLoader):
data = data.dataset

results = []
progress_bar = tqdm(desc="Predict", unit=" batches", total=len(values), disable=not show_progress_bar)
for value in values:
# set values
data.set_overwrite_values(variable=variable, values=value, target=target)
# predict
kwargs.setdefault("mode", "prediction")
results.append(self.predict(data, **kwargs))
# increment progress
progress_bar.update()

data.reset_overwrite_values() # reset overwrite values to avoid side-effect

# results to one tensor
results = torch.stack(results, dim=0)

# convert results to requested output format
if mode == "series":
results = results[:, ~torch.isnan(results[0])].mean(1) # average samples and prediction horizon
results = pd.Series(results, index=values)

elif mode == "dataframe":
# take mean over time
is_nan = torch.isnan(results)
results[is_nan] = 0
results = results.sum(-1) / (~is_nan).float().sum(-1)

# create dataframe
dependencies = data.get_index()
dependencies = (
dependencies.iloc[np.tile(np.arange(len(dependencies)), len(values))]
.reset_index(drop=True)
.assign(prediction=results.flatten())
)
dependencies[variable] = values.repeat(len(data))
first_prediction = dependencies.groupby(data.group_ids, observed=True).prediction.transform("first")
dependencies["normalized_prediction"] = dependencies["prediction"] / first_prediction
dependencies["id"] = dependencies.groupby(data.group_ids, observed=True).ngroup()
results = dependencies

elif mode == "raw":
pass

else:
raise ValueError(f"mode {mode} is unknown - see documentation for available modes")

return results


class CovariatesMixin:
"""
Model mix-in for additional methods using covariates.
Assumes the following hyperparameters:
Args:
x_reals: order of continuous variables in tensor passed to forward function
x_categoricals: order of categorical variables in tensor passed to forward function
embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and
embedding size
embedding_labels: dictionary mapping (string) indices to list of categorical labels
"""

@property
def categorical_groups_mapping(self) -> Dict[str, str]:
groups = {}
for group_name, sublist in self.hparams.categorical_groups.items():
groups.update({name: group_name for name in sublist})
return groups

def calculate_prediction_actual_by_variable(
self,
x: Dict[str, torch.Tensor],
Expand All @@ -621,13 +727,13 @@ def calculate_prediction_actual_by_variable(
Args:
x: input as ``forward()``
y_pred: predictions obtained by ``self.loss.to_prediction(self(x))``
y_pred: predictions obtained by ``self.transform_output(self(x))``
normalize: if to return normalized averages, i.e. mean or sum of ``y``
bins: number of bins to calculate
std: number of standard deviations for standard scaled continuous variables
Returns:
dictionary that can be used to plot averages with ``plot_prediction_actual_by_variable()``
dictionary that can be used to plot averages with :py:meth:`~plot_prediction_actual_by_variable`
"""
support = {} # histogram
# averages
Expand All @@ -640,7 +746,10 @@ def calculate_prediction_actual_by_variable(
# select valid y values
y_flat = x["decoder_target"][mask]
y_pred_flat = y_pred[mask]
if self.loss.log_space:
log_y = self.dataset_parameters["target_normalizer"] is not None and getattr(
self.dataset_parameters["target_normalizer"], "log_scale", False
)
if log_y:
y_flat = torch.log(y_flat + 1e-8)
y_pred_flat = torch.log(y_pred_flat + 1e-8)

Expand Down Expand Up @@ -675,28 +784,51 @@ def calculate_prediction_actual_by_variable(
# categorical_variables
cats = x["decoder_cat"]
for idx, name in enumerate(self.hparams.x_categoricals): # todo: make it work for grouped categoricals
averages_actual[name], support[name] = groupby_apply(
reduction = "sum"
name = self.categorical_groups_mapping.get(name, name)
averages_actual_cat, support_cat = groupby_apply(
cats[..., idx][mask],
y_flat,
bins=self.hparams.embedding_sizes[idx][0],
bins=self.hparams.embedding_sizes[name][0],
reduction=reduction,
return_histogram=True,
)
averages_prediction[name], _ = groupby_apply(
averages_prediction_cat, _ = groupby_apply(
cats[..., idx][mask],
y_pred_flat,
bins=self.hparams.embedding_sizes[idx][0],
bins=self.hparams.embedding_sizes[name][0],
reduction=reduction,
return_histogram=True,
)

# add either to existing calculations or
if name in averages_actual:
averages_actual[name] += averages_actual_cat
support[name] += support_cat
averages_prediction[name] += averages_prediction_cat
else:
averages_actual[name] = averages_actual_cat
support[name] = support_cat
averages_prediction[name] = averages_prediction_cat

if normalize: # run reduction for categoricals
for name in self.hparams.embedding_sizes.keys():
averages_actual[name] /= support[name].clamp(min=1)
averages_prediction[name] /= support[name].clamp(min=1)

if log_y: # reverse log scaling
for name in support.keys():
averages_actual[name] = torch.exp(averages_actual[name])
averages_prediction[name] = torch.exp(averages_prediction[name])

return {
"support": support,
"average": {"actual": averages_actual, "prediction": averages_prediction},
"std": std,
}

def plot_prediction_actual_by_variable(
self, data: Dict[str, Dict[str, torch.Tensor]], name: str = None
self, data: Dict[str, Dict[str, torch.Tensor]], name: str = None, ax=None
) -> Union[Dict[str, plt.Figure], plt.Figure]:
"""
Plot predicions and actual averages by variables
Expand All @@ -720,23 +852,29 @@ def plot_prediction_actual_by_variable(
# create figure
kwargs = {}
# adjust figure size for figures with many labels
if self.hparams.embedding_sizes[name][0] > 10:
if self.hparams.embedding_sizes.get(name, [1e9])[0] > 10:
kwargs = dict(figsize=(10, 5))
fig, ax = plt.subplots(**kwargs)
if ax is None:
fig, ax = plt.subplots(**kwargs)
else:
fig = ax.get_figure()
ax.set_title(f"{name} averages")
ax.set_xlabel(name)
if self.loss.log_space:
ax.set_ylabel("Log prediction")
else:
ax.set_ylabel("Prediction")
ax.set_ylabel("Prediction")

ax2 = ax.twinx() # second axis for histogram
ax2.set_ylabel("Frequency")

# get values for average plot and histogram
values_actual = data["average"]["actual"][name].cpu().numpy()
values_prediction = data["average"]["prediction"][name].cpu().numpy()
bins = values_actual.size
support = data["average"][name].cpu().numpy()
support = data["support"][name].cpu().numpy()

if self.dataset_parameters["target_normalizer"] is not None and getattr(
self.dataset_parameters["target_normalizer"], "log_scale", False
):
ax.set_yscale("log")

# only display values where samples were observed
support_non_zero = support > 0
Expand All @@ -746,8 +884,14 @@ def plot_prediction_actual_by_variable(

# plot averages
if name in self.hparams.x_reals:
mean, scale = self.dataset_parameters.scalers[name].mean, self.dataset_parameters.scalers[name].scale
x = np.linspace(-data["std"], data["std"], bins) * scale + mean
# create x
scaler = self.dataset_parameters["scalers"][name]
x = np.linspace(-data["std"], data["std"], bins)
# reversing normalization for group normalizer is not possible without sample level information
if not isinstance(scaler, GroupNormalizer):
x = scaler.inverse_transform(x)
ax.set_xlabel(f"Normalized {name}")

if len(x) > 0:
x_step = x[1] - x[0]
else:
Expand All @@ -759,7 +903,7 @@ def plot_prediction_actual_by_variable(
elif name in self.hparams.embedding_labels:
# sort values from lowest to highest
sorting = values_actual.argsort()
labels = np.asarray(self.hparams.embedding_labels[name])[support_non_zero][sorting]
labels = np.asarray(list(self.hparams.embedding_labels[name].keys()))[support_non_zero][sorting]
values_actual = values_actual[sorting]
values_prediction = values_prediction[sorting]
support = support[sorting]
Expand All @@ -783,6 +927,8 @@ def plot_prediction_actual_by_variable(
else:
raise ValueError(f"Unknown name {name}")
# plot support histogram
if len(support) > 1 and np.median(support) < support.max() / 10:
ax2.set_yscale("log")
ax2.bar(x, support, width=x_step, linewidth=0, alpha=0.2, color="k")
# adjust layout and legend
fig.tight_layout()
Expand Down
Loading

0 comments on commit 6843748

Please sign in to comment.