From d1b2c44761e96cbab6413897119f113dc7d657b9 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 1 Nov 2023 15:18:22 -0700 Subject: [PATCH] ENH/POC: infer resolution in array_strptime (#55778) * ENH/POC: infer resolution in array_strptime * creso_changed->creso_ever_changed --- pandas/_libs/tslibs/strptime.pxd | 3 + pandas/_libs/tslibs/strptime.pyx | 103 +++++++++++++++++++++++---- pandas/tests/tslibs/test_strptime.py | 61 ++++++++++++++++ 3 files changed, 152 insertions(+), 15 deletions(-) create mode 100644 pandas/tests/tslibs/test_strptime.py diff --git a/pandas/_libs/tslibs/strptime.pxd b/pandas/_libs/tslibs/strptime.pxd index 32a8edc9ee4a3..64db2b59dfcff 100644 --- a/pandas/_libs/tslibs/strptime.pxd +++ b/pandas/_libs/tslibs/strptime.pxd @@ -14,5 +14,8 @@ cdef class DatetimeParseState: cdef: bint found_tz bint found_naive + bint creso_ever_changed + NPY_DATETIMEUNIT creso cdef tzinfo process_datetime(self, datetime dt, tzinfo tz, bint utc_convert) + cdef bint update_creso(self, NPY_DATETIMEUNIT item_reso) noexcept diff --git a/pandas/_libs/tslibs/strptime.pyx b/pandas/_libs/tslibs/strptime.pyx index 69466511a67fd..6b744ce4c8cdb 100644 --- a/pandas/_libs/tslibs/strptime.pyx +++ b/pandas/_libs/tslibs/strptime.pyx @@ -49,6 +49,10 @@ from numpy cimport ( from pandas._libs.missing cimport checknull_with_nat_and_na from pandas._libs.tslibs.conversion cimport get_datetime64_nanos +from pandas._libs.tslibs.dtypes cimport ( + get_supported_reso, + npy_unit_to_abbrev, +) from pandas._libs.tslibs.nattype cimport ( NPY_NAT, c_nat_strings as nat_strings, @@ -57,6 +61,7 @@ from pandas._libs.tslibs.np_datetime cimport ( NPY_DATETIMEUNIT, NPY_FR_ns, check_dts_bounds, + get_datetime64_unit, import_pandas_datetime, npy_datetimestruct, npy_datetimestruct_to_datetime, @@ -232,9 +237,21 @@ cdef _get_format_regex(str fmt): cdef class DatetimeParseState: - def __cinit__(self): + def __cinit__(self, NPY_DATETIMEUNIT creso=NPY_DATETIMEUNIT.NPY_FR_ns): self.found_tz = False self.found_naive = False + self.creso = creso + self.creso_ever_changed = False + + cdef bint update_creso(self, NPY_DATETIMEUNIT item_reso) noexcept: + # Return a bool indicating whether we bumped to a higher resolution + if self.creso == NPY_DATETIMEUNIT.NPY_FR_GENERIC: + self.creso = item_reso + elif item_reso > self.creso: + self.creso = item_reso + self.creso_ever_changed = True + return True + return False cdef tzinfo process_datetime(self, datetime dt, tzinfo tz, bint utc_convert): if dt.tzinfo is not None: @@ -268,6 +285,7 @@ def array_strptime( bint exact=True, errors="raise", bint utc=False, + NPY_DATETIMEUNIT creso=NPY_FR_ns, ): """ Calculates the datetime structs represented by the passed array of strings @@ -278,6 +296,8 @@ def array_strptime( fmt : string-like regex exact : matches must be exact if True, search if False errors : string specifying error handling, {'raise', 'ignore', 'coerce'} + creso : NPY_DATETIMEUNIT, default NPY_FR_ns + Set to NPY_FR_GENERIC to infer a resolution. """ cdef: @@ -291,17 +311,22 @@ def array_strptime( bint is_coerce = errors=="coerce" tzinfo tz_out = None bint iso_format = format_is_iso(fmt) - NPY_DATETIMEUNIT out_bestunit + NPY_DATETIMEUNIT out_bestunit, item_reso int out_local = 0, out_tzoffset = 0 bint string_to_dts_succeeded = 0 - DatetimeParseState state = DatetimeParseState() + bint infer_reso = creso == NPY_DATETIMEUNIT.NPY_FR_GENERIC + DatetimeParseState state = DatetimeParseState(creso) assert is_raise or is_ignore or is_coerce _validate_fmt(fmt) format_regex, locale_time = _get_format_regex(fmt) - result = np.empty(n, dtype="M8[ns]") + if infer_reso: + abbrev = "ns" + else: + abbrev = npy_unit_to_abbrev(creso) + result = np.empty(n, dtype=f"M8[{abbrev}]") iresult = result.view("i8") result_timezone = np.empty(n, dtype="object") @@ -318,20 +343,32 @@ def array_strptime( iresult[i] = NPY_NAT continue elif PyDateTime_Check(val): + if isinstance(val, _Timestamp): + item_reso = val._creso + else: + item_reso = NPY_DATETIMEUNIT.NPY_FR_us + state.update_creso(item_reso) tz_out = state.process_datetime(val, tz_out, utc) if isinstance(val, _Timestamp): - iresult[i] = val.tz_localize(None).as_unit("ns")._value + val = (<_Timestamp>val)._as_creso(state.creso) + iresult[i] = val.tz_localize(None)._value else: - iresult[i] = pydatetime_to_dt64(val.replace(tzinfo=None), &dts) - check_dts_bounds(&dts) + iresult[i] = pydatetime_to_dt64( + val.replace(tzinfo=None), &dts, reso=state.creso + ) + check_dts_bounds(&dts, state.creso) result_timezone[i] = val.tzinfo continue elif PyDate_Check(val): - iresult[i] = pydate_to_dt64(val, &dts) - check_dts_bounds(&dts) + item_reso = NPY_DATETIMEUNIT.NPY_FR_s + state.update_creso(item_reso) + iresult[i] = pydate_to_dt64(val, &dts, reso=state.creso) + check_dts_bounds(&dts, state.creso) continue elif is_datetime64_object(val): - iresult[i] = get_datetime64_nanos(val, NPY_FR_ns) + item_reso = get_supported_reso(get_datetime64_unit(val)) + state.update_creso(item_reso) + iresult[i] = get_datetime64_nanos(val, state.creso) continue elif ( (is_integer_object(val) or is_float_object(val)) @@ -355,7 +392,9 @@ def array_strptime( if string_to_dts_succeeded: # No error reported by string_to_dts, pick back up # where we left off - value = npy_datetimestruct_to_datetime(NPY_FR_ns, &dts) + item_reso = get_supported_reso(out_bestunit) + state.update_creso(item_reso) + value = npy_datetimestruct_to_datetime(state.creso, &dts) if out_local == 1: # Store the out_tzoffset in seconds # since we store the total_seconds of @@ -368,7 +407,9 @@ def array_strptime( check_dts_bounds(&dts) continue - if parse_today_now(val, &iresult[i], utc, NPY_FR_ns): + if parse_today_now(val, &iresult[i], utc, state.creso): + item_reso = NPY_DATETIMEUNIT.NPY_FR_us + state.update_creso(item_reso) continue # Some ISO formats can't be parsed by string_to_dts @@ -380,9 +421,10 @@ def array_strptime( raise ValueError(f"Time data {val} is not ISO8601 format") tz = _parse_with_format( - val, fmt, exact, format_regex, locale_time, &dts + val, fmt, exact, format_regex, locale_time, &dts, &item_reso ) - iresult[i] = npy_datetimestruct_to_datetime(NPY_FR_ns, &dts) + state.update_creso(item_reso) + iresult[i] = npy_datetimestruct_to_datetime(state.creso, &dts) check_dts_bounds(&dts) result_timezone[i] = tz @@ -403,11 +445,34 @@ def array_strptime( raise return values, [] + if infer_reso: + if state.creso_ever_changed: + # We encountered mismatched resolutions, need to re-parse with + # the correct one. + return array_strptime( + values, + fmt=fmt, + exact=exact, + errors=errors, + utc=utc, + creso=state.creso, + ) + + # Otherwise we can use the single reso that we encountered and avoid + # a second pass. + abbrev = npy_unit_to_abbrev(state.creso) + result = iresult.base.view(f"M8[{abbrev}]") return result, result_timezone.base cdef tzinfo _parse_with_format( - str val, str fmt, bint exact, format_regex, locale_time, npy_datetimestruct* dts + str val, + str fmt, + bint exact, + format_regex, + locale_time, + npy_datetimestruct* dts, + NPY_DATETIMEUNIT* item_reso, ): # Based on https://github.com/python/cpython/blob/main/Lib/_strptime.py#L293 cdef: @@ -441,6 +506,8 @@ cdef tzinfo _parse_with_format( f"time data \"{val}\" doesn't match format \"{fmt}\"" ) + item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_s + iso_year = -1 year = 1900 month = day = 1 @@ -527,6 +594,12 @@ cdef tzinfo _parse_with_format( elif parse_code == 10: # e.g. val='10:10:10.100'; fmt='%H:%M:%S.%f' s = found_dict["f"] + if len(s) <= 3: + item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_ms + elif len(s) <= 6: + item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_us + else: + item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_ns # Pad to always return nanoseconds s += "0" * (9 - len(s)) us = long(s) diff --git a/pandas/tests/tslibs/test_strptime.py b/pandas/tests/tslibs/test_strptime.py new file mode 100644 index 0000000000000..0992eecf0eedd --- /dev/null +++ b/pandas/tests/tslibs/test_strptime.py @@ -0,0 +1,61 @@ +from datetime import ( + datetime, + timezone, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit +from pandas._libs.tslibs.strptime import array_strptime + +from pandas import Timestamp +import pandas._testing as tm + +creso_infer = NpyDatetimeUnit.NPY_FR_GENERIC.value + + +class TestArrayStrptimeResolutionInference: + @pytest.mark.parametrize("tz", [None, timezone.utc]) + def test_array_strptime_resolution_inference_homogeneous_strings(self, tz): + dt = datetime(2016, 1, 2, 3, 4, 5, 678900, tzinfo=tz) + + fmt = "%Y-%m-%d %H:%M:%S" + dtstr = dt.strftime(fmt) + arr = np.array([dtstr] * 3, dtype=object) + expected = np.array([dt.replace(tzinfo=None)] * 3, dtype="M8[s]") + + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + fmt = "%Y-%m-%d %H:%M:%S.%f" + dtstr = dt.strftime(fmt) + arr = np.array([dtstr] * 3, dtype=object) + expected = np.array([dt.replace(tzinfo=None)] * 3, dtype="M8[us]") + + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + fmt = "ISO8601" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + @pytest.mark.parametrize("tz", [None, timezone.utc]) + def test_array_strptime_resolution_mixed(self, tz): + dt = datetime(2016, 1, 2, 3, 4, 5, 678900, tzinfo=tz) + + ts = Timestamp(dt).as_unit("ns") + + arr = np.array([dt, ts], dtype=object) + expected = np.array( + [Timestamp(dt).as_unit("ns").asm8, ts.asm8], + dtype="M8[ns]", + ) + + fmt = "%Y-%m-%d %H:%M:%S" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + fmt = "ISO8601" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected)