diff --git a/src/query/expression/src/types/decimal.rs b/src/query/expression/src/types/decimal.rs index d27cc86488d0..91233032fe16 100644 --- a/src/query/expression/src/types/decimal.rs +++ b/src/query/expression/src/types/decimal.rs @@ -26,7 +26,6 @@ use ethnum::i256; use ethnum::AsI256; use itertools::Itertools; use num_traits::NumCast; -use num_traits::ToPrimitive; use serde::Deserialize; use serde::Serialize; @@ -305,8 +304,8 @@ pub trait Decimal: fn default_decimal_size() -> DecimalSize; fn from_float(value: f64) -> Self; - fn from_u64(value: u64) -> Self; - fn from_i64(value: i64) -> Self; + fn from_i128>(value: U) -> Self; + fn de_binary(bytes: &mut &[u8]) -> Self; fn to_float32(self, scale: u8) -> f32; @@ -442,12 +441,8 @@ impl Decimal for i128 { } } - fn from_u64(value: u64) -> Self { - value.to_i128().unwrap() - } - - fn from_i64(value: i64) -> Self { - value.to_i128().unwrap() + fn from_i128>(value: U) -> Self { + value.into() } fn de_binary(bytes: &mut &[u8]) -> Self { @@ -611,12 +606,8 @@ impl Decimal for i256 { value.as_i256() } - fn from_u64(value: u64) -> Self { - i256::from(value.to_i128().unwrap()) - } - - fn from_i64(value: i64) -> Self { - i256::from(value.to_i128().unwrap()) + fn from_i128>(value: U) -> Self { + i256::from(value.into()) } fn de_binary(bytes: &mut &[u8]) -> Self { diff --git a/src/query/expression/src/utils/serialize.rs b/src/query/expression/src/utils/serialize.rs index b95a70d53662..3b8831f64a15 100644 --- a/src/query/expression/src/utils/serialize.rs +++ b/src/query/expression/src/utils/serialize.rs @@ -50,9 +50,9 @@ pub fn read_decimal_with_size( // Checking whether numbers need to be added or subtracted to calculate rounding if let Some(r) = n.checked_rem(T::e(scale_diff)) { if let Some(m) = r.checked_div(T::e(scale_diff - 1)) { - if m >= T::from_i64(5i64) { + if m >= T::from_i128(5i64) { round_val = Some(T::one()); - } else if m <= T::from_i64(-5i64) { + } else if m <= T::from_i128(-5i64) { round_val = Some(T::minus_one()); } } @@ -140,7 +140,7 @@ pub fn read_decimal( .checked_mul(T::e(zeros + 1)) .ok_or_else(decimal_overflow_error)?; n = n - .checked_add(T::from_u64((v - b'0') as u64)) + .checked_add(T::from_i128((v - b'0') as u64)) .ok_or_else(decimal_overflow_error)?; zeros = 0; } @@ -200,7 +200,7 @@ pub fn read_decimal( .checked_mul(T::e(zeros + 1)) .ok_or_else(decimal_overflow_error)?; n = n - .checked_add(T::from_u64((v - b'0') as u64)) + .checked_add(T::from_i128((v - b'0') as u64)) .ok_or_else(decimal_overflow_error)?; digits += zeros + 1; zeros = 0; @@ -288,11 +288,11 @@ pub fn read_decimal_from_json( match value { serde_json::Value::Number(n) => { if n.is_i64() { - Ok(T::from_i64(n.as_i64().unwrap()) + Ok(T::from_i128(n.as_i64().unwrap()) .with_size(size) .ok_or_else(decimal_overflow_error)?) } else if n.is_u64() { - Ok(T::from_u64(n.as_u64().unwrap()) + Ok(T::from_i128(n.as_u64().unwrap()) .with_size(size) .ok_or_else(decimal_overflow_error)?) } else { diff --git a/src/query/expression/src/values.rs b/src/query/expression/src/values.rs index 200460ad075f..2a43d27c6767 100755 --- a/src/query/expression/src/values.rs +++ b/src/query/expression/src/values.rs @@ -59,6 +59,7 @@ use crate::types::decimal::DecimalColumnBuilder; use crate::types::decimal::DecimalDataType; use crate::types::decimal::DecimalScalar; use crate::types::decimal::DecimalSize; +use crate::types::decimal::DecimalType; use crate::types::nullable::NullableColumn; use crate::types::nullable::NullableColumnBuilder; use crate::types::nullable::NullableColumnVec; @@ -284,6 +285,15 @@ impl Value { } } +impl Value> { + pub fn upcast_decimal(self, size: DecimalSize) -> Value { + match self { + Value::Scalar(scalar) => Value::Scalar(T::upcast_scalar(scalar, size)), + Value::Column(col) => Value::Column(T::upcast_column(col, size)), + } + } +} + impl Value { pub fn convert_to_full_column(&self, ty: &DataType, num_rows: usize) -> Column { match self { diff --git a/src/query/functions/src/aggregates/aggregate_array_moving.rs b/src/query/functions/src/aggregates/aggregate_array_moving.rs index 23efb0eea025..acffce28f419 100644 --- a/src/query/functions/src/aggregates/aggregate_array_moving.rs +++ b/src/query/functions/src/aggregates/aggregate_array_moving.rs @@ -374,7 +374,7 @@ where T: Decimal } let avg_val = match sum .checked_mul(T::e(scale_add as u32)) - .and_then(|v| v.checked_div(T::from_u64(window_size as u64))) + .and_then(|v| v.checked_div(T::from_i128(window_size as u64))) { Some(value) => value, None => { diff --git a/src/query/functions/src/aggregates/aggregate_avg.rs b/src/query/functions/src/aggregates/aggregate_avg.rs index 5304f9751ed4..94cd8dfff9c9 100644 --- a/src/query/functions/src/aggregates/aggregate_avg.rs +++ b/src/query/functions/src/aggregates/aggregate_avg.rs @@ -189,7 +189,7 @@ where match self .value .checked_mul(T::Scalar::e(decimal_avg_data.scale_add as u32)) - .and_then(|v| v.checked_div(T::Scalar::from_u64(self.count))) + .and_then(|v| v.checked_div(T::Scalar::from_i128(self.count))) { Some(value) => { T::push_item(builder, T::to_scalar_ref(&value)); diff --git a/src/query/functions/src/aggregates/aggregate_stddev.rs b/src/query/functions/src/aggregates/aggregate_stddev.rs index 2800e2423cf0..573567aaa91a 100644 --- a/src/query/functions/src/aggregates/aggregate_stddev.rs +++ b/src/query/functions/src/aggregates/aggregate_stddev.rs @@ -173,7 +173,7 @@ where self.count += 1; if self.count > 1 { let t = match value - .checked_mul(T::Scalar::from_u64(self.count)) + .checked_mul(T::Scalar::from_i128(self.count)) .and_then(|v| v.checked_sub(self.sum)) .and_then(|v| v.checked_mul(T::Scalar::e(VARIANCE_PRECISION as u32))) { @@ -204,7 +204,7 @@ where } }; - let count = T::Scalar::from_u64(self.count * (self.count - 1)); + let count = T::Scalar::from_i128(self.count * (self.count - 1)); let add_variance = match t.checked_div(count) { Some(t) => t, @@ -236,8 +236,8 @@ where return Ok(()); } - let other_count = T::Scalar::from_u64(other.count); - let self_count = T::Scalar::from_u64(self.count); + let other_count = T::Scalar::from_i128(other.count); + let self_count = T::Scalar::from_i128(self.count); let t = match other_count .checked_mul(self.sum) .and_then(|v| v.checked_mul(T::Scalar::e(VARIANCE_PRECISION as u32))) diff --git a/src/query/functions/src/scalars/arithmetic.rs b/src/query/functions/src/scalars/arithmetic.rs index 4d75ecb3e189..1319dacc70e6 100644 --- a/src/query/functions/src/scalars/arithmetic.rs +++ b/src/query/functions/src/scalars/arithmetic.rs @@ -43,6 +43,7 @@ use databend_common_expression::types::ALL_INTEGER_TYPES; use databend_common_expression::types::ALL_NUMBER_CLASSES; use databend_common_expression::types::ALL_NUMERICS_TYPES; use databend_common_expression::types::ALL_UNSIGNED_INTEGER_TYPES; +use databend_common_expression::types::F32; use databend_common_expression::utils::arithmetics_type::ResultTypeOfBinary; use databend_common_expression::utils::arithmetics_type::ResultTypeOfUnary; use databend_common_expression::values::Value; @@ -72,10 +73,9 @@ use lexical_core::FormattedSize; use num_traits::AsPrimitive; use super::arithmetic_modulo::vectorize_modulo; -use super::decimal::register_decimal_to_float32; -use super::decimal::register_decimal_to_float64; use super::decimal::register_decimal_to_int; use crate::scalars::decimal::register_decimal_arithmetic; +use crate::scalars::decimal::register_decimal_to_float; pub fn register(registry: &mut FunctionRegistry) { registry.register_aliases("plus", &["add"]); @@ -717,10 +717,10 @@ pub fn register_number_to_number(registry: &mut FunctionRegistry) { NumberClass::Decimal128 => { // todo(youngsofun): add decimal try_cast and decimal to int and float if matches!(dest_type, NumberDataType::Float32) { - register_decimal_to_float32(registry); + register_decimal_to_float::(registry); } if matches!(dest_type, NumberDataType::Float64) { - register_decimal_to_float64(registry); + register_decimal_to_float::(registry); } with_number_mapped_type!(|DEST_TYPE| match dest_type { diff --git a/src/query/functions/src/scalars/decimal.rs b/src/query/functions/src/scalars/decimal.rs deleted file mode 100644 index 78618cc8720e..000000000000 --- a/src/query/functions/src/scalars/decimal.rs +++ /dev/null @@ -1,1911 +0,0 @@ -// Copyright 2021 Datafuse Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::cmp::Ord; -use std::ops::*; -use std::sync::Arc; - -use databend_common_arrow::arrow::bitmap::Bitmap; -use databend_common_arrow::arrow::buffer::Buffer; -use databend_common_expression::serialize::read_decimal_with_size; -use databend_common_expression::type_check::common_super_type; -use databend_common_expression::types::decimal::*; -use databend_common_expression::types::string::StringColumn; -use databend_common_expression::types::*; -use databend_common_expression::with_decimal_mapped_type; -use databend_common_expression::with_integer_mapped_type; -use databend_common_expression::with_number_mapped_type; -use databend_common_expression::Column; -use databend_common_expression::ColumnBuilder; -use databend_common_expression::Domain; -use databend_common_expression::EvalContext; -use databend_common_expression::FromData; -use databend_common_expression::Function; -use databend_common_expression::FunctionContext; -use databend_common_expression::FunctionDomain; -use databend_common_expression::FunctionEval; -use databend_common_expression::FunctionRegistry; -use databend_common_expression::FunctionSignature; -use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; -use databend_common_expression::SimpleDomainCmp; -use databend_common_expression::Value; -use databend_common_expression::ValueRef; -use ethnum::i256; -use num_traits::AsPrimitive; -use ordered_float::OrderedFloat; - -macro_rules! op_decimal { - ($a: expr, $b: expr, $ctx: expr, $left: expr, $right: expr, $result_type: expr, $op: ident, $is_divide: expr) => { - match $left { - DecimalDataType::Decimal128(_) => { - binary_decimal!( - $a, - $b, - $ctx, - $left, - $right, - $op, - $result_type.size(), - i128, - Decimal128, - $is_divide - ) - } - DecimalDataType::Decimal256(_) => { - binary_decimal!( - $a, - $b, - $ctx, - $left, - $right, - $op, - $result_type.size(), - i256, - Decimal256, - $is_divide - ) - } - } - }; - ($a: expr, $b: expr, $return_type: expr, $op: ident) => { - match $return_type { - DecimalDataType::Decimal128(_) => { - compare_decimal!($a, $b, $op, Decimal128) - } - DecimalDataType::Decimal256(_) => { - compare_decimal!($a, $b, $op, Decimal256) - } - } - }; -} - -macro_rules! compare_decimal { - ($a: expr, $b: expr, $op: ident, $decimal_type: tt) => {{ - match ($a, $b) { - ( - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer_a, _))), - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer_b, _))), - ) => { - let result = buffer_a - .iter() - .zip(buffer_b.iter()) - .map(|(a, b)| a.cmp(b).$op()) - .collect(); - - Value::Column(Column::Boolean(result)) - } - - ( - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer, _))), - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(b, _))), - ) => { - let result = buffer.iter().map(|a| a.cmp(b).$op()).collect(); - - Value::Column(Column::Boolean(result)) - } - - ( - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(a, _))), - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer, _))), - ) => { - let result = buffer.iter().map(|b| a.cmp(b).$op()).collect(); - - Value::Column(Column::Boolean(result)) - } - - ( - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(a, _))), - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(b, _))), - ) => Value::Scalar(Scalar::Boolean(a.cmp(b).$op())), - - _ => unreachable!("arg type of cmp op is not required decimal"), - } - }}; -} - -macro_rules! binary_decimal { - ($a: expr, $b: expr, $ctx: expr, $left: expr, $right: expr, $op: ident, $size: expr, $type_name: ty, $decimal_type: tt, $is_divide: expr) => {{ - let overflow = $size.precision == <$type_name>::default_decimal_size().precision; - - if $is_divide { - let scale_a = $left.scale(); - let scale_b = $right.scale(); - binary_decimal_div!( - $a, - $b, - $ctx, - scale_a, - scale_b, - $op, - $size, - $type_name, - $decimal_type - ) - } else if overflow { - binary_decimal_check_overflow!($a, $b, $ctx, $op, $size, $type_name, $decimal_type) - } else { - binary_decimal_no_overflow!($a, $b, $ctx, $op, $size, $type_name, $decimal_type) - } - }}; -} - -macro_rules! binary_decimal_no_overflow { - ($a: expr, $b: expr, $ctx: expr, $op: ident, $size: expr, $type_name: ty, $decimal_type: tt) => {{ - match ($a, $b) { - ( - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer_a, _))), - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer_b, _))), - ) => { - let result: Vec<_> = buffer_a - .iter() - .zip(buffer_b.iter()) - .map(|(a, b)| a.$op(b)) - .collect(); - Value::Column(Column::Decimal(DecimalColumn::$decimal_type( - result.into(), - $size, - ))) - } - - ( - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer, _))), - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(b, _))), - ) => { - let result: Vec<_> = buffer.iter().map(|a| a.$op(b)).collect(); - - Value::Column(Column::Decimal(DecimalColumn::$decimal_type( - result.into(), - $size, - ))) - } - - ( - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(a, _))), - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer, _))), - ) => { - let result: Vec<_> = buffer.iter().map(|b| a.$op(b)).collect(); - Value::Column(Column::Decimal(DecimalColumn::$decimal_type( - result.into(), - $size, - ))) - } - - ( - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(a, _))), - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(b, _))), - ) => Value::Scalar(Scalar::Decimal(DecimalScalar::$decimal_type( - a.$op(b), - $size, - ))), - - _ => unreachable!("arg type of binary op is not required decimal"), - } - }}; -} - -macro_rules! binary_decimal_check_overflow { - ($a: expr, $b: expr, $ctx: expr, $op: ident, $size: expr, $type_name: ty, $decimal_type: tt) => {{ - let one = <$type_name>::one(); - let min_for_precision = <$type_name>::min_for_precision($size.precision); - let max_for_precision = <$type_name>::max_for_precision($size.precision); - - match ($a, $b) { - ( - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer_a, _))), - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer_b, _))), - ) => { - let mut result = Vec::with_capacity(buffer_a.len()); - - for (a, b) in buffer_a.iter().zip(buffer_b.iter()) { - let t = a.$op(b); - if t < min_for_precision || t > max_for_precision { - $ctx.set_error( - result.len(), - concat!("Decimal overflow at line : ", line!()), - ); - result.push(one); - } else { - result.push(t); - } - } - Value::Column(Column::Decimal(DecimalColumn::$decimal_type( - result.into(), - $size, - ))) - } - - ( - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer, _))), - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(b, _))), - ) => { - let mut result = Vec::with_capacity(buffer.len()); - - for a in buffer.iter() { - let t = a.$op(b); - if t < min_for_precision || t > max_for_precision { - $ctx.set_error( - result.len(), - concat!("Decimal overflow at line : ", line!()), - ); - result.push(one); - } else { - result.push(t); - } - } - - Value::Column(Column::Decimal(DecimalColumn::$decimal_type( - result.into(), - $size, - ))) - } - - ( - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(a, _))), - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer, _))), - ) => { - let mut result = Vec::with_capacity(buffer.len()); - - for b in buffer.iter() { - let t = a.$op(b); - if t < min_for_precision || t > max_for_precision { - $ctx.set_error( - result.len(), - concat!("Decimal overflow at line : ", line!()), - ); - result.push(one); - } else { - result.push(t); - } - } - Value::Column(Column::Decimal(DecimalColumn::$decimal_type( - result.into(), - $size, - ))) - } - - ( - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(a, _))), - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(b, _))), - ) => { - let t = a.$op(b); - if t < min_for_precision || t > max_for_precision { - $ctx.set_error(0, concat!("Decimal overflow at line : ", line!())); - } - Value::Scalar(Scalar::Decimal(DecimalScalar::$decimal_type(t, $size))) - } - - _ => unreachable!("arg type of binary op is not required decimal"), - } - }}; -} - -macro_rules! binary_decimal_div { - ($a: expr, $b: expr, $ctx: expr, $scale_a: expr, $scale_b: expr, $op: ident, $size: expr, $type_name: ty, $decimal_type: tt) => {{ - let zero = <$type_name>::zero(); - let one = <$type_name>::one(); - - let (scale_mul, scale_div) = if $scale_b + $size.scale > $scale_a { - ($scale_b + $size.scale - $scale_a, 0) - } else { - (0, $scale_b + $size.scale - $scale_a) - }; - - let multiplier = <$type_name>::e(scale_mul as u32); - let div = <$type_name>::e(scale_div as u32); - - match ($a, $b) { - ( - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer_a, _))), - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer_b, _))), - ) => { - let mut result = Vec::with_capacity(buffer_a.len()); - - for (a, b) in buffer_a.iter().zip(buffer_b.iter()) { - if std::intrinsics::unlikely(*b == zero) { - $ctx.set_error(result.len(), "divided by zero"); - result.push(one); - } else { - result.push((a * multiplier).$op(b) / div); - } - } - Value::Column(Column::Decimal(DecimalColumn::$decimal_type( - result.into(), - $size, - ))) - } - - ( - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer, _))), - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(b, _))), - ) => { - let mut result = Vec::with_capacity(buffer.len()); - - for a in buffer.iter() { - if std::intrinsics::unlikely(*b == zero) { - $ctx.set_error(result.len(), "divided by zero"); - result.push(one); - } else { - result.push((a * multiplier).$op(b) / div); - } - } - - Value::Column(Column::Decimal(DecimalColumn::$decimal_type( - result.into(), - $size, - ))) - } - - ( - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(a, _))), - ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer, _))), - ) => { - let mut result = Vec::with_capacity(buffer.len()); - - for b in buffer.iter() { - if std::intrinsics::unlikely(*b == zero) { - $ctx.set_error(result.len(), "divided by zero"); - result.push(one); - } else { - result.push((a * multiplier).$op(b) / div); - } - } - Value::Column(Column::Decimal(DecimalColumn::$decimal_type( - result.into(), - $size, - ))) - } - - ( - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(a, _))), - ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(b, _))), - ) => { - let mut t = zero; - if std::intrinsics::unlikely(*b == zero) { - $ctx.set_error(0, "divided by zero"); - } else { - t = (a * multiplier).$op(b) / div; - } - Value::Scalar(Scalar::Decimal(DecimalScalar::$decimal_type(t, $size))) - } - - _ => unreachable!("arg type of binary op is not required decimal"), - } - }}; -} - -macro_rules! register_decimal_compare_op { - ($registry: expr, $name: expr, $op: ident, $domain_op: tt) => { - $registry.register_function_factory($name, |_, args_type| { - if args_type.len() != 2 { - return None; - } - - let has_nullable = args_type.iter().any(|x| x.is_nullable_or_null()); - let args_type: Vec = args_type.iter().map(|x| x.remove_nullable()).collect(); - - // Only works for one of is decimal types - if !args_type[0].is_decimal() && !args_type[1].is_decimal() { - return None; - } - - let common_type = common_super_type(args_type[0].clone(), args_type[1].clone(), &[])?; - - if !common_type.is_decimal() { - return None; - } - - // Comparison between different decimal types must be same siganature types - let function = Function { - signature: FunctionSignature { - name: $name.to_string(), - args_type: vec![common_type.clone(), common_type.clone()], - return_type: DataType::Boolean, - }, - eval: FunctionEval::Scalar { - calc_domain: Box::new(|_, d| { - let new_domain = match (&d[0], &d[1]) { - ( - Domain::Decimal(DecimalDomain::Decimal128(d1, _)), - Domain::Decimal(DecimalDomain::Decimal128(d2, _)), - ) => d1.$domain_op(d2), - ( - Domain::Decimal(DecimalDomain::Decimal256(d1, _)), - Domain::Decimal(DecimalDomain::Decimal256(d2, _)), - ) => d1.$domain_op(d2), - _ => unreachable!("Expect two same decimal domains, got {:?}", d), - }; - new_domain.map(|d| Domain::Boolean(d)) - }), - eval: Box::new(move |args, _ctx| { - op_decimal!(&args[0], &args[1], common_type.as_decimal().unwrap(), $op) - }), - }, - }; - if has_nullable { - Some(Arc::new(function.passthrough_nullable())) - } else { - Some(Arc::new(function)) - } - }); - }; -} - -#[inline(always)] -fn domain_plus( - lhs: &SimpleDomain, - rhs: &SimpleDomain, - precision: u8, -) -> Option> { - // For plus, the scale of the two operands must be the same. - let min = T::min_for_precision(precision); - let max = T::max_for_precision(precision); - Some(SimpleDomain { - min: lhs - .min - .checked_add(rhs.min) - .filter(|&m| m >= min && m <= max)?, - max: lhs - .max - .checked_add(rhs.max) - .filter(|&m| m >= min && m <= max)?, - }) -} - -#[inline(always)] -fn domain_minus( - lhs: &SimpleDomain, - rhs: &SimpleDomain, - precision: u8, -) -> Option> { - // For minus, the scale of the two operands must be the same. - let min = T::min_for_precision(precision); - let max = T::max_for_precision(precision); - Some(SimpleDomain { - min: lhs - .min - .checked_sub(rhs.max) - .filter(|&m| m >= min && m <= max)?, - max: lhs - .max - .checked_sub(rhs.min) - .filter(|&m| m >= min && m <= max)?, - }) -} - -#[inline(always)] -fn domain_mul( - lhs: &SimpleDomain, - rhs: &SimpleDomain, - precision: u8, -) -> Option> { - let min = T::min_for_precision(precision); - let max = T::max_for_precision(precision); - - let a = lhs - .min - .checked_mul(rhs.min) - .filter(|&m| m >= min && m <= max)?; - let b = lhs - .min - .checked_mul(rhs.max) - .filter(|&m| m >= min && m <= max)?; - let c = lhs - .max - .checked_mul(rhs.min) - .filter(|&m| m >= min && m <= max)?; - let d = lhs - .max - .checked_mul(rhs.max) - .filter(|&m| m >= min && m <= max)?; - - Some(SimpleDomain { - min: a.min(b).min(c).min(d), - max: a.max(b).max(c).max(d), - }) -} - -#[inline(always)] -fn domain_div( - _lhs: &SimpleDomain, - _rhs: &SimpleDomain, - _precision: u8, -) -> Option> { - // For div, we cannot determine the domain. - None -} - -macro_rules! register_decimal_binary_op { - ($registry: expr, $name: expr, $op: ident, $domain_op: ident, $default_domain: expr) => { - $registry.register_function_factory($name, |_, args_type| { - if args_type.len() != 2 { - return None; - } - - let has_nullable = args_type.iter().any(|x| x.is_nullable_or_null()); - let args_type: Vec = args_type.iter().map(|x| x.remove_nullable()).collect(); - - // number X decimal -> decimal - // decimal X number -> decimal - // decimal X decimal -> decimal - if !args_type[0].is_decimal() && !args_type[1].is_decimal() { - return None; - } - - let decimal_a = - DecimalDataType::from_size(args_type[0].get_decimal_properties()?).unwrap(); - let decimal_b = - DecimalDataType::from_size(args_type[1].get_decimal_properties()?).unwrap(); - - let is_multiply = $name == "multiply"; - let is_divide = $name == "divide"; - let is_plus_minus = !is_multiply && !is_divide; - - // left, right will unify to same width decimal, both 256 or both 128 - let (left, right, return_decimal_type) = DecimalDataType::binary_result_type( - &decimal_a, - &decimal_b, - is_multiply, - is_divide, - is_plus_minus, - ) - .ok()?; - - let function = Function { - signature: FunctionSignature { - name: $name.to_string(), - args_type: vec![ - DataType::Decimal(left.clone()), - DataType::Decimal(right.clone()), - ], - return_type: DataType::Decimal(return_decimal_type), - }, - eval: FunctionEval::Scalar { - calc_domain: Box::new(move |_ctx, d| { - let lhs = d[0].as_decimal(); - let rhs = d[1].as_decimal(); - - if lhs.is_none() || rhs.is_none() { - return FunctionDomain::Full; - } - - let lhs = lhs.unwrap(); - let rhs = rhs.unwrap(); - - let size = return_decimal_type.size(); - - { - match (lhs, rhs) { - ( - DecimalDomain::Decimal128(d1, _), - DecimalDomain::Decimal128(d2, _), - ) => $domain_op(&d1, &d2, size.precision) - .map(|d| DecimalDomain::Decimal128(d, size)), - ( - DecimalDomain::Decimal256(d1, _), - DecimalDomain::Decimal256(d2, _), - ) => $domain_op(&d1, &d2, size.precision) - .map(|d| DecimalDomain::Decimal256(d, size)), - _ => { - unreachable!("unreachable decimal domain {:?} /{:?}", lhs, rhs) - } - } - } - .map(|d| FunctionDomain::Domain(Domain::Decimal(d))) - .unwrap_or($default_domain) - }), - eval: Box::new(move |args, ctx| { - let res = op_decimal!( - &args[0], - &args[1], - ctx, - left, - right, - return_decimal_type, - $op, - is_divide - ); - - res - }), - }, - }; - if has_nullable { - Some(Arc::new(function.passthrough_nullable())) - } else { - Some(Arc::new(function)) - } - }); - }; -} - -pub(crate) fn register_decimal_compare_op(registry: &mut FunctionRegistry) { - register_decimal_compare_op!(registry, "lt", is_lt, domain_lt); - register_decimal_compare_op!(registry, "eq", is_eq, domain_eq); - register_decimal_compare_op!(registry, "gt", is_gt, domain_gt); - register_decimal_compare_op!(registry, "lte", is_le, domain_lte); - register_decimal_compare_op!(registry, "gte", is_ge, domain_gte); - register_decimal_compare_op!(registry, "noteq", is_ne, domain_noteq); -} - -pub(crate) fn register_decimal_arithmetic(registry: &mut FunctionRegistry) { - // TODO checked overflow by default - register_decimal_binary_op!(registry, "plus", add, domain_plus, FunctionDomain::Full); - register_decimal_binary_op!(registry, "minus", sub, domain_minus, FunctionDomain::Full); - register_decimal_binary_op!( - registry, - "divide", - div, - domain_div, - FunctionDomain::MayThrow - ); - register_decimal_binary_op!(registry, "multiply", mul, domain_mul, FunctionDomain::Full); -} - -// int float to decimal -pub fn register(registry: &mut FunctionRegistry) { - let factory = |params: &[usize], args_type: &[DataType]| { - if args_type.len() != 1 { - return None; - } - if params.len() != 2 { - return None; - } - - let from_type = args_type[0].remove_nullable(); - - if !matches!( - from_type, - DataType::Boolean | DataType::Number(_) | DataType::Decimal(_) | DataType::String - ) { - return None; - } - - let decimal_size = DecimalSize { - precision: params[0] as u8, - scale: params[1] as u8, - }; - - let decimal_type = DecimalDataType::from_size(decimal_size).ok()?; - - Some(Function { - signature: FunctionSignature { - name: "to_decimal".to_string(), - args_type: vec![from_type.clone()], - return_type: DataType::Decimal(decimal_type), - }, - eval: FunctionEval::Scalar { - calc_domain: Box::new(move |ctx, d| { - convert_to_decimal_domain(ctx, d[0].clone(), decimal_type) - .map(|d| FunctionDomain::Domain(Domain::Decimal(d))) - .unwrap_or(FunctionDomain::MayThrow) - }), - eval: Box::new(move |args, ctx| { - convert_to_decimal(&args[0], ctx, &from_type, decimal_type) - }), - }, - }) - }; - - registry.register_function_factory("to_decimal", move |params, args_type| { - Some(Arc::new(factory(params, args_type)?)) - }); - registry.register_function_factory("to_decimal", move |params, args_type| { - let f = factory(params, args_type)?; - Some(Arc::new(f.passthrough_nullable())) - }); - registry.register_function_factory("try_to_decimal", move |params, args_type| { - let mut f = factory(params, args_type)?; - f.signature.name = "try_to_decimal".to_string(); - Some(Arc::new(f.error_to_null())) - }); - registry.register_function_factory("try_to_decimal", move |params, args_type| { - let mut f = factory(params, args_type)?; - f.signature.name = "try_to_decimal".to_string(); - Some(Arc::new(f.error_to_null().passthrough_nullable())) - }); -} - -pub(crate) fn register_decimal_to_float64(registry: &mut FunctionRegistry) { - let factory = |_params: &[usize], args_type: &[DataType]| { - if args_type.len() != 1 { - return None; - } - - let arg_type = args_type[0].remove_nullable(); - - if !arg_type.is_decimal() { - return None; - } - - let function = Function { - signature: FunctionSignature { - name: "to_float64".to_string(), - args_type: vec![arg_type.clone()], - return_type: Float64Type::data_type(), - }, - eval: FunctionEval::Scalar { - calc_domain: Box::new(|_, d| match d[0].as_decimal().unwrap() { - DecimalDomain::Decimal128(d, size) => FunctionDomain::Domain(Domain::Number( - NumberDomain::Float64(SimpleDomain { - min: OrderedFloat(d.min.to_float64(size.scale)), - max: OrderedFloat(d.max.to_float64(size.scale)), - }), - )), - DecimalDomain::Decimal256(d, size) => FunctionDomain::Domain(Domain::Number( - NumberDomain::Float64(SimpleDomain { - min: OrderedFloat(d.min.to_float64(size.scale)), - max: OrderedFloat(d.max.to_float64(size.scale)), - }), - )), - }), - eval: Box::new(move |args, tx| decimal_to_float64(&args[0], arg_type.clone(), tx)), - }, - }; - - Some(function) - }; - - registry.register_function_factory("to_float64", move |params, args_type| { - Some(Arc::new(factory(params, args_type)?)) - }); - registry.register_function_factory("to_float64", move |params, args_type| { - let f = factory(params, args_type)?; - Some(Arc::new(f.passthrough_nullable())) - }); - registry.register_function_factory("try_to_float64", move |params, args_type| { - let mut f = factory(params, args_type)?; - f.signature.name = "try_to_float64".to_string(); - Some(Arc::new(f.error_to_null())) - }); - registry.register_function_factory("try_to_float64", move |params, args_type| { - let mut f = factory(params, args_type)?; - f.signature.name = "try_to_float64".to_string(); - Some(Arc::new(f.error_to_null().passthrough_nullable())) - }); -} - -pub(crate) fn register_decimal_to_float32(registry: &mut FunctionRegistry) { - let factory = |_params: &[usize], args_type: &[DataType]| { - if args_type.len() != 1 { - return None; - } - - let arg_type = args_type[0].remove_nullable(); - if !arg_type.is_decimal() { - return None; - } - - let function = Function { - signature: FunctionSignature { - name: "to_float32".to_string(), - args_type: vec![arg_type.clone()], - return_type: Float32Type::data_type(), - }, - eval: FunctionEval::Scalar { - calc_domain: Box::new(|_, d| match d[0].as_decimal().unwrap() { - DecimalDomain::Decimal128(d, size) => FunctionDomain::Domain(Domain::Number( - NumberDomain::Float32(SimpleDomain { - min: OrderedFloat(d.min.to_float32(size.scale)), - max: OrderedFloat(d.max.to_float32(size.scale)), - }), - )), - DecimalDomain::Decimal256(d, size) => FunctionDomain::Domain(Domain::Number( - NumberDomain::Float32(SimpleDomain { - min: OrderedFloat(d.min.to_float32(size.scale)), - max: OrderedFloat(d.max.to_float32(size.scale)), - }), - )), - }), - eval: Box::new(move |args, tx| decimal_to_float32(&args[0], arg_type.clone(), tx)), - }, - }; - - Some(function) - }; - - registry.register_function_factory("to_float32", move |params, args_type| { - Some(Arc::new(factory(params, args_type)?)) - }); - registry.register_function_factory("to_float32", move |params, args_type| { - let f = factory(params, args_type)?; - Some(Arc::new(f.passthrough_nullable())) - }); - registry.register_function_factory("try_to_float32", move |params, args_type| { - let mut f = factory(params, args_type)?; - f.signature.name = "try_to_float32".to_string(); - Some(Arc::new(f.error_to_null())) - }); - registry.register_function_factory("try_to_float32", move |params, args_type| { - let mut f = factory(params, args_type)?; - f.signature.name = "try_to_float32".to_string(); - Some(Arc::new(f.error_to_null().passthrough_nullable())) - }); -} - -pub(crate) fn register_decimal_to_int(registry: &mut FunctionRegistry) { - if T::data_type().is_float() { - return; - } - let name = format!("to_{}", T::data_type().to_string().to_lowercase()); - let try_name = format!("try_to_{}", T::data_type().to_string().to_lowercase()); - - let factory = |_params: &[usize], args_type: &[DataType]| { - if args_type.len() != 1 { - return None; - } - - let name = format!("to_{}", T::data_type().to_string().to_lowercase()); - let arg_type = args_type[0].remove_nullable(); - if !arg_type.is_decimal() { - return None; - } - - let function = Function { - signature: FunctionSignature { - name, - args_type: vec![arg_type.clone()], - return_type: DataType::Number(T::data_type()), - }, - eval: FunctionEval::Scalar { - calc_domain: Box::new(|ctx, d| { - let res_fn = move || match d[0].as_decimal().unwrap() { - DecimalDomain::Decimal128(d, size) => Some(SimpleDomain:: { - min: d.min.to_int(size.scale, ctx.rounding_mode)?, - max: d.max.to_int(size.scale, ctx.rounding_mode)?, - }), - DecimalDomain::Decimal256(d, size) => Some(SimpleDomain:: { - min: d.min.to_int(size.scale, ctx.rounding_mode)?, - max: d.max.to_int(size.scale, ctx.rounding_mode)?, - }), - }; - - res_fn() - .map(|d| FunctionDomain::Domain(Domain::Number(T::upcast_domain(d)))) - .unwrap_or(FunctionDomain::MayThrow) - }), - eval: Box::new(move |args, tx| decimal_to_int::(&args[0], arg_type.clone(), tx)), - }, - }; - - Some(function) - }; - - registry.register_function_factory(&name, move |params, args_type| { - Some(Arc::new(factory(params, args_type)?)) - }); - registry.register_function_factory(&name, move |params, args_type| { - let f = factory(params, args_type)?; - Some(Arc::new(f.passthrough_nullable())) - }); - registry.register_function_factory(&try_name, move |params, args_type| { - let mut f = factory(params, args_type)?; - f.signature.name = format!("try_to_{}", T::data_type().to_string().to_lowercase()); - Some(Arc::new(f.error_to_null())) - }); - registry.register_function_factory(&try_name, move |params, args_type| { - let mut f = factory(params, args_type)?; - f.signature.name = format!("try_to_{}", T::data_type().to_string().to_lowercase()); - Some(Arc::new(f.error_to_null().passthrough_nullable())) - }); -} - -fn convert_to_decimal( - arg: &ValueRef, - ctx: &mut EvalContext, - from_type: &DataType, - dest_type: DecimalDataType, -) -> Value { - match from_type { - DataType::Boolean => boolean_to_decimal(arg, dest_type), - DataType::Number(ty) => { - if ty.is_float() { - float_to_decimal(arg, ctx, *ty, dest_type) - } else { - integer_to_decimal(arg, ctx, *ty, dest_type) - } - } - DataType::Decimal(from) => decimal_to_decimal(arg, ctx, *from, dest_type), - DataType::String => string_to_decimal(arg, ctx, dest_type), - _ => unreachable!("to_decimal not support this DataType"), - } -} - -fn convert_to_decimal_domain( - func_ctx: &FunctionContext, - domain: Domain, - dest_type: DecimalDataType, -) -> Option { - // Convert the domain to a Column. - // The first row is the min value, the second row is the max value. - let column = match domain { - Domain::Number(number_domain) => { - with_number_mapped_type!(|NUM_TYPE| match number_domain { - NumberDomain::NUM_TYPE(d) => { - let min = d.min; - let max = d.max; - NumberType::::from_data(vec![min, max]) - } - }) - } - Domain::Boolean(d) => { - let min = !d.has_false; - let max = d.has_true; - BooleanType::from_data(vec![min, max]) - } - Domain::Decimal(d) => { - with_decimal_mapped_type!(|DECIMAL| match d { - DecimalDomain::DECIMAL(d, size) => { - let min = d.min; - let max = d.max; - DecimalType::from_data_with_size(vec![min, max], size) - } - }) - } - Domain::String(d) => { - let min = d.min; - let max = d.max?; - StringType::from_data(vec![min, max]) - } - _ => { - return None; - } - }; - - let from_type = column.data_type(); - let value = Value::::Column(column); - let mut ctx = EvalContext { - generics: &[], - num_rows: 2, - func_ctx, - validity: None, - errors: None, - }; - let dest_size = dest_type.size(); - let res = convert_to_decimal(&value.as_ref(), &mut ctx, &from_type, dest_type); - - if ctx.errors.is_some() { - return None; - } - let decimal_col = res.as_column()?.as_decimal()?; - assert_eq!(decimal_col.len(), 2); - - Some(match decimal_col { - DecimalColumn::Decimal128(buf, size) => { - assert_eq!(&dest_size, size); - let (min, max) = unsafe { (*buf.get_unchecked(0), *buf.get_unchecked(1)) }; - DecimalDomain::Decimal128(SimpleDomain { min, max }, *size) - } - DecimalColumn::Decimal256(buf, size) => { - assert_eq!(&dest_size, size); - let (min, max) = unsafe { (*buf.get_unchecked(0), *buf.get_unchecked(1)) }; - DecimalDomain::Decimal256(SimpleDomain { min, max }, *size) - } - }) -} - -fn boolean_to_decimal_column( - boolean_column: &Bitmap, - size: DecimalSize, -) -> DecimalColumn { - let mut values = Vec::::with_capacity(boolean_column.len()); - for val in boolean_column.iter() { - if val { - values.push(T::e(size.scale as u32)); - } else { - values.push(T::zero()); - } - } - T::to_column(values, size) -} - -fn boolean_to_decimal_scalar(val: bool, size: DecimalSize) -> DecimalScalar { - if val { - T::to_scalar(T::e(size.scale as u32), size) - } else { - T::to_scalar(T::zero(), size) - } -} - -fn boolean_to_decimal(arg: &ValueRef, dest_type: DecimalDataType) -> Value { - match arg { - ValueRef::Column(column) => { - let boolean_column = BooleanType::try_downcast_column(column).unwrap(); - let column = match dest_type { - DecimalDataType::Decimal128(size) => { - boolean_to_decimal_column::(&boolean_column, size) - } - DecimalDataType::Decimal256(size) => { - boolean_to_decimal_column::(&boolean_column, size) - } - }; - Value::Column(Column::Decimal(column)) - } - ValueRef::Scalar(scalar) => { - let val = BooleanType::try_downcast_scalar(scalar).unwrap(); - let scalar = match dest_type { - DecimalDataType::Decimal128(size) => boolean_to_decimal_scalar::(val, size), - DecimalDataType::Decimal256(size) => boolean_to_decimal_scalar::(val, size), - }; - Value::Scalar(Scalar::Decimal(scalar)) - } - } -} - -fn string_to_decimal_column( - ctx: &mut EvalContext, - string_column: &StringColumn, - size: DecimalSize, - rounding_mode: bool, -) -> DecimalColumn { - let mut values = Vec::::with_capacity(string_column.len()); - for (row, buf) in string_column.iter().enumerate() { - match read_decimal_with_size::(buf, size, true, rounding_mode) { - Ok((d, _)) => values.push(d), - Err(e) => { - ctx.set_error(row, e.message()); - values.push(T::zero()) - } - } - } - T::to_column(values, size) -} - -fn string_to_decimal_scalar( - ctx: &mut EvalContext, - string_buf: &[u8], - size: DecimalSize, - rounding_mode: bool, -) -> DecimalScalar { - let value = match read_decimal_with_size::(string_buf, size, true, rounding_mode) { - Ok((d, _)) => d, - Err(e) => { - ctx.set_error(0, e.message()); - T::zero() - } - }; - T::to_scalar(value, size) -} - -fn string_to_decimal( - arg: &ValueRef, - ctx: &mut EvalContext, - dest_type: DecimalDataType, -) -> Value { - let rounding_mode = ctx.func_ctx.rounding_mode; - match arg { - ValueRef::Column(column) => { - let string_column = StringType::try_downcast_column(column).unwrap(); - let column = match dest_type { - DecimalDataType::Decimal128(size) => { - string_to_decimal_column::(ctx, &string_column, size, rounding_mode) - } - DecimalDataType::Decimal256(size) => { - string_to_decimal_column::(ctx, &string_column, size, rounding_mode) - } - }; - Value::Column(Column::Decimal(column)) - } - ValueRef::Scalar(scalar) => { - let buf = StringType::try_downcast_scalar(scalar).unwrap(); - let scalar = match dest_type { - DecimalDataType::Decimal128(size) => { - string_to_decimal_scalar::(ctx, buf, size, rounding_mode) - } - DecimalDataType::Decimal256(size) => { - string_to_decimal_scalar::(ctx, buf, size, rounding_mode) - } - }; - Value::Scalar(Scalar::Decimal(scalar)) - } - } -} - -fn integer_to_decimal( - arg: &ValueRef, - ctx: &mut EvalContext, - from_type: NumberDataType, - dest_type: DecimalDataType, -) -> Value { - let mut is_scalar = false; - let column = match arg { - ValueRef::Column(column) => column.clone(), - ValueRef::Scalar(s) => { - is_scalar = true; - let builder = ColumnBuilder::repeat(s, 1, &DataType::Number(from_type)); - builder.build() - } - }; - - let result = with_integer_mapped_type!(|NUM_TYPE| match from_type { - NumberDataType::NUM_TYPE => { - let column = NumberType::::try_downcast_column(&column).unwrap(); - integer_to_decimal_internal(column, ctx, &dest_type) - } - _ => unreachable!(), - }); - - if is_scalar { - let scalar = result.index(0).unwrap(); - Value::Scalar(Scalar::Decimal(scalar)) - } else { - Value::Column(Column::Decimal(result)) - } -} - -macro_rules! m_integer_to_decimal { - ($from: expr, $size: expr, $type_name: ty, $ctx: expr) => { - let multiplier = <$type_name>::e($size.scale as u32); - let min_for_precision = <$type_name>::min_for_precision($size.precision); - let max_for_precision = <$type_name>::max_for_precision($size.precision); - - let values = $from - .iter() - .enumerate() - .map(|(row, x)| { - let x = x.as_() * <$type_name>::one(); - let x = x.checked_mul(multiplier).and_then(|v| { - if v > max_for_precision || v < min_for_precision { - None - } else { - Some(v) - } - }); - - match x { - Some(x) => x, - None => { - $ctx.set_error(row, concat!("Decimal overflow at line : ", line!())); - <$type_name>::one() - } - } - }) - .collect(); - <$type_name>::to_column(values, $size) - }; -} - -fn integer_to_decimal_internal>( - from: Buffer, - ctx: &mut EvalContext, - dest_type: &DecimalDataType, -) -> DecimalColumn { - match dest_type { - DecimalDataType::Decimal128(size) => { - m_integer_to_decimal! {from, *size, i128, ctx} - } - DecimalDataType::Decimal256(size) => { - m_integer_to_decimal! {from, *size, i256, ctx} - } - } -} - -macro_rules! m_float_to_decimal { - ($from: expr, $size: expr, $type_name: ty, $ctx: expr) => { - let multiplier: f64 = (10_f64).powi($size.scale as i32).as_(); - - let min_for_precision = <$type_name>::min_for_precision($size.precision); - let max_for_precision = <$type_name>::max_for_precision($size.precision); - - let values = $from - .iter() - .enumerate() - .map(|(row, x)| { - let mut x = x.as_() * multiplier; - if $ctx.func_ctx.rounding_mode { - x = x.round(); - } - let x = <$type_name>::from_float(x); - if x > max_for_precision || x < min_for_precision { - $ctx.set_error(row, concat!("Decimal overflow at line : ", line!())); - <$type_name>::one() - } else { - x - } - }) - .collect(); - <$type_name>::to_column(values, $size) - }; -} - -fn float_to_decimal( - arg: &ValueRef, - ctx: &mut EvalContext, - from_type: NumberDataType, - dest_type: DecimalDataType, -) -> Value { - let mut is_scalar = false; - let column = match arg { - ValueRef::Column(column) => column.clone(), - ValueRef::Scalar(s) => { - is_scalar = true; - let builder = ColumnBuilder::repeat(s, 1, &DataType::Number(from_type)); - builder.build() - } - }; - - let result = match from_type { - NumberDataType::Float32 => { - let column = NumberType::::try_downcast_column(&column).unwrap(); - float_to_decimal_internal(column, ctx, &dest_type) - } - NumberDataType::Float64 => { - let column = NumberType::::try_downcast_column(&column).unwrap(); - float_to_decimal_internal(column, ctx, &dest_type) - } - _ => unreachable!(), - }; - if is_scalar { - let scalar = result.index(0).unwrap(); - Value::Scalar(Scalar::Decimal(scalar)) - } else { - Value::Column(Column::Decimal(result)) - } -} - -fn float_to_decimal_internal>( - from: Buffer, - ctx: &mut EvalContext, - dest_type: &DecimalDataType, -) -> DecimalColumn { - match dest_type { - DecimalDataType::Decimal128(size) => { - m_float_to_decimal! {from, *size, i128, ctx} - } - DecimalDataType::Decimal256(size) => { - m_float_to_decimal! {from, *size, i256, ctx} - } - } -} - -fn get_round_val(x: T, scale: u32, ctx: &mut EvalContext) -> Option { - let mut round_val = None; - if ctx.func_ctx.rounding_mode && scale > 0 { - // Checking whether numbers need to be added or subtracted to calculate rounding - if let Some(r) = x.checked_rem(T::e(scale)) { - if let Some(m) = r.checked_div(T::e(scale - 1)) { - if m >= T::from_i64(5i64) { - round_val = Some(T::one()); - } else if m <= T::from_i64(-5i64) { - round_val = Some(T::minus_one()); - } - } - } - } - round_val -} - -fn decimal_256_to_128( - buffer: Buffer, - from_size: DecimalSize, - dest_size: DecimalSize, - ctx: &mut EvalContext, -) -> DecimalColumn { - let max = i128::max_for_precision(dest_size.precision); - let min = i128::min_for_precision(dest_size.precision); - - let values = if dest_size.scale >= from_size.scale { - let factor = i256::e((dest_size.scale - from_size.scale) as u32); - buffer - .iter() - .enumerate() - .map(|(row, x)| { - let x = x * i128::one(); - match x.checked_mul(factor) { - Some(x) if x <= max && x >= min => *x.low(), - _ => { - ctx.set_error(row, concat!("Decimal overflow at line : ", line!())); - i128::one() - } - } - }) - .collect() - } else { - let scale_diff = (from_size.scale - dest_size.scale) as u32; - let factor = i256::e(scale_diff); - let source_factor = i256::e(from_size.scale as u32); - - buffer - .iter() - .enumerate() - .map(|(row, x)| { - let x = x * i128::one(); - let round_val = get_round_val::(x, scale_diff, ctx); - let y = match (x.checked_div(factor), round_val) { - (Some(x), Some(round_val)) => x.checked_add(round_val), - (Some(x), None) => Some(x), - (None, _) => None, - }; - - match y { - Some(y) if (y <= max && y >= min) && (y != 0 || x / source_factor == 0) => { - *y.low() - } - _ => { - ctx.set_error(row, concat!("Decimal overflow at line : ", line!())); - i128::one() - } - } - }) - .collect() - }; - i128::to_column(values, dest_size) -} - -macro_rules! m_decimal_to_decimal { - ($from_size: expr, $dest_size: expr, $buffer: expr, $from_type_name: ty, $dest_type_name: ty, $ctx: expr) => { - // faster path - if $from_size.scale == $dest_size.scale && $from_size.precision <= $dest_size.precision { - if <$from_type_name>::MAX == <$dest_type_name>::MAX { - // 128 -> 128 or 256 -> 256 - <$from_type_name>::to_column_from_buffer($buffer, $dest_size) - } else { - // 128 -> 256 - let buffer = $buffer - .into_iter() - .map(|x| x * <$dest_type_name>::one()) - .collect(); - <$dest_type_name>::to_column(buffer, $dest_size) - } - } else { - let values: Vec<_> = if $from_size.scale > $dest_size.scale { - let scale_diff = ($from_size.scale - $dest_size.scale) as u32; - let factor = <$dest_type_name>::e(scale_diff); - let max = <$dest_type_name>::max_for_precision($dest_size.precision); - let min = <$dest_type_name>::min_for_precision($dest_size.precision); - - let source_factor = <$from_type_name>::e($from_size.scale as u32); - $buffer - .iter() - .enumerate() - .map(|(row, x)| { - let x = x * <$dest_type_name>::one(); - let round_val = get_round_val::<$dest_type_name>(x, scale_diff, $ctx); - let y = match (x.checked_div(factor), round_val) { - (Some(x), Some(round_val)) => x.checked_add(round_val), - (Some(x), None) => Some(x), - (None, _) => None, - }; - match y { - Some(y) - if y <= max && y >= min && (y != 0 || x / source_factor == 0) => - { - y as $dest_type_name - } - _ => { - $ctx.set_error( - row, - concat!("Decimal overflow at line : ", line!()), - ); - <$dest_type_name>::one() - } - } - }) - .collect() - } else { - let factor = <$dest_type_name>::e(($dest_size.scale - $from_size.scale) as u32); - let max = <$dest_type_name>::max_for_precision($dest_size.precision); - let min = <$dest_type_name>::min_for_precision($dest_size.precision); - $buffer - .iter() - .enumerate() - .map(|(row, x)| { - let x = x * <$dest_type_name>::one(); - match x.checked_mul(factor) { - Some(x) if x <= max && x >= min => x as $dest_type_name, - _ => { - $ctx.set_error( - row, - concat!("Decimal overflow at line : ", line!()), - ); - <$dest_type_name>::one() - } - } - }) - .collect() - }; - <$dest_type_name>::to_column(values, $dest_size) - } - }; -} - -fn decimal_to_decimal( - arg: &ValueRef, - ctx: &mut EvalContext, - from_type: DecimalDataType, - dest_type: DecimalDataType, -) -> Value { - let mut is_scalar = false; - let column = match arg { - ValueRef::Column(column) => column.clone(), - ValueRef::Scalar(s) => { - is_scalar = true; - let builder = ColumnBuilder::repeat(s, 1, &DataType::Decimal(from_type)); - builder.build() - } - }; - - let result: DecimalColumn = match (from_type, dest_type) { - (DecimalDataType::Decimal128(_), DecimalDataType::Decimal128(dest_size)) => { - let (buffer, from_size) = i128::try_downcast_column(&column).unwrap(); - m_decimal_to_decimal! {from_size, dest_size, buffer, i128, i128, ctx} - } - (DecimalDataType::Decimal128(_), DecimalDataType::Decimal256(dest_size)) => { - let (buffer, from_size) = i128::try_downcast_column(&column).unwrap(); - m_decimal_to_decimal! {from_size, dest_size, buffer, i128, i256, ctx} - } - (DecimalDataType::Decimal256(_), DecimalDataType::Decimal256(dest_size)) => { - let (buffer, from_size) = i256::try_downcast_column(&column).unwrap(); - m_decimal_to_decimal! {from_size, dest_size, buffer, i256, i256, ctx} - } - (DecimalDataType::Decimal256(_), DecimalDataType::Decimal128(dest_size)) => { - let (buffer, from_size) = i256::try_downcast_column(&column).unwrap(); - decimal_256_to_128(buffer, from_size, dest_size, ctx) - } - }; - - if is_scalar { - let scalar = result.index(0).unwrap(); - Value::Scalar(Scalar::Decimal(scalar)) - } else { - Value::Column(Column::Decimal(result)) - } -} - -fn decimal_to_float64( - arg: &ValueRef, - from_type: DataType, - _ctx: &mut EvalContext, -) -> Value { - let mut is_scalar = false; - let column = match arg { - ValueRef::Column(column) => column.clone(), - ValueRef::Scalar(s) => { - is_scalar = true; - let builder = ColumnBuilder::repeat(s, 1, &from_type); - builder.build() - } - }; - - let from_type = from_type.as_decimal().unwrap(); - - let result = match from_type { - DecimalDataType::Decimal128(_) => { - let (buffer, from_size) = i128::try_downcast_column(&column).unwrap(); - - let div = 10_f64.powi(from_size.scale as i32); - - let values: Buffer = buffer.iter().map(|x| (*x as f64 / div).into()).collect(); - Float64Type::upcast_column(values) - } - - DecimalDataType::Decimal256(_) => { - let (buffer, from_size) = i256::try_downcast_column(&column).unwrap(); - - let div = 10_f64.powi(from_size.scale as i32); - - let values: Buffer = buffer - .iter() - .map(|x| (f64::from(*x) / div).into()) - .collect(); - Float64Type::upcast_column(values) - } - }; - - if is_scalar { - let scalar = result.index(0).unwrap(); - Value::Scalar(scalar.to_owned()) - } else { - Value::Column(result) - } -} - -fn decimal_to_float32( - arg: &ValueRef, - from_type: DataType, - _ctx: &mut EvalContext, -) -> Value { - let mut is_scalar = false; - let column = match arg { - ValueRef::Column(column) => column.clone(), - ValueRef::Scalar(s) => { - is_scalar = true; - let builder = ColumnBuilder::repeat(s, 1, &from_type); - builder.build() - } - }; - - let from_type = from_type.as_decimal().unwrap(); - - let result = match from_type { - DecimalDataType::Decimal128(_) => { - let (buffer, from_size) = i128::try_downcast_column(&column).unwrap(); - - let div = 10_f32.powi(from_size.scale as i32); - - let values: Buffer = buffer.iter().map(|x| (*x as f32 / div).into()).collect(); - Float32Type::upcast_column(values) - } - - DecimalDataType::Decimal256(_) => { - let (buffer, from_size) = i256::try_downcast_column(&column).unwrap(); - - let div = 10_f32.powi(from_size.scale as i32); - - let values: Buffer = buffer - .iter() - .map(|x| (f32::from(*x) / div).into()) - .collect(); - Float32Type::upcast_column(values) - } - }; - - if is_scalar { - let scalar = result.index(0).unwrap(); - Value::Scalar(scalar.to_owned()) - } else { - Value::Column(result) - } -} - -fn decimal_to_int( - arg: &ValueRef, - from_type: DataType, - ctx: &mut EvalContext, -) -> Value { - let mut is_scalar = false; - let column = match arg { - ValueRef::Column(column) => column.clone(), - ValueRef::Scalar(s) => { - is_scalar = true; - let builder = ColumnBuilder::repeat(s, 1, &from_type); - builder.build() - } - }; - - let from_type = from_type.as_decimal().unwrap(); - - let result = match from_type { - DecimalDataType::Decimal128(_) => { - let (buffer, from_size) = i128::try_downcast_column(&column).unwrap(); - - let mut values = Vec::with_capacity(ctx.num_rows); - - for (i, x) in buffer.iter().enumerate() { - match x.to_int(from_size.scale, ctx.func_ctx.rounding_mode) { - Some(x) => values.push(x), - None => { - ctx.set_error(i, "decimal cast to int overflow"); - values.push(T::default()) - } - } - } - - NumberType::::upcast_column(Buffer::from(values)) - } - - DecimalDataType::Decimal256(_) => { - let (buffer, from_size) = i256::try_downcast_column(&column).unwrap(); - let mut values = Vec::with_capacity(ctx.num_rows); - - for (i, x) in buffer.iter().enumerate() { - match x.to_int(from_size.scale, ctx.func_ctx.rounding_mode) { - Some(x) => values.push(x), - None => { - ctx.set_error(i, "decimal cast to int overflow"); - values.push(T::default()) - } - } - } - NumberType::::upcast_column(Buffer::from(values)) - } - }; - - if is_scalar { - let scalar = result.index(0).unwrap(); - Value::Scalar(scalar.to_owned()) - } else { - Value::Column(result) - } -} - -pub fn register_decimal_math(registry: &mut FunctionRegistry) { - let factory = |params: &[usize], args_type: &[DataType], round_mode: RoundMode| { - if args_type.is_empty() { - return None; - } - - let from_type = args_type[0].remove_nullable(); - if !matches!(from_type, DataType::Decimal(_)) { - return None; - } - - let from_decimal_type = from_type.as_decimal().unwrap(); - - let scale = if params.is_empty() { - 0 - } else { - params[0] as i64 - 76 - }; - - let decimal_size = DecimalSize { - precision: from_decimal_type.precision(), - scale: scale.clamp(0, from_decimal_type.scale() as i64) as u8, - }; - - let dest_decimal_type = DecimalDataType::from_size(decimal_size).ok()?; - let name = format!("{:?}", round_mode).to_lowercase(); - - let mut sig_args_type = args_type.to_owned(); - sig_args_type[0] = from_type.clone(); - let f = Function { - signature: FunctionSignature { - name, - args_type: sig_args_type, - return_type: DataType::Decimal(dest_decimal_type), - }, - eval: FunctionEval::Scalar { - calc_domain: Box::new(move |_ctx, _d| FunctionDomain::Full), - eval: Box::new(move |args, _ctx| { - decimal_round_truncate( - &args[0], - from_type.clone(), - dest_decimal_type, - scale, - round_mode, - ) - }), - }, - }; - - if args_type[0].is_nullable() { - Some(f.passthrough_nullable()) - } else { - Some(f) - } - }; - - for m in [ - RoundMode::Round, - RoundMode::Truncate, - RoundMode::Ceil, - RoundMode::Floor, - ] { - let name = format!("{:?}", m).to_lowercase(); - registry.register_function_factory(&name, move |params, args_type| { - Some(Arc::new(factory(params, args_type, m)?)) - }); - } -} - -#[derive(Copy, Clone, Debug)] -enum RoundMode { - Round, - Truncate, - Floor, - Ceil, -} - -fn decimal_round_positive(values: &[T], source_scale: i64, target_scale: i64) -> Vec -where T: Decimal + From + DivAssign + Div + Add + Sub { - let power_of_ten = T::e((source_scale - target_scale) as u32); - let addition = power_of_ten / T::from(2); - - values - .iter() - .map(|input| { - let input = if input < &T::zero() { - *input - addition - } else { - *input + addition - }; - input / power_of_ten - }) - .collect() -} - -fn decimal_round_negative(values: &[T], source_scale: i64, target_scale: i64) -> Vec -where T: Decimal - + From - + DivAssign - + Div - + Add - + Sub - + Mul { - let divide_power_of_ten = T::e((source_scale - target_scale) as u32); - let addition = divide_power_of_ten / T::from(2); - let multiply_power_of_ten = T::e((-target_scale) as u32); - - values - .iter() - .map(|input| { - let input = if input < &T::zero() { - *input - addition - } else { - *input + addition - }; - input / divide_power_of_ten * multiply_power_of_ten - }) - .collect() -} - -// if round mode is ceil, truncate should add one value -fn decimal_truncate_positive(values: &[T], source_scale: i64, target_scale: i64) -> Vec -where T: Decimal + From + DivAssign + Div + Add + Sub { - let power_of_ten = T::e((source_scale - target_scale) as u32); - - values.iter().map(|input| *input / power_of_ten).collect() -} - -fn decimal_truncate_negative(values: &[T], source_scale: i64, target_scale: i64) -> Vec -where T: Decimal - + From - + DivAssign - + Div - + Add - + Sub - + Mul { - let divide_power_of_ten = T::e((source_scale - target_scale) as u32); - let multiply_power_of_ten = T::e((-target_scale) as u32); - - values - .iter() - .map(|input| *input / divide_power_of_ten * multiply_power_of_ten) - .collect() -} - -fn decimal_floor(values: &[T], source_scale: i64) -> Vec -where T: Decimal - + From - + DivAssign - + Div - + Add - + Sub - + Mul { - let power_of_ten = T::e(source_scale as u32); - - values - .iter() - .map(|input| { - if input < &T::zero() { - // below 0 we ceil the number (e.g. -10.5 -> -11) - ((*input + T::one()) / power_of_ten) - T::one() - } else { - *input / power_of_ten - } - }) - .collect() -} - -fn decimal_ceil(values: &[T], source_scale: i64) -> Vec -where T: Decimal - + From - + DivAssign - + Div - + Add - + Sub - + Mul { - let power_of_ten = T::e(source_scale as u32); - - values - .iter() - .map(|input| { - if input <= &T::zero() { - *input / power_of_ten - } else { - ((*input - T::one()) / power_of_ten) + T::one() - } - }) - .collect() -} - -fn decimal_round_truncate( - arg: &ValueRef, - from_type: DataType, - dest_type: DecimalDataType, - target_scale: i64, - mode: RoundMode, -) -> Value { - let from_decimal_type = from_type.as_decimal().unwrap(); - let source_scale = from_decimal_type.scale() as i64; - - if source_scale < target_scale { - return arg.clone().to_owned(); - } - - let mut is_scalar = false; - let column = match arg { - ValueRef::Column(column) => column.clone(), - ValueRef::Scalar(s) => { - is_scalar = true; - let builder = ColumnBuilder::repeat(s, 1, &from_type); - builder.build() - } - }; - - let none_negative = target_scale >= 0; - - let result = match from_decimal_type { - DecimalDataType::Decimal128(_) => { - let (buffer, _) = i128::try_downcast_column(&column).unwrap(); - - let result = match (none_negative, mode) { - (true, RoundMode::Round) => { - decimal_round_positive::<_>(&buffer, source_scale, target_scale) - } - (true, RoundMode::Truncate) => { - decimal_truncate_positive::<_>(&buffer, source_scale, target_scale) - } - (false, RoundMode::Round) => { - decimal_round_negative::<_>(&buffer, source_scale, target_scale) - } - (false, RoundMode::Truncate) => { - decimal_truncate_negative::<_>(&buffer, source_scale, target_scale) - } - (_, RoundMode::Floor) => decimal_floor::<_>(&buffer, source_scale), - (_, RoundMode::Ceil) => decimal_ceil::<_>(&buffer, source_scale), - }; - i128::to_column(result, dest_type.size()) - } - - DecimalDataType::Decimal256(_) => { - let (buffer, _) = i256::try_downcast_column(&column).unwrap(); - let result = match (none_negative, mode) { - (true, RoundMode::Round) => { - decimal_round_positive::<_>(&buffer, source_scale, target_scale) - } - (true, RoundMode::Truncate) => { - decimal_truncate_positive::<_>(&buffer, source_scale, target_scale) - } - (false, RoundMode::Round) => { - decimal_round_negative::<_>(&buffer, source_scale, target_scale) - } - (false, RoundMode::Truncate) => { - decimal_truncate_negative::<_>(&buffer, source_scale, target_scale) - } - (_, RoundMode::Floor) => decimal_floor::<_>(&buffer, source_scale), - (_, RoundMode::Ceil) => decimal_ceil::<_>(&buffer, source_scale), - }; - i256::to_column(result, dest_type.size()) - } - }; - - let result = Column::Decimal(result); - if is_scalar { - let scalar = result.index(0).unwrap(); - Value::Scalar(scalar.to_owned()) - } else { - Value::Column(result) - } -} diff --git a/src/query/functions/src/scalars/decimal/arithmetic.rs b/src/query/functions/src/scalars/decimal/arithmetic.rs new file mode 100644 index 000000000000..dec2bfecb87f --- /dev/null +++ b/src/query/functions/src/scalars/decimal/arithmetic.rs @@ -0,0 +1,363 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::*; +use std::sync::Arc; + +use databend_common_expression::types::decimal::*; +use databend_common_expression::types::*; +use databend_common_expression::vectorize_2_arg; +use databend_common_expression::vectorize_with_builder_2_arg; +use databend_common_expression::Domain; +use databend_common_expression::EvalContext; +use databend_common_expression::Function; +use databend_common_expression::FunctionDomain; +use databend_common_expression::FunctionEval; +use databend_common_expression::FunctionRegistry; +use databend_common_expression::FunctionSignature; +use ethnum::i256; + +#[derive(Copy, Clone, Debug)] +enum ArithmeticOp { + Plus, + Minus, + Multiply, + Divide, +} + +macro_rules! op_decimal { + ($a: expr, $b: expr, $ctx: expr, $left: expr, $right: expr, $result_type: expr, $op: ident, $arithmetic_op: expr) => { + match $left { + DecimalDataType::Decimal128(_) => { + binary_decimal!( + $a, + $b, + $ctx, + $left, + $right, + $op, + $result_type.size(), + i128, + $arithmetic_op + ) + } + DecimalDataType::Decimal256(_) => { + binary_decimal!( + $a, + $b, + $ctx, + $left, + $right, + $op, + $result_type.size(), + i256, + $arithmetic_op + ) + } + } + }; +} + +macro_rules! binary_decimal { + ($a: expr, $b: expr, $ctx: expr, $left: expr, $right: expr, $op: ident, $size: expr, $type_name: ty, $arithmetic_op: expr) => {{ + type T = $type_name; + + let overflow = $size.precision == T::default_decimal_size().precision; + + let a = $a.try_downcast().unwrap(); + let b = $b.try_downcast().unwrap(); + + let zero = T::zero(); + let one = T::one(); + + let result = if matches!($arithmetic_op, ArithmeticOp::Divide) { + let scale_a = $left.scale(); + let scale_b = $right.scale(); + + let (scale_mul, scale_div) = if scale_b + $size.scale > scale_a { + (scale_b + $size.scale - scale_a, 0) + } else { + (0, scale_b + $size.scale - scale_a) + }; + + let multiplier = T::e(scale_mul as u32); + let div = T::e(scale_div as u32); + + let func = |a: T, b: T, result: &mut Vec, ctx: &mut EvalContext| { + if std::intrinsics::unlikely(b == zero) { + ctx.set_error(result.len(), "divided by zero"); + result.push(one); + } else { + result.push((a * multiplier).div(b) / div); + } + }; + + vectorize_with_builder_2_arg::, DecimalType, DecimalType>(func)( + a, b, $ctx, + ) + } else { + if overflow { + let min_for_precision = T::min_for_precision($size.precision); + let max_for_precision = T::max_for_precision($size.precision); + + let func = |a: T, b: T, result: &mut Vec, ctx: &mut EvalContext| { + let t = a.$op(b); + if t < min_for_precision || t > max_for_precision { + ctx.set_error( + result.len(), + concat!("Decimal overflow at line : ", line!()), + ); + result.push(one); + } else { + result.push(t); + } + }; + + vectorize_with_builder_2_arg::, DecimalType, DecimalType>(func)( + a, b, $ctx + ) + } else { + let func = |l: T, r: T, _ctx: &mut EvalContext| l.$op(r); + + vectorize_2_arg::, DecimalType, DecimalType>(func)( + a, b, $ctx + ) + } + }; + result.upcast_decimal($size) + }}; +} + +#[inline(always)] +fn domain_plus( + lhs: &SimpleDomain, + rhs: &SimpleDomain, + precision: u8, +) -> Option> { + // For plus, the scale of the two operands must be the same. + let min = T::min_for_precision(precision); + let max = T::max_for_precision(precision); + Some(SimpleDomain { + min: lhs + .min + .checked_add(rhs.min) + .filter(|&m| m >= min && m <= max)?, + max: lhs + .max + .checked_add(rhs.max) + .filter(|&m| m >= min && m <= max)?, + }) +} + +#[inline(always)] +fn domain_minus( + lhs: &SimpleDomain, + rhs: &SimpleDomain, + precision: u8, +) -> Option> { + // For minus, the scale of the two operands must be the same. + let min = T::min_for_precision(precision); + let max = T::max_for_precision(precision); + Some(SimpleDomain { + min: lhs + .min + .checked_sub(rhs.max) + .filter(|&m| m >= min && m <= max)?, + max: lhs + .max + .checked_sub(rhs.min) + .filter(|&m| m >= min && m <= max)?, + }) +} + +#[inline(always)] +fn domain_mul( + lhs: &SimpleDomain, + rhs: &SimpleDomain, + precision: u8, +) -> Option> { + let min = T::min_for_precision(precision); + let max = T::max_for_precision(precision); + + let a = lhs + .min + .checked_mul(rhs.min) + .filter(|&m| m >= min && m <= max)?; + let b = lhs + .min + .checked_mul(rhs.max) + .filter(|&m| m >= min && m <= max)?; + let c = lhs + .max + .checked_mul(rhs.min) + .filter(|&m| m >= min && m <= max)?; + let d = lhs + .max + .checked_mul(rhs.max) + .filter(|&m| m >= min && m <= max)?; + + Some(SimpleDomain { + min: a.min(b).min(c).min(d), + max: a.max(b).max(c).max(d), + }) +} + +#[inline(always)] +fn domain_div( + _lhs: &SimpleDomain, + _rhs: &SimpleDomain, + _precision: u8, +) -> Option> { + // For div, we cannot determine the domain. + None +} + +macro_rules! register_decimal_binary_op { + ($registry: expr, $arithmetic_op: expr, $op: ident, $domain_op: ident, $default_domain: expr) => { + let name = format!("{:?}", $arithmetic_op).to_lowercase(); + + $registry.register_function_factory(&name, |_, args_type| { + if args_type.len() != 2 { + return None; + } + + let has_nullable = args_type.iter().any(|x| x.is_nullable_or_null()); + let args_type: Vec = args_type.iter().map(|x| x.remove_nullable()).collect(); + + // number X decimal -> decimal + // decimal X number -> decimal + // decimal X decimal -> decimal + if !args_type[0].is_decimal() && !args_type[1].is_decimal() { + return None; + } + + let decimal_a = + DecimalDataType::from_size(args_type[0].get_decimal_properties()?).unwrap(); + let decimal_b = + DecimalDataType::from_size(args_type[1].get_decimal_properties()?).unwrap(); + + let is_multiply = matches!($arithmetic_op, ArithmeticOp::Multiply); + let is_divide = matches!($arithmetic_op, ArithmeticOp::Divide); + let is_plus_minus = !is_multiply && !is_divide; + + // left, right will unify to same width decimal, both 256 or both 128 + let (left, right, return_decimal_type) = DecimalDataType::binary_result_type( + &decimal_a, + &decimal_b, + is_multiply, + is_divide, + is_plus_minus, + ) + .ok()?; + + let function = Function { + signature: FunctionSignature { + name: format!("{:?}", $arithmetic_op).to_lowercase(), + args_type: vec![ + DataType::Decimal(left.clone()), + DataType::Decimal(right.clone()), + ], + return_type: DataType::Decimal(return_decimal_type), + }, + eval: FunctionEval::Scalar { + calc_domain: Box::new(move |_ctx, d| { + let lhs = d[0].as_decimal(); + let rhs = d[1].as_decimal(); + + if lhs.is_none() || rhs.is_none() { + return FunctionDomain::Full; + } + + let lhs = lhs.unwrap(); + let rhs = rhs.unwrap(); + + let size = return_decimal_type.size(); + + { + match (lhs, rhs) { + ( + DecimalDomain::Decimal128(d1, _), + DecimalDomain::Decimal128(d2, _), + ) => $domain_op(&d1, &d2, size.precision) + .map(|d| DecimalDomain::Decimal128(d, size)), + ( + DecimalDomain::Decimal256(d1, _), + DecimalDomain::Decimal256(d2, _), + ) => $domain_op(&d1, &d2, size.precision) + .map(|d| DecimalDomain::Decimal256(d, size)), + _ => { + unreachable!("unreachable decimal domain {:?} /{:?}", lhs, rhs) + } + } + } + .map(|d| FunctionDomain::Domain(Domain::Decimal(d))) + .unwrap_or($default_domain) + }), + eval: Box::new(move |args, ctx| { + let res = op_decimal!( + &args[0], + &args[1], + ctx, + left, + right, + return_decimal_type, + $op, + $arithmetic_op + ); + + res + }), + }, + }; + if has_nullable { + Some(Arc::new(function.passthrough_nullable())) + } else { + Some(Arc::new(function)) + } + }); + }; +} + +pub(crate) fn register_decimal_arithmetic(registry: &mut FunctionRegistry) { + // TODO checked overflow by default + register_decimal_binary_op!( + registry, + ArithmeticOp::Plus, + add, + domain_plus, + FunctionDomain::Full + ); + + register_decimal_binary_op!( + registry, + ArithmeticOp::Minus, + sub, + domain_minus, + FunctionDomain::Full + ); + register_decimal_binary_op!( + registry, + ArithmeticOp::Divide, + div, + domain_div, + FunctionDomain::MayThrow + ); + register_decimal_binary_op!( + registry, + ArithmeticOp::Multiply, + mul, + domain_mul, + FunctionDomain::Full + ); +} diff --git a/src/query/functions/src/scalars/decimal/cast.rs b/src/query/functions/src/scalars/decimal/cast.rs new file mode 100644 index 000000000000..58b3d9d83d60 --- /dev/null +++ b/src/query/functions/src/scalars/decimal/cast.rs @@ -0,0 +1,746 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::Mul; +use std::sync::Arc; + +use databend_common_expression::serialize::read_decimal_with_size; +use databend_common_expression::types::decimal::*; +use databend_common_expression::types::*; +use databend_common_expression::vectorize_1_arg; +use databend_common_expression::vectorize_with_builder_1_arg; +use databend_common_expression::with_decimal_mapped_type; +use databend_common_expression::with_integer_mapped_type; +use databend_common_expression::with_number_mapped_type; +use databend_common_expression::Domain; +use databend_common_expression::EvalContext; +use databend_common_expression::FromData; +use databend_common_expression::Function; +use databend_common_expression::FunctionContext; +use databend_common_expression::FunctionDomain; +use databend_common_expression::FunctionEval; +use databend_common_expression::FunctionRegistry; +use databend_common_expression::FunctionSignature; +use databend_common_expression::Value; +use databend_common_expression::ValueRef; +use ethnum::i256; +use num_traits::AsPrimitive; +use ordered_float::OrderedFloat; + +// int float to decimal +pub fn register_to_decimal(registry: &mut FunctionRegistry) { + let factory = |params: &[usize], args_type: &[DataType]| { + if args_type.len() != 1 { + return None; + } + if params.len() != 2 { + return None; + } + + let from_type = args_type[0].remove_nullable(); + + if !matches!( + from_type, + DataType::Boolean | DataType::Number(_) | DataType::Decimal(_) | DataType::String + ) { + return None; + } + + let decimal_size = DecimalSize { + precision: params[0] as u8, + scale: params[1] as u8, + }; + + let decimal_type = DecimalDataType::from_size(decimal_size).ok()?; + + Some(Function { + signature: FunctionSignature { + name: "to_decimal".to_string(), + args_type: vec![from_type.clone()], + return_type: DataType::Decimal(decimal_type), + }, + eval: FunctionEval::Scalar { + calc_domain: Box::new(move |ctx, d| { + convert_to_decimal_domain(ctx, d[0].clone(), decimal_type) + .map(|d| FunctionDomain::Domain(Domain::Decimal(d))) + .unwrap_or(FunctionDomain::MayThrow) + }), + eval: Box::new(move |args, ctx| { + convert_to_decimal(&args[0], ctx, &from_type, decimal_type) + }), + }, + }) + }; + + registry.register_function_factory("to_decimal", move |params, args_type| { + Some(Arc::new(factory(params, args_type)?)) + }); + registry.register_function_factory("to_decimal", move |params, args_type| { + let f = factory(params, args_type)?; + Some(Arc::new(f.passthrough_nullable())) + }); + registry.register_function_factory("try_to_decimal", move |params, args_type| { + let mut f = factory(params, args_type)?; + f.signature.name = "try_to_decimal".to_string(); + Some(Arc::new(f.error_to_null())) + }); + registry.register_function_factory("try_to_decimal", move |params, args_type| { + let mut f = factory(params, args_type)?; + f.signature.name = "try_to_decimal".to_string(); + Some(Arc::new(f.error_to_null().passthrough_nullable())) + }); +} + +pub(crate) fn register_decimal_to_float(registry: &mut FunctionRegistry) { + let data_type = NumberType::::data_type(); + debug_assert!(data_type.is_floating()); + + let is_f32 = matches!(data_type, DataType::Number(NumberDataType::Float32)); + + let factory = |_params: &[usize], args_type: &[DataType], data_type: DataType| { + if args_type.len() != 1 { + return None; + } + + let arg_type = args_type[0].remove_nullable(); + if !arg_type.is_decimal() { + return None; + } + let is_f32 = matches!(data_type, DataType::Number(NumberDataType::Float32)); + let name = if is_f32 { "to_float32" } else { "to_float64" }; + let calc_domain = if is_f32 { + Box::new(|_: &_, d: &[Domain]| { + with_decimal_mapped_type!(|DECIMAL_TYPE| match d[0].as_decimal().unwrap() { + DecimalDomain::DECIMAL_TYPE(d, size) => { + FunctionDomain::Domain(Domain::Number(NumberDomain::Float32( + SimpleDomain { + min: OrderedFloat(d.min.to_float32(size.scale)), + max: OrderedFloat(d.max.to_float32(size.scale)), + }, + ))) + } + }) + }) as _ + } else { + Box::new(|_: &_, d: &[Domain]| { + with_decimal_mapped_type!(|DECIMAL_TYPE| match d[0].as_decimal().unwrap() { + DecimalDomain::DECIMAL_TYPE(d, size) => { + FunctionDomain::Domain(Domain::Number(NumberDomain::Float64( + SimpleDomain { + min: OrderedFloat(d.min.to_float64(size.scale)), + max: OrderedFloat(d.max.to_float64(size.scale)), + }, + ))) + } + }) + }) as _ + }; + + let eval = if is_f32 { + let arg_type = arg_type.clone(); + Box::new(move |args: &[ValueRef], tx: &mut EvalContext| { + decimal_to_float::(&args[0], arg_type.clone(), tx) + }) as _ + } else { + let arg_type = arg_type.clone(); + Box::new(move |args: &[ValueRef], tx: &mut EvalContext| { + decimal_to_float::(&args[0], arg_type.clone(), tx) + }) as _ + }; + + let function = Function { + signature: FunctionSignature { + name: name.to_string(), + args_type: vec![arg_type.clone()], + return_type: data_type.clone(), + }, + eval: FunctionEval::Scalar { calc_domain, eval }, + }; + + Some(function) + }; + + let name = if is_f32 { "to_float32" } else { "to_float64" }; + + registry.register_function_factory(name, move |params, args_type| { + let data_type = NumberType::::data_type(); + Some(Arc::new(factory(params, args_type, data_type)?)) + }); + registry.register_function_factory(name, move |params, args_type| { + let data_type = NumberType::::data_type(); + let f = factory(params, args_type, data_type)?; + Some(Arc::new(f.passthrough_nullable())) + }); + registry.register_function_factory(&format!("try_{name}"), move |params, args_type| { + let data_type = NumberType::::data_type(); + let mut f = factory(params, args_type, data_type)?; + f.signature.name = format!("try_{name}"); + Some(Arc::new(f.error_to_null())) + }); + registry.register_function_factory(&format!("try_{name}"), move |params, args_type| { + let data_type = NumberType::::data_type(); + let mut f = factory(params, args_type, data_type)?; + f.signature.name = format!("try_{name}"); + Some(Arc::new(f.error_to_null().passthrough_nullable())) + }); +} + +pub(crate) fn register_decimal_to_int(registry: &mut FunctionRegistry) { + if T::data_type().is_float() { + return; + } + let name = format!("to_{}", T::data_type().to_string().to_lowercase()); + let try_name = format!("try_to_{}", T::data_type().to_string().to_lowercase()); + + let factory = |_params: &[usize], args_type: &[DataType]| { + if args_type.len() != 1 { + return None; + } + + let name = format!("to_{}", T::data_type().to_string().to_lowercase()); + let arg_type = args_type[0].remove_nullable(); + if !arg_type.is_decimal() { + return None; + } + + let function = Function { + signature: FunctionSignature { + name, + args_type: vec![arg_type.clone()], + return_type: DataType::Number(T::data_type()), + }, + eval: FunctionEval::Scalar { + calc_domain: Box::new(|ctx, d| { + let res_fn = move || match d[0].as_decimal().unwrap() { + DecimalDomain::Decimal128(d, size) => Some(SimpleDomain:: { + min: d.min.to_int(size.scale, ctx.rounding_mode)?, + max: d.max.to_int(size.scale, ctx.rounding_mode)?, + }), + DecimalDomain::Decimal256(d, size) => Some(SimpleDomain:: { + min: d.min.to_int(size.scale, ctx.rounding_mode)?, + max: d.max.to_int(size.scale, ctx.rounding_mode)?, + }), + }; + + res_fn() + .map(|d| FunctionDomain::Domain(Domain::Number(T::upcast_domain(d)))) + .unwrap_or(FunctionDomain::MayThrow) + }), + eval: Box::new(move |args, tx| decimal_to_int::(&args[0], arg_type.clone(), tx)), + }, + }; + + Some(function) + }; + + registry.register_function_factory(&name, move |params, args_type| { + Some(Arc::new(factory(params, args_type)?)) + }); + registry.register_function_factory(&name, move |params, args_type| { + let f = factory(params, args_type)?; + Some(Arc::new(f.passthrough_nullable())) + }); + registry.register_function_factory(&try_name, move |params, args_type| { + let mut f = factory(params, args_type)?; + f.signature.name = format!("try_to_{}", T::data_type().to_string().to_lowercase()); + Some(Arc::new(f.error_to_null())) + }); + registry.register_function_factory(&try_name, move |params, args_type| { + let mut f = factory(params, args_type)?; + f.signature.name = format!("try_to_{}", T::data_type().to_string().to_lowercase()); + Some(Arc::new(f.error_to_null().passthrough_nullable())) + }); +} + +fn convert_to_decimal( + arg: &ValueRef, + ctx: &mut EvalContext, + from_type: &DataType, + dest_type: DecimalDataType, +) -> Value { + if let DataType::Decimal(f) = from_type { + return decimal_to_decimal(arg, ctx, *f, dest_type); + } + + with_decimal_mapped_type!(|DECIMAL_TYPE| match dest_type { + DecimalDataType::DECIMAL_TYPE(size) => { + type T = DECIMAL_TYPE; + let result = match from_type { + DataType::Boolean => { + let arg = arg.try_downcast().unwrap(); + vectorize_1_arg::>(|a: bool, _| { + if a { + T::e(size.scale as u32) + } else { + T::zero() + } + })(arg, ctx) + } + + DataType::Number(ty) => { + if ty.is_float() { + match ty { + NumberDataType::Float32 => { + let arg = arg.try_downcast().unwrap(); + float_to_decimal::>(arg, ctx, size) + } + NumberDataType::Float64 => { + let arg = arg.try_downcast().unwrap(); + float_to_decimal::>(arg, ctx, size) + } + _ => unreachable!(), + } + } else { + with_integer_mapped_type!(|NUM_TYPE| match ty { + NumberDataType::NUM_TYPE => { + let arg = arg.try_downcast().unwrap(); + integer_to_decimal::>(arg, ctx, size) + } + _ => unreachable!(), + }) + } + } + DataType::String => { + let arg = arg.try_downcast().unwrap(); + string_to_decimal::(arg, ctx, size) + } + _ => unreachable!("to_decimal not support this DataType"), + }; + result.upcast_decimal(size) + } + }) +} + +fn convert_to_decimal_domain( + func_ctx: &FunctionContext, + domain: Domain, + dest_type: DecimalDataType, +) -> Option { + // Convert the domain to a Column. + // The first row is the min value, the second row is the max value. + let column = match domain { + Domain::Number(number_domain) => { + with_number_mapped_type!(|NUM_TYPE| match number_domain { + NumberDomain::NUM_TYPE(d) => { + let min = d.min; + let max = d.max; + NumberType::::from_data(vec![min, max]) + } + }) + } + Domain::Boolean(d) => { + let min = !d.has_false; + let max = d.has_true; + BooleanType::from_data(vec![min, max]) + } + Domain::Decimal(d) => { + with_decimal_mapped_type!(|DECIMAL| match d { + DecimalDomain::DECIMAL(d, size) => { + let min = d.min; + let max = d.max; + DecimalType::from_data_with_size(vec![min, max], size) + } + }) + } + Domain::String(d) => { + let min = d.min; + let max = d.max?; + StringType::from_data(vec![min, max]) + } + _ => { + return None; + } + }; + + let from_type = column.data_type(); + let value = Value::::Column(column); + let mut ctx = EvalContext { + generics: &[], + num_rows: 2, + func_ctx, + validity: None, + errors: None, + }; + let dest_size = dest_type.size(); + let res = convert_to_decimal(&value.as_ref(), &mut ctx, &from_type, dest_type); + + if ctx.errors.is_some() { + return None; + } + let decimal_col = res.as_column()?.as_decimal()?; + assert_eq!(decimal_col.len(), 2); + + Some(match decimal_col { + DecimalColumn::Decimal128(buf, size) => { + assert_eq!(&dest_size, size); + let (min, max) = unsafe { (*buf.get_unchecked(0), *buf.get_unchecked(1)) }; + DecimalDomain::Decimal128(SimpleDomain { min, max }, *size) + } + DecimalColumn::Decimal256(buf, size) => { + assert_eq!(&dest_size, size); + let (min, max) = unsafe { (*buf.get_unchecked(0), *buf.get_unchecked(1)) }; + DecimalDomain::Decimal256(SimpleDomain { min, max }, *size) + } + }) +} + +fn string_to_decimal( + from: ValueRef, + ctx: &mut EvalContext, + size: DecimalSize, +) -> Value> +where + T: Decimal + Mul, +{ + let f = |x: &[u8], builder: &mut Vec, ctx: &mut EvalContext| { + let value = match read_decimal_with_size::(x, size, true, ctx.func_ctx.rounding_mode) { + Ok((d, _)) => d, + Err(e) => { + ctx.set_error(builder.len(), e.message()); + T::zero() + } + }; + + builder.push(value); + }; + + vectorize_with_builder_1_arg::>(f)(from, ctx) +} + +fn integer_to_decimal( + from: ValueRef, + ctx: &mut EvalContext, + size: DecimalSize, +) -> Value> +where + T: Decimal + Mul, + for<'a> S::ScalarRef<'a>: Number + AsPrimitive, +{ + let multiplier = T::e(size.scale as u32); + + let min_for_precision = T::min_for_precision(size.precision); + let max_for_precision = T::max_for_precision(size.precision); + + let f = |x: S::ScalarRef<'_>, builder: &mut Vec, ctx: &mut EvalContext| { + let x = T::from_i128(x.as_()) * multiplier; + + if x > max_for_precision || x < min_for_precision { + ctx.set_error( + builder.len(), + concat!("Decimal overflow at line : ", line!()), + ); + builder.push(T::one()); + } else { + builder.push(x); + } + }; + + vectorize_with_builder_1_arg(f)(from, ctx) +} + +fn float_to_decimal( + from: ValueRef, + ctx: &mut EvalContext, + size: DecimalSize, +) -> Value> +where + for<'a> S::ScalarRef<'a>: Number + AsPrimitive, +{ + let multiplier: f64 = (10_f64).powi(size.scale as i32).as_(); + + let min_for_precision = T::min_for_precision(size.precision); + let max_for_precision = T::max_for_precision(size.precision); + + let f = |x: S::ScalarRef<'_>, builder: &mut Vec, ctx: &mut EvalContext| { + let mut x = x.as_() * multiplier; + if ctx.func_ctx.rounding_mode { + x = x.round(); + } + let x = T::from_float(x); + if x > max_for_precision || x < min_for_precision { + ctx.set_error( + builder.len(), + concat!("Decimal overflow at line : ", line!()), + ); + builder.push(T::one()); + } else { + builder.push(x); + } + }; + + vectorize_with_builder_1_arg(f)(from, ctx) +} + +#[inline] +fn get_round_val(x: T, scale: u32, ctx: &mut EvalContext) -> Option { + let mut round_val = None; + if ctx.func_ctx.rounding_mode && scale > 0 { + // Checking whether numbers need to be added or subtracted to calculate rounding + if let Some(r) = x.checked_rem(T::e(scale)) { + if let Some(m) = r.checked_div(T::e(scale - 1)) { + if m >= T::from_i128(5i64) { + round_val = Some(T::one()); + } else if m <= T::from_i128(-5i64) { + round_val = Some(T::minus_one()); + } + } + } + } + round_val +} + +fn decimal_256_to_128( + buffer: &ValueRef, + from_size: DecimalSize, + dest_size: DecimalSize, + ctx: &mut EvalContext, +) -> Value> { + let max = i128::max_for_precision(dest_size.precision); + let min = i128::min_for_precision(dest_size.precision); + + let buffer = buffer.try_downcast::>().unwrap(); + if dest_size.scale >= from_size.scale { + let factor = i256::e((dest_size.scale - from_size.scale) as u32); + + vectorize_with_builder_1_arg::, DecimalType>( + |x: i256, builder: &mut Vec, ctx: &mut EvalContext| match x.checked_mul(factor) { + Some(x) if x <= max && x >= min => builder.push(*x.low()), + _ => { + ctx.set_error( + builder.len(), + concat!("Decimal overflow at line : ", line!()), + ); + builder.push(i128::one()); + } + }, + )(buffer, ctx) + } else { + let scale_diff = (from_size.scale - dest_size.scale) as u32; + let factor = i256::e(scale_diff); + let source_factor = i256::e(from_size.scale as u32); + + vectorize_with_builder_1_arg::, DecimalType>( + |x: i256, builder: &mut Vec, ctx: &mut EvalContext| { + let round_val = get_round_val::(x, scale_diff, ctx); + let y = match (x.checked_div(factor), round_val) { + (Some(x), Some(round_val)) => x.checked_add(round_val), + (Some(x), None) => Some(x), + (None, _) => None, + }; + + match y { + Some(y) if (y <= max && y >= min) && (y != 0 || x / source_factor == 0) => { + builder.push(*y.low()); + } + _ => { + ctx.set_error( + builder.len(), + concat!("Decimal overflow at line : ", line!()), + ); + + builder.push(i128::one()); + } + } + }, + )(buffer, ctx) + } +} + +macro_rules! m_decimal_to_decimal { + ($from_size: expr, $dest_size: expr, $value: expr, $from_type_name: ty, $dest_type_name: ty, $ctx: expr) => { + type F = $from_type_name; + type T = $dest_type_name; + + let buffer: ValueRef> = $value.try_downcast().unwrap(); + // faster path + let result: Value> = if $from_size.scale == $dest_size.scale + && $from_size.precision <= $dest_size.precision + { + if F::MAX == T::MAX { + // 128 -> 128 or 256 -> 256 + return buffer.clone().to_owned().upcast_decimal($dest_size); + } else { + // 128 -> 256 + vectorize_1_arg::, DecimalType>(|x: F, _: &mut EvalContext| { + T::from(x) + })(buffer, $ctx) + } + } else if $from_size.scale > $dest_size.scale { + let scale_diff = ($from_size.scale - $dest_size.scale) as u32; + let factor = T::e(scale_diff); + let max = T::max_for_precision($dest_size.precision); + let min = T::min_for_precision($dest_size.precision); + + let source_factor = F::e($from_size.scale as u32); + + vectorize_with_builder_1_arg::, DecimalType>( + |x: F, builder: &mut Vec, ctx: &mut EvalContext| { + let x = T::from(x); + let round_val = get_round_val::(x, scale_diff, ctx); + let y = match (x.checked_div(factor), round_val) { + (Some(x), Some(round_val)) => x.checked_add(round_val), + (Some(x), None) => Some(x), + (None, _) => None, + }; + + match y { + Some(y) if y <= max && y >= min && (y != 0 || x / source_factor == 0) => { + builder.push(y as T); + } + _ => { + ctx.set_error( + builder.len(), + concat!("Decimal overflow at line : ", line!()), + ); + builder.push(T::one()); + } + } + }, + )(buffer, $ctx) + } else { + let factor = T::e(($dest_size.scale - $from_size.scale) as u32); + let min = T::min_for_precision($dest_size.precision); + let max = T::max_for_precision($dest_size.precision); + + vectorize_with_builder_1_arg::, DecimalType>( + |x: F, builder: &mut Vec, ctx: &mut EvalContext| { + let x = T::from(x); + match x.checked_mul(factor) { + Some(x) if x <= max && x >= min => { + builder.push(x as T); + } + _ => { + ctx.set_error( + builder.len(), + concat!("Decimal overflow at line : ", line!()), + ); + builder.push(T::one()); + } + } + }, + )(buffer, $ctx) + }; + + result.upcast_decimal($dest_size) + }; +} + +fn decimal_to_decimal( + arg: &ValueRef, + ctx: &mut EvalContext, + from_type: DecimalDataType, + dest_type: DecimalDataType, +) -> Value { + let from_size = from_type.size(); + let dest_size = dest_type.size(); + match (from_type, dest_type) { + (DecimalDataType::Decimal128(_), DecimalDataType::Decimal128(_)) => { + m_decimal_to_decimal! {from_size, dest_size, arg, i128, i128, ctx} + } + (DecimalDataType::Decimal128(_), DecimalDataType::Decimal256(_)) => { + m_decimal_to_decimal! {from_size, dest_size, arg, i128, i256, ctx} + } + (DecimalDataType::Decimal256(_), DecimalDataType::Decimal128(_)) => { + let value = decimal_256_to_128(arg, from_size, dest_size, ctx); + value.upcast_decimal(dest_size) + } + (DecimalDataType::Decimal256(_), DecimalDataType::Decimal256(_)) => { + m_decimal_to_decimal! {from_size, dest_size, arg, i256, i256, ctx} + } + } +} + +trait DecimalConvert { + fn convert(t: T, _scale: i32) -> U; +} + +impl DecimalConvert for F32 { + fn convert(t: i128, scale: i32) -> F32 { + let div = 10f32.powi(scale); + ((t as f32) / div).into() + } +} + +impl DecimalConvert for F64 { + fn convert(t: i128, scale: i32) -> F64 { + let div = 10f64.powi(scale); + ((t as f64) / div).into() + } +} + +impl DecimalConvert for F32 { + fn convert(t: i256, scale: i32) -> F32 { + let div = 10f32.powi(scale); + (f32::from(t) / div).into() + } +} + +impl DecimalConvert for F64 { + fn convert(t: i256, scale: i32) -> F64 { + let div = 10f64.powi(scale); + (f64::from(t) / div).into() + } +} + +fn decimal_to_float( + arg: &ValueRef, + from_type: DataType, + ctx: &mut EvalContext, +) -> Value +where + T: Number, + T: DecimalConvert, + T: DecimalConvert, +{ + let from_type = from_type.as_decimal().unwrap(); + + let result = with_decimal_mapped_type!(|DECIMAL_TYPE| match from_type { + DecimalDataType::DECIMAL_TYPE(from_size) => { + let value = arg.try_downcast().unwrap(); + let scale = from_size.scale as i32; + vectorize_1_arg::, NumberType>( + |x, _ctx: &mut EvalContext| T::convert(x, scale), + )(value, ctx) + } + }); + + result.upcast() +} + +fn decimal_to_int( + arg: &ValueRef, + from_type: DataType, + ctx: &mut EvalContext, +) -> Value { + let from_type = from_type.as_decimal().unwrap(); + + let result = with_decimal_mapped_type!(|DECIMAL_TYPE| match from_type { + DecimalDataType::DECIMAL_TYPE(from_size) => { + let value = arg.try_downcast().unwrap(); + vectorize_with_builder_1_arg::, NumberType>( + |x, builder: &mut Vec, ctx: &mut EvalContext| match x + .to_int(from_size.scale, ctx.func_ctx.rounding_mode) + { + Some(x) => builder.push(x), + None => { + ctx.set_error(builder.len(), "decimal cast to int overflow"); + builder.push(T::default()) + } + }, + )(value, ctx) + } + }); + + result.upcast() +} diff --git a/src/query/functions/src/scalars/decimal/comparison.rs b/src/query/functions/src/scalars/decimal/comparison.rs new file mode 100644 index 000000000000..3c49b6630dab --- /dev/null +++ b/src/query/functions/src/scalars/decimal/comparison.rs @@ -0,0 +1,130 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Ord; +use std::ops::*; +use std::sync::Arc; + +use databend_common_expression::type_check::common_super_type; +use databend_common_expression::types::decimal::*; +use databend_common_expression::types::*; +use databend_common_expression::vectorize_2_arg; +use databend_common_expression::Domain; +use databend_common_expression::EvalContext; +use databend_common_expression::Function; +use databend_common_expression::FunctionEval; +use databend_common_expression::FunctionRegistry; +use databend_common_expression::FunctionSignature; +use databend_common_expression::SimpleDomainCmp; +use databend_common_expression::Value; +use databend_common_expression::ValueRef; +use ethnum::i256; + +macro_rules! register_decimal_compare_op { + ($registry: expr, $name: expr, $op: ident, $domain_op: tt) => { + $registry.register_function_factory($name, |_, args_type| { + if args_type.len() != 2 { + return None; + } + + let has_nullable = args_type.iter().any(|x| x.is_nullable_or_null()); + let args_type: Vec = args_type.iter().map(|x| x.remove_nullable()).collect(); + + // Only works for one of is decimal types + if !args_type[0].is_decimal() && !args_type[1].is_decimal() { + return None; + } + + let common_type = common_super_type(args_type[0].clone(), args_type[1].clone(), &[])?; + + if !common_type.is_decimal() { + return None; + } + + // Comparison between different decimal types must be same siganature types + let function = Function { + signature: FunctionSignature { + name: $name.to_string(), + args_type: vec![common_type.clone(), common_type.clone()], + return_type: DataType::Boolean, + }, + eval: FunctionEval::Scalar { + calc_domain: Box::new(|_, d| { + let new_domain = match (&d[0], &d[1]) { + ( + Domain::Decimal(DecimalDomain::Decimal128(d1, _)), + Domain::Decimal(DecimalDomain::Decimal128(d2, _)), + ) => d1.$domain_op(d2), + ( + Domain::Decimal(DecimalDomain::Decimal256(d1, _)), + Domain::Decimal(DecimalDomain::Decimal256(d2, _)), + ) => d1.$domain_op(d2), + _ => unreachable!("Expect two same decimal domains, got {:?}", d), + }; + new_domain.map(|d| Domain::Boolean(d)) + }), + eval: Box::new(move |args, ctx| { + op_decimal! { &args[0], &args[1], common_type, $op, ctx} + }), + }, + }; + if has_nullable { + Some(Arc::new(function.passthrough_nullable())) + } else { + Some(Arc::new(function)) + } + }); + }; +} + +macro_rules! op_decimal { + ($a: expr, $b: expr, $common_type: expr, $op: ident, $ctx: expr) => { + match $common_type { + DataType::Decimal(DecimalDataType::Decimal128(_)) => { + let f = |a: i128, b: i128, _: &mut EvalContext| -> bool { a.cmp(&b).$op() }; + compare_decimal($a, $b, f, $ctx) + } + DataType::Decimal(DecimalDataType::Decimal256(_)) => { + let f = |a: i256, b: i256, _: &mut EvalContext| -> bool { a.cmp(&b).$op() }; + compare_decimal($a, $b, f, $ctx) + } + _ => unreachable!(), + } + }; +} + +fn compare_decimal( + a: &ValueRef, + b: &ValueRef, + f: F, + ctx: &mut EvalContext, +) -> Value +where + T: Decimal, + F: Fn(T, T, &mut EvalContext) -> bool + Copy + Send + Sync, +{ + let a = a.try_downcast().unwrap(); + let b = b.try_downcast().unwrap(); + let value = vectorize_2_arg::, DecimalType, BooleanType>(f)(a, b, ctx); + value.upcast() +} + +pub fn register_decimal_compare_op(registry: &mut FunctionRegistry) { + register_decimal_compare_op!(registry, "lt", is_lt, domain_lt); + register_decimal_compare_op!(registry, "eq", is_eq, domain_eq); + register_decimal_compare_op!(registry, "gt", is_gt, domain_gt); + register_decimal_compare_op!(registry, "lte", is_le, domain_lte); + register_decimal_compare_op!(registry, "gte", is_ge, domain_gte); + register_decimal_compare_op!(registry, "noteq", is_ne, domain_noteq); +} diff --git a/src/query/functions/src/scalars/decimal/math.rs b/src/query/functions/src/scalars/decimal/math.rs new file mode 100644 index 000000000000..7ed5976b858b --- /dev/null +++ b/src/query/functions/src/scalars/decimal/math.rs @@ -0,0 +1,289 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Ord; +use std::ops::*; +use std::sync::Arc; + +use databend_common_expression::types::decimal::*; +use databend_common_expression::types::*; +use databend_common_expression::vectorize_1_arg; +use databend_common_expression::with_decimal_mapped_type; +use databend_common_expression::EvalContext; +use databend_common_expression::Function; +use databend_common_expression::FunctionDomain; +use databend_common_expression::FunctionEval; +use databend_common_expression::FunctionRegistry; +use databend_common_expression::FunctionSignature; +use databend_common_expression::Value; +use databend_common_expression::ValueRef; +use ethnum::i256; + +pub fn register_decimal_math(registry: &mut FunctionRegistry) { + let factory = |params: &[usize], args_type: &[DataType], round_mode: RoundMode| { + if args_type.is_empty() { + return None; + } + + let from_type = args_type[0].remove_nullable(); + if !matches!(from_type, DataType::Decimal(_)) { + return None; + } + + let from_decimal_type = from_type.as_decimal().unwrap(); + + let scale = if params.is_empty() { + 0 + } else { + params[0] as i64 - 76 + }; + + let decimal_size = DecimalSize { + precision: from_decimal_type.precision(), + scale: scale.clamp(0, from_decimal_type.scale() as i64) as u8, + }; + + let dest_decimal_type = DecimalDataType::from_size(decimal_size).ok()?; + let name = format!("{:?}", round_mode).to_lowercase(); + + let mut sig_args_type = args_type.to_owned(); + sig_args_type[0] = from_type.clone(); + let f = Function { + signature: FunctionSignature { + name, + args_type: sig_args_type, + return_type: DataType::Decimal(dest_decimal_type), + }, + eval: FunctionEval::Scalar { + calc_domain: Box::new(move |_ctx, _d| FunctionDomain::Full), + eval: Box::new(move |args, ctx| { + decimal_math( + &args[0], + ctx, + from_type.clone(), + dest_decimal_type, + scale, + round_mode, + ) + }), + }, + }; + + if args_type[0].is_nullable() { + Some(f.passthrough_nullable()) + } else { + Some(f) + } + }; + + for m in [ + RoundMode::Round, + RoundMode::Truncate, + RoundMode::Ceil, + RoundMode::Floor, + ] { + let name = format!("{:?}", m).to_lowercase(); + registry.register_function_factory(&name, move |params, args_type| { + Some(Arc::new(factory(params, args_type, m)?)) + }); + } +} + +#[derive(Copy, Clone, Debug)] +enum RoundMode { + Round, + Truncate, + Floor, + Ceil, +} + +fn decimal_round_positive( + value: ValueRef>, + source_scale: i64, + target_scale: i64, + ctx: &mut EvalContext, +) -> Value> +where + T: Decimal + From + DivAssign + Div + Add + Sub, +{ + let power_of_ten = T::e((source_scale - target_scale) as u32); + let addition = power_of_ten / T::from(2); + vectorize_1_arg::, DecimalType>(|a, _| { + if a < T::zero() { + (a - addition) / power_of_ten + } else { + (a + addition) / power_of_ten + } + })(value, ctx) +} + +fn decimal_round_negative( + value: ValueRef>, + source_scale: i64, + target_scale: i64, + ctx: &mut EvalContext, +) -> Value> +where + T: Decimal + + From + + DivAssign + + Div + + Add + + Sub + + Mul, +{ + let divide_power_of_ten = T::e((source_scale - target_scale) as u32); + let addition = divide_power_of_ten / T::from(2); + let multiply_power_of_ten = T::e((-target_scale) as u32); + + vectorize_1_arg::, DecimalType>(|a, _| { + let a = if a < T::zero() { + a - addition + } else { + a + addition + }; + a / divide_power_of_ten * multiply_power_of_ten + })(value, ctx) +} + +// if round mode is ceil, truncate should add one value +fn decimal_truncate_positive( + value: ValueRef>, + source_scale: i64, + target_scale: i64, + ctx: &mut EvalContext, +) -> Value> +where + T: Decimal + From + DivAssign + Div + Add + Sub, +{ + let power_of_ten = T::e((source_scale - target_scale) as u32); + vectorize_1_arg::, DecimalType>(|a, _| a / power_of_ten)(value, ctx) +} + +fn decimal_truncate_negative( + value: ValueRef>, + source_scale: i64, + target_scale: i64, + ctx: &mut EvalContext, +) -> Value> +where + T: Decimal + + From + + DivAssign + + Div + + Add + + Sub + + Mul, +{ + let divide_power_of_ten = T::e((source_scale - target_scale) as u32); + let multiply_power_of_ten = T::e((-target_scale) as u32); + + vectorize_1_arg::, DecimalType>(|a, _| { + a / divide_power_of_ten * multiply_power_of_ten + })(value, ctx) +} + +fn decimal_floor( + value: ValueRef>, + source_scale: i64, + ctx: &mut EvalContext, +) -> Value> +where + T: Decimal + + From + + DivAssign + + Div + + Add + + Sub + + Mul, +{ + let power_of_ten = T::e(source_scale as u32); + + vectorize_1_arg::, DecimalType>(|a, _| { + if a < T::zero() { + // below 0 we ceil the number (e.g. -10.5 -> -11) + ((a + T::one()) / power_of_ten) - T::one() + } else { + a / power_of_ten + } + })(value, ctx) +} + +fn decimal_ceil( + value: ValueRef>, + source_scale: i64, + ctx: &mut EvalContext, +) -> Value> +where + T: Decimal + + From + + DivAssign + + Div + + Add + + Sub + + Mul, +{ + let power_of_ten = T::e(source_scale as u32); + + vectorize_1_arg::, DecimalType>(|a, _| { + if a <= T::zero() { + a / power_of_ten + } else { + ((a - T::one()) / power_of_ten) + T::one() + } + })(value, ctx) +} + +fn decimal_math( + arg: &ValueRef, + ctx: &mut EvalContext, + from_type: DataType, + dest_type: DecimalDataType, + target_scale: i64, + mode: RoundMode, +) -> Value { + let from_decimal_type = from_type.as_decimal().unwrap(); + let source_scale = from_decimal_type.scale() as i64; + + if source_scale < target_scale { + return arg.clone().to_owned(); + } + + let none_negative = target_scale >= 0; + + with_decimal_mapped_type!(|DECIMAL_TYPE| match from_decimal_type { + DecimalDataType::DECIMAL_TYPE(_) => { + let value = arg.try_downcast::>().unwrap(); + + let result = match (none_negative, mode) { + (true, RoundMode::Round) => { + decimal_round_positive::<_>(value, source_scale, target_scale, ctx) + } + (true, RoundMode::Truncate) => { + decimal_truncate_positive::<_>(value, source_scale, target_scale, ctx) + } + (false, RoundMode::Round) => { + decimal_round_negative::<_>(value, source_scale, target_scale, ctx) + } + (false, RoundMode::Truncate) => { + decimal_truncate_negative::<_>(value, source_scale, target_scale, ctx) + } + (_, RoundMode::Floor) => decimal_floor::<_>(value, source_scale, ctx), + (_, RoundMode::Ceil) => decimal_ceil::<_>(value, source_scale, ctx), + }; + + result.upcast_decimal(dest_type.size()) + } + }) +} diff --git a/src/query/functions/src/scalars/decimal/mod.rs b/src/query/functions/src/scalars/decimal/mod.rs new file mode 100644 index 000000000000..12407ca97278 --- /dev/null +++ b/src/query/functions/src/scalars/decimal/mod.rs @@ -0,0 +1,25 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod arithmetic; +mod cast; +mod comparison; +mod math; + +pub(crate) use arithmetic::register_decimal_arithmetic; +pub(crate) use cast::register_decimal_to_float; +pub(crate) use cast::register_decimal_to_int; +pub(crate) use cast::register_to_decimal; +pub(crate) use comparison::register_decimal_compare_op; +pub(crate) use math::register_decimal_math; diff --git a/src/query/functions/src/scalars/mod.rs b/src/query/functions/src/scalars/mod.rs index 0e482e58f513..464ee93a5aa0 100644 --- a/src/query/functions/src/scalars/mod.rs +++ b/src/query/functions/src/scalars/mod.rs @@ -57,7 +57,7 @@ pub fn register(registry: &mut FunctionRegistry) { geo_h3::register(registry); hash::register(registry); other::register(registry); - decimal::register(registry); + decimal::register_to_decimal(registry); vector::register(registry); bitmap::register(registry); } diff --git a/src/query/functions/tests/it/scalars/cast.rs b/src/query/functions/tests/it/scalars/cast.rs index ec360d1cf77e..9ed9a6a79a5a 100644 --- a/src/query/functions/tests/it/scalars/cast.rs +++ b/src/query/functions/tests/it/scalars/cast.rs @@ -665,6 +665,11 @@ fn test_cast_between_string_and_decimal(file: &mut impl Write, is_try: bool) { format!("{prefix}CAST('-1.0e+10' AS DECIMAL(11, 0))"), &[], ); + run_ast( + file, + format!("{prefix}CAST('-0.000000' AS DECIMAL(11, 0))"), + &[], + ); } fn gen_bitmap_data() -> Column { diff --git a/src/query/functions/tests/it/scalars/comparison.rs b/src/query/functions/tests/it/scalars/comparison.rs index 7a3f8a7c4564..c93cc6046858 100644 --- a/src/query/functions/tests/it/scalars/comparison.rs +++ b/src/query/functions/tests/it/scalars/comparison.rs @@ -40,6 +40,7 @@ fn test_eq(file: &mut impl Write) { run_ast(file, "null=null", &[]); run_ast(file, "1=2", &[]); run_ast(file, "1.0=1", &[]); + run_ast(file, "2.222>2.11", &[]); run_ast(file, "true=null", &[]); run_ast(file, "true=false", &[]); run_ast(file, "false=false", &[]); diff --git a/src/query/functions/tests/it/scalars/testdata/cast.txt b/src/query/functions/tests/it/scalars/testdata/cast.txt index 084da87a5bf5..cf09cc4fe637 100644 --- a/src/query/functions/tests/it/scalars/testdata/cast.txt +++ b/src/query/functions/tests/it/scalars/testdata/cast.txt @@ -1080,6 +1080,15 @@ output domain : {-10000000000..=-10000000000} output : -10000000000 +ast : CAST('-0.000000' AS DECIMAL(11, 0)) +raw expr : CAST('-0.000000' AS Decimal(11, 0)) +checked expr : to_decimal(11, 0)("-0.000000") +optimized expr : 0_d128(11,0) +output type : Decimal(11, 0) +output domain : {0..=0} +output : 0 + + ast : CAST(0 AS BOOLEAN) raw expr : CAST(0 AS Boolean) checked expr : to_boolean(0_u8) @@ -3070,6 +3079,15 @@ output domain : {-10000000000..=-10000000000} output : -10000000000 +ast : TRY_CAST('-0.000000' AS DECIMAL(11, 0)) +raw expr : TRY_CAST('-0.000000' AS Decimal(11, 0)) +checked expr : try_to_decimal(11, 0)("-0.000000") +optimized expr : 0_d128(11,0) +output type : Decimal(11, 0) NULL +output domain : {0..=0} +output : 0 + + ast : TRY_CAST(0 AS BOOLEAN) raw expr : TRY_CAST(0 AS Boolean) checked expr : try_to_boolean(0_u8) diff --git a/src/query/functions/tests/it/scalars/testdata/comparison.txt b/src/query/functions/tests/it/scalars/testdata/comparison.txt index 866182f55989..97194c73aef2 100644 --- a/src/query/functions/tests/it/scalars/testdata/comparison.txt +++ b/src/query/functions/tests/it/scalars/testdata/comparison.txt @@ -34,6 +34,15 @@ output domain : {TRUE} output : true +ast : 2.222>2.11 +raw expr : gt(2.222, 2.11) +checked expr : gt(2.222_d128(4,3), to_decimal(4, 3)(2.11_d128(3,2))) +optimized expr : true +output type : Boolean +output domain : {TRUE} +output : true + + ast : true=null raw expr : eq(true, NULL) checked expr : eq(CAST(true AS Boolean NULL), CAST(NULL AS Boolean NULL)) diff --git a/src/query/storages/parquet/src/parquet_rs/statistics/column.rs b/src/query/storages/parquet/src/parquet_rs/statistics/column.rs index da2f92fca7b5..416c12022c4e 100644 --- a/src/query/storages/parquet/src/parquet_rs/statistics/column.rs +++ b/src/query/storages/parquet/src/parquet_rs/statistics/column.rs @@ -58,14 +58,8 @@ pub fn convert_column_statistics(s: &Statistics, typ: &TableDataType) -> ColumnS Scalar::Decimal(DecimalScalar::Decimal128(i128::from(min), *size)), ), TableDataType::Decimal(DecimalDataType::Decimal256(size)) => ( - Scalar::Decimal(DecimalScalar::Decimal256( - I256::from_i64(max as i64), - *size, - )), - Scalar::Decimal(DecimalScalar::Decimal256( - I256::from_i64(min as i64), - *size, - )), + Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i128(max), *size)), + Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i128(min), *size)), ), _ => (Scalar::Null, Scalar::Null), } @@ -95,8 +89,8 @@ pub fn convert_column_statistics(s: &Statistics, typ: &TableDataType) -> ColumnS Scalar::Decimal(DecimalScalar::Decimal128(i128::from(min), *size)), ), TableDataType::Decimal(DecimalDataType::Decimal256(size)) => ( - Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i64(max), *size)), - Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i64(min), *size)), + Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i128(max), *size)), + Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i128(min), *size)), ), _ => (Scalar::Null, Scalar::Null), } diff --git a/src/query/storages/parquet/src/parquet_rs/statistics/page.rs b/src/query/storages/parquet/src/parquet_rs/statistics/page.rs index 27186c6ab0f0..d4e12246af2e 100644 --- a/src/query/storages/parquet/src/parquet_rs/statistics/page.rs +++ b/src/query/storages/parquet/src/parquet_rs/statistics/page.rs @@ -151,8 +151,8 @@ fn convert_page_index_int32( Scalar::Decimal(DecimalScalar::Decimal128(i128::from(min), *size)), ), TableDataType::Decimal(DecimalDataType::Decimal256(size)) => ( - Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i64(max as i64), *size)), - Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i64(min as i64), *size)), + Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i128(max), *size)), + Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i128(min), *size)), ), _ => unreachable!(), }; @@ -187,8 +187,8 @@ fn convert_page_index_int64( Scalar::Decimal(DecimalScalar::Decimal128(i128::from(min), *size)), ), TableDataType::Decimal(DecimalDataType::Decimal256(size)) => ( - Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i64(max), *size)), - Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i64(min), *size)), + Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i128(max), *size)), + Scalar::Decimal(DecimalScalar::Decimal256(I256::from_i128(min), *size)), ), _ => unreachable!(), };