Skip to content

Commit

Permalink
API (string): return str dtype for .dt methods, DatetimeIndex methods (
Browse files Browse the repository at this point in the history
…#59526)

* API (string): return str dtype for .dt methods, DatetimeIndex methods

* mypy fixup
  • Loading branch information
jbrockmendel authored and jorisvandenbossche committed Oct 9, 2024
1 parent be6354b commit 3bb9ae6
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 16 deletions.
6 changes: 6 additions & 0 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import numpy as np

from pandas._config import using_string_dtype

from pandas._libs import (
algos,
lib,
Expand Down Expand Up @@ -1789,6 +1791,10 @@ def strftime(self, date_format: str) -> npt.NDArray[np.object_]:
dtype='object')
"""
result = self._format_native_types(date_format=date_format, na_rep=np.nan)
if using_string_dtype():
from pandas import StringDtype

return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
return result.astype(object, copy=False)


Expand Down
17 changes: 17 additions & 0 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import numpy as np

from pandas._config import using_string_dtype

from pandas._libs import (
lib,
tslib,
Expand Down Expand Up @@ -1306,6 +1308,13 @@ def month_name(self, locale=None) -> npt.NDArray[np.object_]:
values, "month_name", locale=locale, reso=self._creso
)
result = self._maybe_mask_results(result, fill_value=None)
if using_string_dtype():
from pandas import (
StringDtype,
array as pd_array,
)

return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
return result

def day_name(self, locale=None) -> npt.NDArray[np.object_]:
Expand Down Expand Up @@ -1363,6 +1372,14 @@ def day_name(self, locale=None) -> npt.NDArray[np.object_]:
values, "day_name", locale=locale, reso=self._creso
)
result = self._maybe_mask_results(result, fill_value=None)
if using_string_dtype():
# TODO: no tests that check for dtype of result as of 2024-08-15
from pandas import (
StringDtype,
array as pd_array,
)

return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
return result

@property
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _engine_type(self) -> type[libindex.DatetimeEngine]:
@doc(DatetimeArray.strftime)
def strftime(self, date_format) -> Index:
arr = self._data.strftime(date_format)
return Index(arr, name=self.name, dtype=object)
return Index(arr, name=self.name, dtype=arr.dtype)

@doc(DatetimeArray.tz_convert)
def tz_convert(self, tz) -> Self:
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def fget(self):
return type(self)._simple_new(result, name=self.name)
elif isinstance(result, ABCDataFrame):
return result.set_index(self)
return Index(result, name=self.name)
return Index(result, name=self.name, dtype=result.dtype)
return result

def fset(self, value) -> None:
Expand All @@ -98,7 +98,7 @@ def method(self, *args, **kwargs): # type: ignore[misc]
return type(self)._simple_new(result, name=self.name)
elif isinstance(result, ABCDataFrame):
return result.set_index(self)
return Index(result, name=self.name)
return Index(result, name=self.name, dtype=result.dtype)
return result

# error: "property" has no attribute "__name__"
Expand Down
24 changes: 16 additions & 8 deletions pandas/tests/arrays/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,20 +889,24 @@ def test_concat_same_type_different_freq(self, unit):

tm.assert_datetime_array_equal(result, expected)

def test_strftime(self, arr1d):
def test_strftime(self, arr1d, using_infer_string):
arr = arr1d

result = arr.strftime("%Y %b")
expected = np.array([ts.strftime("%Y %b") for ts in arr], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)

def test_strftime_nat(self):
def test_strftime_nat(self, using_infer_string):
# GH 29578
arr = DatetimeIndex(["2019-01-01", NaT])._data

result = arr.strftime("%Y-%m-%d")
expected = np.array(["2019-01-01", np.nan], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)


class TestTimedeltaArray(SharedTests):
Expand Down Expand Up @@ -1159,20 +1163,24 @@ def test_array_interface(self, arr1d):
expected = np.asarray(arr).astype("S20")
tm.assert_numpy_array_equal(result, expected)

def test_strftime(self, arr1d):
def test_strftime(self, arr1d, using_infer_string):
arr = arr1d

result = arr.strftime("%Y")
expected = np.array([per.strftime("%Y") for per in arr], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)

def test_strftime_nat(self):
def test_strftime_nat(self, using_infer_string):
# GH 29578
arr = PeriodArray(PeriodIndex(["2019-01-01", NaT], dtype="period[D]"))

result = arr.strftime("%Y-%m-%d")
expected = np.array(["2019-01-01", np.nan], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)


@pytest.mark.parametrize(
Expand Down
8 changes: 3 additions & 5 deletions pandas/tests/series/accessors/test_dt_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Period,
PeriodIndex,
Series,
StringDtype,
TimedeltaIndex,
date_range,
period_range,
Expand Down Expand Up @@ -528,7 +529,6 @@ def test_dt_accessor_datetime_name_accessors(self, time_locale):
ser = pd.concat([ser, Series([pd.NaT])])
assert np.isnan(ser.dt.month_name(locale=time_locale).iloc[-1])

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_strftime(self):
# GH 10086
ser = Series(date_range("20130101", periods=5))
Expand Down Expand Up @@ -599,10 +599,9 @@ def test_strftime_period_days(self, using_infer_string):
dtype="=U10",
)
if using_infer_string:
expected = expected.astype("str")
expected = expected.astype(StringDtype(na_value=np.nan))
tm.assert_index_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_strftime_dt64_microsecond_resolution(self):
ser = Series([datetime(2013, 1, 1, 2, 32, 59), datetime(2013, 1, 2, 14, 32, 1)])
result = ser.dt.strftime("%Y-%m-%d %H:%M:%S")
Expand Down Expand Up @@ -635,7 +634,6 @@ def test_strftime_period_minutes(self):
)
tm.assert_series_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.parametrize(
"data",
[
Expand All @@ -658,7 +656,7 @@ def test_strftime_all_nat(self, data):
ser = Series(data)
with tm.assert_produces_warning(None):
result = ser.dt.strftime("%Y-%m-%d")
expected = Series([np.nan], dtype=object)
expected = Series([np.nan], dtype="str")
tm.assert_series_equal(result, expected)

def test_valid_dt_with_missing_values(self):
Expand Down

0 comments on commit 3bb9ae6

Please sign in to comment.