Skip to content

Commit

Permalink
added skip_connections and selectors to FFNN
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Dec 20, 2024
1 parent 0421fe2 commit c8bbdaa
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 23 deletions.
3 changes: 2 additions & 1 deletion pyforecaster/forecasting_models/neural_models/INN.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class EndToEndCausalInvertibleModule(nn.Module):

def setup(self):
self.embedder = CausalInvertibleModule(num_layers=self.num_embedding_layers, features=self.features_embedding, scaling_factor=self.scaling_factor)
self.predictor = FeedForwardModule(n_layers=np.hstack([(np.ones(self.num_prediction_layers-1)*self.features_prediction).astype(int), self.n_out]), split_heads=False)
self.predictor = FeedForwardModule(n_layers=np.hstack([(np.ones(self.num_prediction_layers-1)*self.features_prediction).astype(int), self.n_out]),
skip_connection=True, split_heads=False)
self.invert_fun = jax.jit(partial(self.inverter, embedder=self.embedder))
def __call__(self, x):
x = x.copy()
Expand Down
19 changes: 16 additions & 3 deletions pyforecaster/forecasting_models/neural_models/base_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class FeedForwardModule(nn.Module):
n_out: int=None
n_neurons: int=None
split_heads: bool = False
selector:Union[np.array, None]=None
skip_connection:bool=False
@nn.compact
def __call__(self, x):
if isinstance(self.n_layers, int):
Expand All @@ -105,7 +107,15 @@ def __call__(self, x):
layers = layers.astype(int)
else:
layers = self.n_layers
y = nn.Dense(features=layers[-1], name='dense_-1')(x)

if self.skip_connection or self.selector is not None:
if self.selector is not None:
y = x[self.selector]
else:
y = nn.Dense(features=layers[-1], name='dense_-1')(x)
else:
y = 0

for i, n in enumerate(layers):
if i < len(layers)-1:
x = nn.Dense(features=n, name='dense_{}'.format(i))(x)
Expand Down Expand Up @@ -486,15 +496,18 @@ def get_normalized_inputs(self, inputs, target=None):

class FFNN(NN):
def __init__(self, n_out=None, q_vect=None, n_epochs=10, val_ratio=None, nodes_at_step=None, learning_rate=1e-3,
scengen_dict={}, batch_size=None, split_heads=False, **model_kwargs):
scengen_dict={}, batch_size=None, split_heads=False, selector=None, skip_connection=False,
**model_kwargs):
self.split_heads = split_heads
self.selector = selector
self.skip_connection = skip_connection
super().__init__(n_out=n_out, q_vect=q_vect, n_epochs=n_epochs, val_ratio=val_ratio, nodes_at_step=nodes_at_step, learning_rate=learning_rate,
nn_module=FeedForwardModule, scengen_dict=scengen_dict, batch_size=batch_size, **model_kwargs)

def set_arch(self):
self.optimizer = optax.adamw(learning_rate=self.learning_rate)
self.model = FeedForwardModule(n_layers=self.n_layers, n_neurons=self.n_hidden_x,
n_out=self.n_out, split_heads=self.split_heads)
n_out=self.n_out, split_heads=self.split_heads, selector=self.selector, skip_connection=self.skip_connection)
self.predict_batch = vmap(jitting_wrapper(predict_batch, self.model), in_axes=(None, 0))
self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else (
jitting_wrapper(loss_fn, self.predict_batch))
Expand Down
84 changes: 65 additions & 19 deletions tests/test_nns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jax.numpy as jnp
import matplotlib.pyplot as plt

import numpy as np
import optax
import pandas as pd
Expand All @@ -13,11 +14,12 @@
from pyforecaster.forecasting_models.neural_models.ICNN import PICNN, RecStablePICNN, PIQCNN, PIQCNNSigmoid, \
StructuredPICNN, LatentStructuredPICNN, latent_pred
from pyforecaster.forecasting_models.neural_models.INN import CausalInvertibleNN
from pyforecaster.forecasting_models.neural_models.base_nn import NN, FFNN, LSTMNN
from pyforecaster.forecasting_models.neural_models.base_nn import FFNN
from pyforecaster.formatter import Formatter
from pyforecaster.trainer import hyperpar_optimizer

