From 546204c3f32ff3c046589aa53772a53a95e2c7be Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Apr 2024 10:59:07 -0700 Subject: [PATCH] fix: SortMergeJoin with unsupported key type should fall back to Spark (#355) * fix: SortMergeJoin with unsupported key type should fall back to Spark * Fix * For review * For review --- .../apache/comet/serde/QueryPlanSerde.scala | 26 +++++++++++++++++++ .../apache/comet/exec/CometJoinSuite.scala | 19 ++++++++++++++ .../org/apache/spark/sql/CometTestBase.scala | 14 ++++++++-- 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e1e7a7117c..6eda0547f2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2104,6 +2104,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { expression } + /** + * Returns true if given datatype is supported as a key in DataFusion sort merge join. + */ + def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | + _: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType => + true + // `TimestampNTZType` is private in Spark 3.2/3.3. + case dt if dt.typeName == "timestamp_ntz" => true + case _ => false + } + /** * Convert a Spark plan operator to a protobuf Comet operator. * @@ -2410,6 +2422,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { return None } + // Checks if the join keys are supported by DataFusion SortMergeJoin. + val errorMsgs = join.leftKeys.flatMap { key => + if (!supportedSortMergeJoinEqualType(key.dataType)) { + Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}") + } else { + None + } + } + + if (errorMsgs.nonEmpty) { + withInfo(op, errorMsgs.flatten.mkString("\n")) + return None + } + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 54c0baf161..91d88c76e8 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -40,6 +40,25 @@ class CometJoinSuite extends CometTestBase { } } + test("SortMergeJoin with unsupported key type should fall back to Spark") { + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTable("t1", "t2") { + sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET") + sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')") + + sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET") + sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')") + + val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time") + val (sparkPlan, cometPlan) = checkSparkAnswer(df) + assert(sparkPlan.canonicalized === cometPlan.canonicalized) + } + } + } + test("Broadcast HashJoin without join filter") { assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") withSQLConf( diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index ef64d666ea..27428b8e99 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -127,18 +127,28 @@ abstract class CometTestBase } } - protected def checkSparkAnswer(query: String): Unit = { + protected def checkSparkAnswer(query: String): (SparkPlan, SparkPlan) = { checkSparkAnswer(sql(query)) } - protected def checkSparkAnswer(df: => DataFrame): Unit = { + /** + * Check the answer of a Comet SQL query with Spark result. + * @param df + * The DataFrame of the query. + * @return + * A tuple of the SparkPlan of the query and the SparkPlan of the Comet query. + */ + protected def checkSparkAnswer(df: => DataFrame): (SparkPlan, SparkPlan) = { var expected: Array[Row] = Array.empty + var sparkPlan = null.asInstanceOf[SparkPlan] withSQLConf(CometConf.COMET_ENABLED.key -> "false") { val dfSpark = Dataset.ofRows(spark, df.logicalPlan) expected = dfSpark.collect() + sparkPlan = dfSpark.queryExecution.executedPlan } val dfComet = Dataset.ofRows(spark, df.logicalPlan) checkAnswer(dfComet, expected) + (sparkPlan, dfComet.queryExecution.executedPlan) } protected def checkSparkAnswerAndOperator(query: String, excludedClasses: Class[_]*): Unit = {