Skip to content

Commit

Permalink
BUG: outer join on equal indexes not sorting (#56426)
Browse files Browse the repository at this point in the history
* outer join on equal indexes to sort by default

* whatsnew

* fix test

* remove Index._join_precedence
  • Loading branch information
lukemanley authored Dec 9, 2023
1 parent cb56347 commit 23c20de
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 52 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ These are bug fixes that might have notable behavior changes.

In previous versions of pandas, :func:`merge` and :meth:`DataFrame.join` did not
always return a result that followed the documented sort behavior. pandas now
follows the documented sort behavior in merge and join operations (:issue:`54611`).
follows the documented sort behavior in merge and join operations (:issue:`54611`, :issue:`56426`).

As documented, ``sort=True`` sorts the join keys lexicographically in the resulting
:class:`DataFrame`. With ``sort=False``, the order of the join keys depends on the
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/computation/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _align_core(terms):
ax, itm = axis, items

if not axes[ax].is_(itm):
axes[ax] = axes[ax].join(itm, how="outer")
axes[ax] = axes[ax].union(itm)

for i, ndim in ndims.items():
for axis, items in zip(range(ndim), axes):
Expand Down
31 changes: 11 additions & 20 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,6 @@ class Index(IndexOpsMixin, PandasObject):
Index([1, 2, 3], dtype='uint8')
"""

# To hand over control to subclasses
_join_precedence = 1

# similar to __array_priority__, positions Index after Series and DataFrame
# but before ExtensionArray. Should NOT be overridden by subclasses.
__pandas_priority__ = 2000
Expand Down Expand Up @@ -4564,6 +4561,7 @@ def join(
Index([1, 2, 3, 4, 5, 6], dtype='int64')
"""
other = ensure_index(other)
sort = sort or how == "outer"

if isinstance(self, ABCDatetimeIndex) and isinstance(other, ABCDatetimeIndex):
if (self.tz is None) ^ (other.tz is None):
Expand Down Expand Up @@ -4614,15 +4612,6 @@ def join(
rindexer = np.array([])
return join_index, None, rindexer

if self._join_precedence < other._join_precedence:
flip: dict[JoinHow, JoinHow] = {"right": "left", "left": "right"}
how = flip.get(how, how)
join_index, lidx, ridx = other.join(
self, how=how, level=level, return_indexers=True
)
lidx, ridx = ridx, lidx
return join_index, lidx, ridx

if self.dtype != other.dtype:
dtype = self._find_common_type_compat(other)
this = self.astype(dtype, copy=False)
Expand Down Expand Up @@ -4666,18 +4655,20 @@ def _join_via_get_indexer(
# Note: at this point we have checked matching dtypes

if how == "left":
join_index = self
join_index = self.sort_values() if sort else self
elif how == "right":
join_index = other
join_index = other.sort_values() if sort else other
elif how == "inner":
join_index = self.intersection(other, sort=sort)
elif how == "outer":
# TODO: sort=True here for backwards compat. It may
# be better to use the sort parameter passed into join
join_index = self.union(other)

if sort and how in ["left", "right"]:
join_index = join_index.sort_values()
try:
join_index = self.union(other, sort=sort)
except TypeError:
join_index = self.union(other)
try:
join_index = _maybe_try_sort(join_index, sort)
except TypeError:
pass

if join_index is self:
lindexer = None
Expand Down
2 changes: 0 additions & 2 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,6 @@ class DatetimeTimedeltaMixin(DatetimeIndexOpsMixin, ABC):
_is_monotonic_decreasing = Index.is_monotonic_decreasing
_is_unique = Index.is_unique

_join_precedence = 10

@property
def unit(self) -> str:
return self._data.unit
Expand Down
5 changes: 1 addition & 4 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def __init__(
self.on = com.maybe_make_list(on)

self.suffixes = suffixes
self.sort = sort
self.sort = sort or how == "outer"

self.left_index = left_index
self.right_index = right_index
Expand Down Expand Up @@ -1694,9 +1694,6 @@ def get_join_indexers(
elif not sort and how in ["left", "outer"]:
return _get_no_sort_one_missing_indexer(left_n, False)

if not sort and how == "outer":
sort = True

# get left & right join labels and num. of levels at each location
mapped = (
_factorize_keys(left_keys[n], right_keys[n], sort=sort)
Expand Down
13 changes: 5 additions & 8 deletions pandas/tests/indexes/multi/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ def test_join_level_corner_case(idx):


def test_join_self(idx, join_type):
joined = idx.join(idx, how=join_type)
tm.assert_index_equal(joined, idx)
result = idx.join(idx, how=join_type)
expected = idx
if join_type == "outer":
expected = expected.sort_values()
tm.assert_index_equal(result, expected)


def test_join_multi():
Expand Down Expand Up @@ -89,12 +92,6 @@ def test_join_multi():
tm.assert_numpy_array_equal(ridx, exp_ridx)


def test_join_self_unique(idx, join_type):
if idx.is_unique:
joined = idx.join(idx, how=join_type)
assert (idx == joined).all()


def test_join_multi_wrong_order():
# GH 25760
# GH 28956
Expand Down
11 changes: 6 additions & 5 deletions pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,8 +987,11 @@ def test_slice_keep_name(self):
indirect=True,
)
def test_join_self(self, index, join_type):
joined = index.join(index, how=join_type)
assert index is joined
result = index.join(index, how=join_type)
expected = index
if join_type == "outer":
expected = expected.sort_values()
tm.assert_index_equal(result, expected)

@pytest.mark.parametrize("method", ["strip", "rstrip", "lstrip"])
def test_str_attribute(self, method):
Expand Down Expand Up @@ -1072,10 +1075,8 @@ def test_outer_join_sort(self):
with tm.assert_produces_warning(RuntimeWarning):
result = left_index.join(right_index, how="outer")

# right_index in this case because DatetimeIndex has join precedence
# over int64 Index
with tm.assert_produces_warning(RuntimeWarning):
expected = right_index.astype(object).union(left_index.astype(object))
expected = left_index.astype(object).union(right_index.astype(object))

tm.assert_index_equal(result, expected)

Expand Down
6 changes: 5 additions & 1 deletion pandas/tests/indexes/test_old_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
period_range,
)
import pandas._testing as tm
import pandas.core.algorithms as algos
from pandas.core.arrays import BaseMaskedArray


Expand Down Expand Up @@ -653,7 +654,10 @@ def test_join_self_unique(self, join_type, simple_index):
idx = simple_index
if idx.is_unique:
joined = idx.join(idx, how=join_type)
assert (idx == joined).all()
expected = simple_index
if join_type == "outer":
expected = algos.safe_sort(expected)
tm.assert_index_equal(joined, expected)

def test_map(self, simple_index):
# callable
Expand Down
20 changes: 10 additions & 10 deletions pandas/tests/reshape/merge/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,9 @@ def test_merge_empty(self, left_empty, how, exp):
elif exp == "empty_cross":
expected = DataFrame(columns=["A_x", "B", "A_y", "C"], dtype="int64")

if how == "outer":
expected = expected.sort_values("A", ignore_index=True)

tm.assert_frame_equal(result, expected)


Expand Down Expand Up @@ -2913,16 +2916,13 @@ def test_merge_combinations(
expected = expected["key"].repeat(repeats.values)
expected = expected.to_frame()
elif how == "outer":
if on_index and left_unique and left["key"].equals(right["key"]):
expected = DataFrame({"key": left["key"]})
else:
left_counts = left["key"].value_counts()
right_counts = right["key"].value_counts()
expected_counts = left_counts.mul(right_counts, fill_value=1)
expected_counts = expected_counts.astype(np.intp)
expected = expected_counts.index.values.repeat(expected_counts.values)
expected = DataFrame({"key": expected})
expected = expected.sort_values("key")
left_counts = left["key"].value_counts()
right_counts = right["key"].value_counts()
expected_counts = left_counts.mul(right_counts, fill_value=1)
expected_counts = expected_counts.astype(np.intp)
expected = expected_counts.index.values.repeat(expected_counts.values)
expected = DataFrame({"key": expected})
expected = expected.sort_values("key")

if on_index:
expected = expected.set_index("key")
Expand Down
3 changes: 3 additions & 0 deletions pandas/tests/series/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,9 @@ def test_series_add_tz_mismatch_converts_to_utc(self):
uts2 = ser2.tz_convert("utc")
expected = uts1 + uts2

# sort since input indexes are not equal
expected = expected.sort_index()

assert result.index.tz is timezone.utc
tm.assert_series_equal(result, expected)

Expand Down

0 comments on commit 23c20de

Please sign in to comment.