diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index ac25913..24bd2b5 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -29,6 +29,144 @@ def positive_lecun(key, shape, dtype=jnp.float32, init_type='normal'): def identity(x): return x +class LinRegModule(nn.Module): + n_out: int + @nn.compact + def __call__(self, x): + x = nn.Dense(features=self.n_out, name='dense')(x) + return x + +class FastLinReg(ScenarioGenerator): + scaler: StandardScaler = None + learning_rate: float = 0.01 + batch_size: int = None + load_path: str = None + n_out: int = None + n_epochs: int = 10 + savepath_tr_plots: str = None + stats_step: int = 50 + rel_tol: float = 1e-4 + + + def __init__(self, n_out=1, q_vect=None, n_epochs=10, val_ratio=None, nodes_at_step=None, learning_rate=1e-3, **scengen_kwgs): + super().__init__(q_vect, val_ratio=val_ratio, nodes_at_step=nodes_at_step, **scengen_kwgs) + model = LinRegModule(n_out) + self.learning_rate = learning_rate + self.model = model + self.optimizer = optax.adam(learning_rate=self.learning_rate) + self.n_epochs = n_epochs + + @jit + def loss_fn(params, x, y): + predictions = model.apply(params, x) + return jnp.mean((predictions - y) ** 2) + @jit + def train_step(params, optimizer_state, x_batch, y_batch): + values, grads = value_and_grad(loss_fn)(params, x_batch, y_batch) + updates, opt_state = self.optimizer.update(grads, optimizer_state, params) + return optax.apply_updates(params, updates), opt_state, values + + @jit + @partial(vmap, in_axes=(None, 0)) + def predict_batch(pars, x): + return model.apply(pars, x) + + self.train_step = train_step + self.loss_fn = loss_fn + self.predict_batch = predict_batch + self.iterate = None + + @staticmethod + def init_arch(nn_init, n_inputs_x=1): + "divides data into training and test sets " + key1, key2 = random.split(random.key(0)) + x = random.normal(key1, (n_inputs_x, )) # Dummy input data (for the first input) + init_params = nn_init.init(key2, x) # Initialization call + return init_params + + def fit(self, x, y, n_epochs=None, savepath_tr_plots=None, stats_step=None, rel_tol=None): + rel_tol = rel_tol if rel_tol is not None else self.rel_tol + n_epochs = n_epochs if n_epochs is not None else self.n_epochs + stats_step = stats_step if stats_step is not None else self.stats_step + self.scaler = StandardScaler().set_output(transform='pandas').fit(x) + + x, y, x_val_0, y_val = self.train_val_split(x, y) + self.target_columns = y.columns + + batch_size = self.batch_size if self.batch_size is not None else x.shape[0] // 10 + num_batches = y.shape[0] // batch_size + y = y.values + x = self.get_inputs(x) + x_val = self.get_inputs(x_val_0) + + pars = self.init_arch(self.model, x.shape[1]) + opt_state = self.optimizer.init(pars) + + tr_loss, val_loss = [], [] + k = 0 + finished = False + for epoch in range(n_epochs): + rand_idx_all = np.random.choice(x.shape[0], x.shape[0], replace=False) + for i in tqdm(range(num_batches), desc='epoch {}/{}'.format(epoch, n_epochs)): + rand_idx = rand_idx_all[i*batch_size:(i+1)*batch_size] + x_batch = x[rand_idx, :] + y_batch = y[rand_idx, :] + + pars, opt_state, values = self.train_step(pars, opt_state, x_batch, y_batch) + + if k % stats_step == 0 and k > 0: + self.pars = pars + + te_loss_i = self.loss_fn(pars, x_val, y_val.values) + tr_loss_i = self.loss_fn(pars, x, y) + val_loss.append(np.array(jnp.mean(te_loss_i))) + tr_loss.append(np.array(jnp.mean(tr_loss_i))) + + self.logger.warning('tr loss: {:0.2e}, te loss: {:0.2e}'.format(tr_loss[-1], val_loss[-1])) + if len(tr_loss) > 1: + if savepath_tr_plots is not None or self.savepath_tr_plots is not None: + savepath_tr_plots = savepath_tr_plots if savepath_tr_plots is not None else self.savepath_tr_plots + rand_idx_plt = np.random.choice(x_val.shape[0], 9) + self.training_plots(x_val[rand_idx_plt, :], + y_val.values[rand_idx_plt, :], tr_loss, val_loss, savepath_tr_plots, k) + + rel_te_err = (val_loss[-2] - val_loss[-1]) / np.abs(val_loss[-2] + 1e-6) + if rel_te_err None: @@ -49,6 +49,40 @@ def test_hw(self): plt.pause(0.0001) plt.close('all') + def test_fast_linreg(self): + + formatter = Formatter(logger=self.logger).add_transform(['all'], lags=np.arange(24), + relative_lags=True) + formatter.add_transform(['all'], ['min', 'max'], agg_bins=[1, 2, 15, 20]) + formatter.add_target_transform(['all'], lags=-np.arange(6)) + x, y = formatter.transform(self.data.iloc[:1000]) + x.columns = x.columns.astype(str) + y.columns = y.columns.astype(str) + n_tr = int(len(x) * 0.8) + x_tr, x_te, y_tr, y_te = [x.iloc[:n_tr, :].copy(), x.iloc[n_tr:, :].copy(), y.iloc[:n_tr].copy(), + y.iloc[n_tr:].copy()] + + formatter_fast = Formatter(logger=self.logger).add_transform(['all'], lags=np.arange(24), + relative_lags=True) + x_fast, y_fast = formatter.transform(self.data.iloc[:1000]) + x_fast.columns = x_fast.columns.astype(str) + y_fast.columns = y_fast.columns.astype(str) + n_tr = int(len(x_fast) * 0.8) + x_fast_tr, x_fast_te, y_fast_tr, y_fast_te = [x_fast.iloc[:n_tr, :].copy(), x_fast.iloc[n_tr:, :].copy(), y_fast.iloc[:n_tr].copy(), + y_fast.iloc[n_tr:].copy()] + + + m_lin = LinearForecaster(val_ratio=0.2, fit_intercept=False, normalize=False).fit(x_tr, y_tr) + m_fast_lin = FastLinReg(val_ratio=0.2, fit_intercept=False, normalize=False, n_out=y_tr.shape[1], learning_rate=10).fit(x_fast_tr, y_fast_tr, n_epochs=100) + + y_hat = m_lin.predict(x_te) + y_hat_fast = m_fast_lin.predict(x_fast_te) + + s_a = 5 + y_te.iloc[:, s_a].plot() + y_hat.iloc[:, s_a].plot() + (y_hat_fast.iloc[:, s_a]).plot() + def test_hw_difficult(self): n_tr = int(len(self.x) * 0.5)