From 1181dc48ee46e8a20cd59572c735915d57108997 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 29 Apr 2024 23:53:06 -0700 Subject: [PATCH 1/4] fix: SortMergeJoin with unsupported key type should fall back to Spark --- .../apache/comet/serde/QueryPlanSerde.scala | 18 ++++++++++++++++++ .../org/apache/comet/exec/CometJoinSuite.scala | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 0bc2c1d3c..45add0ef5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2012,6 +2012,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { expression } + /** + * Returns true if given datatype is supported as a key in DataFusion sort merge join. + */ + def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | + _: DoubleType | _: StringType | _: DateType | _: TimestampNTZType | _: DecimalType | + _: BooleanType => + true + case _ => false + } + /** * Convert a Spark plan operator to a protobuf Comet operator. * @@ -2318,6 +2329,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { return None } + for (key <- join.leftKeys) { + if (!supportedSortMergeJoinEqualType(key.dataType)) { + withInfo(op, s"Unsupported join key type ${key.dataType}") + return None + } + } + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 54c0baf16..123a2fd70 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -40,6 +40,24 @@ class CometJoinSuite extends CometTestBase { } } + test("SortMergeJoin with unsupported key type should fall back to Spark") { + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTable("t1", "t2") { + sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET") + sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')") + + sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET") + sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')") + + val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time") + checkSparkAnswer(df) + } + } + } + test("Broadcast HashJoin without join filter") { assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") withSQLConf( From 82c63c4cf9d11e337b12cde7365e26a2ea5a4a59 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Apr 2024 07:05:27 -0700 Subject: [PATCH 2/4] Fix --- .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 45add0ef5..6ebb7190f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2017,9 +2017,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { */ def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | - _: DoubleType | _: StringType | _: DateType | _: TimestampNTZType | _: DecimalType | - _: BooleanType => + _: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType => true + // `TimestampNTZType` is private in Spark 3.2/3.3. + case dt if dt.typeName == "timestamp_ntz" => true case _ => false } From 2d1df3999775481dc6c25319ed16b9d8721dbdd1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Apr 2024 07:52:27 -0700 Subject: [PATCH 3/4] For review --- .../src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 6ebb7190f..2bf87c423 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2332,7 +2332,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { for (key <- join.leftKeys) { if (!supportedSortMergeJoinEqualType(key.dataType)) { - withInfo(op, s"Unsupported join key type ${key.dataType}") + withInfo(op, s"Unsupported join key type ${key.dataType} on key: ${key.sql}") return None } } From e7d41eb4ab7053537fe57b28cea8cb93c9768475 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Apr 2024 09:53:15 -0700 Subject: [PATCH 4/4] For review --- .../org/apache/comet/serde/QueryPlanSerde.scala | 13 ++++++++++--- .../org/apache/comet/exec/CometJoinSuite.scala | 3 ++- .../scala/org/apache/spark/sql/CometTestBase.scala | 14 ++++++++++++-- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 2bf87c423..496ef26d9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2330,13 +2330,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { return None } - for (key <- join.leftKeys) { + // Checks if the join keys are supported by DataFusion SortMergeJoin. + val errorMsgs = join.leftKeys.flatMap { key => if (!supportedSortMergeJoinEqualType(key.dataType)) { - withInfo(op, s"Unsupported join key type ${key.dataType} on key: ${key.sql}") - return None + Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}") + } else { + None } } + if (errorMsgs.nonEmpty) { + withInfo(op, errorMsgs.flatten.mkString("\n")) + return None + } + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 123a2fd70..91d88c76e 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -53,7 +53,8 @@ class CometJoinSuite extends CometTestBase { sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')") val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time") - checkSparkAnswer(df) + val (sparkPlan, cometPlan) = checkSparkAnswer(df) + assert(sparkPlan.canonicalized === cometPlan.canonicalized) } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index ef64d666e..27428b8e9 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -127,18 +127,28 @@ abstract class CometTestBase } } - protected def checkSparkAnswer(query: String): Unit = { + protected def checkSparkAnswer(query: String): (SparkPlan, SparkPlan) = { checkSparkAnswer(sql(query)) } - protected def checkSparkAnswer(df: => DataFrame): Unit = { + /** + * Check the answer of a Comet SQL query with Spark result. + * @param df + * The DataFrame of the query. + * @return + * A tuple of the SparkPlan of the query and the SparkPlan of the Comet query. + */ + protected def checkSparkAnswer(df: => DataFrame): (SparkPlan, SparkPlan) = { var expected: Array[Row] = Array.empty + var sparkPlan = null.asInstanceOf[SparkPlan] withSQLConf(CometConf.COMET_ENABLED.key -> "false") { val dfSpark = Dataset.ofRows(spark, df.logicalPlan) expected = dfSpark.collect() + sparkPlan = dfSpark.queryExecution.executedPlan } val dfComet = Dataset.ofRows(spark, df.logicalPlan) checkAnswer(dfComet, expected) + (sparkPlan, dfComet.queryExecution.executedPlan) } protected def checkSparkAnswerAndOperator(query: String, excludedClasses: Class[_]*): Unit = {