-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
1,263 additions
and
1,228 deletions.
There are no files selected for viewing
1,218 changes: 0 additions & 1,218 deletions
1,218
pyforecaster/forecasting_models/neural_forecasters.py
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
Empty file.
375 changes: 375 additions & 0 deletions
375
pyforecaster/forecasting_models/neural_models/base_nn.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from functools import partial | ||
|
||
import jax.numpy as jnp | ||
from flax import linen as nn | ||
|
||
from pyforecaster.forecasting_models.neural_models.neural_utils import positive_lecun, identity | ||
|
||
|
||
class PICNNLayer(nn.Module): | ||
features_x: int | ||
features_y: int | ||
features_out: int | ||
features_latent: int | ||
prediction_layer: bool = False | ||
activation: callable = nn.relu | ||
rec_activation: callable = identity | ||
init_type: str = 'normal' | ||
augment_ctrl_inputs: bool = False | ||
layer_normalization: bool = False | ||
z_min: jnp.array = None | ||
z_max: jnp.array = None | ||
@nn.compact | ||
def __call__(self, y, u, z): | ||
if self.augment_ctrl_inputs: | ||
y = jnp.hstack([y, -y]) | ||
|
||
y_add_kernel_init = nn.initializers.lecun_normal() if self.rec_activation == identity else partial(positive_lecun, init_type=self.init_type) | ||
# Input-Convex component without bias for the element-wise multiplicative interactions | ||
wzu = nn.relu(nn.Dense(features=self.features_latent, use_bias=True, name='wzu')(u)) | ||
wyu = self.rec_activation(nn.Dense(features=self.features_y, use_bias=True, name='wyu')(u)) | ||
z_add = nn.Dense(features=self.features_out, use_bias=False, name='wz', kernel_init=partial(positive_lecun, init_type=self.init_type))(z * wzu) | ||
y_add = nn.Dense(features=self.features_out, use_bias=False, name='wy', kernel_init=y_add_kernel_init)(y * wyu) | ||
u_add = nn.Dense(features=self.features_out, use_bias=True, name='wuz')(u) | ||
|
||
|
||
if self.layer_normalization: | ||
y_add = nn.LayerNorm()(y_add) | ||
z_add = nn.LayerNorm()(z_add) | ||
u_add = nn.LayerNorm()(u_add) | ||
|
||
z_next = z_add + y_add + u_add | ||
if not self.prediction_layer: | ||
z_next = self.activation(z_next) | ||
# Traditional NN component only if it's not the prediction layer | ||
u_next = nn.Dense(features=self.features_x, name='u_dense')(u) | ||
u_next = self.activation(u_next) | ||
return u_next, z_next | ||
else: | ||
if self.z_min is not None: | ||
z_next = nn.sigmoid(z_next) * (self.z_max - self.z_min) + self.z_min | ||
return None, z_next | ||
|
47 changes: 47 additions & 0 deletions
47
pyforecaster/forecasting_models/neural_models/neural_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from functools import partial | ||
|
||
from jax import jit | ||
from jax import numpy as jnp | ||
from jax import random | ||
|
||
|
||
def jitting_wrapper(fun, model, **kwargs): | ||
return jit(partial(fun, model=model, **kwargs)) | ||
|
||
|
||
def positive_lecun(key, shape, dtype=jnp.float32, init_type='normal'): | ||
# Start with standard lecun_normal initialization | ||
stddev = 1. / jnp.sqrt(shape[1]) | ||
if init_type == 'normal': | ||
weights = random.normal(key, shape, dtype) * stddev | ||
elif init_type == 'uniform': | ||
weights = random.uniform(key, shape, dtype) * stddev | ||
else: | ||
raise NotImplementedError('init_type {} not implemented'.format(init_type)) | ||
# Ensure weights are non-negative | ||
return jnp.abs(weights)/10 | ||
|
||
|
||
def identity(x): | ||
return x | ||
|
||
def reproject_weights(params, rec_stable=False, monotone=False): | ||
# Loop through each layer and reproject the input-convex weights | ||
for layer_name in params['params']: | ||
if 'PICNNLayer' in layer_name: | ||
if 'wz' in params['params'][layer_name].keys(): | ||
_reproject(params['params'], layer_name, rec_stable=rec_stable, monotone=monotone or ('monotone' in layer_name)) | ||
else: | ||
for name in params['params'][layer_name].keys(): | ||
_reproject(params['params'][layer_name], name, rec_stable=rec_stable, monotone=monotone or ('monotone' in layer_name)) | ||
|
||
return params | ||
|
||
def _reproject(params, layer_name, rec_stable=False, monotone=False): | ||
if ('monotone' in layer_name) or monotone: | ||
for name in {'wz', 'wy'} & set(params[layer_name].keys()): | ||
params[layer_name][name]['kernel'] = jnp.maximum(0, params[layer_name][name]['kernel']) | ||
else: | ||
params[layer_name]['wz']['kernel'] = jnp.maximum(0, params[layer_name]['wz']['kernel']) | ||
if rec_stable: | ||
params[layer_name]['wy']['kernel'] = jnp.maximum(0, params[layer_name]['wy']['kernel']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters