Skip to content

Commit

Permalink
REF: improve privatization in io.formats.format (#55389)
Browse files Browse the repository at this point in the history
* REF: privatize where possible in io.formats.format

* REF: de-privatize VALID_JUSTIFY_PARAMETERS, avoid multiple get_option calls
  • Loading branch information
jbrockmendel authored Oct 5, 2023
1 parent d943c26 commit 6c58a21
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 72 deletions.
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3266,7 +3266,7 @@ def to_html(
... </table>'''
>>> assert html_string == df.to_html()
"""
if justify is not None and justify not in fmt._VALID_JUSTIFY_PARAMETERS:
if justify is not None and justify not in fmt.VALID_JUSTIFY_PARAMETERS:
raise ValueError("Invalid value for justify parameter")

formatter = fmt.DataFrameFormatter(
Expand Down
81 changes: 41 additions & 40 deletions pandas/io/formats/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import re
from shutil import get_terminal_size
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Expand Down Expand Up @@ -172,7 +171,7 @@
Character recognized as decimal separator, e.g. ',' in Europe.
"""

_VALID_JUSTIFY_PARAMETERS = (
VALID_JUSTIFY_PARAMETERS = (
"left",
"right",
"center",
Expand All @@ -196,10 +195,15 @@


class SeriesFormatter:
"""
Implement the main logic of Series.to_string, which underlies
Series.__repr__.
"""

def __init__(
self,
series: Series,
buf: IO[str] | None = None,
*,
length: bool | str = True,
header: bool = True,
index: bool = True,
Expand All @@ -211,7 +215,7 @@ def __init__(
min_rows: int | None = None,
) -> None:
self.series = series
self.buf = buf if buf is not None else StringIO()
self.buf = StringIO()
self.name = name
self.na_rep = na_rep
self.header = header
Expand Down Expand Up @@ -355,7 +359,7 @@ def to_string(self) -> str:
return str("".join(result))


class TextAdjustment:
class _TextAdjustment:
def __init__(self) -> None:
self.encoding = get_option("display.encoding")

Expand All @@ -371,7 +375,7 @@ def adjoin(self, space: int, *lists, **kwargs) -> str:
)


class EastAsianTextAdjustment(TextAdjustment):
class _EastAsianTextAdjustment(_TextAdjustment):
def __init__(self) -> None:
super().__init__()
if get_option("display.unicode.ambiguous_as_wide"):
Expand Down Expand Up @@ -410,12 +414,12 @@ def _get_pad(t):
return [x.rjust(_get_pad(x)) for x in texts]


def get_adjustment() -> TextAdjustment:
def get_adjustment() -> _TextAdjustment:
use_east_asian_width = get_option("display.unicode.east_asian_width")
if use_east_asian_width:
return EastAsianTextAdjustment()
return _EastAsianTextAdjustment()
else:
return TextAdjustment()
return _TextAdjustment()


def get_dataframe_repr_params() -> dict[str, Any]:
Expand Down Expand Up @@ -469,16 +473,9 @@ def get_series_repr_params() -> dict[str, Any]:
True
"""
width, height = get_terminal_size()
max_rows = (
height
if get_option("display.max_rows") == 0
else get_option("display.max_rows")
)
min_rows = (
height
if get_option("display.max_rows") == 0
else get_option("display.min_rows")
)
max_rows_opt = get_option("display.max_rows")
max_rows = height if max_rows_opt == 0 else max_rows_opt
min_rows = height if max_rows_opt == 0 else get_option("display.min_rows")

return {
"name": True,
Expand All @@ -490,7 +487,11 @@ def get_series_repr_params() -> dict[str, Any]:


class DataFrameFormatter:
"""Class for processing dataframe formatting options and data."""
"""
Class for processing dataframe formatting options and data.
Used by DataFrame.to_string, which backs DataFrame.__repr__.
"""

__doc__ = __doc__ if __doc__ else ""
__doc__ += common_docstring + return_docstring
Expand Down Expand Up @@ -1102,16 +1103,16 @@ def save_to_buffer(
"""
Perform serialization. Write to buf or return as string if buf is None.
"""
with get_buffer(buf, encoding=encoding) as f:
f.write(string)
with _get_buffer(buf, encoding=encoding) as fd:
fd.write(string)
if buf is None:
# error: "WriteBuffer[str]" has no attribute "getvalue"
return f.getvalue() # type: ignore[attr-defined]
return fd.getvalue() # type: ignore[attr-defined]
return None


@contextmanager
def get_buffer(
def _get_buffer(
buf: FilePath | WriteBuffer[str] | None, encoding: str | None = None
) -> Generator[WriteBuffer[str], None, None] | Generator[StringIO, None, None]:
"""
Expand Down Expand Up @@ -1188,24 +1189,24 @@ def format_array(
-------
List[str]
"""
fmt_klass: type[GenericArrayFormatter]
fmt_klass: type[_GenericArrayFormatter]
if lib.is_np_dtype(values.dtype, "M"):
fmt_klass = Datetime64Formatter
fmt_klass = _Datetime64Formatter
values = cast(DatetimeArray, values)
elif isinstance(values.dtype, DatetimeTZDtype):
fmt_klass = Datetime64TZFormatter
fmt_klass = _Datetime64TZFormatter
values = cast(DatetimeArray, values)
elif lib.is_np_dtype(values.dtype, "m"):
fmt_klass = Timedelta64Formatter
fmt_klass = _Timedelta64Formatter
values = cast(TimedeltaArray, values)
elif isinstance(values.dtype, ExtensionDtype):
fmt_klass = ExtensionArrayFormatter
fmt_klass = _ExtensionArrayFormatter
elif lib.is_np_dtype(values.dtype, "fc"):
fmt_klass = FloatArrayFormatter
elif lib.is_np_dtype(values.dtype, "iu"):
fmt_klass = IntArrayFormatter
fmt_klass = _IntArrayFormatter
else:
fmt_klass = GenericArrayFormatter
fmt_klass = _GenericArrayFormatter

if space is None:
space = 12
Expand Down Expand Up @@ -1233,7 +1234,7 @@ def format_array(
return fmt_obj.get_result()


class GenericArrayFormatter:
class _GenericArrayFormatter:
def __init__(
self,
values: ArrayLike,
Expand Down Expand Up @@ -1315,7 +1316,7 @@ def _format(x):
vals = extract_array(self.values, extract_numpy=True)
if not isinstance(vals, np.ndarray):
raise TypeError(
"ExtensionArray formatting should use ExtensionArrayFormatter"
"ExtensionArray formatting should use _ExtensionArrayFormatter"
)
inferred = lib.map_infer(vals, is_float)
is_float_type = (
Expand Down Expand Up @@ -1345,7 +1346,7 @@ def _format(x):
return fmt_values


class FloatArrayFormatter(GenericArrayFormatter):
class FloatArrayFormatter(_GenericArrayFormatter):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -1546,7 +1547,7 @@ def _format_strings(self) -> list[str]:
return list(self.get_result_as_array())


class IntArrayFormatter(GenericArrayFormatter):
class _IntArrayFormatter(_GenericArrayFormatter):
def _format_strings(self) -> list[str]:
if self.leading_space is False:
formatter_str = lambda x: f"{x:d}".format(x=x)
Expand All @@ -1557,7 +1558,7 @@ def _format_strings(self) -> list[str]:
return fmt_values


class Datetime64Formatter(GenericArrayFormatter):
class _Datetime64Formatter(_GenericArrayFormatter):
values: DatetimeArray

def __init__(
Expand Down Expand Up @@ -1586,7 +1587,7 @@ def _format_strings(self) -> list[str]:
return fmt_values.tolist()


class ExtensionArrayFormatter(GenericArrayFormatter):
class _ExtensionArrayFormatter(_GenericArrayFormatter):
values: ExtensionArray

def _format_strings(self) -> list[str]:
Expand Down Expand Up @@ -1727,7 +1728,7 @@ def get_format_datetime64(
return lambda x: _format_datetime64(x, nat_rep=nat_rep)


class Datetime64TZFormatter(Datetime64Formatter):
class _Datetime64TZFormatter(_Datetime64Formatter):
values: DatetimeArray

def _format_strings(self) -> list[str]:
Expand All @@ -1742,7 +1743,7 @@ def _format_strings(self) -> list[str]:
return fmt_values


class Timedelta64Formatter(GenericArrayFormatter):
class _Timedelta64Formatter(_GenericArrayFormatter):
values: TimedeltaArray

def __init__(
Expand Down Expand Up @@ -1809,7 +1810,7 @@ def _make_fixed_width(
strings: list[str],
justify: str = "right",
minimum: int | None = None,
adj: TextAdjustment | None = None,
adj: _TextAdjustment | None = None,
) -> list[str]:
if len(strings) == 0 or justify == "all":
return strings
Expand Down
Loading

0 comments on commit 6c58a21

Please sign in to comment.