Skip to content

Commit

Permalink
feat(optimizer): support logical filter expression simplify rule (#15275
Browse files Browse the repository at this point in the history
)
  • Loading branch information
xzhseh authored Mar 19, 2024
1 parent 1ccb354 commit e4f5eb4
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -385,15 +385,15 @@
│ ├─StreamProject { exprs: [t1.ts, AddWithTimeZone(t1.ts, '01:00:00':Interval, 'UTC':Varchar) as $expr1, t1._row_id] }
│ │ └─StreamFilter { predicate: Not((t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) }
│ │ └─StreamShare { id: 2 }
│ │ └─StreamFilter { predicate: (Not((t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) OR (t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) }
│ │ └─StreamFilter { predicate: IsNotNull(t1.ts) }
│ │ └─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) }
│ └─StreamExchange { dist: Broadcast }
│ └─StreamNow { output: [now] }
└─StreamExchange { dist: HashShard(t1._row_id, 1:Int32) }
└─StreamProject { exprs: [t1.ts, t1._row_id, 1:Int32] }
└─StreamFilter { predicate: (t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz) }
└─StreamShare { id: 2 }
└─StreamFilter { predicate: (Not((t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) OR (t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) }
└─StreamFilter { predicate: IsNotNull(t1.ts) }
└─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) }
- name: Temporal filter with or is null
sql: |
Expand All @@ -407,17 +407,15 @@
│ └─StreamDynamicFilter { predicate: ($expr1 > now), output_watermarks: [$expr1], output: [t1.ts, $expr1, t1._row_id], cleaned_by_watermark: true }
│ ├─StreamProject { exprs: [t1.ts, AddWithTimeZone(t1.ts, '01:00:00':Interval, 'UTC':Varchar) as $expr1, t1._row_id] }
│ │ └─StreamFilter { predicate: Not(IsNull(t1.ts)) }
│ │ └─StreamShare { id: 2 }
│ │ └─StreamFilter { predicate: (Not(IsNull(t1.ts)) OR IsNull(t1.ts)) }
│ │ └─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) }
│ │ └─StreamShare { id: 1 }
│ │ └─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) }
│ └─StreamExchange { dist: Broadcast }
│ └─StreamNow { output: [now] }
└─StreamExchange { dist: HashShard(t1._row_id, 1:Int32) }
└─StreamProject { exprs: [t1.ts, t1._row_id, 1:Int32] }
└─StreamFilter { predicate: IsNull(t1.ts) }
└─StreamShare { id: 2 }
└─StreamFilter { predicate: (Not(IsNull(t1.ts)) OR IsNull(t1.ts)) }
└─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) }
└─StreamShare { id: 1 }
└─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) }
- name: Temporal filter with or predicate
sql: |
create table t1 (ts timestamp with time zone);
Expand Down Expand Up @@ -459,7 +457,7 @@
│ │ │ └─StreamDynamicFilter { predicate: (t.t > $expr1), output_watermarks: [t.t], output: [t.t, t.a, t._row_id], cleaned_by_watermark: true }
│ │ │ ├─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND Not(IsNull(t.t)) AND Not((t.a < 1:Int32)) }
│ │ │ │ └─StreamShare { id: 2 }
│ │ │ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) }
│ │ │ │ └─StreamFilter { predicate: IsNotNull(t.a) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) }
│ │ │ │ └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], stream_scan_type: Backfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ │ │ └─StreamExchange { dist: Broadcast }
│ │ │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] }
Expand All @@ -468,7 +466,7 @@
│ │ └─StreamProject { exprs: [t.t, t.a, t._row_id, 1:Int32] }
│ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (IsNull(t.t) OR (t.a < 1:Int32)) }
│ │ └─StreamShare { id: 2 }
│ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) }
│ │ └─StreamFilter { predicate: IsNotNull(t.a) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) }
│ │ └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], stream_scan_type: Backfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ └─StreamExchange { dist: Broadcast }
│ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] }
Expand All @@ -483,7 +481,7 @@
│ └─StreamDynamicFilter { predicate: (t.t > $expr1), output_watermarks: [t.t], output: [t.t, t.a, t._row_id], cleaned_by_watermark: true }
│ ├─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND Not(IsNull(t.t)) AND Not((t.a < 1:Int32)) }
│ │ └─StreamShare { id: 2 }
│ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) }
│ │ └─StreamFilter { predicate: IsNotNull(t.a) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) }
│ │ └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], stream_scan_type: Backfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ └─StreamExchange { dist: Broadcast }
│ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] }
Expand All @@ -492,5 +490,5 @@
└─StreamProject { exprs: [t.t, t.a, t._row_id, 1:Int32] }
└─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (IsNull(t.t) OR (t.a < 1:Int32)) }
└─StreamShare { id: 2 }
└─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) }
└─StreamFilter { predicate: IsNotNull(t.a) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) }
└─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], stream_scan_type: Backfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
12 changes: 12 additions & 0 deletions src/frontend/src/optimizer/logical_optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ static COMMON_SUB_EXPR_EXTRACT: LazyLock<OptimizationStage> = LazyLock::new(|| {
)
});

