diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index a3dc1f38a..c78d16bd9 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -139,6 +139,15 @@ object CometConf { .booleanConf .createWithDefault(false) + val COMET_EXEC_BROADCAST_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.broadcast.enabled") + .doc( + "Whether to enable broadcasting for Comet native operators. By default, " + + "this config is false. Note that this feature is not fully supported yet " + + "and only enabled for test purpose.") + .booleanConf + .createWithDefault(false) + val COMET_EXEC_SHUFFLE_CODEC: ConfigEntry[String] = conf( s"$COMET_EXEC_CONFIG_PREFIX.shuffle.codec") .doc( diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 489504def..b8b1a7a2f 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -37,12 +37,12 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported} +import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -331,6 +331,16 @@ class CometSparkSessionExtensions u } + case b: BroadcastExchangeExec + if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") && + isCometBroadCastEnabled(conf) => + QueryPlanSerde.operator2Proto(b) match { + case Some(nativeOp) => + val cometOp = CometBroadcastExchangeExec(b, b.child) + CometSinkPlaceHolder(nativeOp, b, cometOp) + case None => b + } + // Native shuffle for Comet operators case s: ShuffleExchangeExec if isCometShuffleEnabled(conf) && @@ -482,6 +492,10 @@ object CometSparkSessionExtensions extends Logging { conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) } + private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = { + COMET_EXEC_BROADCAST_ENABLED.get(conf) + } + private[comet] def isCometShuffleEnabled(conf: SQLConf): Boolean = COMET_EXEC_SHUFFLE_ENABLED.get(conf) && (conf.contains("spark.shuffle.manager") && conf.getConfString("spark.shuffle.manager") == diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a497a448c..0a251d448 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkP import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1802,6 +1802,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case _: UnionExec => true case _: ShuffleExchangeExec => true case _: TakeOrderedAndProjectExec => true + case _: BroadcastExchangeExec => true case _ => false } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala new file mode 100644 index 000000000..3b886a948 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import java.util.UUID +import java.util.concurrent.{Future, TimeoutException, TimeUnit} + +import scala.concurrent.{ExecutionContext, Promise} +import scala.concurrent.duration.NANOSECONDS +import scala.util.control.NonFatal + +import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.{SparkFatalException, ThreadUtils} +import org.apache.spark.util.io.ChunkedByteBuffer + +import com.google.common.base.Objects + +/** + * A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a + * transformed SparkPlan. This is a copy of the [[BroadcastExchangeExec]] class with the necessary + * changes to support the Comet operator. + * + * [[CometBroadcastExchangeExec]] will be used in broadcast join operator. + * + * Note that this class cannot extend `CometExec` as usual similar to other Comet operators. As + * the trait `BroadcastExchangeLike` in Spark extends abstract class `Exchange`, it limits the + * flexibility to extend `CometExec` and `Exchange` at the same time. + */ +case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan) + extends BroadcastExchangeLike { + import CometBroadcastExchangeExec._ + + override val runId: UUID = UUID.randomUUID + + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"), + "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build"), + "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast")) + + override def doCanonicalize(): SparkPlan = { + CometBroadcastExchangeExec(originalPlan.canonicalized, child.canonicalized) + } + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics("numOutputRows").value + Statistics(dataSize, Some(rowCount)) + } + + @transient + private lazy val promise = Promise[broadcast.Broadcast[Any]]() + + @transient + override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = + promise.future + + @transient + private val timeout: Long = conf.broadcastTimeout + + @transient + private lazy val maxBroadcastRows = 512000000 + + @transient + override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( + session, + CometBroadcastExchangeExec.executionContext) { + try { + // Setup a job group here so later it may get cancelled by groupId if necessary. + sparkContext.setJobGroup( + runId.toString, + s"broadcast exchange (runId $runId)", + interruptOnCancel = true) + val beforeCollect = System.nanoTime() + + val countsAndBytes = child.asInstanceOf[CometExec].getByteArrayRdd().collect() + val numRows = countsAndBytes.map(_._1).sum + val input = countsAndBytes.iterator.map(countAndBytes => countAndBytes._2) + + longMetric("numOutputRows") += numRows + if (numRows >= maxBroadcastRows) { + throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableRowsError( + maxBroadcastRows, + numRows) + } + + val beforeBuild = System.nanoTime() + longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect) + + val batches = input.toArray + + val dataSize = batches.map(_.size).sum + + longMetric("dataSize") += dataSize + if (dataSize >= MAX_BROADCAST_TABLE_BYTES) { + throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError( + MAX_BROADCAST_TABLE_BYTES, + dataSize) + } + + val beforeBroadcast = System.nanoTime() + longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild) + + // SPARK-39983 - Broadcast the relation without caching the unserialized object. + val broadcasted = sparkContext + .broadcastInternal(batches, serializedOnly = true) + .asInstanceOf[broadcast.Broadcast[Any]] + longMetric("broadcastTime") += NANOSECONDS.toMillis(System.nanoTime() - beforeBroadcast) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + promise.trySuccess(broadcasted) + broadcasted + } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. + case oe: OutOfMemoryError => + val tables = child.collect { case f: FileSourceScanExec => f.tableIdentifier }.flatten + val ex = new SparkFatalException( + QueryExecutionErrors.notEnoughMemoryToBuildAndBroadcastTableError(oe, tables)) + promise.tryFailure(ex) + throw ex + case e if !NonFatal(e) => + val ex = new SparkFatalException(e) + promise.tryFailure(ex) + throw ex + case e: Throwable => + promise.tryFailure(e) + throw e + } + } + } + + override protected def doPrepare(): Unit = { + // Materialize the future. + relationFuture + } + + override protected def doExecute(): RDD[InternalRow] = { + throw QueryExecutionErrors.executeCodePathUnsupportedError("CometBroadcastExchangeExec") + } + + override def supportsColumnar: Boolean = true + + // This is basically for unit test only. + override def executeCollect(): Array[InternalRow] = + ColumnarToRowExec(this).executeCollect() + + // This is basically for unit test only, called by `executeCollect` indirectly. + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]() + + new CometBatchRDD(sparkContext, broadcasted.value.length, broadcasted) + } + + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + try { + relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in $timeout secs.", ex) + if (!relationFuture.isDone) { + sparkContext.cancelJobGroup(runId.toString) + relationFuture.cancel(true) + } + throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex)) + } + } + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometBroadcastExchangeExec => + this.originalPlan == other.originalPlan && + this.output == other.output && this.child == other.child + case _ => + false + } + } + + override def hashCode(): Int = Objects.hashCode(output, child) + + override def stringArgs: Iterator[Any] = Iterator(output, child) + + override protected def withNewChildInternal(newChild: SparkPlan): CometBroadcastExchangeExec = + copy(child = newChild) +} + +object CometBroadcastExchangeExec { + val MAX_BROADCAST_TABLE_BYTES: Long = 8L << 30 + + private[comet] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool( + "comet-broadcast-exchange", + SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD))) +} + +/** + * [[CometBatchRDD]] is a [[RDD]] of [[ColumnarBatch]]s that are broadcasted to the executors. It + * is only used by [[CometBroadcastExchangeExec]] to broadcast the result of a Comet operator. + * + * @param sc + * SparkContext + * @param numPartitions + * number of partitions + * @param value + * the broadcasted batches which are serialized into an array of [[ChunkedByteBuffer]]s + */ +class CometBatchRDD( + sc: SparkContext, + numPartitions: Int, + value: broadcast.Broadcast[Array[ChunkedByteBuffer]]) + extends RDD[ColumnarBatch](sc, Nil) { + + override def getPartitions: Array[Partition] = (0 until numPartitions).toArray.map { i => + new CometBatchPartition(i, value) + } + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val partition = split.asInstanceOf[CometBatchPartition] + + partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator + } +} + +class CometBatchPartition( + override val index: Int, + val value: broadcast.Broadcast[Array[ChunkedByteBuffer]]) + extends Partition {} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 12f086c76..d1d6f8f20 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -71,7 +71,7 @@ abstract class CometExec extends CometPlan { /** * Executes this Comet operator and serialized output ColumnarBatch into bytes. */ - private def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = { + def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = { executeColumnar().mapPartitionsInternal { iter => val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) @@ -85,28 +85,14 @@ abstract class CometExec extends CometPlan { } } - /** - * Decodes the byte arrays back to ColumnarBatches and put them into buffer. - */ - private def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = { - if (bytes.size == 0) { - return Iterator.empty - } - - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val cbbis = bytes.toInputStream() - val ins = new DataInputStream(codec.compressedInputStream(cbbis)) - - new ArrowReaderIterator(Channels.newChannel(ins)) - } - /** * Executes the Comet operator and returns the result as an iterator of ColumnarBatch. */ def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = { val countsAndBytes = getByteArrayRdd().collect() val total = countsAndBytes.map(_._1).sum - val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeBatches(countAndBytes._2)) + val rows = countsAndBytes.iterator + .flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2)) (total, rows) } } @@ -133,6 +119,21 @@ object CometExec { val bytes = outputStream.toByteArray new CometExecIterator(newIterId, inputs, bytes, nativeMetrics) } + + /** + * Decodes the byte arrays back to ColumnarBatches and put them into buffer. + */ + def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = { + if (bytes.size == 0) { + return Iterator.empty + } + + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbis = bytes.toInputStream() + val ins = new DataInputStream(codec.compressedInputStream(cbbis)) + + new ArrowReaderIterator(Channels.newChannel(ins)) + } } /** diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index e1f864249..d3a1bd2c9 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex -import org.apache.spark.sql.comet.{CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec @@ -55,6 +55,31 @@ class CometExecSuite extends CometTestBase { } } + test("CometBroadcastExchangeExec") { + withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_b") { + val df = sql( + "SELECT tbl_a._1, tbl_b._2 FROM tbl_a JOIN tbl_b " + + "WHERE tbl_a._1 > tbl_a._2 LIMIT 2") + + val nativeBroadcast = find(df.queryExecution.executedPlan) { + case _: CometBroadcastExchangeExec => true + case _ => false + }.get.asInstanceOf[CometBroadcastExchangeExec] + + val numParts = nativeBroadcast.executeColumnar().getNumPartitions + + val rows = nativeBroadcast.executeCollect().toSeq.sortBy(row => row.getInt(0)) + val rowContents = rows.map(row => row.getInt(0)) + val expected = (0 until numParts).flatMap(_ => (0 until 5).map(i => i + 1)).sorted + + assert(rowContents === expected) + } + } + } + } + test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch results") { withSQLConf( CometConf.COMET_EXEC_ENABLED.key -> "true", diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 4f2838cfb..2b37ce035 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -35,7 +35,7 @@ import org.apache.parquet.hadoop.ParquetWriter import org.apache.parquet.hadoop.example.ExampleParquetWriter import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ -import org.apache.spark.sql.comet.{CometBatchScanExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -79,6 +79,8 @@ abstract class CometTestBase CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key -> "true", CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key -> "2g", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1g", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "1g", SQLConf.ANSI_ENABLED.key -> "false") { testFun } @@ -157,6 +159,7 @@ abstract class CometTestBase case _: CometScanExec | _: CometBatchScanExec => true case _: CometSinkPlaceHolder | _: CometScanWrapper => false case _: CometExec | _: CometShuffleExchangeExec => true + case _: CometBroadcastExchangeExec => true case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter => true case op => if (excludedClasses.exists(c => c.isAssignableFrom(op.getClass))) {