Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] case_when function #55390

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@ including other versions of pandas.
Enhancements
~~~~~~~~~~~~

.. _whatsnew_220.enhancements.case_when:

Create a pandas Series based on one or more conditions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The :func:`case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`)

.. ipython:: python

import pandas as pd

df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
pd.case_when(
(df.a == 1, 'first'), # condition, replacement
(df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement
default = 'default', # optional
rhshadrach marked this conversation as resolved.
Show resolved Hide resolved
)


.. _whatsnew_220.enhancements.calamine:

Calamine engine for :func:`read_excel`
Expand Down
2 changes: 2 additions & 0 deletions pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
notnull,
# indexes
Index,
case_when,
CategoricalIndex,
RangeIndex,
MultiIndex,
Expand Down Expand Up @@ -252,6 +253,7 @@
__all__ = [
"ArrowDtype",
"BooleanDtype",
"case_when",
"Categorical",
"CategoricalDtype",
"CategoricalIndex",
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
UInt64Dtype,
)
from pandas.core.arrays.string_ import StringDtype
from pandas.core.case_when import case_when
from pandas.core.construction import array
from pandas.core.flags import Flags
from pandas.core.groupby import (
Expand Down Expand Up @@ -86,6 +87,7 @@
"bdate_range",
"BooleanDtype",
"Categorical",
"case_when",
"CategoricalDtype",
"CategoricalIndex",
"DataFrame",
Expand Down
218 changes: 218 additions & 0 deletions pandas/core/case_when.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from pandas._libs import lib

from pandas.core.dtypes.cast import (
construct_1d_arraylike_from_scalar,
find_common_type,
infer_dtype_from,
)
from pandas.core.dtypes.common import is_scalar
from pandas.core.dtypes.generic import ABCSeries

from pandas.core.construction import array as pd_array

if TYPE_CHECKING:
from pandas._typing import Series


def case_when(
*args: tuple[tuple],
default=lib.no_default,
) -> Series:
"""
Replace values where the conditions are True.

Parameters
----------
*args : tuple(s) of array-like, scalar
Variable argument of tuples of conditions and expected replacements.
Takes the form: ``(condition0, replacement0)``,
``(condition1, replacement1)``, ... .
``condition`` should be a 1-D boolean array.
When multiple boolean conditions are satisfied,
the first replacement is used.
If ``condition`` is a Series, and the equivalent ``replacement``
is a Series, they must have the same index.
If there are multiple replacement options,
and they are Series, they must have the same index.

default : scalar, array-like, default None
If provided, it is the replacement value to use
if all conditions evaluate to False.
If not specified, entries will be filled with the
corresponding NULL value.

.. versionadded:: 2.2.0

Returns
-------
Series

See Also
--------
Series.mask : Replace values where the condition is True.

Examples
--------
>>> df = pd.DataFrame({
... "a": [0,0,1,2],
... "b": [0,3,4,5],
... "c": [6,7,8,9]
... })
>>> df
a b c
0 0 0 6
1 0 3 7
2 1 4 8
3 2 5 9

>>> pd.case_when((df.a.gt(0), df.a), # condition, replacement
... (df.b.gt(0), df.b), # condition, replacement
... default=df.c) # optional
0 6
1 3
2 1
3 2
Name: c, dtype: int64
"""
from pandas import Series
from pandas._testing.asserters import assert_index_equal

args = validate_case_when(args=args)

conditions, replacements = zip(*args)
common_dtypes = [infer_dtype_from(replacement)[0] for replacement in replacements]

if default is not lib.no_default:
arg_dtype, _ = infer_dtype_from(default)
common_dtypes.append(arg_dtype)
else:
default = None
if len(set(common_dtypes)) > 1:
common_dtypes = find_common_type(common_dtypes)
updated_replacements = []
for condition, replacement in zip(conditions, replacements):
if is_scalar(replacement):
replacement = construct_1d_arraylike_from_scalar(
value=replacement, length=len(condition), dtype=common_dtypes
)
elif isinstance(replacement, ABCSeries):
replacement = replacement.astype(common_dtypes)
else:
replacement = pd_array(replacement, dtype=common_dtypes)
updated_replacements.append(replacement)
replacements = updated_replacements
if (default is not None) and isinstance(default, ABCSeries):
default = default.astype(common_dtypes)
else:
common_dtypes = common_dtypes[0]
if not isinstance(default, ABCSeries):
cond_indices = [cond for cond in conditions if isinstance(cond, ABCSeries)]
replacement_indices = [
replacement
for replacement in replacements
if isinstance(replacement, ABCSeries)
]
cond_length = None
if replacement_indices:
for left, right in zip(replacement_indices, replacement_indices[1:]):
try:
assert_index_equal(left.index, right.index, check_order=False)
except AssertionError:
raise AssertionError(
"All replacement objects must have the same index."
samukweku marked this conversation as resolved.
Show resolved Hide resolved
)
if cond_indices:
for left, right in zip(cond_indices, cond_indices[1:]):
try:
assert_index_equal(left.index, right.index, check_order=False)
except AssertionError:
raise AssertionError(
"All condition objects must have the same index."
)
if replacement_indices:
try:
assert_index_equal(
replacement_indices[0].index,
cond_indices[0].index,
check_order=False,
)
except AssertionError:
raise ValueError(
"All replacement objects and condition objects "
"should have the same index."
)
else:
conditions = [
np.asanyarray(cond) if not hasattr(cond, "shape") else cond
for cond in conditions
]
cond_length = {len(cond) for cond in conditions}
if len(cond_length) > 1:
raise ValueError("The boolean conditions should have the same length.")
rhshadrach marked this conversation as resolved.
Show resolved Hide resolved
cond_length = len(conditions[0])
if not is_scalar(default):
if len(default) != cond_length:
raise ValueError(
"length of `default` does not match the length "
"of any of the conditions."
)
if not replacement_indices:
for num, replacement in enumerate(replacements):
if is_scalar(replacement):
continue
if not hasattr(replacement, "shape"):
replacement = np.asanyarray(replacement)
if len(replacement) != cond_length:
raise ValueError(
f"Length of condition{num} does not match "
f"the length of replacement{num}; "
rhshadrach marked this conversation as resolved.
Show resolved Hide resolved
f"{cond_length} != {len(replacement)}"
)
if cond_indices:
default_index = cond_indices[0].index
elif replacement_indices:
default_index = replacement_indices[0].index
else:
default_index = range(cond_length)
default = Series(default, index=default_index, dtype=common_dtypes)
counter = reversed(range(len(conditions)))
for position, condition, replacement in zip(
counter, conditions[::-1], replacements[::-1]
):
try:
default = default.mask(
condition, other=replacement, axis=0, inplace=False, level=None
)
except Exception as error:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove this try except and have mask raise it's error normally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here is to keep track of which condition, replacement failed. More like condition1 failed, this is why it failed. Devolving to mask error directly and you lose the error tracking. I assume the error tracking would be useful to the user.

raise ValueError(
f"condition{position} and replacement{position} failed to evaluate."
) from error
return default


def validate_case_when(args: tuple) -> tuple:
"""
Validates the variable arguments for the case_when function.
"""
if not len(args):
raise ValueError(
"provide at least one boolean condition, "
"with a corresponding replacement."
)

for num, entry in enumerate(args):
if not isinstance(entry, tuple):
raise TypeError(f"Argument {num} must be a tuple; got {type(entry)}.")
if len(entry) != 2:
raise ValueError(
f"Argument {num} must have length 2; "
"a condition and replacement; "
f"got length {len(entry)}."
)
return args
78 changes: 78 additions & 0 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.arrays.sparse import SparseAccessor
from pandas.core.arrays.string_ import StringDtype
from pandas.core.case_when import (
case_when,
validate_case_when,
)
from pandas.core.construction import (
extract_array,
sanitize_array,
Expand Down Expand Up @@ -5554,6 +5558,80 @@ def between(

return lmask & rmask

def case_when(
self,
*args: tuple[tuple],
) -> Series:
"""
Replace values where the conditions are True.

Parameters
----------
*args : tuple(s) of array-like, scalar.
Variable argument of tuples of conditions and expected replacements.
Takes the form: ``(condition0, replacement0)``,
``(condition1, replacement1)``, ... .
``condition`` should be a 1-D boolean array-like object
or a callable. If ``condition`` is a callable,
it is computed on the Series
and should return a boolean Series or array.
The callable must not change the input Series
(though pandas doesn`t check it). ``replacement`` should be a
1-D array-like object, a scalar or a callable.
If ``replacement`` is a callable, it is computed on the Series
and should return a scalar or Series. The callable
must not change the input Series
(though pandas doesn`t check it).
If ``condition`` is a Series, and the equivalent ``replacement``
is a Series, they must have the same index.
If there are multiple replacement options,
and they are Series, they must have the same index.

level : int, default None
Alignment level if needed.

.. versionadded:: 2.2.0

Returns
-------
Series

See Also
--------
Series.mask : Replace values where the condition is True.

Examples
--------
>>> df = pd.DataFrame({
... "a": [0,0,1,2],
... "b": [0,3,4,5],
... "c": [6,7,8,9]
... })
>>> df
a b c
0 0 0 6
1 0 3 7
2 1 4 8
3 2 5 9

>>> df.c.case_when((df.a.gt(0), df.a), # condition, replacement
... (df.b.gt(0), df.b))
0 6
1 3
2 1
3 2
Name: c, dtype: int64
"""
args = validate_case_when(args)
args = [
(
com.apply_if_callable(condition, self),
com.apply_if_callable(replacement, self),
)
for condition, replacement in args
]
samukweku marked this conversation as resolved.
Show resolved Hide resolved
return case_when(*args, default=self, level=None)

# error: Cannot determine type of 'isna'
@doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type]
def isna(self) -> Series:
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class TestPDApi(Base):
funcs = [
"array",
"bdate_range",
"case_when",
"concat",
"crosstab",
"cut",
Expand Down
Loading
Loading