from pyforecaster.forecaster import LinearForecaster

class TestFormatDataset(unittest.TestCase):
def setUp(self) -> None:
self.data = pd.read_pickle('tests/data/test_data.zip').droplevel(0, 1)
Expand All @@ -26,7 +28,7 @@ def setUp(self) -> None:
filename=None)
formatter = Formatter(logger=self.logger).add_transform(['all'], lags=np.arange(144),
relative_lags=True)
formatter.add_transform(['all'], ['min', 'max'], agg_bins=[1, 2, 15, 20])
#formatter.add_transform(['all'], ['min', 'max'], agg_bins=[1, 2, 15, 20])
formatter.add_target_transform(['all'], lags=-np.arange(144))

self.x, self.y = formatter.transform(self.data.iloc[:40000])
Expand Down Expand Up @@ -57,10 +59,52 @@ def test_ffnn(self):
if not exists(savepath_tr_plots):
makedirs(savepath_tr_plots)

m = NN(learning_rate=1e-3, batch_size=1000, load_path=None, n_hidden_x=200,
n_out=y_tr.shape[1], n_layers=3).fit(x_tr,y_tr, n_epochs=3, savepath_tr_plots=savepath_tr_plots, stats_step=40)
# m = NN(learning_rate=1e-3, batch_size=1000, load_path=None, n_hidden_x=200,
# n_out=y_tr.shape[1], n_layers=3).fit(x_tr,y_tr, n_epochs=3, savepath_tr_plots=savepath_tr_plots, stats_step=40)
#
# y_hat_1 = m.predict(x_te.iloc[:100, :])

m_lin = LinearForecaster().fit(x_tr, y_tr)
y_hat_lin = m_lin.predict(x_te.iloc[:1000, :])

n_epochs=10
pars = {'learning_rate': 1e-3, 'batch_size': 1000, 'n_hidden_x': 200, 'n_out': y_tr.shape[1]}
m = FFNN(**pars).fit(x_tr,y_tr, n_epochs=n_epochs, savepath_tr_plots=savepath_tr_plots, stats_step=40)
y_hat = m.predict(x_te.iloc[:1000, :])

pars.update({'skip_connection': True})
m_skip = FFNN(**pars).fit(x_tr,y_tr, n_epochs=n_epochs, savepath_tr_plots=savepath_tr_plots, stats_step=40)
y_hat_skip = m_skip.predict(x_te.iloc[:1000, :])

selector = np.tile(np.argwhere(x_tr.columns=='all').ravel(), y_tr.shape[1])
pars.update({'selector': selector})
m_sel = FFNN(**pars).fit(x_tr, y_tr, n_epochs=n_epochs, savepath_tr_plots=savepath_tr_plots, stats_step=40)
y_hat_sel = m_sel.predict(x_te.iloc[:1000, :])

from pyforecaster.metrics import summary_scores, nmae
maes_nn = summary_scores(y_hat, y_te.iloc[:1000], metrics=[nmae], idxs=pd.DataFrame(y_hat.index.hour, index=y_hat.index))
maes_nn_skip = summary_scores(y_hat_skip, y_te.iloc[:1000], metrics=[nmae], idxs=pd.DataFrame(y_hat.index.hour, index=y_hat.index))
maes_nn_sel = summary_scores(y_hat_sel, y_te.iloc[:1000], metrics=[nmae], idxs=pd.DataFrame(y_hat.index.hour, index=y_hat.index))
maes_lin = summary_scores(y_hat_lin, y_te.iloc[:1000], metrics=[nmae], idxs=pd.DataFrame(y_hat.index.hour, index=y_hat.index))

fig, ax = plt.subplots(1, 1)
ax.plot(maes_nn['nmae'].mean(axis=0), label='nn')
ax.plot(maes_nn_skip['nmae'].mean(axis=0), label='nn skip')
ax.plot(maes_nn_sel['nmae'].mean(axis=0), label='nn sel')
ax.plot(maes_lin['nmae'].mean(axis=0), label='lin')
plt.legend()
plt.show()

