diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 6a683e24d..4ee33c0f1 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -284,10 +284,10 @@ class CometSparkSessionExtensions case op: CollectLimitExec if isCometNative(op.child) && isCometOperatorEnabled(conf, "collectLimit") && isCometShuffleEnabled(conf) - && getOffset(op).getOrElse(0) == 0 => + && getOffset(op) == 0 => QueryPlanSerde.operator2Proto(op) match { case Some(nativeOp) => - val offset = getOffset(op).getOrElse(0) + val offset = getOffset(op) val cometOp = CometCollectLimitExec(op, op.limit, offset, op.child) CometSinkPlaceHolder(nativeOp, op, cometOp) diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 3ccac5456..85c6413e1 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -20,10 +20,11 @@ package org.apache.comet.shims import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.execution.LimitExec +import org.apache.spark.sql.execution.{LimitExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan trait ShimCometSparkSessionExtensions { + import org.apache.comet.shims.ShimCometSparkSessionExtensions._ /** * TODO: delete after dropping Spark 3.2.0 support and directly call scan.pushedAggregate @@ -34,8 +35,18 @@ trait ShimCometSparkSessionExtensions { .flatMap(_.get(scan).asInstanceOf[Option[Aggregation]]) .headOption - def getOffset(limit: LimitExec): Option[Int] = limit.getClass.getDeclaredFields + /** + * TODO: delete after dropping Spark 3.2 and 3.3 support + */ + def getOffset(limit: LimitExec): Int = getOffsetOpt(limit).getOrElse(0) + +} + +object ShimCometSparkSessionExtensions { + private def getOffsetOpt(plan: SparkPlan): Option[Int] = plan.getClass.getDeclaredFields .filter(_.getName == "offset") - .map { a => a.setAccessible(true); a.get(limit).asInstanceOf[Int] } + .map { a => a.setAccessible(true); a.get(plan) } + .filter(_.isInstanceOf[Int]) + .map(_.asInstanceOf[Int]) .headOption } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala index 5727d0cce..1afa2d6ba 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -1,17 +1,22 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more contributor license - * agreements. See the NOTICE file distributed with this work for additional information regarding - * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with the License. You may - * obtain a copy of the License at +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software distributed under the - * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, - * either express or implied. See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ + package org.apache.spark.sql.comet import java.util.Objects @@ -25,10 +30,12 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleR import org.apache.spark.sql.vectorized.ColumnarBatch /** - * Comet physical plan node for Spark `CollectExecNode`. + * Comet physical plan node for Spark `CollectLimitExec`. * * Similar to `CometTakeOrderedAndProjectExec`, it contains two native executions seperated by a - * comet shuffle. + * Comet shuffle. + * + * TODO: support offset semantics */ case class CometCollectLimitExec( override val originalPlan: SparkPlan, @@ -66,10 +73,7 @@ case class CometCollectLimitExec( childRDD } else { val localLimitedRDD = if (limit >= 0) { - childRDD.mapPartitionsInternal { iter => - val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get - CometExec.getCometIterator(Seq(iter), limitOp) - } + CometExecUtils.toNativeLimitedPerPartition(childRDD, output, limit) } else { childRDD } @@ -84,12 +88,7 @@ case class CometCollectLimitExec( new CometShuffledBatchRDD(dep, readMetrics) } - - // todo: supports offset later - singlePartitionRDD.mapPartitionsInternal { iter => - val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get - CometExec.getCometIterator(Seq(iter), limitOp) - } + CometExecUtils.toNativeLimitedPerPartition(singlePartitionRDD, output, limit) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index b77fc5c60..1856bc728 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -33,11 +33,28 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} object CometExecUtils { + /** + * Create an empty ColumnarBatch RDD with a single partition. + */ def createEmptyColumnarRDDWithSinglePartition( sparkContext: SparkContext): RDD[ColumnarBatch] = { new EmptyRDDWithPartitions(sparkContext, 1) } + /** + * Transform the given RDD into a new RDD that takes the first `limit` elements of each + * partition. The limit operation is performed on the native side. + */ + def toNativeLimitedPerPartition( + childPlan: RDD[ColumnarBatch], + outputAttribute: Seq[Attribute], + limit: Int): RDD[ColumnarBatch] = { + childPlan.mapPartitionsInternal { iter => + val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get + CometExec.getCometIterator(Seq(iter), limitOp) + } + } + /** * Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec. */ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 88984388f..e8597e8b5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -77,11 +77,7 @@ case class CometTakeOrderedAndProjectExec( childRDD } else { val localTopK = if (orderingSatisfies) { - childRDD.mapPartitionsInternal { iter => - val limitOp = - CometExecUtils.getLimitNativePlan(output, limit).get - CometExec.getCometIterator(Seq(iter), limitOp) - } + CometExecUtils.toNativeLimitedPerPartition(childRDD, output, limit) } else { childRDD.mapPartitionsInternal { iter => val topK = diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 36e7db4b5..659f883d7 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -1085,6 +1085,8 @@ class CometExecSuite extends CometTestBase { // checks CometCollectExec.doExecuteColumnar is indirectly called val qe = df.queryExecution + // make sure the root node is CometCollectLimitExec + assert(qe.executedPlan.isInstanceOf[CometCollectLimitExec]) SQLExecution.withNewExecutionId(qe, Some("count")) { qe.executedPlan.resetMetrics() assert(qe.executedPlan.execute().count() === 2)