Skip to content

Commit

Permalink
added crps, made two separate networks for mu and sigma in an attempt…
Browse files Browse the repository at this point in the history
… to increase expressivity
  • Loading branch information
nepslor committed Dec 15, 2023
1 parent 0110491 commit 706fc11
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 35 deletions.
91 changes: 64 additions & 27 deletions pyforecaster/forecasting_models/neural_forecasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,19 @@ def loss_fn(params, inputs, targets, model=None):
predictions = model(params, inputs)
return jnp.mean((predictions - targets) ** 2)

def probabilistic_loss_fn(params, inputs, targets, model=None):
def probabilistic_loss_fn(params, inputs, targets, model=None, kind='maximum_likelihood'):
out = model(params, inputs)
predictions = out[:, :out.shape[1]//2]
sigma_square = out[:, out.shape[1]//2:]
ll = jnp.mean(((predictions - targets)**2) / sigma_square + jnp.log(sigma_square))
if kind == 'maximum_likelihood':
ll = jnp.mean(((predictions - targets)**2) / sigma_square + jnp.log(sigma_square))
elif kind == 'crps':
sigma = jnp.sqrt(sigma_square)
u = (targets - predictions) / sigma
ll = jnp.mean(sigma * (u * (2 * jax.scipy.stats.norm.cdf(u) - 1) + 2 * jax.scipy.stats.norm.pdf(u) - 1 / jnp.sqrt(jnp.pi)))
return ll


def train_step(params, optimizer_state, inputs_batch, targets_batch, model=None, loss_fn=None, **kwargs):
values, grads = value_and_grad(loss_fn)(params, inputs_batch, targets_batch, **kwargs)
updates, opt_state = model.update(grads, optimizer_state, params)
Expand Down Expand Up @@ -171,6 +177,7 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat
val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None,
stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True,
stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False,
probabilistic_loss_kind='maximum_likelihood',
**scengen_kwgs):

super().__init__(q_vect, val_ratio=val_ratio, nodes_at_step=nodes_at_step, **scengen_kwgs)
Expand All @@ -191,7 +198,8 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat
"stopping_rounds":stopping_rounds,
"subtract_mean_when_normalizing":subtract_mean_when_normalizing,
"causal_df":causal_df,
"probabilistic":probabilistic
"probabilistic":probabilistic,
"probabilistic_loss_kind":probabilistic_loss_kind
})

if self.load_path is not None:
Expand Down Expand Up @@ -245,7 +253,7 @@ def init_arch(nn_init, n_inputs_x=1):

def set_arch(self):
model = FeedForwardModule(n_layers=self.n_layers, n_neurons=self.n_hidden_x,
n_out=self.n_out*2 if self.probabilistic else self.n_out)
n_out=self.n_out)
return model

def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step=None, rel_tol=None):
Expand Down Expand Up @@ -325,13 +333,19 @@ def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step
def training_plots(self, inputs, target, tr_loss, te_loss, savepath, k):
n_instances = target.shape[0]
y_hat = self.predict_batch(self.pars, inputs)

if self.probabilistic:
q_hat = self.predict_quantiles(inputs, normalize=False)
# make the appropriate numbers of subplots disposed as a square
fig, ax = plt.subplots(int(np.ceil(np.sqrt(n_instances))), int(np.ceil(np.sqrt(n_instances))))
fig, ax = plt.subplots(int(np.ceil(np.sqrt(n_instances))), int(np.ceil(np.sqrt(n_instances))), figsize=(10, 10),
layout='tight')
for i, a in enumerate(ax.ravel()):
a.plot(y_hat[i, :target.shape[1]])
l = a.plot(y_hat[i, :target.shape[1]])
if self.probabilistic:
a.plot(q_hat[i, :target.shape[1], :], '--', color=l[0].get_color(), alpha=0.2)
a.plot(target[i, :])
a.set_title('instance {}, iter {}'.format(i, k))


plt.savefig(join(savepath, 'examples_iter_{:05d}.png'.format(k)))

fig, ax = plt.subplots(1, 1)
Expand Down Expand Up @@ -360,8 +374,11 @@ def predict(self, inputs, return_sigma=False, **kwargs):
else:
return pd.DataFrame(y_hat, index=inputs.index, columns=self.target_columns)

