Skip to content

Commit

Permalink
check if parallel is true, also for predict (predict was slower in pa…
Browse files Browse the repository at this point in the history
…rallel mode)
  • Loading branch information
vascomedici committed Aug 8, 2024
1 parent c2a47af commit 4c2f05b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
9 changes: 7 additions & 2 deletions pyforecaster/forecasting_models/gradientboosters.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,13 @@ def predict_single(self, i, x):
return self.multi_step_model.predict(x,num_threads=1).reshape(-1, 1)

def predict_parallel(self, x):
with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor:
y_hat = list(executor.map(partial(self.predict_single, x=x), np.arange(self.n_multistep)))
if self.parallel:
with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor:
y_hat = list(executor.map(partial(self.predict_single, x=x), np.arange(self.n_multistep)))
else:
y_hat = []
for i in range(self.n_multistep):
y_hat.append(self.predict_single(i, x))
return np.hstack(y_hat)

@staticmethod
Expand Down
15 changes: 13 additions & 2 deletions tests/test_boosters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,19 @@ def test_linear_val_split(self):
y_hat_lin = m_lin.predict(x_te)
q = m_lin.predict_quantiles(x_te)

m_lgbhybrid = LGBMHybrid(red_frac_multistep=0.1, val_ratio=0.3, lgb_pars={'num_leaves': 300, 'n_estimators': 10, 'learning_rate':0.05},
n_single=10, parallel=True, formatter=formatter, metadata_features=['minuteofday', 'utc_offset', 'dayofweek', 'hour'],tol_period='1h', keep_last_seconds=3600).fit(x_tr, y_tr)
m_lgbhybrid = LGBMHybrid(red_frac_multistep=0.1, val_ratio=0.3,
lgb_pars={'num_leaves': 300, 'n_estimators': 10, 'learning_rate': 0.05},
n_single=10, parallel=False, formatter=formatter,
metadata_features=['minuteofday', 'utc_offset', 'dayofweek', 'hour'], tol_period='1h',
keep_last_seconds=3600).fit(x_tr, y_tr)
y_hat_lgbh = m_lgbhybrid.predict(x_te)
q = m_lgbhybrid.predict_quantiles(x_te)

m_lgbhybrid = LGBMHybrid(red_frac_multistep=0.1, val_ratio=0.3,
lgb_pars={'num_leaves': 300, 'n_estimators': 10, 'learning_rate': 0.05},
n_single=10, parallel=True, formatter=formatter,
metadata_features=['minuteofday', 'utc_offset', 'dayofweek', 'hour'], tol_period='1h',
keep_last_seconds=3600).fit(x_tr, y_tr)
y_hat_lgbh = m_lgbhybrid.predict(x_te)
q = m_lgbhybrid.predict_quantiles(x_te)

Expand Down

0 comments on commit 4c2f05b

Please sign in to comment.