From 7930b9a77a8042998fa0524b99b38a8548d5675d Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Sat, 17 Oct 2020 21:39:25 +0100 Subject: [PATCH] FIx second stacking variable length tensors occurance --- .../models/temporal_fusion_transformer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index 180a9aa0..d050fa09 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -823,7 +823,7 @@ def _log_interpretation(self, outputs, label="train"): # log lengths of encoder/decoder for type in ["encoder", "decoder"]: fig, ax = plt.subplots() - lengths = torch.stack([out["interpretation"][f"{type}_length_histogram"] for out in outputs]).sum(0).cpu() + lengths = padded_stack([out["interpretation"][f"{type}_length_histogram"] for out in outputs]).sum(0).cpu() if type == "decoder": start = 1 else: