diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index d899069..294769f 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -12,6 +12,8 @@ import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler from typing import Union +from jax import jit +from inspect import getmro def positive_lecun(key, shape, dtype=jnp.float32, init_type='normal'): # Start with standard lecun_normal initialization @@ -45,6 +47,7 @@ def __call__(self, x): if isinstance(self.n_layers, int): layers = np.ones(self.n_layers) * self.n_neurons layers[-1] = self.n_out + layers = layers.astype(int) else: layers = self.n_layers for i, n in enumerate(layers): @@ -53,6 +56,23 @@ def __call__(self, x): x = nn.relu(x) return x +def jitting_wrapper(fun, model): + return jit(partial(fun, model=model)) + +def loss_fn(params, inputs, targets, model=None): + predictions = model(params, inputs) + return jnp.mean((predictions - targets) ** 2) + +def train_step(params, optimizer_state, inputs_batch, targets_batch, model=None, loss_fn=None): + values, grads = value_and_grad(loss_fn)(params, inputs_batch, targets_batch) + updates, opt_state = model.update(grads, optimizer_state, params) + return optax.apply_updates(params, updates), opt_state, values + +def predict_batch(pars, inputs, model=None): + return model.apply(pars, inputs) + +def predict_batch_picnn(pars, inputs, model=None): + return model.apply(pars, *inputs) class NN(ScenarioGenerator): scaler: StandardScaler = None @@ -60,98 +80,143 @@ class NN(ScenarioGenerator): batch_size: int = None load_path: str = None n_out: int = None - n_epochs: int = 10 - savepath_tr_plots: str = None + n_layers: int = None + n_hidden_x: int = None + pars: dict = None + n_epochs:int = 10 + savepath_tr_plots:str = None stats_step: int = 50 rel_tol: float = 1e-4 + unnormalized_inputs: list = None + to_be_normalized:list = None + target_columns: list = None + iterate = None + def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, + n_hidden_x: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None, + val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, + stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, **scengen_kwgs): + super().__init__(q_vect, val_ratio=val_ratio, nodes_at_step=nodes_at_step, **scengen_kwgs) + self.set_attr({"learning_rate": learning_rate, + "batch_size": batch_size, + "load_path": load_path, + "n_hidden_x": n_hidden_x, + "n_out": n_out, + "n_layers": n_layers, + "pars": pars, + "n_epochs": n_epochs, + "savepath_tr_plots": savepath_tr_plots, + "stats_step": stats_step, + "rel_tol": rel_tol, + "unnormalized_inputs": unnormalized_inputs + }) - def __init__(self, n_out=1, q_vect=None, n_epochs=10, val_ratio=None, nodes_at_step=None, learning_rate=1e-3, - nn_module=None, scengen_dict={}, batch_size=None, **model_kwargs): - super().__init__(q_vect, val_ratio=val_ratio, nodes_at_step=nodes_at_step, **scengen_dict) - model = nn_module(n_out=n_out, **model_kwargs) - self.batch_size = batch_size - self.learning_rate = learning_rate - self.model = model - self.optimizer = optax.adam(learning_rate=self.learning_rate) - self.n_epochs = n_epochs + if self.load_path is not None: + self.load(self.load_path) + self.model = self.set_arch() + self.optimizer = optax.adamw(learning_rate=self.learning_rate) - @jit - def loss_fn(params, x, y): - predictions = model.apply(params, x) - return jnp.mean((predictions - y) ** 2) - @jit - def train_step(params, optimizer_state, x_batch, y_batch): - values, grads = value_and_grad(loss_fn)(params, x_batch, y_batch) - updates, opt_state = self.optimizer.update(grads, optimizer_state, params) - return optax.apply_updates(params, updates), opt_state, values - @jit - @partial(vmap, in_axes=(None, 0)) - def predict_batch(pars, x): - return model.apply(pars, x) + self.predict_batch = vmap(jitting_wrapper(predict_batch, self.model), in_axes=(None, 0)) + self.loss_fn = jitting_wrapper(loss_fn, self.predict_batch) + self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) - self.train_step = train_step - self.loss_fn = loss_fn - self.predict_batch = predict_batch - self.iterate = None + def get_class_properties_names(cls): + attributes = [] + cls = cls if isinstance(cls, type) else cls.__class__ + # Loop through the MRO (Method Resolution Order) to include parent classes + for base_class in getmro(cls): + attributes.extend( + key for key, value in base_class.__dict__.items() + if not callable(value) + and not key.startswith('__') + and not isinstance(value, (classmethod, staticmethod)) + ) + return list(set(attributes)) # Remove duplicates + + def get_class_properties(self): + return {k: getattr(self, k) for k in self.get_class_properties_names()} + + def save(self, save_path): + attrdict = self.get_class_properties() + with open(save_path, 'wb') as f: + pk.dump(attrdict, f, protocol=pk.HIGHEST_PROTOCOL) + + def set_attr(self, attrdict): + [self.__setattr__(k, v) for k, v in attrdict.items()] + + def load(self, save_path): + with open(save_path, 'rb') as f: + attrdict = pk.load(f) + self.set_attr(attrdict) @staticmethod def init_arch(nn_init, n_inputs_x=1): "divides data into training and test sets " key1, key2 = random.split(random.key(0)) - x = random.normal(key1, (n_inputs_x, )) # Dummy input data (for the first input) + x = random.normal(key1, (n_inputs_x,)) # Dummy input data (for the first input) init_params = nn_init.init(key2, x) # Initialization call return init_params - def fit(self, x, y, n_epochs=None, savepath_tr_plots=None, stats_step=None, rel_tol=None): + def set_arch(self): + model = FeedForwardModule(n_layers=self.n_layers, n_neurons=self.n_hidden_x, + n_out=self.n_out) + return model + + def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step=None, rel_tol=None): + self.to_be_normalized = [c for c in inputs.columns if + c not in self.unnormalized_inputs] if self.unnormalized_inputs is not None else inputs.columns rel_tol = rel_tol if rel_tol is not None else self.rel_tol n_epochs = n_epochs if n_epochs is not None else self.n_epochs stats_step = stats_step if stats_step is not None else self.stats_step - self.scaler = StandardScaler().set_output(transform='pandas').fit(x) + self.scaler = StandardScaler().set_output(transform='pandas').fit(inputs[self.to_be_normalized]) + + inputs, targets, inputs_val_0, targets_val = self.train_val_split(inputs, targets) + training_len = inputs.shape[0] + validation_len = inputs_val_0.shape[0] - x, y, x_val_0, y_val = self.train_val_split(x, y) - self.target_columns = y.columns + self.target_columns = targets.columns + batch_size = self.batch_size if self.batch_size is not None else inputs.shape[0] // 10 + num_batches = inputs.shape[0] // batch_size - batch_size = self.batch_size if self.batch_size is not None else x.shape[0] // 10 - num_batches = y.shape[0] // batch_size - y = y.values - x = self.get_normalized_inputs(x) - x_val = self.get_normalized_inputs(x_val_0) + inputs = self.get_normalized_inputs(inputs) + inputs_val = self.get_normalized_inputs(inputs_val_0) + inputs_len = [i.shape[1] for i in inputs] if isinstance(inputs, tuple) else inputs.shape[1] - pars = self.init_arch(self.model, x.shape[1]) + pars = self.init_arch(self.model, *np.atleast_1d(inputs_len)) opt_state = self.optimizer.init(pars) tr_loss, val_loss = [], [] k = 0 finished = False for epoch in range(n_epochs): - rand_idx_all = np.random.choice(x.shape[0], x.shape[0], replace=False) + rand_idx_all = np.random.choice(training_len, training_len, replace=False) for i in tqdm(range(num_batches), desc='epoch {}/{}'.format(epoch, n_epochs)): - rand_idx = rand_idx_all[i*batch_size:(i+1)*batch_size] - x_batch = x[rand_idx, :] - y_batch = y[rand_idx, :] + rand_idx = rand_idx_all[i * batch_size:(i + 1) * batch_size] + inputs_batch = [i[rand_idx, :] for i in inputs] if isinstance(inputs, tuple) else inputs[rand_idx, :] + targets_batch = targets.values[rand_idx, :] - pars, opt_state, values = self.train_step(pars, opt_state, x_batch, y_batch) + pars, opt_state, values = self.train_step(pars, opt_state, inputs_batch, targets_batch) if k % stats_step == 0 and k > 0: self.pars = pars - te_loss_i = self.loss_fn(pars, x_val, y_val.values) - tr_loss_i = self.loss_fn(pars, x, y) + te_loss_i = self.loss_fn(pars, inputs_val, targets_val.values) + tr_loss_i = self.loss_fn(pars, inputs, targets.values) val_loss.append(np.array(jnp.mean(te_loss_i))) tr_loss.append(np.array(jnp.mean(tr_loss_i))) - self.logger.warning('tr loss: {:0.2e}, te loss: {:0.2e}'.format(tr_loss[-1], val_loss[-1])) + self.logger.info('tr loss: {:0.2e}, te loss: {:0.2e}'.format(tr_loss[-1], val_loss[-1])) if len(tr_loss) > 1: if savepath_tr_plots is not None or self.savepath_tr_plots is not None: savepath_tr_plots = savepath_tr_plots if savepath_tr_plots is not None else self.savepath_tr_plots - rand_idx_plt = np.random.choice(x_val.shape[0], 9) - self.training_plots(x_val[rand_idx_plt, :], - y_val.values[rand_idx_plt, :], tr_loss, val_loss, savepath_tr_plots, k) + + rand_idx_plt = np.random.choice(validation_len, 9) + self.training_plots([i[rand_idx_plt, :] for i in inputs] if isinstance(inputs, tuple) else inputs_val[rand_idx_plt, :], + targets_val.values[rand_idx_plt, :], tr_loss, val_loss, savepath_tr_plots, k) rel_te_err = (val_loss[-2] - val_loss[-1]) / np.abs(val_loss[-2] + 1e-6) - if rel_te_err