Skip to content

Commit

Permalink
added parallel tuning of hyperpars
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Apr 9, 2024
1 parent 6474478 commit 44bcf14
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
50 changes: 32 additions & 18 deletions pyforecaster/forecasting_models/holtwinters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,24 @@ def hankel(x, n, generate_me=None):
h.append(x[w + g].reshape(-1, 1))
return np.hstack(h)

def tune_hyperpars(x, model_class, hyperpars, n_trials=100, targets_names=None, verbose=True, return_model=True, **model_init_kwargs):
def fit_sample(p, hyperpars,model_class, model_init_kwargs, x, targets_names=None):
pars_dict = {k: v for k, v in zip(hyperpars.keys(), p)}
model_init_kwargs.update(pars_dict)
model = model_class(**model_init_kwargs)
if targets_names is not None:
subscores = []
for c in targets_names:
model_init_kwargs.update({'target_name': c, 'targets_names': None})
model = model_class(**model_init_kwargs)
# features are all x columns except the targets and current target c
features_names = [n for n in x.columns if n not in set(targets_names) - {c}]
subscores.append(score_autoregressive(model, x[features_names], target_name=c))
score = np.mean(subscores)
else:
score = score_autoregressive(model, x)
return score

def tune_hyperpars(x, model_class, hyperpars, n_trials=100, targets_names=None, verbose=True, return_model=True, parallel=True, **model_init_kwargs):
"""
:param x: pd.DataFrame (n, n_cov)
:param y: pd.Series (n)
Expand All @@ -37,23 +54,20 @@ def tune_hyperpars(x, model_class, hyperpars, n_trials=100, targets_names=None,
pars_cartridge = np.random.rand(n_trials, n_pars) * (par_maxima-par_minima) + par_minima
model_init_kwargs.update({'optimize_hyperpars':False})
model_init_kwargs_0 = deepcopy(model_init_kwargs)
scores = []
generator = tqdm(pars_cartridge) if verbose else pars_cartridge
for i, p in enumerate(generator):
pars_dict = {k:v for k, v in zip(hyperpars.keys(), p)}
model_init_kwargs.update(pars_dict)
model = model_class(**model_init_kwargs)
if targets_names is not None:
subscores = []
for c in targets_names:
model_init_kwargs.update({'target_name':c, 'targets_names':None})
model = model_class(**model_init_kwargs)
# features are all x columns except the targets and current target c
features_names = [n for n in x.columns if n not in set(targets_names) - {c}]
subscores.append(score_autoregressive(model, x[features_names], target_name=c))
scores.append(np.mean(subscores))
else:
scores.append(score_autoregressive(model, x))


if parallel:
from concurrent.futures import ProcessPoolExecutor
with ProcessPoolExecutor() as executor:
scores = list(tqdm(executor.map(partial(fit_sample, hyperpars=hyperpars, model_class=model_class,
model_init_kwargs=model_init_kwargs, x=x, targets_names=targets_names),
pars_cartridge), total=n_trials, desc='Tuning hyperpars for {}'.format(model_class.__name__)))

else:
generator = tqdm(pars_cartridge) if verbose else pars_cartridge
scores = []
for i, p in enumerate(generator):
scores.append(fit_sample(p, hyperpars, model_class, model_init_kwargs, x, targets_names=targets_names))
if verbose:
plt.figure()
t = np.linspace(0.01, 0.99, 30)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nns.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,8 @@ def test_invertible_causal_nn(self):
y_hat = m.predict(e_te.iloc[:, :144])

m = CausalInvertibleNN(learning_rate=1e-2, batch_size=200, load_path=None, n_hidden_x=144,
n_layers=2, normalize_target=False, n_epochs=5, stopping_rounds=20, rel_tol=-1,
end_to_end='full', n_hidden_y=300, n_prediction_layers=3, n_out=144).fit(e_tr.iloc[:, :144], e_tr.iloc[:, -144:])
n_layers=3, normalize_target=False, n_epochs=5, stopping_rounds=20, rel_tol=-1,
end_to_end='quasi', n_hidden_y=300, n_prediction_layers=3, n_out=144).fit(e_tr.iloc[:, :144], e_tr.iloc[:, -144:])

z_hat_ete = m.predict(e_te.iloc[:, :144])

Expand Down

0 comments on commit 44bcf14

Please sign in to comment.