Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Return numpy types from ArrowExtensionArray.to_numpy for temporal types when possible #56459

Merged
merged 3 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ documentation.

.. _whatsnew_220.enhancements.to_numpy_ea:

ExtensionArray.to_numpy converts to suitable NumPy dtype
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``to_numpy`` for NumPy nullable and Arrow types converts to suitable NumPy dtype
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

:meth:`ExtensionArray.to_numpy` will now convert to a suitable NumPy dtype instead
of ``object`` dtype for nullable extension dtypes.
``to_numpy`` for NumPy nullable and Arrow types will now convert to a
suitable NumPy dtype instead of ``object`` dtype for nullable extension dtypes.

*Old behavior:*

Expand All @@ -128,13 +128,17 @@ of ``object`` dtype for nullable extension dtypes.
ser = pd.Series([1, 2, 3], dtype="Int64")
ser.to_numpy()

ser = pd.Series([1, 2, 3], dtype="timestamp[ns][pyarrow]")
ser.to_numpy()

The default NumPy dtype (without any arguments) is determined as follows:

- float dtypes are cast to NumPy floats
- integer dtypes without missing values are cast to NumPy integer dtypes
- integer dtypes with missing values are cast to NumPy float dtypes and ``NaN`` is used as missing value indicator
- boolean dtypes without missing values are cast to NumPy bool dtype
- boolean dtypes with missing values keep object dtype
- datetime and timedelta types are cast to Numpy datetime64 and timedelta64 types respectively and ``NaT`` is used as missing value indicator

.. _whatsnew_220.enhancements.struct_accessor:

