Skip to content

Commit

Permalink
Merge pull request #902 from jdb78/feature/fix-attention-concat
Browse files Browse the repository at this point in the history
Fix attention concat
  • Loading branch information
jdb78 authored Mar 23, 2022
2 parents eea8c16 + d810313 commit af9e1d3
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 51 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Release Notes

## v0.10.0 UNRELEASED
## v0.10.0 Adding N-HiTS network (N-BEATS successor) (23/03/2022)

### Added

- Added new `N-HiTS` network that has consistently beaten `N-BEATS` (#890)
- Allow using [torchmetrics](https://torchmetrics.readthedocs.io/) as loss metrics (#776)
- Enable fitting `EncoderNormalizer()` with limited data history using `max_length` argument (#782)
- More flexible `MultiEmbedding()` with convenience `output_size` and `input_size` properties (#829)
- Fix concatentation of attention (#902)

### Fixed

Expand Down
6 changes: 3 additions & 3 deletions pytorch_forecasting/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,9 @@ def to_quantiles(self, out: Dict[str, torch.Tensor], quantiles=None):
else:
quantiles = self.quantiles
predictions = super().to_prediction(out)
return torch.stack([torch.tensor(scipy.stats.poisson(predictions.cpu()).ppf(q)) for q in quantiles], dim=-1).to(
predictions.device
)
return torch.stack(
[torch.tensor(scipy.stats.poisson(predictions.detach().numpy()).ppf(q)) for q in quantiles], dim=-1
).to(predictions.device)


class QuantileLoss(MultiHorizonMetric):
Expand Down
20 changes: 19 additions & 1 deletion pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from copy import deepcopy
import inspect
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import warnings

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -75,7 +76,24 @@ def _torch_cat_na(x: List[torch.Tensor]) -> torch.Tensor:
)
for xi in x
]
return torch.cat(x, dim=0)

# check if remaining dimensions are all equal
if x[0].ndim > 2:
remaining_dimensions_equal = all([all([xi.size(i) == x[0].size(i) for xi in x]) for i in range(2, x[0].ndim)])
else:
remaining_dimensions_equal = True

# deaggregate
if remaining_dimensions_equal:
return torch.cat(x, dim=0)
else:
# make list instead but warn
warnings.warn(
f"Not all dimensions are equal for tensors shapes. Example tensor {x[0].shape}. "
"Returning list instead of torch.Tensor.",
UserWarning,
)
return [xii for xi in x for xii in xi]


def _concatenate_output(
Expand Down
129 changes: 83 additions & 46 deletions pytorch_forecasting/models/temporal_fusion_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
InterpretableMultiHeadAttention,
VariableSelectionNetwork,
)
from pytorch_forecasting.utils import autocorrelation, create_mask, detach, integer_histogram, padded_stack, to_list
from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list


class TemporalFusionTransformer(BaseModelWithCovariates):
Expand Down Expand Up @@ -501,7 +501,8 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

return self.to_network_output(
prediction=self.transform_output(output, target_scale=x["target_scale"]),
attention=attn_output_weights,
encoder_attention=attn_output_weights[..., :max_encoder_length],
decoder_attention=attn_output_weights[..., max_encoder_length:],
static_variables=static_variable_selection,
encoder_variables=encoder_sparse_weights,
decoder_variables=decoder_sparse_weights,
Expand Down Expand Up @@ -540,7 +541,6 @@ def interpret_output(
out: Dict[str, torch.Tensor],
reduction: str = "none",
attention_prediction_horizon: int = 0,
attention_as_autocorrelation: bool = False,
) -> Dict[str, torch.Tensor]:
"""
interpret output of model
Expand All @@ -550,12 +550,77 @@ def interpret_output(
reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for
normalizing by encode lengths
attention_prediction_horizon: which prediction horizon to use for attention
attention_as_autocorrelation: if to record attention as autocorrelation - this should be set to true in
case of ``reduction != "none"`` and differing prediction times of the samples. Defaults to False
Returns:
interpretations that can be plotted with ``plot_interpretation()``
"""
# take attention and concatenate if a list to proper attention object
if isinstance(out["decoder_attention"], (list, tuple)):
batch_size = len(out["decoder_attention"])
# start with decoder attention
# assume issue is in last dimension, we need to find max
max_last_dimension = max(x.size(-1) for x in out["decoder_attention"])
first_elm = out["decoder_attention"][0]
# create new attention tensor into which we will scatter
decoder_attention = torch.full(
(batch_size, *first_elm.shape[:-1], max_last_dimension),
float("nan"),
dtype=first_elm.dtype,
device=first_elm.device,
)
# scatter into tensor
for idx, x in enumerate(out["decoder_attention"]):
decoder_length = out["decoder_lengths"][idx]
decoder_attention[idx, :, :, :decoder_length] = x[..., :decoder_length]

# same game for encoder attention
# create new attention tensor into which we will scatter
first_elm = out["encoder_attention"][0]
encoder_attention = torch.full(
(batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length),
float("nan"),
dtype=first_elm.dtype,
device=first_elm.device,
)
# scatter into tensor
for idx, x in enumerate(out["encoder_attention"]):
encoder_length = out["encoder_lengths"][idx]
encoder_attention[idx, :, :, self.hparams.max_encoder_length - encoder_length :] = x[
..., :encoder_length
]
else:
decoder_attention = out["decoder_attention"]
decoder_mask = create_mask(out["decoder_attention"].size(1), out["decoder_lengths"])
decoder_attention[decoder_mask[..., None, None].expand_as(decoder_attention)] = float("nan")
# roll encoder attention (so start last encoder value is on the right)
encoder_attention = out["encoder_attention"]
shifts = encoder_attention.size(3) - out["encoder_lengths"]
new_index = (
torch.arange(encoder_attention.size(3))[None, None, None].expand_as(encoder_attention)
- shifts[:, None, None, None]
) % encoder_attention.size(3)
encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index)
# expand encoder_attentiont to full size
if encoder_attention.size(-1) < self.hparams.max_encoder_length:
encoder_attention = torch.concat(
[
torch.full(
(
*encoder_attention.shape[:-1],
self.hparams.max_encoder_length - out["encoder_lengths"].max(),
),
float("nan"),
dtype=encoder_attention.dtype,
device=encoder_attention.device,
),
encoder_attention,
],
dim=-1,
)

# combine attention vector
attention = torch.concat([encoder_attention, decoder_attention], dim=-1)
attention[attention < 1e-5] = float("nan")

# histogram of decode and encode lengths
encoder_length_histogram = integer_histogram(out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length)
Expand All @@ -582,53 +647,25 @@ def interpret_output(
static_variables = out["static_variables"].squeeze(1)
# attention is batch x time x heads x time_to_attend
# average over heads + only keep prediction attention and attention on observed timesteps
attention = out["attention"][
:, attention_prediction_horizon, :, : out["encoder_lengths"].max() + attention_prediction_horizon
].mean(1)
attention = masked_op(
attention[
:, attention_prediction_horizon, :, : self.hparams.max_encoder_length + attention_prediction_horizon
],
op="mean",
dim=1,
)

if reduction != "none": # if to average over batches
static_variables = static_variables.sum(dim=0)
encoder_variables = encoder_variables.sum(dim=0)
decoder_variables = decoder_variables.sum(dim=0)

# reorder attention or averaging
for i in range(len(attention)): # very inefficient but does the trick
if 0 < out["encoder_lengths"][i] < attention.size(1) - attention_prediction_horizon - 1:
relevant_attention = attention[
i, : out["encoder_lengths"][i] + attention_prediction_horizon
].clone()
if attention_as_autocorrelation:
relevant_attention = autocorrelation(relevant_attention)
attention[i, -out["encoder_lengths"][i] - attention_prediction_horizon :] = relevant_attention
attention[i, : attention.size(1) - out["encoder_lengths"][i] - attention_prediction_horizon] = 0.0
elif attention_as_autocorrelation:
attention[i] = autocorrelation(attention[i])

attention = attention.sum(dim=0)
if reduction == "mean":
attention = attention / encoder_length_histogram[1:].flip(0).cumsum(0).clamp(1)
attention = attention / attention.sum(-1).unsqueeze(-1) # renormalize
elif reduction == "sum":
pass
else:
raise ValueError(f"Unknown reduction {reduction}")

attention = torch.zeros(
self.hparams.max_encoder_length + attention_prediction_horizon, device=self.device
).scatter(
dim=0,
index=torch.arange(
self.hparams.max_encoder_length + attention_prediction_horizon - attention.size(-1),
self.hparams.max_encoder_length + attention_prediction_horizon,
device=self.device,
),
src=attention,
)
attention = masked_op(attention, dim=0, op=reduction)
else:
attention = attention / attention.sum(-1).unsqueeze(-1) # renormalize
attention = attention / masked_op(attention, dim=1, op="sum").unsqueeze(-1) # renormalize

interpretation = dict(
attention=attention,
attention=attention.masked_fill(torch.isnan(attention), 0.0),
static_variables=static_variables,
encoder_variables=encoder_variables,
decoder_variables=decoder_variables,
Expand Down Expand Up @@ -677,15 +714,15 @@ def plot_prediction(

# add attention on secondary axis
if plot_attention:
interpretation = self.interpret_output(out)
interpretation = self.interpret_output(out.iget(slice(idx, idx + 1)))
for f in to_list(fig):
ax = f.axes[0]
ax2 = ax.twinx()
ax2.set_ylabel("Attention")
encoder_length = x["encoder_lengths"][idx]
encoder_length = x["encoder_lengths"][0]
ax2.plot(
torch.arange(-encoder_length, 0),
interpretation["attention"][idx, :encoder_length].detach().cpu(),
interpretation["attention"][0, -encoder_length:].detach().cpu(),
alpha=0.2,
color="k",
)
Expand Down
36 changes: 36 additions & 0 deletions pytorch_forecasting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,17 @@ def items(self):
def keys(self):
return self._fields

def iget(self, idx: Union[int, slice]):
"""Select item(s) row-wise.
Args:
idx ([int, slice]): item to select
Returns:
Output of single item.
"""
return self.__class__(*(x[idx] for x in self))


class TupleOutputMixIn:
"""MixIn to give output a namedtuple-like access capabilities with ``to_network_output() function``."""
Expand Down Expand Up @@ -436,3 +447,28 @@ def detach(
return [detach(xi) for xi in x]
else:
return x


def masked_op(tensor: torch.Tensor, op: str = "mean", dim: int = 0, mask: torch.Tensor = None) -> torch.Tensor:
"""Calculate operation on masked tensor.
Args:
tensor (torch.Tensor): tensor to conduct operation over
op (str): operation to apply. One of ["mean", "sum"]. Defaults to "mean".
dim (int, optional): dimension to average over. Defaults to 0.
mask (torch.Tensor, optional): boolean mask to apply (True=will take mean, False=ignore).
Masks nan values by default.
Returns:
torch.Tensor: tensor with averaged out dimension
"""
if mask is None:
mask = ~torch.isnan(tensor)
masked = tensor.masked_fill(~mask, 0.0)
summed = masked.sum(dim=dim)
if op == "mean":
return summed / mask.sum(dim=dim) # Find the average
elif op == "sum":
return summed
else:
raise ValueError(f"unkown operation {op}")
47 changes: 47 additions & 0 deletions tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
import sys

import numpy as np
import pytest
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
Expand Down Expand Up @@ -256,6 +257,52 @@ def test_prediction_with_dataloder(model, dataloaders_with_covariates, kwargs):
model.predict(val_dataloader, fast_dev_run=True, **kwargs)


def test_prediction_with_dataloder_raw(data_with_covariates, tmp_path):
# tests correct concatenation of raw output
test_data = data_with_covariates.copy()
np.random.seed(2)
test_data = test_data.sample(frac=0.5)

dataset = TimeSeriesDataSet(
test_data,
time_idx="time_idx",
max_encoder_length=8,
max_prediction_length=10,
min_prediction_length=1,
min_encoder_length=1,
target="volume",
group_ids=["agency", "sku"],
constant_fill_strategy=dict(volume=0.0),
allow_missing_timesteps=True,
time_varying_unknown_reals=["volume"],
time_varying_known_reals=["time_idx"],
target_normalizer=GroupNormalizer(groups=["agency", "sku"]),
)

net = TemporalFusionTransformer.from_dataset(
dataset,
learning_rate=1e-6,
hidden_size=4,
attention_head_size=1,
dropout=0.2,
hidden_continuous_size=2,
log_interval=1,
log_val_interval=1,
log_gradient_flow=True,
)
logger = TensorBoardLogger(tmp_path)
trainer = pl.Trainer(max_epochs=1, gradient_clip_val=1e-6, logger=logger)
trainer.fit(net, train_dataloaders=dataset.to_dataloader(batch_size=4, num_workers=0))

# choose small batch size to provoke issue
res = net.predict(dataset.to_dataloader(batch_size=2, num_workers=0), mode="raw")
# check that interpretation works
net.interpret_output(res)["attention"]
assert net.interpret_output(res.iget(slice(1)))["attention"].size() == torch.Size(
(1, net.hparams.max_encoder_length)
)


def test_prediction_with_dataset(model, dataloaders_with_covariates):
val_dataloader = dataloaders_with_covariates["val"]
model.predict(val_dataloader.dataset, fast_dev_run=True)
Expand Down

0 comments on commit af9e1d3

Please sign in to comment.