def predict_quantiles(self, inputs, **kwargs):
x, _ = self.get_normalized_inputs(inputs)
def predict_quantiles(self, inputs, normalize=True, **kwargs):
if normalize:
x, _ = self.get_normalized_inputs(inputs)
else:
x = inputs
y_hat = self.predict_batch(self.pars, x)
y_hat = np.array(y_hat)
if self.normalize_target:
Expand Down Expand Up @@ -421,7 +438,17 @@ def __call__(self, x, y):
augment_ctrl_inputs=self.augment_ctrl_inputs,
layer_normalization=self.layer_normalization)(y, u, z)
if self.probabilistic:
return jnp.hstack([z[:self.features_out//2], nn.softplus(z[self.features_out//2:]) + 1e-8])
u = x
sigma = jnp.zeros(self.features_out) # Initialize z_0 to be the same shape as y
for i in range(self.num_layers):
prediction_layer = i == self.num_layers - 1
u, sigma = PICNNLayer(features_x=self.features_x, features_y=self.features_y,
features_out=self.features_out,
n_layer=i, prediction_layer=prediction_layer, activation=self.activation,
rec_activation=self.rec_activation, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs,
layer_normalization=self.layer_normalization)(y, u, sigma)
return jnp.hstack([z, nn.softplus(sigma) + 1e-8])
return z


Expand Down Expand Up @@ -482,13 +509,21 @@ def causal_loss_fn(params, inputs, targets, model=None, causal_matrix=None):
return mse + jnp.mean(causal_loss)


def probabilistic_causal_loss_fn(params, inputs, targets, model=None, causal_matrix=None):
def probabilistic_causal_loss_fn(params, inputs, targets, model=None, causal_matrix=None, kind='maximum_likelihood'):
ex_inputs, ctrl_inputs = inputs[0], inputs[1]
out = vmap(model.apply, in_axes=(None, 0, 0))(params, ex_inputs, ctrl_inputs)
predictions = out[:, :out.shape[1]//2]
sigma_square = out[:, out.shape[1]//2:]
causal_loss = vmap(_my_jmp, in_axes=(None, None, 0, 0, None))(model, params, ex_inputs, ctrl_inputs, causal_matrix.T)
ll = jnp.mean(((predictions - targets) ** 2) / sigma_square + jnp.log(sigma_square))

if kind == 'maximum_likelihood':
ll = jnp.mean(((predictions - targets) ** 2) / sigma_square + jnp.log(sigma_square))
elif kind == 'crps':
sigma = jnp.sqrt(sigma_square)
u = (targets - predictions) / sigma
ll = jnp.mean(sigma * (
u * (2 * jax.scipy.stats.norm.cdf(u) - 1) + 2 * jax.scipy.stats.norm.pdf(u) - 1 / jnp.sqrt(jnp.pi)))

return ll + jnp.mean(causal_loss)


Expand All @@ -509,14 +544,14 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat
val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None,
stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True,
stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False,
inverter_learning_rate: float = 0.1, optimization_vars: list = (),
probabilistic_loss_kind='maximum_likelihood', inverter_learning_rate: float = 0.1, optimization_vars: list = (),
target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False,
optimizer=None, **scengen_kwgs):

super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio,
nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs,
normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df,
probabilistic, **scengen_kwgs)
probabilistic, probabilistic_loss_kind, **scengen_kwgs)

self.set_attr({"inverter_learning_rate":inverter_learning_rate,
"optimization_vars":optimization_vars,
Expand All @@ -537,15 +572,16 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat
if causal_df is not None:
causal_df = (~causal_df.astype(bool)).astype(float)
causal_matrix = np.tile(causal_df.values, 2) if self.probabilistic else causal_df.values
self.loss_fn = jitting_wrapper(causal_loss_fn, self.model, causal_matrix=causal_matrix) if not self.probabilistic else jitting_wrapper(probabilistic_causal_loss_fn, self.model, causal_matrix=causal_matrix)
self.loss_fn = jitting_wrapper(causal_loss_fn, self.model, causal_matrix=causal_matrix, kind=probabilistic_loss_kind) \
if not self.probabilistic else jitting_wrapper(probabilistic_causal_loss_fn, self.model, causal_matrix=causal_matrix, kind=probabilistic_loss_kind)
else:
self.loss_fn = jitting_wrapper(loss_fn, self.predict_batch) if not self.probabilistic else jitting_wrapper(probabilistic_loss_fn, self.predict_batch)

self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer)

def set_arch(self):
model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y,
features_out=self.n_out*2 if self.probabilistic else self.n_out, init_type=self.init_type,
features_out=self.n_out, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic)
return model

Expand Down Expand Up @@ -641,20 +677,21 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat
val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None,
stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True,
stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False,
probabilistic_loss_kind='maximum_likelihood',
inverter_learning_rate: float = 0.1, optimization_vars: list = (),
target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False,
optimizer=None, **scengen_kwgs):

super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio,
nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs,
normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic,
inverter_learning_rate, optimization_vars, target_columns, init_type, augment_ctrl_inputs,
layer_normalization, optimizer, **scengen_kwgs)
probabilistic_loss_kind, inverter_learning_rate, optimization_vars, target_columns, init_type,
augment_ctrl_inputs, layer_normalization, optimizer, **scengen_kwgs)


