Skip to content

Commit

Permalink
feat: Add CometBroadcastExchangeExec to support broadcasting the resu…
Browse files Browse the repository at this point in the history
…lt of Comet native operator (#80)
  • Loading branch information
viirya authored Feb 21, 2024
1 parent 4bbc307 commit 637dba9
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 22 deletions.
9 changes: 9 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ object CometConf {
.booleanConf
.createWithDefault(false)

val COMET_EXEC_BROADCAST_ENABLED: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.broadcast.enabled")
.doc(
"Whether to enable broadcasting for Comet native operators. By default, " +
"this config is false. Note that this feature is not fully supported yet " +
"and only enabled for test purpose.")
.booleanConf
.createWithDefault(false)

val COMET_EXEC_SHUFFLE_CODEC: ConfigEntry[String] = conf(
s"$COMET_EXEC_CONFIG_PREFIX.shuffle.codec")
.doc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf._
import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported}
import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported}
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde
Expand Down Expand Up @@ -331,6 +331,16 @@ class CometSparkSessionExtensions
u
}

case b: BroadcastExchangeExec
if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") &&
isCometBroadCastEnabled(conf) =>
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
val cometOp = CometBroadcastExchangeExec(b, b.child)
CometSinkPlaceHolder(nativeOp, b, cometOp)
case None => b
}

// Native shuffle for Comet operators
case s: ShuffleExchangeExec
if isCometShuffleEnabled(conf) &&
Expand Down Expand Up @@ -482,6 +492,10 @@ object CometSparkSessionExtensions extends Logging {
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf)
}

private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = {
COMET_EXEC_BROADCAST_ENABLED.get(conf)
}

private[comet] def isCometShuffleEnabled(conf: SQLConf): Boolean =
COMET_EXEC_SHUFFLE_ENABLED.get(conf) &&
(conf.contains("spark.shuffle.manager") && conf.getConfString("spark.shuffle.manager") ==
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkP
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -1802,6 +1802,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
case _: UnionExec => true
case _: ShuffleExchangeExec => true
case _: TakeOrderedAndProjectExec => true
case _: BroadcastExchangeExec => true
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet

import java.util.UUID
import java.util.concurrent.{Future, TimeoutException, TimeUnit}

import scala.concurrent.{ExecutionContext, Promise}
import scala.concurrent.duration.NANOSECONDS
import scala.util.control.NonFatal

import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SparkFatalException, ThreadUtils}
import org.apache.spark.util.io.ChunkedByteBuffer

import com.google.common.base.Objects

/**
* A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a
* transformed SparkPlan. This is a copy of the [[BroadcastExchangeExec]] class with the necessary
* changes to support the Comet operator.
*
* [[CometBroadcastExchangeExec]] will be used in broadcast join operator.
*
* Note that this class cannot extend `CometExec` as usual similar to other Comet operators. As
* the trait `BroadcastExchangeLike` in Spark extends abstract class `Exchange`, it limits the
* flexibility to extend `CometExec` and `Exchange` at the same time.
*/
case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
extends BroadcastExchangeLike {
import CometBroadcastExchangeExec._

override val runId: UUID = UUID.randomUUID

override lazy val metrics: Map[String, SQLMetric] = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"),
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build"),
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"))

override def doCanonicalize(): SparkPlan = {
CometBroadcastExchangeExec(originalPlan.canonicalized, child.canonicalized)
}

override def runtimeStatistics: Statistics = {
val dataSize = metrics("dataSize").value
val rowCount = metrics("numOutputRows").value
Statistics(dataSize, Some(rowCount))
}

@transient
private lazy val promise = Promise[broadcast.Broadcast[Any]]()

@transient
override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] =
promise.future

@transient
private val timeout: Long = conf.broadcastTimeout

@transient
private lazy val maxBroadcastRows = 512000000

@transient
override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
session,
CometBroadcastExchangeExec.executionContext) {
try {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(
runId.toString,
s"broadcast exchange (runId $runId)",
interruptOnCancel = true)
val beforeCollect = System.nanoTime()

val countsAndBytes = child.asInstanceOf[CometExec].getByteArrayRdd().collect()
val numRows = countsAndBytes.map(_._1).sum
val input = countsAndBytes.iterator.map(countAndBytes => countAndBytes._2)

longMetric("numOutputRows") += numRows
if (numRows >= maxBroadcastRows) {
throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableRowsError(
maxBroadcastRows,
numRows)
}

val beforeBuild = System.nanoTime()
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)

val batches = input.toArray

val dataSize = batches.map(_.size).sum

longMetric("dataSize") += dataSize
if (dataSize >= MAX_BROADCAST_TABLE_BYTES) {
throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError(
MAX_BROADCAST_TABLE_BYTES,
dataSize)
}

val beforeBroadcast = System.nanoTime()
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)

