Skip to content

Commit

Permalink
ParallelExecution step (#543)
Browse files Browse the repository at this point in the history
* `ParallelExecution` step

[ML-7689](https://iguazio.atlassian.net/browse/ML-7689)

* Add runnable names, change output format

* Various improvements and additions

* Add error on duplicate runnable selection, similar to #545

* Make ParallelExecution friendly to mlrun serialization

* Process and thread limits, always spawn, docs

* Runtime, output format, docs

* Export ParallelExecution

* Improve error, add test

* Type hints

* Copy event body to protect against mutation, pass path

* select_runnables can return None for all runnables

* Add run_async, expose supported mechanisms

* Fix test following async change

* Fix result gathering on selection

* Rename parameters, add comment, add kwargs

* Add more docstrings, rename parameter

* Add type annotations

* Move mechanism list out of class

* Add explicit max processes and threads defaults

* Remove redundant ifs

* Remove print, improve docstring

---------

Co-authored-by: Gal Topper <[email protected]>
  • Loading branch information
gtopper and Gal Topper authored Dec 9, 2024
1 parent c1710a2 commit 77e57d8
Show file tree
Hide file tree
Showing 3 changed files with 361 additions and 1 deletion.
2 changes: 2 additions & 0 deletions storey/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
from .flow import Map # noqa: F401
from .flow import MapClass # noqa: F401
from .flow import MapWithState # noqa: F401
from .flow import ParallelExecution # noqa: F401
from .flow import ParallelExecutionRunnable # noqa: F401
from .flow import Recover # noqa: F401
from .flow import Reduce # noqa: F401
from .flow import Rename # noqa: F401
Expand Down
182 changes: 182 additions & 0 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import copy
import datetime
import inspect
import multiprocessing
import os
import pickle
import time
import traceback
Expand Down Expand Up @@ -1435,3 +1437,183 @@ def get_table(self, key):

def set_table(self, key, table):
self._tables[key] = table


class _ParallelExecutionRunnableResult:
def __init__(self, runnable_name: str, data: Any, runtime: float):
self.runnable_name = runnable_name
self.data = data
self.runtime = runtime


parallel_execution_mechanisms = ("multiprocessing", "threading", "asyncio", "naive")


class ParallelExecutionRunnable:
"""
Runnable to be run by a ParallelExecution step. Subclasses must assign execution_mechanism with one of:
* "multiprocessing" – To run in a separate process. This is appropriate for CPU or GPU intensive tasks as they
would otherwise block the main process by holding Python's Global Interpreter Lock (GIL).
* "threading" – To run in a separate thread. This is appropriate for blocking I/O tasks, as they would otherwise
block the main event loop thread.
* "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the event
loop to continue running while waiting for a response.
* "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O. It
means that the runnable will not actually be run in parallel to anything else.
Subclasses must also override the run() method, or run_async() when execution_mechanism="asyncio", with user code
that handles the event and returns a result.
Subclasses may optionally override the init() method if the user's implementation of run() requires prior
initialization.
:param name: Runnable name
"""

execution_mechanism: Optional[str] = None

# ignore unused keyword arguments such as context which may be passed in by mlrun
def __init__(self, name: str, **kwargs):
if self.execution_mechanism not in parallel_execution_mechanisms:
raise ValueError(
"ParallelExecutionRunnable's execution_mechanism attribute must be overridden with one of: "
'"multiprocessing", "threading", "asyncio", "naive"'
)
self.name = name

def init(self) -> None:
"""Override this method to add initialization logic."""
pass

def run(self, body: Any, path: str) -> Any:
"""
Override this method with the code this runnable should run. If execution_mechanism is "asyncio", override
run_async() instead.
:param body: Event body
:param path: Event path
"""
return body

async def run_async(self, body: Any, path: str) -> Any:
"""
If execution_mechanism is "asyncio", override this method with the code this runnable should run. Otherwise,
override run() instead.
:param body: Event body
:param path: Event path
"""
return body

def _run(self, body: Any, path: str) -> Any:
start = time.monotonic()
body = self.run(body, path)
end = time.monotonic()
return _ParallelExecutionRunnableResult(self.name, body, end - start)

async def _async_run(self, body: Any, path: str) -> Any:
start = time.monotonic()
body = await self.run_async(body, path)
end = time.monotonic()
return _ParallelExecutionRunnableResult(self.name, body, end - start)


class ParallelExecution(Flow):
"""
Runs multiple jobs in parallel for each event.
:param runnables: A list of ParallelExecutionRunnable instances.
:param max_processes: Maximum number of processes to spawn. Defaults to the number of available CPUs, or 16 if
number of CPUs can't be determined.
:param max_threads: Maximum number of threads to start. Defaults to 32.
"""

def __init__(
self,
runnables: list[ParallelExecutionRunnable],
max_processes: Optional[int] = None,
max_threads: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)

if not runnables:
raise ValueError("ParallelExecution cannot be instantiated without at least one runnable")

self.runnables = runnables
self._runnable_by_name = {}

self.max_processes = max_processes or os.cpu_count() or 16
self.max_threads = max_threads or 32

def select_runnables(self, event) -> Optional[Union[list[str], list[ParallelExecutionRunnable]]]:
"""
Given an event, returns a list of runnables (or a list of runnable names) to execute on it. It can also return
None, in which case all runnables are executed on the event, which is also the default.
:param event: Event object
"""
pass

def _init(self):
super()._init()
num_processes = 0
num_threads = 0
for runnable in self.runnables:
if runnable.name in self._runnable_by_name:
raise ValueError(f"ParallelExecutionRunnable name '{runnable.name}' is not unique")
self._runnable_by_name[runnable.name] = runnable
runnable.init()
if runnable.execution_mechanism == "multiprocessing":
num_processes += 1
elif runnable.execution_mechanism == "threading":
num_threads += 1
elif runnable.execution_mechanism not in ("asyncio", "naive"):
raise ValueError(f"Unsupported execution mechanism: {runnable.execution_mechanism}")

# enforce max
num_processes = min(num_processes, self.max_processes)
num_threads = min(num_threads, self.max_threads)

self._executors = {}
if num_processes:
mp_context = multiprocessing.get_context("spawn")
self._executors["multiprocessing"] = ProcessPoolExecutor(max_workers=num_processes, mp_context=mp_context)
if num_threads:
self._executors["threading"] = ThreadPoolExecutor(max_workers=num_threads)

async def _do(self, event):
if event is _termination_obj:
return await self._do_downstream(_termination_obj)
else:
runnables = self.select_runnables(event)
if runnables is None:
runnables = self.runnables
futures = []
runnables_encountered = set()
for runnable in runnables:
if isinstance(runnable, str):
runnable = self._runnable_by_name[runnable]
if id(runnable) in runnables_encountered:
raise ValueError(f"select_runnables() returned more than one outlet named '{runnable.name}'")
input = event.body if runnable.execution_mechanism == "multiprocessing" else copy.deepcopy(event.body)
runnables_encountered.add(id(runnable))
if runnable.execution_mechanism == "asyncio":
future = asyncio.get_running_loop().create_task(runnable._async_run(input, event.path))
elif runnable.execution_mechanism == "naive":
future = asyncio.get_running_loop().create_future()
future.set_result(runnable._run(input, event.path))
else:
executor = self._executors[runnable.execution_mechanism]
future = asyncio.get_running_loop().run_in_executor(
executor,
runnable._run,
input,
event.path,
)
futures.append(future)
results: list[_ParallelExecutionRunnableResult] = await asyncio.gather(*futures)
event.body = {"input": event.body, "results": {}}
for result in results:
event.body["results"][result.runnable_name] = {"runtime": result.runtime, "output": result.data}
return await self._do_downstream(event)
178 changes: 177 additions & 1 deletion tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,14 @@
V3ioDriver,
build_flow,
)
from storey.flow import Context, ReifyMetadata, Rename, _ConcurrentJobExecution
from storey.flow import (
Context,
ParallelExecution,
ParallelExecutionRunnable,
ReifyMetadata,
Rename,
_ConcurrentJobExecution,
)


class ATestException(Exception):
Expand Down Expand Up @@ -4686,3 +4693,172 @@ def test_filters_type():
additional_filters=[[("city", "=", "Tel Aviv")], [("age", ">=", "40")]],
filter_column="start_time",
)


