From 673bb7fd6d94060a67561ede3c4d4b22777b6fb0 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Sat, 24 Feb 2024 02:42:09 +0800 Subject: [PATCH] feat: Support CollectLimit operator --- .../comet/CometSparkSessionExtensions.scala | 14 +++ .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../ShimCometSparkSessionExtensions.scala | 6 ++ .../spark/sql/comet/CometCoalesceExec.scala | 20 +--- .../sql/comet/CometCollectLimitExec.scala | 102 ++++++++++++++++++ .../spark/sql/comet/CometExecUtils.scala | 61 +++++++++++ .../apache/comet/exec/CometExecSuite.scala | 21 +++- .../org/apache/spark/sql/CometTestBase.scala | 20 ++++ 8 files changed, 222 insertions(+), 23 deletions(-) create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index ddd5f57ff9..7761c933ab 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -278,6 +278,20 @@ class CometSparkSessionExtensions op } + case op: CollectLimitExec + if isCometNative(op.child) && isCometOperatorEnabled(conf, "collectLimit") + && isCometShuffleEnabled(conf) + && getOffset(op).getOrElse(0) == 0 => + QueryPlanSerde.operator2Proto(op) match { + case Some(nativeOp) => + val offset = getOffset(op).getOrElse(0) + val cometOp = + CometCollectLimitExec(op, op.limit, offset, op.child) + CometSinkPlaceHolder(nativeOp, op, cometOp) + case None => + op + } + case op: ExpandExec => val newOp = transform1(op) newOp match { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 938e49f73e..0938625bb4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1799,6 +1799,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case s if isCometScan(s) => true case _: CometSinkPlaceHolder => true case _: CoalesceExec => true + case _: CollectLimitExec => true case _: UnionExec => true case _: ShuffleExchangeExec => true case _: TakeOrderedAndProjectExec => true diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 8afed84ff5..3ccac5456e 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.execution.LimitExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan trait ShimCometSparkSessionExtensions { @@ -32,4 +33,9 @@ trait ShimCometSparkSessionExtensions { .map { a => a.setAccessible(true); a } .flatMap(_.get(scan).asInstanceOf[Option[Aggregation]]) .headOption + + def getOffset(limit: LimitExec): Option[Int] = limit.getClass.getDeclaredFields + .filter(_.getName == "offset") + .map { a => a.setAccessible(true); a.get(limit).asInstanceOf[Int] } + .headOption } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala index fc4f90f897..532eb91a99 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.comet -import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} @@ -42,7 +41,7 @@ case class CometCoalesceExec( if (numPartitions == 1 && rdd.getNumPartitions < 1) { // Make sure we don't output an RDD with 0 partitions, when claiming that we have a // `SinglePartition`. - new CometCoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + CometExecUtils.createEmptyColumnarRDDWithSinglePartition(sparkContext) } else { rdd.coalesce(numPartitions, shuffle = false) } @@ -67,20 +66,3 @@ case class CometCoalesceExec( override def hashCode(): Int = Objects.hashCode(numPartitions: java.lang.Integer, child) } - -object CometCoalesceExec { - - /** A simple RDD with no data, but with the given number of partitions. */ - class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int) - extends RDD[ColumnarBatch](sc, Nil) { - - override def getPartitions: Array[Partition] = - Array.tabulate(numPartitions)(i => EmptyPartition(i)) - - override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - Iterator.empty - } - } - - case class EmptyPartition(index: Int) extends Partition -} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala new file mode 100644 index 0000000000..3e377da2f1 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -0,0 +1,102 @@ +package org.apache.spark.sql.comet + +import java.util.Objects + +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} +import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Comet physical plan node for Spark `CollectExecNode`. + * + * Similar to `CometTakeOrderedAndProjectExec`, it contains two native executions seperated by a + * comet shuffle. + */ +case class CometCollectLimitExec( + override val originalPlan: SparkPlan, + limit: Int, + offset: Int, + child: SparkPlan) + extends CometExec + with UnaryExecNode { + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "shuffleReadElapsedCompute" -> + SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"), + "numPartitions" -> SQLMetrics.createMetric( + sparkContext, + "number of partitions")) ++ readMetrics ++ writeMetrics + + private lazy val serializer: Serializer = + new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + + override def executeCollect(): Array[InternalRow] = { + // note: this method is never called as `ApplyColumnarRulesAndInsertTransitions` will add a + // ColumnarToRowExec on top of it when executed via df.collect(). + // todo: find a way to fix that and leverage executeTake + ColumnarToRowExec(child).executeTake(limit) + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val childRDD = child.executeColumnar() + if (childRDD.getNumPartitions == 0) { + CometExecUtils.createEmptyColumnarRDDWithSinglePartition(sparkContext) + } else { + val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { + childRDD + } else { + val localLimitedRDD = if (limit >= 0) { + childRDD.mapPartitionsInternal { iter => + val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get + CometExec.getCometIterator(Seq(iter), limitOp) + } + } else { + childRDD + } + // Shuffle to Single Partition using Comet native shuffle + val dep = CometShuffleExchangeExec.prepareShuffleDependency( + localLimitedRDD, + child.output, + outputPartitioning, + serializer, + metrics) + metrics("numPartitions").set(dep.partitioner.numPartitions) + + new CometShuffledBatchRDD(dep, readMetrics) + } + + // todo: supports offset later + singlePartitionRDD.mapPartitionsInternal { iter => + val limitOp = CometExecUtils.getLimitNativePlan(output, limit).get + CometExec.getCometIterator(Seq(iter), limitOp) + } + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + this.copy(child = newChild) + + override def stringArgs: Iterator[Any] = Iterator(limit, offset, child) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometCollectLimitExec => + this.limit == other.limit && this.offset == other.offset && + this.child == other.child + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(limit: java.lang.Integer, offset: java.lang.Integer, child) +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala new file mode 100644 index 0000000000..98bb244a3b --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -0,0 +1,61 @@ +package org.apache.spark.sql.comet + +import scala.collection.JavaConverters._ + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.serde.OperatorOuterClass +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.serializeDataType + +object CometExecUtils { + + def createEmptyColumnarRDDWithSinglePartition( + sparkContext: SparkContext): RDD[ColumnarBatch] = { + new EmptyRDDWithPartitions(sparkContext, 1) + } + + /** + * Prepare Limit native plan for Comet operators which take the first `limit` elements of each + * child partition + */ + def getLimitNativePlan(outputAttributes: Seq[Attribute], limit: Int): Option[Operator] = { + val scanBuilder = OperatorOuterClass.Scan.newBuilder() + val scanOpBuilder = OperatorOuterClass.Operator.newBuilder() + + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == outputAttributes.length) { + scanBuilder.addAllFields(scanTypes.asJava) + + val limitBuilder = OperatorOuterClass.Limit.newBuilder() + limitBuilder.setLimit(limit) + + val limitOpBuilder = OperatorOuterClass.Operator + .newBuilder() + .addChildren(scanOpBuilder.setScan(scanBuilder)) + Some(limitOpBuilder.setLimit(limitBuilder).build()) + } else { + None + } + } +} + +/** A simple RDD with no data, but with the given number of partitions. */ +private class EmptyRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int) + extends RDD[ColumnarBatch](sc, Nil) { + + override def getPartitions: Array[Partition] = + Array.tabulate(numPartitions)(i => EmptyPartition(i)) + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + Iterator.empty + } +} + +private case class EmptyPartition(index: Int) extends Partition diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index d3a1bd2c95..d6244a44e7 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -22,16 +22,14 @@ package org.apache.comet.exec import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random - import org.scalactic.source.Position import org.scalatest.Tag - import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame, DataFrameWriter, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec @@ -40,7 +38,6 @@ import org.apache.spark.sql.functions.{date_add, expr} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String - import org.apache.comet.CometConf class CometExecSuite extends CometTestBase { @@ -964,6 +961,22 @@ class CometExecSuite extends CometTestBase { } } } + + test("collect limit") { + Seq("true", "false").foreach(aqe => { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe) { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { + val df = sql("SELECT _1 as id, _2 as value FROM tbl limit 2") + assert(df.rdd.getNumPartitions === 1) + checkSparkAnswerAndOperator(df, Seq(classOf[CometCollectLimitExec])) + assert(df.collect().length === 2) + // follow up native operation is possible + val df3 = df.groupBy("id").sum("value") + checkSparkAnswerAndOperator(df3) + } + } + }) + } } case class BucketedTableTestSpec( diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 2e523fa5a4..9d12769c2e 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -67,6 +67,7 @@ abstract class CometTestBase conf.set(SQLConf.SHUFFLE_PARTITIONS, 10) // reduce parallelism in tests conf.set("spark.shuffle.manager", shuffleManager) conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") + conf.set("spark.web.ui.enabled", "true") conf } @@ -150,7 +151,15 @@ abstract class CometTestBase protected def checkSparkAnswerAndOperator( df: => DataFrame, excludedClasses: Class[_]*): Unit = { + checkSparkAnswerAndOperator(df, Seq.empty, excludedClasses: _*) + } + + protected def checkSparkAnswerAndOperator( + df: => DataFrame, + includeClasses: Seq[Class[_]], + excludedClasses: Class[_]*): Unit = { checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), excludedClasses: _*) + checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), includeClasses: _*) checkSparkAnswer(df) } @@ -173,6 +182,17 @@ abstract class CometTestBase } } + protected def checkPlanContains(plan: SparkPlan, includePlans: Class[_]*): Unit = { + includePlans.foreach { case planClass => + if (!plan.exists(op => planClass.isAssignableFrom(op.getClass))) { + assert( + false, + s"Expected plan to contain ${planClass.getSimpleName}.\n" + + s"plan: $plan") + } + } + } + /** * Check the answer of a Comet SQL query with Spark result using absolute tolerance. */