From 2d288e1d49e81a8ecc98194cdb36f1e102d5583c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 21 May 2024 08:00:54 -0700 Subject: [PATCH] feat: correlation support --- .../datafusion/expressions/correlation.rs | 256 ++++++++++++++++++ .../execution/datafusion/expressions/mod.rs | 1 + core/src/execution/datafusion/planner.rs | 13 + core/src/execution/proto/expr.proto | 8 + docs/source/user-guide/expressions.md | 1 + .../apache/comet/serde/QueryPlanSerde.scala | 22 +- .../comet/exec/CometAggregateSuite.scala | 151 +++++++++++ 7 files changed, 451 insertions(+), 1 deletion(-) create mode 100644 core/src/execution/datafusion/expressions/correlation.rs diff --git a/core/src/execution/datafusion/expressions/correlation.rs b/core/src/execution/datafusion/expressions/correlation.rs new file mode 100644 index 000000000..5b12c96aa --- /dev/null +++ b/core/src/execution/datafusion/expressions/correlation.rs @@ -0,0 +1,256 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::{and, filter, is_not_null}; + +use std::{any::Any, sync::Arc}; + +use crate::execution::datafusion::expressions::{ + covariance::CovarianceAccumulator, stats::StatsType, stddev::StddevAccumulator, + utils::down_cast_any_ref, +}; +use arrow::{ + array::ArrayRef, + datatypes::{DataType, Field}, +}; +use datafusion::logical_expr::Accumulator; +use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; + +/// CORR aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field `count`, +/// while Spark has Double for count. Also we have added `null_on_divide_by_zero` +/// to be consistent with Spark's implementation. +#[derive(Debug)] +pub struct Correlation { + name: String, + expr1: Arc, + expr2: Arc, + null_on_divide_by_zero: bool, +} + +impl Correlation { + pub fn new( + expr1: Arc, + expr2: Arc, + name: impl Into, + data_type: DataType, + null_on_divide_by_zero: bool, + ) -> Self { + // the result of correlation just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr1, + expr2, + null_on_divide_by_zero, + } + } +} + +impl AggregateExpr for Correlation { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(CorrelationAccumulator::try_new( + self.null_on_divide_by_zero, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean1"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean2"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "algo_const"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "m2_1"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "m2_2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr1.clone(), self.expr2.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for Correlation { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.expr1.eq(&x.expr1) + && self.expr2.eq(&x.expr2) + && self.null_on_divide_by_zero == x.null_on_divide_by_zero + }) + .unwrap_or(false) + } +} + +/// An accumulator to compute correlation +#[derive(Debug)] +pub struct CorrelationAccumulator { + covar: CovarianceAccumulator, + stddev1: StddevAccumulator, + stddev2: StddevAccumulator, + null_on_divide_by_zero: bool, +} + +impl CorrelationAccumulator { + /// Creates a new `CorrelationAccumulator` + pub fn try_new(null_on_divide_by_zero: bool) -> Result { + Ok(Self { + covar: CovarianceAccumulator::try_new(StatsType::Population)?, + stddev1: StddevAccumulator::try_new(StatsType::Population, null_on_divide_by_zero)?, + stddev2: StddevAccumulator::try_new(StatsType::Population, null_on_divide_by_zero)?, + null_on_divide_by_zero, + }) + } +} + +impl Accumulator for CorrelationAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.covar.get_count()), + ScalarValue::from(self.covar.get_mean1()), + ScalarValue::from(self.covar.get_mean2()), + ScalarValue::from(self.covar.get_algo_const()), + ScalarValue::from(self.stddev1.get_m2()), + ScalarValue::from(self.stddev2.get_m2()), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + let values1 = filter(&values[0], &mask)?; + let values2 = filter(&values[1], &mask)?; + + vec![values1, values2] + } else { + values.to_vec() + }; + + if !values[0].is_empty() && !values[1].is_empty() { + self.covar.update_batch(&values)?; + self.stddev1.update_batch(&values[0..1])?; + self.stddev2.update_batch(&values[1..2])?; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + let values1 = filter(&values[0], &mask)?; + let values2 = filter(&values[1], &mask)?; + + vec![values1, values2] + } else { + values.to_vec() + }; + + self.covar.retract_batch(&values)?; + self.stddev1.retract_batch(&values[0..1])?; + self.stddev2.retract_batch(&values[1..2])?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let states_c = [ + states[0].clone(), + states[1].clone(), + states[2].clone(), + states[3].clone(), + ]; + let states_s1 = [states[0].clone(), states[1].clone(), states[4].clone()]; + let states_s2 = [states[0].clone(), states[2].clone(), states[5].clone()]; + + if states[0].len() > 0 && states[1].len() > 0 && states[2].len() > 0 { + self.covar.merge_batch(&states_c)?; + self.stddev1.merge_batch(&states_s1)?; + self.stddev2.merge_batch(&states_s2)?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let covar = self.covar.evaluate()?; + let stddev1 = self.stddev1.evaluate()?; + let stddev2 = self.stddev2.evaluate()?; + + match (covar, stddev1, stddev2) { + ( + ScalarValue::Float64(Some(c)), + ScalarValue::Float64(Some(s1)), + ScalarValue::Float64(Some(s2)) + ) if s1 != 0.0 && s2 != 0.0 => Ok(ScalarValue::Float64(Some(c / (s1 * s2)))), + _ if self.null_on_divide_by_zero => Ok(ScalarValue::Float64(None)), + _ => { + if self.covar.get_count() == 1.0 { + return Ok(ScalarValue::Float64(Some(f64::NAN))); + } + Ok(ScalarValue::Float64(None)) + } + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) + self.covar.size() + - std::mem::size_of_val(&self.stddev1) + + self.stddev1.size() + - std::mem::size_of_val(&self.stddev2) + + self.stddev2.size() + } +} diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 10cac1696..9db4b65b3 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -27,6 +27,7 @@ pub use normalize_nan::NormalizeNaNAndZero; pub mod avg; pub mod avg_decimal; pub mod bloom_filter_might_contain; +pub mod correlation; pub mod covariance; pub mod stats; pub mod stddev; diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 59818857e..01d892381 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -67,6 +67,7 @@ use crate::{ bloom_filter_might_contain::BloomFilterMightContain, cast::{Cast, EvalMode}, checkoverflow::CheckOverflow, + correlation::Correlation, covariance::Covariance, if_expr::IfExpr, scalar_funcs::create_comet_physical_fun, @@ -1310,6 +1311,18 @@ impl PhysicalPlanner { ))), } } + AggExprStruct::Correlation(expr) => { + let child1 = self.create_expr(expr.child1.as_ref().unwrap(), schema.clone())?; + let child2 = self.create_expr(expr.child2.as_ref().unwrap(), schema.clone())?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(Correlation::new( + child1, + child2, + "correlation", + datatype, + expr.null_on_divide_by_zero, + ))) + } } } diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index ee3de865a..be85e8a92 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -96,6 +96,7 @@ message AggExpr { CovPopulation covPopulation = 13; Variance variance = 14; Stddev stddev = 15; + Correlation correlation = 16; } } @@ -186,6 +187,13 @@ message Stddev { StatisticsType stats_type = 4; } +message Correlation { + Expr child1 = 1; + Expr child2 = 2; + bool null_on_divide_by_zero = 3; + DataType datatype = 4; +} + message Literal { oneof value { bool bool_val = 1; diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 38c86c727..521699d34 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -109,3 +109,4 @@ The following Spark expressions are currently available: - VarianceSamp - StddevPop - StddevSamp + - Corr diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7238990ad..5e59f86fe 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ @@ -547,6 +547,26 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(aggExpr, child) None } + case corr @ Corr(child1, child2, nullOnDivideByZero) => + val child1Expr = exprToProto(child1, inputs, binding) + val child2Expr = exprToProto(child2, inputs, binding) + val dataType = serializeDataType(corr.dataType) + + if (child1Expr.isDefined && child2Expr.isDefined && dataType.isDefined) { + val corrBuilder = ExprOuterClass.Correlation.newBuilder() + corrBuilder.setChild1(child1Expr.get) + corrBuilder.setChild2(child2Expr.get) + corrBuilder.setNullOnDivideByZero(nullOnDivideByZero) + corrBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setCorrelation(corrBuilder) + .build()) + } else { + None + } case fn => val msg = s"unsupported Spark aggregate function: ${fn.prettyName}" emitWarning(msg) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 310a24ee3..f52070335 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1212,6 +1212,157 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("correlation") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + Seq(false).foreach { cometColumnShuffleEnabled => + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq(false).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + Seq(true).foreach { nullOnDivideByZero => + withSQLConf( + "spark.sql.legacy.statisticalAggregate" -> nullOnDivideByZero.toString) { + val table = "test" + withTable(table) { + sql( + s"create table $table(col1 double, col2 double, col3 double) using parquet") + sql(s"insert into $table values(1, 4, 1), (2, 5, 1), (3, 6, 2)") + val expectedNumOfCometAggregates = 2 + + sql("SELECT corr(col1, col2) FROM test GROUP BY col3").show + + checkSparkAnswerAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test", + expectedNumOfCometAggregates) + + checkSparkAnswerAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql( + s"create table $table(col1 double, col2 double, col3 double) using parquet") + sql(s"insert into $table values(1, 4, 3), (2, -5, 3), (3, 6, 1)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql( + s"create table $table(col1 double, col2 double, col3 double) using parquet") + sql(s"insert into $table values(1.1, 4.1, 2.3), (2, 5, 1.5), (3, 6, 2.3)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql( + s"create table $table(col1 double, col2 double, col3 double) using parquet") + sql(s"insert into $table values(1, 4, 1), (2, 5, 2), (3, 6, 3), (1.1, 4.4, 1), (2.2, 5.5, 2), (3.3, 6.6, 3)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql(s"insert into $table values(1, 4, 1), (2, 5, 2), (3, 6, 3)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql( + s"insert into $table values(1, 4, 2), (null, null, 2), (3, 6, 1), (3, 3, 1)") + val expectedNumOfCometAggregates = 2 + + sql("SELECT corr(col1, col2) FROM test GROUP BY col3").show() + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql(s"insert into $table values(1, 4, 1), (null, 5, 1), (2, 5, 2), (9, null, 2), (3, 6, 2)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql(s"insert into $table values(null, null, 1), (1, 2, 1), (null, null, 2)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql( + s"insert into $table values(null, null, 1), (null, null, 1), (null, null, 2)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT corr(col1, col2) FROM test GROUP BY col3", + expectedNumOfCometAggregates) + } + } + } + } + } + } + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df)