Skip to content

Commit

Permalink
use tuples for conditons and replacements; add index checking logic
Browse files Browse the repository at this point in the history
  • Loading branch information
samukweku committed Oct 28, 2023
1 parent cf6d1d1 commit c9a300d
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 65 deletions.
143 changes: 108 additions & 35 deletions pandas/core/case_when.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from __future__ import annotations

from itertools import zip_longest
from typing import TYPE_CHECKING

import numpy as np

from pandas.core.dtypes.cast import (
construct_1d_arraylike_from_scalar,
find_common_type,
infer_dtype_from,
)
from pandas.core.dtypes.common import is_scalar
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCNDFrame,
)

from pandas.core.construction import array as pd_array

Expand All @@ -21,7 +27,7 @@


def case_when(
*args: ListLike | Scalar,
*args: tuple[tuple],
default: Scalar | ListLike | None = None,
level: int | None = None,
) -> Series:
Expand All @@ -33,9 +39,9 @@ def case_when(
Parameters
----------
*args : array-like, scalar
Variable argument of conditions and expected replacements.
Takes the form: `condition0`, `replacement0`,
`condition1`, `replacement1`, ... .
Variable argument of tuples of conditions and expected replacements.
Takes the form: (`condition0`, `replacement0`),
(`condition1`, `replacement1`), ... .
`condition` should be a 1-D boolean array.
When multiple boolean conditions are satisfied,
the first replacement is used.
Expand Down Expand Up @@ -74,8 +80,8 @@ def case_when(
2 1 4 8
3 2 5 9
>>> pd.case_when(df.a.gt(0), df.a, # condition, replacement
... df.b.gt(0), df.b,
>>> pd.case_when((df.a.gt(0), df.a), # condition, replacement
... (df.b.gt(0), df.b), # condition, replacement
... default=df.c) # optional
0 6
1 3
Expand All @@ -84,31 +90,41 @@ def case_when(
Name: c, dtype: int64
"""
from pandas import Series
from pandas._testing.asserters import assert_index_equal

# summary of logic -
# get booleans and replacements
# and iteratively get the final output via Series.mask
# iteration starts from the end,
# ensuring that if there are any booleans with the same replacement,
# the first one is taken
# as much as possible the computation is pushed to Series.mask
len_args = len(args)
if not len_args:
# as much as possible the computation is pushed to Series.masklen_args = len(args)
if not len(args):
raise ValueError(
"Kindly provide at least one boolean condition, "
"with a corresponding replacement."
)
if len_args % 2:
raise ValueError(
"The number of boolean conditions should be equal "
"to the number of replacements. "
"However, the total number of conditions and replacements "
f"is {len(args)}, which is an odd number."
)

conditions = args[-2::-2]
replacements = args[-1::-2]
common_dtype = [infer_dtype_from(replacement)[0] for replacement in replacements]
for num, entry in enumerate(args):
if not isinstance(entry, tuple):
raise TypeError(
f"Argument{num} should be a tuple, instead, got {type(entry)}."
)
if len(entry) != 2:
raise ValueError(
f"Argument{num} should have a condition "
"and a replacement - it should have a length 2; "
f"however, it has a length {len(entry)}."
)
conditions, replacements = zip(*args)
common_dtype = []
for replacement in replacements:
if isinstance(replacement, ABCDataFrame):
_dtype = [arr.dtype for _, arr in replacement.items()]
common_dtype.extend(_dtype)
else:
_dtype = infer_dtype_from(replacement)[0]
common_dtype.append(_dtype)
if default is not None:
arg_dtype, _ = infer_dtype_from(default)
common_dtype.append(arg_dtype)
Expand All @@ -119,26 +135,83 @@ def case_when(
value=replacement, length=len(condition), dtype=common_dtype
)
if is_scalar(replacement)
else replacement.astype(common_dtype, copy=False)
if isinstance(replacement, ABCSeries)
else pd_array(replacement, dtype=common_dtype, copy=False)
else replacement.astype(common_dtype)
if isinstance(replacement, ABCDataFrame)
else pd_array(replacement, dtype=common_dtype)
for condition, replacement in zip(conditions, replacements)
]
if (default is not None) and isinstance(default, ABCSeries):
default = default.astype(common_dtype, copy=False)
if (default is not None) and isinstance(default, ABCNDFrame):
default = default.astype(common_dtype)
else:
common_dtype = common_dtype[0]
# TODO: possibly extend this to a DataFrame?
if not isinstance(default, ABCSeries):
default_index = len(args[0])
for condition in conditions[::-1]:
if isinstance(condition, ABCSeries):
default_index = condition.index
break
if not isinstance(default, ABCNDFrame):
cond_series = [cond for cond in conditions if isinstance(cond, ABCNDFrame)]
replacement_series = [
replacement
for replacement in replacements
if isinstance(replacement, ABCNDFrame)
]
cond_length = None
if replacement_series:
for left, right in zip_longest(replacement_series, replacement_series[1:]):
if right is None:
continue
if assert_index_equal(left.index, right.index, check_order=False):
raise ValueError(
"All replacement objects must have the same index."
)
if cond_series:
for left, right in zip_longest(cond_series, cond_series[1:]):
if right is None:
continue
if assert_index_equal(left.index, right.index, check_order=False):
raise ValueError("All condition objects must have the same index.")
if replacement_series and cond_series:
if assert_index_equal(
replacement_series[0].index, cond_series[0].index, check_order=False
):
raise ValueError(
"All replacement objects and condition objects "
"should have the same index."
)
elif not cond_series:
conditions = [
np.asanyarray(cond) if not hasattr(cond, "shape") else cond
for cond in conditions
]
cond_length = {len(cond) for cond in conditions}
if len(cond_length) > 1:
raise ValueError("The boolean conditions should be of the same length.")
cond_length = len(conditions[0])
if not is_scalar(default):
if len(default) != cond_length:
raise ValueError(
"length of `default` does not match the length "
"of any of the boolean conditions."
)
if not replacement_series:
for num, replacement in enumerate(replacements):
if is_scalar(replacement):
continue
if not hasattr(replacement, "shape"):
replacement = np.asanyarray(replacement)
if len(replacement) != cond_length:
raise ValueError(
f"Length of condition{num} does not match "
f"the length of replacement{num}; "
f"{cond_length} != {len(replacement)}"
)
if cond_series:
default_index = cond_series[0].index
elif replacement_series:
default_index = replacement_series[0].index
elif not cond_series and not replacement_series:
default_index = range(cond_length)
default = Series(default, index=default_index, dtype=common_dtype)
counter = len_args // 2 - 1
counter = range(counter, -1, -1)
for position, condition, replacement in zip(counter, conditions, replacements):
counter = reversed(range(len(conditions)))
for position, condition, replacement in zip(
counter, conditions[::-1], replacements[::-1]
):
try:
default = default.mask(
condition, other=replacement, axis=0, inplace=False, level=level
Expand Down
15 changes: 7 additions & 8 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@
IndexKeyFunc,
IndexLabel,
Level,
ListLike,
MutableMappingT,
NaPosition,
NumpySorter,
Expand Down Expand Up @@ -5555,7 +5554,7 @@ def between(

def case_when(
self,
*args: ListLike | Callable | Scalar,
*args: tuple[tuple],
level: int | None = None,
) -> Series:
"""
Expand All @@ -5565,10 +5564,10 @@ def case_when(
Parameters
----------
*args : array-like, scalar
Variable argument of conditions and expected replacements.
Takes the form: `condition0`, `replacement0`,
`condition1`, `replacement1`, ... .
*args : tuple
Variable argument of tuples of conditions and expected replacements.
Takes the form: (`condition0`, `replacement0`),
(`condition1`, `replacement1`), ... .
`condition` should be a 1-D boolean array-like object
or a callable. If `condition` is a callable,
it is computed on the Series
Expand Down Expand Up @@ -5608,8 +5607,8 @@ def case_when(
2 1 4 8
3 2 5 9
>>> df.c.case_when(df.a.gt(0), df.a, # condition, replacement
... df.b.gt(0), df.b)
>>> df.c.case_when((df.a.gt(0), df.a), # condition, replacement
... (df.b.gt(0), df.b))
0 6
1 3
2 1
Expand Down
37 changes: 15 additions & 22 deletions pandas/tests/test_case_when.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ def test_case_when_odd_args(df):
"""
Raise ValueError if no of args is odd.
"""
msg = "The number of boolean conditions should be equal "
msg += "to the number of replacements. "
msg += "However, the total number of conditions and replacements "
msg += "is 3, which is an odd number."
msg = "Argument0 should have a condition "
msg += "and a replacement - it should have a length 2; "
msg += "however, it has a length 3."
with pytest.raises(ValueError, match=msg):
case_when(df["a"].eq(1), 1, df.a.gt(1))
case_when((df["a"].eq(1), 1, df.a.gt(1)))


def test_case_when_raise_error_from_mask(df):
Expand All @@ -48,14 +47,14 @@ def test_case_when_raise_error_from_mask(df):
msg = "condition0 and replacement0 failed to evaluate. "
msg += "Original error message.+"
with pytest.raises(ValueError, match=msg):
case_when(df["a"].eq(1), df)
case_when((df["a"].eq(1), df))


def test_case_when_single_condition(df):
"""
Test output on a single condition.
"""
result = case_when(df.a.eq(1), 1)
result = case_when((df.a.eq(1), 1))
expected = Series([1, np.nan, np.nan])
tm.assert_series_equal(result, expected)

Expand All @@ -64,7 +63,7 @@ def test_case_when_multiple_conditions(df):
"""
Test output when booleans are derived from a computation
"""
result = case_when(df.a.eq(1), 1, Series([False, True, False]), 2)
result = case_when((df.a.eq(1), 1), (Series([False, True, False]), 2))
expected = Series([1, 2, np.nan])
tm.assert_series_equal(result, expected)

Expand All @@ -74,7 +73,7 @@ def test_case_when_multiple_conditions_replacement_list(df):
Test output when replacement is a list
"""
result = case_when(
[True, False, False], 1, df["a"].gt(1) & df["b"].eq(5), [1, 2, 3]
([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), [1, 2, 3])
)
expected = Series([1, 2, np.nan])
tm.assert_series_equal(result, expected)
Expand All @@ -85,10 +84,8 @@ def test_case_when_multiple_conditions_replacement_extension_dtype(df):
Test output when replacement has an extension dtype
"""
result = case_when(
[True, False, False],
1,
df["a"].gt(1) & df["b"].eq(5),
pd_array([1, 2, 3], dtype="Int64"),
([True, False, False], 1),
(df["a"].gt(1) & df["b"].eq(5), pd_array([1, 2, 3], dtype="Int64")),
)
expected = Series([1, 2, np.nan], dtype="Int64")
tm.assert_series_equal(result, expected)
Expand All @@ -99,10 +96,8 @@ def test_case_when_multiple_conditions_replacement_series(df):
Test output when replacement is a Series
"""
result = case_when(
np.array([True, False, False]),
1,
df["a"].gt(1) & df["b"].eq(5),
Series([1, 2, 3]),
(np.array([True, False, False]), 1),
(df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])),
)
expected = Series([1, 2, np.nan])
tm.assert_series_equal(result, expected)
Expand All @@ -113,10 +108,8 @@ def test_case_when_multiple_conditions_default_is_not_none(df):
Test output when default is not None
"""
result = case_when(
[True, False, False],
1,
df["a"].gt(1) & df["b"].eq(5),
Series([1, 2, 3]),
([True, False, False], 1),
(df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])),
default=-1,
)
expected = Series([1, 2, -1])
Expand All @@ -132,7 +125,7 @@ def test_case_when_non_range_index():
df = DataFrame(
rng.standard_normal(size=(8, 4)), index=dates, columns=["A", "B", "C", "D"]
)
result = case_when(df.A.gt(0), df.B, default=5)
result = case_when((df.A.gt(0), df.B), default=5)
result = Series(result, name="A")
expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5)
tm.assert_series_equal(result, expected)

0 comments on commit c9a300d

Please sign in to comment.