Skip to content

Commit

Permalink
corrected bug in discrete benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Dec 19, 2024
1 parent aa4cef4 commit c11976f
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 21 deletions.
18 changes: 12 additions & 6 deletions pyforecaster/forecasting_models/neural_models/INN.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyforecaster.forecasting_models.neural_models.neural_utils import jitting_wrapper
from functools import partial
import numpy as np
import matplotlib.pyplot as plt

def identity(params, err):
return err
Expand Down Expand Up @@ -52,9 +53,10 @@ def full_end_to_end_loss_fn(params, inputs, targets, model=None, embedder=None,
class CausalInvertibleModule(nn.Module):
num_layers: int = 3
features: int = 32
scaling_factor:float = 0.1

def setup(self):
self.layers = [CausalInvertibleLayer(prediction_layer=l==self.num_layers-1, features=self.features) for l in range(self.num_layers)]
self.layers = [CausalInvertibleLayer(prediction_layer=l==self.num_layers-1, features=self.features, scaling_factor=self.scaling_factor) for l in range(self.num_layers)]
def __call__(self, x):
for i in range(self.num_layers):
x = self.layers[i](x)
Expand All @@ -78,8 +80,10 @@ class EndToEndCausalInvertibleModule(nn.Module):
n_exogenous_features: int = 32
n_out: int = 32
activation: callable = nn.relu
scaling_factor: float = 0.1

def setup(self):
self.embedder = CausalInvertibleModule(num_layers=self.num_embedding_layers, features=self.features_embedding)
self.embedder = CausalInvertibleModule(num_layers=self.num_embedding_layers, features=self.features_embedding, scaling_factor=self.scaling_factor)
self.predictor = FeedForwardModule(n_layers=np.hstack([(np.ones(self.num_prediction_layers-1)*self.features_prediction).astype(int), self.n_out]), split_heads=False)
self.invert_fun = jax.jit(partial(self.inverter, embedder=self.embedder))
def __call__(self, x):
Expand Down Expand Up @@ -112,20 +116,22 @@ class CausalInvertibleNN(NN):
n_hidden_y = 200
names_exogenous = None
n_exogenous = 0
scaling_factor = 0.1
def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None,
n_in: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None,
val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None,
stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=False,
stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False,
probabilistic_loss_kind='maximum_likelihood', end_to_end='none', n_prediction_layers=3,
n_hidden_y=200, names_exogenous=None,
n_hidden_y=200, names_exogenous=None, scaling_factor=0.1,
**scengen_kwgs):

self.set_attr({"names_exogenous": names_exogenous,
"end_to_end": end_to_end,
"n_prediction_layers": n_prediction_layers,
"num_embedding_layers": n_hidden_y,
"n_exogenous":len(names_exogenous) if names_exogenous is not None else 0})
"n_exogenous":len(names_exogenous) if names_exogenous is not None else 0,
"scaling_factor": scaling_factor})
assert n_in - self.n_exogenous >= n_out, ('the history length must be greater than the forecast horizon to '
'learn an efficiently invertible causal transformation')
super().__init__(learning_rate, batch_size, load_path, n_in, n_out, n_layers, pars, q_vect, val_ratio,
Expand All @@ -140,7 +146,7 @@ def set_arch(self):
features_prediction=self.n_hidden_y,
features_embedding=self.n_hidden_x - self.n_exogenous,
n_exogenous_features=self.n_exogenous,
n_out=self.n_out) if (self.end_to_end in ['full', 'quasi']) \
n_out=self.n_out, scaling_factor=self.scaling_factor) if (self.end_to_end in ['full', 'quasi']) \
else CausalInvertibleModule(num_layers=self.n_layers, features=self.n_hidden_x)

self.predict_batch = vmap(jitting_wrapper(predict_batch, self.model), in_axes=(None, 0))
Expand Down Expand Up @@ -194,7 +200,7 @@ def predict(self, inputs, return_sigma=False, **kwargs):
y_hat = invert(self.pars, e_future, self.model)[:, -embeddings_hat.shape[1]:]

# embedding-predicted embedding distributions
import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 1, figsize = (10, 6))
ax[0].hist(np.array(embeddings.ravel()), bins=100, alpha=0.5, density=True, label='past embedding')
ax[0].hist(np.array(embeddings_hat.ravel()), bins=100, alpha=0.5, density=True, label='forecasted embedding')
Expand Down
129 changes: 118 additions & 11 deletions pyforecaster/forecasting_models/neural_models/base_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from inspect import getmro
from os.path import join
from typing import Union

from time import time
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
Expand All @@ -23,6 +23,13 @@
from pyforecaster.forecasting_models.neural_models.neural_utils import jitting_wrapper, reproject_weights


