From 13926e5e298acf328b0c1347f008ef3f9c4eb078 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 31 Oct 2024 17:12:43 +0100 Subject: [PATCH] BUG: preserve (object) dtype in factorize (#60118) * BUG: preserve (object) dtype in factorize * add fallback for float16 --- pandas/core/base.py | 12 +++++++++--- pandas/tests/test_algos.py | 1 - 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pandas/core/base.py b/pandas/core/base.py index 863cf978426e2..58572aab5b20f 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -44,6 +44,7 @@ from pandas.core.dtypes.generic import ( ABCDataFrame, ABCIndex, + ABCMultiIndex, ABCSeries, ) from pandas.core.dtypes.missing import ( @@ -1287,13 +1288,18 @@ def factorize( if uniques.dtype == np.float16: uniques = uniques.astype(np.float32) - if isinstance(self, ABCIndex): - # preserve e.g. MultiIndex + if isinstance(self, ABCMultiIndex): + # preserve MultiIndex uniques = self._constructor(uniques) else: from pandas import Index - uniques = Index(uniques) + try: + uniques = Index(uniques, dtype=self.dtype) + except NotImplementedError: + # not all dtypes are supported in Index that are allowed for Series + # e.g. float16 or bytes + uniques = Index(uniques) return codes, uniques _shared_docs["searchsorted"] = """ diff --git a/pandas/tests/test_algos.py b/pandas/tests/test_algos.py index 81e7d3774b613..dac74a0e32a42 100644 --- a/pandas/tests/test_algos.py +++ b/pandas/tests/test_algos.py @@ -65,7 +65,6 @@ def test_factorize_complex(self): expected_uniques = np.array([(1 + 0j), (2 + 0j), (2 + 1j)], dtype=complex) tm.assert_numpy_array_equal(uniques, expected_uniques) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False) def test_factorize(self, index_or_series_obj, sort): obj = index_or_series_obj result_codes, result_uniques = obj.factorize(sort=sort)