Skip to content

Commit

Permalink
BUG: Return numpy types from ArrowExtensionArray.to_numpy for tempora…
Browse files Browse the repository at this point in the history
…l types when possible
  • Loading branch information
mroeschke committed Dec 11, 2023
1 parent b0ffccd commit d1d6f06
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 35 deletions.
12 changes: 8 additions & 4 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ documentation.

.. _whatsnew_220.enhancements.to_numpy_ea:

ExtensionArray.to_numpy converts to suitable NumPy dtype
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``to_numpy`` for NumPy nullable and Arrow types converts to suitable NumPy dtype
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

:meth:`ExtensionArray.to_numpy` will now convert to a suitable NumPy dtype instead
of ``object`` dtype for nullable extension dtypes.
``to_numpy`` for NumPy nullable and Arrow types will now convert to a
suitable NumPy dtype instead of ``object`` dtype for nullable extension dtypes.

*Old behavior:*

Expand All @@ -128,13 +128,17 @@ of ``object`` dtype for nullable extension dtypes.
ser = pd.Series([1, 2, 3], dtype="Int64")
ser.to_numpy()
ser = pd.Series([1, 2, 3], dtype="timestamp[ns][pyarrow]")
ser.to_numpy()
The default NumPy dtype (without any arguments) is determined as follows:

- float dtypes are cast to NumPy floats
- integer dtypes without missing values are cast to NumPy integer dtypes
- integer dtypes with missing values are cast to NumPy float dtypes and ``NaN`` is used as missing value indicator
- boolean dtypes without missing values are cast to NumPy bool dtype
- boolean dtypes with missing values keep object dtype
- datetime and timedelta types are cast to Numpy datetime64 and timedelta64 types respectively and ``NaT`` is used as missing value indicator

.. _whatsnew_220.enhancements.struct_accessor:

