Skip to content

Commit

Permalink
feat: Support CollectLimit operator
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Feb 27, 2024
1 parent 0a96145 commit 248cd03
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class CometSparkSessionExtensions

case class CometExecColumnar(session: SparkSession) extends ColumnarRule {
override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session)

override def postColumnarTransitions: Rule[SparkPlan] =
EliminateRedundantColumnarToRow(session)
}

case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] {
Expand Down Expand Up @@ -278,6 +281,20 @@ class CometSparkSessionExtensions
op
}

case op: CollectLimitExec
if isCometNative(op.child) && isCometOperatorEnabled(conf, "collectLimit")
&& isCometShuffleEnabled(conf)
&& getOffset(op).getOrElse(0) == 0 =>
QueryPlanSerde.operator2Proto(op) match {
case Some(nativeOp) =>
val offset = getOffset(op).getOrElse(0)
val cometOp =
CometCollectLimitExec(op, op.limit, offset, op.child)
CometSinkPlaceHolder(nativeOp, op, cometOp)
case None =>
op
}

case op: ExpandExec =>
val newOp = transform1(op)
newOp match {
Expand Down Expand Up @@ -451,6 +468,23 @@ class CometSparkSessionExtensions
}
}
}

// CometExec already wraps a `ColumnarToRowExec` for row-based operators. Therefore,
// `ColumnarToRowExec` is redundant and can be eliminated.
//
// It was added during ApplyColumnarRulesAndInsertTransitions' insertTransitions phase when Spark
// requests row-based output such as `collect` call. It's correct to add a redundant
// `ColumnarToRowExec` for `CometExec`. However, for certain operators such as
// `CometCollectLimitExec` which overrides `executeCollect`, the redundant `ColumnarToRowExec`
// makes the override ineffective. The purpose of this rule is to eliminate the redundant
// `ColumnarToRowExec` for such operators.
case class EliminateRedundantColumnarToRow(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
plan.transform { case ColumnarToRowExec(child: CometCollectLimitExec) =>
child
}
}
}
}

object CometSparkSessionExtensions extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1835,6 +1835,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
case s if isCometScan(s) => true
case _: CometSinkPlaceHolder => true
case _: CoalesceExec => true
case _: CollectLimitExec => true
case _: UnionExec => true
case _: ShuffleExchangeExec => true
case _: TakeOrderedAndProjectExec => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.comet.shims

import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.execution.LimitExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan

trait ShimCometSparkSessionExtensions {
Expand All @@ -32,4 +33,9 @@ trait ShimCometSparkSessionExtensions {
.map { a => a.setAccessible(true); a }
.flatMap(_.get(scan).asInstanceOf[Option[Aggregation]])
.headOption

def getOffset(limit: LimitExec): Option[Int] = limit.getClass.getDeclaredFields
.filter(_.getName == "offset")
.map { a => a.setAccessible(true); a.get(limit).asInstanceOf[Int] }
.headOption
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

package org.apache.spark.sql.comet

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
Expand All @@ -42,7 +41,7 @@ case class CometCoalesceExec(
if (numPartitions == 1 && rdd.getNumPartitions < 1) {
// Make sure we don't output an RDD with 0 partitions, when claiming that we have a
// `SinglePartition`.
new CometCoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions)
CometExecUtils.createEmptyColumnarRDDWithSinglePartition(sparkContext)
} else {
rdd.coalesce(numPartitions, shuffle = false)
}
Expand All @@ -67,20 +66,3 @@ case class CometCoalesceExec(

override def hashCode(): Int = Objects.hashCode(numPartitions: java.lang.Integer, child)
}

object CometCoalesceExec {

/** A simple RDD with no data, but with the given number of partitions. */
class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int)
extends RDD[ColumnarBatch](sc, Nil) {

override def getPartitions: Array[Partition] =
Array.tabulate(numPartitions)(i => EmptyPartition(i))

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
Iterator.empty
}
}

case class EmptyPartition(index: Int) extends Partition
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more contributor license
* agreements. See the NOTICE file distributed with this work for additional information regarding
* copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the
* License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.comet

import java.util.Objects

import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* Comet physical plan node for Spark `CollectExecNode`.
*
* Similar to `CometTakeOrderedAndProjectExec`, it contains two native executions seperated by a
* comet shuffle.
*/
case class CometCollectLimitExec(
override val originalPlan: SparkPlan,
limit: Int,
offset: Int,
child: SparkPlan)
extends CometExec
with UnaryExecNode {

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics: Map[String, SQLMetric] = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"shuffleReadElapsedCompute" ->
SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"),
"numPartitions" -> SQLMetrics.createMetric(
sparkContext,
"number of partitions")) ++ readMetrics ++ writeMetrics

private lazy val serializer: Serializer =
new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))

