Skip to content

Commit

Permalink
added corr reordering in global model formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Nov 2, 2023
1 parent 17e9e25 commit 70d9bcc
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 13 deletions.
140 changes: 139 additions & 1 deletion pyforecaster/forecasting_models/neural_forecasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,144 @@ def positive_lecun(key, shape, dtype=jnp.float32, init_type='normal'):
def identity(x):
return x

class LinRegModule(nn.Module):
n_out: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.n_out, name='dense')(x)
return x

class FastLinReg(ScenarioGenerator):
scaler: StandardScaler = None
learning_rate: float = 0.01
batch_size: int = None
load_path: str = None
n_out: int = None
n_epochs: int = 10
savepath_tr_plots: str = None
stats_step: int = 50
rel_tol: float = 1e-4


def __init__(self, n_out=1, q_vect=None, n_epochs=10, val_ratio=None, nodes_at_step=None, learning_rate=1e-3, **scengen_kwgs):
super().__init__(q_vect, val_ratio=val_ratio, nodes_at_step=nodes_at_step, **scengen_kwgs)
model = LinRegModule(n_out)
self.learning_rate = learning_rate
self.model = model
self.optimizer = optax.adam(learning_rate=self.learning_rate)
self.n_epochs = n_epochs

@jit
def loss_fn(params, x, y):
predictions = model.apply(params, x)
return jnp.mean((predictions - y) ** 2)
@jit
def train_step(params, optimizer_state, x_batch, y_batch):
values, grads = value_and_grad(loss_fn)(params, x_batch, y_batch)
updates, opt_state = self.optimizer.update(grads, optimizer_state, params)
return optax.apply_updates(params, updates), opt_state, values

@jit
@partial(vmap, in_axes=(None, 0))
def predict_batch(pars, x):
return model.apply(pars, x)

self.train_step = train_step
self.loss_fn = loss_fn
self.predict_batch = predict_batch
self.iterate = None

@staticmethod
def init_arch(nn_init, n_inputs_x=1):
"divides data into training and test sets "
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (n_inputs_x, )) # Dummy input data (for the first input)
init_params = nn_init.init(key2, x) # Initialization call
return init_params

def fit(self, x, y, n_epochs=None, savepath_tr_plots=None, stats_step=None, rel_tol=None):
rel_tol = rel_tol if rel_tol is not None else self.rel_tol
n_epochs = n_epochs if n_epochs is not None else self.n_epochs
stats_step = stats_step if stats_step is not None else self.stats_step
self.scaler = StandardScaler().set_output(transform='pandas').fit(x)

x, y, x_val_0, y_val = self.train_val_split(x, y)
self.target_columns = y.columns

batch_size = self.batch_size if self.batch_size is not None else x.shape[0] // 10
num_batches = y.shape[0] // batch_size
y = y.values
x = self.get_inputs(x)
x_val = self.get_inputs(x_val_0)

pars = self.init_arch(self.model, x.shape[1])
opt_state = self.optimizer.init(pars)

tr_loss, val_loss = [], []
k = 0
finished = False
for epoch in range(n_epochs):
rand_idx_all = np.random.choice(x.shape[0], x.shape[0], replace=False)
for i in tqdm(range(num_batches), desc='epoch {}/{}'.format(epoch, n_epochs)):
rand_idx = rand_idx_all[i*batch_size:(i+1)*batch_size]
x_batch = x[rand_idx, :]
y_batch = y[rand_idx, :]

pars, opt_state, values = self.train_step(pars, opt_state, x_batch, y_batch)

if k % stats_step == 0 and k > 0:
self.pars = pars

te_loss_i = self.loss_fn(pars, x_val, y_val.values)
tr_loss_i = self.loss_fn(pars, x, y)
val_loss.append(np.array(jnp.mean(te_loss_i)))
tr_loss.append(np.array(jnp.mean(tr_loss_i)))

self.logger.warning('tr loss: {:0.2e}, te loss: {:0.2e}'.format(tr_loss[-1], val_loss[-1]))
if len(tr_loss) > 1:
if savepath_tr_plots is not None or self.savepath_tr_plots is not None:
savepath_tr_plots = savepath_tr_plots if savepath_tr_plots is not None else self.savepath_tr_plots
rand_idx_plt = np.random.choice(x_val.shape[0], 9)
self.training_plots(x_val[rand_idx_plt, :],
y_val.values[rand_idx_plt, :], tr_loss, val_loss, savepath_tr_plots, k)

rel_te_err = (val_loss[-2] - val_loss[-1]) / np.abs(val_loss[-2] + 1e-6)
if rel_te_err<rel_tol:
finished = True
break
k += 1
if finished:
break

self.pars = pars
super().fit(x_val_0, y_val)
return self
def get_inputs(self, inputs):
inputs = self.scaler.transform(inputs)
return inputs.values