class RunnableBusyWait(ParallelExecutionRunnable):
execution_mechanism = "multiprocessing"
_result = 0

def init(self):
self._result = 1

def run(self, data, path):
start = time.monotonic()
while time.monotonic() - start < 1:
pass
return self._result


class RunnableSleep(ParallelExecutionRunnable):
execution_mechanism = "threading"
_result = 0

def init(self):
self._result = 1

def run(self, data, path):
time.sleep(1)
return self._result


class RunnableAsyncSleep(ParallelExecutionRunnable):
execution_mechanism = "asyncio"
_result = 0

def init(self):
self._result = 1

async def run_async(self, data, path):
await asyncio.sleep(1)
print(f"{self.name} returning {self._result}")
return self._result


class RunnableNaiveNoOp(ParallelExecutionRunnable):
execution_mechanism = "naive"
_result = 0

def init(self):
self._result = 1

def run(self, data, path):
return self._result


class RunnableWithError(ParallelExecutionRunnable):
execution_mechanism = "naive"

def run(self, data, path):
raise Exception("This shouldn't run!")


def test_parallel_execution_runnable_uniqueness():
runnables = [
RunnableBusyWait("x"),
RunnableBusyWait("x"),
]
parallel_execution = ParallelExecution(runnables)
with pytest.raises(ValueError, match="ParallelExecutionRunnable name 'x' is not unique"):
parallel_execution._init()


