From 896b686b31861eba7a9e66d52851e2cb808463a9 Mon Sep 17 00:00:00 2001 From: nepslor Date: Tue, 9 Jan 2024 11:58:00 +0100 Subject: [PATCH] added structuredPICNN corrected initialization bug --- pyforecaster/dictionaries.py | 3 +- .../forecasting_models/neural_forecasters.py | 285 ++++++++++++++---- tests/test_nns.py | 70 ++++- 3 files changed, 294 insertions(+), 64 deletions(-) diff --git a/pyforecaster/dictionaries.py b/pyforecaster/dictionaries.py index beee35a..f8fb15b 100644 --- a/pyforecaster/dictionaries.py +++ b/pyforecaster/dictionaries.py @@ -31,5 +31,6 @@ def picnn_param_space(trial): HYPERPAR_MAP = {'LinearForecaster': linear_param_space, 'LGBForecaster': lgb_param_space, 'LGBMHybrid': lgb_param_space, - "PICNN": picnn_param_space} + "PICNN": picnn_param_space, + "StructuredPICNN": picnn_param_space} diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index ac8d6b3..4259291 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -204,15 +204,13 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat if self.load_path is not None: self.load(self.load_path) + self.optimizer = None + self.model = None + self.loss_fn = None + self.train_step = None + self.predict_batch = None + self.set_arch() - self.model = self.set_arch() - self.optimizer = optax.adamw(learning_rate=self.learning_rate) - - - self.predict_batch = vmap(jitting_wrapper(predict_batch, self.model), in_axes=(None, 0)) - self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else ( - jitting_wrapper(loss_fn, self.predict_batch)) - self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) def get_class_properties_names(cls): attributes = [] @@ -235,6 +233,10 @@ def save(self, save_path): with open(save_path, 'wb') as f: pk.dump(attrdict, f, protocol=pk.HIGHEST_PROTOCOL) + def set_params(self, **kwargs): + [self.__setattr__(k, v) for k, v in kwargs.items() if k in self.__dict__.keys()] + self.set_arch() + def set_attr(self, attrdict): [self.__setattr__(k, v) for k, v in attrdict.items()] @@ -252,9 +254,14 @@ def init_arch(nn_init, n_inputs_x=1): return init_params def set_arch(self): - model = FeedForwardModule(n_layers=self.n_layers, n_neurons=self.n_hidden_x, + self.optimizer = optax.adamw(learning_rate=self.learning_rate) + self.model = FeedForwardModule(n_layers=self.n_layers, n_neurons=self.n_hidden_x, n_out=self.n_out) - return model + self.predict_batch = vmap(jitting_wrapper(predict_batch, self.model), in_axes=(None, 0)) + self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else ( + jitting_wrapper(loss_fn, self.predict_batch)) + self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) + def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step=None, rel_tol=None): self.to_be_normalized = [c for c in inputs.columns if @@ -327,7 +334,7 @@ def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step if val_loss[-1] > val_loss[-2]: pars = old_pars self.pars = pars - super().fit(inputs_val_0, targets_val_0) + return self def training_plots(self, inputs, target, tr_loss, te_loss, savepath, k): @@ -376,18 +383,13 @@ def predict(self, inputs, return_sigma=False, **kwargs): def predict_quantiles(self, inputs, normalize=True, **kwargs): if normalize: - x, _ = self.get_normalized_inputs(inputs) + mu_hat, sigma_hat = self.predict(inputs, return_sigma=True) else: x = inputs - y_hat = self.predict_batch(self.pars, x) - y_hat = np.array(y_hat) - if self.normalize_target and normalize: - y_hat[:, :y_hat.shape[1] // 2] = self.target_scaler.inverse_transform(y_hat[:, :y_hat.shape[1] // 2]) - y_hat[:, y_hat.shape[1] // 2:] = self.target_scaler.inverse_transform( - (y_hat[:, y_hat.shape[1] // 2:])) - y_hat[:, y_hat.shape[1] // 2:] = (y_hat[:, y_hat.shape[1] // 2:])** 0.5 - mu_hat = y_hat[:, :y_hat.shape[1]//2] - sigma_hat = y_hat[:, y_hat.shape[1]//2:] + y_hat = self.predict_batch(self.pars, x) + y_hat = np.array(y_hat) + mu_hat = y_hat[:, :y_hat.shape[1]//2] + sigma_hat = (y_hat[:, y_hat.shape[1] // 2:])** 0.5 preds = np.expand_dims(mu_hat, -1) * np.ones((1, 1, len(self.q_vect))) for i, q in enumerate(self.q_vect): @@ -405,6 +407,8 @@ def get_normalized_inputs(self, inputs, target=None): target = target.copy() target = self.target_scaler.transform(target) target = target.values + elif target is not None: + target = target.values return inputs.values, target @@ -427,6 +431,7 @@ class PartiallyICNN(nn.Module): augment_ctrl_inputs: bool = False layer_normalization:bool = False probabilistic: bool = False + structured: bool = False @nn.compact def __call__(self, x, y): u = x.copy() @@ -440,11 +445,12 @@ def __call__(self, x, y): layer_normalization=self.layer_normalization)(y, u, z) if self.probabilistic: u = x.copy() - sigma = jnp.zeros(self.features_out) # Initialize z_0 to be the same shape as y + sigma_len = 1 if self.structured else self.features_out + sigma = jnp.zeros(sigma_len) # Initialize z_0 to be the same shape as y for i in range(self.num_layers): prediction_layer = i == self.num_layers - 1 u, sigma = PICNNLayer(features_x=self.features_x, features_y=self.features_y, - features_out=self.features_out, + features_out=sigma_len, n_layer=i, prediction_layer=prediction_layer, activation=self.activation, rec_activation=self.rec_activation, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, @@ -539,7 +545,7 @@ class PICNN(NN): init_type: str = 'normal' augment_ctrl_inputs: bool = False layer_normalization: bool = False - + probabilistic_loss_kind: str = 'maximum_likelihood' def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, n_hidden_x: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None, val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, @@ -547,44 +553,42 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, probabilistic_loss_kind='maximum_likelihood', inverter_learning_rate: float = 0.1, optimization_vars: list = (), target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False, - optimizer=None, **scengen_kwgs): - - super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, - nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, - normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, - probabilistic, probabilistic_loss_kind, **scengen_kwgs) + **scengen_kwgs): self.set_attr({"inverter_learning_rate":inverter_learning_rate, "optimization_vars":optimization_vars, "target_columns":target_columns, "init_type":init_type, "augment_ctrl_inputs":augment_ctrl_inputs, - "layer_normalization":layer_normalization + "layer_normalization":layer_normalization, + "probabilistic_loss_kind":probabilistic_loss_kind }) - - if load_path is not None: - self.load(load_path) self.n_hidden_y = 2 * len(self.optimization_vars) if augment_ctrl_inputs else len(self.optimization_vars) - self.model = self.set_arch() - self.optimizer = optax.adamw(learning_rate=self.learning_rate) if optimizer is None else optimizer self.inverter_optimizer = optax.adabelief(learning_rate=self.inverter_learning_rate) + super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, + nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, + normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, + probabilistic, probabilistic_loss_kind, **scengen_kwgs) + + + def set_arch(self): + self.optimizer = optax.adamw(learning_rate=self.learning_rate) + self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, + features_out=self.n_out, init_type=self.init_type, + augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic) + self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) - if causal_df is not None: - causal_df = (~causal_df.astype(bool)).astype(float) + if self.causal_df is not None: + causal_df = (~self.causal_df.astype(bool)).astype(float) causal_matrix = np.tile(causal_df.values, 2) if self.probabilistic else causal_df.values - self.loss_fn = jitting_wrapper(causal_loss_fn, self.model, causal_matrix=causal_matrix, kind=probabilistic_loss_kind) \ - if not self.probabilistic else jitting_wrapper(probabilistic_causal_loss_fn, self.model, causal_matrix=causal_matrix, kind=probabilistic_loss_kind) + self.loss_fn = jitting_wrapper(causal_loss_fn, self.model, causal_matrix=causal_matrix, kind=self.probabilistic_loss_kind) \ + if not self.probabilistic else jitting_wrapper(probabilistic_causal_loss_fn, self.model, causal_matrix=causal_matrix, kind=self.probabilistic_loss_kind) else: - self.loss_fn = jitting_wrapper(loss_fn, self.predict_batch) if not self.probabilistic else jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=probabilistic_loss_kind) + self.loss_fn = jitting_wrapper(loss_fn, self.predict_batch) if not self.probabilistic else jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind) self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) - def set_arch(self): - model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, - features_out=self.n_out, init_type=self.init_type, - augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic) - return model @staticmethod def init_arch(nn_init, n_inputs_x=1, n_inputs_opt=1): @@ -607,7 +611,8 @@ def get_normalized_inputs(self, inputs, target=None): target = target.copy() target = self.target_scaler.transform(target) target = target.values - + elif target is not None: + target = target.values return (x, y), target def optimize(self, inputs, objective, n_iter=200, rel_tol=1e-4, recompile_obj=True, vanilla_gd=False, **objective_kwargs): @@ -681,20 +686,24 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat probabilistic_loss_kind='maximum_likelihood', inverter_learning_rate: float = 0.1, optimization_vars: list = (), target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False, - optimizer=None, **scengen_kwgs): + **scengen_kwgs): super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic, probabilistic_loss_kind, inverter_learning_rate, optimization_vars, target_columns, init_type, - augment_ctrl_inputs, layer_normalization, optimizer, **scengen_kwgs) + augment_ctrl_inputs, layer_normalization, **scengen_kwgs) def set_arch(self): - model = PartiallyIQCNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, + self.optimizer = optax.adamw(learning_rate=self.learning_rate) + self.model = PartiallyIQCNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, features_out=self.n_out, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic) - return model + self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) + self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else ( + jitting_wrapper(loss_fn, self.predict_batch)) + self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) class PIQCNNSigmoid(PICNN): @@ -709,21 +718,26 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, probabilistic_loss_kind='maximum_likelihood', inverter_learning_rate: float = 0.1, optimization_vars: list = (), target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False, - optimizer=None, **scengen_kwgs): + **scengen_kwgs): super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic, probabilistic_loss_kind, inverter_learning_rate, optimization_vars, target_columns, init_type, - augment_ctrl_inputs, layer_normalization, optimizer, **scengen_kwgs) + augment_ctrl_inputs, layer_normalization, **scengen_kwgs) def set_arch(self): - model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, + self.optimizer = optax.adamw(learning_rate=self.learning_rate) + self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, features_out=self.n_out, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, activation=nn.sigmoid, rec_activation=nn.sigmoid, probabilistic=self.probabilistic) - return model + self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) + self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else ( + jitting_wrapper(loss_fn, self.predict_batch)) + self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) + class RecStablePICNN(PICNN): reproject: bool = True rec_stable = True @@ -734,16 +748,169 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, probabilistic_loss_kind='maximum_likelihood',inverter_learning_rate: float = 0.1, optimization_vars: list = (), target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, - layer_normalization=False, optimizer=None, **scengen_kwgs): + layer_normalization=False, **scengen_kwgs): super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic, probabilistic_loss_kind, inverter_learning_rate, optimization_vars, target_columns, init_type, - augment_ctrl_inputs, layer_normalization, optimizer, **scengen_kwgs) + augment_ctrl_inputs, layer_normalization, **scengen_kwgs) def set_arch(self): - model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, + self.optimizer = optax.adamw(learning_rate=self.learning_rate) + self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, features_out=self.n_out, activation=nn.relu, init_type=self.init_type, probabilistic=self.probabilistic) - return model \ No newline at end of file + self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) + self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else ( + jitting_wrapper(loss_fn, self.predict_batch)) + self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) + + +def structured_loss_fn(params, inputs, targets, model=None, objective=None): + predictions = model(params, inputs) + structured_loss = jnp.mean((predictions - targets) ** 2) + objs_hat = objective(predictions, inputs[1]) + objs = objective(targets, inputs[1]) + objective_loss = jnp.mean((objs-objs_hat)**2) + monotonic_objective_loss = monotonic_objective_relax(objs_hat, objs) + #fourier_loss = jnp.mean((unnormalized_fourier_transform(predictions, 20) - unnormalized_fourier_transform(targets, 20))**2) + return structured_loss + monotonic_objective_loss #+ objective_loss + +def structured_probabilistic_loss_fn(params, inputs, targets, model=None, kind='maximum_likelihood', objective=None): + out = model(params, inputs) + predictions = out[:, :-1] + sigma_square = out[:, -1] + objs_hat = objective(predictions, inputs[1]) + objs = objective(targets, inputs[1]) + + if kind == 'maximum_likelihood': + objective_loss = jnp.mean(((objs_hat - objs)**2) / sigma_square + jnp.log(sigma_square)) + elif kind == 'crps': + sigma = jnp.sqrt(sigma_square) + u = (objs_hat - objs) / sigma + objective_loss = jnp.mean(sigma * (u * (2 * jax.scipy.stats.norm.cdf(u) - 1) + 2 * jax.scipy.stats.norm.pdf(u) - 1 / jnp.sqrt(jnp.pi))) + + structured_loss = jnp.mean((predictions-targets) ** 2) + monotonic_objective_loss = monotonic_objective_relax(objs_hat, objs) + #fourier_loss = jnp.mean((unnormalized_fourier_transform(predictions, 20) - unnormalized_fourier_transform(targets, 20))**2) + return structured_loss + monotonic_objective_loss #+ objective_loss + +@partial(vmap, in_axes=(0, None)) +def unnormalized_fourier_transform(predictions, n_freq): + """ + Projects the predictions over sin and cos base functions, and returns the coefficients + :param predictions: original predictions in time domain + :param n_freq: number of harmonitcs to project on + :return: + """ + # project predictions on cos and sin functions + t = np.arange(len(predictions)) + sin_bases = np.array([np.sin(2 * np.pi * i * t / t[-1]) for i in range(1, n_freq + 1)]).T + cos_bases = np.array([np.cos(2 * np.pi * i * t / t[-1]) for i in range(1, n_freq + 1)]).T + bases = np.hstack([sin_bases, cos_bases]) + unnormalized_fc = predictions @ bases + + return unnormalized_fc + +def monotonic_objective_right(objs_hat, objs): + rank = jnp.argsort(objs) + rank_hat = jnp.argsort(objs_hat) + return -jnp.mean(jnp.abs(jnp.corrcoef(rank, rank_hat)[0, 1])) + +def monotonic_objective_relax(objs_hat, objs): + key = random.key(0) + random_pairs = random.choice(key, len(objs), (len(objs)*100, 2)) + d = objs[random_pairs[:, 0]] - objs[random_pairs[:, 1]] + d_hat = objs_hat[random_pairs[:, 0]] - objs_hat[random_pairs[:, 1]] + discordant = (d * d_hat) < 0 + return -jnp.mean(d*discordant) + + +class StructuredPICNN(PICNN): + reproject: bool = True + rec_stable: bool = False + monotone: bool = True + objective_fun=None + def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, + n_hidden_x: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None, + val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, + stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True, + stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, + probabilistic_loss_kind='maximum_likelihood', inverter_learning_rate: float = 0.1, optimization_vars: list = (), + target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False, + objective_fun=None, **scengen_kwgs): + + self.objective_fun = objective_fun + self.objective = vmap(objective_fun, in_axes=(0, 0)) + + super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, + nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, + normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic, + probabilistic_loss_kind, inverter_learning_rate, optimization_vars, target_columns, init_type, + augment_ctrl_inputs, layer_normalization, **scengen_kwgs) + + def set_arch(self): + self.optimizer = optax.adamw(learning_rate=self.learning_rate) + self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, + features_out=self.n_out, init_type=self.init_type, + augment_ctrl_inputs=self.augment_ctrl_inputs, activation=nn.sigmoid, + rec_activation=nn.sigmoid, probabilistic=self.probabilistic, structured=True) + self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) + self.loss_fn = jitting_wrapper(structured_loss_fn, self.predict_batch, objective=self.objective) if not self.probabilistic else jitting_wrapper(structured_probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, objective=self.objective) + self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) + + + + def predict(self, inputs, return_sigma=False, return_obj=False, **kwargs): + x, _ = self.get_normalized_inputs(inputs) + y_hat = self.predict_batch(self.pars, x) + y_hat = np.array(y_hat) + + if self.normalize_target: + if self.probabilistic: + y_hat[:, :-1] = self.target_scaler.inverse_transform(y_hat[:, :y_hat.shape[1]//2]) + y_hat[:, -1] = self.target_scaler.inverse_transform((y_hat[:, y_hat.shape[1] // 2:])**0.5) # this is wrong, please do not normalize target when probabilistic + else: + y_hat = self.target_scaler.inverse_transform(y_hat) + + if self.probabilistic: + preds = pd.DataFrame(y_hat[:, :-1], index=inputs.index, columns=self.target_columns) + sigma = pd.DataFrame(y_hat[:, -1]**0.5, index=inputs.index, columns=['sigma']) + objs = pd.DataFrame(self.objective(y_hat[:, :-1], x[1]), index=inputs.index, columns=['objective']) + if return_obj and return_sigma: + return preds, objs, sigma + elif return_obj: + return preds, objs + elif return_sigma: + return preds, sigma + else: + return preds + + else: + preds = pd.DataFrame(y_hat, index=inputs.index, columns=self.target_columns) + objs = pd.DataFrame(self.objective(y_hat, x[1]), index=inputs.index, columns=['objective']) + + if return_obj: + return preds, objs + else: + return preds + + + def predict_quantiles(self, inputs, normalize=True, **kwargs): + if normalize: + y_hat, objs_hat, sigma_hat = self.predict(inputs, return_sigma=True, return_obj=True) + mu_hat = objs_hat.values + sigma_hat = sigma_hat.values + else: + y_hat = self.predict_batch(self.pars, inputs) + y_hat = np.array(y_hat) + sigma_hat = (y_hat[:, -1])** 0.5 + mu_hat = self.objective(y_hat[:, :-1], inputs[1]).reshape(-1, 1) + + + preds = np.expand_dims(mu_hat, -1) * np.ones((1, 1, len(self.q_vect))) + for i, q in enumerate(self.q_vect): + qs = sigma_hat * np.sqrt(2) * erfinv(2*q-1) + preds[:, :, i] += qs.reshape(-1, 1) + return preds \ No newline at end of file diff --git a/tests/test_nns.py b/tests/test_nns.py index 28d4530..3d1486a 100644 --- a/tests/test_nns.py +++ b/tests/test_nns.py @@ -4,7 +4,7 @@ import pandas as pd import numpy as np import logging -from pyforecaster.forecasting_models.neural_forecasters import PICNN, RecStablePICNN, NN, PIQCNN, PIQCNNSigmoid +from pyforecaster.forecasting_models.neural_forecasters import PICNN, RecStablePICNN, NN, PIQCNN, PIQCNNSigmoid, StructuredPICNN from pyforecaster.trainer import hyperpar_optimizer from pyforecaster.formatter import Formatter from pyforecaster.metrics import nmae @@ -69,7 +69,7 @@ def test_picnn(self): m_1 = PICNN(learning_rate=1e-3, batch_size=500, load_path=None, n_hidden_x=20, n_hidden_y=20, n_out=y_tr.shape[1], n_layers=3, optimization_vars=optimization_vars,probabilistic=True, probabilistic_loss_kind='crps', rel_tol=-1, - val_ratio=0.2).fit(x_tr, y_tr,n_epochs=10,stats_step=50,savepath_tr_plots=savepath_tr_plots) + val_ratio=0.2).fit(x_tr, y_tr,n_epochs=2,stats_step=50,savepath_tr_plots=savepath_tr_plots) y_hat_1 = m_1.predict(x_te) m_1.save('tests/results/ffnn_model.pk') @@ -265,15 +265,77 @@ def test_hyperpar_optimization(self): n_folds = 1 cv_idxs = [] for i in range(n_folds): - tr_idx = np.random.randint(0, 2, len(self.x.index), dtype=bool) + tr_idx = np.random.randint(0, 2, len(self.x.index[:1000]), dtype=bool) te_idx = ~tr_idx cv_idxs.append((tr_idx, te_idx)) - study, replies = hyperpar_optimizer(self.x, self.y, model, n_trials=1, metric=nmae, + objective = lambda y, ctrl: (y ** 2).mean() + def custom_metric(x, t, agg_index=None, inter_normalization=True, **kwargs): + obj = x.apply(lambda x: objective(x, None), axis=1) + obj_hat = t.apply(lambda x: objective(x, None), axis=1) + rank = np.argsort(obj) + rank_hat = np.argsort(obj_hat) + corr = np.corrcoef(rank, rank_hat)[0, 1] + return np.array(corr) + + study, replies = hyperpar_optimizer(self.x.iloc[:1000, :], self.y.iloc[:1000, :], model, n_trials=1, metric=custom_metric, cv=(f for f in cv_idxs), param_space_fun=None, hpo_type='full') + def test_structured_picnn_sigmoid(self): + + #x_cols = np.random.choice(self.x.columns, 5) + x = (self.x - self.x.mean(axis=0)) / (self.x.std(axis=0)+0.01) + y = (self.y - self.y.mean(axis=0)) / (self.y.std(axis=0)+0.01) + + n_tr = int(len(x) * 0.8) + + objective = lambda y, ctrl: jnp.mean(y**2) + + x_tr, x_te, y_tr, y_te = [x.iloc[:n_tr, :].copy(), x.iloc[n_tr:, :].copy(), y.iloc[:n_tr].copy(), + y.iloc[n_tr:].copy()] + + savepath_tr_plots = 'tests/results/figs/convexity' + + # if not there, create directory savepath_tr_plots + if not exists(savepath_tr_plots): + makedirs(savepath_tr_plots) + + + optimization_vars = x_tr.columns[:20] + + m_1 = StructuredPICNN(learning_rate=1e-4, batch_size=100, load_path=None, n_hidden_x=250, n_hidden_y=250, + n_out=y_tr.shape[1], n_layers=3, optimization_vars=optimization_vars,stopping_rounds=100 + , layer_normalization=True, objective_fun=objective, probabilistic=True, probabilistic_loss_kind='crps', normalize_target=False).fit(x_tr, + y_tr, + n_epochs=2, + savepath_tr_plots=savepath_tr_plots, + stats_step=500, rel_tol=-1) + from jax import vmap + objs = vmap(objective,in_axes=(0, 0))(y_te.values, x_te.values) + rnd_idxs = np.random.choice(x_te.shape[0], 5000, replace=False) + rnd_idxs = rnd_idxs[np.argsort(objs[rnd_idxs])] + + fig, ax = plt.subplots(2, 1, figsize=(5, 10)) + ax[0].plot(objs[rnd_idxs], label='y_true') + ax[0].plot(m_1.predict(x_te.iloc[rnd_idxs, :], return_obj=True)[1].values.ravel(), label='y_hat') + ax[1].scatter(objs[rnd_idxs], m_1.predict(x_te.iloc[rnd_idxs, :], return_obj=True)[1].values.ravel(), s=1) + ax[0].plot(np.squeeze(m_1.predict_quantiles(x_te.iloc[rnd_idxs, :], return_obj=True)), color='orange', alpha=0.3) + + + ordered_idx = np.argsort(np.abs(objs[rnd_idxs] - m_1.predict(x_te.iloc[rnd_idxs, :], return_obj=True)[1].values.ravel())) + + + for r in rnd_idxs[ordered_idx[-10:]]: + y_hat = m_1.predict(x_te.iloc[[r], :]) + #q_hat = m_1.predict_quantiles(x_te.iloc[[r], :]) + plt.figure() + plt.plot(y_te.iloc[r, :].values.ravel(), label='y_true') + plt.plot(y_hat.values.ravel(), label='y_hat') + #plt.plot(np.squeeze(q_hat), label='q_hat', color='orange', alpha=0.3) + plt.legend() + if __name__ == '__main__': unittest.main()