diff --git a/pyforecaster/forecasting_models/neural_models/INN.py b/pyforecaster/forecasting_models/neural_models/INN.py index fddc1b7..724f933 100644 --- a/pyforecaster/forecasting_models/neural_models/INN.py +++ b/pyforecaster/forecasting_models/neural_models/INN.py @@ -19,17 +19,36 @@ def loss_fn(params, inputs, targets, model=None): - jnp.log(jnp.mean(predictions**2, axis=0)) - 1) return kl_gaussian_loss -def end_to_end_loss_fn(params, inputs, targets, model=None, embedder=None, inverter=identity): - e_preds = model(params, inputs) - e_target = embedder(params, jnp.hstack([inputs[:, targets.shape[1]:], targets]))[:, :targets.shape[1]] - e_preds = inverter(params, jnp.hstack([inputs[:, e_preds.shape[1]:], e_preds])) - e_target = inverter(params, e_target) - err = e_target - e_preds +def quasi_end_to_end_loss_fn(params, inputs, targets, model=None, embedder=None, inverter=identity): + + # embedded forecasts + e_target_hat = model(params, inputs) + + # retrieve the embedding of the target + inputs_future = jnp.hstack([inputs[:, targets.shape[1]:], targets]) + e_target = embedder(params, inputs_future)[:, -targets.shape[1]:] + + # make them match + err = e_target - e_target_hat persistent_error = e_target - jnp.roll(e_target, 1, axis=1) skill_score = jnp.mean(err**2) / jnp.mean(persistent_error**2) return skill_score +def full_end_to_end_loss_fn(params, inputs, targets, model=None, embedder=None, inverter=identity): + + # embedded forecasts + e_target_hat = model(params, inputs) + + # retrieve the real forecasted target + e_inputs = embedder(params, inputs) + targets_hat = inverter(params, jnp.hstack([e_inputs[:, targets.shape[1]:], e_target_hat]))[:, -targets.shape[1]:] + + err = targets - targets_hat + mse = jnp.mean(err ** 2) + return mse + + class CausalInvertibleModule(nn.Module): num_layers: int = 3 features: int = 32 @@ -111,11 +130,14 @@ def set_arch(self): self.end_to_end in ['full', 'quasi']) else CausalInvertibleModule(num_layers=self.n_layers, features=self.n_hidden_x) self.predict_batch = vmap(jitting_wrapper(predict_batch, self.model), in_axes=(None, 0)) - if self.end_to_end in ['full', 'quasi']: - self.loss_fn = jitting_wrapper(end_to_end_loss_fn, self.predict_batch, + if self.end_to_end == 'quasi': + self.loss_fn = jitting_wrapper(quasi_end_to_end_loss_fn, self.predict_batch, + embedder=vmap(jitting_wrapper(embed, self.model), in_axes=(None, 0)), + inverter=identity) + elif self.end_to_end == 'full': + self.loss_fn = jitting_wrapper(full_end_to_end_loss_fn, self.predict_batch, embedder=vmap(jitting_wrapper(embed, self.model), in_axes=(None, 0)), - inverter=vmap(jitting_wrapper(invert, self.model), in_axes=(None, 0)) if - self.end_to_end == 'full' else identity) + inverter=vmap(jitting_wrapper(invert, self.model), in_axes=(None, 0))) else: self.loss_fn = jitting_wrapper(loss_fn, self.predict_batch) @@ -143,7 +165,8 @@ def predict(self, inputs, return_sigma=False, **kwargs): if self.end_to_end in ['full', 'quasi']: embeddings = embed(self.pars, x, self.model) embeddings_hat = self.predict_batch(self.pars, x) - y_hat = invert(self.pars, embeddings_hat, self.model) + e_future = np.hstack([embeddings[:, embeddings_hat.shape[1]:], embeddings_hat]) + y_hat = invert(self.pars, e_future, self.model)[:, -embeddings_hat.shape[1]:] # embedding-predicted embedding distributions import matplotlib.pyplot as plt @@ -155,11 +178,8 @@ def predict(self, inputs, return_sigma=False, **kwargs): # inputs-forecast distributions if self.target_scaler is not None: y_hat = self.target_scaler.inverse_transform(y_hat) - try: - ax[1].hist(np.array(y_hat.ravel()), bins=100, alpha=0.5, density=True, label='forecast ') - except: - print('Error in plotting forecast distribution') ax[1].hist(np.array(inputs.values.ravel()), bins=100, alpha=0.5, density=True, label='inputs') + ax[1].hist(np.array(y_hat.ravel()), bins=100, alpha=0.5, density=True, label='forecast ') ax[1].legend() else: y_hat = self.predict_batch(self.pars, x) diff --git a/tests/test_nns.py b/tests/test_nns.py index fa4075d..6d5f4fc 100644 --- a/tests/test_nns.py +++ b/tests/test_nns.py @@ -502,7 +502,7 @@ def test_invertible_causal_nn(self): m = CausalInvertibleNN(learning_rate=1e-2, batch_size=200, load_path=None, n_hidden_x=144, n_layers=2, normalize_target=False, n_epochs=5, stopping_rounds=20, rel_tol=-1, - end_to_end='quasi', n_hidden_y=300, n_prediction_layers=3, n_out=144).fit(e_tr.iloc[:, :144], e_tr.iloc[:, 144:]) + end_to_end='full', n_hidden_y=300, n_prediction_layers=3, n_out=144).fit(e_tr.iloc[:, :144], e_tr.iloc[:, -144:]) z_hat_ete = m.predict(e_te.iloc[:, :144])