diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index ab83872c3..e53ebe763 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -1094,8 +1094,16 @@ impl PhysicalPlanner { ) -> Result, ExecutionError> { match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { - let child = self.create_expr(&expr.children[0], schema)?; - Ok(Arc::new(Count::new(child, "count", DataType::Int64))) + let children = expr + .children + .iter() + .map(|child| self.create_expr(child, schema.clone())) + .collect::, _>>()?; + Ok(Arc::new(Count::new_with_multiple_exprs( + children, + "count", + DataType::Int64, + ))) } AggExprStruct::Min(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?; diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 89681d3df..230ac36b0 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.optimizer.EliminateSorts import org.apache.spark.sql.comet.CometHashAggregateExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.sum +import org.apache.spark.sql.functions.{count_distinct, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf @@ -40,6 +40,25 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ + test("multiple column distinct count") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + val df1 = Seq( + ("a", "b", "c"), + ("a", "b", "c"), + ("a", "b", "d"), + ("x", "y", "z"), + ("x", "q", null.asInstanceOf[String])) + .toDF("key1", "key2", "key3") + + checkSparkAnswer(df1.agg(count_distinct($"key1", $"key2"))) + checkSparkAnswer(df1.agg(count_distinct($"key1", $"key2", $"key3"))) + checkSparkAnswer(df1.groupBy($"key1").agg(count_distinct($"key2", $"key3"))) + } + } + test("Only trigger Comet Final aggregation on Comet partial aggregation") { withTempView("lowerCaseData") { lowerCaseData.createOrReplaceTempView("lowerCaseData")