Skip to content

Commit

Permalink
Revert "fix: Remove original plan parameter from CometNativeExec"
Browse files Browse the repository at this point in the history
This reverts commit b272551.
  • Loading branch information
viirya committed Jul 8, 2024
1 parent b272551 commit 722dc07
Show file tree
Hide file tree
Showing 145 changed files with 2,285 additions and 2,312 deletions.
178 changes: 60 additions & 118 deletions spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,13 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val newPlan = CometProjectExec(
CometProjectExec(
nativeOp,
op.output,
op,
op.projectList,
op.output,
op.child,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -343,9 +343,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val newPlan =
CometFilterExec(nativeOp, op.output, op.condition, op.child, SerializedPlan(None))
setLogicalLink(newPlan, op)
CometFilterExec(nativeOp, op, op.condition, op.child, SerializedPlan(None))
case None =>
op
}
Expand All @@ -354,15 +352,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val newPlan =
CometSortExec(
nativeOp,
op.output,
op.outputOrdering,
op.sortOrder,
op.child,
SerializedPlan(None))
setLogicalLink(newPlan, op)
CometSortExec(nativeOp, op, op.sortOrder, op.child, SerializedPlan(None))
case None =>
op
}
Expand All @@ -371,9 +361,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val newPlan =
CometLocalLimitExec(nativeOp, op.limit, op.child, SerializedPlan(None))
setLogicalLink(newPlan, op)
CometLocalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None))
case None =>
op
}
Expand All @@ -382,9 +370,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val newPlan =
CometGlobalLimitExec(nativeOp, op.limit, op.child, SerializedPlan(None))
setLogicalLink(newPlan, op)
CometGlobalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None))
case None =>
op
}
Expand All @@ -396,9 +382,8 @@ class CometSparkSessionExtensions
QueryPlanSerde.operator2Proto(op) match {
case Some(nativeOp) =>
val offset = getOffset(op)
val newPlan =
CometCollectLimitExec(op.limit, offset, op.child)
val cometOp = setLogicalLink(newPlan, op)
val cometOp =
CometCollectLimitExec(op, op.limit, offset, op.child)
CometSinkPlaceHolder(nativeOp, op, cometOp)
case None =>
op
Expand All @@ -408,28 +393,12 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val newPlan =
CometExpandExec(
nativeOp,
op.output,
op.projections,
op.child,
SerializedPlan(None))
setLogicalLink(newPlan, op)
CometExpandExec(nativeOp, op, op.projections, op.child, SerializedPlan(None))
case None =>
op
}

case op @ HashAggregateExec(
_,
_,
_,
groupingExprs,
aggExprs,
_,
_,
resultExpressions,
child) =>
case op @ HashAggregateExec(_, _, _, groupingExprs, aggExprs, _, _, _, child) =>
val modes = aggExprs.map(_.mode).distinct