def training_plots(self, x, y, target, tr_loss, te_loss, savepath, k):
n_instances = x.shape[0]
y_hat = self.predict_batch(self.pars, x, y)

# make the appropriate numbers of subplots disposed as a square
fig, ax = plt.subplots(int(np.ceil(np.sqrt(n_instances))), int(np.ceil(np.sqrt(n_instances))))
for i, a in enumerate(ax.ravel()):
a.plot(y_hat[i, :])
a.plot(target[i, :])
a.set_title('instance {}, iter {}'.format(i, k))
plt.savefig(join(savepath, 'examples_iter_{:05d}.png'.format(k)))

fig, ax = plt.subplots(1, 1)
ax.plot(np.array(tr_loss), label='tr_loss')
ax.plot(np.array(te_loss), label='te_loss')
ax.legend()
plt.savefig(join(savepath, 'losses_iter_{:05d}.png'.format(k)))

def predict(self, inputs, **kwargs):
x = self.get_inputs(inputs)
y_hat = self.predict_batch(self.pars, x)
return pd.DataFrame(y_hat, index=inputs.index, columns=self.target_columns)

class PICNNLayer(nn.Module):

Expand Down Expand Up @@ -304,7 +442,7 @@ def iterate(x, y, opt_state):
return y, values
self.iterate = iterate

opt_state = self.optimizer.init(y)
opt_state = self.inverter_optimizer.init(y)
y, values_old = self.iterate(x, y, opt_state)
values_init = np.copy(values_old)

Expand Down
43 changes: 32 additions & 11 deletions pyforecaster/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def add_target_normalizer(self, target, function:str=None, agg_freq:str=None, la
self.target_normalizers.append(transformer)

def transform(self, x, time_features=True, holidays=False, return_target=True, global_form=False, parallel=False,
reduce_memory=True, **holidays_kwargs):
reduce_memory=True, corr_reorder=False, **holidays_kwargs):
"""
Takes the DataFrame x and applies the specified transformations stored in the transformers in order to obtain
the pre-fold-transformed dataset: this dataset has the correct final dimensions, but fold-specific
Expand All @@ -139,6 +139,14 @@ def transform(self, x, time_features=True, holidays=False, return_target=True, g
forecasted with a global model. In this case, all target transform must refer to a "target"
column, which is the stacking of the independent signals. An additional column "name" is
added to the transformed dataset, which contains the name of the signal to be forecasted.
This is useful for stacking models.
:param parallel: if True, parallelize the transformation of the dataset. This is useful if you have a lot of
signals to transform and you have a lot of cores. If you have a lot of signals but not a lot
of cores, you can set parallel=False and the transformation will be done in a single core
but with a single pass on the dataset. This is useful if you have a lot of signals but not
a lot of cores.
:param reduce_memory: if True, reduce memory usage by casting float64 to float32 and int64 to int32
:param corr_reorder: if True, reorder columns of the transformed dataset by correlation with the target
:return x, target: the transformed dataset and the target DataFrame with correct dimensions
"""
Expand All @@ -155,20 +163,13 @@ def transform(self, x, time_features=True, holidays=False, return_target=True, g
x, y = fdf_parallel(f=partial(self._transform, time_features=time_features, holidays=holidays,
return_target=return_target, **holidays_kwargs),
df=dfs[n_cpu * i:n_cpu * (i + 1)])
if reduce_memory:
x = reduce_mem_usage(x, use_ray=True)
y = reduce_mem_usage(y, use_ray=True)
xs.append(x)
ys.append(y)
xs, ys = self.global_form_postprocess(x, y, xs, ys, reduce_memory=reduce_memory, corr_reorder=corr_reorder)
else:
for df_i in dfs:
x, y = self._transform(df_i, time_features=time_features, holidays=holidays,
return_target=return_target, **holidays_kwargs)
if reduce_memory:
x = reduce_mem_usage(x, use_ray=False, parallel=False)
y = reduce_mem_usage(y, use_ray=False, parallel=False)
xs.append(x)
ys.append(y)
xs, ys = self.global_form_postprocess(x, y, xs, ys, reduce_memory=reduce_memory, corr_reorder=corr_reorder)

x = pd.concat(xs)
target = pd.concat(ys)
else:
Expand Down Expand Up @@ -509,6 +510,26 @@ def global_form_preprocess(self, x):
pd.DataFrame(c, columns=['name'], index=x.index)],
axis=1))
return dfs

def global_form_postprocess(self, x, y, xs, ys, reduce_memory=False, corr_reorder=False):

if reduce_memory:
x = reduce_mem_usage(x, use_ray=True)
y = reduce_mem_usage(y, use_ray=True)

