Skip to content

Commit

Permalink
pre split
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Feb 26, 2024
1 parent 53c867b commit f1c8aab
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 60 deletions.
88 changes: 58 additions & 30 deletions pyforecaster/forecasting_models/neural_forecasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def reproject_weights(params, rec_stable=False, monotone=False):
for layer_name in params['params']:
if 'PICNNLayer' in layer_name:
if 'wz' in params['params'][layer_name].keys():
_reproject(params['params'], layer_name, rec_stable=rec_stable, monotone=monotone)
_reproject(params['params'], layer_name, rec_stable=rec_stable, monotone=monotone or ('monotone' in layer_name))
else:
for name in params['params'][layer_name].keys():
_reproject(params['params'][layer_name], name, rec_stable=rec_stable, monotone=monotone or ('monotone' in layer_name))
Expand Down Expand Up @@ -81,7 +81,7 @@ def embedded_loss_fn(params, inputs, targets, model=None, full_model=None):
kl_gaussian_loss = jnp.mean(jnp.mean(ctrl_embedding**2, axis=0) + jnp.mean(ctrl_embedding, axis=0)**2
- jnp.log(jnp.mean(ctrl_embedding**2, axis=0)) - 1) # (sigma^2 + mu^2 - log(sigma^2) - 1)/2

return target_loss + ctrl_reconstruction_loss + obj_reconstruction_loss + kl_gaussian_loss
return target_loss + ctrl_reconstruction_loss + kl_gaussian_loss + 100*obj_reconstruction_loss



Expand Down Expand Up @@ -193,20 +193,21 @@ def __call__(self, y, u, z):
if self.augment_ctrl_inputs:
y = jnp.hstack([y, -y])

y_add_kernel_init = nn.initializers.lecun_normal() if self.rec_activation == identity else partial(positive_lecun, init_type=self.init_type)
# Input-Convex component without bias for the element-wise multiplicative interactions
wzu = nn.relu(nn.Dense(features=self.features_latent, use_bias=True, name='wzu')(u))
wyu = self.rec_activation(nn.Dense(features=self.features_y, use_bias=True, name='wyu')(u))
z_next = nn.Dense(features=self.features_out, use_bias=False, name='wz', kernel_init=partial(positive_lecun, init_type=self.init_type))(z * wzu)
y_next = nn.Dense(features=self.features_out, use_bias=False, name='wy')(y * wyu)
z_add = nn.Dense(features=self.features_out, use_bias=False, name='wz', kernel_init=partial(positive_lecun, init_type=self.init_type))(z * wzu)
y_add = nn.Dense(features=self.features_out, use_bias=False, name='wy', kernel_init=y_add_kernel_init)(y * wyu)
u_add = nn.Dense(features=self.features_out, use_bias=True, name='wuz')(u)


if self.layer_normalization:
y_next = nn.LayerNorm()(y_next)
z_next = nn.LayerNorm()(z_next)
y_add = nn.LayerNorm()(y_add)
z_add = nn.LayerNorm()(z_add)
u_add = nn.LayerNorm()(u_add)

