Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Avoid creating huge duplicate of canonicalized plans for CometNativeExec #639

Merged
merged 9 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
4 changes: 2 additions & 2 deletions dev/diffs/3.4.3.diff
Original file line number Diff line number Diff line change
Expand Up @@ -2491,8 +2491,8 @@ index dd55fcfe42c..293e9dc2986 100644
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
+ case CometFilterExec(_, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, child, _), _) => child
+ case CometFilterExec(_, _, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child
}

spark.internalCreateDataFrame(withoutFilters.execute(), schema)
Expand Down
4 changes: 2 additions & 2 deletions dev/diffs/3.5.1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -2650,8 +2650,8 @@ index dd55fcfe42c..293e9dc2986 100644
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
+ case CometFilterExec(_, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, child, _), _) => child
+ case CometFilterExec(_, _, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child
}

spark.internalCreateDataFrame(withoutFilters.execute(), schema)
Expand Down
4 changes: 2 additions & 2 deletions dev/diffs/4.0.0-preview1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -2651,8 +2651,8 @@ index 5fbf379644f..32711763ec1 100644
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
+ case CometFilterExec(_, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, child, _), _) => child
+ case CometFilterExec(_, _, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child
}

spark.internalCreateDataFrame(withoutFilters.execute(), schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ class CometSparkSessionExtensions
CometProjectExec(
nativeOp,
op,
op.projectList,
op.output,
op.projectList,
op.child,
SerializedPlan(None))
case None =>
Expand All @@ -343,7 +343,13 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometFilterExec(nativeOp, op, op.condition, op.child, SerializedPlan(None))
CometFilterExec(
nativeOp,
op,
op.output,
op.condition,
op.child,
SerializedPlan(None))
case None =>
op
}
Expand All @@ -352,7 +358,14 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometSortExec(nativeOp, op, op.sortOrder, op.child, SerializedPlan(None))
CometSortExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.sortOrder,
op.child,
SerializedPlan(None))
case None =>
op
}
Expand Down Expand Up @@ -393,12 +406,27 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometExpandExec(nativeOp, op, op.projections, op.child, SerializedPlan(None))
CometExpandExec(
nativeOp,
op,
op.output,
op.projections,
op.child,
SerializedPlan(None))
case None =>
op
}

case op @ HashAggregateExec(_, _, _, groupingExprs, aggExprs, _, _, _, child) =>
case op @ HashAggregateExec(
_,
_,
_,
groupingExprs,
aggExprs,
_,
_,
resultExpressions,
child) =>
val modes = aggExprs.map(_.mode).distinct

