diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index 4c1c657..6706b59 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -46,7 +46,7 @@ def reproject_weights(params, rec_stable=False, monotone=False): for layer_name in params['params']: if 'PICNNLayer' in layer_name: if 'wz' in params['params'][layer_name].keys(): - _reproject(params['params'], layer_name, rec_stable=rec_stable, monotone=monotone) + _reproject(params['params'], layer_name, rec_stable=rec_stable, monotone=monotone or ('monotone' in layer_name)) else: for name in params['params'][layer_name].keys(): _reproject(params['params'][layer_name], name, rec_stable=rec_stable, monotone=monotone or ('monotone' in layer_name)) @@ -81,7 +81,7 @@ def embedded_loss_fn(params, inputs, targets, model=None, full_model=None): kl_gaussian_loss = jnp.mean(jnp.mean(ctrl_embedding**2, axis=0) + jnp.mean(ctrl_embedding, axis=0)**2 - jnp.log(jnp.mean(ctrl_embedding**2, axis=0)) - 1) # (sigma^2 + mu^2 - log(sigma^2) - 1)/2 - return target_loss + ctrl_reconstruction_loss + obj_reconstruction_loss + kl_gaussian_loss + return target_loss + ctrl_reconstruction_loss + kl_gaussian_loss + 100*obj_reconstruction_loss @@ -193,20 +193,21 @@ def __call__(self, y, u, z): if self.augment_ctrl_inputs: y = jnp.hstack([y, -y]) + y_add_kernel_init = nn.initializers.lecun_normal() if self.rec_activation == identity else partial(positive_lecun, init_type=self.init_type) # Input-Convex component without bias for the element-wise multiplicative interactions wzu = nn.relu(nn.Dense(features=self.features_latent, use_bias=True, name='wzu')(u)) wyu = self.rec_activation(nn.Dense(features=self.features_y, use_bias=True, name='wyu')(u)) - z_next = nn.Dense(features=self.features_out, use_bias=False, name='wz', kernel_init=partial(positive_lecun, init_type=self.init_type))(z * wzu) - y_next = nn.Dense(features=self.features_out, use_bias=False, name='wy')(y * wyu) + z_add = nn.Dense(features=self.features_out, use_bias=False, name='wz', kernel_init=partial(positive_lecun, init_type=self.init_type))(z * wzu) + y_add = nn.Dense(features=self.features_out, use_bias=False, name='wy', kernel_init=y_add_kernel_init)(y * wyu) u_add = nn.Dense(features=self.features_out, use_bias=True, name='wuz')(u) if self.layer_normalization: - y_next = nn.LayerNorm()(y_next) - z_next = nn.LayerNorm()(z_next) + y_add = nn.LayerNorm()(y_add) + z_add = nn.LayerNorm()(z_add) u_add = nn.LayerNorm()(u_add) - z_next = z_next + y_next + u_add + z_next = z_add + y_add + u_add if not self.prediction_layer: z_next = self.activation(z_next) # Traditional NN component only if it's not the prediction layer @@ -514,8 +515,9 @@ def __call__(self, x, y): for i in range(self.num_layers): prediction_layer = i == self.num_layers -1 features_out = self.features_out if prediction_layer else self.features_latent + features_latent = self.features_latent if self.features_latent is not None else self.features_out u, z = PICNNLayer(features_x=self.features_x, features_y=self.features_y, features_out=features_out, - features_latent=self.features_latent, + features_latent=features_latent, n_layer=i, prediction_layer=prediction_layer, activation=self.activation, rec_activation=self.rec_activation, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, @@ -563,11 +565,11 @@ class LatentPartiallyICNN(nn.Module): def setup(self): - features_latent_encoder = np.minimum(self.n_embeddings, self.features_latent) - features_latent_decoder = np.minimum(self.n_control, self.features_latent) - features_latent_encoder = self.features_latent - features_latent_decoder = self.features_latent + features_latent_encoder = self.features_latent if self.features_latent is not None else self.n_embeddings + features_latent_decoder = self.features_latent if self.features_latent is not None else self.n_control + features_latent_preds = self.features_latent if self.features_latent is not None else self.features_out + features_y = self.n_embeddings*2 if self.augment_ctrl_inputs else self.n_embeddings self.encoder = PartiallyICNN(num_layers=self.n_encoder_layers, features_x=self.features_x, features_y=self.features_y, @@ -577,12 +579,12 @@ def setup(self): self.decoder = PartiallyICNN(num_layers=self.n_decoder_layers, features_x=self.features_x, features_y=features_y, features_out=self.n_control, features_latent=features_latent_decoder, init_type=self.init_type, - augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, - z_min=self.z_min, z_max=self.z_max, name='PICNNLayer_decoder') + augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, rec_activation=nn.relu, + z_min=self.z_min, z_max=self.z_max, name='PICNNLayer_decoder_monotone') self.picnn = PartiallyICNN(num_layers=self.num_layers, features_x=self.features_x, features_y=features_y, - features_out=self.features_out, features_latent=self.features_latent, init_type=self.init_type, + features_out=self.features_out, features_latent=features_latent_preds, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, z_min=self.z_min, z_max=self.z_max, name='PICNNLayer_picnn') @@ -712,6 +714,7 @@ class PICNN(NN): z_min: jnp.array = None z_max: jnp.array = None n_latent: int = 1 + inverter_class = None def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None, val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, @@ -719,9 +722,10 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, probabilistic_loss_kind='maximum_likelihood', distribution = 'normal', inverter_learning_rate: float = 0.1, optimization_vars: list = (), target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False, - z_min: jnp.array = None, z_max: jnp.array = None, + z_min: jnp.array = None, z_max: jnp.array = None, inverter_class=None, **scengen_kwgs): + inverter_class = optax.adabelief if inverter_class is None else inverter_class self.set_attr({"inverter_learning_rate":inverter_learning_rate, "optimization_vars":optimization_vars, @@ -733,10 +737,12 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat "distribution": distribution, "z_min": z_min, "z_max": z_max, - "n_latent":n_latent + "n_latent":n_latent, + "inverter_class":inverter_class }) + self.n_hidden_y = 2 * len(self.optimization_vars) if augment_ctrl_inputs else len(self.optimization_vars) - self.inverter_optimizer = optax.adabelief(learning_rate=self.inverter_learning_rate) + self.inverter_optimizer = inverter_class(learning_rate=self.inverter_learning_rate) super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, @@ -1110,7 +1116,7 @@ class LatentStructuredPICNN(PICNN): decoder_neurons: np.array = None n_embeddings: int = 10 def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, - n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None, + n_hidden_x: int = 100, n_out: int = 1, n_latent:int = None, n_layers: int = 3, pars: dict = None, q_vect=None, val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True, stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, @@ -1152,42 +1158,64 @@ def optimize(self, inputs, objective, n_iter=200, rel_tol=1e-4, recompile_obj=Tr normalized_inputs, _ = self.get_normalized_inputs(inputs) x, y = normalized_inputs - def _objective(ctrl_embedding, x, **objective_kwargs): + def _preds_and_regularizations(ctrl_embedding, x): preds = latent_pred(self.pars, self.model, x, ctrl_embedding) ctrl_reconstruct = decode(self.pars, self.model, x, ctrl_embedding) preds_reconstruct, _ , _ = self.predict_batch_training(self.pars, [x, ctrl_reconstruct]) - implicit_regularization_loss = jnp.mean((preds_reconstruct - preds)**2) + # implicit_regularization_loss = jnp.mean((preds_reconstruct - preds)**2) ctrl_embedding_reconstruct = encode(self.pars, self.model, x, ctrl_reconstruct) - #implicit_regularization_loss = jnp.mean((ctrl_embedding_reconstruct - ctrl_embedding)**2) - return objective(preds, ctrl_reconstruct, **objective_kwargs) + implicit_regularization_loss + implicit_regularization_on_ctrl_loss = jnp.mean((ctrl_embedding_reconstruct - ctrl_embedding)**2) + regularization_loss = implicit_regularization_on_ctrl_loss #+ implicit_regularization_loss + return preds, ctrl_reconstruct, regularization_loss + + def _objective(ctrl_embedding, x, weight=0, **objective_kwargs): + preds, ctrl_reconstruct, regularization_loss = _preds_and_regularizations(ctrl_embedding, x) + return objective(preds, ctrl_reconstruct, **objective_kwargs) + (weight * regularization_loss).ravel()[0] + self._objective = _objective # if the objective changes from one call to another, you need to recompile it. Slower but necessary if recompile_obj or self.iterate is None: @jit - def iterate(x, y, opt_state, **objective_kwargs): + def iterate(x, y, opt_state, lagrangian=0, **objective_kwargs): + ctrl_embedding_history = [] + values_history = [] for i in range(10): - values, grads = value_and_grad(partial(_objective, **objective_kwargs))(y, x) + ctrl_embedding_history.append(y) + values, grads = value_and_grad(partial(_objective, weight=lagrangian, **objective_kwargs))(y, x) + #lagrangian = lagrangian + 0.01 * self.inverter_learning_rate * jnp.maximum(_preds_and_regularizations(y, x)[-1], 0) if vanilla_gd: y -= grads * 1e-1 else: updates, opt_state = self.inverter_optimizer.update(grads, opt_state, y) y = optax.apply_updates(y, updates) - return y, values + values_history.append(values) + return y, values, ctrl_embedding_history, values_history, lagrangian self.iterate = iterate _, ctrl_embedding, ctrl_reconstruct = self.predict_batch_training(self.pars, [x, y]) opt_state = self.inverter_optimizer.init(ctrl_embedding) - ctrl_embedding, values_old = self.iterate(x, ctrl_embedding, opt_state, **objective_kwargs) + ctrl_embedding, values_old, ctrl_embedding_history_0, values_history_0, lagrangian = self.iterate(x, ctrl_embedding, opt_state, **objective_kwargs) values_init = np.copy(values_old) + l = vmap(jit(partial(latent_pred, params=self.pars, model=self.model, x=x)), in_axes=(0)) # do 10 iterations at a time to speed up, check for convergence + ctrl_history = [np.vstack(ctrl_embedding_history_0)] + sol_history = [np.vstack(l(ctrl_embedding=np.vstack(ctrl_embedding_history_0)))] + lagrangian = jnp.array(0) + init_constraint = 0 for i in range(n_iter//10): - ctrl_embedding, values = self.iterate(x, ctrl_embedding, opt_state, **objective_kwargs) + ctrl_embedding, values, ctrl_embedding_history, values_history, lagrangian = self.iterate(x, ctrl_embedding, opt_state, lagrangian=lagrangian, **objective_kwargs) rel_improvement = (values_old - values) / (np.abs(values_old)+ 1e-12) values_old = values + if i%10==0: + print(values) + if rel_improvement < rel_tol: break + ctrl_history.append(np.vstack(ctrl_embedding_history)) + sol_history.append(np.vstack(l(ctrl_embedding=np.vstack(ctrl_embedding_history)))) + print('optimization terminated at iter {}, final objective value: {:0.2e} ' 'rel improvement: {:0.2e}'.format((i+1)*10, values, (values_init-values)/(np.abs(values_init)+1e-12))) @@ -1195,7 +1223,7 @@ def iterate(x, y, opt_state, **objective_kwargs): ctrl = decode(self.pars, self.model, x, ctrl_embedding) - y_hat_from_latent = latent_pred(self.pars, self.model, x, ctrl_embedding) + y_hat_from_latent = l(ctrl_embedding=ctrl_embedding) y_hat_from_ctrl_reconstructed, _ ,_ = self.predict_batch_training(self.pars, [x, ctrl]) @@ -1205,4 +1233,4 @@ def iterate(x, y, opt_state, **objective_kwargs): target_opt = self.predict(inputs) y_opt = inputs.loc[:, self.optimization_vars].values.ravel() - return y_opt, inputs, target_opt, values \ No newline at end of file + return y_opt, inputs, target_opt, values, ctrl_embedding, np.vstack(ctrl_history), np.vstack(sol_history) \ No newline at end of file diff --git a/tests/benchmark_funs_tests.py b/tests/benchmark_funs_tests.py index bc12c95..efa6b7a 100644 --- a/tests/benchmark_funs_tests.py +++ b/tests/benchmark_funs_tests.py @@ -9,6 +9,7 @@ from os.path import join from pyforecaster.forecasting_models.neural_forecasters import latent_pred, encode, decode from jax import vmap +import optax def make_dataset(fun, lb, ub, n_sample_per_dim=1000): @@ -43,47 +44,57 @@ def train_test(test_fun, n_dim, forecaster_class=LatentStructuredPICNN, **foreca # create dataset ub = np.ones(n_dim) lb = -np.ones(n_dim) - data = make_dataset(test_fun(ndim=n_dim).evaluate, lb, ub, n_sample_per_dim=500) + data = make_dataset(test_fun(ndim=n_dim).evaluate, lb, ub, n_sample_per_dim=600) data = data.iloc[np.random.permutation(data.shape[0])] optimization_vars = ['x{}'.format(i) for i in range(1, n_dim+1)] x_names = ['const'] + optimization_vars m = forecaster_class(optimization_vars=optimization_vars, **forecaster_kwargs).fit(data[x_names], data[['y']]) - + d = vmap(lambda x, y: decode(m.pars, m.model, x, y), in_axes=(0, 0)) objective_fun = lambda x, ctrl: jnp.mean(x**2) #+ boxconstr(ctrl, ub, lb) print('minimum in the training set: {}'.format(data['y'].min())) sol = m.optimize(data[x_names].iloc[[0], :], objective=objective_fun, - n_iter=10000, recompile_obj=False, rel_tol=1e-12) + n_iter=10, recompile_obj=False, rel_tol=-1) + e_optobj_convexity_test(data[[c for c in data.columns if c !='y']], m, optimization_vars) eo_convexity_test(data[[c for c in data.columns if c != 'y']], m, optimization_vars) io_convexity_test(data[[c for c in data.columns if c !='y']], m, optimization_vars) - e_optobj_convexity_test(data[[c for c in data.columns if c !='y']], m, optimization_vars) + + # find the global minimum in the latent space of the LatentStructuredPICNN - for i in np.random.choice(data.shape[0], 10): - sol = m.optimize(data[x_names].iloc[[i],:], objective=objective_fun,n_iter=10000, recompile_obj=False, rel_tol=1e-12) + x_opt, obj_opt, ctrl_coords, obj_opts = [], [],[], [] + + for i in np.random.choice(data.shape[0], 6): + sol = m.optimize(data[x_names].iloc[[i],:], objective=objective_fun, n_iter=12000, recompile_obj=False, rel_tol=-1e14) print(sol[0], sol[2].values.ravel(), sol[3]) + ctrl_coords.append(m.input_scaler.inverse_transform(d(np.ones((sol[-2].shape[0], 1)), sol[-2]))) + obj_opts.append(m.predict(pd.DataFrame(np.hstack([np.ones((sol[-2].shape[0], 1)), ctrl_coords[-1]]), columns=x_names)).values.ravel()) + x_opt.append(sol[0]) + obj_opt.append(sol[2].values.ravel()) + x_opt = np.vstack(x_opt) + obj_opt = np.vstack(obj_opt) if n_dim == 2: - savepath = 'wp3/global_optimization/' + savepath = 'tests/figs/' # scatter plot of the learned function m y_hat = m.predict(data[x_names]) - plot_surface(pd.concat([data[optimization_vars], y_hat], axis=1), scatter=True) - plt.show() - - y_hat = m.predict(data[x_names]) - plot_surface(data, scatter=True) - plt.savefig(join(savepath, '{}_ground_truth.png'.format(test_fun.__name__))) - - fig, ax = plot_surface(pd.concat([data[optimization_vars], y_hat], axis=1), scatter=True, alpha=0.2) - ax.scatter(*np.hstack([sol[0], sol[2].values.ravel()]), c='r', s=1000, marker='*') + fig, ax = plot_surface(pd.concat([data[optimization_vars], y_hat], axis=1), scatter=True, alpha=0.05) + ax.scatter(*np.hstack([x_opt, obj_opt]).T, c='r', s=1000, marker='*') + #ax.plot(*ctrl_coords.T, sol[-1].ravel(), c='green', marker='s', markersize=2, alpha=0.1) + [ax.scatter(*ctrl_i.T, obj_i, c=np.linspace(0, 1, len(obj_i)),s=1, alpha=0.2) for ctrl_i, obj_i in zip(ctrl_coords, obj_opts)] for ii in range(0,360,36): ax.view_init(elev=10., azim=ii) plt.savefig(join(savepath, '{}_latent_optimization_{}.png'.format(test_fun.__name__, ii))) plt.show() + fig, ax = plt.subplots(1, 1) + plot_surface(data, scatter=True) + plt.savefig(join(savepath, '{}_ground_truth.png'.format(test_fun.__name__))) + + def boxconstr(x, ub, lb): return jnp.sum(jnp.maximum(0, x - ub)**2 + jnp.maximum(0, lb - x)**2) @@ -102,6 +113,8 @@ def e_optobj_convexity_test(df, forecaster, ctrl_names, **objective_kwargs): approx_second_der = np.round(np.diff(preds, 2, axis=0), 5) approx_second_der[approx_second_der == 0] = 0 # to fix the sign is_convex = np.all(np.sign(approx_second_der) >= 0) + if not is_convex: + print('!omyglobeà£à£à£à£à£à£à£!') print('output is convex w.r.t. input {}: {}'.format(ctrl_e, is_convex)) plt.plot(ctrls[:, ctrl_e], preds, alpha=0.3) plt.pause(0.001) @@ -149,7 +162,7 @@ def eo_convexity_test(df, forecaster, ctrl_names, **objective_kwargs): def test1d(forecaster_class, **forecaster_kwargs): fun = lambda x: (0.01 * (x * 10) ** 4 - ((x - 0.1) * 10) ** 2)/10 - x = np.linspace(-1, 1, 5000) + x = np.linspace(-1, 1, 10000) data = pd.DataFrame(np.hstack([np.ones((x.shape[0], 1)), x.reshape(-1, 1), fun(x).reshape(-1, 1)]), columns=['const', 'x', 'y']) data = data.iloc[np.random.permutation(data.shape[0])] @@ -204,39 +217,46 @@ def test1d(forecaster_class, **forecaster_kwargs): ax.scatter(*E_mesh.T, Z_oracle, c='b', s=1) # fun from latent reconstruction tr points only - ax.scatter(*E_tr.T, Z_oracle_tr, c='b', s=20, marker='+', alpha=1) + ax.scatter(*E_tr.T, Z_oracle_tr, c='b', s=50, marker='+', alpha=1) + for i in np.random.choice(data.shape[0], 5): + sol = m.optimize(data[['const','x']].iloc[[i],:], objective=lambda x, ctrl: jnp.mean(x), n_iter=5000, recompile_obj=False, rel_tol=-10) + ax.plot(*sol[-2].T, sol[-1].ravel(), c='green', marker='s', markersize=2, alpha=0.1) plt.show() - if __name__ == '__main__': + forecaster_kwargs = dict( - augment_ctrl_inputs = False, n_hidden_x = 3, n_latent=30, n_out = 1, batch_size = 2500, - n_epochs = 50, stats_step = 20, n_layers = 3, n_encoder_layers = 3, n_decoder_layers = 3, learning_rate = 1e-3, - stopping_rounds = 1000, init_type = 'normal', rel_tol = -2, layer_normalization = False, n_embeddings=2, normalize_target=False, + augment_ctrl_inputs = False, n_hidden_x = 3, n_latent=200, n_out = 1, batch_size = 2500, + n_epochs = 10, stats_step = 20, n_layers = 3, n_encoder_layers = 3, n_decoder_layers = 3, learning_rate = 1e-3, + inverter_learning_rate=1e-1, + stopping_rounds = 1000, init_type = 'normal', rel_tol = -2, layer_normalization = False, n_embeddings=3, normalize_target=False, unnormalized_inputs=['const']) - test1d(forecaster_class=LatentStructuredPICNN, **forecaster_kwargs) + train_test(CamelSixHump, 2, forecaster_class=LatentStructuredPICNN, **forecaster_kwargs) forecaster_kwargs = dict( augment_ctrl_inputs = False, n_hidden_x = 3, n_latent=30, n_out = 1, batch_size = 2500, - n_epochs = 50, stats_step = 20, n_layers = 3, n_encoder_layers = 3, n_decoder_layers = 3, learning_rate = 1e-3, + n_epochs = 300, stats_step = 20, n_layers = 3, n_encoder_layers = 3, n_decoder_layers = 3, learning_rate = 1e-3, + inverter_learning_rate=1e-1, stopping_rounds = 1000, init_type = 'normal', rel_tol = -2, layer_normalization = False, n_embeddings=2, normalize_target=False, unnormalized_inputs=['const']) + test1d(forecaster_class=LatentStructuredPICNN, **forecaster_kwargs) - train_test(Ackley01, 2, forecaster_class=LatentStructuredPICNN, **forecaster_kwargs) + forecaster_kwargs = dict( + augment_ctrl_inputs = False, n_hidden_x = 3, n_latent=50, n_out = 1, batch_size = 2500, + n_epochs = 20, stats_step = 20, n_layers = 3, n_encoder_layers = 2, n_decoder_layers = 4, learning_rate = 1e-3, + inverter_learning_rate=1e-1, inverter_class=optax.adam, + stopping_rounds = 1000, init_type = 'normal', rel_tol = -2, layer_normalization = False, n_embeddings=10, normalize_target=False, + unnormalized_inputs=['const']) + train_test(Ackley01, 2, forecaster_class=LatentStructuredPICNN, **forecaster_kwargs) - forecaster_kwargs = dict( - augment_ctrl_inputs = True, n_hidden_x = 4, n_out = 1, batch_size = 1000, - n_epochs = 10, stats_step = 100, n_layers = 4, learning_rate = 1e-2, n_latent=100, - stopping_rounds = 10, init_type = 'normal', rel_tol = -1, layer_normalization = False, normalize_target=False) - train_test(Ackley01, 2, forecaster_class=LatentStructuredPICNN, **forecaster_kwargs)