Expand Down
29 changes: 24 additions & 5 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from pandas.core import (
algorithms as algos,
missing,
ops,
roperator,
)
from pandas.core.arraylike import OpsMixin
Expand Down Expand Up @@ -655,7 +656,11 @@ def _cmp_method(self, other, op):
mask = isna(self) | isna(other)
valid = ~mask
result = np.zeros(len(self), dtype="bool")
result[valid] = op(np.array(self)[valid], other)
np_array = np.array(self)
try:
result[valid] = op(np_array[valid], other)
except TypeError:
result = ops.invalid_comparison(np_array, other, op)
result = pa.array(result, type=pa.bool_())
result = pc.if_else(valid, result, None)
else:
Expand Down Expand Up @@ -1130,7 +1135,16 @@ def searchsorted(
if isinstance(value, ExtensionArray):
value = value.astype(object)
# Base class searchsorted would cast to object, which is *much* slower.
return self.to_numpy().searchsorted(value, side=side, sorter=sorter)
dtype = None
if isinstance(self.dtype, ArrowDtype):
pa_dtype = self.dtype.pyarrow_dtype
if (
pa.types.is_timestamp(pa_dtype) or pa.types.is_duration(pa_dtype)
) and pa_dtype.unit == "ns":
# np.array[datetime/timedelta].searchsorted(datetime/timedelta)
# erroneously fails when numpy type resolution is nanoseconds
dtype = object
return self.to_numpy(dtype=dtype).searchsorted(value, side=side, sorter=sorter)

def take(
self,
Expand Down Expand Up @@ -1281,10 +1295,15 @@ def to_numpy(

if pa.types.is_timestamp(pa_type) or pa.types.is_duration(pa_type):
result = data._maybe_convert_datelike_array()
if dtype is None or dtype.kind == "O":
result = result.to_numpy(dtype=object, na_value=na_value)
if (pa.types.is_timestamp(pa_type) and pa_type.tz is not None) or (
dtype is not None and dtype.kind == "O"
):
dtype = object
else:
result = result.to_numpy(dtype=dtype)
# GH 55997
dtype = None
na_value = pa_type.to_pandas_dtype().type("nat", pa_type.unit)
result = result.to_numpy(dtype=dtype, na_value=na_value)
elif pa.types.is_time(pa_type) or pa.types.is_date(pa_type):
# convert to list of python datetime.time objects before
# wrapping in ndarray
Expand Down
5 changes: 1 addition & 4 deletions pandas/io/formats/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
)

from pandas.core.arrays import (
BaseMaskedArray,
Categorical,
DatetimeArray,
ExtensionArray,
Expand Down Expand Up @@ -1528,10 +1527,8 @@ def _format_strings(self) -> list[str]:
if isinstance(values, Categorical):
# Categorical is special for now, so that we can preserve tzinfo
array = values._internal_get_values()
elif isinstance(values, BaseMaskedArray):
array = values.to_numpy(dtype=object)
else:
array = np.asarray(values)
array = np.asarray(values, dtype=object)

fmt_values = format_array(
array,
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
other = 42

if op_name not in ["__eq__", "__ne__"]:
with pytest.raises(TypeError, match="not supported between"):
with pytest.raises(TypeError, match="Invalid comparison|not supported between"):
getattr(a, op_name)(other)

return
Expand Down
81 changes: 61 additions & 20 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
pa_version_under13p0,
pa_version_under14p0,
)
import pandas.util._test_decorators as td

from pandas.core.dtypes.dtypes import (
ArrowDtype,
Expand Down Expand Up @@ -266,6 +267,19 @@ def data_for_twos(data):


class TestArrowArray(base.ExtensionTests):
def test_compare_scalar(self, data, comparison_op):
ser = pd.Series(data)
self._compare_other(ser, data, comparison_op, data[0])

@pytest.mark.parametrize("na_action", [None, "ignore"])
def test_map(self, data_missing, na_action):
if data_missing.dtype.kind in "mM":
result = data_missing.map(lambda x: x, na_action=na_action)
expected = data_missing.to_numpy(dtype=object)
tm.assert_numpy_array_equal(result, expected)
else:
super().test_map(data_missing, na_action)

def test_astype_str(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_binary(pa_dtype):
Expand All @@ -274,8 +288,35 @@ def test_astype_str(self, data, request):
reason=f"For {pa_dtype} .astype(str) decodes.",
)
)
elif (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
) or pa.types.is_duration(pa_dtype):
request.applymarker(
pytest.mark.xfail(
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
)
)
super().test_astype_str(data)

@pytest.mark.parametrize(
"nullable_string_dtype",
[
"string[python]",
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
],
)
def test_astype_string(self, data, nullable_string_dtype, request):
pa_dtype = data.dtype.pyarrow_dtype
if (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
) or pa.types.is_duration(pa_dtype):
request.applymarker(
pytest.mark.xfail(
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
)
)
super().test_astype_string(data, nullable_string_dtype)

def test_from_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
Expand Down Expand Up @@ -1511,11 +1552,9 @@ def test_to_numpy_with_defaults(data):
result = data.to_numpy()

pa_type = data._pa_array.type
if (
pa.types.is_duration(pa_type)
or pa.types.is_timestamp(pa_type)
or pa.types.is_date(pa_type)
):
if pa.types.is_duration(pa_type) or pa.types.is_timestamp(pa_type):
pytest.skip("Tested in test_to_numpy_temporal")
elif pa.types.is_date(pa_type):
expected = np.array(list(data))
else:
expected = np.array(data._pa_array)
Expand Down Expand Up @@ -2937,26 +2976,28 @@ def test_groupby_series_size_returns_pa_int(data):


@pytest.mark.parametrize(
"pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES
"pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES, ids=repr
)
def test_to_numpy_temporal(pa_type):
@pytest.mark.parametrize("dtype", [None, object])
def test_to_numpy_temporal(pa_type, dtype):
# GH 53326
# GH 55997: Return datetime64/timedelta64 types with NaT if possible
arr = ArrowExtensionArray(pa.array([1, None], type=pa_type))
result = arr.to_numpy()
result = arr.to_numpy(dtype=dtype)
if pa.types.is_duration(pa_type):
expected = [
pd.Timedelta(1, unit=pa_type.unit).as_unit(pa_type.unit),
pd.NA,
]
assert isinstance(result[0], pd.Timedelta)
value = pd.Timedelta(1, unit=pa_type.unit).as_unit(pa_type.unit)
else:
expected = [
pd.Timestamp(1, unit=pa_type.unit, tz=pa_type.tz).as_unit(pa_type.unit),
pd.NA,
]
assert isinstance(result[0], pd.Timestamp)
expected = np.array(expected, dtype=object)
assert result[0].unit == expected[0].unit
value = pd.Timestamp(1, unit=pa_type.unit, tz=pa_type.tz).as_unit(pa_type.unit)

if dtype == object or (pa.types.is_timestamp(pa_type) and pa_type.tz is not None):
na = pd.NA
expected = np.array([value, na], dtype=object)
assert result[0].unit == value.unit
else:
na = pa_type.to_pandas_dtype().type("nat", pa_type.unit)
value = value.to_numpy()
expected = np.array([value, na])
assert np.datetime_data(result[0])[0] == pa_type.unit
tm.assert_numpy_array_equal(result, expected)


Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/io/formats/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,7 +1921,7 @@ def dtype(self):
series = Series(ExtTypeStub(), copy=False)
res = repr(series) # This line crashed before #33770 was fixed.
expected = "\n".join(
["0 [False True]", "1 [ True False]", "dtype: DtypeStub"]
["0 [False True]", "1 [True False]", "dtype: DtypeStub"]
)
assert res == expected

Expand Down

0 comments on commit d1d6f06

Please sign in to comment.