From e90af0126b95fa933631d3cae26183516e76719b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 18 Mar 2024 15:25:15 -0700 Subject: [PATCH] Move tests --- core/src/execution/datafusion/planner.rs | 5 +- .../apache/comet/exec/CometExecSuite.scala | 86 ------------------- .../apache/comet/exec/CometJoinSuite.scala | 86 +++++++++++++++++++ 3 files changed, 88 insertions(+), 89 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 889fceeee..5efae8ef4 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -927,14 +927,13 @@ impl PhysicalPlanner { // DataFusion `SortMergeJoinExec` operator keeps the input batch internally. We need // to copy the input batch to avoid the data corruption from reusing the input // batch. - let left = if crate::execution::datafusion::planner::can_reuse_input_batch(&left) { + let left = if can_reuse_input_batch(&left) { Arc::new(CopyExec::new(left)) } else { left }; - let right = if crate::execution::datafusion::planner::can_reuse_input_batch(&right) - { + let right = if can_reuse_input_batch(&right) { Arc::new(CopyExec::new(right)) } else { right diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 532c6cdfd..4172c7caa 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -58,92 +58,6 @@ class CometExecSuite extends CometTestBase { } } - test("HashJoin without join filter") { - withSQLConf( - SQLConf.PREFER_SORTMERGEJOIN.key -> "false", - SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { - withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { - // Inner join: build left - val df1 = - sql( - "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df1) - - // Right join: build left - val df2 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df2) - - // Full join: build left - val df3 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df3) - - // TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint. - // Left join with build left and right join with build right in hash join is only supported - // in Spark 3.5 or above. See SPARK-36612. - // - // Left join: build left - // sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - - // TODO: DataFusion HashJoin doesn't support build right yet. - // Inner join: build right - // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") - // - // Left join: build right - // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - // - // Right join: build right - // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - // - // Full join: build right - // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") - // - // val left = sql("SELECT * FROM tbl_a") - // val right = sql("SELECT * FROM tbl_b") - // - // Left semi and anti joins are only supported with build right in Spark. - // left.join(right, left("_2") === right("_1"), "leftsemi") - // left.join(right, left("_2") === right("_1"), "leftanti") - } - } - } - } - - test("HashJoin with join filter") { - withSQLConf( - SQLConf.PREFER_SORTMERGEJOIN.key -> "false", - SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { - withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { - // Inner join: build left - val df1 = - sql( - "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " + - "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") - checkSparkAnswerAndOperator(df1) - - // Right join: build left - val df2 = - sql( - "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + - "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") - checkSparkAnswerAndOperator(df2) - - // Full join: build left - val df3 = - sql( - "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " + - "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") - checkSparkAnswerAndOperator(df3) - } - } - } - } - test("Fix corrupted AggregateMode when transforming plan parameters") { withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2")) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 73ce0e1fd..a64ec8749 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -38,6 +38,92 @@ class CometJoinSuite extends CometTestBase { } } + test("HashJoin without join filter") { + withSQLConf( + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df1) + + // Right join: build left + val df2 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df2) + + // Full join: build left + val df3 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df3) + + // TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint. + // Left join with build left and right join with build right in hash join is only supported + // in Spark 3.5 or above. See SPARK-36612. + // + // Left join: build left + // sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + + // TODO: DataFusion HashJoin doesn't support build right yet. + // Inner join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // Left join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // Right join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // Full join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // val left = sql("SELECT * FROM tbl_a") + // val right = sql("SELECT * FROM tbl_b") + // + // Left semi and anti joins are only supported with build right in Spark. + // left.join(right, left("_2") === right("_1"), "leftsemi") + // left.join(right, left("_2") === right("_1"), "leftanti") + } + } + } + } + + test("HashJoin with join filter") { + withSQLConf( + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df1) + + // Right join: build left + val df2 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df2) + + // Full join: build left + val df3 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df3) + } + } + } + } + // TODO: Add a test for SortMergeJoin with join filter after new DataFusion release test("SortMergeJoin without join filter") { withSQLConf(