From 248cd03211d85acc884f8b98f9acf6f7b3289b4f Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Sat, 24 Feb 2024 02:42:09 +0800 Subject: [PATCH 1/4] feat: Support CollectLimit operator --- .../comet/CometSparkSessionExtensions.scala | 34 ++++++ .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../ShimCometSparkSessionExtensions.scala | 6 + .../spark/sql/comet/CometCoalesceExec.scala | 20 +--- .../sql/comet/CometCollectLimitExec.scala | 113 ++++++++++++++++++ .../spark/sql/comet/CometExecUtils.scala | 22 ++++ .../apache/comet/exec/CometExecSuite.scala | 30 ++++- .../org/apache/spark/sql/CometTestBase.scala | 19 +++ 8 files changed, 224 insertions(+), 21 deletions(-) create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index f2aba74a0..6a683e24d 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -65,6 +65,9 @@ class CometSparkSessionExtensions case class CometExecColumnar(session: SparkSession) extends ColumnarRule { override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session) + + override def postColumnarTransitions: Rule[SparkPlan] = + EliminateRedundantColumnarToRow(session) } case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] { @@ -278,6 +281,20 @@ class CometSparkSessionExtensions op } + case op: CollectLimitExec + if isCometNative(op.child) && isCometOperatorEnabled(conf, "collectLimit") + && isCometShuffleEnabled(conf) + && getOffset(op).getOrElse(0) == 0 => + QueryPlanSerde.operator2Proto(op) match { + case Some(nativeOp) => + val offset = getOffset(op).getOrElse(0) + val cometOp = + CometCollectLimitExec(op, op.limit, offset, op.child) + CometSinkPlaceHolder(nativeOp, op, cometOp) + case None => + op + } + case op: ExpandExec => val newOp = transform1(op) newOp match { @@ -451,6 +468,23 @@ class CometSparkSessionExtensions } } } + + // CometExec already wraps a `ColumnarToRowExec` for row-based operators. Therefore, + // `ColumnarToRowExec` is redundant and can be eliminated. + // + // It was added during ApplyColumnarRulesAndInsertTransitions' insertTransitions phase when Spark + // requests row-based output such as `collect` call. It's correct to add a redundant + // `ColumnarToRowExec` for `CometExec`. However, for certain operators such as + // `CometCollectLimitExec` which overrides `executeCollect`, the redundant `ColumnarToRowExec` + // makes the override ineffective. The purpose of this rule is to eliminate the redundant + // `ColumnarToRowExec` for such operators. + case class EliminateRedundantColumnarToRow(session: SparkSession) extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + plan.transform { case ColumnarToRowExec(child: CometCollectLimitExec) => + child + } + } + } } object CometSparkSessionExtensions extends Logging { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index fcc0ca9c5..46eb1b003 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1835,6 +1835,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case s if isCometScan(s) => true case _: CometSinkPlaceHolder => true case _: CoalesceExec => true + case _: CollectLimitExec => true case _: UnionExec => true case _: ShuffleExchangeExec => true case _: TakeOrderedAndProjectExec => true diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 8afed84ff..3ccac5456 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.execution.LimitExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan trait ShimCometSparkSessionExtensions { @@ -32,4 +33,9 @@ trait ShimCometSparkSessionExtensions { .map { a => a.setAccessible(true); a } .flatMap(_.get(scan).asInstanceOf[Option[Aggregation]]) .headOption + + def getOffset(limit: LimitExec): Option[Int] = limit.getClass.getDeclaredFields + .filter(_.getName == "offset") + .map { a => a.setAccessible(true); a.get(limit).asInstanceOf[Int] } + .headOption } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala index fc4f90f89..532eb91a9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.comet -import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} @@ -42,7 +41,7 @@ case class CometCoalesceExec( if (numPartitions == 1 && rdd.getNumPartitions < 1) { // Make sure we don't output an RDD with 0 partitions, when claiming that we have a // `SinglePartition`. - new CometCoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + CometExecUtils.createEmptyColumnarRDDWithSinglePartition(sparkContext) } else { rdd.coalesce(numPartitions, shuffle = false) } @@ -67,20 +66,3 @@ case class CometCoalesceExec( override def hashCode(): Int = Objects.hashCode(numPartitions: java.lang.Integer, child) } - -object CometCoalesceExec { - - /** A simple RDD with no data, but with the given number of partitions. */ - class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int) - extends RDD[ColumnarBatch](sc, Nil) { - - override def getPartitions: Array[Partition] = - Array.tabulate(numPartitions)(i => EmptyPartition(i)) - - override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - Iterator.empty - } - } - - case class EmptyPartition(index: Int) extends Partition -} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala new file mode 100644 index 000000000..5727d0cce --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -0,0 +1,113 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license + * agreements. See the NOTICE file distributed with this work for additional information regarding + * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the + * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.comet + +import java.util.Objects + +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} +import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Comet physical plan node for Spark `CollectExecNode`. + * + * Similar to `CometTakeOrderedAndProjectExec`, it contains two native executions seperated by a + * comet shuffle. + */ +case class CometCollectLimitExec( + override val originalPlan: SparkPlan, + limit: Int, + offset: Int, + child: SparkPlan) + extends CometExec + with UnaryExecNode { + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "shuffleReadElapsedCompute" -> + SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"), + "numPartitions" -> SQLMetrics.createMetric( + sparkContext, + "number of partitions")) ++ readMetrics ++ writeMetrics + + private lazy val serializer: Serializer = + new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + + override def executeCollect(): Array[InternalRow] = { + ColumnarToRowExec(child).executeTake(limit) + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val childRDD = child.executeColumnar() + if (childRDD.getNumPartitions == 0) { + CometExecUtils.createEmptyColumnarRDDWithSinglePartition(sparkContext) + } else { + val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { + childRDD + } else { + val localLimitedRDD = if (limit >= 0) { + childRDD.mapPartitionsInternal { iter => + val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get + CometExec.getCometIterator(Seq(iter), limitOp) + } + } else { + childRDD + } + // Shuffle to Single Partition using Comet native shuffle + val dep = CometShuffleExchangeExec.prepareShuffleDependency( + localLimitedRDD, + child.output, + outputPartitioning, + serializer, + metrics) + metrics("numPartitions").set(dep.partitioner.numPartitions) + + new CometShuffledBatchRDD(dep, readMetrics) + } + + // todo: supports offset later + singlePartitionRDD.mapPartitionsInternal { iter => + val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get + CometExec.getCometIterator(Seq(iter), limitOp) + } + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + this.copy(child = newChild) + + override def stringArgs: Iterator[Any] = Iterator(limit, offset, child) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometCollectLimitExec => + this.limit == other.limit && this.offset == other.offset && + this.child == other.child + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(limit: java.lang.Integer, offset: java.lang.Integer, child) +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 9f8f2150c..b77fc5c60 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -21,8 +21,11 @@ package org.apache.spark.sql.comet import scala.collection.JavaConverters.asJavaIterableConverter +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.serde.OperatorOuterClass import org.apache.comet.serde.OperatorOuterClass.Operator @@ -30,6 +33,11 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} object CometExecUtils { + def createEmptyColumnarRDDWithSinglePartition( + sparkContext: SparkContext): RDD[ColumnarBatch] = { + new EmptyRDDWithPartitions(sparkContext, 1) + } + /** * Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec. */ @@ -119,3 +127,17 @@ object CometExecUtils { } } } + +/** A simple RDD with no data, but with the given number of partitions. */ +private class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int) + extends RDD[ColumnarBatch](sc, Nil) { + + override def getPartitions: Array[Partition] = + Array.tabulate(numPartitions)(i => EmptyPartition(i)) + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + Iterator.empty + } +} + +private case class EmptyPartition(index: Int) extends Partition diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 29b6e120a..36e7db4b5 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec @@ -1073,6 +1073,32 @@ class CometExecSuite extends CometTestBase { } }) } + + test("collect limit") { + Seq("true", "false").foreach(aqe => { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe) { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { + val df = sql("SELECT _1 as id, _2 as value FROM tbl limit 2") + assert(df.queryExecution.executedPlan.execute().getNumPartitions === 1) + checkSparkAnswerAndOperator(df, Seq(classOf[CometCollectLimitExec])) + assert(df.collect().length === 2) + + // checks CometCollectExec.doExecuteColumnar is indirectly called + val qe = df.queryExecution + SQLExecution.withNewExecutionId(qe, Some("count")) { + qe.executedPlan.resetMetrics() + assert(qe.executedPlan.execute().count() === 2) + } + + assert(df.isEmpty === false) + + // follow up native operation is possible + val df3 = df.groupBy("id").sum("value") + checkSparkAnswerAndOperator(df3) + } + } + }) + } } case class BucketedTableTestSpec( diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 2e523fa5a..e85fbcfba 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -150,7 +150,15 @@ abstract class CometTestBase protected def checkSparkAnswerAndOperator( df: => DataFrame, excludedClasses: Class[_]*): Unit = { + checkSparkAnswerAndOperator(df, Seq.empty, excludedClasses: _*) + } + + protected def checkSparkAnswerAndOperator( + df: => DataFrame, + includeClasses: Seq[Class[_]], + excludedClasses: Class[_]*): Unit = { checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), excludedClasses: _*) + checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), includeClasses: _*) checkSparkAnswer(df) } @@ -173,6 +181,17 @@ abstract class CometTestBase } } + protected def checkPlanContains(plan: SparkPlan, includePlans: Class[_]*): Unit = { + includePlans.foreach { case planClass => + if (!plan.exists(op => planClass.isAssignableFrom(op.getClass))) { + assert( + false, + s"Expected plan to contain ${planClass.getSimpleName}.\n" + + s"plan: $plan") + } + } + } + /** * Check the answer of a Comet SQL query with Spark result using absolute tolerance. */ From afb651307fe7d582677873d13b456e43378f1d72 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 27 Feb 2024 23:04:49 +0800 Subject: [PATCH 2/4] Address comments --- .../comet/CometSparkSessionExtensions.scala | 4 +- .../ShimCometSparkSessionExtensions.scala | 17 +++++-- .../sql/comet/CometCollectLimitExec.scala | 45 +++++++++---------- .../spark/sql/comet/CometExecUtils.scala | 17 +++++++ .../CometTakeOrderedAndProjectExec.scala | 6 +-- .../apache/comet/exec/CometExecSuite.scala | 2 + 6 files changed, 58 insertions(+), 33 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 6a683e24d..4ee33c0f1 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -284,10 +284,10 @@ class CometSparkSessionExtensions case op: CollectLimitExec if isCometNative(op.child) && isCometOperatorEnabled(conf, "collectLimit") && isCometShuffleEnabled(conf) - && getOffset(op).getOrElse(0) == 0 => + && getOffset(op) == 0 => QueryPlanSerde.operator2Proto(op) match { case Some(nativeOp) => - val offset = getOffset(op).getOrElse(0) + val offset = getOffset(op) val cometOp = CometCollectLimitExec(op, op.limit, offset, op.child) CometSinkPlaceHolder(nativeOp, op, cometOp) diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 3ccac5456..85c6413e1 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -20,10 +20,11 @@ package org.apache.comet.shims import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.execution.LimitExec +import org.apache.spark.sql.execution.{LimitExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan trait ShimCometSparkSessionExtensions { + import org.apache.comet.shims.ShimCometSparkSessionExtensions._ /** * TODO: delete after dropping Spark 3.2.0 support and directly call scan.pushedAggregate @@ -34,8 +35,18 @@ trait ShimCometSparkSessionExtensions { .flatMap(_.get(scan).asInstanceOf[Option[Aggregation]]) .headOption - def getOffset(limit: LimitExec): Option[Int] = limit.getClass.getDeclaredFields + /** + * TODO: delete after dropping Spark 3.2 and 3.3 support + */ + def getOffset(limit: LimitExec): Int = getOffsetOpt(limit).getOrElse(0) + +} + +object ShimCometSparkSessionExtensions { + private def getOffsetOpt(plan: SparkPlan): Option[Int] = plan.getClass.getDeclaredFields .filter(_.getName == "offset") - .map { a => a.setAccessible(true); a.get(limit).asInstanceOf[Int] } + .map { a => a.setAccessible(true); a.get(plan) } + .filter(_.isInstanceOf[Int]) + .map(_.asInstanceOf[Int]) .headOption } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala index 5727d0cce..1afa2d6ba 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -1,17 +1,22 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more contributor license - * agreements. See the NOTICE file distributed with this work for additional information regarding - * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with the License. You may - * obtain a copy of the License at +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software distributed under the - * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, - * either express or implied. See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ + package org.apache.spark.sql.comet import java.util.Objects @@ -25,10 +30,12 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleR import org.apache.spark.sql.vectorized.ColumnarBatch /** - * Comet physical plan node for Spark `CollectExecNode`. + * Comet physical plan node for Spark `CollectLimitExec`. * * Similar to `CometTakeOrderedAndProjectExec`, it contains two native executions seperated by a - * comet shuffle. + * Comet shuffle. + * + * TODO: support offset semantics */ case class CometCollectLimitExec( override val originalPlan: SparkPlan, @@ -66,10 +73,7 @@ case class CometCollectLimitExec( childRDD } else { val localLimitedRDD = if (limit >= 0) { - childRDD.mapPartitionsInternal { iter => - val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get - CometExec.getCometIterator(Seq(iter), limitOp) - } + CometExecUtils.toNativeLimitedPerPartition(childRDD, output, limit) } else { childRDD } @@ -84,12 +88,7 @@ case class CometCollectLimitExec( new CometShuffledBatchRDD(dep, readMetrics) } - - // todo: supports offset later - singlePartitionRDD.mapPartitionsInternal { iter => - val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get - CometExec.getCometIterator(Seq(iter), limitOp) - } + CometExecUtils.toNativeLimitedPerPartition(singlePartitionRDD, output, limit) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index b77fc5c60..1856bc728 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -33,11 +33,28 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} object CometExecUtils { + /** + * Create an empty ColumnarBatch RDD with a single partition. + */ def createEmptyColumnarRDDWithSinglePartition( sparkContext: SparkContext): RDD[ColumnarBatch] = { new EmptyRDDWithPartitions(sparkContext, 1) } + /** + * Transform the given RDD into a new RDD that takes the first `limit` elements of each + * partition. The limit operation is performed on the native side. + */ + def toNativeLimitedPerPartition( + childPlan: RDD[ColumnarBatch], + outputAttribute: Seq[Attribute], + limit: Int): RDD[ColumnarBatch] = { + childPlan.mapPartitionsInternal { iter => + val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get + CometExec.getCometIterator(Seq(iter), limitOp) + } + } + /** * Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec. */ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 88984388f..e8597e8b5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -77,11 +77,7 @@ case class CometTakeOrderedAndProjectExec( childRDD } else { val localTopK = if (orderingSatisfies) { - childRDD.mapPartitionsInternal { iter => - val limitOp = - CometExecUtils.getLimitNativePlan(output, limit).get - CometExec.getCometIterator(Seq(iter), limitOp) - } + CometExecUtils.toNativeLimitedPerPartition(childRDD, output, limit) } else { childRDD.mapPartitionsInternal { iter => val topK = diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 36e7db4b5..659f883d7 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -1085,6 +1085,8 @@ class CometExecSuite extends CometTestBase { // checks CometCollectExec.doExecuteColumnar is indirectly called val qe = df.queryExecution + // make sure the root node is CometCollectLimitExec + assert(qe.executedPlan.isInstanceOf[CometCollectLimitExec]) SQLExecution.withNewExecutionId(qe, Some("count")) { qe.executedPlan.resetMetrics() assert(qe.executedPlan.execute().count() === 2) From e15852f3973cd56240ef8933b4f2fa6be5cf102f Mon Sep 17 00:00:00 2001 From: Xianjin Date: Wed, 28 Feb 2024 19:10:44 +0800 Subject: [PATCH 3/4] address review comments --- .../comet/CometSparkSessionExtensions.scala | 7 +++++-- .../spark/sql/comet/CometCoalesceExec.scala | 2 +- .../sql/comet/CometCollectLimitExec.scala | 8 ++++---- .../spark/sql/comet/CometExecUtils.scala | 20 +++++++++++-------- .../CometTakeOrderedAndProjectExec.scala | 2 +- .../apache/comet/exec/CometExecSuite.scala | 2 +- .../org/apache/spark/sql/CometTestBase.scala | 2 +- 7 files changed, 25 insertions(+), 18 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 4ee33c0f1..f88081387 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -480,8 +480,11 @@ class CometSparkSessionExtensions // `ColumnarToRowExec` for such operators. case class EliminateRedundantColumnarToRow(session: SparkSession) extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { - plan.transform { case ColumnarToRowExec(child: CometCollectLimitExec) => - child + plan match { + case ColumnarToRowExec(child: CometCollectLimitExec) => + child + case other => + other } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala index 532eb91a9..cc635d739 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala @@ -41,7 +41,7 @@ case class CometCoalesceExec( if (numPartitions == 1 && rdd.getNumPartitions < 1) { // Make sure we don't output an RDD with 0 partitions, when claiming that we have a // `SinglePartition`. - CometExecUtils.createEmptyColumnarRDDWithSinglePartition(sparkContext) + CometExecUtils.emptyRDDWithPartitions(sparkContext, 1) } else { rdd.coalesce(numPartitions, shuffle = false) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala index 1afa2d6ba..83126a7ba 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -67,17 +67,17 @@ case class CometCollectLimitExec( protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { val childRDD = child.executeColumnar() if (childRDD.getNumPartitions == 0) { - CometExecUtils.createEmptyColumnarRDDWithSinglePartition(sparkContext) + CometExecUtils.emptyRDDWithPartitions(sparkContext, 1) } else { val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { childRDD } else { val localLimitedRDD = if (limit >= 0) { - CometExecUtils.toNativeLimitedPerPartition(childRDD, output, limit) + CometExecUtils.getNativeLimitRDD(childRDD, output, limit) } else { childRDD } - // Shuffle to Single Partition using Comet native shuffle + // Shuffle to Single Partition using Comet shuffle val dep = CometShuffleExchangeExec.prepareShuffleDependency( localLimitedRDD, child.output, @@ -88,7 +88,7 @@ case class CometCollectLimitExec( new CometShuffledBatchRDD(dep, readMetrics) } - CometExecUtils.toNativeLimitedPerPartition(singlePartitionRDD, output, limit) + CometExecUtils.getNativeLimitRDD(singlePartitionRDD, output, limit) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 1856bc728..5931920a2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.comet import scala.collection.JavaConverters.asJavaIterableConverter +import scala.reflect.ClassTag import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD @@ -34,18 +35,19 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} object CometExecUtils { /** - * Create an empty ColumnarBatch RDD with a single partition. + * Create an empty RDD with the given number of partitions. */ - def createEmptyColumnarRDDWithSinglePartition( - sparkContext: SparkContext): RDD[ColumnarBatch] = { - new EmptyRDDWithPartitions(sparkContext, 1) + def emptyRDDWithPartitions[T: ClassTag]( + sparkContext: SparkContext, + numPartitions: Int): RDD[T] = { + new EmptyRDDWithPartitions(sparkContext, numPartitions) } /** * Transform the given RDD into a new RDD that takes the first `limit` elements of each * partition. The limit operation is performed on the native side. */ - def toNativeLimitedPerPartition( + def getNativeLimitRDD( childPlan: RDD[ColumnarBatch], outputAttribute: Seq[Attribute], limit: Int): RDD[ColumnarBatch] = { @@ -146,13 +148,15 @@ object CometExecUtils { } /** A simple RDD with no data, but with the given number of partitions. */ -private class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int) - extends RDD[ColumnarBatch](sc, Nil) { +private class EmptyRDDWithPartitions[T: ClassTag]( + @transient private val sc: SparkContext, + numPartitions: Int) + extends RDD[T](sc, Nil) { override def getPartitions: Array[Partition] = Array.tabulate(numPartitions)(i => EmptyPartition(i)) - override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + override def compute(split: Partition, context: TaskContext): Iterator[T] = { Iterator.empty } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index e8597e8b5..26ec401ed 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -77,7 +77,7 @@ case class CometTakeOrderedAndProjectExec( childRDD } else { val localTopK = if (orderingSatisfies) { - CometExecUtils.toNativeLimitedPerPartition(childRDD, output, limit) + CometExecUtils.getNativeLimitRDD(childRDD, output, limit) } else { childRDD.mapPartitionsInternal { iter => val topK = diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 659f883d7..cf184528a 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -1083,10 +1083,10 @@ class CometExecSuite extends CometTestBase { checkSparkAnswerAndOperator(df, Seq(classOf[CometCollectLimitExec])) assert(df.collect().length === 2) - // checks CometCollectExec.doExecuteColumnar is indirectly called val qe = df.queryExecution // make sure the root node is CometCollectLimitExec assert(qe.executedPlan.isInstanceOf[CometCollectLimitExec]) + // executes CometCollectExec directly to check doExecuteColumnar implementation SQLExecution.withNewExecutionId(qe, Some("count")) { qe.executedPlan.resetMetrics() assert(qe.executedPlan.execute().count() === 2) diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index e85fbcfba..0d7904c01 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -186,7 +186,7 @@ abstract class CometTestBase if (!plan.exists(op => planClass.isAssignableFrom(op.getClass))) { assert( false, - s"Expected plan to contain ${planClass.getSimpleName}.\n" + + s"Expected plan to contain ${planClass.getSimpleName}, but not.\n" + s"plan: $plan") } } From f1c77184868e9ffe324c062fa3e9cb5c0475ff0d Mon Sep 17 00:00:00 2001 From: Xianjin Date: Wed, 28 Feb 2024 20:04:45 +0800 Subject: [PATCH 4/4] spotless:apply --- .../src/test/scala/org/apache/comet/exec/CometExecSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index d4049d4fe..d7434d518 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -22,8 +22,10 @@ package org.apache.comet.exec import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random + import org.scalactic.source.Position import org.scalatest.Tag + import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame, DataFrameWriter, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier @@ -40,6 +42,7 @@ import org.apache.spark.sql.functions.{date_add, expr, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String + import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus