Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: EA._from_scalars #53089

Merged
merged 18 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,22 @@ def _from_sequence(
) -> Self:
return cls(scalars, dtype=dtype, copy=copy)

@classmethod
def _from_scalars(cls, scalars, dtype=None):
if dtype is None:
# The _from_scalars strictness doesn't make much sense in this case.
raise NotImplementedError

res = cls._from_sequence(scalars, dtype=dtype)

# if there are any non-category elements in scalars, these will be
# converted to NAs in res.
mask = isna(scalars)
if not (mask == res.isna()).all():
# Some non-category element in scalars got converted to NA in res.
raise ValueError
return res

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...
Expand Down
8 changes: 8 additions & 0 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,14 @@ def _scalar_type(self) -> type[Timestamp]:
_freq: BaseOffset | None = None
_default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__

@classmethod
def _from_scalars(cls, scalars, dtype=None):
if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]:
# TODO: require any NAs be valid-for-DTA
# TODO: if dtype is passed, check for tzawareness compat?
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)

@classmethod
def _validate_dtype(cls, values, dtype):
# used in TimeLikeOps.__init__
Expand Down
8 changes: 8 additions & 0 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
NumpySorter,
NumpyValueArrayLike,
Scalar,
Self,
npt,
type_t,
)
Expand Down Expand Up @@ -228,6 +229,13 @@ def tolist(self):
return [x.tolist() for x in self]
return list(self.to_numpy())

@classmethod
def _from_scalars(cls, scalars, dtype=None) -> Self:
if lib.infer_dtype(scalars, skipna=True) != "string":
# TODO: require any NAs be valid-for-string
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)


# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
Expand Down
29 changes: 15 additions & 14 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,16 +455,11 @@ def maybe_cast_pointwise_result(
"""

if isinstance(dtype, ExtensionDtype):
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
# TODO: avoid this special-casing
# We have to special case categorical so as not to upcast
# things like counts back to categorical

cls = dtype.construct_array_type()
if same_dtype:
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = _maybe_cast_to_extension_array(cls, result)
cls = dtype.construct_array_type()
if same_dtype:
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = _maybe_cast_to_extension_array(cls, result)

elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
result = maybe_downcast_to_dtype(result, dtype)
Expand All @@ -489,11 +484,17 @@ def _maybe_cast_to_extension_array(
-------
ExtensionArray or obj
"""
from pandas.core.arrays.string_ import BaseStringArray

# Everything can be converted to StringArrays, but we may not want to convert
if issubclass(cls, BaseStringArray) and lib.infer_dtype(obj) != "string":
return obj
if hasattr(cls, "_from_scalars"):
# TODO: get this everywhere!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably add a simple _from_scalars in the base class that just calls _from_sequence. We will need some fallback like that for external EAs anyway (unless we keep this hasattr check, but since we fallback to _from_sequence below anyway, can better do this in _from_scalars I think)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed this makes sense.

ATM _from_sequence can raise anything, so we catch anything, which isn't an ideal pattern. Thoughts on documenting _from_scalars as only raising ValueError/TypeError and eventually enforcing that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on documenting _from_scalars as only raising ValueError/TypeError and eventually enforcing that?

Yes, I think that's certainly fine

try:
result = cls._from_scalars(obj, dtype=dtype)
except (TypeError, ValueError, NotImplementedError):
# TODO: document that _from_scalars should only raise ValueError
# or TypeError; NotImplementedError is here until we decide what
# to do for Categorical.
return obj
return result

try:
result = cls._from_sequence(obj, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/resample/test_timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_resample_categorical_data_with_timedeltaindex():
index=pd.TimedeltaIndex([0, 10], unit="s", freq="10s"),
)
expected = expected.reindex(["Group_obj", "Group"], axis=1)
expected["Group"] = expected["Group_obj"]
expected["Group"] = expected["Group_obj"].astype("category")
tm.assert_frame_equal(result, expected)


Expand Down