Skip to content

Commit

Permalink
Fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Dec 9, 2023
1 parent 329edb3 commit 1c6c3e3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
11 changes: 5 additions & 6 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,19 +689,18 @@ def cat(
result = cat_safe(all_cols, sep)

out: Index | Series
if isinstance(self._orig.dtype, CategoricalDtype):
# We need to infer the new categories.
dtype = self._orig.dtype.categories.dtype # type: ignore[assignment]
else:
dtype = self._orig.dtype
if isinstance(self._orig, ABCIndex):
# add dtype for case that result is all-NA
dtype = None
if isna(result).all():
dtype = object

out = Index(result, dtype=dtype, name=self._orig.name)
else: # Series
if isinstance(self._orig.dtype, CategoricalDtype):
# We need to infer the new categories.
dtype = self._orig.dtype.categories.dtype # type: ignore[assignment]
else:
dtype = self._orig.dtype
res_ser = Series(
result, dtype=dtype, index=data.index, name=self._orig.name, copy=False
)
Expand Down
21 changes: 16 additions & 5 deletions pandas/tests/strings/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,18 @@ def test_str_cat_categorical(

with option_context("future.infer_string", infer_string):
s = Index(["a", "a", "b", "a"], dtype=dtype_caller)
s = s if box == Index else Series(s, index=s)
s = s if box == Index else Series(s, index=s, dtype=s.dtype)
t = Index(["b", "a", "b", "c"], dtype=dtype_target)

expected = Index(["ab", "aa", "bb", "ac"])
expected = Index(
["ab", "aa", "bb", "ac"], dtype=object if dtype_caller == "object" else None
)
expected = (
expected
if box == Index
else Series(expected, index=Index(s, dtype=dtype_caller))
else Series(
expected, index=Index(s, dtype=dtype_caller), dtype=expected.dtype
)
)

# Series/Index with unaligned Index -> t.values
Expand All @@ -123,12 +127,19 @@ def test_str_cat_categorical(

# Series/Index with Series having different Index
t = Series(t.values, index=t.values)
expected = Index(["aa", "aa", "bb", "bb", "aa"])
expected = Index(
["aa", "aa", "bb", "bb", "aa"],
dtype=object if dtype_caller == "object" else None,
)
dtype = object if dtype_caller == "object" else s.dtype.categories.dtype
expected = (
expected
if box == Index
else Series(expected, index=Index(expected.str[:1], dtype=dtype))
else Series(
expected,
index=Index(expected.str[:1], dtype=dtype),
dtype=expected.dtype,
)
)

result = s.str.cat(t, sep=sep)
Expand Down

0 comments on commit 1c6c3e3

Please sign in to comment.