Skip to content

Commit

Permalink
ENH: Add case_when method (pandas-dev#56059)
Browse files Browse the repository at this point in the history
  • Loading branch information
samukweku authored and pmhatre1 committed May 7, 2024
1 parent c270eb3 commit aeaf086
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ Reindexing / selection / label manipulation
:toctree: api/

Series.align
Series.case_when
Series.drop
Series.droplevel
Series.drop_duplicates
Expand Down
20 changes: 20 additions & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,26 @@ For a full list of ADBC drivers and their development status, see the `ADBC Driv
Implementation Status <https://arrow.apache.org/adbc/current/driver/status.html>`_
documentation.

.. _whatsnew_220.enhancements.case_when:

Create a pandas Series based on one or more conditions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The :meth:`Series.case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`)

.. ipython:: python
import pandas as pd
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
default=pd.Series('default', index=df.index)
default.case_when(
caselist=[
(df.a == 1, 'first'), # condition, replacement
(df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement
],
)
.. _whatsnew_220.enhancements.to_numpy_ea:

``to_numpy`` for NumPy nullable and Arrow types converts to suitable NumPy dtype
Expand Down
124 changes: 123 additions & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
from pandas.core.dtypes.astype import astype_is_view
from pandas.core.dtypes.cast import (
LossySetitemError,
construct_1d_arraylike_from_scalar,
find_common_type,
infer_dtype_from,
maybe_box_native,
maybe_cast_pointwise_result,
)
Expand All @@ -85,7 +88,10 @@
CategoricalDtype,
ExtensionDtype,
)
from pandas.core.dtypes.generic import ABCDataFrame
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
)
from pandas.core.dtypes.inference import is_hashable
from pandas.core.dtypes.missing import (
isna,
Expand Down Expand Up @@ -114,6 +120,7 @@
from pandas.core.arrays.sparse import SparseAccessor
from pandas.core.arrays.string_ import StringDtype
from pandas.core.construction import (
array as pd_array,
extract_array,
sanitize_array,
)
Expand Down Expand Up @@ -5631,6 +5638,121 @@ def between(

return lmask & rmask

def case_when(
self,
caselist: list[
tuple[
ArrayLike | Callable[[Series], Series | np.ndarray | Sequence[bool]],
ArrayLike | Scalar | Callable[[Series], Series | np.ndarray],
],
],
) -> Series:
"""
Replace values where the conditions are True.
Parameters
----------
caselist : A list of tuples of conditions and expected replacements
Takes the form: ``(condition0, replacement0)``,
``(condition1, replacement1)``, ... .
``condition`` should be a 1-D boolean array-like object
or a callable. If ``condition`` is a callable,
it is computed on the Series
and should return a boolean Series or array.
The callable must not change the input Series
(though pandas doesn`t check it). ``replacement`` should be a
1-D array-like object, a scalar or a callable.
If ``replacement`` is a callable, it is computed on the Series
and should return a scalar or Series. The callable
must not change the input Series
(though pandas doesn`t check it).
.. versionadded:: 2.2.0
Returns
-------
Series
See Also
--------
Series.mask : Replace values where the condition is True.
Examples
--------
>>> c = pd.Series([6, 7, 8, 9], name='c')
>>> a = pd.Series([0, 0, 1, 2])
>>> b = pd.Series([0, 3, 4, 5])
>>> c.case_when(caselist=[(a.gt(0), a), # condition, replacement
... (b.gt(0), b)])
0 6
1 3
2 1
3 2
Name: c, dtype: int64
"""
if not isinstance(caselist, list):
raise TypeError(
f"The caselist argument should be a list; instead got {type(caselist)}"
)

if not caselist:
raise ValueError(
"provide at least one boolean condition, "
"with a corresponding replacement."
)

for num, entry in enumerate(caselist):
if not isinstance(entry, tuple):
raise TypeError(
f"Argument {num} must be a tuple; instead got {type(entry)}."
)
if len(entry) != 2:
raise ValueError(
f"Argument {num} must have length 2; "
"a condition and replacement; "
f"instead got length {len(entry)}."
)
caselist = [
(
com.apply_if_callable(condition, self),
com.apply_if_callable(replacement, self),
)
for condition, replacement in caselist
]
default = self.copy()
conditions, replacements = zip(*caselist)
common_dtypes = [infer_dtype_from(arg)[0] for arg in [*replacements, default]]
if len(set(common_dtypes)) > 1:
common_dtype = find_common_type(common_dtypes)
updated_replacements = []
for condition, replacement in zip(conditions, replacements):
if is_scalar(replacement):
replacement = construct_1d_arraylike_from_scalar(
value=replacement, length=len(condition), dtype=common_dtype
)
elif isinstance(replacement, ABCSeries):
replacement = replacement.astype(common_dtype)
else:
replacement = pd_array(replacement, dtype=common_dtype)
updated_replacements.append(replacement)
replacements = updated_replacements
default = default.astype(common_dtype)

counter = reversed(range(len(conditions)))
for position, condition, replacement in zip(
counter, conditions[::-1], replacements[::-1]
):
try:
default = default.mask(
condition, other=replacement, axis=0, inplace=False, level=None
)
except Exception as error:
raise ValueError(
f"Failed to apply condition{position} and replacement{position}."
) from error
return default

# error: Cannot determine type of 'isna'
@doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type]
def isna(self) -> Series:
Expand Down
148 changes: 148 additions & 0 deletions pandas/tests/series/methods/test_case_when.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
import pytest

from pandas import (
DataFrame,
Series,
array as pd_array,
date_range,
)
import pandas._testing as tm


@pytest.fixture
def df():
"""
base dataframe for testing
"""
return DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


def test_case_when_caselist_is_not_a_list(df):
"""
Raise ValueError if caselist is not a list.
"""
msg = "The caselist argument should be a list; "
msg += "instead got.+"
with pytest.raises(TypeError, match=msg): # GH39154
df["a"].case_when(caselist=())


def test_case_when_no_caselist(df):
"""
Raise ValueError if no caselist is provided.
"""
msg = "provide at least one boolean condition, "
msg += "with a corresponding replacement."
with pytest.raises(ValueError, match=msg): # GH39154
df["a"].case_when([])


def test_case_when_odd_caselist(df):
"""
Raise ValueError if no of caselist is odd.
"""
msg = "Argument 0 must have length 2; "
msg += "a condition and replacement; instead got length 3."

with pytest.raises(ValueError, match=msg):
df["a"].case_when([(df["a"].eq(1), 1, df.a.gt(1))])


def test_case_when_raise_error_from_mask(df):
"""
Raise Error from within Series.mask
"""
msg = "Failed to apply condition0 and replacement0."
with pytest.raises(ValueError, match=msg):
df["a"].case_when([(df["a"].eq(1), [1, 2])])


def test_case_when_single_condition(df):
"""
Test output on a single condition.
"""
result = Series([np.nan, np.nan, np.nan]).case_when([(df.a.eq(1), 1)])
expected = Series([1, np.nan, np.nan])
tm.assert_series_equal(result, expected)


def test_case_when_multiple_conditions(df):
"""
Test output when booleans are derived from a computation
"""
result = Series([np.nan, np.nan, np.nan]).case_when(
[(df.a.eq(1), 1), (Series([False, True, False]), 2)]
)
expected = Series([1, 2, np.nan])
tm.assert_series_equal(result, expected)


def test_case_when_multiple_conditions_replacement_list(df):
"""
Test output when replacement is a list
"""
result = Series([np.nan, np.nan, np.nan]).case_when(
[([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), [1, 2, 3])]
)
expected = Series([1, 2, np.nan])
tm.assert_series_equal(result, expected)


def test_case_when_multiple_conditions_replacement_extension_dtype(df):
"""
Test output when replacement has an extension dtype
"""
result = Series([np.nan, np.nan, np.nan]).case_when(
[
([True, False, False], 1),
(df["a"].gt(1) & df["b"].eq(5), pd_array([1, 2, 3], dtype="Int64")),
],
)
expected = Series([1, 2, np.nan], dtype="Float64")
tm.assert_series_equal(result, expected)


def test_case_when_multiple_conditions_replacement_series(df):
"""
Test output when replacement is a Series
"""
result = Series([np.nan, np.nan, np.nan]).case_when(
[
(np.array([True, False, False]), 1),
(df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])),
],
)
expected = Series([1, 2, np.nan])
tm.assert_series_equal(result, expected)


def test_case_when_non_range_index():
"""
Test output if index is not RangeIndex
"""
rng = np.random.default_rng(seed=123)
dates = date_range("1/1/2000", periods=8)
df = DataFrame(
rng.standard_normal(size=(8, 4)), index=dates, columns=["A", "B", "C", "D"]
)
result = Series(5, index=df.index, name="A").case_when([(df.A.gt(0), df.B)])
expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5)
tm.assert_series_equal(result, expected)


def test_case_when_callable():
"""
Test output on a callable
"""
# https://numpy.org/doc/stable/reference/generated/numpy.piecewise.html
x = np.linspace(-2.5, 2.5, 6)
ser = Series(x)
result = ser.case_when(
caselist=[
(lambda df: df < 0, lambda df: -df),
(lambda df: df >= 0, lambda df: df),
]
)
expected = np.piecewise(x, [x < 0, x >= 0], [lambda x: -x, lambda x: x])
tm.assert_series_equal(result, Series(expected))

0 comments on commit aeaf086

Please sign in to comment.