diff --git a/src/db.rs b/src/db.rs index 2a7a1670..c0f27aeb 100644 --- a/src/db.rs +++ b/src/db.rs @@ -74,10 +74,10 @@ impl Database { /// Limit(1) /// Project(a,b) let source_plan = binder.bind(&stmts[0])?; - // println!("source_plan plan: {:#?}", source_plan); + //println!("source_plan plan: {:#?}", source_plan); let best_plan = Self::default_optimizer(source_plan).find_best()?; - // println!("best_plan plan: {:#?}", best_plan); + //println!("best_plan plan: {:#?}", best_plan); Ok(build(best_plan, &transaction)) } @@ -92,7 +92,11 @@ impl Database { .batch( "Simplify Filter".to_string(), HepBatchStrategy::fix_point_topdown(10), - vec![RuleImpl::SimplifyFilter, RuleImpl::ConstantCalculation], + vec![ + RuleImpl::LikeRewrite, + RuleImpl::SimplifyFilter, + RuleImpl::ConstantCalculation, + ], ) .batch( "Predicate Pushdown".to_string(), @@ -305,6 +309,12 @@ mod test { let _ = kipsql .run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)") .await?; + let _ = kipsql + .run("create table t4 (a int primary key, b varchar(100))") + .await?; + let _ = kipsql + .run("insert into t4 (a, b) values (1, 'abc'), (2, 'abdc'), (3, 'abcd'), (4, 'ddabc')") + .await?; println!("show tables:"); let tuples_show_tables = kipsql.run("show tables").await?; @@ -470,6 +480,10 @@ mod test { let tuples_decimal = kipsql.run("select * from t3").await?; println!("{}", create_table(&tuples_decimal)); + println!("like rewrite:"); + let tuples_like_rewrite = kipsql.run("select * from t4 where b like 'abc%'").await?; + println!("{}", create_table(&tuples_like_rewrite)); + Ok(()) } } diff --git a/src/optimizer/rule/mod.rs b/src/optimizer/rule/mod.rs index 1c9bbbed..a908a460 100644 --- a/src/optimizer/rule/mod.rs +++ b/src/optimizer/rule/mod.rs @@ -9,8 +9,8 @@ use crate::optimizer::rule::pushdown_limit::{ }; use crate::optimizer::rule::pushdown_predicates::PushPredicateIntoScan; use crate::optimizer::rule::pushdown_predicates::PushPredicateThroughJoin; -use crate::optimizer::rule::simplification::ConstantCalculation; use crate::optimizer::rule::simplification::SimplifyFilter; +use crate::optimizer::rule::simplification::{ConstantCalculation, LikeRewrite}; use crate::optimizer::OptimizerError; mod column_pruning; @@ -37,6 +37,7 @@ pub enum RuleImpl { // Simplification SimplifyFilter, ConstantCalculation, + LikeRewrite, } impl Rule for RuleImpl { @@ -53,6 +54,7 @@ impl Rule for RuleImpl { RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.pattern(), RuleImpl::SimplifyFilter => SimplifyFilter.pattern(), RuleImpl::ConstantCalculation => ConstantCalculation.pattern(), + RuleImpl::LikeRewrite => LikeRewrite.pattern(), } } @@ -69,6 +71,7 @@ impl Rule for RuleImpl { RuleImpl::SimplifyFilter => SimplifyFilter.apply(node_id, graph), RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.apply(node_id, graph), RuleImpl::ConstantCalculation => ConstantCalculation.apply(node_id, graph), + RuleImpl::LikeRewrite => LikeRewrite.apply(node_id, graph), } } } diff --git a/src/optimizer/rule/simplification.rs b/src/optimizer/rule/simplification.rs index 68aadadd..c9961955 100644 --- a/src/optimizer/rule/simplification.rs +++ b/src/optimizer/rule/simplification.rs @@ -1,11 +1,21 @@ +use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::OptimizerError; +use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; +use crate::types::value::{DataValue, ValueRef}; +use crate::types::LogicalType; use lazy_static::lazy_static; lazy_static! { + static ref LIKE_REWRITE_RULE: Pattern = { + Pattern { + predicate: |op| matches!(op, Operator::Filter(_)), + children: PatternChildrenPredicate::None, + } + }; static ref CONSTANT_CALCULATION_RULE: Pattern = { Pattern { predicate: |_| true, @@ -109,6 +119,99 @@ impl Rule for SimplifyFilter { } } +pub struct LikeRewrite; + +impl Rule for LikeRewrite { + fn pattern(&self) -> &Pattern { + &LIKE_REWRITE_RULE + } + + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { + if let Operator::Filter(mut filter_op) = graph.operator(node_id).clone() { + if let ScalarExpression::Binary { + op: BinaryOperator::Like, + ref mut left_expr, + ref mut right_expr, + ty, + } = filter_op.predicate.clone() + { + if let ScalarExpression::Constant(value) = right_expr.as_ref() { + if let DataValue::Utf8(value_str) = (**value).clone() { + Self::process_value_str(value_str, left_expr, ty, &mut filter_op); + } + } + } + graph.replace_node(node_id, Operator::Filter(filter_op)) + } + Ok(()) + } +} + +impl LikeRewrite { + fn process_value_str( + value_str: Option, + left_expr: &mut Box, + ty: LogicalType, + filter_op: &mut FilterOperator, + ) { + value_str.map(|value_str| { + if value_str.ends_with('%') { + let left_bound = value_str.trim_end_matches('%'); + let right_bound = increment_last_char(left_bound); + right_bound.map(|rb| { + filter_op.predicate = Self::create_new_expr( + &mut left_expr.clone(), + ty, + left_bound.to_string(), + rb, + ); + }); + } + }); + } + + fn create_new_expr( + left_expr: &mut Box, + ty: LogicalType, + left_bound: String, + right_bound: String, + ) -> ScalarExpression { + let new_expr = ScalarExpression::Binary { + op: BinaryOperator::And, + left_expr: Box::new(ScalarExpression::Binary { + op: BinaryOperator::GtEq, + left_expr: left_expr.clone(), + right_expr: Box::new(ScalarExpression::Constant(ValueRef::from(DataValue::Utf8( + Some(left_bound), + )))), + ty, + }), + + right_expr: Box::new(ScalarExpression::Binary { + op: BinaryOperator::Lt, + left_expr: left_expr.clone(), + right_expr: Box::new(ScalarExpression::Constant(ValueRef::from(DataValue::Utf8( + Some(right_bound), + )))), + ty, + }), + ty, + }; + new_expr + } +} + +fn increment_last_char(s: &str) -> Option { + let mut chars: Vec = s.chars().collect(); + for i in (0..chars.len()).rev() { + if let Some(next_char) = std::char::from_u32(chars[i] as u32 + 1) { + chars[i] = next_char; + return Some(chars.into_iter().collect()); + } + } + None +} + #[cfg(test)] mod test { use crate::binder::test::select_sql_run; @@ -118,6 +221,7 @@ mod test { use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; + use crate::optimizer::rule::simplification::increment_last_char; use crate::optimizer::rule::RuleImpl; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::Operator; @@ -127,6 +231,13 @@ mod test { use std::collections::Bound; use std::sync::Arc; + #[test] + fn test_increment_char() { + assert_eq!(increment_last_char("abc"), Some("abd".to_string())); + assert_eq!(increment_last_char("abz"), Some("ab{".to_string())); + assert_eq!(increment_last_char("ab}"), Some("ab~".to_string())); + } + #[tokio::test] async fn test_constant_calculation_omitted() -> Result<(), DatabaseError> { // (2 + (-1)) < -(c1 + 1) @@ -343,7 +454,7 @@ mod test { cb_1_c1, Some(ConstantBinary::Scope { min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))) + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))), }) ); @@ -353,7 +464,7 @@ mod test { cb_1_c2, Some(ConstantBinary::Scope { min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))), - max: Bound::Unbounded + max: Bound::Unbounded, }) ); @@ -363,7 +474,7 @@ mod test { cb_2_c1, Some(ConstantBinary::Scope { min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))), - max: Bound::Unbounded + max: Bound::Unbounded, }) ); @@ -373,7 +484,7 @@ mod test { cb_1_c1, Some(ConstantBinary::Scope { min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))) + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))), }) ); @@ -383,7 +494,7 @@ mod test { cb_3_c1, Some(ConstantBinary::Scope { min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))) + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))), }) ); @@ -393,7 +504,7 @@ mod test { cb_3_c2, Some(ConstantBinary::Scope { min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), - max: Bound::Unbounded + max: Bound::Unbounded, }) ); @@ -403,7 +514,7 @@ mod test { cb_4_c1, Some(ConstantBinary::Scope { min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), - max: Bound::Unbounded + max: Bound::Unbounded, }) ); @@ -413,7 +524,7 @@ mod test { cb_4_c2, Some(ConstantBinary::Scope { min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))) + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))), }) ); @@ -450,4 +561,85 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_like_rewrite() -> Result<(), DatabaseError> { + let plan = select_sql_run("select * from t1 where c1 like 'abc%%'").await?; + let best_plan = HepOptimizer::new(plan.clone()) + .batch( + "test_like_rewrite".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::LikeRewrite], + ) + .find_best()?; + assert_eq!(best_plan.childrens.len(), 1); + + match best_plan.operator { + Operator::Project(op) => { + assert_eq!(op.exprs.len(), 2); + } + _ => unreachable!(), + } + + match &best_plan.childrens[0].operator { + Operator::Filter(op) => { + assert_eq!( + op.predicate, + ScalarExpression::Binary { + op: BinaryOperator::And, + left_expr: Box::new(ScalarExpression::Binary { + op: BinaryOperator::GtEq, + left_expr: Box::new(ScalarExpression::ColumnRef(Arc::new( + ColumnCatalog { + summary: ColumnSummary { + id: Some(0), + name: "c1".to_string(), + }, + nullable: false, + desc: ColumnDesc { + column_datatype: LogicalType::Integer, + is_primary: true, + is_unique: false, + default: None, + }, + ref_expr: None, + } + ))), + right_expr: Box::new(ScalarExpression::Constant(Arc::new( + DataValue::Utf8(Some("abc".to_string())) + ))), + ty: LogicalType::Boolean, + }), + right_expr: Box::new(ScalarExpression::Binary { + op: BinaryOperator::Lt, + left_expr: Box::new(ScalarExpression::ColumnRef(Arc::new( + ColumnCatalog { + summary: ColumnSummary { + id: Some(0), + name: "c1".to_string(), + }, + nullable: false, + desc: ColumnDesc { + column_datatype: LogicalType::Integer, + is_primary: true, + is_unique: false, + default: None, + }, + ref_expr: None, + } + ))), + right_expr: Box::new(ScalarExpression::Constant(Arc::new( + DataValue::Utf8(Some("abd".to_string())) + ))), + ty: LogicalType::Boolean, + }), + ty: LogicalType::Boolean, + } + ); + } + _ => unreachable!(), + } + + Ok(()) + } } diff --git a/tests/slt/filter.slt b/tests/slt/filter.slt index 45cb7fd9..e2decadf 100644 --- a/tests/slt/filter.slt +++ b/tests/slt/filter.slt @@ -112,6 +112,28 @@ select * from t1 where id not in (1, 2) 0 KipSQL 3 Cool! +query II +select * from t1 where v1 like 'Kip%%' +---- +0 KipSQL +1 KipDB +2 KipBlog + +query II +select * from t1 where v1 like 'KC%%' +---- + + +query II +select * from t1 where v1 like 'Co%%' +---- +3 Cool! + +query II +select * from t1 where v1 like 'Cool!%' +---- +3 Cool! + statement ok drop table t