diff --git a/src/query/expression/src/types/number.rs b/src/query/expression/src/types/number.rs index 049c288fdfa8..d10c57fdfe15 100644 --- a/src/query/expression/src/types/number.rs +++ b/src/query/expression/src/types/number.rs @@ -531,6 +531,28 @@ impl NumberScalar { NumberScalar::NUM_TYPE(_) => NumberDataType::NUM_TYPE, }) } + + pub fn is_integer(&self) -> bool { + crate::with_integer_mapped_type!(|NUM_TYPE| match self { + NumberScalar::NUM_TYPE(_) => true, + _ => false, + }) + } + + pub fn integer_to_i128(&self) -> Option { + crate::with_integer_mapped_type!(|NUM_TYPE| match self { + NumberScalar::NUM_TYPE(x) => Some(*x as i128), + _ => None, + }) + } + + pub fn float_to_f64(&self) -> Option { + match self { + NumberScalar::Float32(value) => Some(value.into_inner() as f64), + NumberScalar::Float64(value) => Some(value.into_inner()), + _ => None, + } + } } impl From for NumberScalar diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs index 83389e8ef33f..49283f2e1ac1 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs @@ -385,7 +385,7 @@ fn rewrite_scalar_index( /// [`Range`] is to represent the value range of a column according to the predicates. /// /// Notes that only conjunctions will be parsed, and disjunctions will be ignored. -#[derive(Default, PartialEq, Debug)] +#[derive(Default, Debug)] struct Range<'a> { min: Option<&'a Scalar>, min_close: bool, @@ -393,6 +393,49 @@ struct Range<'a> { max_close: bool, } +impl<'a> PartialEq for Range<'a> { + fn eq(&self, other: &Self) -> bool { + // We cannot compare Scalar directly because when the NumberScalar types + // are different but the internal values are the same, the comparison + // result is false. + // So we need to compare the internal values of the Scalar, for example, + // `NumberScalar(UInt8(1)) == NumberScalar(UInt32(1))` should return true. + fn scalar_equal(left: Option<&Scalar>, right: Option<&Scalar>) -> bool { + match (left, right) { + (Some(left), Some(right)) => { + if let (Scalar::Number(left), Scalar::Number(right)) = (left, right) { + match (left.is_integer(), right.is_integer()) { + (true, true) => { + left.integer_to_i128().unwrap() == right.integer_to_i128().unwrap() + } + (false, false) => { + left.float_to_f64().unwrap() == right.float_to_f64().unwrap() + } + _ => false, + } + } else { + left == right + } + } + (None, None) => true, + _ => false, + } + } + + if self.min_close != other.min_close + || self.max_close != other.max_close + || !scalar_equal(self.min, other.min) + || !scalar_equal(self.max, other.max) + { + return false; + } + + true + } +} + +impl<'a> Eq for Range<'a> {} + impl<'a> Range<'a> { fn new(val: &'a Scalar, op: &str) -> Self { let mut range = Range::default(); @@ -462,9 +505,31 @@ impl<'a> Range<'a> { return false; } + // We need to compare the internal values of the Scalar, for example, + // `NumberScalar(UInt8(2)) > NumberScalar(UInt32(1))` should return true. match (self.min, other.min) { - (Some(m1), Some(m2)) => { - if m1 > m2 || (m1 == m2 && !self.min_close && other.min_close) { + (Some(left), Some(right)) => { + if let (Scalar::Number(left), Scalar::Number(right)) = (left, right) { + match (left.is_integer(), right.is_integer()) { + (true, true) => { + let left = left.integer_to_i128().unwrap(); + let right = right.integer_to_i128().unwrap(); + if left > right || (left == right && self.min_close && !other.min_close) + { + return false; + } + } + (false, false) => { + let left = left.float_to_f64().unwrap(); + let right = right.float_to_f64().unwrap(); + if left > right || (left == right && self.min_close && !other.min_close) + { + return false; + } + } + _ => return false, + } + } else if left > right || (left == right && self.min_close && !other.min_close) { return false; } } @@ -475,8 +540,28 @@ impl<'a> Range<'a> { } match (self.max, other.max) { - (Some(m1), Some(m2)) => { - if m1 < m2 || (m1 == m2 && !self.max_close && other.max_close) { + (Some(left), Some(right)) => { + if let (Scalar::Number(left), Scalar::Number(right)) = (left, right) { + match (left.is_integer(), right.is_integer()) { + (true, true) => { + let left = left.integer_to_i128().unwrap(); + let right = right.integer_to_i128().unwrap(); + if left < right || (left == right && self.min_close && !other.min_close) + { + return false; + } + } + (false, false) => { + let left = left.float_to_f64().unwrap(); + let right = right.float_to_f64().unwrap(); + if left < right || (left == right && self.min_close && !other.min_close) + { + return false; + } + } + _ => return false, + } + } else if left < right || (left == right && self.min_close && !other.min_close) { return false; } }