Skip to content

Commit

Permalink
PERF: merge on monotonic keys (#56523)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Dec 17, 2023
1 parent d77d5e5 commit 061c2e9
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 65 deletions.
38 changes: 24 additions & 14 deletions asv_bench/benchmarks/join_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,21 @@ def time_merge_dataframes_cross(self, sort):

class MergeEA:
params = [
"Int64",
"Int32",
"Int16",
"UInt64",
"UInt32",
"UInt16",
"Float64",
"Float32",
[
"Int64",
"Int32",
"Int16",
"UInt64",
"UInt32",
"UInt16",
"Float64",
"Float32",
],
[True, False],
]
param_names = ["dtype"]
param_names = ["dtype", "monotonic"]

def setup(self, dtype):
def setup(self, dtype, monotonic):
N = 10_000
indices = np.arange(1, N)
key = np.tile(indices[:8000], 10)
Expand All @@ -299,8 +302,11 @@ def setup(self, dtype):
"value2": np.random.randn(7999),
}
)
if monotonic:
self.left = self.left.sort_values("key")
self.right = self.right.sort_values("key")

def time_merge(self, dtype):
def time_merge(self, dtype, monotonic):
merge(self.left, self.right)


Expand Down Expand Up @@ -330,10 +336,11 @@ class MergeDatetime:
("ns", "ms"),
],
[None, "Europe/Brussels"],
[True, False],
]
param_names = ["units", "tz"]
param_names = ["units", "tz", "monotonic"]

def setup(self, units, tz):
def setup(self, units, tz, monotonic):
unit_left, unit_right = units
N = 10_000
keys = Series(date_range("2012-01-01", freq="min", periods=N, tz=tz))
Expand All @@ -349,8 +356,11 @@ def setup(self, units, tz):
"value2": np.random.randn(8000),
}
)
if monotonic:
self.left = self.left.sort_values("key")
self.right = self.right.sort_values("key")

