Skip to content

Commit

Permalink
feat: Add COMET_SHUFFLE_MODE config to control Comet shuffle mode
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed May 22, 2024
1 parent 7b0a7e0 commit 1a17c5f
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 177 deletions.
21 changes: 13 additions & 8 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.comet

import java.util.Locale
import java.util.concurrent.TimeUnit

import scala.collection.mutable.ListBuffer
Expand Down Expand Up @@ -131,14 +132,18 @@ object CometConf {
.booleanConf
.createWithDefault(false)

val COMET_COLUMNAR_SHUFFLE_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.columnar.shuffle.enabled")
.doc(
"Whether to enable Arrow-based columnar shuffle for Comet and Spark regular operators. " +
"If this is enabled, Comet prefers columnar shuffle than native shuffle. " +
"By default, this config is true.")
.booleanConf
.createWithDefault(true)
val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode")
.doc(
"The mode of Comet shuffle. This config is only effective only if Comet shuffle " +
"is enabled. Available modes are 'native', 'jvm', and 'auto'. " +
"'native' is for native shuffle which has best performance in general." +
"'jvm' is for jvm-based columnar shuffle which has higher coverage than native shuffle." +
"'auto' is for Comet to choose the best shuffle mode based on the query plan." +
"By default, this config is 'jvm'.")
.stringConf
.transform(_.toLowerCase(Locale.ROOT))
.checkValues(Set("native", "jvm", "auto"))
.createWithDefault("jvm")

val COMET_SHUFFLE_ENFORCE_MODE_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.shuffle.enforceMode.enabled")
Expand Down
196 changes: 89 additions & 107 deletions spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf._
import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, shouldApplyRowToColumnar, withInfo, withInfos}
import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometShuffleMode, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, shouldApplyRowToColumnar, withInfo, withInfos}
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde
Expand Down Expand Up @@ -194,30 +194,6 @@ class CometSparkSessionExtensions
}

case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case s: ShuffleExchangeExec
if isCometPlan(s.child) && !isCometColumnarShuffleEnabled(conf) &&
QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1 =>
logInfo("Comet extension enabled for Native Shuffle")

// Switch to use Decimal128 regardless of precision, since Arrow native execution
// doesn't support Decimal32 and Decimal64 yet.
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)

// Columnar shuffle for regular Spark operators (not Comet) and Comet operators
// (if configured)
case s: ShuffleExchangeExec
if (!s.child.supportsColumnar || isCometPlan(
s.child)) && isCometColumnarShuffleEnabled(conf) &&
QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 &&
!isShuffleOperator(s.child) =>
logInfo("Comet extension enabled for JVM Columnar Shuffle")
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
}
}

private def isCometPlan(op: SparkPlan): Boolean = op.isInstanceOf[CometPlan]

private def isCometNative(op: SparkPlan): Boolean = op.isInstanceOf[CometNativeExec]
Expand Down Expand Up @@ -641,7 +617,7 @@ class CometSparkSessionExtensions
// Native shuffle for Comet operators
case s: ShuffleExchangeExec
if isCometShuffleEnabled(conf) &&
!isCometColumnarShuffleEnabled(conf) &&
!getCometShuffleMode(conf) == JVMShuffle &&
QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1 =>
logInfo("Comet extension enabled for Native Shuffle")

Expand All @@ -662,7 +638,7 @@ class CometSparkSessionExtensions
// If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
// convert it to CometColumnarShuffle,
case s: ShuffleExchangeExec
if isCometShuffleEnabled(conf) && isCometColumnarShuffleEnabled(conf) &&
if isCometShuffleEnabled(conf) && getCometShuffleMode(conf) != NativeShuffle &&
QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 &&
!isShuffleOperator(s.child) =>
logInfo("Comet extension enabled for JVM Columnar Shuffle")
Expand All @@ -684,19 +660,19 @@ class CometSparkSessionExtensions
case s: ShuffleExchangeExec =>
val isShuffleEnabled = isCometShuffleEnabled(conf)
val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available")
val msg1 = createMessage(!isShuffleEnabled, s"Native shuffle is not enabled: $reason")
val columnarShuffleEnabled = isCometColumnarShuffleEnabled(conf)
val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason")
val columnarShuffleEnabled = getCometShuffleMode(conf) == JVMShuffle
val msg2 = createMessage(
isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde
.supportPartitioning(s.child.output, s.outputPartitioning)
._1,
"Shuffle: " +
"Native shuffle: " +
s"${QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._2}")
val msg3 = createMessage(
isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde
.supportPartitioningTypes(s.child.output)
._1,
s"Columnar shuffle: ${QueryPlanSerde.supportPartitioningTypes(s.child.output)._2}")
s"JVM shuffle: ${QueryPlanSerde.supportPartitioningTypes(s.child.output)._2}")
withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(","))
s

