diff --git a/pyforecaster/forecasting_models/gradientboosters.py b/pyforecaster/forecasting_models/gradientboosters.py index 27d16b8..522623e 100644 --- a/pyforecaster/forecasting_models/gradientboosters.py +++ b/pyforecaster/forecasting_models/gradientboosters.py @@ -145,8 +145,13 @@ def predict_single(self, i, x): return self.multi_step_model.predict(x,num_threads=1).reshape(-1, 1) def predict_parallel(self, x): - with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor: - y_hat = list(executor.map(partial(self.predict_single, x=x), np.arange(self.n_multistep))) + if self.parallel: + with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor: + y_hat = list(executor.map(partial(self.predict_single, x=x), np.arange(self.n_multistep))) + else: + y_hat = [] + for i in range(self.n_multistep): + y_hat.append(self.predict_single(i, x)) return np.hstack(y_hat) @staticmethod diff --git a/tests/test_boosters.py b/tests/test_boosters.py index bb61ebd..62771e0 100644 --- a/tests/test_boosters.py +++ b/tests/test_boosters.py @@ -32,8 +32,19 @@ def test_linear_val_split(self): y_hat_lin = m_lin.predict(x_te) q = m_lin.predict_quantiles(x_te) - m_lgbhybrid = LGBMHybrid(red_frac_multistep=0.1, val_ratio=0.3, lgb_pars={'num_leaves': 300, 'n_estimators': 10, 'learning_rate':0.05}, - n_single=10, parallel=True, formatter=formatter, metadata_features=['minuteofday', 'utc_offset', 'dayofweek', 'hour'],tol_period='1h', keep_last_seconds=3600).fit(x_tr, y_tr) + m_lgbhybrid = LGBMHybrid(red_frac_multistep=0.1, val_ratio=0.3, + lgb_pars={'num_leaves': 300, 'n_estimators': 10, 'learning_rate': 0.05}, + n_single=10, parallel=False, formatter=formatter, + metadata_features=['minuteofday', 'utc_offset', 'dayofweek', 'hour'], tol_period='1h', + keep_last_seconds=3600).fit(x_tr, y_tr) + y_hat_lgbh = m_lgbhybrid.predict(x_te) + q = m_lgbhybrid.predict_quantiles(x_te) + + m_lgbhybrid = LGBMHybrid(red_frac_multistep=0.1, val_ratio=0.3, + lgb_pars={'num_leaves': 300, 'n_estimators': 10, 'learning_rate': 0.05}, + n_single=10, parallel=True, formatter=formatter, + metadata_features=['minuteofday', 'utc_offset', 'dayofweek', 'hour'], tol_period='1h', + keep_last_seconds=3600).fit(x_tr, y_tr) y_hat_lgbh = m_lgbhybrid.predict(x_te) q = m_lgbhybrid.predict_quantiles(x_te)