diff --git a/core/src/execution/datafusion/expressions/avg.rs b/core/src/execution/datafusion/expressions/avg.rs index 847acae650..1ff276e5d0 100644 --- a/core/src/execution/datafusion/expressions/avg.rs +++ b/core/src/execution/datafusion/expressions/avg.rs @@ -179,7 +179,7 @@ impl Accumulator for AvgAccumulator { if self.count == 0 { // If all input are nulls, count will be 0 and we will get null after the division. // This is consistent with Spark Average implementation. - return Ok(ScalarValue::Float64(None)); + Ok(ScalarValue::Float64(None)) } else { Ok(ScalarValue::Float64( self.sum.map(|f| f / self.count as f64), diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 982d39f1f0..62b9411779 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -40,6 +40,23 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ + test( + "Average expression in Comet Final should handle " + + "all null inputs from partial Spark aggregation") { + withTempView("allNulls") { + allNulls.createOrReplaceTempView("allNulls") + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + val df = sql("select sum(a), avg(a) from allNulls") + df.collect() + df.explain() + checkSparkAnswer(df) + } + } + } + test("Aggregation without aggregate expressions should use correct result expressions") { withSQLConf( CometConf.COMET_ENABLED.key -> "true",