Expand Down Expand Up @@ -726,87 +702,79 @@ class CometSparkSessionExtensions
// We shouldn't transform Spark query plan if Comet is disabled.
if (!isCometEnabled(conf)) return plan

if (!isCometExecEnabled(conf)) {
// Comet exec is disabled, but for Spark shuffle, we still can use Comet columnar shuffle
if (isCometShuffleEnabled(conf)) {
applyCometShuffle(plan)
} else {
plan
}
} else {
var newPlan = transform(plan)

// if the plan cannot be run fully natively then explain why (when appropriate
// config is enabled)
if (CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.get()) {
new ExtendedExplainInfo().extensionInfo(newPlan) match {
case reasons if reasons.size == 1 =>
logWarning(
"Comet cannot execute some parts of this plan natively " +
s"because ${reasons.head}")
case reasons if reasons.size > 1 =>
logWarning(
"Comet cannot execute some parts of this plan natively" +
s" because:\n\t- ${reasons.mkString("\n\t- ")}")
case _ =>
// no reasons recorded
}
}

// Remove placeholders
newPlan = newPlan.transform {
case CometSinkPlaceHolder(_, _, s) => s
case CometScanWrapper(_, s) => s
var newPlan = transform(plan)

// if the plan cannot be run fully natively then explain why (when appropriate
// config is enabled)
if (CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.get()) {
new ExtendedExplainInfo().extensionInfo(newPlan) match {
case reasons if reasons.size == 1 =>
logWarning(
"Comet cannot execute some parts of this plan natively " +
s"because ${reasons.head}")
case reasons if reasons.size > 1 =>
logWarning(
"Comet cannot execute some parts of this plan natively" +
s" because:\n\t- ${reasons.mkString("\n\t- ")}")
case _ =>
// no reasons recorded
}
}

// Set up logical links
newPlan = newPlan.transform {
case op: CometExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
case op: CometShuffleExchangeExec =>
// Original Spark shuffle exchange operator might have empty logical link.
// But the `setLogicalLink` call above on downstream operator of
// `CometShuffleExchangeExec` will set its logical link to the downstream
// operators which cause AQE behavior to be incorrect. So we need to unset
// the logical link here.
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
// Remove placeholders
newPlan = newPlan.transform {
case CometSinkPlaceHolder(_, _, s) => s
case CometScanWrapper(_, s) => s
}

case op: CometBroadcastExchangeExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
}
// Set up logical links
newPlan = newPlan.transform {
case op: CometExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
case op: CometShuffleExchangeExec =>
// Original Spark shuffle exchange operator might have empty logical link.
// But the `setLogicalLink` call above on downstream operator of
// `CometShuffleExchangeExec` will set its logical link to the downstream
// operators which cause AQE behavior to be incorrect. So we need to unset
// the logical link here.
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op

// Convert native execution block by linking consecutive native operators.
var firstNativeOp = true
newPlan.transformDown {
case op: CometNativeExec =>
if (firstNativeOp) {
firstNativeOp = false
op.convertBlock()
} else {
op
}
case op =>
firstNativeOp = true
case op: CometBroadcastExchangeExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
}

// Convert native execution block by linking consecutive native operators.
var firstNativeOp = true
newPlan.transformDown {
case op: CometNativeExec =>
if (firstNativeOp) {
firstNativeOp = false
op.convertBlock()
} else {
op
}
}
case op =>
firstNativeOp = true
op
}
}

Expand Down Expand Up @@ -966,8 +934,16 @@ object CometSparkSessionExtensions extends Logging {
COMET_EXEC_ENABLED.get(conf)
}

private[comet] def getCometShuffleMode(conf: SQLConf): CometShuffleType = {
COMET_SHUFFLE_MODE.get(conf) match {
case "jvm" => JVMShuffle
case "native" => NativeShuffle
case _ => AutoShuffle
}
}

private[comet] def isCometColumnarShuffleEnabled(conf: SQLConf): Boolean = {
COMET_COLUMNAR_SHUFFLE_ENABLED.get(conf)
COMET_SHUFFLE_MODE.get(conf).equalsIgnoreCase("jvm")
}

private[comet] def isCometAllOperatorEnabled(conf: SQLConf): Boolean = {
Expand Down Expand Up @@ -1138,3 +1114,9 @@ object CometSparkSessionExtensions extends Logging {
}
}
}

sealed abstract class CometShuffleType

case object AutoShuffle extends CometShuffleType
case object JVMShuffle extends CometShuffleType
case object NativeShuffle extends CometShuffleType
Loading

0 comments on commit 1a17c5f

Please sign in to comment.