Skip to content

Commit

Permalink
feat(expr): add short circuit checker in is_const (#15758)
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhseh authored Mar 19, 2024
1 parent 979d710 commit 1ccb354
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 3 deletions.
64 changes: 64 additions & 0 deletions src/frontend/planner_test/tests/testdata/input/short_circuit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# reference: <https://github.com/risingwavelabs/risingwave/issues/15724>

- id: create_table
sql: |
create table t1 (c1 INT, c2 INT, c3 INT);
expected_outputs: []

- id : short_circuit_or_pattern
before:
- create_table
sql: |
select true or 'abc'::int > 1;
expected_outputs:
- logical_plan
- batch_plan

- id : short_circuit_and_pattern
before:
- create_table
sql: |
select false and 'abc'::int > 1;
expected_outputs:
- logical_plan
- batch_plan

- id : short_circuit_or_pattern_with_table
before:
- create_table
sql: |
select true or 'abc'::int > c1 from t1;
expected_outputs:
- logical_plan
- batch_plan

- id : short_circuit_and_pattern_with_table
before:
- create_table
sql: |
select false and 'abc'::int > c1 from t1;
expected_outputs:
- logical_plan
- batch_plan

# should *not* be identified as const
# otherwise the *semantic* will be inconsistent
# ----
# - id : short_circuit_or_panic_pattern
# before:
# - create_table
# sql: |
# select 'abc'::int > c1 or true from t1;
# expected_outputs:
# - logical_plan
# - batch_plan
# ----
# - id : short_circuit_and_panic_pattern
# before:
# - create_table
# sql: |
# select 'abc'::int > c1 and false from t1;
# expected_outputs:
# - logical_plan
# - batch_plan
# ----
2 changes: 1 addition & 1 deletion src/frontend/planner_test/tests/testdata/output/expr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
create table t (v1 int);
SELECT 1 in (3, 0.5*2, min(v1)) from t;
batch_plan: |-
BatchProject { exprs: [(true:Boolean OR (1:Int32 = min(min(t.v1)))) as $expr1] }
BatchProject { exprs: [true:Boolean] }
└─BatchSimpleAgg { aggs: [min(min(t.v1))] }
└─BatchExchange { order: [], dist: Single }
└─BatchSimpleAgg { aggs: [min(t.v1)] }
Expand Down
46 changes: 46 additions & 0 deletions src/frontend/planner_test/tests/testdata/output/short_circuit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# This file is automatically generated. See `src/frontend/planner_test/README.md` for more information.
- id: create_table
sql: |
create table t1 (c1 INT, c2 INT, c3 INT);
- id: short_circuit_or_pattern
before:
- create_table
sql: |
select true or 'abc'::int > 1;
logical_plan: |-
LogicalProject { exprs: [(true:Boolean OR ('abc':Varchar::Int32 > 1:Int32)) as $expr1] }
└─LogicalValues { rows: [[]], schema: Schema { fields: [] } }
batch_plan: 'BatchValues { rows: [[true:Boolean]] }'
- id: short_circuit_and_pattern
before:
- create_table
sql: |
select false and 'abc'::int > 1;
logical_plan: |-
LogicalProject { exprs: [(false:Boolean AND ('abc':Varchar::Int32 > 1:Int32)) as $expr1] }
└─LogicalValues { rows: [[]], schema: Schema { fields: [] } }
batch_plan: 'BatchValues { rows: [[false:Boolean]] }'
- id: short_circuit_or_pattern_with_table
before:
- create_table
sql: |
select true or 'abc'::int > c1 from t1;
logical_plan: |-
LogicalProject { exprs: [(true:Boolean OR ('abc':Varchar::Int32 > t1.c1)) as $expr1] }
└─LogicalScan { table: t1, columns: [t1.c1, t1.c2, t1.c3, t1._row_id] }
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchProject { exprs: [true:Boolean] }
└─BatchScan { table: t1, columns: [t1.c1], distribution: SomeShard }
- id: short_circuit_and_pattern_with_table
before:
- create_table
sql: |
select false and 'abc'::int > c1 from t1;
logical_plan: |-
LogicalProject { exprs: [(false:Boolean AND ('abc':Varchar::Int32 > t1.c1)) as $expr1] }
└─LogicalScan { table: t1, columns: [t1.c1, t1.c2, t1.c3, t1._row_id] }
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchProject { exprs: [false:Boolean] }
└─BatchScan { table: t1, columns: [t1.c1], distribution: SomeShard }
30 changes: 28 additions & 2 deletions src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use fixedbitset::FixedBitSet;
use futures::FutureExt;
use paste::paste;
use risingwave_common::array::ListValue;
use risingwave_common::types::{DataType, Datum, JsonbVal, Scalar};
use risingwave_common::types::{DataType, Datum, JsonbVal, Scalar, ScalarImpl};
use risingwave_expr::aggregate::AggKind;
use risingwave_expr::expr::build_from_prost;
use risingwave_pb::expr::expr_node::RexNode;
Expand Down Expand Up @@ -614,6 +614,7 @@ impl ExprImpl {
struct HasOthers {
has_others: bool,
}

impl ExprVisitor for HasOthers {
fn visit_expr(&mut self, expr: &ExprImpl) {
match expr {
Expand All @@ -627,14 +628,39 @@ impl ExprImpl {
| ExprImpl::Parameter(_)
| ExprImpl::Now(_) => self.has_others = true,
ExprImpl::Literal(_inner) => {}
ExprImpl::FunctionCall(inner) => self.visit_function_call(inner),
ExprImpl::FunctionCall(inner) => {
if !self.is_short_circuit(inner) {
// only if the current `func_call` is *not* a short-circuit
// expression, e.g., true or (...) | false and (...),
// shall we proceed to visit it.
self.visit_function_call(inner)
}
}
ExprImpl::FunctionCallWithLambda(inner) => {
self.visit_function_call_with_lambda(inner)
}
}
}
}

impl HasOthers {
fn is_short_circuit(&self, func_call: &FunctionCall) -> bool {
/// evaluate the first parameter of `Or` or `And` function call
fn eval_first(e: &ExprImpl, expect: bool) -> bool {
let Some(Ok(Some(scalar))) = e.try_fold_const() else {
return false;
};
scalar == ScalarImpl::Bool(expect)
}

match func_call.func_type {
ExprType::Or => eval_first(&func_call.inputs()[0], true),
ExprType::And => eval_first(&func_call.inputs()[0], false),
_ => false,
}
}
}

let mut visitor = HasOthers { has_others: false };
visitor.visit_expr(self);
!visitor.has_others
Expand Down

0 comments on commit 1ccb354

Please sign in to comment.