Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust tests in strings folder for new string option #56159

Merged
merged 13 commits into from
Dec 9, 2023
43 changes: 34 additions & 9 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arrays import ExtensionArray
from pandas.core.base import NoNewAttributesMixin
from pandas.core.construction import extract_array

Expand Down Expand Up @@ -456,7 +457,7 @@ def _get_series_list(self, others):
# in case of list-like `others`, all elements must be
# either Series/Index/np.ndarray (1-dim)...
if all(
isinstance(x, (ABCSeries, ABCIndex))
isinstance(x, (ABCSeries, ABCIndex, ExtensionArray))
or (isinstance(x, np.ndarray) and x.ndim == 1)
for x in others
):
Expand Down Expand Up @@ -690,12 +691,15 @@ def cat(
out: Index | Series
if isinstance(self._orig, ABCIndex):
# add dtype for case that result is all-NA
dtype = None
if isna(result).all():
dtype = object

out = Index(result, dtype=object, name=self._orig.name)
out = Index(result, dtype=dtype, name=self._orig.name)
else: # Series
if isinstance(self._orig.dtype, CategoricalDtype):
# We need to infer the new categories.
dtype = None
dtype = self._orig.dtype.categories.dtype # type: ignore[assignment]
else:
dtype = self._orig.dtype
res_ser = Series(
Expand Down Expand Up @@ -914,7 +918,13 @@ def split(
if is_re(pat):
regex = True
result = self._data.array._str_split(pat, n, expand, regex)
return self._wrap_result(result, returns_string=expand, expand=expand)
if self._data.dtype == "category":
dtype = self._data.dtype.categories.dtype
else:
dtype = object if self._data.dtype == object else None
return self._wrap_result(
result, expand=expand, returns_string=expand, dtype=dtype
)

@Appender(
_shared_docs["str_split"]
Expand All @@ -932,7 +942,10 @@ def split(
@forbid_nonstring_types(["bytes"])
def rsplit(self, pat=None, *, n=-1, expand: bool = False):
result = self._data.array._str_rsplit(pat, n=n)
return self._wrap_result(result, expand=expand, returns_string=expand)
dtype = object if self._data.dtype == object else None
return self._wrap_result(
result, expand=expand, returns_string=expand, dtype=dtype
)

_shared_docs[
"str_partition"
Expand Down Expand Up @@ -1028,7 +1041,13 @@ def rsplit(self, pat=None, *, n=-1, expand: bool = False):
@forbid_nonstring_types(["bytes"])
def partition(self, sep: str = " ", expand: bool = True):
result = self._data.array._str_partition(sep, expand)
return self._wrap_result(result, expand=expand, returns_string=expand)
if self._data.dtype == "category":
dtype = self._data.dtype.categories.dtype
else:
dtype = object if self._data.dtype == object else None
return self._wrap_result(
result, expand=expand, returns_string=expand, dtype=dtype
)

@Appender(
_shared_docs["str_partition"]
Expand All @@ -1042,7 +1061,13 @@ def partition(self, sep: str = " ", expand: bool = True):
@forbid_nonstring_types(["bytes"])
def rpartition(self, sep: str = " ", expand: bool = True):
result = self._data.array._str_rpartition(sep, expand)
return self._wrap_result(result, expand=expand, returns_string=expand)
if self._data.dtype == "category":
dtype = self._data.dtype.categories.dtype
else:
dtype = object if self._data.dtype == object else None
return self._wrap_result(
result, expand=expand, returns_string=expand, dtype=dtype
)

def get(self, i):
"""
Expand Down Expand Up @@ -2748,7 +2773,7 @@ def extract(
else:
name = _get_single_group_name(regex)
result = self._data.array._str_extract(pat, flags=flags, expand=returns_df)
return self._wrap_result(result, name=name)
return self._wrap_result(result, name=name, dtype=result_dtype)

@forbid_nonstring_types(["bytes"])
def extractall(self, pat, flags: int = 0) -> DataFrame:
Expand Down Expand Up @@ -3488,7 +3513,7 @@ def str_extractall(arr, pat, flags: int = 0) -> DataFrame:
raise ValueError("pattern contains no capture groups")

if isinstance(arr, ABCIndex):
arr = arr.to_series().reset_index(drop=True)
arr = arr.to_series().reset_index(drop=True).astype(arr.dtype)

columns = _get_group_names(regex)
match_list = []
Expand Down
6 changes: 5 additions & 1 deletion pandas/tests/strings/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import pytest

from pandas import (
CategoricalDtype,
DataFrame,
Index,
MultiIndex,
Series,
_testing as tm,
option_context,
)
from pandas.core.strings.accessor import StringMethods

Expand Down Expand Up @@ -162,7 +164,8 @@ def test_api_per_method(

if inferred_dtype in allowed_types:
# xref GH 23555, GH 23556
method(*args, **kwargs) # works!
with option_context("future.no_silent_downcasting", True):
method(*args, **kwargs) # works!
else:
# GH 23011, GH 23163
msg = (
Expand All @@ -178,6 +181,7 @@ def test_api_for_categorical(any_string_method, any_string_dtype):
s = Series(list("aabb"), dtype=any_string_dtype)
s = s + " " + s
c = s.astype("category")
c = c.astype(CategoricalDtype(c.dtype.categories.astype("object")))
assert isinstance(c.str, StringMethods)

method_name, args, kwargs = any_string_method
Expand Down
35 changes: 24 additions & 11 deletions pandas/tests/strings/test_case_justify.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def test_title_mixed_object():
s = Series(["FOO", np.nan, "bar", True, datetime.today(), "blah", None, 1, 2.0])
result = s.str.title()
expected = Series(
["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan]
["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan],
dtype=object,
)
tm.assert_almost_equal(result, expected)

Expand All @@ -41,11 +42,15 @@ def test_lower_upper_mixed_object():
s = Series(["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0])

result = s.str.upper()
expected = Series(["A", np.nan, "B", np.nan, np.nan, "FOO", None, np.nan, np.nan])
expected = Series(
["A", np.nan, "B", np.nan, np.nan, "FOO", None, np.nan, np.nan], dtype=object
)
tm.assert_series_equal(result, expected)

result = s.str.lower()
expected = Series(["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan])
expected = Series(
["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan], dtype=object
)
tm.assert_series_equal(result, expected)


Expand All @@ -71,7 +76,8 @@ def test_capitalize_mixed_object():
s = Series(["FOO", np.nan, "bar", True, datetime.today(), "blah", None, 1, 2.0])
result = s.str.capitalize()
expected = Series(
["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan]
["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan],
dtype=object,
)
tm.assert_series_equal(result, expected)

Expand All @@ -87,7 +93,8 @@ def test_swapcase_mixed_object():
s = Series(["FOO", np.nan, "bar", True, datetime.today(), "Blah", None, 1, 2.0])
result = s.str.swapcase()
expected = Series(
["foo", np.nan, "BAR", np.nan, np.nan, "bLAH", None, np.nan, np.nan]
["foo", np.nan, "BAR", np.nan, np.nan, "bLAH", None, np.nan, np.nan],
dtype=object,
)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -138,19 +145,22 @@ def test_pad_mixed_object():

result = s.str.pad(5, side="left")
expected = Series(
[" a", np.nan, " b", np.nan, np.nan, " ee", None, np.nan, np.nan]
[" a", np.nan, " b", np.nan, np.nan, " ee", None, np.nan, np.nan],
dtype=object,
)
tm.assert_series_equal(result, expected)

result = s.str.pad(5, side="right")
expected = Series(
["a ", np.nan, "b ", np.nan, np.nan, "ee ", None, np.nan, np.nan]
["a ", np.nan, "b ", np.nan, np.nan, "ee ", None, np.nan, np.nan],
dtype=object,
)
tm.assert_series_equal(result, expected)

result = s.str.pad(5, side="both")
expected = Series(
[" a ", np.nan, " b ", np.nan, np.nan, " ee ", None, np.nan, np.nan]
[" a ", np.nan, " b ", np.nan, np.nan, " ee ", None, np.nan, np.nan],
dtype=object,
)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -238,7 +248,8 @@ def test_center_ljust_rjust_mixed_object():
None,
np.nan,
np.nan,
]
],
dtype=object,
)
tm.assert_series_equal(result, expected)

Expand All @@ -255,7 +266,8 @@ def test_center_ljust_rjust_mixed_object():
None,
np.nan,
np.nan,
]
],
dtype=object,
)
tm.assert_series_equal(result, expected)

Expand All @@ -272,7 +284,8 @@ def test_center_ljust_rjust_mixed_object():
None,
np.nan,
np.nan,
]
],
dtype=object,
)
tm.assert_series_equal(result, expected)

Expand Down
Loading
Loading