if (!modes.isEmpty && modes.size != 1) {
Expand All @@ -425,8 +453,10 @@ class CometSparkSessionExtensions
CometHashAggregateExec(
nativeOp,
op,
op.output,
groupingExprs,
aggExprs,
resultExpressions,
child.output,
if (modes.nonEmpty) Some(modes.head) else None,
child,
Expand All @@ -446,6 +476,8 @@ class CometSparkSessionExtensions
CometHashJoinExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.leftKeys,
op.rightKeys,
op.joinType,
Expand Down Expand Up @@ -478,6 +510,8 @@ class CometSparkSessionExtensions
CometBroadcastHashJoinExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.leftKeys,
op.rightKeys,
op.joinType,
Expand All @@ -499,6 +533,8 @@ class CometSparkSessionExtensions
CometSortMergeJoinExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.leftKeys,
op.rightKeys,
op.joinType,
Expand Down Expand Up @@ -535,7 +571,7 @@ class CometSparkSessionExtensions
&& isCometNative(child) =>
QueryPlanSerde.operator2Proto(c) match {
case Some(nativeOp) =>
val cometOp = CometCoalesceExec(c, numPartitions, child)
val cometOp = CometCoalesceExec(c, c.output, numPartitions, child)
CometSinkPlaceHolder(nativeOp, c, cometOp)
case None =>
c
Expand All @@ -559,7 +595,13 @@ class CometSparkSessionExtensions
QueryPlanSerde.operator2Proto(s) match {
case Some(nativeOp) =>
val cometOp =
CometTakeOrderedAndProjectExec(s, s.limit, s.sortOrder, s.projectList, s.child)
CometTakeOrderedAndProjectExec(
s,
s.output,
s.limit,
s.sortOrder,
s.projectList,
s.child)
CometSinkPlaceHolder(nativeOp, s, cometOp)
case None =>
s
Expand All @@ -580,7 +622,13 @@ class CometSparkSessionExtensions
newOp match {
case Some(nativeOp) =>
val cometOp =
CometWindowExec(w, w.windowExpression, w.partitionSpec, w.orderSpec, w.child)
CometWindowExec(
w,
w.output,
w.windowExpression,
w.partitionSpec,
w.orderSpec,
w.child)
CometSinkPlaceHolder(nativeOp, w, cometOp)
case None =>
w
Expand All @@ -591,7 +639,7 @@ class CometSparkSessionExtensions
u.children.forall(isCometNative) =>
QueryPlanSerde.operator2Proto(u) match {
case Some(nativeOp) =>
val cometOp = CometUnionExec(u, u.children)
val cometOp = CometUnionExec(u, u.output, u.children)
CometSinkPlaceHolder(nativeOp, u, cometOp)
case None =>
u
Expand Down Expand Up @@ -631,7 +679,7 @@ class CometSparkSessionExtensions
isSpark34Plus => // Spark 3.4+ only
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
val cometOp = CometBroadcastExchangeExec(b, b.child)
val cometOp = CometBroadcastExchangeExec(b, b.output, b.child)
CometSinkPlaceHolder(nativeOp, b, cometOp)
case None => b
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
Expand Down Expand Up @@ -60,7 +61,10 @@ import org.apache.comet.CometRuntimeException
* Note that this only supports Spark 3.4 and later, because the serialization class
* `ChunkedByteBuffer` is only serializable in Spark 3.4 and later.
*/
case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
case class CometBroadcastExchangeExec(
originalPlan: SparkPlan,
override val output: Seq[Attribute],
override val child: SparkPlan)
extends BroadcastExchangeLike
with ShimCometBroadcastExchangeExec {
import CometBroadcastExchangeExec._
Expand All @@ -75,7 +79,7 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"))

override def doCanonicalize(): SparkPlan = {
CometBroadcastExchangeExec(originalPlan.canonicalized, child.canonicalized)
CometBroadcastExchangeExec(null, null, child.canonicalized)
}

override def runtimeStatistics: Statistics = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.spark.sql.comet

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -32,6 +33,7 @@ import com.google.common.base.Objects
*/
case class CometCoalesceExec(
override val originalPlan: SparkPlan,
override val output: Seq[Attribute],
numPartitions: Int,
child: SparkPlan)
extends CometExec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ package org.apache.spark.sql.comet
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition}
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec.{METRIC_NATIVE_TIME_DESCRIPTION, METRIC_NATIVE_TIME_NAME}
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer}
Expand All @@ -45,6 +47,8 @@ case class CometCollectLimitExec(
child: SparkPlan)
extends CometExec
with UnaryExecNode {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,13 @@ import org.apache.comet.shims.ShimCometTakeOrderedAndProjectExec
*/
case class CometTakeOrderedAndProjectExec(
override val originalPlan: SparkPlan,
override val output: Seq[Attribute],
limit: Int,
sortOrder: Seq[SortOrder],
projectList: Seq[NamedExpression],
child: SparkPlan)
extends CometExec
with UnaryExecNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType, wi
*/
case class CometWindowExec(
override val originalPlan: SparkPlan,
override val output: Seq[Attribute],
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
Expand All @@ -52,8 +53,6 @@ case class CometWindowExec(

override def nodeName: String = "CometWindowExec"

override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute)

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
Expand Down
Loading
Loading