From cae21a7de6a669177e0318d485487525092fbcc3 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 14 Nov 2024 18:21:17 +0100 Subject: [PATCH] BUG (string dtype): let fillna with invalid value upcast to object dtype (#60296) * BUG (string dtype): let fillna with invalid value upcast to object dtype * fix fillna limit case + update tests for no longer raising (cherry picked from commit 34c39e9078ea8af12871a92bdcea2058553c9869) --- pandas/core/internals/blocks.py | 6 +++--- pandas/tests/frame/indexing/test_where.py | 8 +------- pandas/tests/series/indexing/test_setitem.py | 6 ------ 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 7ee1361912c05..6ae591a5d4ac8 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1710,7 +1710,7 @@ def fillna( return nbs if limit is not None: - mask[mask.cumsum(self.ndim - 1) > limit] = False + mask[mask.cumsum(self.values.ndim - 1) > limit] = False if inplace: nbs = self.putmask( @@ -2136,7 +2136,7 @@ def where( res_values = arr._where(cond, other).T except (ValueError, TypeError): if self.ndim == 1 or self.shape[0] == 1: - if isinstance(self.dtype, IntervalDtype): + if isinstance(self.dtype, (IntervalDtype, StringDtype)): # TestSetitemFloatIntervalWithIntIntervalValues blk = self.coerce_to_target_dtype(orig_other) nbs = blk.where(orig_other, orig_cond, using_cow=using_cow) @@ -2338,7 +2338,7 @@ def fillna( using_cow: bool = False, already_warned=None, ) -> list[Block]: - if isinstance(self.dtype, IntervalDtype): + if isinstance(self.dtype, (IntervalDtype, StringDtype)): # Block.fillna handles coercion (test_fillna_interval) return super().fillna( value=value, diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index 5fd3796d0255a..356257bbfec98 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -1086,15 +1086,9 @@ def test_where_producing_ea_cond_for_np_dtype(): @pytest.mark.parametrize( "replacement", [0.001, True, "snake", None, datetime(2022, 5, 4)] ) -def test_where_int_overflow(replacement, using_infer_string): +def test_where_int_overflow(replacement): # GH 31687 df = DataFrame([[1.0, 2e25, "nine"], [np.nan, 0.1, None]]) - if using_infer_string and replacement not in (None, "snake"): - with pytest.raises( - TypeError, match=f"Invalid value '{replacement}' for dtype 'str'" - ): - df.where(pd.notnull(df), replacement) - return result = df.where(pd.notnull(df), replacement) expected = DataFrame([[1.0, 2e25, "nine"], [replacement, 0.1, replacement]]) diff --git a/pandas/tests/series/indexing/test_setitem.py b/pandas/tests/series/indexing/test_setitem.py index a1263e2d30853..85558e85494eb 100644 --- a/pandas/tests/series/indexing/test_setitem.py +++ b/pandas/tests/series/indexing/test_setitem.py @@ -34,7 +34,6 @@ concat, date_range, interval_range, - isna, period_range, timedelta_range, ) @@ -865,11 +864,6 @@ def test_series_where(self, obj, key, expected, warn, val, is_inplace): obj = obj.copy() arr = obj._values - if obj.dtype == "string" and not (isinstance(val, str) or isna(val)): - with pytest.raises(TypeError, match="Invalid value"): - obj.where(~mask, val) - return - res = obj.where(~mask, val) if val is NA and res.dtype == object: