Skip to content

Commit

Permalink
TYP: format.formats (#55393)
Browse files Browse the repository at this point in the history
* TYP: format.formats

* Fix MultiIndex._values
  • Loading branch information
jbrockmendel authored Oct 4, 2023
1 parent 364c9cb commit 4145278
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 37 deletions.
16 changes: 10 additions & 6 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,23 +1819,27 @@ def _empty( # type: ignore[override]

return arr._from_backing_data(backing)

def _internal_get_values(self):
def _internal_get_values(self) -> ArrayLike:
"""
Return the values.
For internal compatibility with pandas formatting.
Returns
-------
np.ndarray or Index
A numpy array of the same dtype as categorical.categories.dtype or
Index if datetime / periods.
np.ndarray or ExtensionArray
A numpy array or ExtensionArray of the same dtype as
categorical.categories.dtype.
"""
# if we are a datetime and period index, return Index to keep metadata
if needs_i8_conversion(self.categories.dtype):
return self.categories.take(self._codes, fill_value=NaT)
return self.categories.take(self._codes, fill_value=NaT)._values
elif is_integer_dtype(self.categories.dtype) and -1 in self._codes:
return self.categories.astype("object").take(self._codes, fill_value=np.nan)
return (
self.categories.astype("object")
.take(self._codes, fill_value=np.nan)
._values
)
return np.array(self)

def check_for_ordered(self, op) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def _format_native_types(
from pandas.io.formats.format import get_format_timedelta64

# Relies on TimeDelta._repr_base
formatter = get_format_timedelta64(self._ndarray, na_rep)
formatter = get_format_timedelta64(self, na_rep)
# equiv: np.array([formatter(x) for x in self._ndarray])
# but independent of dimension
return np.frompyfunc(formatter, 1, 1)(self._ndarray)
Expand Down
6 changes: 2 additions & 4 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCDatetimeIndex,
ABCSeries,
ABCTimedeltaIndex,
)
from pandas.core.dtypes.inference import is_array_like
from pandas.core.dtypes.missing import (
Expand Down Expand Up @@ -768,8 +766,8 @@ def _values(self) -> np.ndarray:
vals = cast("CategoricalIndex", vals)
vals = vals._data._internal_get_values()

if isinstance(vals.dtype, ExtensionDtype) or isinstance(
vals, (ABCDatetimeIndex, ABCTimedeltaIndex)
if isinstance(vals.dtype, ExtensionDtype) or lib.is_np_dtype(
vals.dtype, "mM"
):
vals = vals.astype(object)

Expand Down
57 changes: 31 additions & 26 deletions pandas/io/formats/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from pandas.core.arrays import (
Categorical,
DatetimeArray,
ExtensionArray,
TimedeltaArray,
)
from pandas.core.arrays.string_ import StringDtype
Expand Down Expand Up @@ -108,6 +109,7 @@
SequenceNotStr,
StorageOptions,
WriteBuffer,
npt,
)

from pandas import (
Expand Down Expand Up @@ -1216,7 +1218,7 @@ def get_buffer(


def format_array(
values: Any,
values: ArrayLike,
formatter: Callable | None,
float_format: FloatFormatType | None = None,
na_rep: str = "NaN",
Expand All @@ -1233,7 +1235,7 @@ def format_array(
Parameters
----------
values
values : np.ndarray or ExtensionArray
formatter
float_format
na_rep
Expand All @@ -1258,10 +1260,13 @@ def format_array(
fmt_klass: type[GenericArrayFormatter]
if lib.is_np_dtype(values.dtype, "M"):
fmt_klass = Datetime64Formatter
values = cast(DatetimeArray, values)
elif isinstance(values.dtype, DatetimeTZDtype):
fmt_klass = Datetime64TZFormatter
values = cast(DatetimeArray, values)
elif lib.is_np_dtype(values.dtype, "m"):
fmt_klass = Timedelta64Formatter
values = cast(TimedeltaArray, values)
elif isinstance(values.dtype, ExtensionDtype):
fmt_klass = ExtensionArrayFormatter
elif lib.is_np_dtype(values.dtype, "fc"):
Expand Down Expand Up @@ -1300,7 +1305,7 @@ def format_array(
class GenericArrayFormatter:
def __init__(
self,
values: Any,
values: ArrayLike,
digits: int = 7,
formatter: Callable | None = None,
na_rep: str = "NaN",
Expand Down Expand Up @@ -1622,9 +1627,11 @@ def _format_strings(self) -> list[str]:


class Datetime64Formatter(GenericArrayFormatter):
values: DatetimeArray

def __init__(
self,
values: np.ndarray | Series | DatetimeIndex | DatetimeArray,
values: DatetimeArray,
nat_rep: str = "NaT",
date_format: None = None,
**kwargs,
Expand All @@ -1637,21 +1644,23 @@ def _format_strings(self) -> list[str]:
"""we by definition have DO NOT have a TZ"""
values = self.values

if not isinstance(values, DatetimeIndex):
values = DatetimeIndex(values)
dti = DatetimeIndex(values)

if self.formatter is not None and callable(self.formatter):
return [self.formatter(x) for x in values]
return [self.formatter(x) for x in dti]

fmt_values = values._data._format_native_types(
fmt_values = dti._data._format_native_types(
na_rep=self.nat_rep, date_format=self.date_format
)
return fmt_values.tolist()


class ExtensionArrayFormatter(GenericArrayFormatter):
values: ExtensionArray

def _format_strings(self) -> list[str]:
values = extract_array(self.values, extract_numpy=True)
values = cast(ExtensionArray, values)

formatter = self.formatter
fallback_formatter = None
Expand Down Expand Up @@ -1813,13 +1822,10 @@ def get_format_datetime64(


def get_format_datetime64_from_values(
values: np.ndarray | DatetimeArray | DatetimeIndex, date_format: str | None
values: DatetimeArray, date_format: str | None
) -> str | None:
"""given values and a date_format, return a string format"""
if isinstance(values, np.ndarray) and values.ndim > 1:
# We don't actually care about the order of values, and DatetimeIndex
# only accepts 1D values
values = values.ravel()
assert isinstance(values, DatetimeArray)

ido = is_dates_only(values)
if ido:
Expand All @@ -1829,6 +1835,8 @@ def get_format_datetime64_from_values(


class Datetime64TZFormatter(Datetime64Formatter):
values: DatetimeArray

def _format_strings(self) -> list[str]:
"""we by definition have a TZ"""
ido = is_dates_only(self.values)
Expand All @@ -1842,9 +1850,11 @@ def _format_strings(self) -> list[str]:


class Timedelta64Formatter(GenericArrayFormatter):
values: TimedeltaArray

def __init__(
self,
values: np.ndarray | TimedeltaIndex,
values: TimedeltaArray,
nat_rep: str = "NaT",
box: bool = False,
**kwargs,
Expand All @@ -1861,7 +1871,7 @@ def _format_strings(self) -> list[str]:


def get_format_timedelta64(
values: np.ndarray | TimedeltaIndex | TimedeltaArray,
values: TimedeltaArray,
nat_rep: str | float = "NaT",
box: bool = False,
) -> Callable:
Expand All @@ -1872,18 +1882,13 @@ def get_format_timedelta64(
If box, then show the return in quotes
"""
values_int = values.view(np.int64)
values_int = cast("npt.NDArray[np.int64]", values_int)

consider_values = values_int != iNaT

one_day_nanos = 86400 * 10**9
# error: Unsupported operand types for % ("ExtensionArray" and "int")
not_midnight = values_int % one_day_nanos != 0 # type: ignore[operator]
# error: Argument 1 to "__call__" of "ufunc" has incompatible type
# "Union[Any, ExtensionArray, ndarray]"; expected
# "Union[Union[int, float, complex, str, bytes, generic],
# Sequence[Union[int, float, complex, str, bytes, generic]],
# Sequence[Sequence[Any]], _SupportsArray]"
both = np.logical_and(consider_values, not_midnight) # type: ignore[arg-type]
not_midnight = values_int % one_day_nanos != 0
both = np.logical_and(consider_values, not_midnight)
even_days = both.sum() == 0

if even_days:
Expand Down Expand Up @@ -1941,7 +1946,7 @@ def just(x: str) -> str:
return result


def _trim_zeros_complex(str_complexes: np.ndarray, decimal: str = ".") -> list[str]:
def _trim_zeros_complex(str_complexes: ArrayLike, decimal: str = ".") -> list[str]:
"""
Separates the real and imaginary parts from the complex number, and
executes the _trim_zeros_float method on each of those.
Expand Down Expand Up @@ -1987,7 +1992,7 @@ def _trim_zeros_single_float(str_float: str) -> str:


def _trim_zeros_float(
str_floats: np.ndarray | list[str], decimal: str = "."
str_floats: ArrayLike | list[str], decimal: str = "."
) -> list[str]:
"""
Trims the maximum number of trailing zeros equally from
Expand All @@ -2000,7 +2005,7 @@ def _trim_zeros_float(
def is_number_with_decimal(x) -> bool:
return re.match(number_regex, x) is not None

def should_trim(values: np.ndarray | list[str]) -> bool:
def should_trim(values: ArrayLike | list[str]) -> bool:
"""
Determine if an array of strings should be trimmed.
Expand Down

0 comments on commit 4145278

Please sign in to comment.