diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 043646457f604..7df6e7d0e3166 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -740,6 +740,7 @@ Categorical ^^^^^^^^^^^ - :meth:`Categorical.isin` raising ``InvalidIndexError`` for categorical containing overlapping :class:`Interval` values (:issue:`34974`) - Bug in :meth:`CategoricalDtype.__eq__` returning ``False`` for unordered categorical data with mixed types (:issue:`55468`) +- Bug when casting ``pa.dictionary`` to :class:`CategoricalDtype` using a ``pa.DictionaryArray`` as categories (:issue:`56672`) Datetimelike ^^^^^^^^^^^^ diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 065a942cae768..b87c5375856dc 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -44,7 +44,9 @@ pandas_dtype, ) from pandas.core.dtypes.dtypes import ( + ArrowDtype, CategoricalDtype, + CategoricalDtypeType, ExtensionDtype, ) from pandas.core.dtypes.generic import ( @@ -443,24 +445,32 @@ def __init__( values = arr if dtype.categories is None: - if not isinstance(values, ABCIndex): - # in particular RangeIndex xref test_index_equal_range_categories - values = sanitize_array(values, None) - try: - codes, categories = factorize(values, sort=True) - except TypeError as err: - codes, categories = factorize(values, sort=False) - if dtype.ordered: - # raise, as we don't have a sortable data structure and so - # the user should give us one by specifying categories - raise TypeError( - "'values' is not ordered, please " - "explicitly specify the categories order " - "by passing in a categories argument." - ) from err - - # we're inferring from values - dtype = CategoricalDtype(categories, dtype.ordered) + if isinstance(values.dtype, ArrowDtype) and issubclass( + values.dtype.type, CategoricalDtypeType + ): + arr = values._pa_array.combine_chunks() + categories = arr.dictionary.to_pandas(types_mapper=ArrowDtype) + codes = arr.indices.to_numpy() + dtype = CategoricalDtype(categories, values.dtype.pyarrow_dtype.ordered) + else: + if not isinstance(values, ABCIndex): + # in particular RangeIndex xref test_index_equal_range_categories + values = sanitize_array(values, None) + try: + codes, categories = factorize(values, sort=True) + except TypeError as err: + codes, categories = factorize(values, sort=False) + if dtype.ordered: + # raise, as we don't have a sortable data structure and so + # the user should give us one by specifying categories + raise TypeError( + "'values' is not ordered, please " + "explicitly specify the categories order " + "by passing in a categories argument." + ) from err + + # we're inferring from values + dtype = CategoricalDtype(categories, dtype.ordered) elif isinstance(values.dtype, CategoricalDtype): old_codes = extract_array(values)._codes diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index e709e6fcfe456..6689fb34f2ae3 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -3234,6 +3234,22 @@ def test_factorize_chunked_dictionary(): tm.assert_index_equal(res_uniques, exp_uniques) +def test_dictionary_astype_categorical(): + # GH#56672 + arrs = [ + pa.array(np.array(["a", "x", "c", "a"])).dictionary_encode(), + pa.array(np.array(["a", "d", "c"])).dictionary_encode(), + ] + ser = pd.Series(ArrowExtensionArray(pa.chunked_array(arrs))) + result = ser.astype("category") + categories = pd.Index(["a", "x", "c", "d"], dtype=ArrowDtype(pa.string())) + expected = pd.Series( + ["a", "x", "c", "a", "a", "d", "c"], + dtype=pd.CategoricalDtype(categories=categories), + ) + tm.assert_series_equal(result, expected) + + def test_arrow_floordiv(): # GH 55561 a = pd.Series([-7], dtype="int64[pyarrow]")