Skip to content

Commit

Permalink
feat: Support CollectLimit operator (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy authored Feb 28, 2024
1 parent aac4bfd commit 313111d
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class CometSparkSessionExtensions

case class CometExecColumnar(session: SparkSession) extends ColumnarRule {
override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session)

override def postColumnarTransitions: Rule[SparkPlan] =
EliminateRedundantColumnarToRow(session)
}

case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] {
Expand Down Expand Up @@ -284,6 +287,20 @@ class CometSparkSessionExtensions
op
}

case op: CollectLimitExec
if isCometNative(op.child) && isCometOperatorEnabled(conf, "collectLimit")
&& isCometShuffleEnabled(conf)
&& getOffset(op) == 0 =>
QueryPlanSerde.operator2Proto(op) match {
case Some(nativeOp) =>
val offset = getOffset(op)
val cometOp =
CometCollectLimitExec(op, op.limit, offset, op.child)
CometSinkPlaceHolder(nativeOp, op, cometOp)
case None =>
op
}

case op: ExpandExec =>
val newOp = transform1(op)
newOp match {
Expand Down Expand Up @@ -457,6 +474,26 @@ class CometSparkSessionExtensions
}
}
}

// CometExec already wraps a `ColumnarToRowExec` for row-based operators. Therefore,
// `ColumnarToRowExec` is redundant and can be eliminated.
//
// It was added during ApplyColumnarRulesAndInsertTransitions' insertTransitions phase when Spark
// requests row-based output such as `collect` call. It's correct to add a redundant
// `ColumnarToRowExec` for `CometExec`. However, for certain operators such as
// `CometCollectLimitExec` which overrides `executeCollect`, the redundant `ColumnarToRowExec`
// makes the override ineffective. The purpose of this rule is to eliminate the redundant
// `ColumnarToRowExec` for such operators.
case class EliminateRedundantColumnarToRow(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
plan match {
case ColumnarToRowExec(child: CometCollectLimitExec) =>
child
case other =>
other
}
}
}
}

object CometSparkSessionExtensions extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1835,6 +1835,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
case s if isCometScan(s) => true
case _: CometSinkPlaceHolder => true
case _: CoalesceExec => true
case _: CollectLimitExec => true
case _: UnionExec => true
case _: ShuffleExchangeExec => true
case _: TakeOrderedAndProjectExec => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
package org.apache.comet.shims

import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.execution.{LimitExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan

trait ShimCometSparkSessionExtensions {
import org.apache.comet.shims.ShimCometSparkSessionExtensions._

/**
* TODO: delete after dropping Spark 3.2.0 support and directly call scan.pushedAggregate
Expand All @@ -32,4 +34,19 @@ trait ShimCometSparkSessionExtensions {
.map { a => a.setAccessible(true); a }
.flatMap(_.get(scan).asInstanceOf[Option[Aggregation]])
.headOption

/**
* TODO: delete after dropping Spark 3.2 and 3.3 support
*/
def getOffset(limit: LimitExec): Int = getOffsetOpt(limit).getOrElse(0)

}

object ShimCometSparkSessionExtensions {
private def getOffsetOpt(plan: SparkPlan): Option[Int] = plan.getClass.getDeclaredFields
.filter(_.getName == "offset")
.map { a => a.setAccessible(true); a.get(plan) }
.filter(_.isInstanceOf[Int])
.map(_.asInstanceOf[Int])
.headOption
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

package org.apache.spark.sql.comet

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
Expand All @@ -42,7 +41,7 @@ case class CometCoalesceExec(
if (numPartitions == 1 && rdd.getNumPartitions < 1) {
// Make sure we don't output an RDD with 0 partitions, when claiming that we have a
// `SinglePartition`.
new CometCoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions)
CometExecUtils.emptyRDDWithPartitions(sparkContext, 1)
} else {
rdd.coalesce(numPartitions, shuffle = false)
}
Expand All @@ -67,20 +66,3 @@ case class CometCoalesceExec(

override def hashCode(): Int = Objects.hashCode(numPartitions: java.lang.Integer, child)
}

object CometCoalesceExec {

/** A simple RDD with no data, but with the given number of partitions. */
class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int)
extends RDD[ColumnarBatch](sc, Nil) {

override def getPartitions: Array[Partition] =
Array.tabulate(numPartitions)(i => EmptyPartition(i))

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
Iterator.empty
}
}

case class EmptyPartition(index: Int) extends Partition
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet

import java.util.Objects

import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* Comet physical plan node for Spark `CollectLimitExec`.
*
* Similar to `CometTakeOrderedAndProjectExec`, it contains two native executions seperated by a
* Comet shuffle.
*
* TODO: support offset semantics
*/
case class CometCollectLimitExec(
override val originalPlan: SparkPlan,
limit: Int,
offset: Int,
child: SparkPlan)
extends CometExec
with UnaryExecNode {

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics: Map[String, SQLMetric] = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"shuffleReadElapsedCompute" ->
SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"),
"numPartitions" -> SQLMetrics.createMetric(
sparkContext,
"number of partitions")) ++ readMetrics ++ writeMetrics

private lazy val serializer: Serializer =
new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))

