From 421f0e03b1f2719752eb354ce30210b894ed6e77 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 11 Apr 2024 19:00:53 -0700 Subject: [PATCH] fix: Average expression in Comet Final should handle all null inputs from partial Spark aggregation (#261) --- core/src/execution/datafusion/expressions/avg.rs | 12 +++++++++--- .../apache/comet/exec/CometAggregateSuite.scala | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/core/src/execution/datafusion/expressions/avg.rs b/core/src/execution/datafusion/expressions/avg.rs index e35ff6120..1ff276e5d 100644 --- a/core/src/execution/datafusion/expressions/avg.rs +++ b/core/src/execution/datafusion/expressions/avg.rs @@ -176,9 +176,15 @@ impl Accumulator for AvgAccumulator { } fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Float64( - self.sum.map(|f| f / self.count as f64), - )) + if self.count == 0 { + // If all input are nulls, count will be 0 and we will get null after the division. + // This is consistent with Spark Average implementation. + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64( + self.sum.map(|f| f / self.count as f64), + )) + } } fn size(&self) -> usize { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 982d39f1f..b95ce9b19 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -40,6 +40,21 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ + test( + "Average expression in Comet Final should handle " + + "all null inputs from partial Spark aggregation") { + withTempView("allNulls") { + allNulls.createOrReplaceTempView("allNulls") + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + val df = sql("select sum(a), avg(a) from allNulls") + checkSparkAnswer(df) + } + } + } + test("Aggregation without aggregate expressions should use correct result expressions") { withSQLConf( CometConf.COMET_ENABLED.key -> "true",