Skip to content

Commit

Permalink
added implicit regularization while optimizing
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Jan 23, 2024
1 parent 4c73b27 commit 88ea395
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 20 deletions.
57 changes: 43 additions & 14 deletions pyforecaster/forecasting_models/neural_forecasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def loss_fn(params, inputs, targets, model=None):

def embedded_loss_fn(params, inputs, targets, model=None):
predictions, ctrl_embedding, ctrl_reconstruction = model(params, inputs)
predictions_from_ctrl_reconstr, _, _ = model(params, [inputs[0], ctrl_reconstruction])
target_loss = jnp.mean((predictions - targets) ** 2)
ctrl_reconstruction_loss = jnp.mean((ctrl_reconstruction - inputs[1]) ** 2)
return target_loss + ctrl_reconstruction_loss
obj_reconstruction_loss = jnp.mean((predictions - predictions_from_ctrl_reconstr) ** 2)
return target_loss + ctrl_reconstruction_loss + obj_reconstruction_loss

def probabilistic_loss(y_hat, y, sigma_square, kind='maximum_likelihood', distribution='normal'):

Expand Down Expand Up @@ -550,8 +552,15 @@ def setup(self):
ctrl_embedding_len = self.encoder_neurons[-1]
features_y = 2*ctrl_embedding_len if self.augment_ctrl_inputs else ctrl_embedding_len

self.encoder = FeedForwardModule(n_layers=self.encoder_neurons, name='encoder')
self.decoder = FeedForwardModule(n_layers=self.decoder_neurons, name='decoder')
self.encoder = PartiallyICNN(num_layers=self.num_layers, features_x=self.features_x, features_y=self.features_y,
features_out=self.encoder_neurons[-1], features_latent=self.features_latent, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic,
z_min=self.z_min, z_max=self.z_max, name='encoder')

self.decoder = PartiallyICNN(num_layers=self.num_layers, features_x=self.features_x, features_y=features_y,
features_out=self.decoder_neurons[-1], features_latent=self.features_latent, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic,
z_min=self.z_min, z_max=self.z_max, name='decoder')


self.picnn = PartiallyICNN(num_layers=self.num_layers, features_x=self.features_x, features_y=features_y,
Expand All @@ -561,23 +570,32 @@ def setup(self):

def __call__(self, x, y):

ctrl_embedding = self.encoder(jnp.hstack([x, y]))
ctrl_embedding = self.encoder(x, y)
z = self.picnn(x, ctrl_embedding)
ctrl_reconstruction = self.decoder(ctrl_embedding)
ctrl_reconstruction = self.decoder(x, ctrl_embedding)


return z, ctrl_embedding, ctrl_reconstruction

def decode(self, ctrl_embedding):
return self.decoder(ctrl_embedding)
def decode(self, x, ctrl_embedding):
return self.decoder(x, ctrl_embedding)

def latent_pred(self, x, ctrl_embedding):
return self.picnn(x, ctrl_embedding)

def decode(params, model, ctrl_embedding):
def decode(params, model, x, ctrl_embedding):
def decoder(lpicnn ):
return lpicnn.decode(ctrl_embedding)
return lpicnn.decode(x, ctrl_embedding)

return nn.apply(decoder, model)(params)

def latent_pred(params, model, ctrl_embedding, x):
def _latent_pred(lpicnn ):
return lpicnn.latent_pred(ctrl_embedding, x)

return nn.apply(_latent_pred, model)(params)


class PartiallyIQCNN(nn.Module):
num_layers: int
features_x: int
Expand Down Expand Up @@ -1061,6 +1079,8 @@ class LatentStructuredPICNN(PICNN):
rec_stable: bool = False
monotone: bool = True
objective_fun=None
encoder_neurons: np.array = None
decoder_neurons: np.array = None
def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None,
n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None,
val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None,
Expand Down Expand Up @@ -1109,8 +1129,11 @@ def optimize(self, inputs, objective, n_iter=200, rel_tol=1e-4, recompile_obj=Tr
x, y = normalized_inputs

def _objective(ctrl_embedding, x, **objective_kwargs):
ctrl = decode(self.pars, self.model, ctrl_embedding)
return objective(self.predict_batch(self.pars, [x, ctrl_embedding]), ctrl, **objective_kwargs)
preds = latent_pred(self.pars, self.model, x, ctrl_embedding)
ctrl_reconstruct = decode(self.pars, self.model, x, ctrl_embedding)
preds_reconstruct, _ , _ = self.predict_batch_training(self.pars, [x, ctrl_reconstruct])
implicit_regularization_loss = jnp.mean((preds_reconstruct - preds)**2)
return objective(preds, ctrl_embedding, **objective_kwargs) + implicit_regularization_loss

# if the objective changes from one call to another, you need to recompile it. Slower but necessary
if recompile_obj or self.iterate is None:
Expand All @@ -1126,8 +1149,8 @@ def iterate(x, y, opt_state, **objective_kwargs):
return y, values
self.iterate = iterate

opt_state = self.inverter_optimizer.init(y)
_, ctrl_embedding, ctrl_reconstruct = self.predict_batch_training(self.pars, [x, y])
opt_state = self.inverter_optimizer.init(ctrl_embedding)
ctrl_embedding, values_old = self.iterate(x, ctrl_embedding, opt_state, **objective_kwargs)
values_init = np.copy(values_old)

Expand All @@ -1144,9 +1167,15 @@ def iterate(x, y, opt_state, **objective_kwargs):
(values_init-values)/(np.abs(values_init)+1e-12)))