override def executeCollect(): Array[InternalRow] = {
ColumnarToRowExec(child).executeTake(limit)
}

protected override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val childRDD = child.executeColumnar()
if (childRDD.getNumPartitions == 0) {
CometExecUtils.emptyRDDWithPartitions(sparkContext, 1)
} else {
val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
childRDD
} else {
val localLimitedRDD = if (limit >= 0) {
CometExecUtils.getNativeLimitRDD(childRDD, output, limit)
} else {
childRDD
}
// Shuffle to Single Partition using Comet shuffle
val dep = CometShuffleExchangeExec.prepareShuffleDependency(
localLimitedRDD,
child.output,
outputPartitioning,
serializer,
metrics)
metrics("numPartitions").set(dep.partitioner.numPartitions)

new CometShuffledBatchRDD(dep, readMetrics)
}
CometExecUtils.getNativeLimitRDD(singlePartitionRDD, output, limit)
}
}

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

override def stringArgs: Iterator[Any] = Iterator(limit, offset, child)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometCollectLimitExec =>
this.limit == other.limit && this.offset == other.offset &&
this.child == other.child
case _ =>
false
}
}

override def hashCode(): Int =
Objects.hashCode(limit: java.lang.Integer, offset: java.lang.Integer, child)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,43 @@
package org.apache.spark.sql.comet

import scala.collection.JavaConverters.asJavaIterableConverter
import scala.reflect.ClassTag

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.serde.OperatorOuterClass
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}

object CometExecUtils {

/**
* Create an empty RDD with the given number of partitions.
*/
def emptyRDDWithPartitions[T: ClassTag](
sparkContext: SparkContext,
numPartitions: Int): RDD[T] = {
new EmptyRDDWithPartitions(sparkContext, numPartitions)
}

/**
* Transform the given RDD into a new RDD that takes the first `limit` elements of each
* partition. The limit operation is performed on the native side.
*/
def getNativeLimitRDD(
childPlan: RDD[ColumnarBatch],
outputAttribute: Seq[Attribute],
limit: Int): RDD[ColumnarBatch] = {
childPlan.mapPartitionsInternal { iter =>
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
}
}

/**
* Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec.
*/
Expand Down Expand Up @@ -119,3 +146,19 @@ object CometExecUtils {
}
}
}

/** A simple RDD with no data, but with the given number of partitions. */
private class EmptyRDDWithPartitions[T: ClassTag](
@transient private val sc: SparkContext,
numPartitions: Int)
extends RDD[T](sc, Nil) {

override def getPartitions: Array[Partition] =
Array.tabulate(numPartitions)(i => EmptyPartition(i))

override def compute(split: Partition, context: TaskContext): Iterator[T] = {
Iterator.empty
}
}

private case class EmptyPartition(index: Int) extends Partition
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,7 @@ case class CometTakeOrderedAndProjectExec(
childRDD
} else {
val localTopK = if (orderingSatisfies) {
childRDD.mapPartitionsInternal { iter =>
val limitOp =
CometExecUtils.getLimitNativePlan(output, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
}
CometExecUtils.getNativeLimitRDD(childRDD, output, limit)
} else {
childRDD.mapPartitionsInternal { iter =>
val topK =
Expand Down
32 changes: 30 additions & 2 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
Expand Down Expand Up @@ -1087,6 +1087,34 @@ class CometExecSuite extends CometTestBase {
}
})
}

test("collect limit") {
Seq("true", "false").foreach(aqe => {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe) {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") {
val df = sql("SELECT _1 as id, _2 as value FROM tbl limit 2")
assert(df.queryExecution.executedPlan.execute().getNumPartitions === 1)
checkSparkAnswerAndOperator(df, Seq(classOf[CometCollectLimitExec]))
assert(df.collect().length === 2)

val qe = df.queryExecution
// make sure the root node is CometCollectLimitExec
assert(qe.executedPlan.isInstanceOf[CometCollectLimitExec])
// executes CometCollectExec directly to check doExecuteColumnar implementation
SQLExecution.withNewExecutionId(qe, Some("count")) {
qe.executedPlan.resetMetrics()
assert(qe.executedPlan.execute().count() === 2)
}

assert(df.isEmpty === false)

// follow up native operation is possible
val df3 = df.groupBy("id").sum("value")
checkSparkAnswerAndOperator(df3)
}
}
})
}
}

case class BucketedTableTestSpec(
Expand Down
Loading

0 comments on commit 313111d

Please sign in to comment.