@jax.jit
def compute_mean_losses(te_loss_i, tr_loss_i):
"""Compute mean training and validation losses."""
mean_te = jnp.mean(te_loss_i)
mean_tr = jnp.mean(tr_loss_i)
return mean_te, mean_tr

def train_step(params, optimizer_state, inputs_batch, targets_batch, model=None, loss_fn=None, **kwargs):
values, grads = value_and_grad(loss_fn)(params, inputs_batch, targets_batch, **kwargs)
updates, opt_state = model.update(grads, optimizer_state, params)
Expand Down Expand Up @@ -128,6 +135,51 @@ def __call__(self, x):

return x + y

class LSTMModule(nn.Module):
n_layers: int
n_neurons: int
n_out: int
@nn.compact
def __call__(self, x):

if isinstance(self.n_layers, int):
layers = np.ones(self.n_layers) * self.n_neurons
layers[-1] = self.n_out
layers = layers
else:
layers = np.array(self.n_layers)

layers = layers.astype(int).tolist()
ScanLSTM = nn.scan(
nn.LSTMCell,
variable_broadcast="params",
split_rngs={"params": False},
in_axes=0, # scan over first axis of x (time)
out_axes=0 # produce outputs with time on the first axis
)

for i, n in enumerate(layers):
if i < len(layers) - 1:
# Define a scanned LSTM cell that scans over time dimension:

# Create the LSTM cell
lstm = ScanLSTM(n)
# the carry is the state of the LSTM cell, contains c and h states
carry = lstm.initialize_carry(random.key(0), (x.shape[-1],))
_, x = lstm(carry, x) # return sequence, trow away the state

else:
# Create the LSTM cell
lstm = ScanLSTM(n)
carry = lstm.initialize_carry(random.key(0), (x.shape[-1],))
_, x = lstm(carry, x)
x = nn.Dense(features=n)(x[-1]) # take the last temporal step, trow away the state

return x




class NN(ScenarioGenerator):
input_scaler: StandardScaler = None
target_scaler: StandardScaler = None
Expand Down Expand Up @@ -245,6 +297,8 @@ def set_arch(self):
jitting_wrapper(loss_fn, self.predict_batch))
self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)

def check_inputs(self, x, y):
return x, y

