From 142f7b73c20a0813a218aa78af97c9bcff96146e Mon Sep 17 00:00:00 2001 From: Daniel Vainsencher Date: Wed, 3 Aug 2022 17:12:55 -0400 Subject: [PATCH] Fix the slice for plotting attention Before the fix, works only when `interpretation["attention"]` matches `encoder_length` (in which case the slice is not needed anyway). After, length matches always. I'm running into this with encoder_length=0, which is probably a bug elsewhere, possibly in my code, but this is probably worth fixing anyway. --- .../models/temporal_fusion_transformer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index 8c816dbe..cad1bc7e 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -728,7 +728,7 @@ def plot_prediction( encoder_length = x["encoder_lengths"][0] ax2.plot( torch.arange(-encoder_length, 0), - interpretation["attention"][0, -encoder_length:].detach().cpu(), + interpretation["attention"][0, :-encoder_length].detach().cpu(), alpha=0.2, color="k", )