Skip to content

Commit

Permalink
REF: Move checks to object into a variable (pandas-dev#54536)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored and mroeschke committed Aug 18, 2023
1 parent 28e0366 commit 89a7c34
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
2 changes: 2 additions & 0 deletions pandas/tests/strings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Needed for new arrow string dtype
object_pyarrow_numpy = ("object",)
31 changes: 16 additions & 15 deletions pandas/tests/strings/test_find_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Series,
_testing as tm,
)
from pandas.tests.strings import object_pyarrow_numpy

# --------------------------------------------------------------------------------------
# str.contains
Expand All @@ -25,7 +26,7 @@ def test_contains(any_string_dtype):
pat = "mmm[_]+"

result = values.str.contains(pat)
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series(
np.array([False, np.nan, True, True, False], dtype=np.object_),
dtype=expected_dtype,
Expand All @@ -44,7 +45,7 @@ def test_contains(any_string_dtype):
dtype=any_string_dtype,
)
result = values.str.contains(pat)
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand All @@ -71,14 +72,14 @@ def test_contains(any_string_dtype):
pat = "mmm[_]+"

result = values.str.contains(pat)
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series(
np.array([False, np.nan, True, True], dtype=np.object_), dtype=expected_dtype
)
tm.assert_series_equal(result, expected)

result = values.str.contains(pat, na=False)
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -163,7 +164,7 @@ def test_contains_moar(any_string_dtype):
)

result = s.str.contains("a")
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series(
[False, False, False, True, True, False, np.nan, False, False, True],
dtype=expected_dtype,
Expand Down Expand Up @@ -204,7 +205,7 @@ def test_contains_nan(any_string_dtype):
s = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype)

result = s.str.contains("foo", na=False)
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series([False, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand All @@ -220,7 +221,7 @@ def test_contains_nan(any_string_dtype):
tm.assert_series_equal(result, expected)

result = s.str.contains("foo")
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series([np.nan, np.nan, np.nan], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -648,7 +649,7 @@ def test_replace_regex_single_character(regex, any_string_dtype):

def test_match(any_string_dtype):
# New match behavior introduced in 0.13
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"

values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype)
result = values.str.match(".*(BAD[_]+).*(BAD)")
Expand Down Expand Up @@ -703,20 +704,20 @@ def test_match_na_kwarg(any_string_dtype):
s = Series(["a", "b", np.nan], dtype=any_string_dtype)

result = s.str.match("a", na=False)
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series([True, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

result = s.str.match("a")
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series([True, False, np.nan], dtype=expected_dtype)
tm.assert_series_equal(result, expected)


def test_match_case_kwarg(any_string_dtype):
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
result = values.str.match("ab", case=False)
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series([True, True, True, True], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand All @@ -732,7 +733,7 @@ def test_fullmatch(any_string_dtype):
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
)
result = ser.str.fullmatch(".*BAD[_]+.*BAD")
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series([True, False, np.nan, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand All @@ -742,14 +743,14 @@ def test_fullmatch_na_kwarg(any_string_dtype):
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
)
result = ser.str.fullmatch(".*BAD[_]+.*BAD", na=False)
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series([True, False, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)


def test_fullmatch_case_kwarg(any_string_dtype):
ser = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"

expected = Series([True, False, False, False], dtype=expected_dtype)

Expand Down Expand Up @@ -877,7 +878,7 @@ def test_find_nan(any_string_dtype):
ser = Series(
["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"], dtype=any_string_dtype
)
expected_dtype = np.float64 if any_string_dtype == "object" else "Int64"
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"

result = ser.str.find("EF")
expected = Series([4, np.nan, 1, np.nan, -1], dtype=expected_dtype)
Expand Down
17 changes: 9 additions & 8 deletions pandas/tests/strings/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
import pandas._testing as tm
from pandas.core.strings.accessor import StringMethods
from pandas.tests.strings import object_pyarrow_numpy


@pytest.mark.parametrize("pattern", [0, True, Series(["foo", "bar"])])
Expand All @@ -40,7 +41,7 @@ def test_iter_raises():
def test_count(any_string_dtype):
ser = Series(["foo", "foofoo", np.nan, "foooofooofommmfoo"], dtype=any_string_dtype)
result = ser.str.count("f[o]+")
expected_dtype = np.float64 if any_string_dtype == "object" else "Int64"
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"
expected = Series([1, 2, np.nan, 4], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -91,7 +92,7 @@ def test_repeat_with_null(any_string_dtype, arg, repeat):

def test_empty_str_methods(any_string_dtype):
empty_str = empty = Series(dtype=any_string_dtype)
if any_string_dtype == "object":
if any_string_dtype in object_pyarrow_numpy:
empty_int = Series(dtype="int64")
empty_bool = Series(dtype=bool)
else:
Expand Down Expand Up @@ -205,7 +206,7 @@ def test_ismethods(method, expected, any_string_dtype):
ser = Series(
["A", "b", "Xy", "4", "3A", "", "TT", "55", "-", " "], dtype=any_string_dtype
)
expected_dtype = "bool" if any_string_dtype == "object" else "boolean"
expected_dtype = "bool" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series(expected, dtype=expected_dtype)
result = getattr(ser.str, method)()
tm.assert_series_equal(result, expected)
Expand All @@ -230,7 +231,7 @@ def test_isnumeric_unicode(method, expected, any_string_dtype):
ser = Series(
["A", "3", "¼", "★", "፸", "3", "four"], dtype=any_string_dtype # noqa: RUF001
)
expected_dtype = "bool" if any_string_dtype == "object" else "boolean"
expected_dtype = "bool" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series(expected, dtype=expected_dtype)
result = getattr(ser.str, method)()
tm.assert_series_equal(result, expected)
Expand All @@ -250,7 +251,7 @@ def test_isnumeric_unicode(method, expected, any_string_dtype):
def test_isnumeric_unicode_missing(method, expected, any_string_dtype):
values = ["A", np.nan, "¼", "★", np.nan, "3", "four"] # noqa: RUF001
ser = Series(values, dtype=any_string_dtype)
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected = Series(expected, dtype=expected_dtype)
result = getattr(ser.str, method)()
tm.assert_series_equal(result, expected)
Expand Down Expand Up @@ -280,7 +281,7 @@ def test_len(any_string_dtype):
dtype=any_string_dtype,
)
result = ser.str.len()
expected_dtype = "float64" if any_string_dtype == "object" else "Int64"
expected_dtype = "float64" if any_string_dtype in object_pyarrow_numpy else "Int64"
expected = Series([3, 4, 6, np.nan, 8, 4, 1], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -309,7 +310,7 @@ def test_index(method, sub, start, end, index_or_series, any_string_dtype, expec
obj = index_or_series(
["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"], dtype=any_string_dtype
)
expected_dtype = np.int64 if any_string_dtype == "object" else "Int64"
expected_dtype = np.int64 if any_string_dtype in object_pyarrow_numpy else "Int64"
expected = index_or_series(expected, dtype=expected_dtype)

result = getattr(obj.str, method)(sub, start, end)
Expand Down Expand Up @@ -350,7 +351,7 @@ def test_index_wrong_type_raises(index_or_series, any_string_dtype, method):
)
def test_index_missing(any_string_dtype, method, exp):
ser = Series(["abcb", "ab", "bcbe", np.nan], dtype=any_string_dtype)
expected_dtype = np.float64 if any_string_dtype == "object" else "Int64"
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"

result = getattr(ser.str, method)("b")
expected = Series(exp + [np.nan], dtype=expected_dtype)
Expand Down

0 comments on commit 89a7c34

Please sign in to comment.