def set_arch(self):
model = PartiallyIQCNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y,
features_out=self.n_out*2 if self.probabilistic else self.n_out, init_type=self.init_type,
features_out=self.n_out, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic)
return model

Expand All @@ -669,20 +706,20 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat
val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None,
stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True,
stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False,
inverter_learning_rate: float = 0.1, optimization_vars: list = (),
probabilistic_loss_kind='maximum_likelihood', inverter_learning_rate: float = 0.1, optimization_vars: list = (),
target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False,
optimizer=None, **scengen_kwgs):

super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio,
nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs,
normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic,
inverter_learning_rate, optimization_vars, target_columns, init_type, augment_ctrl_inputs,
layer_normalization, optimizer, **scengen_kwgs)
probabilistic_loss_kind, inverter_learning_rate, optimization_vars, target_columns, init_type,
augment_ctrl_inputs, layer_normalization, optimizer, **scengen_kwgs)


def set_arch(self):
model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y,
features_out=self.n_out*2 if self.probabilistic else self.n_out, init_type=self.init_type,
features_out=self.n_out, init_type=self.init_type,
augment_ctrl_inputs=self.augment_ctrl_inputs, activation=nn.sigmoid,
rec_activation=nn.sigmoid, probabilistic=self.probabilistic)
return model
Expand All @@ -694,18 +731,18 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat
val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None,
stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True,
stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False,
inverter_learning_rate: float = 0.1, optimization_vars: list = (),
probabilistic_loss_kind='maximum_likelihood',inverter_learning_rate: float = 0.1, optimization_vars: list = (),
target_columns: list = None, init_type='normal', augment_ctrl_inputs=False,
layer_normalization=False, optimizer=None, **scengen_kwgs):

super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio,
nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs,
normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic,
inverter_learning_rate, optimization_vars, target_columns, init_type, augment_ctrl_inputs,
layer_normalization, optimizer, **scengen_kwgs)
probabilistic_loss_kind, inverter_learning_rate, optimization_vars, target_columns, init_type,
augment_ctrl_inputs, layer_normalization, optimizer, **scengen_kwgs)

def set_arch(self):
model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y,
features_out=self.n_out*2 if self.probabilistic else self.n_out, activation=nn.relu,
features_out=self.n_out, activation=nn.relu,
init_type=self.init_type, probabilistic=self.probabilistic)
return model
12 changes: 4 additions & 8 deletions tests/test_nns.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,8 @@ def test_picnn(self):


m_1 = PICNN(learning_rate=1e-3, batch_size=500, load_path=None, n_hidden_x=200, n_hidden_y=200,
n_out=y_tr.shape[1], n_layers=3, optimization_vars=optimization_vars,probabilistic=True, rel_tol=-1,
val_ratio=0.2).fit(x_tr,
y_tr,
n_epochs=1,
stats_step=200,
savepath_tr_plots=savepath_tr_plots)
n_out=y_tr.shape[1], n_layers=3, optimization_vars=optimization_vars,probabilistic=True, probabilistic_loss_kind='crps', rel_tol=-1,
val_ratio=0.2).fit(x_tr, y_tr,n_epochs=1,stats_step=100,savepath_tr_plots=savepath_tr_plots)

y_hat_1 = m_1.predict(x_te)
m_1.save('tests/results/ffnn_model.pk')
Expand All @@ -83,9 +79,9 @@ def test_picnn(self):
y_hat = m_1.predict(x_te.iloc[[r], :])
q_hat = m_1.predict_quantiles(x_te.iloc[[r], :])
plt.figure()
plt.plot(y_hat.values.ravel(), label='y_hat')
plt.plot(y_te.iloc[r, :].values.ravel(), label='y_true')
plt.plot(np.squeeze(q_hat), label='q_hat', color='red', alpha=0.2)
plt.plot(y_hat.values.ravel(), label='y_hat')
plt.plot(np.squeeze(q_hat), label='q_hat', color='orange', alpha=0.3)
plt.legend()

n = PICNN(load_path='tests/results/ffnn_model.pk')
Expand Down

0 comments on commit 706fc11

Please sign in to comment.