def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step=None, rel_tol=None):
self.to_be_normalized = [c for c in inputs.columns if
Expand All @@ -269,12 +323,20 @@ def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step
num_batches = inputs.shape[0] // batch_size

inputs, targets = self.get_normalized_inputs(inputs, targets)

inputs, targets = self.check_inputs(inputs, targets)

inputs_val, targets_val = self.get_normalized_inputs(inputs_val_0, targets_val_0)
inputs_len = [i.shape[1] for i in inputs] if isinstance(inputs, tuple) else inputs.shape[1]
#inputs_len = [i.shape[1] for i in inputs] if isinstance(inputs, tuple) else inputs.shape[1]

inputs_val, targets_val = self.check_inputs(inputs_val, targets_val)
inputs_len = [np.squeeze(i.shape[1:]) for i in inputs] if isinstance(inputs, tuple) else inputs.shape[1:]

pars = self.init_arch(self.model, *np.atleast_1d(inputs_len))
opt_state = self.optimizer.init(pars)



tr_loss, val_loss = [np.inf], [np.inf]
k = 0
finished = False
Expand All @@ -283,31 +345,46 @@ def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step
for i in tqdm(range(num_batches),
desc='epoch {}/{}, val loss={:0.3e}'.format(epoch, n_epochs, val_loss[-1] if val_loss[-1] is not np.inf else np.nan)):
rand_idx = rand_idx_all[i * batch_size:(i + 1) * batch_size]
inputs_batch = [i[rand_idx, :] for i in inputs] if isinstance(inputs, tuple) else inputs[rand_idx, :]
inputs_batch = [i[rand_idx, :] for i in inputs] if isinstance(inputs, tuple) else inputs[rand_idx]
targets_batch = targets[rand_idx, :]

pars, opt_state, values = self.train_step(pars, opt_state, inputs_batch, targets_batch)
if self.reproject:
pars = reproject_weights(pars, rec_stable=self.rec_stable, monotone=self.monotone)

if k % stats_step == 0 and k > 0:

old_pars = self.pars
self.pars = pars
rand_idx_val = np.random.choice(validation_len, np.minimum(batch_size, validation_len), replace=False)
inputs_val_sampled = [i[rand_idx_val, :] for i in inputs_val] if isinstance(inputs_val, tuple) else inputs_val[rand_idx_val, :]
te_loss_i = self.loss_fn(pars, inputs_val_sampled, targets_val[rand_idx_val, :])
inputs_val_sampled = [i[rand_idx_val] for i in inputs_val] if isinstance(inputs_val, tuple) else inputs_val[rand_idx_val]
t_0 = time()
te_loss_i = self.loss_fn(pars, inputs_val_sampled, targets_val[rand_idx_val])
tr_loss_i = self.loss_fn(pars, inputs_batch, targets_batch)
val_loss.append(np.array(jnp.mean(te_loss_i)))
tr_loss.append(np.array(jnp.mean(tr_loss_i)))
t_1 = time()
mean_te, mean_tr = compute_mean_losses(te_loss_i, tr_loss_i)

# Ensure computations are complete before transferring to host
mean_te = mean_te.block_until_ready()
mean_tr = mean_tr.block_until_ready()

# Convert to Python scalars
val_loss.append(float(mean_te))
tr_loss.append(float(mean_tr))

self.logger.info('tr loss: {:0.2e}, te loss: {:0.2e}'.format(tr_loss[-1], val_loss[-1]))
print(f'tr loss: {tr_loss[-1]:.2e}, te loss: {val_loss[-1]:.2e}, eval took: {t_1-t_0:.2e} s , averaging took: {time() - t_0:.2e} s')

# val_loss.append(np.array(jnp.mean(te_loss_i)))
# tr_loss.append(np.array(jnp.mean(tr_loss_i)))
#print('tr loss: {:0.2e}, te loss: {:0.2e}, eval took:{:0.2e} s'.format(tr_loss[-1], val_loss[-1], time()-t_0))
#self.logger.info('tr loss: {:0.2e}, te loss: {:0.2e}, eval took:{:0.2e} s'.format(tr_loss[-1], val_loss[-1], time()-t_0))
if len(tr_loss) > 2:
if savepath_tr_plots is not None or self.savepath_tr_plots is not None:
savepath_tr_plots = savepath_tr_plots if savepath_tr_plots is not None else self.savepath_tr_plots

rand_idx_plt = np.random.choice(validation_len, 9)
self.training_plots([i[rand_idx_plt, :] for i in inputs_val] if isinstance(inputs_val, tuple) else inputs_val[rand_idx_plt, :],
targets_val[rand_idx_plt, :], tr_loss[1:], val_loss[1:], savepath_tr_plots, k)
self.training_plots([i[rand_idx_plt] for i in inputs_val] if isinstance(inputs_val, tuple) else inputs_val[rand_idx_plt],
targets_val[rand_idx_plt], tr_loss[1:], val_loss[1:], savepath_tr_plots, k)
plt.close("all")
rel_te_err = (val_loss[-2] - val_loss[-1]) / np.abs(val_loss[-2] + 1e-6)
last_improvement = k // stats_step - np.argwhere(np.array(val_loss) == np.min(val_loss)).ravel()[-1]
Expand Down Expand Up @@ -350,6 +427,7 @@ def training_plots(self, inputs, target, tr_loss, te_loss, savepath, k):

def predict(self, inputs, return_sigma=False, **kwargs):
x, _ = self.get_normalized_inputs(inputs)
x, _ = self.check_inputs(x, None)
y_hat = self.predict_batch(self.pars, x)
y_hat = np.array(y_hat)
if self.normalize_target:
Expand All @@ -369,6 +447,7 @@ def predict(self, inputs, return_sigma=False, **kwargs):
return pd.DataFrame(y_hat, index=inputs.index, columns=self.target_columns)

def predict_quantiles(self, inputs, normalize=True, **kwargs):
inputs, _ = self.check_inputs(inputs, None)
if self.probabilistic:
if normalize:
mu_hat, sigma_hat = self.predict(inputs, return_sigma=True)
Expand Down Expand Up @@ -419,4 +498,32 @@ def set_arch(self):
self.predict_batch = vmap(jitting_wrapper(predict_batch, self.model), in_axes=(None, 0))
self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else (
jitting_wrapper(loss_fn, self.predict_batch))
self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)
self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)

class LSTMNN(NN):
def __init__(self, n_out=None, q_vect=None, n_epochs=10, val_ratio=None, nodes_at_step=None, learning_rate=1e-3,
scengen_dict={}, batch_size=None, **model_kwargs):
super().__init__(n_out=n_out, q_vect=q_vect, n_epochs=n_epochs, val_ratio=val_ratio, nodes_at_step=nodes_at_step, learning_rate=learning_rate,
nn_module=FeedForwardModule, scengen_dict=scengen_dict, batch_size=batch_size, **model_kwargs)

def set_arch(self):
self.optimizer = optax.adamw(learning_rate=self.learning_rate)
self.model = LSTMModule(n_layers=self.n_layers, n_neurons=self.n_hidden_x,
n_out=self.n_out)
self.predict_batch = vmap(jitting_wrapper(predict_batch, self.model), in_axes=(None, 0))
self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch) if self.probabilistic else (
jitting_wrapper(loss_fn, self.predict_batch))
self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)


def check_inputs(self, x, y):
x = np.atleast_3d(x)
return x, y

