Skip to content

Commit

Permalink
fix: SortMergeJoin with unsupported key type should fall back to Spark
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 30, 2024
1 parent 21717eb commit 1181dc4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
18 changes: 18 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
expression
}

/**
* Returns true if given datatype is supported as a key in DataFusion sort merge join.
*/
def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: DateType | _: TimestampNTZType | _: DecimalType |
_: BooleanType =>
true
case _ => false
}

/**
* Convert a Spark plan operator to a protobuf Comet operator.
*
Expand Down Expand Up @@ -2318,6 +2329,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
return None
}

for (key <- join.leftKeys) {
if (!supportedSortMergeJoinEqualType(key.dataType)) {
withInfo(op, s"Unsupported join key type ${key.dataType}")
return None
}
}

val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))

Expand Down
18 changes: 18 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ class CometJoinSuite extends CometTestBase {
}
}

test("SortMergeJoin with unsupported key type should fall back to Spark") {
withSQLConf(
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')")

sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')")

val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time")
checkSparkAnswer(df)
}
}
}

test("Broadcast HashJoin without join filter") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
withSQLConf(
Expand Down

0 comments on commit 1181dc4

Please sign in to comment.