From 2a0cb8170c31ab17d5d18bc1a3060181566ef2f4 Mon Sep 17 00:00:00 2001 From: nepslor Date: Wed, 25 Sep 2024 16:09:05 +0200 Subject: [PATCH] added sample_weight to QRF --- pyforecaster/forecasting_models/randomforests.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyforecaster/forecasting_models/randomforests.py b/pyforecaster/forecasting_models/randomforests.py index 6e0fdbc..6045d64 100644 --- a/pyforecaster/forecasting_models/randomforests.py +++ b/pyforecaster/forecasting_models/randomforests.py @@ -85,23 +85,23 @@ def __init__(self, n_estimators=100, q_vect=None, val_ratio=None, nodes_at_step= "ccp_alpha":ccp_alpha } - def _fit(self, i, x, y): + def _fit(self, i, x, y, sample_weight=None): x_i = self.dataset_at_stepahead(x, i, self.metadata_features, formatter=self.formatter, logger=self.logger, method='periodic', keep_last_n_lags=self.keep_last_n_lags, keep_last_seconds=self.keep_last_seconds, tol_period=self.tol_period) - model = RandomForestQuantileRegressor(**self.qrf_pars).fit(x_i, y.iloc[:, i]) + model = RandomForestQuantileRegressor(**self.qrf_pars).fit(x_i, y.iloc[:, i], sample_weight=sample_weight) return model @encode_categorical - def fit(self, x, y): + def fit(self, x, y, sample_weight=None): x, y, x_val, y_val = self.train_val_split(x, y) if self.parallel: with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_parallel_workers) as executor: - self.models = [i for i in tqdm(executor.map(partial(self._fit, x=x, y=y), range(self.n_single)),total=self.n_single)] + self.models = [i for i in tqdm(executor.map(partial(self._fit, x=x, y=y, sample_weight=sample_weight), range(self.n_single)),total=self.n_single)] else: for i in tqdm(range(self.n_single)): - model = self._fit(i, x, y) + model = self._fit(i, x, y, sample_weight=sample_weight) self.models.append(model) n_sa = y.shape[1]