Skip to content

Commit

Permalink
fix: Avoid creating huge duplicate of canonicalized plans for CometNa…
Browse files Browse the repository at this point in the history
…tiveExec (#639)

* fix: Remove original plan parameter from CometNativeExec

* Revert "fix: Remove original plan parameter from CometNativeExec"

This reverts commit b272551.

* More

* Revert "Revert "fix: Remove original plan parameter from CometNativeExec""

This reverts commit 722dc07.

* More

* More

* Fix

* Fix diffs

* Update
  • Loading branch information
viirya authored Jul 8, 2024
1 parent 8f4427a commit b924aeb
Show file tree
Hide file tree
Showing 415 changed files with 6,380 additions and 6,302 deletions.
4 changes: 2 additions & 2 deletions dev/diffs/3.4.3.diff
Original file line number Diff line number Diff line change
Expand Up @@ -2491,8 +2491,8 @@ index dd55fcfe42c..293e9dc2986 100644
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
+ case CometFilterExec(_, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, child, _), _) => child
+ case CometFilterExec(_, _, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child
}

spark.internalCreateDataFrame(withoutFilters.execute(), schema)
Expand Down
4 changes: 2 additions & 2 deletions dev/diffs/3.5.1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -2650,8 +2650,8 @@ index dd55fcfe42c..293e9dc2986 100644
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
+ case CometFilterExec(_, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, child, _), _) => child
+ case CometFilterExec(_, _, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child
}

spark.internalCreateDataFrame(withoutFilters.execute(), schema)
Expand Down
4 changes: 2 additions & 2 deletions dev/diffs/4.0.0-preview1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -2651,8 +2651,8 @@ index 5fbf379644f..32711763ec1 100644
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
+ case CometFilterExec(_, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, child, _), _) => child
+ case CometFilterExec(_, _, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child
}

spark.internalCreateDataFrame(withoutFilters.execute(), schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ class CometSparkSessionExtensions
CometProjectExec(
nativeOp,
op,
op.projectList,
op.output,
op.projectList,
op.child,
SerializedPlan(None))
case None =>
Expand All @@ -343,7 +343,13 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometFilterExec(nativeOp, op, op.condition, op.child, SerializedPlan(None))
CometFilterExec(
nativeOp,
op,
op.output,
op.condition,
op.child,
SerializedPlan(None))
case None =>
op
}
Expand All @@ -352,7 +358,14 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometSortExec(nativeOp, op, op.sortOrder, op.child, SerializedPlan(None))
CometSortExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.sortOrder,
op.child,
SerializedPlan(None))
case None =>
op
}
Expand Down Expand Up @@ -393,12 +406,27 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometExpandExec(nativeOp, op, op.projections, op.child, SerializedPlan(None))
CometExpandExec(
nativeOp,
op,
op.output,
op.projections,
op.child,
SerializedPlan(None))
case None =>
op
}

case op @ HashAggregateExec(_, _, _, groupingExprs, aggExprs, _, _, _, child) =>
case op @ HashAggregateExec(
_,
_,
_,
groupingExprs,
aggExprs,
_,
_,
resultExpressions,
child) =>
val modes = aggExprs.map(_.mode).distinct

