diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index f2aba74a0..10c332801 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -237,7 +237,13 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometProjectExec(nativeOp, op, op.projectList, op.output, op.child, None) + CometProjectExec( + nativeOp, + op, + op.projectList, + op.output, + op.child, + SerializedPlan(None)) case None => op } @@ -246,7 +252,7 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometFilterExec(nativeOp, op, op.condition, op.child, None) + CometFilterExec(nativeOp, op, op.condition, op.child, SerializedPlan(None)) case None => op } @@ -255,7 +261,7 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometSortExec(nativeOp, op, op.sortOrder, op.child, None) + CometSortExec(nativeOp, op, op.sortOrder, op.child, SerializedPlan(None)) case None => op } @@ -264,7 +270,7 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometLocalLimitExec(nativeOp, op, op.limit, op.child, None) + CometLocalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None)) case None => op } @@ -273,7 +279,7 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometGlobalLimitExec(nativeOp, op, op.limit, op.child, None) + CometGlobalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None)) case None => op } @@ -282,7 +288,7 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometExpandExec(nativeOp, op, op.projections, op.child, None) + CometExpandExec(nativeOp, op, op.projections, op.child, SerializedPlan(None)) case None => op } @@ -305,7 +311,7 @@ class CometSparkSessionExtensions child.output, if (modes.nonEmpty) Some(modes.head) else None, child, - None) + SerializedPlan(None)) case None => op } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0298bc643..e75f9a4a5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -150,7 +150,7 @@ abstract class CometNativeExec extends CometExec { * The serialized native query plan, optional. This is only defined when the current node is the * "boundary" node between native and Spark. */ - def serializedPlanOpt: Option[Array[Byte]] + def serializedPlanOpt: SerializedPlan /** The Comet native operator */ def nativeOp: Operator @@ -200,7 +200,7 @@ abstract class CometNativeExec extends CometExec { } override def doExecuteColumnar(): RDD[ColumnarBatch] = { - serializedPlanOpt match { + serializedPlanOpt.plan match { case None => // This is in the middle of a native execution, it should not be executed directly. throw new CometRuntimeException( @@ -282,11 +282,11 @@ abstract class CometNativeExec extends CometExec { */ def convertBlock(): CometNativeExec = { def transform(arg: Any): AnyRef = arg match { - case serializedPlan: Option[Array[Byte]] if serializedPlan.isEmpty => + case serializedPlan: SerializedPlan if serializedPlan.isEmpty => val out = new ByteArrayOutputStream() nativeOp.writeTo(out) out.close() - Some(out.toByteArray) + SerializedPlan(Some(out.toByteArray)) case other: AnyRef => other case null => null } @@ -300,8 +300,8 @@ abstract class CometNativeExec extends CometExec { */ def cleanBlock(): CometNativeExec = { def transform(arg: Any): AnyRef = arg match { - case serializedPlan: Option[Array[Byte]] if serializedPlan.isDefined => - None + case serializedPlan: SerializedPlan if serializedPlan.isDefined => + SerializedPlan(None) case other: AnyRef => other case null => null } @@ -323,13 +323,23 @@ abstract class CometNativeExec extends CometExec { abstract class CometUnaryExec extends CometNativeExec with UnaryExecNode +/** + * Represents the serialized plan of Comet native operators. Only the first operator in a block of + * continuous Comet native operators has defined plan bytes which contains the serialization of + * the plan tree of the block. + */ +case class SerializedPlan(plan: Option[Array[Byte]]) { + def isDefined: Boolean = plan.isDefined + def isEmpty: Boolean = plan.isEmpty +} + case class CometProjectExec( override val nativeOp: Operator, override val originalPlan: SparkPlan, projectList: Seq[NamedExpression], override val output: Seq[Attribute], child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override def producedAttributes: AttributeSet = outputSet override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = @@ -356,7 +366,7 @@ case class CometFilterExec( override val originalPlan: SparkPlan, condition: Expression, child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -390,7 +400,7 @@ case class CometSortExec( override val originalPlan: SparkPlan, sortOrder: Seq[SortOrder], child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -422,7 +432,7 @@ case class CometLocalLimitExec( override val originalPlan: SparkPlan, limit: Int, child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -449,7 +459,7 @@ case class CometGlobalLimitExec( override val originalPlan: SparkPlan, limit: Int, child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -474,7 +484,7 @@ case class CometExpandExec( override val originalPlan: SparkPlan, projections: Seq[Seq[Expression]], child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override def producedAttributes: AttributeSet = outputSet @@ -538,7 +548,7 @@ case class CometHashAggregateExec( input: Seq[Attribute], mode: Option[AggregateMode], child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -576,7 +586,7 @@ case class CometHashAggregateExec( case class CometScanWrapper(override val nativeOp: Operator, override val originalPlan: SparkPlan) extends CometNativeExec with LeafExecNode { - override val serializedPlanOpt: Option[Array[Byte]] = None + override val serializedPlanOpt: SerializedPlan = SerializedPlan(None) override def stringArgs: Iterator[Any] = Iterator(originalPlan.output, originalPlan) } @@ -592,7 +602,7 @@ case class CometSinkPlaceHolder( override val originalPlan: SparkPlan, child: SparkPlan) extends CometUnaryExec { - override val serializedPlanOpt: Option[Array[Byte]] = None + override val serializedPlanOpt: SerializedPlan = SerializedPlan(None) override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { this.copy(child = newChild) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 0414671c2..6dafb2792 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -31,13 +31,14 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec -import org.apache.spark.sql.functions.{date_add, expr} +import org.apache.spark.sql.functions.{date_add, expr, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String @@ -56,6 +57,19 @@ class CometExecSuite extends CometTestBase { } } + test("Fix corrupted AggregateMode when transforming plan parameters") { + withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { + val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2")) + val agg = stripAQEPlan(df.queryExecution.executedPlan).collectFirst { + case s: CometHashAggregateExec => s + }.get + + assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode]) + val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec] + assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode]) + } + } + test("CometBroadcastExchangeExec") { withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") {