Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement engine="thread" in ChunkRecordingExecutor #3526

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
117 changes: 79 additions & 38 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sys
from tqdm.auto import tqdm

from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing as mp
from threadpoolctl import threadpool_limits

Expand All @@ -39,6 +39,7 @@


job_keys = (
"pool_engine",
"n_jobs",
"total_memory",
"chunk_size",
Expand Down Expand Up @@ -292,6 +293,8 @@ class ChunkRecordingExecutor:
gather_func : None or callable, default: None
Optional function that is called in the main thread and retrieves the results of each worker.
This function can be used instead of `handle_returns` to implement custom storage on-the-fly.
pool_engine : "process" | "thread"
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
If n_jobs>1 then use ProcessPoolExecutor or ThreadPoolExecutor
n_jobs : int, default: 1
Number of jobs to be used. Use -1 to use as many jobs as number of cores
total_memory : str, default: None
Expand Down Expand Up @@ -329,6 +332,7 @@ def __init__(
progress_bar=False,
handle_returns=False,
gather_func=None,
pool_engine="process",
n_jobs=1,
total_memory=None,
chunk_size=None,
Expand Down Expand Up @@ -370,6 +374,8 @@ def __init__(
self.job_name = job_name
self.max_threads_per_process = max_threads_per_process

self.pool_engine = pool_engine

if verbose:
chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize
total_memory = chunk_memory * self.n_jobs
Expand All @@ -380,6 +386,7 @@ def __init__(
print(
self.job_name,
"\n"
f"engine={self.pool_engine} - "
f"n_jobs={self.n_jobs} - "
f"samples_per_chunk={self.chunk_size:,} - "
f"chunk_memory={chunk_memory_str} - "
Expand All @@ -402,7 +409,7 @@ def run(self, recording_slices=None):

if self.n_jobs == 1:
if self.progress_bar:
recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name)
recording_slices = tqdm(recording_slices, desc=self.job_name, total=len(recording_slices))

worker_ctx = self.init_func(*self.init_args)
for segment_index, frame_start, frame_stop in recording_slices:
Expand All @@ -411,60 +418,94 @@ def run(self, recording_slices=None):
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)

else:
n_jobs = min(self.n_jobs, len(recording_slices))

# parallel
with ProcessPoolExecutor(
max_workers=n_jobs,
initializer=worker_initializer,
mp_context=mp.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
) as executor:
results = executor.map(function_wrapper, recording_slices)

if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(recording_slices))
if self.pool_engine == "process":

# parallel
with ProcessPoolExecutor(
max_workers=n_jobs,
initializer=process_worker_initializer,
mp_context=mp.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
) as executor:
results = executor.map(process_function_wrapper, recording_slices)

if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(recording_slices))

for res in results:
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)

elif self.pool_engine == "thread":
# only one shared context

worker_dict = self.init_func(*self.init_args)
thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process)

with ThreadPoolExecutor(
max_workers=n_jobs,
) as executor:
results = executor.map(thread_func, recording_slices)

if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(recording_slices))

for res in results:
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)

else:
raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'")

return returns

for res in results:
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)

return returns

class WorkerFuncWrapper:
def __init__(self, func, worker_dict, max_threads_per_process):
self.func = func
self.worker_dict = worker_dict
self.max_threads_per_process = max_threads_per_process

def __call__(self, args):
segment_index, start_frame, end_frame = args
if self.max_threads_per_process is None:
return self.func(segment_index, start_frame, end_frame, self.worker_dict)
else:
with threadpool_limits(limits=self.max_threads_per_process):
return self.func(segment_index, start_frame, end_frame, self.worker_dict)

# see
# https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool
# the tricks is : theses 2 variables are global per worker
# so they are not share in the same process
global _worker_ctx
global _func
# global _worker_ctx
# global _func
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
global _process_func_wrapper


def worker_initializer(func, init_func, init_args, max_threads_per_process):
global _worker_ctx
def process_worker_initializer(func, init_func, init_args, max_threads_per_process):
global _process_func_wrapper
if max_threads_per_process is None:
_worker_ctx = init_func(*init_args)
worker_dict = init_func(*init_args)
else:
with threadpool_limits(limits=max_threads_per_process):
_worker_ctx = init_func(*init_args)
_worker_ctx["max_threads_per_process"] = max_threads_per_process
global _func
_func = func
worker_dict = init_func(*init_args)
_process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process)


def function_wrapper(args):
segment_index, start_frame, end_frame = args
global _func
global _worker_ctx
max_threads_per_process = _worker_ctx["max_threads_per_process"]
if max_threads_per_process is None:
return _func(segment_index, start_frame, end_frame, _worker_ctx)
else:
with threadpool_limits(limits=max_threads_per_process):
return _func(segment_index, start_frame, end_frame, _worker_ctx)
def process_function_wrapper(args):
global _process_func_wrapper
return _process_func_wrapper(args)



# Here some utils copy/paste from DART (Charlie Windolf)
Expand Down
27 changes: 21 additions & 6 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def init_func(arg1, arg2, arg3):

def test_ChunkRecordingExecutor():
recording = generate_recording(num_channels=2)
# make serializable
recording = recording.save()

init_args = "a", 120, "yep"

Expand Down Expand Up @@ -139,7 +137,7 @@ def __call__(self, res):

gathering_func2 = GatherClass()

# chunk + parallel + gather_func
# process + gather_func
processor = ChunkRecordingExecutor(
recording,
func,
Expand All @@ -148,6 +146,7 @@ def __call__(self, res):
verbose=True,
progress_bar=True,
gather_func=gathering_func2,
pool_engine="process",
n_jobs=2,
chunk_duration="200ms",
job_name="job_name",
Expand All @@ -157,21 +156,37 @@ def __call__(self, res):

assert gathering_func2.pos == num_chunks

# chunk + parallel + spawn
# process spawn
processor = ChunkRecordingExecutor(
recording,
func,
init_func,
init_args,
verbose=True,
progress_bar=True,
pool_engine="process",
mp_context="spawn",
n_jobs=2,
chunk_duration="200ms",
job_name="job_name",
)
processor.run()

# thread
processor = ChunkRecordingExecutor(
recording,
func,
init_func,
init_args,
verbose=True,
progress_bar=True,
pool_engine="thread",
n_jobs=2,
chunk_duration="200ms",
job_name="job_name",
)
processor.run()


def test_fix_job_kwargs():
# test negative n_jobs
Expand Down Expand Up @@ -224,6 +239,6 @@ def test_split_job_kwargs():
# test_divide_segment_into_chunks()
# test_ensure_n_jobs()
# test_ensure_chunk_size()
# test_ChunkRecordingExecutor()
test_fix_job_kwargs()
test_ChunkRecordingExecutor()
# test_fix_job_kwargs()
# test_split_job_kwargs()