Skip to content

Commit

Permalink
FEAT-modin-project#7401: Implement DataFrame/Series.attrs
Browse files Browse the repository at this point in the history
Signed-off-by: Jonathan Shi <[email protected]>
  • Loading branch information
noloerino committed Sep 24, 2024
1 parent 1c4d173 commit cee01de
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
39 changes: 39 additions & 0 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from __future__ import annotations

import abc
import copy
import functools
import pickle as pkl
import re
import warnings
Expand Down Expand Up @@ -105,6 +107,7 @@
"_ipython_canary_method_should_not_exist_",
"_ipython_display_",
"_repr_mimebundle_",
"_attrs",
}

_DEFAULT_BEHAVIOUR = {
Expand Down Expand Up @@ -193,6 +196,26 @@ def _get_repr_axis_label_indexer(labels, num_for_repr):
)


def propagate_self_attrs(method):
"""
Wrap a BasePandasDataset/DataFrame/Series method with a function that automatically deep-copies self.attrs if present.
This annotation should not be used on special methods like concat, str, and groupby, which may need to
examine multiple sources to reconcile `attrs`.
"""

@functools.wraps(method)
def wrapper(self, *args, **kwargs):
result = method(self, *args, **kwargs)
if isinstance(result, BasePandasDataset) and len(self._attrs):
# If the result of the method call is a modin.pandas object and `self.attrs` is
# not empty, perform a deep copy of `self.attrs`.
result._attrs = copy.deepcopy(self._attrs)
return result

return wrapper


@_inherit_docstrings(pandas.DataFrame, apilink=["pandas.DataFrame", "pandas.Series"])
class BasePandasDataset(ClassLogger):
"""
Expand All @@ -208,6 +231,7 @@ class BasePandasDataset(ClassLogger):
_pandas_class = pandas.core.generic.NDFrame
_query_compiler: BaseQueryCompiler
_siblings: list[BasePandasDataset]
_attrs: dict

@cached_property
def _is_dataframe(self) -> bool:
Expand Down Expand Up @@ -1125,6 +1149,20 @@ def at(self, axis=None) -> _LocIndexer: # noqa: PR01, RT01, D200

return _LocIndexer(self)

def _set_attrs(self, key: Any, value: Any) -> dict: # noqa: PR01, RT01, D200
"""
Set the dictionary of global attributes of this dataset.
"""
self._attrs[key] = value

def _get_attrs(self) -> dict: # noqa: PR01, RT01, D200
"""
Get the dictionary of global attributes of this dataset.
"""
return self._attrs

attrs: dict = property(_get_attrs, _set_attrs)

def at_time(self, time, asof=False, axis=None) -> Self: # noqa: PR01, RT01, D200
"""
Select values at particular time of day (e.g., 9:30AM).
Expand Down Expand Up @@ -3221,6 +3259,7 @@ def tail(self, n=5) -> Self: # noqa: PR01, RT01, D200
return self.iloc[-n:]
return self.iloc[len(self) :]

@propagate_self_attrs
def take(self, indices, axis=0, **kwargs) -> Self: # noqa: PR01, RT01, D200
"""
Return the elements in the given *positional* indices along an axis.
Expand Down
18 changes: 5 additions & 13 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,11 @@ def __init__(
# Siblings are other dataframes that share the same query compiler. We
# use this list to update inplace when there is a shallow copy.
self._siblings = []
self._attrs = {}
if isinstance(data, (DataFrame, Series)):
self._query_compiler = data._query_compiler.copy()
if len(data._attrs):
self._attrs = copy.deepcopy(data._attrs)
if index is not None and any(i not in data.index for i in index):
raise NotImplementedError(
"Passing non-existant columns or index values to constructor not"
Expand Down Expand Up @@ -2636,12 +2639,12 @@ def __setattr__(self, key, value) -> None:
# - anything in self.__dict__. This includes any attributes that the
# user has added to the dataframe with, e.g., `df.c = 3`, and
# any attribute that Modin has added to the frame, e.g.
# `_query_compiler` and `_siblings`
# `_query_compiler`, `_siblings`, and "_attrs"
# - `_query_compiler`, which Modin initializes before it appears in
# __dict__
# - `_siblings`, which Modin initializes before it appears in __dict__
# before it appears in __dict__.
if key in ("_query_compiler", "_siblings") or key in self.__dict__:
if key in ("_attrs", "_query_compiler", "_siblings") or key in self.__dict__:
pass
# we have to check for the key in `dir(self)` first in order not to trigger columns computation
elif key not in dir(self) and key in self:
Expand Down Expand Up @@ -2938,17 +2941,6 @@ def __dataframe_consortium_standard__(
)
return convert_to_standard_compliant_dataframe(self, api_version=api_version)

@property
def attrs(self) -> dict: # noqa: RT01, D200
"""
Return dictionary of global attributes of this dataset.
"""

def attrs(df):
return df.attrs

return self._default_to_pandas(attrs)

@property
def style(self): # noqa: RT01, D200
"""
Expand Down
14 changes: 3 additions & 11 deletions modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ def __init__(
# Siblings are other dataframes that share the same query compiler. We
# use this list to update inplace when there is a shallow copy.
self._siblings = []
self._attrs = {}
if isinstance(data, type(self)):
query_compiler = data._query_compiler.copy()
if len(data._attrs):
self._attrs = copy.deepcopy(data._attrs)
if index is not None:
if any(i not in data.index for i in index):
raise NotImplementedError(
Expand Down Expand Up @@ -2264,17 +2267,6 @@ def where(
level=level,
)

@property
def attrs(self) -> dict: # noqa: RT01, D200
"""
Return dictionary of global attributes of this dataset.
"""

def attrs(df):
return df.attrs

return self._default_to_pandas(attrs)

@property
def array(self) -> ExtensionArray: # noqa: RT01, D200
"""
Expand Down

0 comments on commit cee01de

Please sign in to comment.