z_next = z_next + y_next + u_add
z_next = z_add + y_add + u_add
if not self.prediction_layer:
z_next = self.activation(z_next)
# Traditional NN component only if it's not the prediction layer
Expand Down Expand Up @@ -514,8 +515,9 @@ def __call__(self, x, y):
for i in range(self.num_layers):
prediction_layer = i == self.num_layers -1
features_out = self.features_out if prediction_layer else self.features_latent
features_latent = self.features_latent if self.features_latent is not None else self.features_out
u, z = PICNNLayer(features_x=self.features_x, features_y=self.features_y, features_out=features_out,
features_latent=self.features_latent,
features_latent=features_latent,
n_layer=i, prediction_layer=prediction_layer, activation=self.activation,
rec_activation=self.rec_activation, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs,
Expand Down Expand Up @@ -563,11 +565,11 @@ class LatentPartiallyICNN(nn.Module):


def setup(self):
features_latent_encoder = np.minimum(self.n_embeddings, self.features_latent)
features_latent_decoder = np.minimum(self.n_control, self.features_latent)

features_latent_encoder = self.features_latent
features_latent_decoder = self.features_latent
features_latent_encoder = self.features_latent if self.features_latent is not None else self.n_embeddings
features_latent_decoder = self.features_latent if self.features_latent is not None else self.n_control
features_latent_preds = self.features_latent if self.features_latent is not None else self.features_out


features_y = self.n_embeddings*2 if self.augment_ctrl_inputs else self.n_embeddings
self.encoder = PartiallyICNN(num_layers=self.n_encoder_layers, features_x=self.features_x, features_y=self.features_y,
Expand All @@ -577,12 +579,12 @@ def setup(self):

self.decoder = PartiallyICNN(num_layers=self.n_decoder_layers, features_x=self.features_x, features_y=features_y,
features_out=self.n_control, features_latent=features_latent_decoder, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic,
z_min=self.z_min, z_max=self.z_max, name='PICNNLayer_decoder')
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, rec_activation=nn.relu,
z_min=self.z_min, z_max=self.z_max, name='PICNNLayer_decoder_monotone')


self.picnn = PartiallyICNN(num_layers=self.num_layers, features_x=self.features_x, features_y=features_y,
features_out=self.features_out, features_latent=self.features_latent, init_type=self.init_type,
features_out=self.features_out, features_latent=features_latent_preds, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic,
z_min=self.z_min, z_max=self.z_max, name='PICNNLayer_picnn')

Expand Down Expand Up @@ -712,16 +714,18 @@ class PICNN(NN):
z_min: jnp.array = None
z_max: jnp.array = None
n_latent: int = 1
inverter_class = None
def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None,
n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None,
val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None,
stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True,
stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False,
probabilistic_loss_kind='maximum_likelihood', distribution = 'normal', inverter_learning_rate: float = 0.1, optimization_vars: list = (),
target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False,
z_min: jnp.array = None, z_max: jnp.array = None,
z_min: jnp.array = None, z_max: jnp.array = None, inverter_class=None,
**scengen_kwgs):

inverter_class = optax.adabelief if inverter_class is None else inverter_class

self.set_attr({"inverter_learning_rate":inverter_learning_rate,
"optimization_vars":optimization_vars,
Expand All @@ -733,10 +737,12 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat
"distribution": distribution,
"z_min": z_min,
"z_max": z_max,
"n_latent":n_latent
"n_latent":n_latent,
"inverter_class":inverter_class
})

self.n_hidden_y = 2 * len(self.optimization_vars) if augment_ctrl_inputs else len(self.optimization_vars)
self.inverter_optimizer = optax.adabelief(learning_rate=self.inverter_learning_rate)
self.inverter_optimizer = inverter_class(learning_rate=self.inverter_learning_rate)

