Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add case_when method #56059

Merged
merged 43 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f48502f
updates
samukweku Nov 19, 2023
40057c7
add test for default if Series
samukweku Nov 19, 2023
4a8be16
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku Nov 23, 2023
089bbe6
updates based on feedback
samukweku Nov 23, 2023
bcfd458
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku Nov 30, 2023
b95ce55
updates based on feedback
samukweku Nov 30, 2023
acc3fdb
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku Nov 30, 2023
8be4349
update typing hints for *args, based on feedback
samukweku Nov 30, 2023
8d08458
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku Dec 1, 2023
ec18086
update typehints; add caselist argument - based on feedback
samukweku Dec 1, 2023
0b72fbb
cleanup docstrings
samukweku Dec 1, 2023
0085956
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku Dec 15, 2023
a441481
support method only for case_when
samukweku Dec 15, 2023
29ad697
minor update
samukweku Dec 15, 2023
bf740f9
fix test
samukweku Dec 15, 2023
264a675
remove redundant tests
samukweku Dec 15, 2023
2a3035e
cleanup docs
samukweku Dec 15, 2023
5e33304
use singular version - common_dtype
samukweku Dec 22, 2023
5c7c287
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku Dec 22, 2023
8569cd1
fix doctest failure
samukweku Dec 23, 2023
bbb5887
fix for whatnew
samukweku Dec 23, 2023
e03e3dc
Update doc/source/whatsnew/v2.2.0.rst
samukweku Dec 23, 2023
283488f
Update v2.2.0.rst
phofl Dec 23, 2023
7a8694c
Update v2.2.0.rst
phofl Dec 23, 2023
f6cf725
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku Dec 24, 2023
67dfcaa
improve typing and add test for callable
samukweku Dec 24, 2023
3da7cf2
fix typing error
samukweku Dec 24, 2023
bdc54f6
Update pandas/core/series.py
samukweku Dec 25, 2023
649fb84
Merge branch 'main' into samukweku/case_when_function
rhshadrach Dec 27, 2023
b68d20e
Update doc/source/whatsnew/v2.2.0.rst
samukweku Dec 28, 2023
b4de208
PERF: resolution, is_normalized (#56637)
jbrockmendel Dec 27, 2023
5966bfe
TYP: more simple return types from ruff (#56628)
twoertwein Dec 27, 2023
3e404fa
ENH: Update CFF with publication reference, Zenodo DOI, and other det…
cgobat Dec 27, 2023
21659bc
DOC: Fixup CoW userguide (#56636)
phofl Dec 27, 2023
f6d8cd0
REF: check monotonicity inside _can_use_libjoin (#55342)
jbrockmendel Dec 27, 2023
becc626
DOC: Minor fixups for 2.2.0 whatsnew (#56632)
rhshadrach Dec 27, 2023
918a19e
TYP: Fix some PythonParser and Plotting types (#56643)
twoertwein Dec 27, 2023
5744df2
BUG: Series.to_numpy raising for arrow floats to numpy floats (#56644)
phofl Dec 28, 2023
bc6ba0e
updates based on feedback
samukweku Dec 28, 2023
a0f4797
add to API reference
samukweku Dec 28, 2023
cb7d6e3
fix whitespace
samukweku Dec 28, 2023
c8f0e2e
updates
samukweku Dec 28, 2023
9679b9e
Update series.py
samukweku Jan 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@ including other versions of pandas.

Enhancements
~~~~~~~~~~~~
.. _whatsnew_220.enhancements.case_when:

Create a pandas Series based on one or more conditions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The :func:`case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`)

.. ipython:: python

import pandas as pd

df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
pd.case_when(
(df.a == 1, 'first'), # condition, replacement
(df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement
default = 'default', # optional
)

.. _whatsnew_220.enhancements.adbc_support:

Expand Down
2 changes: 2 additions & 0 deletions pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
notnull,
# indexes
Index,
case_when,
CategoricalIndex,
RangeIndex,
MultiIndex,
Expand Down Expand Up @@ -253,6 +254,7 @@
"ArrowDtype",
"BooleanDtype",
"Categorical",
"case_when",
"CategoricalDtype",
"CategoricalIndex",
"DataFrame",
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
UInt64Dtype,
)
from pandas.core.arrays.string_ import StringDtype
from pandas.core.case_when import case_when
from pandas.core.construction import array
from pandas.core.flags import Flags
from pandas.core.groupby import (
Expand Down Expand Up @@ -86,6 +87,7 @@
"bdate_range",
"BooleanDtype",
"Categorical",
"case_when",
"CategoricalDtype",
"CategoricalIndex",
"DataFrame",
Expand Down
212 changes: 212 additions & 0 deletions pandas/core/case_when.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from pandas._libs import lib

from pandas.core.dtypes.cast import (
construct_1d_arraylike_from_scalar,
find_common_type,
infer_dtype_from,
)
from pandas.core.dtypes.common import is_scalar
from pandas.core.dtypes.generic import ABCSeries

from pandas.core.construction import array as pd_array

if TYPE_CHECKING:
from pandas._typing import (
ArrayLike,
Scalar,
Series,
)


def case_when(
*args: tuple[tuple[tuple[ArrayLike], tuple[ArrayLike | Scalar]]],
samukweku marked this conversation as resolved.
Show resolved Hide resolved
default: ArrayLike | Scalar = lib.no_default,
) -> Series:
"""
Replace values where the conditions are True.

Parameters
----------
*args : tuple(s) of array-like, scalar
Variable argument of tuples of conditions and expected replacements.
Takes the form: ``(condition0, replacement0)``,
``(condition1, replacement1)``, ... .
``condition`` should be a 1-D boolean array.
When multiple boolean conditions are satisfied,
the first replacement is used.
If ``condition`` is a Series, and the equivalent ``replacement``
is a Series, they must have the same index.
If there are multiple replacement options,
and they are Series, they must have the same index.

default : scalar, array-like, default None
If provided, it is the replacement value to use
if all conditions evaluate to False.
If not specified, entries will be filled with the
corresponding NULL value.

.. versionadded:: 2.2.0

Returns
-------
Series

See Also
--------
Series.mask : Replace values where the condition is True.

Examples
--------
>>> df = pd.DataFrame({
... "a": [0,0,1,2],
... "b": [0,3,4,5],
... "c": [6,7,8,9]
... })
>>> df
a b c
0 0 0 6
1 0 3 7
2 1 4 8
3 2 5 9

>>> pd.case_when((df.a.gt(0), df.a), # condition, replacement
... (df.b.gt(0), df.b), # condition, replacement
... default=df.c) # optional
0 6
1 3
2 1
3 2
Name: c, dtype: int64
"""
from pandas import Series

validate_case_when(args=args)

conditions, replacements = zip(*args)
common_dtypes = [infer_dtype_from(replacement)[0] for replacement in replacements]

if default is not lib.no_default:
arg_dtype, _ = infer_dtype_from(default)
common_dtypes.append(arg_dtype)
else:
default = None
if len(set(common_dtypes)) > 1:
common_dtypes = find_common_type(common_dtypes)
updated_replacements = []
for condition, replacement in zip(conditions, replacements):
if is_scalar(replacement):
replacement = construct_1d_arraylike_from_scalar(
value=replacement, length=len(condition), dtype=common_dtypes
)
elif isinstance(replacement, ABCSeries):
replacement = replacement.astype(common_dtypes)
else:
replacement = pd_array(replacement, dtype=common_dtypes)
updated_replacements.append(replacement)
replacements = updated_replacements
if (default is not None) and isinstance(default, ABCSeries):
default = default.astype(common_dtypes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @MarcoGorelli

This might upcast, thoughts related to PDEP6?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is related to PDEP6, we are creating a new Series and not doing something like setitem.

else:
common_dtypes = common_dtypes[0]
if not isinstance(default, ABCSeries):
cond_indices = [cond for cond in conditions if isinstance(cond, ABCSeries)]
replacement_indices = [
replacement
for replacement in replacements
if isinstance(replacement, ABCSeries)
]
cond_length = None
if replacement_indices:
for left, right in zip(replacement_indices, replacement_indices[1:]):
if not left.index.equals(right.index):
raise AssertionError(
"All replacement objects must have the same index."
)
if cond_indices:
for left, right in zip(cond_indices, cond_indices[1:]):
if not left.index.equals(right.index):
raise AssertionError(
"All condition objects must have the same index."
)
if replacement_indices:
if not replacement_indices[0].index.equals(cond_indices[0].index):
raise AssertionError(
"All replacement objects and condition objects "
"should have the same index."
)
else:
conditions = [
np.asanyarray(cond) if not hasattr(cond, "shape") else cond
for cond in conditions
]
cond_length = {len(cond) for cond in conditions}
if len(cond_length) > 1:
raise ValueError("The boolean conditions should have the same length.")
cond_length = len(conditions[0])
if not is_scalar(default):
if len(default) != cond_length:
raise ValueError(
"length of `default` does not match the length "
"of any of the conditions."
)
if not replacement_indices:
for num, replacement in enumerate(replacements):
if is_scalar(replacement):
continue
if not hasattr(replacement, "shape"):
replacement = np.asanyarray(replacement)
if len(replacement) != cond_length:
raise ValueError(
f"Length of condition{num} does not match "
f"the length of replacement{num}; "
f"{cond_length} != {len(replacement)}"
)
if cond_indices:
default_index = cond_indices[0].index
elif replacement_indices:
default_index = replacement_indices[0].index
else:
default_index = range(cond_length)
default = Series(default, index=default_index, dtype=common_dtypes)
counter = reversed(range(len(conditions)))
for position, condition, replacement in zip(
counter, conditions[::-1], replacements[::-1]
):
try:
default = default.mask(
condition, other=replacement, axis=0, inplace=False, level=None
)
except Exception as error:
raise ValueError(
f"Failed to apply condition{position} and replacement{position}."
) from error
return default


def validate_case_when(args: tuple) -> tuple:
samukweku marked this conversation as resolved.
Show resolved Hide resolved
"""
Validates the variable arguments for the case_when function.
"""

if not len(args):
Dr-Irv marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"provide at least one boolean condition, "
"with a corresponding replacement."
)

for num, entry in enumerate(args):
if not isinstance(entry, tuple):
raise TypeError(f"Argument {num} must be a tuple; got {type(entry)}.")
if len(entry) != 2:
raise ValueError(
f"Argument {num} must have length 2; "
"a condition and replacement; "
f"got length {len(entry)}."
)
return None
78 changes: 78 additions & 0 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.arrays.sparse import SparseAccessor
from pandas.core.arrays.string_ import StringDtype
from pandas.core.case_when import (
case_when,
validate_case_when,
)
from pandas.core.construction import (
extract_array,
sanitize_array,
Expand Down Expand Up @@ -5582,6 +5586,80 @@ def between(

return lmask & rmask

def case_when(
samukweku marked this conversation as resolved.
Show resolved Hide resolved
self,
*args: tuple[tuple[tuple[ArrayLike], tuple[ArrayLike | Scalar]]],
) -> Series:
"""
Replace values where the conditions are True.

Parameters
----------
*args : tuple(s) of array-like, scalar.
Variable argument of tuples of conditions and expected replacements.
Takes the form: ``(condition0, replacement0)``,
``(condition1, replacement1)``, ... .
``condition`` should be a 1-D boolean array-like object
or a callable. If ``condition`` is a callable,
it is computed on the Series
and should return a boolean Series or array.
samukweku marked this conversation as resolved.
Show resolved Hide resolved
The callable must not change the input Series
(though pandas doesn`t check it). ``replacement`` should be a
1-D array-like object, a scalar or a callable.
If ``replacement`` is a callable, it is computed on the Series
and should return a scalar or Series. The callable
must not change the input Series
(though pandas doesn`t check it).
If ``condition`` is a Series, and the equivalent ``replacement``
is a Series, they must have the same index.
If there are multiple replacement options,
and they are Series, they must have the same index.

level : int, default None
Alignment level if needed.

.. versionadded:: 2.2.0

Returns
-------
Series

See Also
--------
Series.mask : Replace values where the condition is True.

Examples
--------
>>> df = pd.DataFrame({
samukweku marked this conversation as resolved.
Show resolved Hide resolved
... "a": [0,0,1,2],
... "b": [0,3,4,5],
... "c": [6,7,8,9]
... })
>>> df
a b c
0 0 0 6
1 0 3 7
2 1 4 8
3 2 5 9

>>> df.c.case_when((df.a.gt(0), df.a), # condition, replacement
... (df.b.gt(0), df.b))
0 6
1 3
2 1
3 2
Name: c, dtype: int64
"""
validate_case_when(args)
args = [
(
com.apply_if_callable(condition, self),
com.apply_if_callable(replacement, self),
)
for condition, replacement in args
]
return case_when(*args, default=self, level=None)

# error: Cannot determine type of 'isna'
@doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type]
def isna(self) -> Series:
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class TestPDApi(Base):
funcs = [
"array",
"bdate_range",
"case_when",
"concat",
"crosstab",
"cut",
Expand Down
Loading
Loading