diff --git a/pyforecaster/forecasting_models/randomforests.py b/pyforecaster/forecasting_models/randomforests.py index f49293a..bf035c9 100644 --- a/pyforecaster/forecasting_models/randomforests.py +++ b/pyforecaster/forecasting_models/randomforests.py @@ -2,7 +2,7 @@ import pandas as pd from pyforecaster.utilities import get_logger from tqdm import tqdm -from pyforecaster.forecaster import ScenarioGenerator +from pyforecaster.forecaster import ScenarioGenerator, encode_categorical import numpy as np import concurrent.futures from time import time @@ -94,6 +94,7 @@ def _fit(self, i, x, y): model = RandomForestQuantileRegressor(**self.qrf_pars).fit(x_i, y.iloc[:, i], sparse_pickle=True) return model + @encode_categorical def fit(self, x, y): x, y, x_val, y_val = self.train_val_split(x, y) if self.parallel: @@ -135,6 +136,7 @@ def fit(self, x, y): super().fit(x_val, y_val) return self + @encode_categorical def predict(self, x, **kwargs): preds = [] period = kwargs['period'] if 'period' in kwargs else '24h'