From 4145278fa169a79fb44db15c563091e695c54d64 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 4 Oct 2023 11:18:01 -0700 Subject: [PATCH] TYP: format.formats (#55393) * TYP: format.formats * Fix MultiIndex._values --- pandas/core/arrays/categorical.py | 16 +++++---- pandas/core/arrays/timedeltas.py | 2 +- pandas/core/indexes/multi.py | 6 ++-- pandas/io/formats/format.py | 57 +++++++++++++++++-------------- 4 files changed, 44 insertions(+), 37 deletions(-) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 8d2633c10b428..4152d2a50ea63 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -1819,7 +1819,7 @@ def _empty( # type: ignore[override] return arr._from_backing_data(backing) - def _internal_get_values(self): + def _internal_get_values(self) -> ArrayLike: """ Return the values. @@ -1827,15 +1827,19 @@ def _internal_get_values(self): Returns ------- - np.ndarray or Index - A numpy array of the same dtype as categorical.categories.dtype or - Index if datetime / periods. + np.ndarray or ExtensionArray + A numpy array or ExtensionArray of the same dtype as + categorical.categories.dtype. """ # if we are a datetime and period index, return Index to keep metadata if needs_i8_conversion(self.categories.dtype): - return self.categories.take(self._codes, fill_value=NaT) + return self.categories.take(self._codes, fill_value=NaT)._values elif is_integer_dtype(self.categories.dtype) and -1 in self._codes: - return self.categories.astype("object").take(self._codes, fill_value=np.nan) + return ( + self.categories.astype("object") + .take(self._codes, fill_value=np.nan) + ._values + ) return np.array(self) def check_for_ordered(self, op) -> None: diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index b7b81b8271106..931a220d7ab29 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -471,7 +471,7 @@ def _format_native_types( from pandas.io.formats.format import get_format_timedelta64 # Relies on TimeDelta._repr_base - formatter = get_format_timedelta64(self._ndarray, na_rep) + formatter = get_format_timedelta64(self, na_rep) # equiv: np.array([formatter(x) for x in self._ndarray]) # but independent of dimension return np.frompyfunc(formatter, 1, 1)(self._ndarray) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 041ef2d742c16..8955329a7afbf 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -73,9 +73,7 @@ ) from pandas.core.dtypes.generic import ( ABCDataFrame, - ABCDatetimeIndex, ABCSeries, - ABCTimedeltaIndex, ) from pandas.core.dtypes.inference import is_array_like from pandas.core.dtypes.missing import ( @@ -768,8 +766,8 @@ def _values(self) -> np.ndarray: vals = cast("CategoricalIndex", vals) vals = vals._data._internal_get_values() - if isinstance(vals.dtype, ExtensionDtype) or isinstance( - vals, (ABCDatetimeIndex, ABCTimedeltaIndex) + if isinstance(vals.dtype, ExtensionDtype) or lib.is_np_dtype( + vals.dtype, "mM" ): vals = vals.astype(object) diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index 922d0f37bee3a..8b65a09ee5ac5 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -72,6 +72,7 @@ from pandas.core.arrays import ( Categorical, DatetimeArray, + ExtensionArray, TimedeltaArray, ) from pandas.core.arrays.string_ import StringDtype @@ -108,6 +109,7 @@ SequenceNotStr, StorageOptions, WriteBuffer, + npt, ) from pandas import ( @@ -1216,7 +1218,7 @@ def get_buffer( def format_array( - values: Any, + values: ArrayLike, formatter: Callable | None, float_format: FloatFormatType | None = None, na_rep: str = "NaN", @@ -1233,7 +1235,7 @@ def format_array( Parameters ---------- - values + values : np.ndarray or ExtensionArray formatter float_format na_rep @@ -1258,10 +1260,13 @@ def format_array( fmt_klass: type[GenericArrayFormatter] if lib.is_np_dtype(values.dtype, "M"): fmt_klass = Datetime64Formatter + values = cast(DatetimeArray, values) elif isinstance(values.dtype, DatetimeTZDtype): fmt_klass = Datetime64TZFormatter + values = cast(DatetimeArray, values) elif lib.is_np_dtype(values.dtype, "m"): fmt_klass = Timedelta64Formatter + values = cast(TimedeltaArray, values) elif isinstance(values.dtype, ExtensionDtype): fmt_klass = ExtensionArrayFormatter elif lib.is_np_dtype(values.dtype, "fc"): @@ -1300,7 +1305,7 @@ def format_array( class GenericArrayFormatter: def __init__( self, - values: Any, + values: ArrayLike, digits: int = 7, formatter: Callable | None = None, na_rep: str = "NaN", @@ -1622,9 +1627,11 @@ def _format_strings(self) -> list[str]: class Datetime64Formatter(GenericArrayFormatter): + values: DatetimeArray + def __init__( self, - values: np.ndarray | Series | DatetimeIndex | DatetimeArray, + values: DatetimeArray, nat_rep: str = "NaT", date_format: None = None, **kwargs, @@ -1637,21 +1644,23 @@ def _format_strings(self) -> list[str]: """we by definition have DO NOT have a TZ""" values = self.values - if not isinstance(values, DatetimeIndex): - values = DatetimeIndex(values) + dti = DatetimeIndex(values) if self.formatter is not None and callable(self.formatter): - return [self.formatter(x) for x in values] + return [self.formatter(x) for x in dti] - fmt_values = values._data._format_native_types( + fmt_values = dti._data._format_native_types( na_rep=self.nat_rep, date_format=self.date_format ) return fmt_values.tolist() class ExtensionArrayFormatter(GenericArrayFormatter): + values: ExtensionArray + def _format_strings(self) -> list[str]: values = extract_array(self.values, extract_numpy=True) + values = cast(ExtensionArray, values) formatter = self.formatter fallback_formatter = None @@ -1813,13 +1822,10 @@ def get_format_datetime64( def get_format_datetime64_from_values( - values: np.ndarray | DatetimeArray | DatetimeIndex, date_format: str | None + values: DatetimeArray, date_format: str | None ) -> str | None: """given values and a date_format, return a string format""" - if isinstance(values, np.ndarray) and values.ndim > 1: - # We don't actually care about the order of values, and DatetimeIndex - # only accepts 1D values - values = values.ravel() + assert isinstance(values, DatetimeArray) ido = is_dates_only(values) if ido: @@ -1829,6 +1835,8 @@ def get_format_datetime64_from_values( class Datetime64TZFormatter(Datetime64Formatter): + values: DatetimeArray + def _format_strings(self) -> list[str]: """we by definition have a TZ""" ido = is_dates_only(self.values) @@ -1842,9 +1850,11 @@ def _format_strings(self) -> list[str]: class Timedelta64Formatter(GenericArrayFormatter): + values: TimedeltaArray + def __init__( self, - values: np.ndarray | TimedeltaIndex, + values: TimedeltaArray, nat_rep: str = "NaT", box: bool = False, **kwargs, @@ -1861,7 +1871,7 @@ def _format_strings(self) -> list[str]: def get_format_timedelta64( - values: np.ndarray | TimedeltaIndex | TimedeltaArray, + values: TimedeltaArray, nat_rep: str | float = "NaT", box: bool = False, ) -> Callable: @@ -1872,18 +1882,13 @@ def get_format_timedelta64( If box, then show the return in quotes """ values_int = values.view(np.int64) + values_int = cast("npt.NDArray[np.int64]", values_int) consider_values = values_int != iNaT one_day_nanos = 86400 * 10**9 - # error: Unsupported operand types for % ("ExtensionArray" and "int") - not_midnight = values_int % one_day_nanos != 0 # type: ignore[operator] - # error: Argument 1 to "__call__" of "ufunc" has incompatible type - # "Union[Any, ExtensionArray, ndarray]"; expected - # "Union[Union[int, float, complex, str, bytes, generic], - # Sequence[Union[int, float, complex, str, bytes, generic]], - # Sequence[Sequence[Any]], _SupportsArray]" - both = np.logical_and(consider_values, not_midnight) # type: ignore[arg-type] + not_midnight = values_int % one_day_nanos != 0 + both = np.logical_and(consider_values, not_midnight) even_days = both.sum() == 0 if even_days: @@ -1941,7 +1946,7 @@ def just(x: str) -> str: return result -def _trim_zeros_complex(str_complexes: np.ndarray, decimal: str = ".") -> list[str]: +def _trim_zeros_complex(str_complexes: ArrayLike, decimal: str = ".") -> list[str]: """ Separates the real and imaginary parts from the complex number, and executes the _trim_zeros_float method on each of those. @@ -1987,7 +1992,7 @@ def _trim_zeros_single_float(str_float: str) -> str: def _trim_zeros_float( - str_floats: np.ndarray | list[str], decimal: str = "." + str_floats: ArrayLike | list[str], decimal: str = "." ) -> list[str]: """ Trims the maximum number of trailing zeros equally from @@ -2000,7 +2005,7 @@ def _trim_zeros_float( def is_number_with_decimal(x) -> bool: return re.match(number_regex, x) is not None - def should_trim(values: np.ndarray | list[str]) -> bool: + def should_trim(values: ArrayLike | list[str]) -> bool: """ Determine if an array of strings should be trimmed.