Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Feb 27, 2024
1 parent 248cd03 commit afb6513
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,10 @@ class CometSparkSessionExtensions
case op: CollectLimitExec
if isCometNative(op.child) && isCometOperatorEnabled(conf, "collectLimit")
&& isCometShuffleEnabled(conf)
&& getOffset(op).getOrElse(0) == 0 =>
&& getOffset(op) == 0 =>
QueryPlanSerde.operator2Proto(op) match {
case Some(nativeOp) =>
val offset = getOffset(op).getOrElse(0)
val offset = getOffset(op)
val cometOp =
CometCollectLimitExec(op, op.limit, offset, op.child)
CometSinkPlaceHolder(nativeOp, op, cometOp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
package org.apache.comet.shims

import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.execution.LimitExec
import org.apache.spark.sql.execution.{LimitExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan

trait ShimCometSparkSessionExtensions {
import org.apache.comet.shims.ShimCometSparkSessionExtensions._

/**
* TODO: delete after dropping Spark 3.2.0 support and directly call scan.pushedAggregate
Expand All @@ -34,8 +35,18 @@ trait ShimCometSparkSessionExtensions {
.flatMap(_.get(scan).asInstanceOf[Option[Aggregation]])
.headOption

def getOffset(limit: LimitExec): Option[Int] = limit.getClass.getDeclaredFields
/**
* TODO: delete after dropping Spark 3.2 and 3.3 support
*/
def getOffset(limit: LimitExec): Int = getOffsetOpt(limit).getOrElse(0)

}

object ShimCometSparkSessionExtensions {
private def getOffsetOpt(plan: SparkPlan): Option[Int] = plan.getClass.getDeclaredFields
.filter(_.getName == "offset")
.map { a => a.setAccessible(true); a.get(limit).asInstanceOf[Int] }
.map { a => a.setAccessible(true); a.get(plan) }
.filter(_.isInstanceOf[Int])
.map(_.asInstanceOf[Int])
.headOption
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more contributor license
* agreements. See the NOTICE file distributed with this work for additional information regarding
* copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with the License. You may
* obtain a copy of the License at
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the
* License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the specific language governing permissions and
* limitations under the License.
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet

import java.util.Objects
Expand All @@ -25,10 +30,12 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleR
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* Comet physical plan node for Spark `CollectExecNode`.
* Comet physical plan node for Spark `CollectLimitExec`.
*
* Similar to `CometTakeOrderedAndProjectExec`, it contains two native executions seperated by a
* comet shuffle.
* Comet shuffle.
*
* TODO: support offset semantics
*/
case class CometCollectLimitExec(
override val originalPlan: SparkPlan,
Expand Down Expand Up @@ -66,10 +73,7 @@ case class CometCollectLimitExec(
childRDD
} else {
val localLimitedRDD = if (limit >= 0) {
childRDD.mapPartitionsInternal { iter =>
val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
}
CometExecUtils.toNativeLimitedPerPartition(childRDD, output, limit)
} else {
childRDD
}
Expand All @@ -84,12 +88,7 @@ case class CometCollectLimitExec(

new CometShuffledBatchRDD(dep, readMetrics)
}

// todo: supports offset later
singlePartitionRDD.mapPartitionsInternal { iter =>
val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
}
CometExecUtils.toNativeLimitedPerPartition(singlePartitionRDD, output, limit)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,28 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}

object CometExecUtils {

/**
* Create an empty ColumnarBatch RDD with a single partition.
*/
def createEmptyColumnarRDDWithSinglePartition(
sparkContext: SparkContext): RDD[ColumnarBatch] = {
new EmptyRDDWithPartitions(sparkContext, 1)
}

/**
* Transform the given RDD into a new RDD that takes the first `limit` elements of each
* partition. The limit operation is performed on the native side.
*/
def toNativeLimitedPerPartition(
childPlan: RDD[ColumnarBatch],
outputAttribute: Seq[Attribute],
limit: Int): RDD[ColumnarBatch] = {
childPlan.mapPartitionsInternal { iter =>
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
}
}

/**
* Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,7 @@ case class CometTakeOrderedAndProjectExec(
childRDD
} else {
val localTopK = if (orderingSatisfies) {
childRDD.mapPartitionsInternal { iter =>
val limitOp =
CometExecUtils.getLimitNativePlan(output, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
}
CometExecUtils.toNativeLimitedPerPartition(childRDD, output, limit)
} else {
childRDD.mapPartitionsInternal { iter =>
val topK =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,8 @@ class CometExecSuite extends CometTestBase {

// checks CometCollectExec.doExecuteColumnar is indirectly called
val qe = df.queryExecution
// make sure the root node is CometCollectLimitExec
assert(qe.executedPlan.isInstanceOf[CometCollectLimitExec])
SQLExecution.withNewExecutionId(qe, Some("count")) {
qe.executedPlan.resetMetrics()
assert(qe.executedPlan.execute().count() === 2)
Expand Down

0 comments on commit afb6513

Please sign in to comment.