if (!modes.isEmpty && modes.size != 1) {
Expand All @@ -425,8 +453,10 @@ class CometSparkSessionExtensions
CometHashAggregateExec(
nativeOp,
op,
op.output,
groupingExprs,
aggExprs,
resultExpressions,
child.output,
if (modes.nonEmpty) Some(modes.head) else None,
child,
Expand All @@ -446,6 +476,8 @@ class CometSparkSessionExtensions
CometHashJoinExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.leftKeys,
op.rightKeys,
op.joinType,
Expand Down Expand Up @@ -478,6 +510,8 @@ class CometSparkSessionExtensions
CometBroadcastHashJoinExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.leftKeys,
op.rightKeys,
op.joinType,
Expand All @@ -499,6 +533,8 @@ class CometSparkSessionExtensions
CometSortMergeJoinExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.leftKeys,
op.rightKeys,
op.joinType,
Expand Down Expand Up @@ -535,7 +571,7 @@ class CometSparkSessionExtensions
&& isCometNative(child) =>
QueryPlanSerde.operator2Proto(c) match {
case Some(nativeOp) =>
val cometOp = CometCoalesceExec(c, numPartitions, child)
val cometOp = CometCoalesceExec(c, c.output, numPartitions, child)
CometSinkPlaceHolder(nativeOp, c, cometOp)
case None =>
c
Expand All @@ -559,7 +595,13 @@ class CometSparkSessionExtensions
QueryPlanSerde.operator2Proto(s) match {
case Some(nativeOp) =>
val cometOp =
CometTakeOrderedAndProjectExec(s, s.limit, s.sortOrder, s.projectList, s.child)
CometTakeOrderedAndProjectExec(
s,
s.output,
s.limit,
s.sortOrder,
s.projectList,
s.child)
CometSinkPlaceHolder(nativeOp, s, cometOp)
case None =>
s
Expand All @@ -580,7 +622,13 @@ class CometSparkSessionExtensions
newOp match {
case Some(nativeOp) =>
val cometOp =
CometWindowExec(w, w.windowExpression, w.partitionSpec, w.orderSpec, w.child)
CometWindowExec(
w,
w.output,
w.windowExpression,
w.partitionSpec,
w.orderSpec,
w.child)
CometSinkPlaceHolder(nativeOp, w, cometOp)
case None =>
w
Expand All @@ -591,7 +639,7 @@ class CometSparkSessionExtensions
u.children.forall(isCometNative) =>
QueryPlanSerde.operator2Proto(u) match {
case Some(nativeOp) =>
val cometOp = CometUnionExec(u, u.children)
val cometOp = CometUnionExec(u, u.output, u.children)
CometSinkPlaceHolder(nativeOp, u, cometOp)
case None =>
u
Expand Down Expand Up @@ -631,7 +679,7 @@ class CometSparkSessionExtensions
isSpark34Plus => // Spark 3.4+ only
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
val cometOp = CometBroadcastExchangeExec(b, b.child)
val cometOp = CometBroadcastExchangeExec(b, b.output, b.child)
CometSinkPlaceHolder(nativeOp, b, cometOp)
case None => b
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
Expand Down Expand Up @@ -60,7 +61,10 @@ import org.apache.comet.CometRuntimeException
* Note that this only supports Spark 3.4 and later, because the serialization class
* `ChunkedByteBuffer` is only serializable in Spark 3.4 and later.
*/
case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
case class CometBroadcastExchangeExec(
originalPlan: SparkPlan,
override val output: Seq[Attribute],
override val child: SparkPlan)
extends BroadcastExchangeLike
with ShimCometBroadcastExchangeExec {
import CometBroadcastExchangeExec._
Expand All @@ -75,7 +79,7 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"))

override def doCanonicalize(): SparkPlan = {
CometBroadcastExchangeExec(originalPlan.canonicalized, child.canonicalized)
CometBroadcastExchangeExec(null, null, child.canonicalized)
}

override def runtimeStatistics: Statistics = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.spark.sql.comet

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -32,6 +33,7 @@ import com.google.common.base.Objects
*/
case class CometCoalesceExec(
override val originalPlan: SparkPlan,
override val output: Seq[Attribute],
numPartitions: Int,
child: SparkPlan)
extends CometExec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ package org.apache.spark.sql.comet
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition}
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec.{METRIC_NATIVE_TIME_DESCRIPTION, METRIC_NATIVE_TIME_NAME}
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer}
Expand All @@ -45,6 +47,8 @@ case class CometCollectLimitExec(
child: SparkPlan)
extends CometExec
with UnaryExecNode {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,13 @@ import org.apache.comet.shims.ShimCometTakeOrderedAndProjectExec
*/
case class CometTakeOrderedAndProjectExec(
override val originalPlan: SparkPlan,
override val output: Seq[Attribute],
limit: Int,
sortOrder: Seq[SortOrder],
projectList: Seq[NamedExpression],
child: SparkPlan)
extends CometExec
with UnaryExecNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType, wi
*/
case class CometWindowExec(
override val originalPlan: SparkPlan,
override val output: Seq[Attribute],
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
Expand All @@ -52,8 +53,6 @@ case class CometWindowExec(

override def nodeName: String = "CometWindowExec"

override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute)

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
Expand Down
Loading

0 comments on commit b924aeb

Please sign in to comment.