Skip to content

Commit

Permalink
String dtype: use ObjectEngine for indexing for now correctness over …
Browse files Browse the repository at this point in the history
…performance (#60329)
  • Loading branch information
jorisvandenbossche authored Nov 26, 2024
1 parent fd570f4 commit 98f7e4d
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 14 deletions.
3 changes: 3 additions & 0 deletions pandas/_libs/index.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class MaskedUInt16Engine(MaskedIndexEngine): ...
class MaskedUInt8Engine(MaskedIndexEngine): ...
class MaskedBoolEngine(MaskedUInt8Engine): ...

class StringObjectEngine(ObjectEngine):
def __init__(self, values: object, na_value) -> None: ...

class BaseMultiIndexCodesEngine:
levels: list[np.ndarray]
offsets: np.ndarray # np.ndarray[..., ndim=1]
Expand Down
25 changes: 25 additions & 0 deletions pandas/_libs/index.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,31 @@ cdef class StringEngine(IndexEngine):
raise KeyError(val)
return str(val)

cdef class StringObjectEngine(ObjectEngine):

cdef:
object na_value
bint uses_na

def __init__(self, ndarray values, na_value):
super().__init__(values)
self.na_value = na_value
self.uses_na = na_value is C_NA

cdef bint _checknull(self, object val):
if self.uses_na:
return val is C_NA
else:
return util.is_nan(val)

cdef _check_type(self, object val):
if isinstance(val, str):
return val
elif self._checknull(val):
return self.na_value
else:
raise KeyError(val)


cdef class DatetimeEngine(Int64Engine):

Expand Down
3 changes: 1 addition & 2 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def _engine(
# ndarray[Any, Any]]" has no attribute "_ndarray" [union-attr]
target_values = self._data._ndarray # type: ignore[union-attr]
elif is_string_dtype(self.dtype) and not is_object_dtype(self.dtype):
return libindex.StringEngine(target_values)
return libindex.StringObjectEngine(target_values, self.dtype.na_value) # type: ignore[union-attr]

# error: Argument 1 to "ExtensionEngine" has incompatible type
# "ndarray[Any, Any]"; expected "ExtensionArray"
Expand Down Expand Up @@ -5974,7 +5974,6 @@ def _should_fallback_to_positional(self) -> bool:
def get_indexer_non_unique(
self, target
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
target = ensure_index(target)
target = self._maybe_cast_listlike_indexer(target)

if not self._should_compare(target) and not self._should_partial_index(target):
Expand Down
104 changes: 93 additions & 11 deletions pandas/tests/indexes/string/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,51 @@
import pandas._testing as tm


def _isnan(val):
try:
return val is not pd.NA and np.isnan(val)
except TypeError:
return False


class TestGetLoc:
def test_get_loc(self, any_string_dtype):
index = Index(["a", "b", "c"], dtype=any_string_dtype)
assert index.get_loc("b") == 1

def test_get_loc_raises(self, any_string_dtype):
index = Index(["a", "b", "c"], dtype=any_string_dtype)
with pytest.raises(KeyError, match="d"):
index.get_loc("d")

def test_get_loc_invalid_value(self, any_string_dtype):
index = Index(["a", "b", "c"], dtype=any_string_dtype)
with pytest.raises(KeyError, match="1"):
index.get_loc(1)

def test_get_loc_non_unique(self, any_string_dtype):
index = Index(["a", "b", "a"], dtype=any_string_dtype)
result = index.get_loc("a")
expected = np.array([True, False, True])
tm.assert_numpy_array_equal(result, expected)

def test_get_loc_non_missing(self, any_string_dtype, nulls_fixture):
index = Index(["a", "b", "c"], dtype=any_string_dtype)
with pytest.raises(KeyError):
index.get_loc(nulls_fixture)

def test_get_loc_missing(self, any_string_dtype, nulls_fixture):
index = Index(["a", "b", nulls_fixture], dtype=any_string_dtype)
if any_string_dtype == "string" and (
(any_string_dtype.na_value is pd.NA and nulls_fixture is not pd.NA)
or (_isnan(any_string_dtype.na_value) and not _isnan(nulls_fixture))
):
with pytest.raises(KeyError):
index.get_loc(nulls_fixture)
else:
assert index.get_loc(nulls_fixture) == 2


class TestGetIndexer:
@pytest.mark.parametrize(
"method,expected",
Expand Down Expand Up @@ -41,23 +86,60 @@ def test_get_indexer_strings_raises(self, any_string_dtype):
["a", "b", "c", "d"], method="pad", tolerance=[2, 2, 2, 2]
)

@pytest.mark.parametrize("null", [None, np.nan, float("nan"), pd.NA])
def test_get_indexer_missing(self, any_string_dtype, null, using_infer_string):
# NaT and Decimal("NaN") from null_fixture are not supported for string dtype
index = Index(["a", "b", null], dtype=any_string_dtype)
result = index.get_indexer(["a", null, "c"])
if using_infer_string:
expected = np.array([0, 2, -1], dtype=np.intp)
elif any_string_dtype == "string" and (
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
):
expected = np.array([0, -1, -1], dtype=np.intp)
else:
expected = np.array([0, 2, -1], dtype=np.intp)

class TestGetIndexerNonUnique:
@pytest.mark.xfail(reason="TODO(infer_string)", strict=False)
def test_get_indexer_non_unique_nas(self, any_string_dtype, nulls_fixture):
index = Index(["a", "b", None], dtype=any_string_dtype)
indexer, missing = index.get_indexer_non_unique([nulls_fixture])
tm.assert_numpy_array_equal(result, expected)

expected_indexer = np.array([2], dtype=np.intp)
expected_missing = np.array([], dtype=np.intp)

class TestGetIndexerNonUnique:
@pytest.mark.parametrize("null", [None, np.nan, float("nan"), pd.NA])
def test_get_indexer_non_unique_nas(
self, any_string_dtype, null, using_infer_string
):
index = Index(["a", "b", null], dtype=any_string_dtype)
indexer, missing = index.get_indexer_non_unique(["a", null])

if using_infer_string:
expected_indexer = np.array([0, 2], dtype=np.intp)
expected_missing = np.array([], dtype=np.intp)
elif any_string_dtype == "string" and (
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
):
expected_indexer = np.array([0, -1], dtype=np.intp)
expected_missing = np.array([1], dtype=np.intp)
else:
expected_indexer = np.array([0, 2], dtype=np.intp)
expected_missing = np.array([], dtype=np.intp)
tm.assert_numpy_array_equal(indexer, expected_indexer)
tm.assert_numpy_array_equal(missing, expected_missing)

# actually non-unique
index = Index(["a", None, "b", None], dtype=any_string_dtype)
indexer, missing = index.get_indexer_non_unique([nulls_fixture])

expected_indexer = np.array([1, 3], dtype=np.intp)
index = Index(["a", null, "b", null], dtype=any_string_dtype)
indexer, missing = index.get_indexer_non_unique(["a", null])

if using_infer_string:
expected_indexer = np.array([0, 1, 3], dtype=np.intp)
elif any_string_dtype == "string" and (
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
):
pass
else:
expected_indexer = np.array([0, 1, 3], dtype=np.intp)
tm.assert_numpy_array_equal(indexer, expected_indexer)
tm.assert_numpy_array_equal(missing, expected_missing)

Expand Down
3 changes: 2 additions & 1 deletion pandas/tests/io/parser/common/test_common_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW
from pandas.errors import (
EmptyDataError,
ParserError,
Expand Down Expand Up @@ -766,7 +767,7 @@ def test_dict_keys_as_names(all_parsers):
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.xfail(using_string_dtype() and HAS_PYARROW, reason="TODO(infer_string)")
@xfail_pyarrow # UnicodeDecodeError: 'utf-8' codec can't decode byte 0xed in position 0
def test_encoding_surrogatepass(all_parsers):
# GH39017
Expand Down

0 comments on commit 98f7e4d

Please sign in to comment.