Skip to content

Commit

Permalink
Fix remaining failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Beitner committed Aug 15, 2020
1 parent c92ef96 commit 9b1d12f
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 77 deletions.
3 changes: 1 addition & 2 deletions examples/stallion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

data = get_stallion_data()

data["month"] = data.date.dt.month
data["month"] = data.date.dt.month.astype("str").astype("category")
data["log_volume"] = np.log(data.volume + 1e-8)

data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
Expand Down Expand Up @@ -125,7 +125,6 @@
# fig.show()
# tft.hparams.learning_rate = res.suggestion()


trainer.fit(
tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)
Expand Down
87 changes: 52 additions & 35 deletions pytorch_forecasting/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def fit(self, y):
self.classes_[val] = idx + 1
else:
self.classes_ = {val: idx for idx, val in enumerate(np.unique(y))}
self.classes_vector = np.array(list(self.classes_.keys()))
self.classes_vector_ = np.array(list(self.classes_.keys()))
return self

def transform(self, y):
Expand All @@ -86,11 +86,11 @@ def transform(self, y):
return encoded

def inverse_transform(self, y):
if y.max() >= len(self.classes_vector):
if y.max() >= len(self.classes_vector_):
raise KeyError("New unknown values detected")

# decode
decoded = self.classes_vector[y]
decoded = self.classes_vector_[y]
return decoded


Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(
self.coerce_positive = coerce_positive

def get_parameters(self, *args, **kwargs):
return torch.tensor([self.center_, self.scale])
return torch.tensor([self.center_, self.scale_])

def _preprocess_y(self, y):
if self.coerce_positive is None and not self.log_scale:
Expand All @@ -168,10 +168,10 @@ def fit(self, y):
if self.method == "standard":
if isinstance(y, torch.Tensor):
self.center_ = torch.mean(y)
self.scale = torch.std(y) / (self.center_ + self.eps)
self.scale_ = torch.std(y) / (self.center_ + self.eps)
else:
self.center_ = np.mean(y)
self.scale = np.std(y) / (self.center_ + self.eps)
self.scale_ = np.std(y) / (self.center_ + self.eps)

elif self.method == "robust":
if isinstance(y, torch.Tensor):
Expand All @@ -182,7 +182,7 @@ def fit(self, y):
self.center_ = np.median(y)
q_75 = np.percentiley(y, 75)
q_25 = np.percentiley(y, 25)
self.scale = (q_75 - q_25) / (self.center_ + self.eps) / 2.0
self.scale_ = (q_75 - q_25) / (self.center_ + self.eps) / 2.0
return self

def transform(self, y, return_norm: bool = False):
Expand All @@ -192,7 +192,7 @@ def transform(self, y, return_norm: bool = False):
else:
y = np.log(y + self.log_zero_value + self.eps)
if self.center:
y = (y / (self.center_ + self.eps) - 1) / (self.scale + self.eps)
y = (y / (self.center_ + self.eps) - 1) / (self.scale_ + self.eps)
else:
y = y / (self.center_ + self.eps)
if return_norm:
Expand All @@ -208,7 +208,7 @@ def __call__(self, data: Dict[str, torch.Tensor]):
norm = data["target_scale"]

# use correct shape for norm
if data["prediction"].ndim > 1:
if data["prediction"].ndim > norm.ndim:
norm = norm.unsqueeze(-1)

# transform
Expand Down Expand Up @@ -286,14 +286,14 @@ def fit(self, y, X):
assert not self.scale_by_group, "No groups are defined, i.e. `scale_by_group=[]`"
if self.method == "standard":
mean = np.mean(y)
self.norm = mean, np.std(y) / (mean + self.eps)
self.norm_ = mean, np.std(y) / (mean + self.eps)
else:
quantiles = np.quantile(y, [0.25, 0.5, 0.75])
self.norm = quantiles[1], (quantiles[2] - quantiles[0]) / (quantiles[1] + self.eps)
self.norm_ = quantiles[1], (quantiles[2] - quantiles[0]) / (quantiles[1] + self.eps)

elif self.scale_by_group:
if self.method == "standard":
self.norm = {
self.norm_ = {
g: X[[g]]
.assign(y=y)
.groupby(g, observed=True)
Expand All @@ -302,7 +302,7 @@ def fit(self, y, X):
for g in self.groups
}
else:
self.norm = {
self.norm_ = {
g: X[[g]]
.assign(y=y)
.groupby(g, observed=True)
Expand All @@ -315,19 +315,19 @@ def fit(self, y, X):
for g in self.groups
}
# calculate missings
self._missing = {group: scales.median().to_dict() for group, scales in self.norm.items()}
self.missing_ = {group: scales.median().to_dict() for group, scales in self.norm_.items()}

else:
if self.method == "standard":
self.norm = (
self.norm_ = (
X[self.groups]
.assign(y=y)
.groupby(self.groups, observed=True)
.agg(mean=("y", "mean"), scale=("y", "std"))
.assign(scale=lambda x: x.scale / (x["mean"] + self.eps))
)
else:
self.norm = (
self.norm_ = (
X[self.groups]
.assign(y=y)
.groupby(self.groups, observed=True)
Expand All @@ -338,7 +338,7 @@ def fit(self, y, X):
scale=lambda x: (x[0.75] - x[0.25] + self.eps) / (x[0.5] + self.eps) / 2.0,
)[["median", "scale"]]
)
self._missing = self.norm.median().to_dict()
self.missing_ = self.norm_.median().to_dict()
return self

@property
Expand Down Expand Up @@ -379,32 +379,33 @@ def get_parameters(self, groups, group_names: List[str] = None):
assert len(group_names) == len(self.groups), "Passed groups and fitted do not match"

if len(self.groups) == 0:
return np.asarray(self.norm)
params = np.asarray(self.norm_).squeeze()
elif self.scale_by_group:
norm = np.array([1.0, 1.0])
for group, group_name in zip(groups, group_names):
try:
norm = norm * self.norm[group_name].loc[group].to_numpy()
norm = norm * self.norm_[group_name].loc[group].to_numpy()
except KeyError:
norm = norm * np.asarray([self._missing[group_name][name] for name in self.names])
norm = norm * np.asarray([self.missing_[group_name][name] for name in self.names])
norm = np.power(norm, 1.0 / len(self.groups))
return norm
params = norm
else:
try:
return self.norm.loc[groups].to_numpy()
params = self.norm_.loc[groups].to_numpy()
except (KeyError, TypeError):
return np.asarray([self._missing[name] for name in self.names])
params = np.asarray([self.missing_[name] for name in self.names])
return params

def get_norm(self, X):
if len(self.groups) == 0:
norm = np.asarray(self.norm).reshape(1, -1)
norm = np.asarray(self.norm_).reshape(1, -1)
elif self.scale_by_group:
norm = [
np.prod(
[
X[group_name]
.map(self.norm[group_name][name])
.fillna(self._missing[group_name][name])
.map(self.norm_[group_name][name])
.fillna(self.missing_[group_name][name])
.to_numpy()
for group_name in self.groups
],
Expand All @@ -414,7 +415,7 @@ def get_norm(self, X):
]
norm = np.power(np.stack(norm, axis=1), 1.0 / len(self.groups))
else:
norm = X[self.groups].set_index(self.groups).join(self.norm).fillna(self._missing).to_numpy()
norm = X[self.groups].set_index(self.groups).join(self.norm_).fillna(self.missing_).to_numpy()
return norm


Expand Down Expand Up @@ -512,7 +513,7 @@ def __init__(
self.target = target
self.weight = weight
self.time_idx = time_idx
self.group_ids = group_ids
self.group_ids = [] + group_ids
self.static_categoricals = [] + static_categoricals
self.static_reals = [] + static_reals
self.time_varying_known_categoricals = [] + time_varying_known_categoricals
Expand All @@ -523,14 +524,14 @@ def __init__(
self.add_relative_time_idx = add_relative_time_idx
self.randomize_length = randomize_length
self.min_prediction_idx = min_prediction_idx or data[self.time_idx].min()
self.constant_fill_strategy = constant_fill_strategy
self.constant_fill_strategy = {} if len(constant_fill_strategy) == 0 else constant_fill_strategy
self.predict_mode = predict_mode
self.allow_missings = allow_missings
self.target_normalizer = target_normalizer
self.categorical_encoders = categorical_encoders
self.scalers = scalers
self.categorical_encoders = {} if len(categorical_encoders) == 0 else categorical_encoders
self.scalers = {} if len(scalers) == 0 else scalers
self.add_target_scales = add_target_scales
self.variable_groups = variable_groups
self.variable_groups = {} if len(variable_groups) == 0 else variable_groups

# add_decoder_length
if isinstance(add_decoder_length, str):
Expand Down Expand Up @@ -591,6 +592,9 @@ def __init__(
self.static_reals.append("decoder_length")
data["decoder_length"] = 0.0 # dummy - real value will be set dynamiclly in __getitem__()

# validate
self._validate_data(data)

# preprocess data
data = self._preprocess_data(data)

Expand All @@ -600,6 +604,19 @@ def __init__(
# convert to torch tensor for high performance data loading later
self.data = self._data_to_tensors(data)

def _validate_data(self, data: pd.DataFrame):
# check for numeric categoricals which can cause hick-ups in logging
category_columns = data.head(1).select_dtypes("category").columns
object_columns = data.head(1).select_dtypes(object).columns
for name in self.flat_categoricals:
if not (
name in object_columns
or (name in category_columns and data[name].cat.categories.dtype.kind not in "bifc")
):
raise ValueError(
f"Data type of category {name} was found to be numeric - use a string type / categorified string"
)

def save(self, fname: str) -> None:
"""
Save dataset to disk
Expand Down Expand Up @@ -639,7 +656,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
)
elif self.categorical_encoders[name] is not None:
try:
check_is_fitted(self.target_normalizer)
check_is_fitted(self.categorical_encoders[name])
except NotFittedError:
self.categorical_encoders[name] = self.categorical_encoders[name].fit(
data[columns].to_numpy().reshape(-1)
Expand All @@ -649,7 +666,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
self.categorical_encoders[name] = NaNLabelEncoder(add_nan=allow_nans).fit(data[name])
elif self.categorical_encoders[name] is not None:
try:
check_is_fitted(self.target_normalizer)
check_is_fitted(self.categorical_encoders[name])
except NotFittedError:
self.categorical_encoders[name] = self.categorical_encoders[name].fit(data[name])

Expand Down Expand Up @@ -716,7 +733,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
self.scalers[name] = StandardScaler().fit(data[[name]])
elif self.scalers[name] is not None:
try:
check_is_fitted(self.target_normalizer)
check_is_fitted(self.scalers[name])
except NotFittedError:
if isinstance(self.scalers[name], GroupNormalizer):
self.scalers[name] = self.scalers[name].fit(data[[name]], data)
Expand Down
10 changes: 4 additions & 6 deletions pytorch_forecasting/models/nbeats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
weight_decay: float = 1e-3,
loss=SMAPE(),
reduce_on_plateau_patience: int = 1000,
**kwargs,
):
"""
Initialize NBeats Model
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(
reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10
"""
self.save_hyperparameters()
super().__init__()
super().__init__(**kwargs)
self.loss = loss

# setup stacks
Expand Down Expand Up @@ -106,9 +107,6 @@ def __init__(
def forward(self, x: Dict[str, torch.Tensor]):
target = x["encoder_target"]

if self.loss.log_space:
target = torch.log(target + 1e-8)

timesteps = self.hparams.context_length + self.hparams.prediction_length
generic_forecast = [torch.zeros((target.size(0), timesteps), dtype=torch.float32, device=self.device)]
trend_forecast = [torch.zeros((target.size(0), timesteps), dtype=torch.float32, device=self.device)]
Expand Down Expand Up @@ -139,6 +137,7 @@ def forward(self, x: Dict[str, torch.Tensor]):

return dict(
prediction=forecast,
target_scale=x["target_scale"],
backcast=backcast,
trend=torch.stack(trend_forecast, dim=0).sum(0),
seasonality=torch.stack(seasonal_forecast, dim=0).sum(0),
Expand All @@ -148,7 +147,6 @@ def forward(self, x: Dict[str, torch.Tensor]):
@classmethod
def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
new_kwargs = {"prediction_length": dataset.max_prediction_length, "context_length": dataset.max_encoder_length}
new_kwargs["dataset_parameters"] = dataset.get_parameters()
new_kwargs.update(kwargs)

# validate arguments
Expand All @@ -167,7 +165,7 @@ def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
return super().from_dataset(dataset, **new_kwargs)

def step(self, x, y, batch_idx, label) -> Dict[str, torch.Tensor]:
log, out = self.step(x, y, batch_idx=batch_idx, label=label)
log, out = super().step(x, y, batch_idx=batch_idx, label=label)
self._log_interpretation(x, out, batch_idx=batch_idx, label=label)
return log, out

Expand Down
49 changes: 49 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,52 @@

sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../..")))
sys.path.insert(0, "examples")

from examples.data import get_stallion_data
from pytorch_forecasting import TimeSeriesDataSet


@pytest.fixture
def test_data():
data = get_stallion_data()
data["month"] = data.date.dt.month.astype(str)
data["log_volume"] = np.log1p(data.volume)
data["weight"] = 1 + np.sqrt(data.volume)

data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()

special_days = [
"easter_day",
"good_friday",
"new_year",
"christmas",
"labor_day",
"independence_day",
"revolution_day_memorial",
"regional_games",
"fifa_u_17_world_cup",
"football_gold_cup",
"beer_capital",
"music_fest",
]
data[special_days] = data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category")

data = data[lambda x: x.time_idx < 10] # downsample
return data


@pytest.fixture
def test_dataset(test_data):
training = TimeSeriesDataSet(
test_data,
time_idx="time_idx",
target="volume",
time_varying_known_reals=["price_regular"],
group_ids=["agency", "sku"],
static_categoricals=["agency"],
max_encoder_length=5,
max_prediction_length=2,
randomize_length=None,
)
return training
Loading

0 comments on commit 9b1d12f

Please sign in to comment.