Skip to content

Commit

Permalink
feat: ✨ implement single job version and fix concatenation
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Bury committed Nov 28, 2023
1 parent 4964c04 commit d7594fc
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions src/arfs/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,19 @@ def parallel_matrix_entries(func, df, comb_list, sample_weight=None, n_jobs=-1):
n_jobs = cpu_count() if n_jobs == -1 else min(cpu_count(), n_jobs)

if n_jobs == 1:
return func(X=df, sample_weight=sample_weight, comb_list=comb_list)
lst = func(X=df, sample_weight=sample_weight, comb_list=comb_list)
return pd.concat(lst, ignore_index=True).sort_values("val", ascending=False)

comb_chunks = np.array_split(comb_list, n_jobs)
lst = Parallel(n_jobs=n_jobs)(
delayed(func)(X=df, sample_weight=sample_weight, comb_list=comb_chunk)
for comb_chunk in comb_chunks
)

return pd.concat(list(chain(*lst)), ignore_index=True)
# Directly return the single DataFrame if lst contains only one element
if len(lst) == 1:
return lst[0]
else:
return pd.concat(list(chain(*lst)), ignore_index=True)


def parallel_df(func, df, series, sample_weight=None, n_jobs=-1):
Expand Down Expand Up @@ -80,15 +84,21 @@ def parallel_df(func, df, series, sample_weight=None, n_jobs=-1):
n_jobs = cpu_count() if n_jobs == -1 else min(cpu_count(), n_jobs)

if n_jobs == 1:
return func(df, series, sample_weight).sort_values(ascending=False)
lst = func(df, series, sample_weight).sort_values(ascending=False)

col_chunks = np.array_split(range(len(df.columns)), n_jobs)
lst = Parallel(n_jobs=n_jobs)(
delayed(func)(df.iloc[:, col_chunk], series, sample_weight)
for col_chunk in col_chunks
)
return (
pd.concat(lst, ignore_index=True).sort_values("val", ascending=False)
if isinstance(lst, list)
else lst
)
else:
col_chunks = np.array_split(range(len(df.columns)), n_jobs)
lst = Parallel(n_jobs=n_jobs)(
delayed(func)(df.iloc[:, col_chunk], series, sample_weight)
for col_chunk in col_chunks
)

return pd.concat(lst).sort_values(ascending=False)
return pd.concat(lst).sort_values(ascending=False)


def _compute_series(
Expand Down

0 comments on commit d7594fc

Please sign in to comment.