diff --git a/examples/stallion.py b/examples/stallion.py index 257d1c84..408c4d11 100644 --- a/examples/stallion.py +++ b/examples/stallion.py @@ -25,7 +25,7 @@ data = get_stallion_data() -data["month"] = data.date.dt.month +data["month"] = data.date.dt.month.astype("str").astype("category") data["log_volume"] = np.log(data.volume + 1e-8) data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month @@ -125,7 +125,6 @@ # fig.show() # tft.hparams.learning_rate = res.suggestion() - trainer.fit( tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, ) diff --git a/pytorch_forecasting/data.py b/pytorch_forecasting/data.py index e5158967..95b58574 100644 --- a/pytorch_forecasting/data.py +++ b/pytorch_forecasting/data.py @@ -64,7 +64,7 @@ def fit(self, y): self.classes_[val] = idx + 1 else: self.classes_ = {val: idx for idx, val in enumerate(np.unique(y))} - self.classes_vector = np.array(list(self.classes_.keys())) + self.classes_vector_ = np.array(list(self.classes_.keys())) return self def transform(self, y): @@ -86,11 +86,11 @@ def transform(self, y): return encoded def inverse_transform(self, y): - if y.max() >= len(self.classes_vector): + if y.max() >= len(self.classes_vector_): raise KeyError("New unknown values detected") # decode - decoded = self.classes_vector[y] + decoded = self.classes_vector_[y] return decoded @@ -149,7 +149,7 @@ def __init__( self.coerce_positive = coerce_positive def get_parameters(self, *args, **kwargs): - return torch.tensor([self.center_, self.scale]) + return torch.tensor([self.center_, self.scale_]) def _preprocess_y(self, y): if self.coerce_positive is None and not self.log_scale: @@ -168,10 +168,10 @@ def fit(self, y): if self.method == "standard": if isinstance(y, torch.Tensor): self.center_ = torch.mean(y) - self.scale = torch.std(y) / (self.center_ + self.eps) + self.scale_ = torch.std(y) / (self.center_ + self.eps) else: self.center_ = np.mean(y) - self.scale = np.std(y) / (self.center_ + self.eps) + self.scale_ = np.std(y) / (self.center_ + self.eps) elif self.method == "robust": if isinstance(y, torch.Tensor): @@ -182,7 +182,7 @@ def fit(self, y): self.center_ = np.median(y) q_75 = np.percentiley(y, 75) q_25 = np.percentiley(y, 25) - self.scale = (q_75 - q_25) / (self.center_ + self.eps) / 2.0 + self.scale_ = (q_75 - q_25) / (self.center_ + self.eps) / 2.0 return self def transform(self, y, return_norm: bool = False): @@ -192,7 +192,7 @@ def transform(self, y, return_norm: bool = False): else: y = np.log(y + self.log_zero_value + self.eps) if self.center: - y = (y / (self.center_ + self.eps) - 1) / (self.scale + self.eps) + y = (y / (self.center_ + self.eps) - 1) / (self.scale_ + self.eps) else: y = y / (self.center_ + self.eps) if return_norm: @@ -208,7 +208,7 @@ def __call__(self, data: Dict[str, torch.Tensor]): norm = data["target_scale"] # use correct shape for norm - if data["prediction"].ndim > 1: + if data["prediction"].ndim > norm.ndim: norm = norm.unsqueeze(-1) # transform @@ -286,14 +286,14 @@ def fit(self, y, X): assert not self.scale_by_group, "No groups are defined, i.e. `scale_by_group=[]`" if self.method == "standard": mean = np.mean(y) - self.norm = mean, np.std(y) / (mean + self.eps) + self.norm_ = mean, np.std(y) / (mean + self.eps) else: quantiles = np.quantile(y, [0.25, 0.5, 0.75]) - self.norm = quantiles[1], (quantiles[2] - quantiles[0]) / (quantiles[1] + self.eps) + self.norm_ = quantiles[1], (quantiles[2] - quantiles[0]) / (quantiles[1] + self.eps) elif self.scale_by_group: if self.method == "standard": - self.norm = { + self.norm_ = { g: X[[g]] .assign(y=y) .groupby(g, observed=True) @@ -302,7 +302,7 @@ def fit(self, y, X): for g in self.groups } else: - self.norm = { + self.norm_ = { g: X[[g]] .assign(y=y) .groupby(g, observed=True) @@ -315,11 +315,11 @@ def fit(self, y, X): for g in self.groups } # calculate missings - self._missing = {group: scales.median().to_dict() for group, scales in self.norm.items()} + self.missing_ = {group: scales.median().to_dict() for group, scales in self.norm_.items()} else: if self.method == "standard": - self.norm = ( + self.norm_ = ( X[self.groups] .assign(y=y) .groupby(self.groups, observed=True) @@ -327,7 +327,7 @@ def fit(self, y, X): .assign(scale=lambda x: x.scale / (x["mean"] + self.eps)) ) else: - self.norm = ( + self.norm_ = ( X[self.groups] .assign(y=y) .groupby(self.groups, observed=True) @@ -338,7 +338,7 @@ def fit(self, y, X): scale=lambda x: (x[0.75] - x[0.25] + self.eps) / (x[0.5] + self.eps) / 2.0, )[["median", "scale"]] ) - self._missing = self.norm.median().to_dict() + self.missing_ = self.norm_.median().to_dict() return self @property @@ -379,32 +379,33 @@ def get_parameters(self, groups, group_names: List[str] = None): assert len(group_names) == len(self.groups), "Passed groups and fitted do not match" if len(self.groups) == 0: - return np.asarray(self.norm) + params = np.asarray(self.norm_).squeeze() elif self.scale_by_group: norm = np.array([1.0, 1.0]) for group, group_name in zip(groups, group_names): try: - norm = norm * self.norm[group_name].loc[group].to_numpy() + norm = norm * self.norm_[group_name].loc[group].to_numpy() except KeyError: - norm = norm * np.asarray([self._missing[group_name][name] for name in self.names]) + norm = norm * np.asarray([self.missing_[group_name][name] for name in self.names]) norm = np.power(norm, 1.0 / len(self.groups)) - return norm + params = norm else: try: - return self.norm.loc[groups].to_numpy() + params = self.norm_.loc[groups].to_numpy() except (KeyError, TypeError): - return np.asarray([self._missing[name] for name in self.names]) + params = np.asarray([self.missing_[name] for name in self.names]) + return params def get_norm(self, X): if len(self.groups) == 0: - norm = np.asarray(self.norm).reshape(1, -1) + norm = np.asarray(self.norm_).reshape(1, -1) elif self.scale_by_group: norm = [ np.prod( [ X[group_name] - .map(self.norm[group_name][name]) - .fillna(self._missing[group_name][name]) + .map(self.norm_[group_name][name]) + .fillna(self.missing_[group_name][name]) .to_numpy() for group_name in self.groups ], @@ -414,7 +415,7 @@ def get_norm(self, X): ] norm = np.power(np.stack(norm, axis=1), 1.0 / len(self.groups)) else: - norm = X[self.groups].set_index(self.groups).join(self.norm).fillna(self._missing).to_numpy() + norm = X[self.groups].set_index(self.groups).join(self.norm_).fillna(self.missing_).to_numpy() return norm @@ -512,7 +513,7 @@ def __init__( self.target = target self.weight = weight self.time_idx = time_idx - self.group_ids = group_ids + self.group_ids = [] + group_ids self.static_categoricals = [] + static_categoricals self.static_reals = [] + static_reals self.time_varying_known_categoricals = [] + time_varying_known_categoricals @@ -523,14 +524,14 @@ def __init__( self.add_relative_time_idx = add_relative_time_idx self.randomize_length = randomize_length self.min_prediction_idx = min_prediction_idx or data[self.time_idx].min() - self.constant_fill_strategy = constant_fill_strategy + self.constant_fill_strategy = {} if len(constant_fill_strategy) == 0 else constant_fill_strategy self.predict_mode = predict_mode self.allow_missings = allow_missings self.target_normalizer = target_normalizer - self.categorical_encoders = categorical_encoders - self.scalers = scalers + self.categorical_encoders = {} if len(categorical_encoders) == 0 else categorical_encoders + self.scalers = {} if len(scalers) == 0 else scalers self.add_target_scales = add_target_scales - self.variable_groups = variable_groups + self.variable_groups = {} if len(variable_groups) == 0 else variable_groups # add_decoder_length if isinstance(add_decoder_length, str): @@ -591,6 +592,9 @@ def __init__( self.static_reals.append("decoder_length") data["decoder_length"] = 0.0 # dummy - real value will be set dynamiclly in __getitem__() + # validate + self._validate_data(data) + # preprocess data data = self._preprocess_data(data) @@ -600,6 +604,19 @@ def __init__( # convert to torch tensor for high performance data loading later self.data = self._data_to_tensors(data) + def _validate_data(self, data: pd.DataFrame): + # check for numeric categoricals which can cause hick-ups in logging + category_columns = data.head(1).select_dtypes("category").columns + object_columns = data.head(1).select_dtypes(object).columns + for name in self.flat_categoricals: + if not ( + name in object_columns + or (name in category_columns and data[name].cat.categories.dtype.kind not in "bifc") + ): + raise ValueError( + f"Data type of category {name} was found to be numeric - use a string type / categorified string" + ) + def save(self, fname: str) -> None: """ Save dataset to disk @@ -639,7 +656,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: ) elif self.categorical_encoders[name] is not None: try: - check_is_fitted(self.target_normalizer) + check_is_fitted(self.categorical_encoders[name]) except NotFittedError: self.categorical_encoders[name] = self.categorical_encoders[name].fit( data[columns].to_numpy().reshape(-1) @@ -649,7 +666,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: self.categorical_encoders[name] = NaNLabelEncoder(add_nan=allow_nans).fit(data[name]) elif self.categorical_encoders[name] is not None: try: - check_is_fitted(self.target_normalizer) + check_is_fitted(self.categorical_encoders[name]) except NotFittedError: self.categorical_encoders[name] = self.categorical_encoders[name].fit(data[name]) @@ -716,7 +733,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: self.scalers[name] = StandardScaler().fit(data[[name]]) elif self.scalers[name] is not None: try: - check_is_fitted(self.target_normalizer) + check_is_fitted(self.scalers[name]) except NotFittedError: if isinstance(self.scalers[name], GroupNormalizer): self.scalers[name] = self.scalers[name].fit(data[[name]], data) diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index e05fc46f..52bddc7c 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -31,6 +31,7 @@ def __init__( weight_decay: float = 1e-3, loss=SMAPE(), reduce_on_plateau_patience: int = 1000, + **kwargs, ): """ Initialize NBeats Model @@ -64,7 +65,7 @@ def __init__( reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 """ self.save_hyperparameters() - super().__init__() + super().__init__(**kwargs) self.loss = loss # setup stacks @@ -106,9 +107,6 @@ def __init__( def forward(self, x: Dict[str, torch.Tensor]): target = x["encoder_target"] - if self.loss.log_space: - target = torch.log(target + 1e-8) - timesteps = self.hparams.context_length + self.hparams.prediction_length generic_forecast = [torch.zeros((target.size(0), timesteps), dtype=torch.float32, device=self.device)] trend_forecast = [torch.zeros((target.size(0), timesteps), dtype=torch.float32, device=self.device)] @@ -139,6 +137,7 @@ def forward(self, x: Dict[str, torch.Tensor]): return dict( prediction=forecast, + target_scale=x["target_scale"], backcast=backcast, trend=torch.stack(trend_forecast, dim=0).sum(0), seasonality=torch.stack(seasonal_forecast, dim=0).sum(0), @@ -148,7 +147,6 @@ def forward(self, x: Dict[str, torch.Tensor]): @classmethod def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): new_kwargs = {"prediction_length": dataset.max_prediction_length, "context_length": dataset.max_encoder_length} - new_kwargs["dataset_parameters"] = dataset.get_parameters() new_kwargs.update(kwargs) # validate arguments @@ -167,7 +165,7 @@ def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): return super().from_dataset(dataset, **new_kwargs) def step(self, x, y, batch_idx, label) -> Dict[str, torch.Tensor]: - log, out = self.step(x, y, batch_idx=batch_idx, label=label) + log, out = super().step(x, y, batch_idx=batch_idx, label=label) self._log_interpretation(x, out, batch_idx=batch_idx, label=label) return log, out diff --git a/tests/conftest.py b/tests/conftest.py index 2b66a1f0..876f2446 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,3 +5,52 @@ sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../.."))) sys.path.insert(0, "examples") + +from examples.data import get_stallion_data +from pytorch_forecasting import TimeSeriesDataSet + + +@pytest.fixture +def test_data(): + data = get_stallion_data() + data["month"] = data.date.dt.month.astype(str) + data["log_volume"] = np.log1p(data.volume) + data["weight"] = 1 + np.sqrt(data.volume) + + data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month + data["time_idx"] -= data["time_idx"].min() + + special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + data[special_days] = data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") + + data = data[lambda x: x.time_idx < 10] # downsample + return data + + +@pytest.fixture +def test_dataset(test_data): + training = TimeSeriesDataSet( + test_data, + time_idx="time_idx", + target="volume", + time_varying_known_reals=["price_regular"], + group_ids=["agency", "sku"], + static_categoricals=["agency"], + max_encoder_length=5, + max_prediction_length=2, + randomize_length=None, + ) + return training diff --git a/tests/test_data.py b/tests/test_data.py index a4e9dd3d..3949f9d0 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -119,19 +119,6 @@ def test_GroupNormalizer(kwargs, groups): ).all(), "Inverse transform should reverse transform" -@pytest.fixture -def test_data(): - data = get_stallion_data() - data["month"] = data.date.dt.month - data["log_volume"] = np.log1p(data.volume) - data["weight"] = 1 + np.sqrt(data.volume) - - data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month - data["time_idx"] -= data["time_idx"].min() - data = data[lambda x: x.time_idx < 10] # downsample - return data - - def check_dataloader_output(dataset: TimeSeriesDataSet, out: Dict[str, torch.Tensor]): x, y = out @@ -219,22 +206,6 @@ def test_TimeSeriesDataSet(test_data, kwargs): check_dataloader_output(dataset, next(iter(dataset.to_dataloader(num_workers=0)))) -@pytest.fixture -def test_dataset(test_data): - training = TimeSeriesDataSet( - test_data, - time_idx="time_idx", - target="volume", - time_varying_known_reals=["price_regular"], - group_ids=["agency", "sku"], - static_categoricals=["agency"], - max_encoder_length=5, - max_prediction_length=2, - randomize_length=None, - ) - return training - - def test_from_dataset(test_dataset, test_data): dataset = TimeSeriesDataSet.from_dataset(test_dataset, test_data) check_dataloader_output(dataset, next(iter(dataset.to_dataloader(num_workers=0)))) diff --git a/tests/test_models/conftest.py b/tests/test_models/conftest.py index a6bb2c61..39c4771a 100644 --- a/tests/test_models/conftest.py +++ b/tests/test_models/conftest.py @@ -8,13 +8,30 @@ @pytest.fixture def data_with_covariates(): data = get_stallion_data() - data["month"] = data.date.dt.month + data["month"] = data.date.dt.month.astype(str) data["log_volume"] = np.log1p(data.volume) data["weight"] = 1 + np.sqrt(data.volume) data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month data["time_idx"] -= data["time_idx"].min() + # convert special days into strings + special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + data[special_days] = data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") + return data @@ -70,9 +87,9 @@ def dataloaders_with_coveratiates(data_with_covariates): @pytest.fixture def dataloaders_fixed_window_without_coveratiates(data_with_covariates): - data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100) - data["static"] = 2 - validation = data.series.sample(20) + data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=10) + data["static"] = "2" + validation = data.series.iloc[:2] max_encoder_length = 60 max_prediction_length = 20 diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index ddafff31..6df04281 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -5,7 +5,10 @@ from pytorch_forecasting.metrics import QuantileLoss from pytorch_forecasting.models import TemporalFusionTransformer - +# todo: run with multiple normalizers +# todo: run with muliple datasets and normalizers: ... +# todo: monotonicity +# todo: test different parameters def test_integration(dataloaders_with_coveratiates, tmp_path): train_dataloader = dataloaders_with_coveratiates["train"] val_dataloader = dataloaders_with_coveratiates["val"] @@ -50,3 +53,7 @@ def test_integration(dataloaders_with_coveratiates, tmp_path): net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) finally: shutil.rmtree(tmp_path, ignore_errors=True) + + +def test_monotinicity(): + pass