diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 10c332801..dae9f3fd5 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -65,6 +65,9 @@ class CometSparkSessionExtensions case class CometExecColumnar(session: SparkSession) extends ColumnarRule { override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session) + + override def postColumnarTransitions: Rule[SparkPlan] = + EliminateRedundantColumnarToRow(session) } case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] { @@ -284,6 +287,20 @@ class CometSparkSessionExtensions op } + case op: CollectLimitExec + if isCometNative(op.child) && isCometOperatorEnabled(conf, "collectLimit") + && isCometShuffleEnabled(conf) + && getOffset(op) == 0 => + QueryPlanSerde.operator2Proto(op) match { + case Some(nativeOp) => + val offset = getOffset(op) + val cometOp = + CometCollectLimitExec(op, op.limit, offset, op.child) + CometSinkPlaceHolder(nativeOp, op, cometOp) + case None => + op + } + case op: ExpandExec => val newOp = transform1(op) newOp match { @@ -457,6 +474,26 @@ class CometSparkSessionExtensions } } } + + // CometExec already wraps a `ColumnarToRowExec` for row-based operators. Therefore, + // `ColumnarToRowExec` is redundant and can be eliminated. + // + // It was added during ApplyColumnarRulesAndInsertTransitions' insertTransitions phase when Spark + // requests row-based output such as `collect` call. It's correct to add a redundant + // `ColumnarToRowExec` for `CometExec`. However, for certain operators such as + // `CometCollectLimitExec` which overrides `executeCollect`, the redundant `ColumnarToRowExec` + // makes the override ineffective. The purpose of this rule is to eliminate the redundant + // `ColumnarToRowExec` for such operators. + case class EliminateRedundantColumnarToRow(session: SparkSession) extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + plan match { + case ColumnarToRowExec(child: CometCollectLimitExec) => + child + case other => + other + } + } + } } object CometSparkSessionExtensions extends Logging { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index fcc0ca9c5..46eb1b003 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1835,6 +1835,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case s if isCometScan(s) => true case _: CometSinkPlaceHolder => true case _: CoalesceExec => true + case _: CollectLimitExec => true case _: UnionExec => true case _: ShuffleExchangeExec => true case _: TakeOrderedAndProjectExec => true diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 8afed84ff..85c6413e1 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -20,9 +20,11 @@ package org.apache.comet.shims import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.execution.{LimitExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan trait ShimCometSparkSessionExtensions { + import org.apache.comet.shims.ShimCometSparkSessionExtensions._ /** * TODO: delete after dropping Spark 3.2.0 support and directly call scan.pushedAggregate @@ -32,4 +34,19 @@ trait ShimCometSparkSessionExtensions { .map { a => a.setAccessible(true); a } .flatMap(_.get(scan).asInstanceOf[Option[Aggregation]]) .headOption + + /** + * TODO: delete after dropping Spark 3.2 and 3.3 support + */ + def getOffset(limit: LimitExec): Int = getOffsetOpt(limit).getOrElse(0) + +} + +object ShimCometSparkSessionExtensions { + private def getOffsetOpt(plan: SparkPlan): Option[Int] = plan.getClass.getDeclaredFields + .filter(_.getName == "offset") + .map { a => a.setAccessible(true); a.get(plan) } + .filter(_.isInstanceOf[Int]) + .map(_.asInstanceOf[Int]) + .headOption } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala index fc4f90f89..cc635d739 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.comet -import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} @@ -42,7 +41,7 @@ case class CometCoalesceExec( if (numPartitions == 1 && rdd.getNumPartitions < 1) { // Make sure we don't output an RDD with 0 partitions, when claiming that we have a // `SinglePartition`. - new CometCoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + CometExecUtils.emptyRDDWithPartitions(sparkContext, 1) } else { rdd.coalesce(numPartitions, shuffle = false) } @@ -67,20 +66,3 @@ case class CometCoalesceExec( override def hashCode(): Int = Objects.hashCode(numPartitions: java.lang.Integer, child) } - -object CometCoalesceExec { - - /** A simple RDD with no data, but with the given number of partitions. */ - class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int) - extends RDD[ColumnarBatch](sc, Nil) { - - override def getPartitions: Array[Partition] = - Array.tabulate(numPartitions)(i => EmptyPartition(i)) - - override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - Iterator.empty - } - } - - case class EmptyPartition(index: Int) extends Partition -} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala new file mode 100644 index 000000000..83126a7ba --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import java.util.Objects + +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} +import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Comet physical plan node for Spark `CollectLimitExec`. + * + * Similar to `CometTakeOrderedAndProjectExec`, it contains two native executions seperated by a + * Comet shuffle. + * + * TODO: support offset semantics + */ +case class CometCollectLimitExec( + override val originalPlan: SparkPlan, + limit: Int, + offset: Int, + child: SparkPlan) + extends CometExec + with UnaryExecNode { + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "shuffleReadElapsedCompute" -> + SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"), + "numPartitions" -> SQLMetrics.createMetric( + sparkContext, + "number of partitions")) ++ readMetrics ++ writeMetrics + + private lazy val serializer: Serializer = + new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + + override def executeCollect(): Array[InternalRow] = { + ColumnarToRowExec(child).executeTake(limit) + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val childRDD = child.executeColumnar() + if (childRDD.getNumPartitions == 0) { + CometExecUtils.emptyRDDWithPartitions(sparkContext, 1) + } else { + val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { + childRDD + } else { + val localLimitedRDD = if (limit >= 0) { + CometExecUtils.getNativeLimitRDD(childRDD, output, limit) + } else { + childRDD + } + // Shuffle to Single Partition using Comet shuffle + val dep = CometShuffleExchangeExec.prepareShuffleDependency( + localLimitedRDD, + child.output, + outputPartitioning, + serializer, + metrics) + metrics("numPartitions").set(dep.partitioner.numPartitions) + + new CometShuffledBatchRDD(dep, readMetrics) + } + CometExecUtils.getNativeLimitRDD(singlePartitionRDD, output, limit) + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + this.copy(child = newChild) + + override def stringArgs: Iterator[Any] = Iterator(limit, offset, child) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometCollectLimitExec => + this.limit == other.limit && this.offset == other.offset && + this.child == other.child + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(limit: java.lang.Integer, offset: java.lang.Integer, child) +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 9f8f2150c..5931920a2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -20,9 +20,13 @@ package org.apache.spark.sql.comet import scala.collection.JavaConverters.asJavaIterableConverter +import scala.reflect.ClassTag +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.serde.OperatorOuterClass import org.apache.comet.serde.OperatorOuterClass.Operator @@ -30,6 +34,29 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} object CometExecUtils { + /** + * Create an empty RDD with the given number of partitions. + */ + def emptyRDDWithPartitions[T: ClassTag]( + sparkContext: SparkContext, + numPartitions: Int): RDD[T] = { + new EmptyRDDWithPartitions(sparkContext, numPartitions) + } + + /** + * Transform the given RDD into a new RDD that takes the first `limit` elements of each + * partition. The limit operation is performed on the native side. + */ + def getNativeLimitRDD( + childPlan: RDD[ColumnarBatch], + outputAttribute: Seq[Attribute], + limit: Int): RDD[ColumnarBatch] = { + childPlan.mapPartitionsInternal { iter => + val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get + CometExec.getCometIterator(Seq(iter), limitOp) + } + } + /** * Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec. */ @@ -119,3 +146,19 @@ object CometExecUtils { } } } + +/** A simple RDD with no data, but with the given number of partitions. */ +private class EmptyRDDWithPartitions[T: ClassTag]( + @transient private val sc: SparkContext, + numPartitions: Int) + extends RDD[T](sc, Nil) { + + override def getPartitions: Array[Partition] = + Array.tabulate(numPartitions)(i => EmptyPartition(i)) + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + Iterator.empty + } +} + +private case class EmptyPartition(index: Int) extends Partition diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 88984388f..26ec401ed 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -77,11 +77,7 @@ case class CometTakeOrderedAndProjectExec( childRDD } else { val localTopK = if (orderingSatisfies) { - childRDD.mapPartitionsInternal { iter => - val limitOp = - CometExecUtils.getLimitNativePlan(output, limit).get - CometExec.getCometIterator(Seq(iter), limitOp) - } + CometExecUtils.getNativeLimitRDD(childRDD, output, limit) } else { childRDD.mapPartitionsInternal { iter => val topK = diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 05be34c10..d7434d518 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec @@ -1087,6 +1087,34 @@ class CometExecSuite extends CometTestBase { } }) } + + test("collect limit") { + Seq("true", "false").foreach(aqe => { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe) { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { + val df = sql("SELECT _1 as id, _2 as value FROM tbl limit 2") + assert(df.queryExecution.executedPlan.execute().getNumPartitions === 1) + checkSparkAnswerAndOperator(df, Seq(classOf[CometCollectLimitExec])) + assert(df.collect().length === 2) + + val qe = df.queryExecution + // make sure the root node is CometCollectLimitExec + assert(qe.executedPlan.isInstanceOf[CometCollectLimitExec]) + // executes CometCollectExec directly to check doExecuteColumnar implementation + SQLExecution.withNewExecutionId(qe, Some("count")) { + qe.executedPlan.resetMetrics() + assert(qe.executedPlan.execute().count() === 2) + } + + assert(df.isEmpty === false) + + // follow up native operation is possible + val df3 = df.groupBy("id").sum("value") + checkSparkAnswerAndOperator(df3) + } + } + }) + } } case class BucketedTableTestSpec( diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 2e523fa5a..0d7904c01 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -150,7 +150,15 @@ abstract class CometTestBase protected def checkSparkAnswerAndOperator( df: => DataFrame, excludedClasses: Class[_]*): Unit = { + checkSparkAnswerAndOperator(df, Seq.empty, excludedClasses: _*) + } + + protected def checkSparkAnswerAndOperator( + df: => DataFrame, + includeClasses: Seq[Class[_]], + excludedClasses: Class[_]*): Unit = { checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), excludedClasses: _*) + checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), includeClasses: _*) checkSparkAnswer(df) } @@ -173,6 +181,17 @@ abstract class CometTestBase } } + protected def checkPlanContains(plan: SparkPlan, includePlans: Class[_]*): Unit = { + includePlans.foreach { case planClass => + if (!plan.exists(op => planClass.isAssignableFrom(op.getClass))) { + assert( + false, + s"Expected plan to contain ${planClass.getSimpleName}, but not.\n" + + s"plan: $plan") + } + } + } + /** * Check the answer of a Comet SQL query with Spark result using absolute tolerance. */