Skip to content

Commit

Permalink
corrected skill score, allowed history length different from steps ahead
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Mar 6, 2024
1 parent 0fd9f1f commit f6429a8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
50 changes: 35 additions & 15 deletions pyforecaster/forecasting_models/neural_models/INN.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,36 @@ def loss_fn(params, inputs, targets, model=None):
- jnp.log(jnp.mean(predictions**2, axis=0)) - 1)
return kl_gaussian_loss

def end_to_end_loss_fn(params, inputs, targets, model=None, embedder=None, inverter=identity):
e_preds = model(params, inputs)
e_target = embedder(params, jnp.hstack([inputs[:, targets.shape[1]:], targets]))[:, :targets.shape[1]]
e_preds = inverter(params, jnp.hstack([inputs[:, e_preds.shape[1]:], e_preds]))
e_target = inverter(params, e_target)
err = e_target - e_preds
def quasi_end_to_end_loss_fn(params, inputs, targets, model=None, embedder=None, inverter=identity):

# embedded forecasts
e_target_hat = model(params, inputs)

# retrieve the embedding of the target
inputs_future = jnp.hstack([inputs[:, targets.shape[1]:], targets])
e_target = embedder(params, inputs_future)[:, -targets.shape[1]:]

# make them match
err = e_target - e_target_hat
persistent_error = e_target - jnp.roll(e_target, 1, axis=1)
skill_score = jnp.mean(err**2) / jnp.mean(persistent_error**2)
return skill_score


def full_end_to_end_loss_fn(params, inputs, targets, model=None, embedder=None, inverter=identity):

# embedded forecasts
e_target_hat = model(params, inputs)

# retrieve the real forecasted target
e_inputs = embedder(params, inputs)
targets_hat = inverter(params, jnp.hstack([e_inputs[:, targets.shape[1]:], e_target_hat]))[:, -targets.shape[1]:]

err = targets - targets_hat
mse = jnp.mean(err ** 2)
return mse


class CausalInvertibleModule(nn.Module):
num_layers: int = 3
features: int = 32
Expand Down Expand Up @@ -111,11 +130,14 @@ def set_arch(self):
self.end_to_end in ['full', 'quasi']) else CausalInvertibleModule(num_layers=self.n_layers, features=self.n_hidden_x)

self.predict_batch = vmap(jitting_wrapper(predict_batch, self.model), in_axes=(None, 0))
if self.end_to_end in ['full', 'quasi']:
self.loss_fn = jitting_wrapper(end_to_end_loss_fn, self.predict_batch,
if self.end_to_end == 'quasi':
self.loss_fn = jitting_wrapper(quasi_end_to_end_loss_fn, self.predict_batch,
embedder=vmap(jitting_wrapper(embed, self.model), in_axes=(None, 0)),
inverter=identity)
elif self.end_to_end == 'full':
self.loss_fn = jitting_wrapper(full_end_to_end_loss_fn, self.predict_batch,
embedder=vmap(jitting_wrapper(embed, self.model), in_axes=(None, 0)),
inverter=vmap(jitting_wrapper(invert, self.model), in_axes=(None, 0)) if
self.end_to_end == 'full' else identity)
inverter=vmap(jitting_wrapper(invert, self.model), in_axes=(None, 0)))
else:
self.loss_fn = jitting_wrapper(loss_fn, self.predict_batch)

Expand Down Expand Up @@ -143,7 +165,8 @@ def predict(self, inputs, return_sigma=False, **kwargs):
if self.end_to_end in ['full', 'quasi']:
embeddings = embed(self.pars, x, self.model)
embeddings_hat = self.predict_batch(self.pars, x)
y_hat = invert(self.pars, embeddings_hat, self.model)
e_future = np.hstack([embeddings[:, embeddings_hat.shape[1]:], embeddings_hat])
y_hat = invert(self.pars, e_future, self.model)[:, -embeddings_hat.shape[1]:]

# embedding-predicted embedding distributions
import matplotlib.pyplot as plt
Expand All @@ -155,11 +178,8 @@ def predict(self, inputs, return_sigma=False, **kwargs):
# inputs-forecast distributions
if self.target_scaler is not None:
y_hat = self.target_scaler.inverse_transform(y_hat)
try:
ax[1].hist(np.array(y_hat.ravel()), bins=100, alpha=0.5, density=True, label='forecast ')
except:
print('Error in plotting forecast distribution')
ax[1].hist(np.array(inputs.values.ravel()), bins=100, alpha=0.5, density=True, label='inputs')
ax[1].hist(np.array(y_hat.ravel()), bins=100, alpha=0.5, density=True, label='forecast ')
ax[1].legend()
else:
y_hat = self.predict_batch(self.pars, x)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nns.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def test_invertible_causal_nn(self):

m = CausalInvertibleNN(learning_rate=1e-2, batch_size=200, load_path=None, n_hidden_x=144,
n_layers=2, normalize_target=False, n_epochs=5, stopping_rounds=20, rel_tol=-1,
end_to_end='quasi', n_hidden_y=300, n_prediction_layers=3, n_out=144).fit(e_tr.iloc[:, :144], e_tr.iloc[:, 144:])
end_to_end='full', n_hidden_y=300, n_prediction_layers=3, n_out=144).fit(e_tr.iloc[:, :144], e_tr.iloc[:, -144:])

z_hat_ete = m.predict(e_te.iloc[:, :144])

Expand Down

0 comments on commit f6429a8

Please sign in to comment.