diff --git a/EXPRESSIONS.md b/EXPRESSIONS.md index bc03be6ae..45c36844b 100644 --- a/EXPRESSIONS.md +++ b/EXPRESSIONS.md @@ -101,3 +101,6 @@ The following Spark expressions are currently available: + BitAnd + BitOr + BitXor + + BoolAnd + + BoolOr + + Covariance diff --git a/core/src/execution/datafusion/expressions/covariance.rs b/core/src/execution/datafusion/expressions/covariance.rs new file mode 100644 index 000000000..5d0e550fa --- /dev/null +++ b/core/src/execution/datafusion/expressions/covariance.rs @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{any::Any, sync::Arc}; + +use crate::execution::datafusion::expressions::stats::StatsType; +use arrow::{ + array::{ArrayRef, Float64Array}, + compute::cast, + datatypes::{DataType, Field}, +}; +use datafusion::logical_expr::Accumulator; +use datafusion_common::{ + downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_physical_expr::{ + aggregate::utils::down_cast_any_ref, expressions::format_state_name, AggregateExpr, + PhysicalExpr, +}; + +/// COVAR_SAMP and COVAR_POP aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field count, +/// while Spark has Double for count. +#[derive(Debug, Clone)] +pub struct Covariance { + name: String, + expr1: Arc, + expr2: Arc, + stats_type: StatsType, +} + +impl Covariance { + /// Create a new COVAR aggregate function + pub fn new( + expr1: Arc, + expr2: Arc, + name: impl Into, + data_type: DataType, + stats_type: StatsType, + ) -> Self { + // the result of covariance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr1, + expr2, + stats_type, + } + } +} + +impl AggregateExpr for Covariance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(CovarianceAccumulator::try_new(self.stats_type)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean1"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean2"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr1.clone(), self.expr2.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for Covariance { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.expr1.eq(&x.expr1) + && self.expr2.eq(&x.expr2) + && self.stats_type == x.stats_type + }) + .unwrap_or(false) + } +} + +/// An accumulator to compute covariance +#[derive(Debug)] +pub struct CovarianceAccumulator { + algo_const: f64, + mean1: f64, + mean2: f64, + count: f64, + stats_type: StatsType, +} + +impl CovarianceAccumulator { + /// Creates a new `CovarianceAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + algo_const: 0_f64, + mean1: 0_f64, + mean2: 0_f64, + count: 0_f64, + stats_type: s_type, + }) + } + + pub fn get_count(&self) -> f64 { + self.count + } + + pub fn get_mean1(&self) -> f64 { + self.mean1 + } + + pub fn get_mean2(&self) -> f64 { + self.mean2 + } + + pub fn get_algo_const(&self) -> f64 { + self.algo_const + } +} + +impl Accumulator for CovarianceAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean1), + ScalarValue::from(self.mean2), + ScalarValue::from(self.algo_const), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values1 = &cast(&values[0], &DataType::Float64)?; + let values2 = &cast(&values[1], &DataType::Float64)?; + + let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); + let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + + for i in 0..values1.len() { + let value1 = if values1.is_valid(i) { + arr1.next() + } else { + None + }; + let value2 = if values2.is_valid(i) { + arr2.next() + } else { + None + }; + + if value1.is_none() || value2.is_none() { + continue; + } + + let value1 = unwrap_or_internal_err!(value1); + let value2 = unwrap_or_internal_err!(value2); + let new_count = self.count + 1.0; + let delta1 = value1 - self.mean1; + let new_mean1 = delta1 / new_count + self.mean1; + let delta2 = value2 - self.mean2; + let new_mean2 = delta2 / new_count + self.mean2; + let new_c = delta1 * (value2 - new_mean2) + self.algo_const; + + self.count += 1.0; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values1 = &cast(&values[0], &DataType::Float64)?; + let values2 = &cast(&values[1], &DataType::Float64)?; + let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); + let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + + for i in 0..values1.len() { + let value1 = if values1.is_valid(i) { + arr1.next() + } else { + None + }; + let value2 = if values2.is_valid(i) { + arr2.next() + } else { + None + }; + + if value1.is_none() || value2.is_none() { + continue; + } + + let value1 = unwrap_or_internal_err!(value1); + let value2 = unwrap_or_internal_err!(value2); + + let new_count = self.count - 1.0; + let delta1 = self.mean1 - value1; + let new_mean1 = delta1 / new_count + self.mean1; + let delta2 = self.mean2 - value2; + let new_mean2 = delta2 / new_count + self.mean2; + let new_c = self.algo_const - delta1 * (new_mean2 - value2); + + self.count -= 1.0; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], Float64Array); + let means1 = downcast_value!(states[1], Float64Array); + let means2 = downcast_value!(states[2], Float64Array); + let cs = downcast_value!(states[3], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0.0 { + continue; + } + let new_count = self.count + c; + let new_mean1 = self.mean1 * self.count / new_count + means1.value(i) * c / new_count; + let new_mean2 = self.mean2 * self.count / new_count + means2.value(i) * c / new_count; + let delta1 = self.mean1 - means1.value(i); + let delta2 = self.mean2 - means2.value(i); + let new_c = + self.algo_const + cs.value(i) + delta1 * delta2 * self.count * c / new_count; + + self.count = new_count; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0.0 { + self.count - 1.0 + } else { + self.count + } + } + }; + + if count == 0.0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(self.algo_const / count))) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 69cdf3e99..799790c9f 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -27,6 +27,8 @@ pub use normalize_nan::NormalizeNaNAndZero; pub mod avg; pub mod avg_decimal; pub mod bloom_filter_might_contain; +pub mod covariance; +pub mod stats; pub mod strings; pub mod subquery; pub mod sum_decimal; diff --git a/core/src/execution/datafusion/expressions/stats.rs b/core/src/execution/datafusion/expressions/stats.rs new file mode 100644 index 000000000..1f4e64d0b --- /dev/null +++ b/core/src/execution/datafusion/expressions/stats.rs @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/// Enum used for differentiating population and sample for statistical functions +#[derive(PartialEq, Eq, Debug, Clone, Copy)] +pub enum StatsType { + /// Population + Population, + /// Sample + Sample, +} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index ab83872c3..0b515f8ee 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -67,8 +67,10 @@ use crate::{ bloom_filter_might_contain::BloomFilterMightContain, cast::Cast, checkoverflow::CheckOverflow, + covariance::Covariance, if_expr::IfExpr, scalar_funcs::create_comet_physical_fun, + stats::StatsType, strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExec, SubstringExec}, subquery::Subquery, sum_decimal::SumDecimal, @@ -1180,6 +1182,30 @@ impl PhysicalPlanner { let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); Ok(Arc::new(BitXor::new(child, "bit_xor", datatype))) } + AggExprStruct::CovSample(expr) => { + let child1 = self.create_expr(expr.child1.as_ref().unwrap(), schema.clone())?; + let child2 = self.create_expr(expr.child2.as_ref().unwrap(), schema.clone())?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(Covariance::new( + child1, + child2, + "covariance", + datatype, + StatsType::Sample, + ))) + } + AggExprStruct::CovPopulation(expr) => { + let child1 = self.create_expr(expr.child1.as_ref().unwrap(), schema.clone())?; + let child2 = self.create_expr(expr.child2.as_ref().unwrap(), schema.clone())?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(Covariance::new( + child1, + child2, + "covariance_pop", + datatype, + StatsType::Population, + ))) + } } } diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 58f607fc0..1a6c29cf3 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -92,6 +92,8 @@ message AggExpr { BitAndAgg bitAndAgg = 9; BitOrAgg bitOrAgg = 10; BitXorAgg bitXorAgg = 11; + CovSample covSample = 12; + CovPopulation covPopulation = 13; } } @@ -149,6 +151,20 @@ message BitXorAgg { DataType datatype = 2; } +message CovSample { + Expr child1 = 1; + Expr child2 = 2; + bool null_on_divide_by_zero = 3; + DataType datatype = 4; +} + +message CovPopulation { + Expr child1 = 1; + Expr child2 = 2; + bool null_on_divide_by_zero = 3; + DataType datatype = 4; +} + message Literal { oneof value { bool bool_val = 1; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 26fc708ff..ca5a2cfd7 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Final, First, Last, Max, Min, Partial, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, Sum} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ @@ -388,7 +388,44 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } else { None } + case cov @ CovSample(child1, child2, _) => + val child1Expr = exprToProto(child1, inputs, binding) + val child2Expr = exprToProto(child2, inputs, binding) + val dataType = serializeDataType(cov.dataType) + if (child1Expr.isDefined && child2Expr.isDefined && dataType.isDefined) { + val covBuilder = ExprOuterClass.CovSample.newBuilder() + covBuilder.setChild1(child1Expr.get) + covBuilder.setChild2(child2Expr.get) + covBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setCovSample(covBuilder) + .build()) + } else { + None + } + case cov @ CovPopulation(child1, child2, _) => + val child1Expr = exprToProto(child1, inputs, binding) + val child2Expr = exprToProto(child2, inputs, binding) + val dataType = serializeDataType(cov.dataType) + + if (child1Expr.isDefined && child2Expr.isDefined && dataType.isDefined) { + val covBuilder = ExprOuterClass.CovPopulation.newBuilder() + covBuilder.setChild1(child1Expr.get) + covBuilder.setChild2(child2Expr.get) + covBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setCovPopulation(covBuilder) + .build()) + } else { + None + } case fn => emitWarning(s"unsupported Spark aggregate function: $fn") None diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 982d39f1f..134223a73 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -997,6 +997,45 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("covar_pop and covar_samp") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + Seq(true, false).foreach { cometColumnShuffleEnabled => + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq(true, false).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "test" + withTable(table) { + sql( + s"create table $table(col1 int, col2 int, col3 int, col4 float, col5 double," + + " col6 double, col7 int) using parquet") + sql(s"insert into $table values(1, 4, null, 1.1, 2.2, null, 1)," + + " (2, 5, 6, 3.4, 5.6, null, 1), (3, 6, null, 7.9, 2.4, null, 2)") + val expectedNumOfCometAggregates = 2 + checkSparkAnswerAndNumOfAggregates( + "SELECT covar_samp(col1, col2), covar_samp(col1, col3), covar_samp(col4, col5)," + + " covar_samp(col4, col6) FROM test", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT covar_pop(col1, col2), covar_pop(col1, col3), covar_pop(col4, col5)," + + " covar_pop(col4, col6) FROM test", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT covar_samp(col1, col2), covar_samp(col1, col3), covar_samp(col4, col5)," + + " covar_samp(col4, col6) FROM test GROUP BY col7", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT covar_pop(col1, col2), covar_pop(col1, col3), covar_pop(col4, col5)," + + " covar_pop(col4, col6) FROM test GROUP BY col7", + expectedNumOfCometAggregates) + } + } + } + } + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df)