Skip to content

Commit

Permalink
added default parallelism on transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Oct 20, 2023
1 parent ebc7df4 commit 5eefefa
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions pyforecaster/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from itertools import product
from multiprocessing import cpu_count
from typing import Union
import concurrent.futures

import holidays as holidays_api
import numpy as np
Expand All @@ -23,7 +24,7 @@ class Formatter:
:param augment: if true, doesn't discard original columns of the dataset. Could be helpful to discard most
recent data if you don't have at prediction time.
"""
def __init__(self, logger=None, augment=True, dt=None):
def __init__(self, logger=None, augment=True, dt=None, n_parallel=None):
self.logger = get_logger(level=logging.WARNING, name='Formatter') if logger is None else logger
self.transformers = []
self.fold_transformers = []
Expand All @@ -32,6 +33,7 @@ def __init__(self, logger=None, augment=True, dt=None):
self.augment = augment
self.timezone = None
self.dt = dt
self.n_parallel = n_parallel if n_parallel is not None else cpu_count()

def add_time_features(self, x):
tz = x.index[0].tz
Expand Down Expand Up @@ -173,6 +175,9 @@ def transform(self, x, time_features=True, holidays=False, return_target=True, g
return_target=return_target, **holidays_kwargs)
return x, target

@staticmethod
def _transform_(tr, x):
return tr.transform(x, augment=False)
def _transform(self, x, time_features=True, holidays=False, return_target=True, **holidays_kwargs):
"""
Takes the DataFrame x and applies the specified transformations stored in the transformers in order to obtain
Expand All @@ -190,8 +195,14 @@ def _transform(self, x, time_features=True, holidays=False, return_target=True,
self.logger.warning('There are {} nans in x, nans are not supported yet, '
'get over it. I have more important things to do.'.format(x.isna().sum()))

for tr in self.transformers:
x = tr.transform(x)
if len(self.transformers)>0:
if self.n_parallel>1:
with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel) as executor:
x_tr = pd.concat([i for i in executor.map(partial(self._transform_, x=x), self.transformers)], axis=1)
x = pd.concat([x, x_tr], axis=1)
else:
for tr in self.transformers:
x = tr.transform(x)
transformed_columns = [c for c in x.columns if c not in original_columns]
target = pd.DataFrame(index=x.index)
if return_target:
Expand Down

0 comments on commit 5eefefa

Please sign in to comment.