Skip to content

Commit

Permalink
fix: Average expression in Comet Final should handle all null inputs …
Browse files Browse the repository at this point in the history
…from partial Spark aggregation (#261)
  • Loading branch information
viirya authored Apr 12, 2024
1 parent b7d2c63 commit 421f0e0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
12 changes: 9 additions & 3 deletions core/src/execution/datafusion/expressions/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,15 @@ impl Accumulator for AvgAccumulator {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(
self.sum.map(|f| f / self.count as f64),
))
if self.count == 0 {
// If all input are nulls, count will be 0 and we will get null after the division.
// This is consistent with Spark Average implementation.
Ok(ScalarValue::Float64(None))
} else {
Ok(ScalarValue::Float64(
self.sum.map(|f| f / self.count as f64),
))
}
}

fn size(&self) -> usize {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,21 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._

test(
"Average expression in Comet Final should handle " +
"all null inputs from partial Spark aggregation") {
withTempView("allNulls") {
allNulls.createOrReplaceTempView("allNulls")
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
val df = sql("select sum(a), avg(a) from allNulls")
checkSparkAnswer(df)
}
}
}

test("Aggregation without aggregate expressions should use correct result expressions") {
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
Expand Down

0 comments on commit 421f0e0

Please sign in to comment.