diff --git a/core/src/execution/datafusion/expressions/covariance.rs b/core/src/execution/datafusion/expressions/covariance.rs index 5a10371c4..dd3f81beb 100644 --- a/core/src/execution/datafusion/expressions/covariance.rs +++ b/core/src/execution/datafusion/expressions/covariance.rs @@ -17,21 +17,22 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::sync::Arc; +use std::{any::Any, sync::Arc}; -use arrow::array::Float64Array; use arrow::{ - array::{ArrayRef, Int64Array}, + array::{ArrayRef, Float64Array}, compute::cast, - datatypes::DataType, - datatypes::Field, + datatypes::{DataType, Field}, }; use datafusion::logical_expr::Accumulator; -use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, expressions::format_state_name, AggregateExpr, PhysicalExpr}; -use datafusion_physical_expr::expressions::StatsType; +use datafusion_common::{ + downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_physical_expr::{ + aggregate::utils::down_cast_any_ref, + expressions::{format_state_name, StatsType}, + AggregateExpr, PhysicalExpr, +}; /// COVAR and COVAR_SAMP aggregate expression #[derive(Debug, Clone)] @@ -85,7 +86,7 @@ impl AggregateExpr for Covariance { Ok(vec![ Field::new( format_state_name(&self.name, "count"), - DataType::Int64, + DataType::Float64, true, ), Field::new( @@ -119,9 +120,7 @@ impl PartialEq for Covariance { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() - .map(|x| { - self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2) - }) + .map(|x| self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2)) .unwrap_or(false) } } @@ -164,7 +163,7 @@ impl AggregateExpr for CovariancePop { Ok(vec![ Field::new( format_state_name(&self.name, "count"), - DataType::Int64, + DataType::Float64, true, ), Field::new( @@ -198,9 +197,7 @@ impl PartialEq for CovariancePop { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() - .map(|x| { - self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2) - }) + .map(|x| self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2)) .unwrap_or(false) } } @@ -211,7 +208,7 @@ pub struct CovarianceAccumulator { algo_const: f64, mean1: f64, mean2: f64, - count: i64, + count: f64, stats_type: StatsType, } @@ -222,12 +219,12 @@ impl CovarianceAccumulator { algo_const: 0_f64, mean1: 0_f64, mean2: 0_f64, - count: 0_i64, + count: 0_f64, stats_type: s_type, }) } - pub fn get_count(&self) -> i64 { + pub fn get_count(&self) -> f64 { self.count } @@ -279,14 +276,14 @@ impl Accumulator for CovarianceAccumulator { let value1 = unwrap_or_internal_err!(value1); let value2 = unwrap_or_internal_err!(value2); - let new_count = self.count + 1; + let new_count = self.count + 1.0; let delta1 = value1 - self.mean1; let new_mean1 = delta1 / new_count as f64 + self.mean1; let delta2 = value2 - self.mean2; let new_mean2 = delta2 / new_count as f64 + self.mean2; let new_c = delta1 * (value2 - new_mean2) + self.algo_const; - self.count += 1; + self.count += 1.0; self.mean1 = new_mean1; self.mean2 = new_mean2; self.algo_const = new_c; @@ -320,14 +317,14 @@ impl Accumulator for CovarianceAccumulator { let value1 = unwrap_or_internal_err!(value1); let value2 = unwrap_or_internal_err!(value2); - let new_count = self.count - 1; + let new_count = self.count - 1.0; let delta1 = self.mean1 - value1; let new_mean1 = delta1 / new_count as f64 + self.mean1; let delta2 = self.mean2 - value2; let new_mean2 = delta2 / new_count as f64 + self.mean2; let new_c = self.algo_const - delta1 * (new_mean2 - value2); - self.count -= 1; + self.count -= 1.0; self.mean1 = new_mean1; self.mean2 = new_mean2; self.algo_const = new_c; @@ -337,14 +334,14 @@ impl Accumulator for CovarianceAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], Int64Array); + let counts = downcast_value!(states[0], Float64Array); let means1 = downcast_value!(states[1], Float64Array); let means2 = downcast_value!(states[2], Float64Array); let cs = downcast_value!(states[3], Float64Array); for i in 0..counts.len() { let c = counts.value(i); - if c == 0 { + if c == 0.0 { continue; } let new_count = self.count + c; @@ -369,19 +366,18 @@ impl Accumulator for CovarianceAccumulator { } fn evaluate(&mut self) -> Result { - println!("evaluate evaluate evaluate"); let count = match self.stats_type { datafusion_physical_expr::expressions::StatsType::Population => self.count, StatsType::Sample => { - if self.count > 0 { - self.count - 1 + if self.count > 0.0 { + self.count - 1.0 } else { self.count } } }; - if count == 0 { + if count == 0.0 { Ok(ScalarValue::Float64(None)) } else { Ok(ScalarValue::Float64(Some(self.algo_const / count as f64))) diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 7cbfd4465..80294d81a 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -27,9 +27,9 @@ pub use normalize_nan::NormalizeNaNAndZero; pub mod avg; pub mod avg_decimal; pub mod bloom_filter_might_contain; +pub mod covariance; pub mod strings; pub mod subquery; pub mod sum_decimal; pub mod temporal; mod utils; -pub mod covariance; diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index bdb7081be..556735f78 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -59,12 +59,11 @@ use crate::{ expressions::{ avg::Avg, avg_decimal::AvgDecimal, - covariance::Covariance, - covariance::CovariancePop, bitwise_not::BitwiseNotExpr, bloom_filter_might_contain::BloomFilterMightContain, cast::Cast, checkoverflow::CheckOverflow, + covariance::{Covariance, CovariancePop}, if_expr::IfExpr, scalar_funcs::create_comet_physical_fun, strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExec, SubstringExec}, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 68a111746..65afe76e3 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -998,30 +998,38 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("covar_pop and covar_samp") { - withSQLConf( - CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { - Seq(false).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql( - s"create table $table(col1 int, col2 int, col3 int, col4 float, col5 double, col6 double, col7 int) using parquet") - sql( - s"insert into $table values(1, 4, null, 1.1, 2.2, null, 1), (2, 5, 6, 3.4, 5.6, null, 1), (3, 6, null, 7.9, 2.4, null, 2)") - val expectedNumOfCometAggregates = 2 - checkSparkAnswerAndNumOfAggregates( - "SELECT covar_samp(col1, col2), covar_samp(col1, col3), covar_samp(col4, col5), covar_samp(col4, col6) FROM test", - expectedNumOfCometAggregates) - checkSparkAnswerAndNumOfAggregates( - "SELECT covar_pop(col1, col2), covar_pop(col1, col3), covar_pop(col4, col5), covar_pop(col4, col6) FROM test", - expectedNumOfCometAggregates) - checkSparkAnswerAndNumOfAggregates( - "SELECT covar_samp(col1, col2), covar_samp(col1, col3), covar_samp(col4, col5), covar_samp(col4, col6) FROM test GROUP BY col7", - expectedNumOfCometAggregates) - checkSparkAnswerAndNumOfAggregates( - "SELECT covar_pop(col1, col2), covar_pop(col1, col3), covar_pop(col4, col5), covar_pop(col4, col6) FROM test GROUP BY col7", - expectedNumOfCometAggregates) + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + Seq(true, false).foreach { cometColumnShuffleEnabled => + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq(true, false).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "test" + withTable(table) { + sql( + s"create table $table(col1 int, col2 int, col3 int, col4 float, col5 double," + + s" col6 double, col7 int) using parquet") + sql(s"insert into $table values(1, 4, null, 1.1, 2.2, null, 1)," + + s" (2, 5, 6, 3.4, 5.6, null, 1), (3, 6, null, 7.9, 2.4, null, 2)") + val expectedNumOfCometAggregates = 2 + checkSparkAnswerAndNumOfAggregates( + "SELECT covar_samp(col1, col2), covar_samp(col1, col3), covar_samp(col4, col5)," + + " covar_samp(col4, col6) FROM test", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT covar_pop(col1, col2), covar_pop(col1, col3), covar_pop(col4, col5)," + + " covar_pop(col4, col6) FROM test", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT covar_samp(col1, col2), covar_samp(col1, col3), covar_samp(col4, col5)," + + " covar_samp(col4, col6) FROM test GROUP BY col7", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT covar_pop(col1, col2), covar_pop(col1, col3), covar_pop(col4, col5)," + + " covar_pop(col4, col6) FROM test GROUP BY col7", + expectedNumOfCometAggregates) + } + } } } }