Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Feb 27, 2024
1 parent 01a9de5 commit a70069c
Show file tree
Hide file tree
Showing 8 changed files with 1,263 additions and 1,228 deletions.
1,218 changes: 0 additions & 1,218 deletions pyforecaster/forecasting_models/neural_forecasters.py

Large diffs are not rendered by default.

775 changes: 775 additions & 0 deletions pyforecaster/forecasting_models/neural_models/ICNN.py

Large diffs are not rendered by default.

Empty file.
Empty file.
375 changes: 375 additions & 0 deletions pyforecaster/forecasting_models/neural_models/base_nn.py

Large diffs are not rendered by default.

52 changes: 52 additions & 0 deletions pyforecaster/forecasting_models/neural_models/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from functools import partial

import jax.numpy as jnp
from flax import linen as nn

from pyforecaster.forecasting_models.neural_models.neural_utils import positive_lecun, identity


class PICNNLayer(nn.Module):
features_x: int
features_y: int
features_out: int
features_latent: int
prediction_layer: bool = False
activation: callable = nn.relu
rec_activation: callable = identity
init_type: str = 'normal'
augment_ctrl_inputs: bool = False
layer_normalization: bool = False
z_min: jnp.array = None
z_max: jnp.array = None
@nn.compact
def __call__(self, y, u, z):
if self.augment_ctrl_inputs:
y = jnp.hstack([y, -y])

y_add_kernel_init = nn.initializers.lecun_normal() if self.rec_activation == identity else partial(positive_lecun, init_type=self.init_type)
# Input-Convex component without bias for the element-wise multiplicative interactions
wzu = nn.relu(nn.Dense(features=self.features_latent, use_bias=True, name='wzu')(u))
wyu = self.rec_activation(nn.Dense(features=self.features_y, use_bias=True, name='wyu')(u))
z_add = nn.Dense(features=self.features_out, use_bias=False, name='wz', kernel_init=partial(positive_lecun, init_type=self.init_type))(z * wzu)
y_add = nn.Dense(features=self.features_out, use_bias=False, name='wy', kernel_init=y_add_kernel_init)(y * wyu)
u_add = nn.Dense(features=self.features_out, use_bias=True, name='wuz')(u)


if self.layer_normalization:
y_add = nn.LayerNorm()(y_add)
z_add = nn.LayerNorm()(z_add)
u_add = nn.LayerNorm()(u_add)

z_next = z_add + y_add + u_add
if not self.prediction_layer:
z_next = self.activation(z_next)
# Traditional NN component only if it's not the prediction layer
u_next = nn.Dense(features=self.features_x, name='u_dense')(u)
u_next = self.activation(u_next)
return u_next, z_next
else:
if self.z_min is not None:
z_next = nn.sigmoid(z_next) * (self.z_max - self.z_min) + self.z_min
return None, z_next

47 changes: 47 additions & 0 deletions pyforecaster/forecasting_models/neural_models/neural_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from functools import partial

from jax import jit
from jax import numpy as jnp
from jax import random


def jitting_wrapper(fun, model, **kwargs):
return jit(partial(fun, model=model, **kwargs))


def positive_lecun(key, shape, dtype=jnp.float32, init_type='normal'):
# Start with standard lecun_normal initialization
stddev = 1. / jnp.sqrt(shape[1])
if init_type == 'normal':
weights = random.normal(key, shape, dtype) * stddev
elif init_type == 'uniform':
weights = random.uniform(key, shape, dtype) * stddev
else:
raise NotImplementedError('init_type {} not implemented'.format(init_type))
# Ensure weights are non-negative
return jnp.abs(weights)/10


def identity(x):
return x

def reproject_weights(params, rec_stable=False, monotone=False):
# Loop through each layer and reproject the input-convex weights
for layer_name in params['params']:
if 'PICNNLayer' in layer_name:
if 'wz' in params['params'][layer_name].keys():
_reproject(params['params'], layer_name, rec_stable=rec_stable, monotone=monotone or ('monotone' in layer_name))
else:
for name in params['params'][layer_name].keys():
_reproject(params['params'][layer_name], name, rec_stable=rec_stable, monotone=monotone or ('monotone' in layer_name))

return params

def _reproject(params, layer_name, rec_stable=False, monotone=False):
if ('monotone' in layer_name) or monotone:
for name in {'wz', 'wy'} & set(params[layer_name].keys()):
params[layer_name][name]['kernel'] = jnp.maximum(0, params[layer_name][name]['kernel'])
else:
params[layer_name]['wz']['kernel'] = jnp.maximum(0, params[layer_name]['wz']['kernel'])
if rec_stable:
params[layer_name]['wy']['kernel'] = jnp.maximum(0, params[layer_name]['wy']['kernel'])
24 changes: 14 additions & 10 deletions tests/test_nns.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import unittest
import matplotlib.pyplot as plt
import optax
import pandas as pd
import numpy as np
import logging
from pyforecaster.forecasting_models.neural_forecasters import PICNN, RecStablePICNN, NN, PIQCNN, PIQCNNSigmoid, StructuredPICNN, LatentStructuredPICNN
from pyforecaster.trainer import hyperpar_optimizer
from pyforecaster.formatter import Formatter
from pyforecaster.metrics import nmae
import unittest
from os import makedirs
from os.path import exists, join

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
from jax import vmap
from pyforecaster.forecasting_models.neural_forecasters import latent_pred

from pyforecaster.forecasting_models.neural_models.ICNN import PICNN, RecStablePICNN, PIQCNN, PIQCNNSigmoid, \
StructuredPICNN, LatentStructuredPICNN, latent_pred
from pyforecaster.forecasting_models.neural_models.base_nn import NN
from pyforecaster.formatter import Formatter
from pyforecaster.trainer import hyperpar_optimizer


class TestFormatDataset(unittest.TestCase):
def setUp(self) -> None:
self.data = pd.read_pickle('tests/data/test_data.zip').droplevel(0, 1)
Expand Down

0 comments on commit a70069c

Please sign in to comment.