y = decode(self.pars, self.model, ctrl_embedding)
ctrl = decode(self.pars, self.model, x, ctrl_embedding)

inputs.loc[:, self.optimization_vars] = y.ravel()
y_hat_from_latent = latent_pred(self.pars, self.model, x, ctrl_embedding)
y_hat_from_ctrl_reconstructed, _ ,_ = self.predict_batch_training(self.pars, [x, ctrl])

plt.plot(y_hat_from_latent.ravel())
plt.plot(y_hat_from_ctrl_reconstructed.ravel())

inputs.loc[:, self.optimization_vars] = ctrl.ravel()
inputs.loc[:, [c for c in inputs.columns if c not in self.optimization_vars]] = x.ravel()
inputs.loc[:, self.to_be_normalized] = self.input_scaler.inverse_transform(inputs[self.to_be_normalized].values)
target_opt = self.predict(inputs)
Expand Down
13 changes: 7 additions & 6 deletions tests/test_nns.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,15 @@ def test_latent_picnn(self):
m = LatentStructuredPICNN(learning_rate=1e-3, batch_size=1000, load_path=None, n_hidden_x=200,
n_out=y_tr.shape[1], n_layers=3, optimization_vars=optimization_vars, inverter_learning_rate=1e-3,
augment_ctrl_inputs=True, layer_normalization=True, unnormalized_inputs=optimization_vars,
n_first_encoder=20, n_last_encoder=100, n_first_decoder=100).fit(x_tr, y_tr,
n_epochs=1,
n_first_encoder=20, n_last_encoder=200, n_first_decoder=100).fit(x_tr, y_tr,
n_epochs=2,
savepath_tr_plots=savepath_tr_plots,
stats_step=40 )
stats_step=40)

objective = lambda y_hat, ctrl: jnp.mean(y_hat ** 2) + 0.0001*jnp.sum(ctrl**2)
ctrl_opt, inputs_opt, y_hat_opt, v_opt = m.optimize(x_te.iloc[[100], :], objective=objective,n_iter=5000)
objective = lambda y_hat, ctrl: jnp.mean(y_hat ** 2)
m.inverter_optimizer = optax.adabelief(learning_rate=1e-3)
ctrl_opt, inputs_opt, y_hat_opt, v_opt = m.optimize(x_te.iloc[[100], :], objective=objective,n_iter=500)

plt.plot(y_hat_opt.values.ravel())
rnd_idxs = np.random.choice(x_te.shape[0], 1)
rnd_idxs = [100]
for r in rnd_idxs:
Expand All @@ -389,6 +389,7 @@ def test_latent_picnn(self):
plt.plot(y_te.iloc[r, :].values.ravel(), label='y_true')
plt.plot(y_hat.values.ravel(), label='y_hat')
plt.legend()
plt.plot(y_hat_opt.values.ravel())

if __name__ == '__main__':
unittest.main()
Expand Down

0 comments on commit 88ea395

Please sign in to comment.