Skip to content

Commit

Permalink
corrected n_out bug with optuna
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Dec 12, 2023
1 parent f2a5724 commit 2d5cc09
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions pyforecaster/forecasting_models/neural_forecasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat
if self.load_path is not None:
self.load(self.load_path)

self.n_out = self.n_out*2 if self.probabilistic else self.n_out
self.model = self.set_arch()
self.optimizer = optax.adamw(learning_rate=self.learning_rate)

Expand Down Expand Up @@ -246,7 +245,7 @@ def init_arch(nn_init, n_inputs_x=1):

def set_arch(self):
model = FeedForwardModule(n_layers=self.n_layers, n_neurons=self.n_hidden_x,
n_out=self.n_out)
n_out=self.n_out*2 if self.probabilistic else self.n_out)
return model

def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step=None, rel_tol=None):
Expand Down

0 comments on commit 2d5cc09

Please sign in to comment.