Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix corrupted AggregateMode when transforming plan parameters #118

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ abstract class CometNativeExec extends CometExec {
/**
* The serialized native query plan, optional. This is only defined when the current node is the
* "boundary" node between native and Spark.
*
* Note that derived classes of `BosonNativeExec` must have `serializedPlanOpt` in last product
viirya marked this conversation as resolved.
Show resolved Hide resolved
* parameter.
*/
def serializedPlanOpt: Option[Array[Byte]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a maybe better approach would be define a new case class to hold the serialized plan and then it could be pattern matched such as.

case class SerializedPlan(val bytes: Array[Byte])

or

case class SerializedPlan(val bytesOpt: Option[Array[Byte]])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya what about this approach?

When the bytes are hold in a case class, it can be pattern matched.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this approach should work. Just trying on possible reflection approach if any. If it doesn't work, I will add a wrapper case class like that.


Expand Down Expand Up @@ -276,13 +279,30 @@ abstract class CometNativeExec extends CometExec {
}
}

/**
* Maps through product elements except the last one. The last element will be transformed using
viirya marked this conversation as resolved.
Show resolved Hide resolved
* the provided function. This is used to transform `serializedPlanOpt` parameter in case
* classes of Boson native operator where the `serializedPlanOpt` is always the last produce
viirya marked this conversation as resolved.
Show resolved Hide resolved
* element. That is because we cannot match `Option[Array[Byte]]` due to type erase.
*/
private def mapProduct(f: Any => AnyRef): Array[AnyRef] = {
val arr = Array.ofDim[AnyRef](productArity)
var i = 0
while (i < arr.length - 1) {
arr(i) = productElement(i).asInstanceOf[AnyRef]
i += 1
}
arr(arr.length - 1) = f(productElement(arr.length - 1))
arr
}

/**
* Converts this native Comet operator and its children into a native block which can be
* executed as a whole (i.e., in a single JNI call) from the native side.
*/
def convertBlock(): CometNativeExec = {
def transform(arg: Any): AnyRef = arg match {
case serializedPlan: Option[Array[Byte]] if serializedPlan.isEmpty =>
case serializedPlan: Option[_] if serializedPlan.isEmpty =>
val out = new ByteArrayOutputStream()
nativeOp.writeTo(out)
out.close()
Expand All @@ -291,7 +311,7 @@ abstract class CometNativeExec extends CometExec {
case null => null
}

val newArgs = mapProductIterator(transform)
val newArgs = mapProduct(transform)
makeCopy(newArgs).asInstanceOf[CometNativeExec]
}

Expand All @@ -300,13 +320,13 @@ abstract class CometNativeExec extends CometExec {
*/
def cleanBlock(): CometNativeExec = {
def transform(arg: Any): AnyRef = arg match {
case serializedPlan: Option[Array[Byte]] if serializedPlan.isDefined =>
case serializedPlan: Option[_] if serializedPlan.isDefined =>
None
case other: AnyRef => other
case null => null
}

val newArgs = mapProductIterator(transform)
val newArgs = mapProduct(transform)
makeCopy(newArgs).asInstanceOf[CometNativeExec]
}

Expand Down
16 changes: 15 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.functions.{date_add, expr}
import org.apache.spark.sql.functions.{date_add, expr, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -56,6 +57,19 @@ class CometExecSuite extends CometTestBase {
}
}

test("Fix corrupted AggregateMode when transforming plan parameters") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "table") {
val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))
val agg = stripAQEPlan(df.queryExecution.executedPlan).collectFirst {
case s: CometHashAggregateExec => s
}.get

assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode])
val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec]
assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode])
}
}

test("CometBroadcastExchangeExec") {
withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") {
Expand Down