super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio,
nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs,
Expand Down Expand Up @@ -1110,7 +1116,7 @@ class LatentStructuredPICNN(PICNN):
decoder_neurons: np.array = None
n_embeddings: int = 10
def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None,
n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None,
n_hidden_x: int = 100, n_out: int = 1, n_latent:int = None, n_layers: int = 3, pars: dict = None, q_vect=None,
val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None,
stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True,
stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False,
Expand Down Expand Up @@ -1152,50 +1158,72 @@ def optimize(self, inputs, objective, n_iter=200, rel_tol=1e-4, recompile_obj=Tr
normalized_inputs, _ = self.get_normalized_inputs(inputs)
x, y = normalized_inputs

def _objective(ctrl_embedding, x, **objective_kwargs):
def _preds_and_regularizations(ctrl_embedding, x):
preds = latent_pred(self.pars, self.model, x, ctrl_embedding)
ctrl_reconstruct = decode(self.pars, self.model, x, ctrl_embedding)
preds_reconstruct, _ , _ = self.predict_batch_training(self.pars, [x, ctrl_reconstruct])
implicit_regularization_loss = jnp.mean((preds_reconstruct - preds)**2)
# implicit_regularization_loss = jnp.mean((preds_reconstruct - preds)**2)
ctrl_embedding_reconstruct = encode(self.pars, self.model, x, ctrl_reconstruct)
#implicit_regularization_loss = jnp.mean((ctrl_embedding_reconstruct - ctrl_embedding)**2)
return objective(preds, ctrl_reconstruct, **objective_kwargs) + implicit_regularization_loss
implicit_regularization_on_ctrl_loss = jnp.mean((ctrl_embedding_reconstruct - ctrl_embedding)**2)
regularization_loss = implicit_regularization_on_ctrl_loss #+ implicit_regularization_loss
return preds, ctrl_reconstruct, regularization_loss

def _objective(ctrl_embedding, x, weight=0, **objective_kwargs):
preds, ctrl_reconstruct, regularization_loss = _preds_and_regularizations(ctrl_embedding, x)
return objective(preds, ctrl_reconstruct, **objective_kwargs) + (weight * regularization_loss).ravel()[0]

self._objective = _objective
# if the objective changes from one call to another, you need to recompile it. Slower but necessary
if recompile_obj or self.iterate is None:
@jit
def iterate(x, y, opt_state, **objective_kwargs):
def iterate(x, y, opt_state, lagrangian=0, **objective_kwargs):
ctrl_embedding_history = []
values_history = []
for i in range(10):
values, grads = value_and_grad(partial(_objective, **objective_kwargs))(y, x)
ctrl_embedding_history.append(y)
values, grads = value_and_grad(partial(_objective, weight=lagrangian, **objective_kwargs))(y, x)
#lagrangian = lagrangian + 0.01 * self.inverter_learning_rate * jnp.maximum(_preds_and_regularizations(y, x)[-1], 0)
if vanilla_gd:
y -= grads * 1e-1
else:
updates, opt_state = self.inverter_optimizer.update(grads, opt_state, y)
y = optax.apply_updates(y, updates)
return y, values
values_history.append(values)
return y, values, ctrl_embedding_history, values_history, lagrangian
self.iterate = iterate

_, ctrl_embedding, ctrl_reconstruct = self.predict_batch_training(self.pars, [x, y])
opt_state = self.inverter_optimizer.init(ctrl_embedding)
ctrl_embedding, values_old = self.iterate(x, ctrl_embedding, opt_state, **objective_kwargs)
ctrl_embedding, values_old, ctrl_embedding_history_0, values_history_0, lagrangian = self.iterate(x, ctrl_embedding, opt_state, **objective_kwargs)
values_init = np.copy(values_old)

l = vmap(jit(partial(latent_pred, params=self.pars, model=self.model, x=x)), in_axes=(0))

# do 10 iterations at a time to speed up, check for convergence
ctrl_history = [np.vstack(ctrl_embedding_history_0)]
sol_history = [np.vstack(l(ctrl_embedding=np.vstack(ctrl_embedding_history_0)))]
lagrangian = jnp.array(0)
init_constraint = 0
for i in range(n_iter//10):
ctrl_embedding, values = self.iterate(x, ctrl_embedding, opt_state, **objective_kwargs)
ctrl_embedding, values, ctrl_embedding_history, values_history, lagrangian = self.iterate(x, ctrl_embedding, opt_state, lagrangian=lagrangian, **objective_kwargs)
rel_improvement = (values_old - values) / (np.abs(values_old)+ 1e-12)
values_old = values
if i%10==0:
print(values)

if rel_improvement < rel_tol:
break
ctrl_history.append(np.vstack(ctrl_embedding_history))
sol_history.append(np.vstack(l(ctrl_embedding=np.vstack(ctrl_embedding_history))))

print('optimization terminated at iter {}, final objective value: {:0.2e} '
'rel improvement: {:0.2e}'.format((i+1)*10, values,
(values_init-values)/(np.abs(values_init)+1e-12)))


ctrl = decode(self.pars, self.model, x, ctrl_embedding)

y_hat_from_latent = latent_pred(self.pars, self.model, x, ctrl_embedding)
y_hat_from_latent = l(ctrl_embedding=ctrl_embedding)
y_hat_from_ctrl_reconstructed, _ ,_ = self.predict_batch_training(self.pars, [x, ctrl])


Expand All @@ -1205,4 +1233,4 @@ def iterate(x, y, opt_state, **objective_kwargs):
target_opt = self.predict(inputs)

y_opt = inputs.loc[:, self.optimization_vars].values.ravel()
return y_opt, inputs, target_opt, values
return y_opt, inputs, target_opt, values, ctrl_embedding, np.vstack(ctrl_history), np.vstack(sol_history)
Loading

0 comments on commit f1c8aab

Please sign in to comment.