From 755ef66dcd217debca835bb638f4801b24985fea Mon Sep 17 00:00:00 2001 From: XinyuWuu Date: Tue, 10 Sep 2024 14:55:53 +0800 Subject: [PATCH] fix no x for TFT --- .../sub_modules.py | 9 +++- .../test_temporal_fusion_transformer.py | 48 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py b/pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py index f3d6ae35..eefb72b4 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py @@ -336,7 +336,8 @@ def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None): outputs = var_outputs * sparse_weights outputs = outputs.sum(dim=-1) - else: # for one input, do not perform variable selection but just encoding + elif self.num_inputs == 1: + # for one input, do not perform variable selection but just encoding name = next(iter(self.single_variable_grns.keys())) variable_embedding = x[name] if name in self.prescalers: @@ -346,6 +347,12 @@ def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None): sparse_weights = torch.ones(outputs.size(0), outputs.size(1), 1, 1, device=outputs.device) # else: # ndim == 2 -> batch size, hidden size, n_variables sparse_weights = torch.ones(outputs.size(0), 1, 1, device=outputs.device) + else: # for no input + outputs = torch.zeros(context.size(), device=context.device) + if outputs.ndim == 3: # -> batch size, time, hidden size, n_variables + sparse_weights = torch.zeros(outputs.size(0), outputs.size(1), 1, 0, device=outputs.device) + else: # ndim == 2 -> batch size, hidden size, n_variables + sparse_weights = torch.zeros(outputs.size(0), 1, 0, device=outputs.device) return outputs, sparse_weights diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 9fe78445..7535fc75 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -6,6 +6,7 @@ from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger import numpy as np +import pandas as pd import pytest import torch @@ -421,3 +422,50 @@ def test_hyperparameter_optimization_integration(dataloaders_with_covariates, tm ) finally: shutil.rmtree(tmp_path, ignore_errors=True) + + +def test_no_exogenous_variable(): + data = pd.DataFrame( + { + "target": np.ones(1600), + "group_id": np.repeat(np.arange(16), 100), + "time_idx": np.tile(np.arange(100), 16), + } + ) + training_dataset = TimeSeriesDataSet( + data=data, + time_idx="time_idx", + target="target", + group_ids=["group_id"], + max_encoder_length=10, + max_prediction_length=5, + time_varying_unknown_reals=["target"], + time_varying_known_reals=[], + ) + validation_dataset = TimeSeriesDataSet.from_dataset(training_dataset, data, stop_randomization=True, predict=True) + training_data_loader = training_dataset.to_dataloader(train=True, batch_size=8, num_workers=0) + validation_data_loader = validation_dataset.to_dataloader(train=False, batch_size=8, num_workers=0) + forecaster = TemporalFusionTransformer.from_dataset( + training_dataset, + log_interval=1, + ) + from lightning.pytorch import Trainer + + trainer = Trainer( + max_epochs=2, + limit_train_batches=8, + limit_val_batches=8, + ) + trainer.fit( + forecaster, + train_dataloaders=training_data_loader, + val_dataloaders=validation_data_loader, + ) + best_model_path = trainer.checkpoint_callback.best_model_path + best_model = TemporalFusionTransformer.load_from_checkpoint(best_model_path) + best_model.predict( + validation_data_loader, + return_x=True, + return_y=True, + return_index=True, + )