Skip to content

Commit

Permalink
FEAT-modin-project#6990: Implement lazy execution for the Ray virtual…
Browse files Browse the repository at this point in the history
… partitions.
  • Loading branch information
AndreyPavlenko committed Mar 14, 2024
1 parent 8710994 commit aa5bfa9
Show file tree
Hide file tree
Showing 4 changed files with 478 additions and 300 deletions.
112 changes: 76 additions & 36 deletions modin/core/execution/ray/common/deferred_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ class DeferredExecution:
The execution input.
func : callable or ObjectRefType
A function to be executed.
args : list or tuple
args : list or tuple, optional
Additional positional arguments to be passed in `func`.
kwargs : dict
kwargs : dict, optional
Additional keyword arguments to be passed in `func`.
num_returns : int
num_returns : int, default: 1
The number of the return values.
flat_data : bool
True means that the data is neither DeferredExecution nor list.
flat_args : bool
True means that there are no lists or DeferredExecution objects in `args`.
In this case, no arguments processing is performed and `args` is passed
Expand All @@ -88,26 +90,29 @@ class DeferredExecution:

def __init__(
self,
data: Union[
ObjectRefType,
"DeferredExecution",
List[Union[ObjectRefType, "DeferredExecution"]],
],
data: Any,
func: Union[Callable, ObjectRefType],
args: Union[List[Any], Tuple[Any]],
kwargs: Dict[str, Any],
args: Union[List[Any], Tuple[Any]] = None,
kwargs: Dict[str, Any] = None,
num_returns=1,
):
if isinstance(data, DeferredExecution):
data.subscribe()
self.flat_data = self._flat_args((data,))
self.data = data
self.func = func
self.args = args
self.kwargs = kwargs
self.num_returns = num_returns
self.flat_args = self._flat_args(args)
self.flat_kwargs = self._flat_args(kwargs.values())
self.subscribers = 0
if args is not None:
self.args = args
self.flat_args = self._flat_args(args)
else:
self.args = ()
self.flat_args = True
if kwargs is not None:
self.kwargs = kwargs
self.flat_kwargs = self._flat_args(kwargs.values())
else:
self.kwargs = {}
self.flat_kwargs = True

@classmethod
def _flat_args(cls, args: Iterable):
Expand All @@ -134,7 +139,7 @@ def _flat_args(cls, args: Iterable):

def exec(
self,
) -> Tuple[ObjectRefOrListType, Union["MetaList", List], Union[int, List[int]]]:
) -> Tuple[ObjectRefOrListType, "MetaList", Union[int, List[int]]]:
"""
Execute this task, if required.
Expand All @@ -150,11 +155,29 @@ def exec(
return self.data, self.meta, self.meta_offset

if (
not isinstance(self.data, DeferredExecution)
self.flat_data
and self.flat_args
and self.flat_kwargs
and self.num_returns == 1
):
# self.data = RayWrapper.materialize(self.data)
# self.args = [
# RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for o in self.args
# ]
# self.kwargs = {
# k: RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for k, o in self.kwargs.items()
# }
# obj = _REMOTE_EXEC.exec_func(
# RayWrapper.materialize(self.func), self.data, self.args, self.kwargs
# )
# result, length, width, ip = (
# obj,
# len(obj) if hasattr(obj, "__len__") else 0,
# len(obj.columns) if hasattr(obj, "columns") else 0,
# "",
# )
result, length, width, ip = remote_exec_func.remote(
self.func, self.data, *self.args, **self.kwargs
)
Expand All @@ -166,14 +189,23 @@ def exec(
# it back. After the execution, the result is saved and the counter has no effect.
self.subscribers += 2
consumers, output = self._deconstruct()

# assert not any(isinstance(o, ListOrTuple) for o in output)
# tmp = [
# RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for o in output
# ]
# list(_REMOTE_EXEC.construct(tmp))

# The last result is the MetaList, so adding +1 here.
num_returns = sum(c.num_returns for c in consumers) + 1
results = self._remote_exec_chain(num_returns, *output)
meta = MetaList(results.pop())
meta_offset = 0
results = iter(results)
for de in consumers:
if de.num_returns == 1:
num_returns = de.num_returns
if num_returns == 1:
de._set_result(next(results), meta, meta_offset)
meta_offset += 2
else:
Expand Down Expand Up @@ -318,6 +350,7 @@ def _deconstruct_chain(
break
elif not isinstance(data := de.data, DeferredExecution):
if isinstance(data, ListOrTuple):
out_append(_Tag.LIST)
yield cls._deconstruct_list(
data, output, stack, result_consumers, out_append
)
Expand Down Expand Up @@ -394,7 +427,13 @@ def _deconstruct_list(
if out_pos := getattr(obj, "out_pos", None):
obj.unsubscribe()
if obj.has_result:
out_append(obj.data)
if isinstance(obj.data, ListOrTuple):
out_append(_Tag.LIST)
yield cls._deconstruct_list(
obj.data, output, stack, result_consumers, out_append
)
else:
out_append(obj.data)
else:
out_append(_Tag.REF)
out_append(out_pos)
Expand Down Expand Up @@ -432,13 +471,13 @@ def _remote_exec_chain(num_returns: int, *args: Tuple) -> List[Any]:
list
The execution results. The last element of this list is the ``MetaList``.
"""
# Prefer _remote_exec_single_chain(). It has fewer arguments and
# does not require the num_returns to be specified in options.
# Prefer _remote_exec_single_chain(). It does not require the num_returns
# to be specified in options.
if num_returns == 2:
return _remote_exec_single_chain.remote(*args)
else:
return _remote_exec_multi_chain.options(num_returns=num_returns).remote(
num_returns, *args
*args
)

