From 1c6c3e3a930c0db92294445c25f213efb6ec1fe7 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Sat, 9 Dec 2023 23:55:10 +0100 Subject: [PATCH] Fixup --- pandas/core/strings/accessor.py | 11 +++++------ pandas/tests/strings/test_cat.py | 21 ++++++++++++++++----- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 75866c6f6013a..631030d637fae 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -689,19 +689,18 @@ def cat( result = cat_safe(all_cols, sep) out: Index | Series + if isinstance(self._orig.dtype, CategoricalDtype): + # We need to infer the new categories. + dtype = self._orig.dtype.categories.dtype # type: ignore[assignment] + else: + dtype = self._orig.dtype if isinstance(self._orig, ABCIndex): # add dtype for case that result is all-NA - dtype = None if isna(result).all(): dtype = object out = Index(result, dtype=dtype, name=self._orig.name) else: # Series - if isinstance(self._orig.dtype, CategoricalDtype): - # We need to infer the new categories. - dtype = self._orig.dtype.categories.dtype # type: ignore[assignment] - else: - dtype = self._orig.dtype res_ser = Series( result, dtype=dtype, index=data.index, name=self._orig.name, copy=False ) diff --git a/pandas/tests/strings/test_cat.py b/pandas/tests/strings/test_cat.py index 284932491a65e..c1e7ad6e02779 100644 --- a/pandas/tests/strings/test_cat.py +++ b/pandas/tests/strings/test_cat.py @@ -98,14 +98,18 @@ def test_str_cat_categorical( with option_context("future.infer_string", infer_string): s = Index(["a", "a", "b", "a"], dtype=dtype_caller) - s = s if box == Index else Series(s, index=s) + s = s if box == Index else Series(s, index=s, dtype=s.dtype) t = Index(["b", "a", "b", "c"], dtype=dtype_target) - expected = Index(["ab", "aa", "bb", "ac"]) + expected = Index( + ["ab", "aa", "bb", "ac"], dtype=object if dtype_caller == "object" else None + ) expected = ( expected if box == Index - else Series(expected, index=Index(s, dtype=dtype_caller)) + else Series( + expected, index=Index(s, dtype=dtype_caller), dtype=expected.dtype + ) ) # Series/Index with unaligned Index -> t.values @@ -123,12 +127,19 @@ def test_str_cat_categorical( # Series/Index with Series having different Index t = Series(t.values, index=t.values) - expected = Index(["aa", "aa", "bb", "bb", "aa"]) + expected = Index( + ["aa", "aa", "bb", "bb", "aa"], + dtype=object if dtype_caller == "object" else None, + ) dtype = object if dtype_caller == "object" else s.dtype.categories.dtype expected = ( expected if box == Index - else Series(expected, index=Index(expected.str[:1], dtype=dtype)) + else Series( + expected, + index=Index(expected.str[:1], dtype=dtype), + dtype=expected.dtype, + ) ) result = s.str.cat(t, sep=sep)