@staticmethod
def init_arch(nn_init, *n_inputs_x):
"divides data into training and test sets "
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (*n_inputs_x,)) # Dummy input data (for the first input)
init_params = nn_init.init(key2, x) # Initialization call
return init_params
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ jax>=0.4.1
jaxlib>=0.4.1
quantile-forest>=1.3.10
optax>=0.1.7
flax>=0.7.4
flax>=0.10
statsmodels>=0.14.2
33 changes: 30 additions & 3 deletions tests/test_nns.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyforecaster.forecasting_models.neural_models.ICNN import PICNN, RecStablePICNN, PIQCNN, PIQCNNSigmoid, \
StructuredPICNN, LatentStructuredPICNN, latent_pred
from pyforecaster.forecasting_models.neural_models.INN import CausalInvertibleNN
from pyforecaster.forecasting_models.neural_models.base_nn import NN, FFNN
from pyforecaster.forecasting_models.neural_models.base_nn import NN, FFNN, LSTMNN
from pyforecaster.formatter import Formatter
from pyforecaster.trainer import hyperpar_optimizer

Expand Down Expand Up @@ -48,17 +48,43 @@ def test_ffnn(self):
x_tr, x_te, y_tr, y_te = [x.iloc[:n_tr, :].copy(), x.iloc[n_tr:, :].copy(), y.iloc[:n_tr].copy(),
y.iloc[n_tr:].copy()]

x_tr = x_tr[np.flip(x_tr.columns)]
x_te = x_te[np.flip(x_te.columns)]

savepath_tr_plots = 'tests/results/ffnn_tr_plots'

# if not there, create directory savepath_tr_plots
if not exists(savepath_tr_plots):
makedirs(savepath_tr_plots)

m = NN(learning_rate=1e-3, batch_size=1000, load_path=None, n_hidden_x=200,
n_out=y_tr.shape[1], n_layers=3).fit(x_tr,y_tr, n_epochs=1, savepath_tr_plots=savepath_tr_plots, stats_step=40)
n_out=y_tr.shape[1], n_layers=3).fit(x_tr,y_tr, n_epochs=3, savepath_tr_plots=savepath_tr_plots, stats_step=40)

y_hat_1 = m.predict(x_te.iloc[:100, :])

# def test_lstmnn(self):
# # normalize inputs
# x = (self.x - self.x.mean(axis=0)) / (self.x.std(axis=0) + 0.01)
# y = (self.y - self.y.mean(axis=0)) / (self.y.std(axis=0) + 0.01)
#
# n_tr = int(len(x) * 0.8)
# x_tr, x_te, y_tr, y_te = [x.iloc[:n_tr, :].copy(), x.iloc[n_tr:, :].copy(), y.iloc[:n_tr].copy(),
# y.iloc[n_tr:].copy()]
#
# x_tr = x_tr[np.flip(x_tr.columns)]
# x_te = x_te[np.flip(x_te.columns)]
#
# savepath_tr_plots = 'tests/results/ffnn_tr_plots'
#
# # if not there, create directory savepath_tr_plots
# if not exists(savepath_tr_plots):
# makedirs(savepath_tr_plots)
#
# m = LSTMNN(learning_rate=1e-3, batch_size=100, load_path=None, n_hidden_x=20,
# n_out=y_tr.shape[1], n_layers=2).fit(x_tr,y_tr, n_epochs=5, savepath_tr_plots=savepath_tr_plots, stats_step=40)
#
# y_hat_1 = m.predict(x_te.iloc[:100, :])


def test_picnn(self):
# normalize inputs
Expand Down Expand Up @@ -516,7 +542,7 @@ def test_invertible_causal_nn(self):


m = CausalInvertibleNN(learning_rate=1e-2, batch_size=300, load_path=None, n_in=n_in,
n_layers=2, normalize_target=False, n_epochs=5, stopping_rounds=30, rel_tol=-1,
n_layers=2, normalize_target=False, n_epochs=2, stopping_rounds=30, rel_tol=-1,
end_to_end='full', n_hidden_y=300, n_prediction_layers=3, n_out=n_out,names_exogenous=['all_lag_000']).fit(e_tr.iloc[:, :n_in], e_tr.iloc[:, -n_out:])

z_hat_ete = m.predict(e_te.iloc[:, :n_in])
Expand All @@ -531,6 +557,7 @@ def test_invertible_causal_nn(self):
plt.cla()
ax.plot(e_te.iloc[i, -n_out:].values)
ax.plot(y_hat_lin.iloc[i, :].values, linewidth=1)
ax.plot(y_hat.iloc[i, :].values, linestyle='--')
ax.plot(z_hat_ete.iloc[i, :].values, linestyle='--')
plt.pause(1e-6)

Expand Down

0 comments on commit c11976f

Please sign in to comment.