def time_merge(self, units, tz):
def time_merge(self, units, tz, monotonic):
merge(self.left, self.right)


Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ Performance improvements
- Performance improvement in :func:`.testing.assert_frame_equal` and :func:`.testing.assert_series_equal` (:issue:`55949`, :issue:`55971`)
- Performance improvement in :func:`concat` with ``axis=1`` and objects with unaligned indexes (:issue:`55084`)
- Performance improvement in :func:`get_dummies` (:issue:`56089`)
- Performance improvement in :func:`merge` and :func:`merge_ordered` when joining on sorted ascending keys (:issue:`56115`)
- Performance improvement in :func:`merge_asof` when ``by`` is not ``None`` (:issue:`55580`, :issue:`55678`)
- Performance improvement in :func:`read_stata` for files with many variables (:issue:`55515`)
- Performance improvement in :meth:`DataFrame.groupby` when aggregating pyarrow timestamp and duration dtypes (:issue:`55031`)
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,8 @@ def value_counts(
_, idx = get_join_indexers(
left, right, sort=False, how="left" # type: ignore[arg-type]
)
out = np.where(idx != -1, out[idx], 0)
if idx is not None:
out = np.where(idx != -1, out[idx], 0)

if sort:
sorter = np.lexsort((out if ascending else -out, left[0]))
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4783,13 +4783,13 @@ def _join_multi(self, other: Index, how: JoinHow):
def _join_non_unique(
self, other: Index, how: JoinHow = "left", sort: bool = False
) -> tuple[Index, npt.NDArray[np.intp], npt.NDArray[np.intp]]:
from pandas.core.reshape.merge import get_join_indexers
from pandas.core.reshape.merge import get_join_indexers_non_unique

# We only get here if dtypes match
assert self.dtype == other.dtype

left_idx, right_idx = get_join_indexers(
[self._values], [other._values], how=how, sort=sort
left_idx, right_idx = get_join_indexers_non_unique(
self._values, other._values, how=how, sort=sort
)
mask = left_idx == -1

Expand Down
160 changes: 114 additions & 46 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,10 +1019,14 @@ def _maybe_add_join_keys(
take_left, take_right = None, None

if name in result:
if left_indexer is not None and right_indexer is not None:
if left_indexer is not None or right_indexer is not None:
if name in self.left:
if left_has_missing is None:
left_has_missing = (left_indexer == -1).any()
left_has_missing = (
False
if left_indexer is None
else (left_indexer == -1).any()
)

if left_has_missing:
take_right = self.right_join_keys[i]
Expand All @@ -1032,21 +1036,27 @@ def _maybe_add_join_keys(

elif name in self.right:
if right_has_missing is None:
right_has_missing = (right_indexer == -1).any()
right_has_missing = (
False
if right_indexer is None
else (right_indexer == -1).any()
)

if right_has_missing:
take_left = self.left_join_keys[i]

if result[name].dtype != self.right[name].dtype:
take_right = self.right[name]._values

elif left_indexer is not None:
else:
take_left = self.left_join_keys[i]
take_right = self.right_join_keys[i]

if take_left is not None or take_right is not None:
if take_left is None:
lvals = result[name]._values
elif left_indexer is None:
lvals = take_left
else:
# TODO: can we pin down take_left's type earlier?
take_left = extract_array(take_left, extract_numpy=True)
Expand All @@ -1055,6 +1065,8 @@ def _maybe_add_join_keys(

if take_right is None:
rvals = result[name]._values
elif right_indexer is None:
rvals = take_right
else:
# TODO: can we pin down take_right's type earlier?
taker = extract_array(take_right, extract_numpy=True)
Expand All @@ -1063,16 +1075,17 @@ def _maybe_add_join_keys(

# if we have an all missing left_indexer
# make sure to just use the right values or vice-versa
mask_left = left_indexer == -1
# error: Item "bool" of "Union[Any, bool]" has no attribute "all"
if mask_left.all(): # type: ignore[union-attr]
if left_indexer is not None and (left_indexer == -1).all():
key_col = Index(rvals)
result_dtype = rvals.dtype
elif right_indexer is not None and (right_indexer == -1).all():
key_col = Index(lvals)
result_dtype = lvals.dtype
else:
key_col = Index(lvals).where(~mask_left, rvals)
key_col = Index(lvals)
if left_indexer is not None:
mask_left = left_indexer == -1
key_col = key_col.where(~mask_left, rvals)
result_dtype = find_common_type([lvals.dtype, rvals.dtype])
if (
lvals.dtype.kind == "M"
Expand Down Expand Up @@ -1103,7 +1116,9 @@ def _maybe_add_join_keys(
else:
result.insert(i, name or f"key_{i}", key_col)

def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
def _get_join_indexers(
self,
) -> tuple[npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:
"""return the join indexers"""
# make mypy happy
assert self.how != "asof"
Expand Down Expand Up @@ -1143,6 +1158,8 @@ def _get_join_info(
left_indexer,
how="right",
)
elif right_indexer is None:
join_index = right_ax.copy()
else:
join_index = right_ax.take(right_indexer)
elif self.left_index:
Expand All @@ -1162,10 +1179,13 @@ def _get_join_info(
right_indexer,
how="left",
)
elif left_indexer is None:
join_index = left_ax.copy()
else:
join_index = left_ax.take(left_indexer)
else:
join_index = default_index(len(left_indexer))
n = len(left_ax) if left_indexer is None else len(left_indexer)
join_index = default_index(n)

return join_index, left_indexer, right_indexer

Expand All @@ -1174,17 +1194,20 @@ def _create_join_index(
self,
index: Index,
other_index: Index,
indexer: npt.NDArray[np.intp],
indexer: npt.NDArray[np.intp] | None,
how: JoinHow = "left",
) -> Index:
"""
Create a join index by rearranging one index to match another
Parameters
----------
index : Index being rearranged
other_index : Index used to supply values not found in index
indexer : np.ndarray[np.intp] how to rearrange index
index : Index
index being rearranged
other_index : Index
used to supply values not found in index
indexer : np.ndarray[np.intp] or None
how to rearrange index
how : str
Replacement is only necessary if indexer based on other_index.
Expand All @@ -1202,6 +1225,8 @@ def _create_join_index(
if np.any(mask):
fill_value = na_value_for_dtype(index.dtype, compat=False)
index = index.append(Index([fill_value]))
if indexer is None:
return index.copy()
return index.take(indexer)

@final
Expand Down Expand Up @@ -1660,7 +1685,7 @@ def get_join_indexers(
right_keys: list[ArrayLike],
sort: bool = False,
how: JoinHow = "inner",
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
) -> tuple[npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:
"""
Parameters
Expand All @@ -1672,9 +1697,9 @@ def get_join_indexers(
Returns
-------
np.ndarray[np.intp]
np.ndarray[np.intp] or None
Indexer into the left_keys.
np.ndarray[np.intp]
np.ndarray[np.intp] or None
Indexer into the right_keys.
"""
assert len(left_keys) == len(
Expand All @@ -1695,37 +1720,77 @@ def get_join_indexers(
elif not sort and how in ["left", "outer"]:
return _get_no_sort_one_missing_indexer(left_n, False)

# get left & right join labels and num. of levels at each location
mapped = (
_factorize_keys(left_keys[n], right_keys[n], sort=sort)
for n in range(len(left_keys))
)
zipped = zip(*mapped)
llab, rlab, shape = (list(x) for x in zipped)
lkey: ArrayLike
rkey: ArrayLike
if len(left_keys) > 1:
# get left & right join labels and num. of levels at each location
mapped = (
_factorize_keys(left_keys[n], right_keys[n], sort=sort)
for n in range(len(left_keys))
)
zipped = zip(*mapped)
llab, rlab, shape = (list(x) for x in zipped)

# get flat i8 keys from label lists
lkey, rkey = _get_join_keys(llab, rlab, tuple(shape), sort)
# get flat i8 keys from label lists
lkey, rkey = _get_join_keys(llab, rlab, tuple(shape), sort)
else:
lkey = left_keys[0]
rkey = right_keys[0]

# factorize keys to a dense i8 space
# `count` is the num. of unique keys
# set(lkey) | set(rkey) == range(count)
left = Index(lkey)
right = Index(rkey)

lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort)
# preserve left frame order if how == 'left' and sort == False
kwargs = {}
if how in ("inner", "left", "right"):
kwargs["sort"] = sort
join_func = {
"inner": libjoin.inner_join,
"left": libjoin.left_outer_join,
"right": lambda x, y, count, **kwargs: libjoin.left_outer_join(
y, x, count, **kwargs
)[::-1],
"outer": libjoin.full_outer_join,
}[how]

# error: Cannot call function of unknown type
return join_func(lkey, rkey, count, **kwargs) # type: ignore[operator]
if (
left.is_monotonic_increasing
and right.is_monotonic_increasing
and (left.is_unique or right.is_unique)
):
_, lidx, ridx = left.join(right, how=how, return_indexers=True, sort=sort)
else:
lidx, ridx = get_join_indexers_non_unique(
left._values, right._values, sort, how
)

if lidx is not None and is_range_indexer(lidx, len(left)):
lidx = None
if ridx is not None and is_range_indexer(ridx, len(right)):
ridx = None
return lidx, ridx


def get_join_indexers_non_unique(
left: ArrayLike,
right: ArrayLike,
sort: bool = False,
how: JoinHow = "inner",
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
"""
Get join indexers for left and right.
Parameters
----------
left : ArrayLike
right : ArrayLike
sort : bool, default False
how : {'inner', 'outer', 'left', 'right'}, default 'inner'
Returns
-------
np.ndarray[np.intp]
Indexer into left.
np.ndarray[np.intp]
Indexer into right.
"""
lkey, rkey, count = _factorize_keys(left, right, sort=sort)
if how == "left":
lidx, ridx = libjoin.left_outer_join(lkey, rkey, count, sort=sort)
elif how == "right":
ridx, lidx = libjoin.left_outer_join(rkey, lkey, count, sort=sort)
elif how == "inner":
lidx, ridx = libjoin.inner_join(lkey, rkey, count, sort=sort)
elif how == "outer":
lidx, ridx = libjoin.full_outer_join(lkey, rkey, count)
return lidx, ridx


def restore_dropped_levels_multijoin(
Expand Down Expand Up @@ -1860,7 +1925,10 @@ def get_result(self, copy: bool | None = True) -> DataFrame:
left_indexer = cast("npt.NDArray[np.intp]", left_indexer)
right_indexer = cast("npt.NDArray[np.intp]", right_indexer)
left_join_indexer = libjoin.ffill_indexer(left_indexer)
right_join_indexer = libjoin.ffill_indexer(right_indexer)
if right_indexer is None:
right_join_indexer = None
else:
right_join_indexer = libjoin.ffill_indexer(right_indexer)
elif self.fill_method is None:
left_join_indexer = left_indexer
right_join_indexer = right_indexer
Expand Down
Loading

0 comments on commit 061c2e9

Please sign in to comment.