diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 69d1fb367..f4f56f04f 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -292,14 +292,18 @@ class CometSparkSessionExtensions newOp match { case Some(nativeOp) => val modes = aggExprs.map(_.mode).distinct - assert(modes.length == 1) + // The aggExprs could be empty. For example, if the aggregate functions only have + // distinct aggregate functions or only have group by, the aggExprs is empty and + // modes is empty too. If aggExprs is not empty, we need to verify all the aggregates + // have the same mode. + assert(modes.length == 1 || modes.length == 0) CometHashAggregateExec( nativeOp, op, groupingExprs, aggExprs, child.output, - modes.head, + if (modes.nonEmpty) Some(modes.head) else None, child) case None => op diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e52018b38..cacadbbed 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.comet.{CometHashAggregateExec, CometSinkPlaceHolder, DecimalPrecision} +import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkPlaceHolder, DecimalPrecision} import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -1632,60 +1632,97 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { _, groupingExpressions, aggregateExpressions, - _, + aggregateAttributes, _, resultExpressions, child) if isCometOperatorEnabled(op.conf, "aggregate") => - val modes = aggregateExpressions.map(_.mode).distinct - - if (modes.size != 1) { - // This shouldn't happen as all aggregation expressions should share the same mode. - // Fallback to Spark nevertheless here. + if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) { return None } - val mode = modes.head match { - case Partial => CometAggregateMode.Partial - case Final => CometAggregateMode.Final - case _ => return None - } - - val output = mode match { - case CometAggregateMode.Partial => child.output - case CometAggregateMode.Final => - // Assuming `Final` always follows `Partial` aggregation, this find the first - // `Partial` aggregation and get the input attributes from it. - child.collectFirst { case CometHashAggregateExec(_, _, _, _, input, Partial, _) => - input - } match { - case Some(input) => input - case _ => return None - } - case _ => return None - } - - val aggExprs = aggregateExpressions.map(aggExprToProto(_, output)) val groupingExprs = groupingExpressions.map(exprToProto(_, child.output)) - if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && - aggExprs.forall(_.isDefined)) { + // In some of the cases, the aggregateExpressions could be empty. + // For example, if the aggregate functions only have group by or if the aggregate + // functions only have distinct aggregate functions: + // + // SELECT COUNT(distinct col2), col1 FROM test group by col1 + // +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] ) + // +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36] + // +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] ) + // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) + // +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ... + // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) + // +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ...... + // If the aggregateExpressions is empty, we only want to build groupingExpressions, + // and skip processing of aggregateExpressions. + if (aggregateExpressions.isEmpty) { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) - hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) - if (mode == CometAggregateMode.Final) { - val attributes = groupingExpressions.map(_.toAttribute) ++ - aggregateExpressions.map(_.resultAttribute) - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - emitWarning(s"Unsupported result expressions found in: ${resultExpressions}") - return None - } - hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) - } - hashAggBuilder.setModeValue(mode.getNumber) Some(result.setHashAgg(hashAggBuilder).build()) } else { - None + val modes = aggregateExpressions.map(_.mode).distinct + + if (modes.size != 1) { + // This shouldn't happen as all aggregation expressions should share the same mode. + // Fallback to Spark nevertheless here. + return None + } + + val mode = modes.head match { + case Partial => CometAggregateMode.Partial + case Final => CometAggregateMode.Final + case _ => return None + } + + val output = mode match { + case CometAggregateMode.Partial => child.output + case CometAggregateMode.Final => + // Assuming `Final` always follows `Partial` aggregation, this find the first + // `Partial` aggregation and get the input attributes from it. + // During finding partial aggregation, we must ensure all traversed op are + // native operators. If not, we should fallback to Spark. + var seenNonNativeOp = false + var partialAggInput: Option[Seq[Attribute]] = None + child.transformDown { + case op if !op.isInstanceOf[CometPlan] => + seenNonNativeOp = true + op + case op @ CometHashAggregateExec(_, _, _, _, input, Some(Partial), _) => + if (!seenNonNativeOp && partialAggInput.isEmpty) { + partialAggInput = Some(input) + } + op + } + + if (partialAggInput.isDefined) { + partialAggInput.get + } else { + return None + } + case _ => return None + } + + val aggExprs = aggregateExpressions.map(aggExprToProto(_, output)) + if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && + aggExprs.forall(_.isDefined)) { + val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() + hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) + hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) + if (mode == CometAggregateMode.Final) { + val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val resultExprs = resultExpressions.map(exprToProto(_, attributes)) + if (resultExprs.exists(_.isEmpty)) { + emitWarning(s"Unsupported result expressions found in: ${resultExpressions}") + return None + } + hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) + } + hashAggBuilder.setModeValue(mode.getNumber) + Some(result.setHashAgg(hashAggBuilder).build()) + } else { + None + } } case op if isCometSink(op) => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index eac013e83..7ac1084d8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -421,7 +421,7 @@ case class CometHashAggregateExec( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], input: Seq[Attribute], - mode: AggregateMode, + mode: Option[AggregateMode], child: SparkPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index da096e56e..a5f2dd210 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -381,7 +381,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { "tbl", dictionaryEnabled) { checkSparkAnswer( - "SELECT _2, SUM(_1), MIN(_1), MAX(_1), COUNT(_1), AVG(_1) FROM tbl GROUP BY _2") + "SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1), COUNT(DISTINCT _1), AVG(_1) FROM tbl GROUP BY _2") } } } @@ -402,6 +402,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withParquetTable(path.toUri.toString, "tbl") { checkSparkAnswer("SELECT _g1, _g2, SUM(_3) FROM tbl GROUP BY _g1, _g2") checkSparkAnswer("SELECT _g1, _g2, COUNT(_3) FROM tbl GROUP BY _g1, _g2") + checkSparkAnswer("SELECT _g1, _g2, SUM(DISTINCT _3) FROM tbl GROUP BY _g1, _g2") + checkSparkAnswer("SELECT _g1, _g2, COUNT(DISTINCT _3) FROM tbl GROUP BY _g1, _g2") checkSparkAnswer("SELECT _g1, _g2, MIN(_3), MAX(_3) FROM tbl GROUP BY _g1, _g2") checkSparkAnswer("SELECT _g1, _g2, AVG(_3) FROM tbl GROUP BY _g1, _g2") } @@ -432,8 +434,12 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { makeParquetFile(path, numValues, numGroups, dictionaryEnabled) withParquetTable(path.toUri.toString, "tbl") { checkSparkAnswer("SELECT _g3, _g4, SUM(_3), SUM(_4) FROM tbl GROUP BY _g3, _g4") + checkSparkAnswer( + "SELECT _g3, _g4, SUM(DISTINCT _3), SUM(DISTINCT _4) FROM tbl GROUP BY _g3, _g4") checkSparkAnswer( "SELECT _g3, _g4, COUNT(_3), COUNT(_4) FROM tbl GROUP BY _g3, _g4") + checkSparkAnswer( + "SELECT _g3, _g4, COUNT(DISTINCT _3), COUNT(DISTINCT _4) FROM tbl GROUP BY _g3, _g4") checkSparkAnswer( "SELECT _g3, _g4, MIN(_3), MAX(_3), MIN(_4), MAX(_4) FROM tbl GROUP BY _g3, _g4") checkSparkAnswer("SELECT _g3, _g4, AVG(_3), AVG(_4) FROM tbl GROUP BY _g3, _g4") @@ -461,7 +467,11 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (1 to 4).foreach { col => (1 to 14).foreach { gCol => checkSparkAnswer(s"SELECT _g$gCol, SUM(_$col) FROM tbl GROUP BY _g$gCol") + checkSparkAnswer( + s"SELECT _g$gCol, SUM(DISTINCT _$col) FROM tbl GROUP BY _g$gCol") checkSparkAnswer(s"SELECT _g$gCol, COUNT(_$col) FROM tbl GROUP BY _g$gCol") + checkSparkAnswer( + s"SELECT _g$gCol, COUNT(DISTINCT _$col) FROM tbl GROUP BY _g$gCol") checkSparkAnswer( s"SELECT _g$gCol, MIN(_$col), MAX(_$col) FROM tbl GROUP BY _g$gCol") checkSparkAnswer(s"SELECT _g$gCol, AVG(_$col) FROM tbl GROUP BY _g$gCol") @@ -701,6 +711,61 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("distinct") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + Seq(true, false).foreach { bosonColumnShuffleEnabled => + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> bosonColumnShuffleEnabled.toString) { + Seq(true, false).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql( + s"insert into $table values(1, 1, 1), (1, 1, 1), (1, 3, 1), (1, 4, 2), (5, 3, 2)") + + var expectedNumOfBosonAggregates = 2 + + checkSparkAnswerAndNumOfAggregates( + s"SELECT DISTINCT(col2) FROM $table", + expectedNumOfBosonAggregates) + + expectedNumOfBosonAggregates = 4 + + checkSparkAnswerAndNumOfAggregates( + s"SELECT COUNT(distinct col2) FROM $table", + expectedNumOfBosonAggregates) + + checkSparkAnswerAndNumOfAggregates( + s"SELECT COUNT(distinct col2), col1 FROM $table group by col1", + expectedNumOfBosonAggregates) + + checkSparkAnswerAndNumOfAggregates( + s"SELECT SUM(distinct col2) FROM $table", + expectedNumOfBosonAggregates) + + checkSparkAnswerAndNumOfAggregates( + s"SELECT SUM(distinct col2), col1 FROM $table group by col1", + expectedNumOfBosonAggregates) + + checkSparkAnswerAndNumOfAggregates( + "SELECT COUNT(distinct col2), SUM(distinct col2), col1, COUNT(distinct col2)," + + s" SUM(distinct col2) FROM $table group by col1", + expectedNumOfBosonAggregates) + + expectedNumOfBosonAggregates = 1 + checkSparkAnswerAndNumOfAggregates( + "SELECT COUNT(col2), MIN(col2), COUNT(DISTINCT col2), SUM(col2)," + + s" SUM(DISTINCT col2), COUNT(DISTINCT col2), col1 FROM $table group by col1", + expectedNumOfBosonAggregates) + } + } + } + } + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df)