Skip to content

Commit

Permalink
FEAT-modin-project#4909: Properly implement map operator
Browse files Browse the repository at this point in the history
Signed-off-by: Jonathan Shi <[email protected]>
  • Loading branch information
noloerino committed Dec 6, 2022
1 parent 2ebc9cf commit da415ad
Show file tree
Hide file tree
Showing 15 changed files with 1,704 additions and 1,548 deletions.
33 changes: 24 additions & 9 deletions modin/core/dataframe/algebra/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def register(cls, func, join_type="outer", labels="replace"):
"""

def caller(
query_compiler, other, broadcast=False, *args, dtypes=None, **kwargs
query_compiler,
other,
broadcast=False,
*args,
dtypes=None,
copy_dtypes=False,
**kwargs
):
"""
Apply binary `func` to passed operands.
Expand All @@ -61,8 +67,14 @@ def caller(
at the query compiler level, so this parameter is a hint that passed from a high level API.
*args : args,
Arguments that will be passed to `func`.
dtypes : "copy" or None, default: None
Whether to keep old dtypes or infer new dtypes from data.
dtypes : pandas.Series or scalar type, optional
The data types for the result. This is an optimization
because there are functions that always result in a particular data
type, and this allows us to avoid (re)computing it.
If the argument is a scalar type, then that type is assigned to each result column.
copy_dtypes : bool, default False
If True, the dtypes of the resulting dataframe are copied from the original,
and the ``dtypes`` argument is ignored.
**kwargs : kwargs,
Arguments that will be passed to `func`.
Expand All @@ -76,21 +88,22 @@ def caller(
if broadcast:
assert (
len(other.columns) == 1
), "Invalid broadcast argument for `broadcast_apply`, too many columns: {}".format(
), "Invalid broadcast argument for `map` with broadcast, too many columns: {}".format(
len(other.columns)
)
# Transpose on `axis=1` because we always represent an individual
# column or row as a single-column Modin DataFrame
if axis == 1:
other = other.transpose()
return query_compiler.__constructor__(
query_compiler._modin_frame.broadcast_apply(
axis,
query_compiler._modin_frame.map(
lambda l, r: func(l, r.squeeze(), *args, **kwargs),
other._modin_frame,
axis=axis,
other=other._modin_frame,
join_type=join_type,
labels=labels,
dtypes=dtypes,
copy_dtypes=copy_dtypes,
)
)
else:
Expand All @@ -105,17 +118,19 @@ def caller(
# TODO: it's possible to chunk the `other` and broadcast them to partitions
# accordingly, in that way we will be able to use more efficient `._modin_frame.map()`
if isinstance(other, (dict, list, np.ndarray, pandas.Series)):
new_modin_frame = query_compiler._modin_frame.apply_full_axis(
axis,
new_modin_frame = query_compiler._modin_frame.map_full_axis(
lambda df: func(df, other, *args, **kwargs),
axis=axis,
new_index=query_compiler.index,
new_columns=query_compiler.columns,
dtypes=dtypes,
copy_dtypes=copy_dtypes,
)
else:
new_modin_frame = query_compiler._modin_frame.map(
lambda df: func(df, other, *args, **kwargs),
dtypes=dtypes,
copy_dtypes=copy_dtypes,
)
return query_compiler.__constructor__(new_modin_frame)

Expand Down
22 changes: 16 additions & 6 deletions modin/core/dataframe/base/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

from abc import ABC, abstractmethod
from typing import List, Hashable, Optional, Callable, Union, Dict

import pandas

from modin.core.dataframe.base.dataframe.utils import Axis, JoinType


Expand Down Expand Up @@ -91,8 +94,10 @@ def filter_by_types(self, types: List[Hashable]) -> "ModinDataframe":
def map(
self,
function: Callable,
*,
axis: Optional[Union[int, Axis]] = None,
dtypes: Optional[str] = None,
dtypes: Optional[Union[pandas.Series, type]] = None,
copy_dtypes: bool = False,
) -> "ModinDataframe":
"""
Apply a user-defined function row-wise if `axis`=0, column-wise if `axis`=1, and cell-wise if `axis` is None.
Expand All @@ -102,11 +107,14 @@ def map(
function : callable(row|col|cell) -> row|col|cell
The function to map across the dataframe.
axis : int or modin.core.dataframe.base.utils.Axis, optional
The axis to map over.
dtypes : str, optional
The axis to map over. If None, the map will be performed element-wise.
dtypes : pandas.Series or scalar type, optional
The data types for the result. This is an optimization
because there are functions that always result in a particular data
type, and this allows us to avoid (re)computing it.
copy_dtypes : bool, default: False
If True, the dtypes of the resulting dataframe are copied from the original,
and the ``dtypes`` argument is ignored.
Returns
-------
Expand Down Expand Up @@ -258,7 +266,8 @@ def reduce(
self,
axis: Union[int, Axis],
function: Callable,
dtypes: Optional[str] = None,
*,
dtypes: Optional[pandas.Series] = None,
) -> "ModinDataframe":
"""
Perform a user-defined aggregation on the specified axis, where the axis reduces down to a singleton.
Expand All @@ -269,7 +278,7 @@ def reduce(
The axis to perform the reduce over.
function : callable(row|col) -> single value
The reduce function to apply to each column.
dtypes : str, optional
dtypes : pandas.Series, optional
The data types for the result. This is an optimization
because there are functions that always result in a particular data
type, and this allows us to avoid (re)computing it.
Expand All @@ -291,7 +300,8 @@ def tree_reduce(
axis: Union[int, Axis],
map_func: Callable,
reduce_func: Optional[Callable] = None,
dtypes: Optional[str] = None,
*,
dtypes: Optional[pandas.Series] = None,
) -> "ModinDataframe":
"""
Perform a user-defined aggregation on the specified axis, where the axis reduces down to a singleton using a tree-reduce computation pattern.
Expand Down
10 changes: 10 additions & 0 deletions modin/core/dataframe/base/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
"""

from enum import Enum
import sys

if sys.version_info.minor < 8:
from typing_extensions import Literal
else:
from typing import Literal


class Axis(Enum): # noqa: PR01
Expand All @@ -36,6 +42,10 @@ class Axis(Enum): # noqa: PR01
CELL_WISE = None


AxisInt = Literal[0, 1]
"""Type for the two possible integer values of an axis argument (0 or 1)."""


class JoinType(Enum): # noqa: PR01
"""
An enum that represents the `join_type` argument provided to the algebra operators.
Expand Down
Loading

0 comments on commit da415ad

Please sign in to comment.