diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 0bc2c1d3c..45add0ef5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2012,6 +2012,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { expression } + /** + * Returns true if given datatype is supported as a key in DataFusion sort merge join. + */ + def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | + _: DoubleType | _: StringType | _: DateType | _: TimestampNTZType | _: DecimalType | + _: BooleanType => + true + case _ => false + } + /** * Convert a Spark plan operator to a protobuf Comet operator. * @@ -2318,6 +2329,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { return None } + for (key <- join.leftKeys) { + if (!supportedSortMergeJoinEqualType(key.dataType)) { + withInfo(op, s"Unsupported join key type ${key.dataType}") + return None + } + } + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 54c0baf16..123a2fd70 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -40,6 +40,24 @@ class CometJoinSuite extends CometTestBase { } } + test("SortMergeJoin with unsupported key type should fall back to Spark") { + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTable("t1", "t2") { + sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET") + sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')") + + sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET") + sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')") + + val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time") + checkSparkAnswer(df) + } + } + } + test("Broadcast HashJoin without join filter") { assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") withSQLConf(