override def executeCollect(): Array[InternalRow] = {
ColumnarToRowExec(child).executeTake(limit)
}

protected override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val childRDD = child.executeColumnar()
if (childRDD.getNumPartitions == 0) {
CometExecUtils.createEmptyColumnarRDDWithSinglePartition(sparkContext)
} else {
val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
childRDD
} else {
val localLimitedRDD = if (limit >= 0) {
childRDD.mapPartitionsInternal { iter =>
val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
}
} else {
childRDD
}
// Shuffle to Single Partition using Comet native shuffle
val dep = CometShuffleExchangeExec.prepareShuffleDependency(
localLimitedRDD,
child.output,
outputPartitioning,
serializer,
metrics)
metrics("numPartitions").set(dep.partitioner.numPartitions)

new CometShuffledBatchRDD(dep, readMetrics)
}

// todo: supports offset later
singlePartitionRDD.mapPartitionsInternal { iter =>
val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
}
}
}

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

override def stringArgs: Iterator[Any] = Iterator(limit, offset, child)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometCollectLimitExec =>
this.limit == other.limit && this.offset == other.offset &&
this.child == other.child
case _ =>
false
}
}

override def hashCode(): Int =
Objects.hashCode(limit: java.lang.Integer, offset: java.lang.Integer, child)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,23 @@ package org.apache.spark.sql.comet

import scala.collection.JavaConverters.asJavaIterableConverter

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.serde.OperatorOuterClass
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}

object CometExecUtils {

def createEmptyColumnarRDDWithSinglePartition(
sparkContext: SparkContext): RDD[ColumnarBatch] = {
new EmptyRDDWithPartitions(sparkContext, 1)
}

/**
* Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec.
*/
Expand Down Expand Up @@ -119,3 +127,17 @@ object CometExecUtils {
}
}
}

/** A simple RDD with no data, but with the given number of partitions. */
private class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int)
extends RDD[ColumnarBatch](sc, Nil) {

override def getPartitions: Array[Partition] =
Array.tabulate(numPartitions)(i => EmptyPartition(i))

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
Iterator.empty
}
}

private case class EmptyPartition(index: Int) extends Partition
30 changes: 28 additions & 2 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
Expand Down Expand Up @@ -1073,6 +1073,32 @@ class CometExecSuite extends CometTestBase {
}
})
}

test("collect limit") {
Seq("true", "false").foreach(aqe => {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe) {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") {
val df = sql("SELECT _1 as id, _2 as value FROM tbl limit 2")
assert(df.queryExecution.executedPlan.execute().getNumPartitions === 1)
checkSparkAnswerAndOperator(df, Seq(classOf[CometCollectLimitExec]))
assert(df.collect().length === 2)

// checks CometCollectExec.doExecuteColumnar is indirectly called
val qe = df.queryExecution
SQLExecution.withNewExecutionId(qe, Some("count")) {
qe.executedPlan.resetMetrics()
assert(qe.executedPlan.execute().count() === 2)
}

assert(df.isEmpty === false)

// follow up native operation is possible
val df3 = df.groupBy("id").sum("value")
checkSparkAnswerAndOperator(df3)
}
}
})
}
}

case class BucketedTableTestSpec(
Expand Down
19 changes: 19 additions & 0 deletions spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,15 @@ abstract class CometTestBase
protected def checkSparkAnswerAndOperator(
df: => DataFrame,
excludedClasses: Class[_]*): Unit = {
checkSparkAnswerAndOperator(df, Seq.empty, excludedClasses: _*)
}

protected def checkSparkAnswerAndOperator(
df: => DataFrame,
includeClasses: Seq[Class[_]],
excludedClasses: Class[_]*): Unit = {
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), excludedClasses: _*)
checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), includeClasses: _*)
checkSparkAnswer(df)
}

Expand All @@ -173,6 +181,17 @@ abstract class CometTestBase
}
}

protected def checkPlanContains(plan: SparkPlan, includePlans: Class[_]*): Unit = {
includePlans.foreach { case planClass =>
if (!plan.exists(op => planClass.isAssignableFrom(op.getClass))) {
assert(
false,
s"Expected plan to contain ${planClass.getSimpleName}.\n" +
s"plan: $plan")
}
}
}

/**
* Check the answer of a Comet SQL query with Spark result using absolute tolerance.
*/
Expand Down

0 comments on commit 248cd03

Please sign in to comment.