From 88ea3959d446994e7fe7b96e18ed327aa77bde1a Mon Sep 17 00:00:00 2001 From: nepslor Date: Tue, 23 Jan 2024 13:45:52 +0100 Subject: [PATCH] added implicit regularization while optimizing --- .../forecasting_models/neural_forecasters.py | 57 ++++++++++++++----- tests/test_nns.py | 13 +++-- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index c0fd1b3..492e337 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -67,9 +67,11 @@ def loss_fn(params, inputs, targets, model=None): def embedded_loss_fn(params, inputs, targets, model=None): predictions, ctrl_embedding, ctrl_reconstruction = model(params, inputs) + predictions_from_ctrl_reconstr, _, _ = model(params, [inputs[0], ctrl_reconstruction]) target_loss = jnp.mean((predictions - targets) ** 2) ctrl_reconstruction_loss = jnp.mean((ctrl_reconstruction - inputs[1]) ** 2) - return target_loss + ctrl_reconstruction_loss + obj_reconstruction_loss = jnp.mean((predictions - predictions_from_ctrl_reconstr) ** 2) + return target_loss + ctrl_reconstruction_loss + obj_reconstruction_loss def probabilistic_loss(y_hat, y, sigma_square, kind='maximum_likelihood', distribution='normal'): @@ -550,8 +552,15 @@ def setup(self): ctrl_embedding_len = self.encoder_neurons[-1] features_y = 2*ctrl_embedding_len if self.augment_ctrl_inputs else ctrl_embedding_len - self.encoder = FeedForwardModule(n_layers=self.encoder_neurons, name='encoder') - self.decoder = FeedForwardModule(n_layers=self.decoder_neurons, name='decoder') + self.encoder = PartiallyICNN(num_layers=self.num_layers, features_x=self.features_x, features_y=self.features_y, + features_out=self.encoder_neurons[-1], features_latent=self.features_latent, init_type=self.init_type, + augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, + z_min=self.z_min, z_max=self.z_max, name='encoder') + + self.decoder = PartiallyICNN(num_layers=self.num_layers, features_x=self.features_x, features_y=features_y, + features_out=self.decoder_neurons[-1], features_latent=self.features_latent, init_type=self.init_type, + augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, + z_min=self.z_min, z_max=self.z_max, name='decoder') self.picnn = PartiallyICNN(num_layers=self.num_layers, features_x=self.features_x, features_y=features_y, @@ -561,23 +570,32 @@ def setup(self): def __call__(self, x, y): - ctrl_embedding = self.encoder(jnp.hstack([x, y])) + ctrl_embedding = self.encoder(x, y) z = self.picnn(x, ctrl_embedding) - ctrl_reconstruction = self.decoder(ctrl_embedding) + ctrl_reconstruction = self.decoder(x, ctrl_embedding) return z, ctrl_embedding, ctrl_reconstruction - def decode(self, ctrl_embedding): - return self.decoder(ctrl_embedding) + def decode(self, x, ctrl_embedding): + return self.decoder(x, ctrl_embedding) + def latent_pred(self, x, ctrl_embedding): + return self.picnn(x, ctrl_embedding) -def decode(params, model, ctrl_embedding): +def decode(params, model, x, ctrl_embedding): def decoder(lpicnn ): - return lpicnn.decode(ctrl_embedding) + return lpicnn.decode(x, ctrl_embedding) return nn.apply(decoder, model)(params) +def latent_pred(params, model, ctrl_embedding, x): + def _latent_pred(lpicnn ): + return lpicnn.latent_pred(ctrl_embedding, x) + + return nn.apply(_latent_pred, model)(params) + + class PartiallyIQCNN(nn.Module): num_layers: int features_x: int @@ -1061,6 +1079,8 @@ class LatentStructuredPICNN(PICNN): rec_stable: bool = False monotone: bool = True objective_fun=None + encoder_neurons: np.array = None + decoder_neurons: np.array = None def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None, val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, @@ -1109,8 +1129,11 @@ def optimize(self, inputs, objective, n_iter=200, rel_tol=1e-4, recompile_obj=Tr x, y = normalized_inputs def _objective(ctrl_embedding, x, **objective_kwargs): - ctrl = decode(self.pars, self.model, ctrl_embedding) - return objective(self.predict_batch(self.pars, [x, ctrl_embedding]), ctrl, **objective_kwargs) + preds = latent_pred(self.pars, self.model, x, ctrl_embedding) + ctrl_reconstruct = decode(self.pars, self.model, x, ctrl_embedding) + preds_reconstruct, _ , _ = self.predict_batch_training(self.pars, [x, ctrl_reconstruct]) + implicit_regularization_loss = jnp.mean((preds_reconstruct - preds)**2) + return objective(preds, ctrl_embedding, **objective_kwargs) + implicit_regularization_loss # if the objective changes from one call to another, you need to recompile it. Slower but necessary if recompile_obj or self.iterate is None: @@ -1126,8 +1149,8 @@ def iterate(x, y, opt_state, **objective_kwargs): return y, values self.iterate = iterate - opt_state = self.inverter_optimizer.init(y) _, ctrl_embedding, ctrl_reconstruct = self.predict_batch_training(self.pars, [x, y]) + opt_state = self.inverter_optimizer.init(ctrl_embedding) ctrl_embedding, values_old = self.iterate(x, ctrl_embedding, opt_state, **objective_kwargs) values_init = np.copy(values_old) @@ -1144,9 +1167,15 @@ def iterate(x, y, opt_state, **objective_kwargs): (values_init-values)/(np.abs(values_init)+1e-12))) - y = decode(self.pars, self.model, ctrl_embedding) + ctrl = decode(self.pars, self.model, x, ctrl_embedding) - inputs.loc[:, self.optimization_vars] = y.ravel() + y_hat_from_latent = latent_pred(self.pars, self.model, x, ctrl_embedding) + y_hat_from_ctrl_reconstructed, _ ,_ = self.predict_batch_training(self.pars, [x, ctrl]) + + plt.plot(y_hat_from_latent.ravel()) + plt.plot(y_hat_from_ctrl_reconstructed.ravel()) + + inputs.loc[:, self.optimization_vars] = ctrl.ravel() inputs.loc[:, [c for c in inputs.columns if c not in self.optimization_vars]] = x.ravel() inputs.loc[:, self.to_be_normalized] = self.input_scaler.inverse_transform(inputs[self.to_be_normalized].values) target_opt = self.predict(inputs) diff --git a/tests/test_nns.py b/tests/test_nns.py index de8a872..8170d59 100644 --- a/tests/test_nns.py +++ b/tests/test_nns.py @@ -372,15 +372,15 @@ def test_latent_picnn(self): m = LatentStructuredPICNN(learning_rate=1e-3, batch_size=1000, load_path=None, n_hidden_x=200, n_out=y_tr.shape[1], n_layers=3, optimization_vars=optimization_vars, inverter_learning_rate=1e-3, augment_ctrl_inputs=True, layer_normalization=True, unnormalized_inputs=optimization_vars, - n_first_encoder=20, n_last_encoder=100, n_first_decoder=100).fit(x_tr, y_tr, - n_epochs=1, + n_first_encoder=20, n_last_encoder=200, n_first_decoder=100).fit(x_tr, y_tr, + n_epochs=2, savepath_tr_plots=savepath_tr_plots, - stats_step=40 ) + stats_step=40) - objective = lambda y_hat, ctrl: jnp.mean(y_hat ** 2) + 0.0001*jnp.sum(ctrl**2) - ctrl_opt, inputs_opt, y_hat_opt, v_opt = m.optimize(x_te.iloc[[100], :], objective=objective,n_iter=5000) + objective = lambda y_hat, ctrl: jnp.mean(y_hat ** 2) + m.inverter_optimizer = optax.adabelief(learning_rate=1e-3) + ctrl_opt, inputs_opt, y_hat_opt, v_opt = m.optimize(x_te.iloc[[100], :], objective=objective,n_iter=500) - plt.plot(y_hat_opt.values.ravel()) rnd_idxs = np.random.choice(x_te.shape[0], 1) rnd_idxs = [100] for r in rnd_idxs: @@ -389,6 +389,7 @@ def test_latent_picnn(self): plt.plot(y_te.iloc[r, :].values.ravel(), label='y_true') plt.plot(y_hat.values.ravel(), label='y_hat') plt.legend() + plt.plot(y_hat_opt.values.ravel()) if __name__ == '__main__': unittest.main()