From 34c39e9078ea8af12871a92bdcea2058553c9869 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 14 Nov 2024 18:21:17 +0100 Subject: [PATCH] BUG (string dtype): let fillna with invalid value upcast to object dtype (#60296) * BUG (string dtype): let fillna with invalid value upcast to object dtype * fix fillna limit case + update tests for no longer raising --- pandas/core/internals/blocks.py | 9 +++++---- pandas/tests/frame/indexing/test_where.py | 8 +------- pandas/tests/series/indexing/test_setitem.py | 5 ----- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 3264676771d5d..3c207e8c14b5b 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -108,6 +108,7 @@ PeriodArray, TimedeltaArray, ) +from pandas.core.arrays.string_ import StringDtype from pandas.core.base import PandasObject import pandas.core.common as com from pandas.core.computation import expressions @@ -1336,7 +1337,7 @@ def fillna( return [self.copy(deep=False)] if limit is not None: - mask[mask.cumsum(self.ndim - 1) > limit] = False + mask[mask.cumsum(self.values.ndim - 1) > limit] = False if inplace: nbs = self.putmask(mask.T, value) @@ -1684,7 +1685,7 @@ def where(self, other, cond) -> list[Block]: res_values = arr._where(cond, other).T except (ValueError, TypeError): if self.ndim == 1 or self.shape[0] == 1: - if isinstance(self.dtype, IntervalDtype): + if isinstance(self.dtype, (IntervalDtype, StringDtype)): # TestSetitemFloatIntervalWithIntIntervalValues blk = self.coerce_to_target_dtype(orig_other, raise_on_upcast=False) return blk.where(orig_other, orig_cond) @@ -1854,9 +1855,9 @@ def fillna( limit: int | None = None, inplace: bool = False, ) -> list[Block]: - if isinstance(self.dtype, IntervalDtype): + if isinstance(self.dtype, (IntervalDtype, StringDtype)): # Block.fillna handles coercion (test_fillna_interval) - if limit is not None: + if isinstance(self.dtype, IntervalDtype) and limit is not None: raise ValueError("limit must be None") return super().fillna( value=value, diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index 86b39ddd19ec1..d6570fcda2ee8 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -1025,15 +1025,9 @@ def test_where_producing_ea_cond_for_np_dtype(): @pytest.mark.parametrize( "replacement", [0.001, True, "snake", None, datetime(2022, 5, 4)] ) -def test_where_int_overflow(replacement, using_infer_string): +def test_where_int_overflow(replacement): # GH 31687 df = DataFrame([[1.0, 2e25, "nine"], [np.nan, 0.1, None]]) - if using_infer_string and replacement not in (None, "snake"): - with pytest.raises( - TypeError, match=f"Invalid value '{replacement}' for dtype 'str'" - ): - df.where(pd.notnull(df), replacement) - return result = df.where(pd.notnull(df), replacement) expected = DataFrame([[1.0, 2e25, "nine"], [replacement, 0.1, replacement]]) diff --git a/pandas/tests/series/indexing/test_setitem.py b/pandas/tests/series/indexing/test_setitem.py index 0d62317893326..158198239ba75 100644 --- a/pandas/tests/series/indexing/test_setitem.py +++ b/pandas/tests/series/indexing/test_setitem.py @@ -839,11 +839,6 @@ def test_series_where(self, obj, key, expected, raises, val, is_inplace): obj = obj.copy() arr = obj._values - if raises and obj.dtype == "string": - with pytest.raises(TypeError, match="Invalid value"): - obj.where(~mask, val) - return - res = obj.where(~mask, val) if val is NA and res.dtype == object: