From d862ecaf3aa05082d36512f994d99106cc7c10b1 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Wed, 16 Aug 2023 18:53:57 +0200 Subject: [PATCH] Na return value --- pandas/core/arrays/string_.py | 5 ++- pandas/tests/strings/__init__.py | 14 +++++++ pandas/tests/strings/test_find_replace.py | 9 ++-- pandas/tests/strings/test_split_partition.py | 44 +++++++------------- 4 files changed, 38 insertions(+), 34 deletions(-) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 1e285f90e9fea..4c28360c732a3 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -101,7 +101,10 @@ class StringDtype(StorageExtensionDtype): #: StringDtype().na_value uses pandas.NA @property def na_value(self) -> libmissing.NAType: - return libmissing.NA + if self.storage == "pyarrow_numpy": + return np.nan + else: + return libmissing.NA _metadata = ("storage",) diff --git a/pandas/tests/strings/__init__.py b/pandas/tests/strings/__init__.py index 326ae24410502..01b49b5e5b633 100644 --- a/pandas/tests/strings/__init__.py +++ b/pandas/tests/strings/__init__.py @@ -1 +1,15 @@ +import numpy as np + +import pandas as pd + object_pyarrow_numpy = ("object", "string[pyarrow_numpy]") + + +def _convert_na_value(ser, expected): + if ser.dtype != object: + if ser.dtype.storage == "pyarrow_numpy": + expected = expected.fillna(np.nan) + else: + # GH#18463 + expected = expected.fillna(pd.NA) + return expected diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index f62299f53aebe..2320ab4ed8b02 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -11,7 +11,10 @@ Series, _testing as tm, ) -from pandas.tests.strings import object_pyarrow_numpy +from pandas.tests.strings import ( + _convert_na_value, + object_pyarrow_numpy, +) # -------------------------------------------------------------------------------------- # str.contains @@ -780,9 +783,7 @@ def test_findall(any_string_dtype): ser = Series(["fooBAD__barBAD", np.nan, "foo", "BAD"], dtype=any_string_dtype) result = ser.str.findall("BAD[_]*") expected = Series([["BAD__", "BAD"], np.nan, [], ["BAD"]]) - if ser.dtype != object: - # GH#18463 - expected = expected.fillna(pd.NA) + expected = _convert_na_value(ser, expected) tm.assert_series_equal(result, expected) diff --git a/pandas/tests/strings/test_split_partition.py b/pandas/tests/strings/test_split_partition.py index 0298694ccaf71..0a7d409773dd6 100644 --- a/pandas/tests/strings/test_split_partition.py +++ b/pandas/tests/strings/test_split_partition.py @@ -12,6 +12,10 @@ Series, _testing as tm, ) +from pandas.tests.strings import ( + _convert_na_value, + object_pyarrow_numpy, +) @pytest.mark.parametrize("method", ["split", "rsplit"]) @@ -20,9 +24,7 @@ def test_split(any_string_dtype, method): result = getattr(values.str, method)("_") exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]]) - if values.dtype != object: - # GH#18463 - exp = exp.fillna(pd.NA) + exp = _convert_na_value(values, exp) tm.assert_series_equal(result, exp) @@ -32,9 +34,7 @@ def test_split_more_than_one_char(any_string_dtype, method): values = Series(["a__b__c", "c__d__e", np.nan, "f__g__h"], dtype=any_string_dtype) result = getattr(values.str, method)("__") exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]]) - if values.dtype != object: - # GH#18463 - exp = exp.fillna(pd.NA) + exp = _convert_na_value(values, exp) tm.assert_series_equal(result, exp) result = getattr(values.str, method)("__", expand=False) @@ -46,9 +46,7 @@ def test_split_more_regex_split(any_string_dtype): values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype) result = values.str.split("[,_]") exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]]) - if values.dtype != object: - # GH#18463 - exp = exp.fillna(pd.NA) + exp = _convert_na_value(values, exp) tm.assert_series_equal(result, exp) @@ -118,8 +116,8 @@ def test_split_object_mixed(expand, method): def test_split_n(any_string_dtype, method, n): s = Series(["a b", pd.NA, "b c"], dtype=any_string_dtype) expected = Series([["a", "b"], pd.NA, ["b", "c"]]) - result = getattr(s.str, method)(" ", n=n) + expected = _convert_na_value(s, expected) tm.assert_series_equal(result, expected) @@ -128,9 +126,7 @@ def test_rsplit(any_string_dtype): values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype) result = values.str.rsplit("[,_]") exp = Series([["a,b_c"], ["c_d,e"], np.nan, ["f,g,h"]]) - if values.dtype != object: - # GH#18463 - exp = exp.fillna(pd.NA) + exp = _convert_na_value(values, exp) tm.assert_series_equal(result, exp) @@ -139,9 +135,7 @@ def test_rsplit_max_number(any_string_dtype): values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype) result = values.str.rsplit("_", n=1) exp = Series([["a_b", "c"], ["c_d", "e"], np.nan, ["f_g", "h"]]) - if values.dtype != object: - # GH#18463 - exp = exp.fillna(pd.NA) + exp = _convert_na_value(values, exp) tm.assert_series_equal(result, exp) @@ -390,7 +384,7 @@ def test_split_nan_expand(any_string_dtype): # check that these are actually np.nan/pd.NA and not None # TODO see GH 18463 # tm.assert_frame_equal does not differentiate - if any_string_dtype == "object": + if any_string_dtype in object_pyarrow_numpy: assert all(np.isnan(x) for x in result.iloc[1]) else: assert all(x is pd.NA for x in result.iloc[1]) @@ -455,9 +449,7 @@ def test_partition_series_more_than_one_char(method, exp, any_string_dtype): s = Series(["a__b__c", "c__d__e", np.nan, "f__g__h", None], dtype=any_string_dtype) result = getattr(s.str, method)("__", expand=False) expected = Series(exp) - if s.dtype != object: - # GH#18463 - expected = expected.fillna(pd.NA) + expected = _convert_na_value(s, expected) tm.assert_series_equal(result, expected) @@ -480,9 +472,7 @@ def test_partition_series_none(any_string_dtype, method, exp): s = Series(["a b c", "c d e", np.nan, "f g h", None], dtype=any_string_dtype) result = getattr(s.str, method)(expand=False) expected = Series(exp) - if s.dtype != object: - # GH#18463 - expected = expected.fillna(pd.NA) + expected = _convert_na_value(s, expected) tm.assert_series_equal(result, expected) @@ -505,9 +495,7 @@ def test_partition_series_not_split(any_string_dtype, method, exp): s = Series(["abc", "cde", np.nan, "fgh", None], dtype=any_string_dtype) result = getattr(s.str, method)("_", expand=False) expected = Series(exp) - if s.dtype != object: - # GH#18463 - expected = expected.fillna(pd.NA) + expected = _convert_na_value(s, expected) tm.assert_series_equal(result, expected) @@ -531,9 +519,7 @@ def test_partition_series_unicode(any_string_dtype, method, exp): result = getattr(s.str, method)("_", expand=False) expected = Series(exp) - if s.dtype != object: - # GH#18463 - expected = expected.fillna(pd.NA) + expected = _convert_na_value(s, expected) tm.assert_series_equal(result, expected)