def _set_result(
Expand All @@ -456,7 +495,7 @@ def _set_result(
meta : MetaList
meta_offset : int or list of int
"""
del self.func, self.args, self.kwargs, self.flat_args, self.flat_kwargs
del self.func, self.args, self.kwargs
self.data = result
self.meta = meta
self.meta_offset = meta_offset
Expand All @@ -478,6 +517,10 @@ class MetaList:
def __init__(self, obj: Union[ray.ObjectID, ClientObjectRef, List]):
self._obj = obj

def materialize(self):
"""Materialized the list, if required."""
self._obj = RayWrapper.materialize(self._obj)

def __getitem__(self, index):
"""
Get item at the specified index.
Expand Down Expand Up @@ -605,7 +648,7 @@ def exec_func(fn: Callable, obj: Any, args: Tuple, kwargs: Dict) -> Any:
raise err

@classmethod
def construct(cls, num_returns: int, args: Tuple): # pragma: no cover
def construct(cls, args: Tuple): # pragma: no cover
"""
Construct and execute the specified chain.
Expand All @@ -615,7 +658,6 @@ def construct(cls, num_returns: int, args: Tuple): # pragma: no cover
Parameters
----------
num_returns : int
args : tuple
Yields
Expand Down Expand Up @@ -687,7 +729,7 @@ def construct_chain(

while chain:
fn = pop()
if fn == tg_e:
if fn is tg_e:
lst.append(obj)
break

Expand Down Expand Up @@ -717,10 +759,10 @@ def construct_chain(

itr = iter([obj] if num_returns == 1 else obj)
for _ in range(num_returns):
obj = next(itr)
meta.append(len(obj) if hasattr(obj, "__len__") else 0)
meta.append(len(obj.columns) if hasattr(obj, "columns") else 0)
yield obj
o = next(itr)
meta.append(len(o) if hasattr(o, "__len__") else 0)
meta.append(len(o.columns) if hasattr(o, "columns") else 0)
yield o

@classmethod
def construct_list(
Expand Down Expand Up @@ -834,20 +876,18 @@ def _remote_exec_single_chain(
-------
Generator
"""
return remote_executor.construct(num_returns=2, args=args)
return remote_executor.construct(args=args)


@ray.remote
def _remote_exec_multi_chain(
num_returns: int, *args: Tuple, remote_executor=_REMOTE_EXEC
*args: Tuple, remote_executor=_REMOTE_EXEC
) -> Generator: # pragma: no cover
"""
Execute the deconstructed chain with a multiple return values in a worker process.
Parameters
----------
num_returns : int
The number of return values.
*args : tuple
A deconstructed chain to be executed.
remote_executor : _RemoteExecutor, default: _REMOTE_EXEC
Expand All @@ -857,4 +897,4 @@ def _remote_exec_multi_chain(
-------
Generator
"""
return remote_executor.construct(num_returns, args)
return remote_executor.construct(args)
13 changes: 13 additions & 0 deletions modin/core/execution/ray/common/engine_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,19 @@ def post_materialize(self, materialized):
"""
raise NotImplementedError()

def __reduce__(self):
"""
Replace this hook with the materialized object on serialization.
Returns
-------
tuple
"""
data = RayWrapper.materialize(self)
if not isinstance(data, int):
raise NotImplementedError("Only integers are currently supported")
return int, (data,)


RayObjectRefTypes = (ray.ObjectRef, ClientObjectRef)
ObjectRefTypes = (*RayObjectRefTypes, MaterializationHook)
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def add_to_apply_calls(
def drain_call_queue(self):
data = self._data_ref
if not isinstance(data, DeferredExecution):
return data
return

log = get_logger()
self._is_debug(log) and log.debug(
Expand Down
Loading

0 comments on commit aa5bfa9

Please sign in to comment.