def test_select_runnable_uniqueness():
runnables = [
RunnableNaiveNoOp("x"),
RunnableNaiveNoOp("y"),
]

class MyParallelExecution(ParallelExecution):
def select_runnables(self, event):
return ["x", "x"]

parallel_execution = MyParallelExecution(runnables)

source = SyncEmitSource()
source.to(parallel_execution)

controller = source.run()
controller.emit(0)
controller.terminate()
with pytest.raises(ValueError, match=r"select_runnables\(\) returned more than one outlet named 'x'"):
controller.await_termination()


def test_parallel_execution():
runnables = [
RunnableWithError("error"),
RunnableBusyWait("busy1"),
RunnableBusyWait("busy2"),
RunnableSleep("sleep1"),
RunnableSleep("sleep2"),
RunnableAsyncSleep("asleep1"),
RunnableAsyncSleep("asleep2"),
RunnableAsyncSleep("naive"),
]

class MyParallelExecution(ParallelExecution):
def select_runnables(self, event):
return [runnable.name for runnable in runnables if runnable.name != "error"]

parallel_execution = MyParallelExecution(runnables)
reduce = Reduce([], lambda acc, x: acc + [x])

source = SyncEmitSource()
source.to(parallel_execution).to(reduce)

start = time.monotonic()
controller = source.run()
controller.emit(0)
controller.terminate()
termination_result = controller.await_termination()
end = time.monotonic()

assert end - start < 6
termination_result = termination_result[0]
assert termination_result.keys() == {"input", "results"}
assert termination_result["input"] == 0
results = termination_result["results"]
assert results.keys() == {"busy1", "busy2", "sleep1", "sleep2", "asleep1", "asleep2", "naive"}
for result in results.values():
assert result["output"] == 1
assert 1 < result["runtime"] < 2


def test_invalid_runnable():
with pytest.raises(
ValueError,
match="ParallelExecutionRunnable's execution_mechanism attribute must be overridden with one of: "
'"multiprocessing", "threading", "asyncio", "naive"',
):
ParallelExecutionRunnable("my_runnable")


class RunnableNaiveWithMutation(ParallelExecutionRunnable):
execution_mechanism = "naive"

def run(self, data, path):
data["n"] += 1
return data


def test_event_input_preservation():
runnables = [
RunnableNaiveWithMutation("x"),
]
reduce = Reduce([], lambda acc, x: acc + [x])

source = SyncEmitSource()
source.to(ParallelExecution(runnables)).to(reduce)

controller = source.run()
controller.emit({"n": 1})
controller.terminate()
termination_result = controller.await_termination()
termination_result = termination_result[0]
assert termination_result.keys() == {"input", "results"}
assert termination_result["input"] == {"n": 1}
results = termination_result["results"]
assert results.keys() == {"x"}
result = results["x"]
assert result.keys() == {"runtime", "output"}
assert result["output"] == {"n": 2}

0 comments on commit 77e57d8

Please sign in to comment.