diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index f2aba74a0d..6a683e24d3 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -65,6 +65,9 @@ class CometSparkSessionExtensions case class CometExecColumnar(session: SparkSession) extends ColumnarRule { override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session) + + override def postColumnarTransitions: Rule[SparkPlan] = + EliminateRedundantColumnarToRow(session) } case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] { @@ -278,6 +281,20 @@ class CometSparkSessionExtensions op } + case op: CollectLimitExec + if isCometNative(op.child) && isCometOperatorEnabled(conf, "collectLimit") + && isCometShuffleEnabled(conf) + && getOffset(op).getOrElse(0) == 0 => + QueryPlanSerde.operator2Proto(op) match { + case Some(nativeOp) => + val offset = getOffset(op).getOrElse(0) + val cometOp = + CometCollectLimitExec(op, op.limit, offset, op.child) + CometSinkPlaceHolder(nativeOp, op, cometOp) + case None => + op + } + case op: ExpandExec => val newOp = transform1(op) newOp match { @@ -451,6 +468,23 @@ class CometSparkSessionExtensions } } } + + // CometExec already wraps a `ColumnarToRowExec` for row-based operators. Therefore, + // `ColumnarToRowExec` is redundant and can be eliminated. + // + // It was added during ApplyColumnarRulesAndInsertTransitions' insertTransitions phase when Spark + // requests row-based output such as `collect` call. It's correct to add a redundant + // `ColumnarToRowExec` for `CometExec`. However, for certain operators such as + // `CometCollectLimitExec` which overrides `executeCollect`, the redundant `ColumnarToRowExec` + // makes the override ineffective. The purpose of this rule is to eliminate the redundant + // `ColumnarToRowExec` for such operators. + case class EliminateRedundantColumnarToRow(session: SparkSession) extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + plan.transform { case ColumnarToRowExec(child: CometCollectLimitExec) => + child + } + } + } } object CometSparkSessionExtensions extends Logging { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index fcc0ca9c5e..46eb1b0035 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1835,6 +1835,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case s if isCometScan(s) => true case _: CometSinkPlaceHolder => true case _: CoalesceExec => true + case _: CollectLimitExec => true case _: UnionExec => true case _: ShuffleExchangeExec => true case _: TakeOrderedAndProjectExec => true diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 8afed84ff5..3ccac5456e 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.execution.LimitExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan trait ShimCometSparkSessionExtensions { @@ -32,4 +33,9 @@ trait ShimCometSparkSessionExtensions { .map { a => a.setAccessible(true); a } .flatMap(_.get(scan).asInstanceOf[Option[Aggregation]]) .headOption + + def getOffset(limit: LimitExec): Option[Int] = limit.getClass.getDeclaredFields + .filter(_.getName == "offset") + .map { a => a.setAccessible(true); a.get(limit).asInstanceOf[Int] } + .headOption } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala index fc4f90f897..532eb91a99 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.comet -import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} @@ -42,7 +41,7 @@ case class CometCoalesceExec( if (numPartitions == 1 && rdd.getNumPartitions < 1) { // Make sure we don't output an RDD with 0 partitions, when claiming that we have a // `SinglePartition`. - new CometCoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + CometExecUtils.createEmptyColumnarRDDWithSinglePartition(sparkContext) } else { rdd.coalesce(numPartitions, shuffle = false) } @@ -67,20 +66,3 @@ case class CometCoalesceExec( override def hashCode(): Int = Objects.hashCode(numPartitions: java.lang.Integer, child) } - -object CometCoalesceExec { - - /** A simple RDD with no data, but with the given number of partitions. */ - class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int) - extends RDD[ColumnarBatch](sc, Nil) { - - override def getPartitions: Array[Partition] = - Array.tabulate(numPartitions)(i => EmptyPartition(i)) - - override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - Iterator.empty - } - } - - case class EmptyPartition(index: Int) extends Partition -} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala new file mode 100644 index 0000000000..5727d0cce7 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -0,0 +1,113 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license + * agreements. See the NOTICE file distributed with this work for additional information regarding + * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the + * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.comet + +import java.util.Objects + +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} +import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Comet physical plan node for Spark `CollectExecNode`. + * + * Similar to `CometTakeOrderedAndProjectExec`, it contains two native executions seperated by a + * comet shuffle. + */ +case class CometCollectLimitExec( + override val originalPlan: SparkPlan, + limit: Int, + offset: Int, + child: SparkPlan) + extends CometExec + with UnaryExecNode { + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "shuffleReadElapsedCompute" -> + SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"), + "numPartitions" -> SQLMetrics.createMetric( + sparkContext, + "number of partitions")) ++ readMetrics ++ writeMetrics + + private lazy val serializer: Serializer = + new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + + override def executeCollect(): Array[InternalRow] = { + ColumnarToRowExec(child).executeTake(limit) + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val childRDD = child.executeColumnar() + if (childRDD.getNumPartitions == 0) { + CometExecUtils.createEmptyColumnarRDDWithSinglePartition(sparkContext) + } else { + val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { + childRDD + } else { + val localLimitedRDD = if (limit >= 0) { + childRDD.mapPartitionsInternal { iter => + val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get + CometExec.getCometIterator(Seq(iter), limitOp) + } + } else { + childRDD + } + // Shuffle to Single Partition using Comet native shuffle + val dep = CometShuffleExchangeExec.prepareShuffleDependency( + localLimitedRDD, + child.output, + outputPartitioning, + serializer, + metrics) + metrics("numPartitions").set(dep.partitioner.numPartitions) + + new CometShuffledBatchRDD(dep, readMetrics) + } + + // todo: supports offset later + singlePartitionRDD.mapPartitionsInternal { iter => + val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get + CometExec.getCometIterator(Seq(iter), limitOp) + } + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + this.copy(child = newChild) + + override def stringArgs: Iterator[Any] = Iterator(limit, offset, child) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometCollectLimitExec => + this.limit == other.limit && this.offset == other.offset && + this.child == other.child + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(limit: java.lang.Integer, offset: java.lang.Integer, child) +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 9f8f2150c1..b77fc5c608 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -21,8 +21,11 @@ package org.apache.spark.sql.comet import scala.collection.JavaConverters.asJavaIterableConverter +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.serde.OperatorOuterClass import org.apache.comet.serde.OperatorOuterClass.Operator @@ -30,6 +33,11 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} object CometExecUtils { + def createEmptyColumnarRDDWithSinglePartition( + sparkContext: SparkContext): RDD[ColumnarBatch] = { + new EmptyRDDWithPartitions(sparkContext, 1) + } + /** * Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec. */ @@ -119,3 +127,17 @@ object CometExecUtils { } } } + +/** A simple RDD with no data, but with the given number of partitions. */ +private class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int) + extends RDD[ColumnarBatch](sc, Nil) { + + override def getPartitions: Array[Partition] = + Array.tabulate(numPartitions)(i => EmptyPartition(i)) + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + Iterator.empty + } +} + +private case class EmptyPartition(index: Int) extends Partition diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 0414671c2c..83a1b4bba0 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec @@ -1071,6 +1071,32 @@ class CometExecSuite extends CometTestBase { } }) } + + test("collect limit") { + Seq("true", "false").foreach(aqe => { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe) { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { + val df = sql("SELECT _1 as id, _2 as value FROM tbl limit 2") + assert(df.queryExecution.executedPlan.execute().getNumPartitions === 1) + checkSparkAnswerAndOperator(df, Seq(classOf[CometCollectLimitExec])) + assert(df.collect().length === 2) + + // checks CometCollectExec.doExecuteColumnar is indirectly called + val qe = df.queryExecution + SQLExecution.withNewExecutionId(qe, Some("count")) { + qe.executedPlan.resetMetrics() + assert(qe.executedPlan.execute().count() === 2) + } + + assert(df.isEmpty === false) + + // follow up native operation is possible + val df3 = df.groupBy("id").sum("value") + checkSparkAnswerAndOperator(df3) + } + } + }) + } } case class BucketedTableTestSpec( diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 2e523fa5a4..e85fbcfba9 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -150,7 +150,15 @@ abstract class CometTestBase protected def checkSparkAnswerAndOperator( df: => DataFrame, excludedClasses: Class[_]*): Unit = { + checkSparkAnswerAndOperator(df, Seq.empty, excludedClasses: _*) + } + + protected def checkSparkAnswerAndOperator( + df: => DataFrame, + includeClasses: Seq[Class[_]], + excludedClasses: Class[_]*): Unit = { checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), excludedClasses: _*) + checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), includeClasses: _*) checkSparkAnswer(df) } @@ -173,6 +181,17 @@ abstract class CometTestBase } } + protected def checkPlanContains(plan: SparkPlan, includePlans: Class[_]*): Unit = { + includePlans.foreach { case planClass => + if (!plan.exists(op => planClass.isAssignableFrom(op.getClass))) { + assert( + false, + s"Expected plan to contain ${planClass.getSimpleName}.\n" + + s"plan: $plan") + } + } + } + /** * Check the answer of a Comet SQL query with Spark result using absolute tolerance. */