// SPARK-39983 - Broadcast the relation without caching the unserialized object.
val broadcasted = sparkContext
.broadcastInternal(batches, serializedOnly = true)
.asInstanceOf[broadcast.Broadcast[Any]]
longMetric("broadcastTime") += NANOSECONDS.toMillis(System.nanoTime() - beforeBroadcast)
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
promise.trySuccess(broadcasted)
broadcasted
} catch {
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
// will catch this exception and re-throw the wrapped fatal throwable.
case oe: OutOfMemoryError =>
val tables = child.collect { case f: FileSourceScanExec => f.tableIdentifier }.flatten
val ex = new SparkFatalException(
QueryExecutionErrors.notEnoughMemoryToBuildAndBroadcastTableError(oe, tables))
promise.tryFailure(ex)
throw ex
case e if !NonFatal(e) =>
val ex = new SparkFatalException(e)
promise.tryFailure(ex)
throw ex
case e: Throwable =>
promise.tryFailure(e)
throw e
}
}
}

override protected def doPrepare(): Unit = {
// Materialize the future.
relationFuture
}

override protected def doExecute(): RDD[InternalRow] = {
throw QueryExecutionErrors.executeCodePathUnsupportedError("CometBroadcastExchangeExec")
}

override def supportsColumnar: Boolean = true

// This is basically for unit test only.
override def executeCollect(): Array[InternalRow] =
ColumnarToRowExec(this).executeCollect()

// This is basically for unit test only, called by `executeCollect` indirectly.
override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]()

new CometBatchRDD(sparkContext, broadcasted.value.length, broadcasted)
}

override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
try {
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
} catch {
case ex: TimeoutException =>
logError(s"Could not execute broadcast in $timeout secs.", ex)
if (!relationFuture.isDone) {
sparkContext.cancelJobGroup(runId.toString)
relationFuture.cancel(true)
}
throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex))
}
}

override def equals(obj: Any): Boolean = {
obj match {
case other: CometBroadcastExchangeExec =>
this.originalPlan == other.originalPlan &&
this.output == other.output && this.child == other.child
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(output, child)

override def stringArgs: Iterator[Any] = Iterator(output, child)

override protected def withNewChildInternal(newChild: SparkPlan): CometBroadcastExchangeExec =
copy(child = newChild)
}

object CometBroadcastExchangeExec {
val MAX_BROADCAST_TABLE_BYTES: Long = 8L << 30

private[comet] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool(
"comet-broadcast-exchange",
SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD)))
}

/**
* [[CometBatchRDD]] is a [[RDD]] of [[ColumnarBatch]]s that are broadcasted to the executors. It
* is only used by [[CometBroadcastExchangeExec]] to broadcast the result of a Comet operator.
*
* @param sc
* SparkContext
* @param numPartitions
* number of partitions
* @param value
* the broadcasted batches which are serialized into an array of [[ChunkedByteBuffer]]s
*/
class CometBatchRDD(
sc: SparkContext,
numPartitions: Int,
value: broadcast.Broadcast[Array[ChunkedByteBuffer]])
extends RDD[ColumnarBatch](sc, Nil) {

override def getPartitions: Array[Partition] = (0 until numPartitions).toArray.map { i =>
new CometBatchPartition(i, value)
}

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]

partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator
}
}

class CometBatchPartition(
override val index: Int,
val value: broadcast.Broadcast[Array[ChunkedByteBuffer]])
extends Partition {}
35 changes: 18 additions & 17 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ abstract class CometExec extends CometPlan {
/**
* Executes this Comet operator and serialized output ColumnarBatch into bytes.
*/
private def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = {
def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = {
executeColumnar().mapPartitionsInternal { iter =>
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
Expand All @@ -85,28 +85,14 @@ abstract class CometExec extends CometPlan {
}
}

/**
* Decodes the byte arrays back to ColumnarBatches and put them into buffer.
*/
private def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}

val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))

new ArrowReaderIterator(Channels.newChannel(ins))
}

/**
* Executes the Comet operator and returns the result as an iterator of ColumnarBatch.
*/
def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = {
val countsAndBytes = getByteArrayRdd().collect()
val total = countsAndBytes.map(_._1).sum
val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeBatches(countAndBytes._2))
val rows = countsAndBytes.iterator
.flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2))
(total, rows)
}
}
Expand All @@ -133,6 +119,21 @@ object CometExec {
val bytes = outputStream.toByteArray
new CometExecIterator(newIterId, inputs, bytes, nativeMetrics)
}

/**
* Decodes the byte arrays back to ColumnarBatches and put them into buffer.
*/
def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}

val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))

new ArrowReaderIterator(Channels.newChannel(ins))
}
}

/**
Expand Down
Loading

0 comments on commit 637dba9

Please sign in to comment.