Skip to content

Commit

Permalink
corrected distribution bug
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Jan 19, 2024
1 parent 798b846 commit c76d2bc
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions pyforecaster/forecasting_models/neural_forecasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,6 @@ class PartiallyICNN(nn.Module):
layer_normalization:bool = False
probabilistic: bool = False
structured: bool = False
distribution: bool = 'normal'
z_max: jnp.array = None
z_min: jnp.array = None
@nn.compact
Expand Down Expand Up @@ -620,8 +619,7 @@ def set_arch(self):
self.optimizer = optax.adamw(learning_rate=self.learning_rate)
self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y,
features_out=self.n_out, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic,
distribution=self.distribution, z_min=self.z_min,
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, z_min=self.z_min,
z_max=self.z_max)

self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0))
Expand All @@ -631,7 +629,7 @@ def set_arch(self):
self.loss_fn = jitting_wrapper(causal_loss_fn, self.model, causal_matrix=causal_matrix, kind=self.probabilistic_loss_kind) \
if not self.probabilistic else jitting_wrapper(probabilistic_causal_loss_fn, self.model, causal_matrix=causal_matrix, kind=self.probabilistic_loss_kind)
else:
self.loss_fn = jitting_wrapper(loss_fn, self.predict_batch) if not self.probabilistic else jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind)
self.loss_fn = jitting_wrapper(loss_fn, self.predict_batch) if not self.probabilistic else jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution)

self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)

Expand Down Expand Up @@ -746,10 +744,9 @@ def set_arch(self):
self.optimizer = optax.adamw(learning_rate=self.learning_rate)
self.model = PartiallyIQCNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y,
features_out=self.n_out, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic,
distribution=self.distribution)
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic)
self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0))
self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else (
self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) if self.probabilistic else (
jitting_wrapper(loss_fn, self.predict_batch))
self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)

Expand Down Expand Up @@ -781,10 +778,10 @@ def set_arch(self):
self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y,
features_out=self.n_out, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, activation=nn.sigmoid,
rec_activation=nn.sigmoid, probabilistic=self.probabilistic,
distribution=self.distribution)
rec_activation=nn.sigmoid, probabilistic=self.probabilistic,z_min=self.z_min,
z_max=self.z_max)
self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0))
self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else (
self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) if self.probabilistic else (
jitting_wrapper(loss_fn, self.predict_batch))
self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)

Expand All @@ -810,9 +807,10 @@ def set_arch(self):
self.optimizer = optax.adamw(learning_rate=self.learning_rate)
self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y,
features_out=self.n_out, activation=nn.relu,
init_type=self.init_type, probabilistic=self.probabilistic, distribution=self.distribution)
init_type=self.init_type, probabilistic=self.probabilistic,z_min=self.z_min,
z_max=self.z_max)
self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0))
self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else (
self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) if self.probabilistic else (
jitting_wrapper(loss_fn, self.predict_batch))
self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)

Expand Down Expand Up @@ -901,9 +899,10 @@ def set_arch(self):
features_out=self.n_out, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, activation=nn.sigmoid,
rec_activation=nn.sigmoid, probabilistic=self.probabilistic, structured=True,
distribution=self.distribution)
z_min=self.z_min, z_max=self.z_max)
self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0))
self.loss_fn = jitting_wrapper(structured_loss_fn, self.predict_batch, objective=self.objective) if not self.probabilistic else jitting_wrapper(structured_probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, objective=self.objective, distribution=self.distribution)
self.loss_fn = jitting_wrapper(structured_loss_fn, self.predict_batch, objective=self.objective) if not self.probabilistic \
else jitting_wrapper(structured_probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, objective=self.objective, distribution=self.distribution)
self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)


Expand Down

0 comments on commit c76d2bc

Please sign in to comment.