diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index 9912d0b..f86c9f3 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -470,7 +470,6 @@ class PartiallyICNN(nn.Module): layer_normalization:bool = False probabilistic: bool = False structured: bool = False - distribution: bool = 'normal' z_max: jnp.array = None z_min: jnp.array = None @nn.compact @@ -620,8 +619,7 @@ def set_arch(self): self.optimizer = optax.adamw(learning_rate=self.learning_rate) self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, features_out=self.n_out, init_type=self.init_type, - augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, - distribution=self.distribution, z_min=self.z_min, + augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, z_min=self.z_min, z_max=self.z_max) self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) @@ -631,7 +629,7 @@ def set_arch(self): self.loss_fn = jitting_wrapper(causal_loss_fn, self.model, causal_matrix=causal_matrix, kind=self.probabilistic_loss_kind) \ if not self.probabilistic else jitting_wrapper(probabilistic_causal_loss_fn, self.model, causal_matrix=causal_matrix, kind=self.probabilistic_loss_kind) else: - self.loss_fn = jitting_wrapper(loss_fn, self.predict_batch) if not self.probabilistic else jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind) + self.loss_fn = jitting_wrapper(loss_fn, self.predict_batch) if not self.probabilistic else jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) @@ -746,10 +744,9 @@ def set_arch(self): self.optimizer = optax.adamw(learning_rate=self.learning_rate) self.model = PartiallyIQCNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, features_out=self.n_out, init_type=self.init_type, - augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, - distribution=self.distribution) + augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic) self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) - self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else ( + self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) if self.probabilistic else ( jitting_wrapper(loss_fn, self.predict_batch)) self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) @@ -781,10 +778,10 @@ def set_arch(self): self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, features_out=self.n_out, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, activation=nn.sigmoid, - rec_activation=nn.sigmoid, probabilistic=self.probabilistic, - distribution=self.distribution) + rec_activation=nn.sigmoid, probabilistic=self.probabilistic,z_min=self.z_min, + z_max=self.z_max) self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) - self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else ( + self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) if self.probabilistic else ( jitting_wrapper(loss_fn, self.predict_batch)) self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) @@ -810,9 +807,10 @@ def set_arch(self): self.optimizer = optax.adamw(learning_rate=self.learning_rate) self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, features_out=self.n_out, activation=nn.relu, - init_type=self.init_type, probabilistic=self.probabilistic, distribution=self.distribution) + init_type=self.init_type, probabilistic=self.probabilistic,z_min=self.z_min, + z_max=self.z_max) self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) - self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else ( + self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) if self.probabilistic else ( jitting_wrapper(loss_fn, self.predict_batch)) self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) @@ -901,9 +899,10 @@ def set_arch(self): features_out=self.n_out, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, activation=nn.sigmoid, rec_activation=nn.sigmoid, probabilistic=self.probabilistic, structured=True, - distribution=self.distribution) + z_min=self.z_min, z_max=self.z_max) self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) - self.loss_fn = jitting_wrapper(structured_loss_fn, self.predict_batch, objective=self.objective) if not self.probabilistic else jitting_wrapper(structured_probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, objective=self.objective, distribution=self.distribution) + self.loss_fn = jitting_wrapper(structured_loss_fn, self.predict_batch, objective=self.objective) if not self.probabilistic \ + else jitting_wrapper(structured_probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, objective=self.objective, distribution=self.distribution) self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)