diff --git a/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml b/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml index 160f34103faf7..d58656d35f18d 100644 --- a/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml +++ b/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml @@ -385,7 +385,7 @@ │ ├─StreamProject { exprs: [t1.ts, AddWithTimeZone(t1.ts, '01:00:00':Interval, 'UTC':Varchar) as $expr1, t1._row_id] } │ │ └─StreamFilter { predicate: Not((t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) } │ │ └─StreamShare { id: 2 } - │ │ └─StreamFilter { predicate: (Not((t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) OR (t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) } + │ │ └─StreamFilter { predicate: IsNotNull(t1.ts) } │ │ └─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } │ └─StreamExchange { dist: Broadcast } │ └─StreamNow { output: [now] } @@ -393,7 +393,7 @@ └─StreamProject { exprs: [t1.ts, t1._row_id, 1:Int32] } └─StreamFilter { predicate: (t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz) } └─StreamShare { id: 2 } - └─StreamFilter { predicate: (Not((t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) OR (t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) } + └─StreamFilter { predicate: IsNotNull(t1.ts) } └─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } - name: Temporal filter with or is null sql: | @@ -407,17 +407,15 @@ │ └─StreamDynamicFilter { predicate: ($expr1 > now), output_watermarks: [$expr1], output: [t1.ts, $expr1, t1._row_id], cleaned_by_watermark: true } │ ├─StreamProject { exprs: [t1.ts, AddWithTimeZone(t1.ts, '01:00:00':Interval, 'UTC':Varchar) as $expr1, t1._row_id] } │ │ └─StreamFilter { predicate: Not(IsNull(t1.ts)) } - │ │ └─StreamShare { id: 2 } - │ │ └─StreamFilter { predicate: (Not(IsNull(t1.ts)) OR IsNull(t1.ts)) } - │ │ └─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } + │ │ └─StreamShare { id: 1 } + │ │ └─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } │ └─StreamExchange { dist: Broadcast } │ └─StreamNow { output: [now] } └─StreamExchange { dist: HashShard(t1._row_id, 1:Int32) } └─StreamProject { exprs: [t1.ts, t1._row_id, 1:Int32] } └─StreamFilter { predicate: IsNull(t1.ts) } - └─StreamShare { id: 2 } - └─StreamFilter { predicate: (Not(IsNull(t1.ts)) OR IsNull(t1.ts)) } - └─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } + └─StreamShare { id: 1 } + └─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], stream_scan_type: Backfill, pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } - name: Temporal filter with or predicate sql: | create table t1 (ts timestamp with time zone); @@ -459,7 +457,7 @@ │ │ │ └─StreamDynamicFilter { predicate: (t.t > $expr1), output_watermarks: [t.t], output: [t.t, t.a, t._row_id], cleaned_by_watermark: true } │ │ │ ├─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND Not(IsNull(t.t)) AND Not((t.a < 1:Int32)) } │ │ │ │ └─StreamShare { id: 2 } - │ │ │ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } + │ │ │ │ └─StreamFilter { predicate: IsNotNull(t.a) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } │ │ │ │ └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], stream_scan_type: Backfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } │ │ │ └─StreamExchange { dist: Broadcast } │ │ │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } @@ -468,7 +466,7 @@ │ │ └─StreamProject { exprs: [t.t, t.a, t._row_id, 1:Int32] } │ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (IsNull(t.t) OR (t.a < 1:Int32)) } │ │ └─StreamShare { id: 2 } - │ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } + │ │ └─StreamFilter { predicate: IsNotNull(t.a) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } │ │ └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], stream_scan_type: Backfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } │ └─StreamExchange { dist: Broadcast } │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } @@ -483,7 +481,7 @@ │ └─StreamDynamicFilter { predicate: (t.t > $expr1), output_watermarks: [t.t], output: [t.t, t.a, t._row_id], cleaned_by_watermark: true } │ ├─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND Not(IsNull(t.t)) AND Not((t.a < 1:Int32)) } │ │ └─StreamShare { id: 2 } - │ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } + │ │ └─StreamFilter { predicate: IsNotNull(t.a) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } │ │ └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], stream_scan_type: Backfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } │ └─StreamExchange { dist: Broadcast } │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } @@ -492,5 +490,5 @@ └─StreamProject { exprs: [t.t, t.a, t._row_id, 1:Int32] } └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (IsNull(t.t) OR (t.a < 1:Int32)) } └─StreamShare { id: 2 } - └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } + └─StreamFilter { predicate: IsNotNull(t.a) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], stream_scan_type: Backfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } diff --git a/src/frontend/src/optimizer/logical_optimization.rs b/src/frontend/src/optimizer/logical_optimization.rs index 3343702936723..c181d828c3567 100644 --- a/src/frontend/src/optimizer/logical_optimization.rs +++ b/src/frontend/src/optimizer/logical_optimization.rs @@ -416,6 +416,14 @@ static COMMON_SUB_EXPR_EXTRACT: LazyLock = LazyLock::new(|| { ) }); +static LOGICAL_FILTER_EXPRESSION_SIMPLIFY: LazyLock = LazyLock::new(|| { + OptimizationStage::new( + "Logical Filter Expression Simplify", + vec![LogicalFilterExpressionSimplifyRule::create()], + ApplyOrder::TopDown, + ) +}); + impl LogicalOptimizer { pub fn predicate_pushdown( plan: PlanRef, @@ -624,6 +632,8 @@ impl LogicalOptimizer { plan = plan.optimize_by_rules(&COMMON_SUB_EXPR_EXTRACT); + plan = plan.optimize_by_rules(&LOGICAL_FILTER_EXPRESSION_SIMPLIFY); + #[cfg(debug_assertions)] InputRefValidator.validate(plan.clone()); @@ -720,6 +730,8 @@ impl LogicalOptimizer { plan = plan.optimize_by_rules(&DAG_TO_TREE); + plan = plan.optimize_by_rules(&LOGICAL_FILTER_EXPRESSION_SIMPLIFY); + #[cfg(debug_assertions)] InputRefValidator.validate(plan.clone()); diff --git a/src/frontend/src/optimizer/plan_expr_visitor/mod.rs b/src/frontend/src/optimizer/plan_expr_visitor/mod.rs index 1dc4c6980188d..08b842b98875f 100644 --- a/src/frontend/src/optimizer/plan_expr_visitor/mod.rs +++ b/src/frontend/src/optimizer/plan_expr_visitor/mod.rs @@ -14,7 +14,7 @@ mod expr_counter; mod input_ref_counter; -mod strong; +pub mod strong; pub(crate) use expr_counter::CseExprCounter; pub(crate) use input_ref_counter::InputRefCounter; diff --git a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs index a80a696c2c16d..79905a076e426 100644 --- a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs +++ b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs @@ -44,8 +44,9 @@ impl Strong { Self { null_columns } } - /// Returns whether the analyzed expression will definitely return null if + /// Returns whether the analyzed expression will *definitely* return null if /// all of a given set of input columns are null. + /// Note: we could not assume any null-related property for the input expression if `is_null` returns false pub fn is_null(expr: &ExprImpl, null_columns: FixedBitSet) -> bool { let strong = Strong::new(null_columns); strong.is_null_visit(expr) diff --git a/src/frontend/src/optimizer/plan_node/expr_rewritable.rs b/src/frontend/src/optimizer/plan_node/expr_rewritable.rs index 693e14a4c33b4..33deebde1cdf9 100644 --- a/src/frontend/src/optimizer/plan_node/expr_rewritable.rs +++ b/src/frontend/src/optimizer/plan_node/expr_rewritable.rs @@ -18,7 +18,7 @@ use super::*; use crate::expr::ExprRewriter; /// Rewrites expressions in a `PlanRef`. Due to `Share` operator, -/// the `ExprRewriter` needs to be idempotent i.e. applying it more than once +/// the `ExprRewriter` needs to be idempotent i.e., applying it more than once /// to the same `ExprImpl` will be a noop on subsequent applications. /// `rewrite_exprs` should only return a plan with the given node modified. /// To rewrite recursively, call `rewrite_exprs_recursive` on [`RewriteExprsRecursive`]. diff --git a/src/frontend/src/optimizer/rule/logical_filter_expression_simplify_rule.rs b/src/frontend/src/optimizer/rule/logical_filter_expression_simplify_rule.rs new file mode 100644 index 0000000000000..f2baf788e56fe --- /dev/null +++ b/src/frontend/src/optimizer/rule/logical_filter_expression_simplify_rule.rs @@ -0,0 +1,258 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use fixedbitset::FixedBitSet; +use risingwave_common::types::ScalarImpl; +use risingwave_connector::source::DataType; + +use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall}; +use crate::optimizer::plan_expr_visitor::strong::Strong; +use crate::optimizer::plan_node::{ExprRewritable, LogicalFilter, LogicalShare, PlanTreeNodeUnary}; +use crate::optimizer::rule::{BoxedRule, Rule}; +use crate::optimizer::PlanRef; + +/// Specially for the predicate under `LogicalFilter -> LogicalShare -> LogicalFilter` +pub struct LogicalFilterExpressionSimplifyRule {} +impl Rule for LogicalFilterExpressionSimplifyRule { + /// The pattern we aim to optimize, e.g., + /// 1. (NOT (e)) OR (e) => True + /// 2. (NOT (e)) AND (e) => False + /// NOTE: `e` should only contain at most a single column + /// otherwise we will not conduct the optimization + fn apply(&self, plan: PlanRef) -> Option { + let filter: &LogicalFilter = plan.as_logical_filter()?; + let mut rewriter = ExpressionSimplifyRewriter {}; + let logical_share_plan = filter.input(); + let share: &LogicalShare = logical_share_plan.as_logical_share()?; + let input = share.input().rewrite_exprs(&mut rewriter); + share.replace_input(input); + Some(LogicalFilter::create( + share.clone().into(), + filter.predicate().clone(), + )) + } +} + +impl LogicalFilterExpressionSimplifyRule { + pub fn create() -> BoxedRule { + Box::new(LogicalFilterExpressionSimplifyRule {}) + } +} + +/// `True` => the return value of the input `func_type` will *definitely* not be null +/// `False` => vice versa +fn is_not_null(func_type: ExprType) -> bool { + func_type == ExprType::IsNull + || func_type == ExprType::IsNotNull + || func_type == ExprType::IsTrue + || func_type == ExprType::IsFalse + || func_type == ExprType::IsNotTrue + || func_type == ExprType::IsNotFalse +} + +/// Simply extract every possible `InputRef` out from the input `expr` +fn extract_column(expr: ExprImpl, columns: &mut Vec) { + match expr.clone() { + ExprImpl::FunctionCall(func_call) => { + // the functions that *never* return null will be ignored + if is_not_null(func_call.func_type()) { + return; + } + for sub_expr in func_call.inputs() { + extract_column(sub_expr.clone(), columns); + } + } + ExprImpl::InputRef(_) => { + if !columns.contains(&expr) { + // only add the column if not exists + columns.push(expr); + } + } + _ => (), + } +} + +/// If ever `Not (e)` and `(e)` appear together +/// First return value indicates if the optimizable pattern exist +/// Second return value indicates if the term `e` should be converted to either `IsNotNull` or `IsNull` +/// If so, it will contain the actual wrapper `ExprImpl` for that; otherwise it will be `None` +fn check_optimizable_pattern(e1: ExprImpl, e2: ExprImpl) -> (bool, Option) { + /// Try wrapping inner *column* with `IsNotNull` + fn try_wrap_inner_expression(expr: ExprImpl) -> Option { + let mut columns = vec![]; + + extract_column(expr, &mut columns); + + assert!(columns.len() <= 1, "should only contain a single column"); + + if columns.is_empty() { + return None; + } + + // From `c1` to `IsNotNull(c1)` + let Ok(expr) = FunctionCall::new(ExprType::IsNotNull, vec![columns[0].clone()]) else { + return None; + }; + + Some(expr.into()) + } + + // Due to constant folding, we only need to consider `FunctionCall` here (presumably) + let ExprImpl::FunctionCall(e1_func) = e1.clone() else { + return (false, None); + }; + let ExprImpl::FunctionCall(e2_func) = e2.clone() else { + return (false, None); + }; + + // No chance to optimize + if e1_func.func_type() != ExprType::Not && e2_func.func_type() != ExprType::Not { + return (false, None); + } + + if e1_func.func_type() != ExprType::Not { + // (e1) [op] (Not (e2)) + if e2_func.inputs().len() != 1 { + // `not` should only have a single operand, which is `e2` in this case + return (false, None); + } + ( + e1 == e2_func.inputs()[0].clone(), + try_wrap_inner_expression(e1), + ) + } else { + // (Not (e1)) [op] (e2) + if e1_func.inputs().len() != 1 { + return (false, None); + } + ( + e2 == e1_func.inputs()[0].clone(), + try_wrap_inner_expression(e2), + ) + } +} + +/// 1. True or (...) | (...) or True => True +/// 2. False and (...) | (...) and False => False +/// NOTE: the `True` and `False` here not only represent a single `ExprImpl::Literal` +/// but represent every `ExprImpl` that can be *evaluated* to `ScalarImpl::Bool` +/// during optimization phase as well +fn check_special_pattern(e1: ExprImpl, e2: ExprImpl, op: ExprType) -> Option { + fn check_special_pattern_inner(e: ExprImpl, op: ExprType) -> Option { + let Some(Ok(Some(scalar))) = e.try_fold_const() else { + return None; + }; + match op { + ExprType::Or => { + if scalar == ScalarImpl::Bool(true) { + Some(true) + } else { + None + } + } + ExprType::And => { + if scalar == ScalarImpl::Bool(false) { + Some(false) + } else { + None + } + } + _ => None, + } + } + + if e1.is_const() { + if let Some(res) = check_special_pattern_inner(e1, op) { + return Some(res); + } + } + + if e2.is_literal() { + if let Some(res) = check_special_pattern_inner(e2, op) { + return Some(res); + } + } + + None +} + +struct ExpressionSimplifyRewriter {} +impl ExprRewriter for ExpressionSimplifyRewriter { + fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl { + // Check if the input expression is *definitely* null + let mut columns = vec![]; + extract_column(expr.clone(), &mut columns); + + // NOTE: we do NOT optimize cases that involve multiple columns + // for detailed reference: + if columns.len() > 1 { + return expr; + } + + // Eliminate the case where the current expression + // will definitely return null by using `Strong::is_null` + if !columns.is_empty() { + let ExprImpl::InputRef(input_ref) = columns[0].clone() else { + return expr; + }; + let index = input_ref.index(); + let fixedbitset = FixedBitSet::with_capacity(index); + if Strong::is_null(&expr, fixedbitset) { + return ExprImpl::literal_bool(false); + } + } + + let ExprImpl::FunctionCall(func_call) = expr.clone() else { + return expr; + }; + if func_call.func_type() != ExprType::Or && func_call.func_type() != ExprType::And { + return expr; + } + assert_eq!(func_call.return_type(), DataType::Boolean); + // Sanity check, the inputs should only contain two branches + if func_call.inputs().len() != 2 { + return expr; + } + + let inputs = func_call.inputs(); + let e1 = inputs[0].clone(); + let e2 = inputs[1].clone(); + + // Eliminate special pattern + if let Some(res) = check_special_pattern(e1.clone(), e2.clone(), func_call.func_type()) { + return ExprImpl::literal_bool(res); + } + + let (optimizable_flag, column) = check_optimizable_pattern(e1, e2); + if optimizable_flag { + match func_call.func_type() { + ExprType::Or => { + if let Some(column) = column { + // IsNotNull(col) + column + } else { + ExprImpl::literal_bool(true) + } + } + // `AND` will always be false, no matter the underlying columns are null or not + // i.e., for `(Not (e)) AND (e)`, since this is filter simplification, + // whether `e` is null or not does NOT matter + ExprType::And => ExprImpl::literal_bool(false), + _ => expr, + } + } else { + expr + } + } +} diff --git a/src/frontend/src/optimizer/rule/mod.rs b/src/frontend/src/optimizer/rule/mod.rs index c77d4f24f1555..dff3f986ce22a 100644 --- a/src/frontend/src/optimizer/rule/mod.rs +++ b/src/frontend/src/optimizer/rule/mod.rs @@ -29,6 +29,8 @@ pub trait Description { pub(super) type BoxedRule = Box; +mod logical_filter_expression_simplify_rule; +pub use logical_filter_expression_simplify_rule::*; mod over_window_merge_rule; pub use over_window_merge_rule::*; mod project_join_merge_rule; @@ -204,6 +206,7 @@ macro_rules! for_all_rules { , { AlwaysFalseFilterRule } , { BushyTreeJoinOrderingRule } , { StreamProjectMergeRule } + , { LogicalFilterExpressionSimplifyRule } , { JoinProjectTransposeRule } , { LimitPushDownRule } , { PullUpHopRule }