static LOGICAL_FILTER_EXPRESSION_SIMPLIFY: LazyLock<OptimizationStage> = LazyLock::new(|| {
OptimizationStage::new(
"Logical Filter Expression Simplify",
vec![LogicalFilterExpressionSimplifyRule::create()],
ApplyOrder::TopDown,
)
});

impl LogicalOptimizer {
pub fn predicate_pushdown(
plan: PlanRef,
Expand Down Expand Up @@ -624,6 +632,8 @@ impl LogicalOptimizer {

plan = plan.optimize_by_rules(&COMMON_SUB_EXPR_EXTRACT);

plan = plan.optimize_by_rules(&LOGICAL_FILTER_EXPRESSION_SIMPLIFY);

#[cfg(debug_assertions)]
InputRefValidator.validate(plan.clone());

Expand Down Expand Up @@ -720,6 +730,8 @@ impl LogicalOptimizer {

plan = plan.optimize_by_rules(&DAG_TO_TREE);

plan = plan.optimize_by_rules(&LOGICAL_FILTER_EXPRESSION_SIMPLIFY);

#[cfg(debug_assertions)]
InputRefValidator.validate(plan.clone());

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/optimizer/plan_expr_visitor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

mod expr_counter;
mod input_ref_counter;
mod strong;
pub mod strong;

pub(crate) use expr_counter::CseExprCounter;
pub(crate) use input_ref_counter::InputRefCounter;
Expand Down
3 changes: 2 additions & 1 deletion src/frontend/src/optimizer/plan_expr_visitor/strong.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ impl Strong {
Self { null_columns }
}

/// Returns whether the analyzed expression will definitely return null if
/// Returns whether the analyzed expression will *definitely* return null if
/// all of a given set of input columns are null.
/// Note: we could not assume any null-related property for the input expression if `is_null` returns false
pub fn is_null(expr: &ExprImpl, null_columns: FixedBitSet) -> bool {
let strong = Strong::new(null_columns);
strong.is_null_visit(expr)
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/optimizer/plan_node/expr_rewritable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use super::*;
use crate::expr::ExprRewriter;

/// Rewrites expressions in a `PlanRef`. Due to `Share` operator,
/// the `ExprRewriter` needs to be idempotent i.e. applying it more than once
/// the `ExprRewriter` needs to be idempotent i.e., applying it more than once
/// to the same `ExprImpl` will be a noop on subsequent applications.
/// `rewrite_exprs` should only return a plan with the given node modified.
/// To rewrite recursively, call `rewrite_exprs_recursive` on [`RewriteExprsRecursive`].
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use fixedbitset::FixedBitSet;
use risingwave_common::types::ScalarImpl;
use risingwave_connector::source::DataType;

use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall};
use crate::optimizer::plan_expr_visitor::strong::Strong;
use crate::optimizer::plan_node::{ExprRewritable, LogicalFilter, LogicalShare, PlanTreeNodeUnary};
use crate::optimizer::rule::{BoxedRule, Rule};
use crate::optimizer::PlanRef;

/// Specially for the predicate under `LogicalFilter -> LogicalShare -> LogicalFilter`
pub struct LogicalFilterExpressionSimplifyRule {}
impl Rule for LogicalFilterExpressionSimplifyRule {
/// The pattern we aim to optimize, e.g.,
/// 1. (NOT (e)) OR (e) => True
/// 2. (NOT (e)) AND (e) => False
/// NOTE: `e` should only contain at most a single column
/// otherwise we will not conduct the optimization
fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
let filter: &LogicalFilter = plan.as_logical_filter()?;
let mut rewriter = ExpressionSimplifyRewriter {};
let logical_share_plan = filter.input();
let share: &LogicalShare = logical_share_plan.as_logical_share()?;
let input = share.input().rewrite_exprs(&mut rewriter);
share.replace_input(input);
Some(LogicalFilter::create(
share.clone().into(),
filter.predicate().clone(),
))
}
}

impl LogicalFilterExpressionSimplifyRule {
pub fn create() -> BoxedRule {
Box::new(LogicalFilterExpressionSimplifyRule {})
}
}

/// `True` => the return value of the input `func_type` will *definitely* not be null
/// `False` => vice versa
fn is_not_null(func_type: ExprType) -> bool {
func_type == ExprType::IsNull
|| func_type == ExprType::IsNotNull
|| func_type == ExprType::IsTrue
|| func_type == ExprType::IsFalse
|| func_type == ExprType::IsNotTrue
|| func_type == ExprType::IsNotFalse
}

/// Simply extract every possible `InputRef` out from the input `expr`
fn extract_column(expr: ExprImpl, columns: &mut Vec<ExprImpl>) {
match expr.clone() {
ExprImpl::FunctionCall(func_call) => {
// the functions that *never* return null will be ignored
if is_not_null(func_call.func_type()) {
return;
}
for sub_expr in func_call.inputs() {
extract_column(sub_expr.clone(), columns);
}
}
ExprImpl::InputRef(_) => {
if !columns.contains(&expr) {
// only add the column if not exists
columns.push(expr);
}
}
_ => (),
}
}

/// If ever `Not (e)` and `(e)` appear together
/// First return value indicates if the optimizable pattern exist
/// Second return value indicates if the term `e` should be converted to either `IsNotNull` or `IsNull`
/// If so, it will contain the actual wrapper `ExprImpl` for that; otherwise it will be `None`
fn check_optimizable_pattern(e1: ExprImpl, e2: ExprImpl) -> (bool, Option<ExprImpl>) {
/// Try wrapping inner *column* with `IsNotNull`
fn try_wrap_inner_expression(expr: ExprImpl) -> Option<ExprImpl> {
let mut columns = vec![];

extract_column(expr, &mut columns);

assert!(columns.len() <= 1, "should only contain a single column");

if columns.is_empty() {
return None;
}

// From `c1` to `IsNotNull(c1)`
let Ok(expr) = FunctionCall::new(ExprType::IsNotNull, vec![columns[0].clone()]) else {
return None;
};

Some(expr.into())
}

// Due to constant folding, we only need to consider `FunctionCall` here (presumably)
let ExprImpl::FunctionCall(e1_func) = e1.clone() else {
return (false, None);
};
let ExprImpl::FunctionCall(e2_func) = e2.clone() else {
return (false, None);
};

// No chance to optimize
if e1_func.func_type() != ExprType::Not && e2_func.func_type() != ExprType::Not {
return (false, None);
}

if e1_func.func_type() != ExprType::Not {
// (e1) [op] (Not (e2))
if e2_func.inputs().len() != 1 {
// `not` should only have a single operand, which is `e2` in this case
return (false, None);
}
(
e1 == e2_func.inputs()[0].clone(),
try_wrap_inner_expression(e1),
)
} else {
// (Not (e1)) [op] (e2)
if e1_func.inputs().len() != 1 {
return (false, None);
}
(
e2 == e1_func.inputs()[0].clone(),
try_wrap_inner_expression(e2),
)
}
}

/// 1. True or (...) | (...) or True => True
/// 2. False and (...) | (...) and False => False
/// NOTE: the `True` and `False` here not only represent a single `ExprImpl::Literal`
/// but represent every `ExprImpl` that can be *evaluated* to `ScalarImpl::Bool`
/// during optimization phase as well
fn check_special_pattern(e1: ExprImpl, e2: ExprImpl, op: ExprType) -> Option<bool> {
fn check_special_pattern_inner(e: ExprImpl, op: ExprType) -> Option<bool> {
let Some(Ok(Some(scalar))) = e.try_fold_const() else {
return None;
};
match op {
ExprType::Or => {
if scalar == ScalarImpl::Bool(true) {
Some(true)
} else {
None
}
}
ExprType::And => {
if scalar == ScalarImpl::Bool(false) {
Some(false)
} else {
None
}
}
_ => None,
}
}

if e1.is_const() {
if let Some(res) = check_special_pattern_inner(e1, op) {
return Some(res);
}
}

if e2.is_literal() {
if let Some(res) = check_special_pattern_inner(e2, op) {
return Some(res);
}
}

None
}

struct ExpressionSimplifyRewriter {}
impl ExprRewriter for ExpressionSimplifyRewriter {
fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl {
// Check if the input expression is *definitely* null
let mut columns = vec![];
extract_column(expr.clone(), &mut columns);

// NOTE: we do NOT optimize cases that involve multiple columns
// for detailed reference: <https://github.com/risingwavelabs/risingwave/pull/15275#issuecomment-1975783856>
if columns.len() > 1 {
return expr;
}

// Eliminate the case where the current expression
// will definitely return null by using `Strong::is_null`
if !columns.is_empty() {
let ExprImpl::InputRef(input_ref) = columns[0].clone() else {
return expr;
};
let index = input_ref.index();
let fixedbitset = FixedBitSet::with_capacity(index);
if Strong::is_null(&expr, fixedbitset) {
return ExprImpl::literal_bool(false);
}
}

let ExprImpl::FunctionCall(func_call) = expr.clone() else {
return expr;
};
if func_call.func_type() != ExprType::Or && func_call.func_type() != ExprType::And {
return expr;
}
assert_eq!(func_call.return_type(), DataType::Boolean);
// Sanity check, the inputs should only contain two branches
if func_call.inputs().len() != 2 {
return expr;
}

let inputs = func_call.inputs();
let e1 = inputs[0].clone();
let e2 = inputs[1].clone();

// Eliminate special pattern
if let Some(res) = check_special_pattern(e1.clone(), e2.clone(), func_call.func_type()) {
return ExprImpl::literal_bool(res);
}

let (optimizable_flag, column) = check_optimizable_pattern(e1, e2);
if optimizable_flag {
match func_call.func_type() {
ExprType::Or => {
if let Some(column) = column {
// IsNotNull(col)
column
} else {
ExprImpl::literal_bool(true)
}
}
// `AND` will always be false, no matter the underlying columns are null or not
// i.e., for `(Not (e)) AND (e)`, since this is filter simplification,
// whether `e` is null or not does NOT matter
ExprType::And => ExprImpl::literal_bool(false),
_ => expr,
}
} else {
expr
}
}
}
Loading

0 comments on commit e4f5eb4

Please sign in to comment.