if (!modes.isEmpty && modes.size != 1) {
Expand All @@ -453,17 +422,15 @@ class CometSparkSessionExtensions
// modes is empty too. If aggExprs is not empty, we need to verify all the
// aggregates have the same mode.
assert(modes.length == 1 || modes.length == 0)
val newPlan = CometHashAggregateExec(
CometHashAggregateExec(
nativeOp,
op.output,
op,
groupingExprs,
aggExprs,
resultExpressions,
child.output,
if (modes.nonEmpty) Some(modes.head) else None,
child,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -476,10 +443,9 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val newPlan = CometHashJoinExec(
CometHashJoinExec(
nativeOp,
op.output,
op.outputOrdering,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
Expand All @@ -488,7 +454,6 @@ class CometSparkSessionExtensions
op.left,
op.right,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -510,10 +475,9 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val newPlan = CometBroadcastHashJoinExec(
CometBroadcastHashJoinExec(
nativeOp,
op.output,
op.outputOrdering,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
Expand All @@ -522,7 +486,6 @@ class CometSparkSessionExtensions
op.left,
op.right,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -533,18 +496,16 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val newPlan = CometSortMergeJoinExec(
CometSortMergeJoinExec(
nativeOp,
op.output,
op.outputOrdering,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
op.condition,
op.left,
op.right,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand Down Expand Up @@ -574,8 +535,7 @@ class CometSparkSessionExtensions
&& isCometNative(child) =>
QueryPlanSerde.operator2Proto(c) match {
case Some(nativeOp) =>
val newPlan = CometCoalesceExec(c.output, numPartitions, child)
val cometOp = setLogicalLink(newPlan, c)
val cometOp = CometCoalesceExec(c, numPartitions, child)
CometSinkPlaceHolder(nativeOp, c, cometOp)
case None =>
c
Expand All @@ -598,14 +558,8 @@ class CometSparkSessionExtensions
CometTakeOrderedAndProjectExec.isSupported(s) =>
QueryPlanSerde.operator2Proto(s) match {
case Some(nativeOp) =>
val newPlan =
CometTakeOrderedAndProjectExec(
s.output,
s.limit,
s.sortOrder,
s.projectList,
s.child)
val cometOp = setLogicalLink(newPlan, s)
val cometOp =
CometTakeOrderedAndProjectExec(s, s.limit, s.sortOrder, s.projectList, s.child)
CometSinkPlaceHolder(nativeOp, s, cometOp)
case None =>
s
Expand All @@ -625,14 +579,8 @@ class CometSparkSessionExtensions
val newOp = transform1(w)
newOp match {
case Some(nativeOp) =>
val newPlan =
CometWindowExec(
w.output,
w.windowExpression,
w.partitionSpec,
w.orderSpec,
w.child)
val cometOp = setLogicalLink(newPlan, w)
val cometOp =
CometWindowExec(w, w.windowExpression, w.partitionSpec, w.orderSpec, w.child)
CometSinkPlaceHolder(nativeOp, w, cometOp)
case None =>
w
Expand All @@ -643,8 +591,7 @@ class CometSparkSessionExtensions
u.children.forall(isCometNative) =>
QueryPlanSerde.operator2Proto(u) match {
case Some(nativeOp) =>
val newPlan = CometUnionExec(u.output, u.children)
val cometOp = setLogicalLink(newPlan, u)
val cometOp = CometUnionExec(u, u.children)
CometSinkPlaceHolder(nativeOp, u, cometOp)
case None =>
u
Expand Down Expand Up @@ -684,8 +631,7 @@ class CometSparkSessionExtensions
isSpark34Plus => // Spark 3.4+ only
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
val newPlan = CometBroadcastExchangeExec(b.output, b.child)
val cometOp = setLogicalLink(newPlan, b)
val cometOp = CometBroadcastExchangeExec(b, b.child)
CometSinkPlaceHolder(nativeOp, b, cometOp)
case None => b
}
Expand Down Expand Up @@ -876,6 +822,40 @@ class CometSparkSessionExtensions
case CometScanWrapper(_, s) => s
}

// Set up logical links
newPlan = newPlan.transform {
case op: CometExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
case op: CometShuffleExchangeExec =>
// Original Spark shuffle exchange operator might have empty logical link.
// But the `setLogicalLink` call above on downstream operator of
// `CometShuffleExchangeExec` will set its logical link to the downstream
// operators which cause AQE behavior to be incorrect. So we need to unset
// the logical link here.
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op

case op: CometBroadcastExchangeExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
}

// Convert native execution block by linking consecutive native operators.
var firstNativeOp = true
newPlan.transformDown {
Expand Down Expand Up @@ -907,44 +887,6 @@ class CometSparkSessionExtensions
}.flatten
}

/**
* Set up logical links for transformed Comet operators.
*/
def setLogicalLink(newPlan: SparkPlan, originalPlan: SparkPlan): SparkPlan = {
newPlan match {
case op: CometExec =>
if (originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
case op: CometShuffleExchangeExec =>
// Original Spark shuffle exchange operator might have empty logical link.
// But the `setLogicalLink` call above on downstream operator of
// `CometShuffleExchangeExec` will set its logical link to the downstream
// operators which cause AQE behavior to be incorrect. So we need to unset
// the logical link here.
if (originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op

case op: CometBroadcastExchangeExec =>
if (originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
}
}

/**
* Returns true if a given spark plan is Comet shuffle operator.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
Expand Down Expand Up @@ -61,9 +60,7 @@ import org.apache.comet.CometRuntimeException
* Note that this only supports Spark 3.4 and later, because the serialization class
* `ChunkedByteBuffer` is only serializable in Spark 3.4 and later.
*/
case class CometBroadcastExchangeExec(
override val output: Seq[Attribute],
override val child: SparkPlan)
case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
extends BroadcastExchangeLike
with ShimCometBroadcastExchangeExec {
import CometBroadcastExchangeExec._
Expand All @@ -77,6 +74,10 @@ case class CometBroadcastExchangeExec(
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build"),
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"))

override def doCanonicalize(): SparkPlan = {
CometBroadcastExchangeExec(originalPlan.canonicalized, child.canonicalized)
}

override def runtimeStatistics: Statistics = {
val dataSize = metrics("dataSize").value
val rowCount = metrics("numOutputRows").value
Expand Down Expand Up @@ -236,7 +237,7 @@ case class CometBroadcastExchangeExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometBroadcastExchangeExec =>
this.output == other.output &&
this.originalPlan == other.originalPlan &&
this.child == other.child
case _ =>
false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.apache.spark.sql.comet

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -32,7 +31,7 @@ import com.google.common.base.Objects
* more efficient when including it in a Comet query plan.
*/
case class CometCoalesceExec(
override val output: Seq[Attribute],
override val originalPlan: SparkPlan,
numPartitions: Int,
child: SparkPlan)
extends CometExec
Expand Down
Loading

0 comments on commit 722dc07

Please sign in to comment.