Skip to content

Commit

Permalink
removed FastLinReg
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Dec 7, 2023
1 parent c103a2e commit fe95359
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions tests/test_nns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pyforecaster.formatter import Formatter
from pyforecaster.metrics import nmae
from os import makedirs
from os.path import exists
from os.path import exists, join
import jax.numpy as jnp

class TestFormatDataset(unittest.TestCase):
Expand Down Expand Up @@ -91,8 +91,7 @@ def test_picnn(self):
plt.figure(layout='tight')
plt.plot(np.tile(x[cc].values.reshape(-1, 1), 96), y_hat.values[:, :96], alpha=0.3)
plt.xlabel(cc)
plt.show()
plt.savefig('wp3/results/figs/convexity/picnn_{}.png'.format(cc), dpi=300)
plt.savefig(join(savepath_tr_plots,'picnn_{}.png'.format(cc)), dpi=300)

n = PICNN(load_path='tests/results/ffnn_model.pk')
y_hat_2 = n.predict(x_te.iloc[:100, :])
Expand Down Expand Up @@ -176,7 +175,6 @@ def test_pqicnn(self):
plt.figure(layout='tight')
plt.plot(np.tile(x[cc].values.reshape(-1, 1), 96), y_hat.values[:, :96], alpha=0.3)
plt.xlabel(cc)
plt.show()
plt.savefig('wp3/results/figs/convexity/{}.png'.format(cc), dpi=300)

def test_optimization(self):
Expand All @@ -200,7 +198,7 @@ def test_optimization(self):
m = PICNN(learning_rate=1e-3, batch_size=1000, load_path=None, n_hidden_x=200, n_hidden_y=200,
n_out=y_tr.shape[1], n_layers=3, optimization_vars=optimization_vars, inverter_learning_rate=1e-3,
augment_ctrl_inputs=True, layer_normalization=True, unnormalized_inputs=optimization_vars).fit(x_tr, y_tr,
n_epochs=3,
n_epochs=1,
savepath_tr_plots=savepath_tr_plots,
stats_step=40)

Expand All @@ -225,7 +223,7 @@ def test_hyperpar_optimization(self):

model = PICNN(optimization_vars=self.x.columns[:10], n_out=self.y.shape[1], n_epochs=6)

n_folds = 2
n_folds = 1
cv_idxs = []
for i in range(n_folds):
tr_idx = np.random.randint(0, 2, len(self.x.index), dtype=bool)
Expand Down

0 comments on commit fe95359

Please sign in to comment.