Skip to content

Commit

Permalink
fix: Remove castting to decimal256
Browse files Browse the repository at this point in the history
  • Loading branch information
kazuyukitanimura committed Jul 31, 2024
1 parent a14f888 commit af1fdfb
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

//! Converts Spark physical plan to DataFusion physical plan
use arrow_schema::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION};
use arrow_schema::{
DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
};
use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf};
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
Expand Down Expand Up @@ -652,6 +654,16 @@ impl PhysicalPlanner {
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let left = self.create_expr(left, input_schema.clone())?;
let right = self.create_expr(right, input_schema.clone())?;

#[inline(always)]
fn div_mul_pow(s1: i8, s2: i8) -> i8 {
DECIMAL128_MAX_SCALE.min(s1 + 4) - s1 + s2
}
#[inline(always)]
fn res_precision(p1: u8, s1: i8, p2: u8, s2: i8) -> u8 {
max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
}

match (
&op,
left.data_type(&input_schema),
Expand All @@ -665,12 +677,10 @@ impl PhysicalPlanner {
Ok(DataType::Decimal128(p1, s1)),
Ok(DataType::Decimal128(p2, s2)),
) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus)
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
>= DECIMAL128_MAX_PRECISION)
&& res_precision(p1, s1, p2, s2) >= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Modulo
&& max(p1 - s1 as u8, p2 - s2 as u8) + max(s1, s2) as u8
> DECIMAL128_MAX_PRECISION) =>
&& res_precision(p1, s1, p2, s2) > DECIMAL128_MAX_PRECISION) =>
{
let data_type = return_type.map(to_arrow_datatype).unwrap();
// For some Decimal128 operations, we need wider internal digits.
Expand All @@ -694,9 +704,13 @@ impl PhysicalPlanner {
}
(
DataFusionOperator::Divide,
Ok(DataType::Decimal128(_p1, _s1)),
Ok(DataType::Decimal128(_p2, _s2)),
) => {
Ok(DataType::Decimal128(p1, s1)),
Ok(DataType::Decimal128(p2, s2)),
) if (div_mul_pow(s1, s2) > 0
&& p1 + div_mul_pow(s1, s2) as u8 > DECIMAL128_MAX_PRECISION)
|| (div_mul_pow(s1, s2) < 0
&& p2 + div_mul_pow(s1, s2) as u8 > DECIMAL128_MAX_PRECISION) =>
{
let data_type = return_type.map(to_arrow_datatype).unwrap();
let fun_expr = create_comet_physical_fun(
"decimal_div",
Expand Down

0 comments on commit af1fdfb

Please sign in to comment.