Skip to content

Commit

Permalink
fix: SortMergeJoin with unsupported key type should fall back to Spark (
Browse files Browse the repository at this point in the history
apache#355)

* fix: SortMergeJoin with unsupported key type should fall back to Spark

* Fix

* For review

* For review
  • Loading branch information
viirya authored and Steve Vaughan Jr committed May 1, 2024
1 parent 7a8b14b commit 546204c
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 2 deletions.
26 changes: 26 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2104,6 +2104,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
expression
}

/**
* Returns true if given datatype is supported as a key in DataFusion sort merge join.
*/
def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType =>
true
// `TimestampNTZType` is private in Spark 3.2/3.3.
case dt if dt.typeName == "timestamp_ntz" => true
case _ => false
}

/**
* Convert a Spark plan operator to a protobuf Comet operator.
*
Expand Down Expand Up @@ -2410,6 +2422,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
return None
}

// Checks if the join keys are supported by DataFusion SortMergeJoin.
val errorMsgs = join.leftKeys.flatMap { key =>
if (!supportedSortMergeJoinEqualType(key.dataType)) {
Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}")
} else {
None
}
}

if (errorMsgs.nonEmpty) {
withInfo(op, errorMsgs.flatten.mkString("\n"))
return None
}

val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))

Expand Down
19 changes: 19 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ class CometJoinSuite extends CometTestBase {
}
}

test("SortMergeJoin with unsupported key type should fall back to Spark") {
withSQLConf(
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')")

sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')")

val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time")
val (sparkPlan, cometPlan) = checkSparkAnswer(df)
assert(sparkPlan.canonicalized === cometPlan.canonicalized)
}
}
}

test("Broadcast HashJoin without join filter") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
withSQLConf(
Expand Down
14 changes: 12 additions & 2 deletions spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,28 @@ abstract class CometTestBase
}
}

protected def checkSparkAnswer(query: String): Unit = {
protected def checkSparkAnswer(query: String): (SparkPlan, SparkPlan) = {
checkSparkAnswer(sql(query))
}

protected def checkSparkAnswer(df: => DataFrame): Unit = {
/**
* Check the answer of a Comet SQL query with Spark result.
* @param df
* The DataFrame of the query.
* @return
* A tuple of the SparkPlan of the query and the SparkPlan of the Comet query.
*/
protected def checkSparkAnswer(df: => DataFrame): (SparkPlan, SparkPlan) = {
var expected: Array[Row] = Array.empty
var sparkPlan = null.asInstanceOf[SparkPlan]
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val dfSpark = Dataset.ofRows(spark, df.logicalPlan)
expected = dfSpark.collect()
sparkPlan = dfSpark.queryExecution.executedPlan
}
val dfComet = Dataset.ofRows(spark, df.logicalPlan)
checkAnswer(dfComet, expected)
(sparkPlan, dfComet.queryExecution.executedPlan)
}

protected def checkSparkAnswerAndOperator(query: String, excludedClasses: Class[_]*): Unit = {
Expand Down

0 comments on commit 546204c

Please sign in to comment.