diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8b4eaa609..56b66909e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1799,43 +1799,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case _ => return None } - val output = mode match { - case CometAggregateMode.Partial => child.output - case CometAggregateMode.Final => - // Assuming `Final` always follows `Partial` aggregation, this find the first - // `Partial` aggregation and get the input attributes from it. - // During finding partial aggregation, we must ensure all traversed op are - // native operators. If not, we should fallback to Spark. - var seenNonNativeOp = false - var partialAggInput: Option[Seq[Attribute]] = None - child.transformDown { - case op if !op.isInstanceOf[CometPlan] => - seenNonNativeOp = true - op - case op @ CometHashAggregateExec(_, _, _, _, input, Some(Partial), _, _) => - if (!seenNonNativeOp && partialAggInput.isEmpty) { - partialAggInput = Some(input) - } - op - } - - if (partialAggInput.isDefined) { - partialAggInput.get - } else { - return None - } - case _ => return None - } - - val binding = if (mode == CometAggregateMode.Final) { - // In final mode, the aggregate expressions are bound to the output of the - // child and partial aggregate expressions buffer attributes produced by partial - // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, - // we don't have to do this because we don't use the merging expression. - false - } else { - true - } + // In final mode, the aggregate expressions are bound to the output of the + // child and partial aggregate expressions buffer attributes produced by partial + // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, + // we don't have to do this because we don't use the merging expression. + val binding = mode != CometAggregateMode.Final + // `output` is only used when `binding` is true (i.e., non-Final) + val output = child.output val aggExprs = aggregateExpressions.map(aggExprToProto(_, output, binding)) if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 3b4fb1c99..bc645cb6a 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -845,7 +845,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { s" SUM(distinct col2) FROM $table group by col1", expectedNumOfCometAggregates) - expectedNumOfCometAggregates = 1 + expectedNumOfCometAggregates = if (cometColumnShuffleEnabled) 2 else 1 checkSparkAnswerAndNumOfAggregates( "SELECT COUNT(col2), MIN(col2), COUNT(DISTINCT col2), SUM(col2)," + s" SUM(DISTINCT col2), COUNT(DISTINCT col2), col1 FROM $table group by col1",