diff --git a/storey/__init__.py b/storey/__init__.py index 0f0664c0..6017a30e 100644 --- a/storey/__init__.py +++ b/storey/__init__.py @@ -50,6 +50,8 @@ from .flow import Map # noqa: F401 from .flow import MapClass # noqa: F401 from .flow import MapWithState # noqa: F401 +from .flow import ParallelExecution # noqa: F401 +from .flow import ParallelExecutionRunnable # noqa: F401 from .flow import Recover # noqa: F401 from .flow import Reduce # noqa: F401 from .flow import Rename # noqa: F401 diff --git a/storey/flow.py b/storey/flow.py index 6956df30..70d566e1 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -16,6 +16,8 @@ import copy import datetime import inspect +import multiprocessing +import os import pickle import time import traceback @@ -1435,3 +1437,183 @@ def get_table(self, key): def set_table(self, key, table): self._tables[key] = table + + +class _ParallelExecutionRunnableResult: + def __init__(self, runnable_name: str, data: Any, runtime: float): + self.runnable_name = runnable_name + self.data = data + self.runtime = runtime + + +parallel_execution_mechanisms = ("multiprocessing", "threading", "asyncio", "naive") + + +class ParallelExecutionRunnable: + """ + Runnable to be run by a ParallelExecution step. Subclasses must assign execution_mechanism with one of: + * "multiprocessing" – To run in a separate process. This is appropriate for CPU or GPU intensive tasks as they + would otherwise block the main process by holding Python's Global Interpreter Lock (GIL). + * "threading" – To run in a separate thread. This is appropriate for blocking I/O tasks, as they would otherwise + block the main event loop thread. + * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the event + loop to continue running while waiting for a response. + * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O. It + means that the runnable will not actually be run in parallel to anything else. + + Subclasses must also override the run() method, or run_async() when execution_mechanism="asyncio", with user code + that handles the event and returns a result. + + Subclasses may optionally override the init() method if the user's implementation of run() requires prior + initialization. + + :param name: Runnable name + """ + + execution_mechanism: Optional[str] = None + + # ignore unused keyword arguments such as context which may be passed in by mlrun + def __init__(self, name: str, **kwargs): + if self.execution_mechanism not in parallel_execution_mechanisms: + raise ValueError( + "ParallelExecutionRunnable's execution_mechanism attribute must be overridden with one of: " + '"multiprocessing", "threading", "asyncio", "naive"' + ) + self.name = name + + def init(self) -> None: + """Override this method to add initialization logic.""" + pass + + def run(self, body: Any, path: str) -> Any: + """ + Override this method with the code this runnable should run. If execution_mechanism is "asyncio", override + run_async() instead. + + :param body: Event body + :param path: Event path + """ + return body + + async def run_async(self, body: Any, path: str) -> Any: + """ + If execution_mechanism is "asyncio", override this method with the code this runnable should run. Otherwise, + override run() instead. + + :param body: Event body + :param path: Event path + """ + return body + + def _run(self, body: Any, path: str) -> Any: + start = time.monotonic() + body = self.run(body, path) + end = time.monotonic() + return _ParallelExecutionRunnableResult(self.name, body, end - start) + + async def _async_run(self, body: Any, path: str) -> Any: + start = time.monotonic() + body = await self.run_async(body, path) + end = time.monotonic() + return _ParallelExecutionRunnableResult(self.name, body, end - start) + + +class ParallelExecution(Flow): + """ + Runs multiple jobs in parallel for each event. + + :param runnables: A list of ParallelExecutionRunnable instances. + :param max_processes: Maximum number of processes to spawn. Defaults to the number of available CPUs, or 16 if + number of CPUs can't be determined. + :param max_threads: Maximum number of threads to start. Defaults to 32. + """ + + def __init__( + self, + runnables: list[ParallelExecutionRunnable], + max_processes: Optional[int] = None, + max_threads: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + + if not runnables: + raise ValueError("ParallelExecution cannot be instantiated without at least one runnable") + + self.runnables = runnables + self._runnable_by_name = {} + + self.max_processes = max_processes or os.cpu_count() or 16 + self.max_threads = max_threads or 32 + + def select_runnables(self, event) -> Optional[Union[list[str], list[ParallelExecutionRunnable]]]: + """ + Given an event, returns a list of runnables (or a list of runnable names) to execute on it. It can also return + None, in which case all runnables are executed on the event, which is also the default. + + :param event: Event object + """ + pass + + def _init(self): + super()._init() + num_processes = 0 + num_threads = 0 + for runnable in self.runnables: + if runnable.name in self._runnable_by_name: + raise ValueError(f"ParallelExecutionRunnable name '{runnable.name}' is not unique") + self._runnable_by_name[runnable.name] = runnable + runnable.init() + if runnable.execution_mechanism == "multiprocessing": + num_processes += 1 + elif runnable.execution_mechanism == "threading": + num_threads += 1 + elif runnable.execution_mechanism not in ("asyncio", "naive"): + raise ValueError(f"Unsupported execution mechanism: {runnable.execution_mechanism}") + + # enforce max + num_processes = min(num_processes, self.max_processes) + num_threads = min(num_threads, self.max_threads) + + self._executors = {} + if num_processes: + mp_context = multiprocessing.get_context("spawn") + self._executors["multiprocessing"] = ProcessPoolExecutor(max_workers=num_processes, mp_context=mp_context) + if num_threads: + self._executors["threading"] = ThreadPoolExecutor(max_workers=num_threads) + + async def _do(self, event): + if event is _termination_obj: + return await self._do_downstream(_termination_obj) + else: + runnables = self.select_runnables(event) + if runnables is None: + runnables = self.runnables + futures = [] + runnables_encountered = set() + for runnable in runnables: + if isinstance(runnable, str): + runnable = self._runnable_by_name[runnable] + if id(runnable) in runnables_encountered: + raise ValueError(f"select_runnables() returned more than one outlet named '{runnable.name}'") + input = event.body if runnable.execution_mechanism == "multiprocessing" else copy.deepcopy(event.body) + runnables_encountered.add(id(runnable)) + if runnable.execution_mechanism == "asyncio": + future = asyncio.get_running_loop().create_task(runnable._async_run(input, event.path)) + elif runnable.execution_mechanism == "naive": + future = asyncio.get_running_loop().create_future() + future.set_result(runnable._run(input, event.path)) + else: + executor = self._executors[runnable.execution_mechanism] + future = asyncio.get_running_loop().run_in_executor( + executor, + runnable._run, + input, + event.path, + ) + futures.append(future) + results: list[_ParallelExecutionRunnableResult] = await asyncio.gather(*futures) + event.body = {"input": event.body, "results": {}} + for result in results: + event.body["results"][result.runnable_name] = {"runtime": result.runtime, "output": result.data} + return await self._do_downstream(event) diff --git a/tests/test_flow.py b/tests/test_flow.py index 9911efab..84dba5ae 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -70,7 +70,14 @@ V3ioDriver, build_flow, ) -from storey.flow import Context, ReifyMetadata, Rename, _ConcurrentJobExecution +from storey.flow import ( + Context, + ParallelExecution, + ParallelExecutionRunnable, + ReifyMetadata, + Rename, + _ConcurrentJobExecution, +) class ATestException(Exception): @@ -4686,3 +4693,172 @@ def test_filters_type(): additional_filters=[[("city", "=", "Tel Aviv")], [("age", ">=", "40")]], filter_column="start_time", ) + + +class RunnableBusyWait(ParallelExecutionRunnable): + execution_mechanism = "multiprocessing" + _result = 0 + + def init(self): + self._result = 1 + + def run(self, data, path): + start = time.monotonic() + while time.monotonic() - start < 1: + pass + return self._result + + +class RunnableSleep(ParallelExecutionRunnable): + execution_mechanism = "threading" + _result = 0 + + def init(self): + self._result = 1 + + def run(self, data, path): + time.sleep(1) + return self._result + + +class RunnableAsyncSleep(ParallelExecutionRunnable): + execution_mechanism = "asyncio" + _result = 0 + + def init(self): + self._result = 1 + + async def run_async(self, data, path): + await asyncio.sleep(1) + print(f"{self.name} returning {self._result}") + return self._result + + +class RunnableNaiveNoOp(ParallelExecutionRunnable): + execution_mechanism = "naive" + _result = 0 + + def init(self): + self._result = 1 + + def run(self, data, path): + return self._result + + +class RunnableWithError(ParallelExecutionRunnable): + execution_mechanism = "naive" + + def run(self, data, path): + raise Exception("This shouldn't run!") + + +def test_parallel_execution_runnable_uniqueness(): + runnables = [ + RunnableBusyWait("x"), + RunnableBusyWait("x"), + ] + parallel_execution = ParallelExecution(runnables) + with pytest.raises(ValueError, match="ParallelExecutionRunnable name 'x' is not unique"): + parallel_execution._init() + + +def test_select_runnable_uniqueness(): + runnables = [ + RunnableNaiveNoOp("x"), + RunnableNaiveNoOp("y"), + ] + + class MyParallelExecution(ParallelExecution): + def select_runnables(self, event): + return ["x", "x"] + + parallel_execution = MyParallelExecution(runnables) + + source = SyncEmitSource() + source.to(parallel_execution) + + controller = source.run() + controller.emit(0) + controller.terminate() + with pytest.raises(ValueError, match=r"select_runnables\(\) returned more than one outlet named 'x'"): + controller.await_termination() + + +def test_parallel_execution(): + runnables = [ + RunnableWithError("error"), + RunnableBusyWait("busy1"), + RunnableBusyWait("busy2"), + RunnableSleep("sleep1"), + RunnableSleep("sleep2"), + RunnableAsyncSleep("asleep1"), + RunnableAsyncSleep("asleep2"), + RunnableAsyncSleep("naive"), + ] + + class MyParallelExecution(ParallelExecution): + def select_runnables(self, event): + return [runnable.name for runnable in runnables if runnable.name != "error"] + + parallel_execution = MyParallelExecution(runnables) + reduce = Reduce([], lambda acc, x: acc + [x]) + + source = SyncEmitSource() + source.to(parallel_execution).to(reduce) + + start = time.monotonic() + controller = source.run() + controller.emit(0) + controller.terminate() + termination_result = controller.await_termination() + end = time.monotonic() + + assert end - start < 6 + termination_result = termination_result[0] + assert termination_result.keys() == {"input", "results"} + assert termination_result["input"] == 0 + results = termination_result["results"] + assert results.keys() == {"busy1", "busy2", "sleep1", "sleep2", "asleep1", "asleep2", "naive"} + for result in results.values(): + assert result["output"] == 1 + assert 1 < result["runtime"] < 2 + + +def test_invalid_runnable(): + with pytest.raises( + ValueError, + match="ParallelExecutionRunnable's execution_mechanism attribute must be overridden with one of: " + '"multiprocessing", "threading", "asyncio", "naive"', + ): + ParallelExecutionRunnable("my_runnable") + + +class RunnableNaiveWithMutation(ParallelExecutionRunnable): + execution_mechanism = "naive" + + def run(self, data, path): + data["n"] += 1 + return data + + +def test_event_input_preservation(): + runnables = [ + RunnableNaiveWithMutation("x"), + ] + reduce = Reduce([], lambda acc, x: acc + [x]) + + source = SyncEmitSource() + source.to(ParallelExecution(runnables)).to(reduce) + + controller = source.run() + controller.emit({"n": 1}) + controller.terminate() + termination_result = controller.await_termination() + termination_result = termination_result[0] + assert termination_result.keys() == {"input", "results"} + assert termination_result["input"] == {"n": 1} + results = termination_result["results"] + assert results.keys() == {"x"} + result = results["x"] + assert result.keys() == {"runtime", "output"} + assert result["output"] == {"n": 2}