# for all transformations
for tr in self.transformers:
# for all the features of the transformation
for v in np.unique(tr.metadata.name):
# reorder columns by correlation with first target
transformed_cols_names = tr.metadata.loc[tr.metadata['name']==v].index
if corr_reorder:
corr = x[transformed_cols_names].corrwith(y.iloc[:, 0])
transformed_cols_names_reordered = corr.sort_values(ascending=False).index
x.loc[:, transformed_cols_names] = x.loc[:, transformed_cols_names_reordered].values
xs.append(x)
ys.append(y)
return xs, ys
class Transformer:
"""
Defines and applies transformations through rolling time windows and lags
Expand Down
14 changes: 14 additions & 0 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,20 @@ def test_global_multiindex(self):
formatter.add_target_transform(['target'], ['mean'], agg_bins=[-10, -15, -20])
df = formatter.transform(df_mi, time_features=True, holidays=True, prov='ZH',global_form=True)

def test_global_multiindex_with_col_reordering(self):
x_private = pd.DataFrame(np.random.randn(500, 15), index=pd.date_range('01-01-2020', '01-05-2020', 500, tz='Europe/Zurich'), columns=pd.MultiIndex.from_product([['b1', 'b2', 'b3'], ['a', 'b', 'c', 'd', 'e']]))
x_shared = pd.DataFrame(np.random.randn(500, 5), index=pd.date_range('01-01-2020', '01-05-2020', 500, tz='Europe/Zurich'), columns=pd.MultiIndex.from_product([['shared'], [0, 1, 2, 3, 4]]))

df_mi = pd.concat([x_private, x_shared], axis=1)

formatter = pyf.Formatter().add_transform([0,1 , 2, 3, 4], lags=np.arange(10), agg_freq='20min',
relative_lags=True)
formatter.add_transform(['a','b', 'c', 'd'], lags=np.arange(10),
agg_freq='20min',
relative_lags=True)
formatter.add_target_transform(['target'], ['mean'], agg_bins=[-10, -15, -20])
df = formatter.transform(df_mi, time_features=True, holidays=True, prov='ZH',global_form=True, corr_reorder=True, parallel=False ,reduce_memory=False)


def test_normalizers(self):
df = pd.DataFrame(np.random.randn(100, 5), index=pd.date_range('01-01-2020', freq='20min', periods=100, tz='Europe/Zurich'), columns=['a', 'b', 'c', 'd', 'e'])
Expand Down
36 changes: 35 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pyforecaster.forecaster import LinearForecaster, LGBForecaster
from pyforecaster.plot_utils import plot_quantiles
from pyforecaster.formatter import Formatter

from pyforecaster.forecasting_models.neural_forecasters import FastLinReg

class TestFormatDataset(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -49,6 +49,40 @@ def test_hw(self):
plt.pause(0.0001)
plt.close('all')

def test_fast_linreg(self):

formatter = Formatter(logger=self.logger).add_transform(['all'], lags=np.arange(24),
relative_lags=True)
formatter.add_transform(['all'], ['min', 'max'], agg_bins=[1, 2, 15, 20])
formatter.add_target_transform(['all'], lags=-np.arange(6))
x, y = formatter.transform(self.data.iloc[:1000])
x.columns = x.columns.astype(str)
y.columns = y.columns.astype(str)
n_tr = int(len(x) * 0.8)
x_tr, x_te, y_tr, y_te = [x.iloc[:n_tr, :].copy(), x.iloc[n_tr:, :].copy(), y.iloc[:n_tr].copy(),
y.iloc[n_tr:].copy()]

formatter_fast = Formatter(logger=self.logger).add_transform(['all'], lags=np.arange(24),
relative_lags=True)
x_fast, y_fast = formatter.transform(self.data.iloc[:1000])
x_fast.columns = x_fast.columns.astype(str)
y_fast.columns = y_fast.columns.astype(str)
n_tr = int(len(x_fast) * 0.8)
x_fast_tr, x_fast_te, y_fast_tr, y_fast_te = [x_fast.iloc[:n_tr, :].copy(), x_fast.iloc[n_tr:, :].copy(), y_fast.iloc[:n_tr].copy(),
y_fast.iloc[n_tr:].copy()]


m_lin = LinearForecaster(val_ratio=0.2, fit_intercept=False, normalize=False).fit(x_tr, y_tr)
m_fast_lin = FastLinReg(val_ratio=0.2, fit_intercept=False, normalize=False, n_out=y_tr.shape[1], learning_rate=10).fit(x_fast_tr, y_fast_tr, n_epochs=100)

y_hat = m_lin.predict(x_te)
y_hat_fast = m_fast_lin.predict(x_fast_te)

s_a = 5
y_te.iloc[:, s_a].plot()
y_hat.iloc[:, s_a].plot()
(y_hat_fast.iloc[:, s_a]).plot()

def test_hw_difficult(self):

n_tr = int(len(self.x) * 0.5)
Expand Down

0 comments on commit 70d9bcc

Please sign in to comment.