for i in range(100):
if i % 5 == 0:
fig, ax = plt.subplots(1, 1, figsize=(4, 3))
ax.plot(y_te.iloc[i, :].values, linestyle='--', linewidth=2)
ax.plot(y_hat.iloc[i, :].values, linewidth=1)
ax.plot(y_hat_skip.iloc[i, :].values, linewidth=1)
ax.plot(y_hat_sel.iloc[i, :].values, linewidth=1)
plt.pause(1e-6)
plt.show()

y_hat_1 = m.predict(x_te.iloc[:100, :])

# def test_lstmnn(self):
# # normalize inputs
Expand Down Expand Up @@ -527,40 +571,42 @@ def test_invertible_causal_nn(self):
plt.pause(1e-6)
"""
m = FFNN(n_hidden_x=20, n_layers=2, learning_rate=1e-3, batch_size=500,
load_path=None, n_out=n_out, rel_tol=-1, stopping_rounds=10, n_epochs=10).fit(e_tr.iloc[:, :n_in], e_tr.iloc[:, -n_out:])
load_path=None, n_out=n_out, rel_tol=-1, stopping_rounds=10, n_epochs=2, skip_connection=True).fit(e_tr.iloc[:, :n_in], e_tr.iloc[:, -n_out:])
y_hat = m.predict(e_te.iloc[:, :n_in])


m = CausalInvertibleNN(learning_rate=1e-3, batch_size=300, load_path=None, n_in=n_in,
n_layers=2, normalize_target=False, n_epochs=2, stopping_rounds=30, rel_tol=-1,
end_to_end='full', n_hidden_y=20, n_prediction_layers=3, n_out=n_out,names_exogenous=['all_lag_000']).fit(e_tr.iloc[:, :n_in], e_tr.iloc[:, -n_out:])
z_hat_ete = m.predict(e_te.iloc[:, :n_in])

from pyforecaster.metrics import summary_scores, nmae

maes_nn = summary_scores(y_hat, e_te.iloc[:, -n_out:], metrics=[nmae],idxs=pd.DataFrame(e_te.index.hour, index=e_te.index))
maes_lin = summary_scores(y_hat_lin, e_te.iloc[:, -n_out:], metrics=[nmae],idxs=pd.DataFrame(e_te.index.hour, index=e_te.index))
maes_ete = summary_scores(z_hat_ete, e_te.iloc[:, -n_out:], metrics=[nmae],idxs=pd.DataFrame(e_te.index.hour, index=e_te.index))

fig, ax = plt.subplots(1, 1)
ax.plot(maes_lin['nmae'].mean(axis=0))
ax.plot(maes_nn['nmae'].mean(axis=0))



m = CausalInvertibleNN(learning_rate=1e-2, batch_size=300, load_path=None, n_in=n_in,
n_layers=2, normalize_target=False, n_epochs=2, stopping_rounds=30, rel_tol=-1,
end_to_end='full', n_hidden_y=300, n_prediction_layers=3, n_out=n_out,names_exogenous=['all_lag_000']).fit(e_tr.iloc[:, :n_in], e_tr.iloc[:, -n_out:])

z_hat_ete = m.predict(e_te.iloc[:, :n_in])
ax.plot(maes_lin['nmae'].mean(axis=0), label='lin')
ax.plot(maes_nn['nmae'].mean(axis=0), label='nn')
ax.plot(maes_ete['nmae'].mean(axis=0), label='ete')
plt.legend()
plt.show()

np.mean((z_hat_ete.values- e_te.iloc[:, -n_out:].values)**2)
np.mean((y_hat.values- e_te.iloc[:, -n_out:].values)**2)
np.mean((y_hat_lin.values- e_te.iloc[:, -n_out:].values)**2)

fig, ax = plt.subplots(1, 1, figsize=(4, 3))
for i in range(100):
if i%5 == 0:
plt.cla()
fig, ax = plt.subplots(1, 1, figsize=(4, 3))

ax.plot(e_te.iloc[i, -n_out:].values)
ax.plot(y_hat_lin.iloc[i, :].values, linewidth=1)
ax.plot(y_hat.iloc[i, :].values, linestyle='--')
ax.plot(z_hat_ete.iloc[i, :].values, linestyle='--')
plt.pause(1e-6)

plt.show()


def boxconstr(x, ub, lb):
Expand Down

0 comments on commit c8bbdaa

Please sign in to comment.