Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix corrupted AggregateMode when transforming plan parameters #118

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ abstract class CometNativeExec extends CometExec {
/**
* The serialized native query plan, optional. This is only defined when the current node is the
* "boundary" node between native and Spark.
*
* Note that derived classes of `CometNativeExec` must have `serializedPlanOpt` in last product
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this change. This seems to be a tight constraint and currently there's no way to enforce that.

It's possible for developers to add new CometNativeExec without serializedPlanOpt or put it in a different order.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we can use reflection to look for the serializedPlanOpt field in the product and only apply the map function on it. With this we don't have to enforce the ordering.

Copy link
Member Author

@viirya viirya Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although you can find serializedPlanOpt field from all declared fields in the product, you don't know its order in the constructor (as the order of getDeclaredFields is not guarenteed). We need to have these arguments in the exact order so we can re-construct the product by calling makeCopy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just iterate over the product and check each field one by one? if it is not serializedPlanOpt we just simply copy to the output array. If it is, we apply the function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, a bit trouble for None as you cannot get inside value for reflection (as it doesn't have that).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, seems getDeclaredFields doesn't guarantee the order of returned fields. It is a bit dangeous to assume its order is same as product elements.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for looking into this! I think the current approach is OK too, as long as we are aware that serializedPlanOpt should be declared as the last field.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is useful. Scala 2.13 added productElementNames however we still need to handle 2.12 ..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't make reflection approach work. Instead of spending time on it, I use the wrapper case class approach.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK np. Thanks for trying it.

* parameter.
*/
def serializedPlanOpt: Option[Array[Byte]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a maybe better approach would be define a new case class to hold the serialized plan and then it could be pattern matched such as.

case class SerializedPlan(val bytes: Array[Byte])

or

case class SerializedPlan(val bytesOpt: Option[Array[Byte]])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya what about this approach?

When the bytes are hold in a case class, it can be pattern matched.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this approach should work. Just trying on possible reflection approach if any. If it doesn't work, I will add a wrapper case class like that.


Expand Down Expand Up @@ -276,13 +279,31 @@ abstract class CometNativeExec extends CometExec {
}
}

/**
* Copies product elements to the output array except the last one. The last element will be
* transformed using the provided function. This is used to transform `serializedPlanOpt`
* parameter in case classes of Comet native operator where the `serializedPlanOpt` is always
* the last produce element. That is because we cannot match `Option[Array[Byte]]` due to type
* erase.
*/
private def mapProduct(f: Any => AnyRef): Array[AnyRef] = {
val arr = Array.ofDim[AnyRef](productArity)
var i = 0
while (i < arr.length - 1) {
arr(i) = productElement(i).asInstanceOf[AnyRef]
i += 1
}
arr(arr.length - 1) = f(productElement(arr.length - 1))
arr
}

/**
* Converts this native Comet operator and its children into a native block which can be
* executed as a whole (i.e., in a single JNI call) from the native side.
*/
def convertBlock(): CometNativeExec = {
def transform(arg: Any): AnyRef = arg match {
case serializedPlan: Option[Array[Byte]] if serializedPlan.isEmpty =>
case serializedPlan: Option[_] if serializedPlan.isEmpty =>
val out = new ByteArrayOutputStream()
nativeOp.writeTo(out)
out.close()
Expand All @@ -291,7 +312,7 @@ abstract class CometNativeExec extends CometExec {
case null => null
}

val newArgs = mapProductIterator(transform)
val newArgs = mapProduct(transform)
makeCopy(newArgs).asInstanceOf[CometNativeExec]
}

Expand All @@ -300,13 +321,13 @@ abstract class CometNativeExec extends CometExec {
*/
def cleanBlock(): CometNativeExec = {
def transform(arg: Any): AnyRef = arg match {
case serializedPlan: Option[Array[Byte]] if serializedPlan.isDefined =>
case serializedPlan: Option[_] if serializedPlan.isDefined =>
None
case other: AnyRef => other
case null => null
}

val newArgs = mapProductIterator(transform)
val newArgs = mapProduct(transform)
makeCopy(newArgs).asInstanceOf[CometNativeExec]
}

Expand Down
16 changes: 15 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.functions.{date_add, expr}
import org.apache.spark.sql.functions.{date_add, expr, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -56,6 +57,19 @@ class CometExecSuite extends CometTestBase {
}
}

test("Fix corrupted AggregateMode when transforming plan parameters") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "table") {
val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))
val agg = stripAQEPlan(df.queryExecution.executedPlan).collectFirst {
case s: CometHashAggregateExec => s
}.get

assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode])
val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec]
assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode])
}
}

test("CometBroadcastExchangeExec") {
withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") {
Expand Down