Skip to content

Commit

Permalink
String dtype: allow string dtype in query/eval with default numexpr e…
Browse files Browse the repository at this point in the history
…ngine (#59810)

String dtype: allow string dtype in query/eval with default mumexpr engine
  • Loading branch information
jorisvandenbossche authored Sep 16, 2024
1 parent 160b3eb commit 013ac67
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 22 deletions.
12 changes: 9 additions & 3 deletions pandas/core/computation/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from pandas.util._exceptions import find_stack_level
from pandas.util._validators import validate_bool_kwarg

from pandas.core.dtypes.common import is_extension_array_dtype
from pandas.core.dtypes.common import (
is_extension_array_dtype,
is_string_dtype,
)

from pandas.core.computation.engines import ENGINES
from pandas.core.computation.expr import (
Expand Down Expand Up @@ -345,10 +348,13 @@ def eval(
parsed_expr = Expr(expr, engine=engine, parser=parser, env=env)

if engine == "numexpr" and (
is_extension_array_dtype(parsed_expr.terms.return_type)
(
is_extension_array_dtype(parsed_expr.terms.return_type)
and not is_string_dtype(parsed_expr.terms.return_type)
)
or getattr(parsed_expr.terms, "operand_types", None) is not None
and any(
is_extension_array_dtype(elem)
(is_extension_array_dtype(elem) and not is_string_dtype(elem))
for elem in parsed_expr.terms.operand_types
)
):
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/computation/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from pandas.errors import UndefinedVariableError

from pandas.core.dtypes.common import is_string_dtype

import pandas.core.common as com
from pandas.core.computation.ops import (
ARITH_OPS_SYMS,
Expand Down Expand Up @@ -524,10 +526,12 @@ def _maybe_evaluate_binop(
elif self.engine != "pytables":
if (
getattr(lhs, "return_type", None) == object
or is_string_dtype(getattr(lhs, "return_type", None))
or getattr(rhs, "return_type", None) == object
or is_string_dtype(getattr(rhs, "return_type", None))
):
# evaluate "==" and "!=" in python if either of our operands
# has an object return type
# has an object or string return type
return self._maybe_eval(res, eval_in_python + maybe_eval_in_python)
return res

Expand Down
24 changes: 6 additions & 18 deletions pandas/tests/frame/test_query_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.errors import (
NumExprClobberingError,
UndefinedVariableError,
Expand Down Expand Up @@ -762,7 +760,6 @@ def test_inf(self, op, f, engine, parser):
result = df.query(q, engine=engine, parser=parser)
tm.assert_frame_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_check_tz_aware_index_query(self, tz_aware_fixture):
# https://github.com/pandas-dev/pandas/issues/29463
tz = tz_aware_fixture
Expand All @@ -775,6 +772,7 @@ def test_check_tz_aware_index_query(self, tz_aware_fixture):
tm.assert_frame_equal(result, expected)

expected = DataFrame(df_index)
expected.columns = expected.columns.astype(object)
result = df.reset_index().query('"2018-01-03 00:00:00+00" < time')
tm.assert_frame_equal(result, expected)

Expand Down Expand Up @@ -1072,7 +1070,7 @@ def test_query_with_string_columns(self, parser, engine):
with pytest.raises(NotImplementedError, match=msg):
df.query("a in b and c < d", parser=parser, engine=engine)

def test_object_array_eq_ne(self, parser, engine, using_infer_string):
def test_object_array_eq_ne(self, parser, engine):
df = DataFrame(
{
"a": list("aaaabbbbcccc"),
Expand All @@ -1081,14 +1079,11 @@ def test_object_array_eq_ne(self, parser, engine, using_infer_string):
"d": np.random.default_rng(2).integers(9, size=12),
}
)
warning = RuntimeWarning if using_infer_string and engine == "numexpr" else None
with tm.assert_produces_warning(warning):
res = df.query("a == b", parser=parser, engine=engine)
res = df.query("a == b", parser=parser, engine=engine)
exp = df[df.a == df.b]
tm.assert_frame_equal(res, exp)

with tm.assert_produces_warning(warning):
res = df.query("a != b", parser=parser, engine=engine)
res = df.query("a != b", parser=parser, engine=engine)
exp = df[df.a != df.b]
tm.assert_frame_equal(res, exp)

Expand Down Expand Up @@ -1128,15 +1123,13 @@ def test_query_with_nested_special_character(self, parser, engine):
],
)
def test_query_lex_compare_strings(
self, parser, engine, op, func, using_infer_string
self, parser, engine, op, func
):
a = Series(np.random.default_rng(2).choice(list("abcde"), 20))
b = Series(np.arange(a.size))
df = DataFrame({"X": a, "Y": b})

warning = RuntimeWarning if using_infer_string and engine == "numexpr" else None
with tm.assert_produces_warning(warning):
res = df.query(f'X {op} "d"', engine=engine, parser=parser)
res = df.query(f'X {op} "d"', engine=engine, parser=parser)
expected = df[func(df.X, "d")]
tm.assert_frame_equal(res, expected)

Expand Down Expand Up @@ -1400,15 +1393,13 @@ def test_expr_with_column_name_with_backtick(self):
expected = df[df["a`b"] < 2]
tm.assert_frame_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_expr_with_string_with_backticks(self):
# GH 59285
df = DataFrame(("`", "`````", "``````````"), columns=["#backticks"])
result = df.query("'```' < `#backticks`")
expected = df["```" < df["#backticks"]]
tm.assert_frame_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_expr_with_string_with_backticked_substring_same_as_column_name(self):
# GH 59285
df = DataFrame(("`", "`````", "``````````"), columns=["#backticks"])
Expand Down Expand Up @@ -1439,7 +1430,6 @@ def test_expr_with_column_names_with_special_characters(self, col1, col2, expr):
expected = df[df[col1] < df[col2]]
tm.assert_frame_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_expr_with_no_backticks(self):
# GH 59285
df = DataFrame(("aaa", "vvv", "zzz"), columns=["column_name"])
Expand Down Expand Up @@ -1483,15 +1473,13 @@ def test_expr_with_quote_opened_before_backtick_and_quote_is_unmatched(self):
):
df.query("`column-name` < 'It`s that\\'s \"quote\" #hash")

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_expr_with_quote_opened_before_backtick_and_quote_is_matched_at_end(self):
# GH 59285
df = DataFrame(("aaa", "vvv", "zzz"), columns=["column-name"])
result = df.query("`column-name` < 'It`s that\\'s \"quote\" #hash'")
expected = df[df["column-name"] < 'It`s that\'s "quote" #hash']
tm.assert_frame_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_expr_with_quote_opened_before_backtick_and_quote_is_matched_in_mid(self):
# GH 59285
df = DataFrame(("aaa", "vvv", "zzz"), columns=["column-name"])
Expand Down

0 comments on commit 013ac67

Please sign in to comment.