Skip to content

Commit

Permalink
FEAT-modin-project#4909: Properly implement map operator
Browse files Browse the repository at this point in the history
Signed-off-by: Jonathan Shi <[email protected]>
  • Loading branch information
noloerino committed Nov 1, 2022
1 parent a93399c commit f4fc351
Show file tree
Hide file tree
Showing 15 changed files with 1,601 additions and 1,548 deletions.
34 changes: 25 additions & 9 deletions modin/core/dataframe/algebra/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def register(cls, func, join_type="outer", labels="replace"):
"""

def caller(
query_compiler, other, broadcast=False, *args, dtypes=None, **kwargs
query_compiler,
other,
broadcast=False,
*args,
dtypes=None,
copy_dtypes=False,
**kwargs
):
"""
Apply binary `func` to passed operands.
Expand All @@ -61,8 +67,14 @@ def caller(
at the query compiler level, so this parameter is a hint that passed from a high level API.
*args : args,
Arguments that will be passed to `func`.
dtypes : "copy" or None, default: None
Whether to keep old dtypes or infer new dtypes from data.
dtypes : pandas.Series or scalar type, optional
The data types for the result. This is an optimization
because there are functions that always result in a particular data
type, and this allows us to avoid (re)computing it.
If the argument is a scalar type, then that type is assigned to each result column.
copy_dtypes : bool, default False
If True, the dtypes of the resulting dataframe are copied from the original,
and the ``dtypes`` argument is ignored.
**kwargs : kwargs,
Arguments that will be passed to `func`.
Expand All @@ -76,21 +88,22 @@ def caller(
if broadcast:
assert (
len(other.columns) == 1
), "Invalid broadcast argument for `broadcast_apply`, too many columns: {}".format(
), "Invalid broadcast argument for `map` with broadcast, too many columns: {}".format(
len(other.columns)
)
# Transpose on `axis=1` because we always represent an individual
# column or row as a single-column Modin DataFrame
if axis == 1:
other = other.transpose()
return query_compiler.__constructor__(
query_compiler._modin_frame.broadcast_apply(
axis,
query_compiler._modin_frame.map(
lambda l, r: func(l, r.squeeze(), *args, **kwargs),
other._modin_frame,
axis=axis,
other=other._modin_frame,
join_type=join_type,
labels=labels,
dtypes=dtypes,
copy_dtypes=copy_dtypes,
)
)
else:
Expand All @@ -103,17 +116,20 @@ def caller(
)
else:
if isinstance(other, (list, np.ndarray, pandas.Series)):
new_modin_frame = query_compiler._modin_frame.apply_full_axis(
axis,
new_modin_frame = query_compiler._modin_frame.map(
lambda df: func(df, other, *args, **kwargs),
axis=axis,
full_axis=True,
new_index=query_compiler.index,
new_columns=query_compiler.columns,
dtypes=dtypes,
copy_dtypes=copy_dtypes,
)
else:
new_modin_frame = query_compiler._modin_frame.map(
lambda df: func(df, other, *args, **kwargs),
dtypes=dtypes,
copy_dtypes=copy_dtypes,
)
return query_compiler.__constructor__(new_modin_frame)

Expand Down
38 changes: 28 additions & 10 deletions modin/core/dataframe/base/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,20 @@
"""

from abc import ABC, abstractmethod
from typing import List, Hashable, Optional, Callable, Union, Dict
from typing import List, Hashable, Optional, Callable, Union, Dict, TYPE_CHECKING

import pandas

from modin.core.dataframe.base.dataframe.utils import Axis, JoinType

if TYPE_CHECKING:
from modin.core.dataframe.pandas.partitioning import (
PandasDataframePartition,
PandasDataframeAxisPartition,
)

Partition = Union[PandasDataframeAxisPartition, PandasDataframePartition]


class ModinDataframe(ABC):
"""
Expand Down Expand Up @@ -91,8 +102,10 @@ def filter_by_types(self, types: List[Hashable]) -> "ModinDataframe":
def map(
self,
function: Callable,
*,
axis: Optional[Union[int, Axis]] = None,
dtypes: Optional[str] = None,
dtypes: Optional[Union[pandas.Series, type]] = None,
copy_dtypes: bool = False,
) -> "ModinDataframe":
"""
Apply a user-defined function row-wise if `axis`=0, column-wise if `axis`=1, and cell-wise if `axis` is None.
Expand All @@ -102,11 +115,14 @@ def map(
function : callable(row|col|cell) -> row|col|cell
The function to map across the dataframe.
axis : int or modin.core.dataframe.base.utils.Axis, optional
The axis to map over.
dtypes : str, optional
The axis to map over. If None, the map will be performed element-wise.
dtypes : pandas.Series or scalar type, optional
The data types for the result. This is an optimization
because there are functions that always result in a particular data
type, and this allows us to avoid (re)computing it.
copy_dtypes : bool, default: False
If True, the dtypes of the resulting dataframe are copied from the original,
and the ``dtypes`` argument is ignored.
Returns
-------
Expand Down Expand Up @@ -257,8 +273,9 @@ def groupby(
def reduce(
self,
axis: Union[int, Axis],
function: Callable,
dtypes: Optional[str] = None,
function: Callable[["Partition"], object],
*,
dtypes: Optional[pandas.Series] = None,
) -> "ModinDataframe":
"""
Perform a user-defined aggregation on the specified axis, where the axis reduces down to a singleton.
Expand All @@ -269,7 +286,7 @@ def reduce(
The axis to perform the reduce over.
function : callable(row|col) -> single value
The reduce function to apply to each column.
dtypes : str, optional
dtypes : pandas.Series, optional
The data types for the result. This is an optimization
because there are functions that always result in a particular data
type, and this allows us to avoid (re)computing it.
Expand All @@ -289,9 +306,10 @@ def reduce(
def tree_reduce(
self,
axis: Union[int, Axis],
map_func: Callable,
reduce_func: Optional[Callable] = None,
dtypes: Optional[str] = None,
map_func: Callable[["Partition"], "Partition"],
reduce_func: Optional[Callable[["Partition"], object]] = None,
*,
dtypes: Optional[pandas.Series] = None,
) -> "ModinDataframe":
"""
Perform a user-defined aggregation on the specified axis, where the axis reduces down to a singleton using a tree-reduce computation pattern.
Expand Down
10 changes: 10 additions & 0 deletions modin/core/dataframe/base/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
"""

from enum import Enum
import sys

if sys.version_info.minor < 8:
from typing_extensions import Literal
else:
from typing import Literal


class Axis(Enum): # noqa: PR01
Expand All @@ -36,6 +42,10 @@ class Axis(Enum): # noqa: PR01
CELL_WISE = None


AxisInt = Literal[0, 1]
"""Type for the two possible integer values of an axis argument (0 or 1)."""


class JoinType(Enum): # noqa: PR01
"""
An enum that represents the `join_type` argument provided to the algebra operators.
Expand Down
Loading

0 comments on commit f4fc351

Please sign in to comment.