Expand Down
29 changes: 24 additions & 5 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pandas.core import (
algorithms as algos,
missing,
ops,
roperator,
)
from pandas.core.arraylike import OpsMixin
Expand Down Expand Up @@ -660,7 +661,11 @@ def _cmp_method(self, other, op):
mask = isna(self) | isna(other)
valid = ~mask
result = np.zeros(len(self), dtype="bool")
result[valid] = op(np.array(self)[valid], other)
np_array = np.array(self)
try:
result[valid] = op(np_array[valid], other)
except TypeError:
result = ops.invalid_comparison(np_array, other, op)
result = pa.array(result, type=pa.bool_())
result = pc.if_else(valid, result, None)
else:
Expand Down Expand Up @@ -1135,7 +1140,16 @@ def searchsorted(
if isinstance(value, ExtensionArray):
value = value.astype(object)
# Base class searchsorted would cast to object, which is *much* slower.
return self.to_numpy().searchsorted(value, side=side, sorter=sorter)
dtype = None
if isinstance(self.dtype, ArrowDtype):
pa_dtype = self.dtype.pyarrow_dtype
if (
pa.types.is_timestamp(pa_dtype) or pa.types.is_duration(pa_dtype)
) and pa_dtype.unit == "ns":
# np.array[datetime/timedelta].searchsorted(datetime/timedelta)
# erroneously fails when numpy type resolution is nanoseconds
dtype = object
return self.to_numpy(dtype=dtype).searchsorted(value, side=side, sorter=sorter)

def take(
self,
Expand Down Expand Up @@ -1286,10 +1300,15 @@ def to_numpy(

if pa.types.is_timestamp(pa_type) or pa.types.is_duration(pa_type):
result = data._maybe_convert_datelike_array()
if dtype is None or dtype.kind == "O":
result = result.to_numpy(dtype=object, na_value=na_value)
if (pa.types.is_timestamp(pa_type) and pa_type.tz is not None) or (
dtype is not None and dtype.kind == "O"
):
dtype = object
else:
result = result.to_numpy(dtype=dtype)
# GH 55997
dtype = None
na_value = pa_type.to_pandas_dtype().type("nat", pa_type.unit)
result = result.to_numpy(dtype=dtype, na_value=na_value)
elif pa.types.is_time(pa_type) or pa.types.is_date(pa_type):
# convert to list of python datetime.time objects before
# wrapping in ndarray
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/arrays/sparse/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def from_spmatrix(cls, data, index=None, columns=None) -> DataFrame:
>>> mat = scipy.sparse.eye(3, dtype=float)
>>> pd.DataFrame.sparse.from_spmatrix(mat)
0 1 2
0 1.0 0.0 0.0
1 0.0 1.0 0.0
2 0.0 0.0 1.0
0 1.0 0 0
1 0 1.0 0
2 0 0 1.0
"""
from pandas._libs.sparse import IntIndex

Expand Down
5 changes: 1 addition & 4 deletions pandas/io/formats/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
)

from pandas.core.arrays import (
BaseMaskedArray,
Categorical,
DatetimeArray,
ExtensionArray,
Expand Down Expand Up @@ -1528,10 +1527,8 @@ def _format_strings(self) -> list[str]:
if isinstance(values, Categorical):
# Categorical is special for now, so that we can preserve tzinfo
array = values._internal_get_values()
elif isinstance(values, BaseMaskedArray):
array = values.to_numpy(dtype=object)
else:
array = np.asarray(values)
array = np.asarray(values, dtype=object)

fmt_values = format_array(
array,
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
other = 42

if op_name not in ["__eq__", "__ne__"]:
with pytest.raises(TypeError, match="not supported between"):
with pytest.raises(TypeError, match="Invalid comparison|not supported between"):
getattr(a, op_name)(other)

return
Expand Down
81 changes: 61 additions & 20 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
pa_version_under13p0,
pa_version_under14p0,
)
import pandas.util._test_decorators as td

from pandas.core.dtypes.dtypes import (
ArrowDtype,
Expand Down Expand Up @@ -266,6 +267,19 @@ def data_for_twos(data):


class TestArrowArray(base.ExtensionTests):
def test_compare_scalar(self, data, comparison_op):
ser = pd.Series(data)
self._compare_other(ser, data, comparison_op, data[0])

@pytest.mark.parametrize("na_action", [None, "ignore"])
def test_map(self, data_missing, na_action):
if data_missing.dtype.kind in "mM":
result = data_missing.map(lambda x: x, na_action=na_action)
expected = data_missing.to_numpy(dtype=object)
tm.assert_numpy_array_equal(result, expected)
else:
super().test_map(data_missing, na_action)

def test_astype_str(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_binary(pa_dtype):
Expand All @@ -274,8 +288,35 @@ def test_astype_str(self, data, request):
reason=f"For {pa_dtype} .astype(str) decodes.",
)
)
elif (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
) or pa.types.is_duration(pa_dtype):
request.applymarker(
pytest.mark.xfail(
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
)
)
super().test_astype_str(data)

@pytest.mark.parametrize(
"nullable_string_dtype",
[
"string[python]",
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
],
)
def test_astype_string(self, data, nullable_string_dtype, request):
pa_dtype = data.dtype.pyarrow_dtype
if (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
) or pa.types.is_duration(pa_dtype):
request.applymarker(
pytest.mark.xfail(
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
)
)
super().test_astype_string(data, nullable_string_dtype)

def test_from_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
Expand Down Expand Up @@ -1511,11 +1552,9 @@ def test_to_numpy_with_defaults(data):
result = data.to_numpy()

pa_type = data._pa_array.type
if (
pa.types.is_duration(pa_type)
or pa.types.is_timestamp(pa_type)
or pa.types.is_date(pa_type)
):
if pa.types.is_duration(pa_type) or pa.types.is_timestamp(pa_type):
pytest.skip("Tested in test_to_numpy_temporal")
elif pa.types.is_date(pa_type):
expected = np.array(list(data))
else:
expected = np.array(data._pa_array)
Expand Down Expand Up @@ -2937,26 +2976,28 @@ def test_groupby_series_size_returns_pa_int(data):


@pytest.mark.parametrize(
"pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES
"pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES, ids=repr
)
def test_to_numpy_temporal(pa_type):
@pytest.mark.parametrize("dtype", [None, object])
def test_to_numpy_temporal(pa_type, dtype):
# GH 53326
# GH 55997: Return datetime64/timedelta64 types with NaT if possible
arr = ArrowExtensionArray(pa.array([1, None], type=pa_type))
result = arr.to_numpy()
result = arr.to_numpy(dtype=dtype)
if pa.types.is_duration(pa_type):
expected = [
pd.Timedelta(1, unit=pa_type.unit).as_unit(pa_type.unit),
pd.NA,
]
assert isinstance(result[0], pd.Timedelta)
value = pd.Timedelta(1, unit=pa_type.unit).as_unit(pa_type.unit)
else:
expected = [
pd.Timestamp(1, unit=pa_type.unit, tz=pa_type.tz).as_unit(pa_type.unit),
pd.NA,
]
assert isinstance(result[0], pd.Timestamp)
expected = np.array(expected, dtype=object)
assert result[0].unit == expected[0].unit
value = pd.Timestamp(1, unit=pa_type.unit, tz=pa_type.tz).as_unit(pa_type.unit)

if dtype == object or (pa.types.is_timestamp(pa_type) and pa_type.tz is not None):
na = pd.NA
expected = np.array([value, na], dtype=object)
assert result[0].unit == value.unit
else:
na = pa_type.to_pandas_dtype().type("nat", pa_type.unit)
value = value.to_numpy()
expected = np.array([value, na])
assert np.datetime_data(result[0])[0] == pa_type.unit
tm.assert_numpy_array_equal(result, expected)


Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/io/formats/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,7 +1921,7 @@ def dtype(self):
series = Series(ExtTypeStub(), copy=False)
res = repr(series) # This line crashed before #33770 was fixed.
expected = "\n".join(
["0 [False True]", "1 [ True False]", "dtype: DtypeStub"]
["0 [False True]", "1 [True False]", "dtype: DtypeStub"]
)
assert res == expected

Expand Down