Skip to content

Commit

Permalink
Move tests
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 18, 2024
1 parent a236051 commit e90af01
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 89 deletions.
5 changes: 2 additions & 3 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -927,14 +927,13 @@ impl PhysicalPlanner {
// DataFusion `SortMergeJoinExec` operator keeps the input batch internally. We need
// to copy the input batch to avoid the data corruption from reusing the input
// batch.
let left = if crate::execution::datafusion::planner::can_reuse_input_batch(&left) {
let left = if can_reuse_input_batch(&left) {
Arc::new(CopyExec::new(left))
} else {
left
};

let right = if crate::execution::datafusion::planner::can_reuse_input_batch(&right)
{
let right = if can_reuse_input_batch(&right) {
Arc::new(CopyExec::new(right))
} else {
right
Expand Down
86 changes: 0 additions & 86 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,92 +58,6 @@ class CometExecSuite extends CometTestBase {
}
}

test("HashJoin without join filter") {
withSQLConf(
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
// Inner join: build left
val df1 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1")
checkSparkAnswerAndOperator(df1)

// Right join: build left
val df2 =
sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1")
checkSparkAnswerAndOperator(df2)

// Full join: build left
val df3 =
sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1")
checkSparkAnswerAndOperator(df3)

// TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint.
// Left join with build left and right join with build right in hash join is only supported
// in Spark 3.5 or above. See SPARK-36612.
//
// Left join: build left
// sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1")

// TODO: DataFusion HashJoin doesn't support build right yet.
// Inner join: build right
// sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1")
//
// Left join: build right
// sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1")
//
// Right join: build right
// sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1")
//
// Full join: build right
// sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1")
//
// val left = sql("SELECT * FROM tbl_a")
// val right = sql("SELECT * FROM tbl_b")
//
// Left semi and anti joins are only supported with build right in Spark.
// left.join(right, left("_2") === right("_1"), "leftsemi")
// left.join(right, left("_2") === right("_1"), "leftanti")
}
}
}
}

test("HashJoin with join filter") {
withSQLConf(
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
// Inner join: build left
val df1 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " +
"ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df1)

// Right join: build left
val df2 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " +
"ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df2)

// Full join: build left
val df3 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " +
"ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df3)
}
}
}
}

test("Fix corrupted AggregateMode when transforming plan parameters") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "table") {
val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))
Expand Down
86 changes: 86 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,92 @@ class CometJoinSuite extends CometTestBase {
}
}

test("HashJoin without join filter") {
withSQLConf(
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
// Inner join: build left
val df1 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1")
checkSparkAnswerAndOperator(df1)

// Right join: build left
val df2 =
sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1")
checkSparkAnswerAndOperator(df2)

// Full join: build left
val df3 =
sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1")
checkSparkAnswerAndOperator(df3)

// TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint.
// Left join with build left and right join with build right in hash join is only supported
// in Spark 3.5 or above. See SPARK-36612.
//
// Left join: build left
// sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1")

// TODO: DataFusion HashJoin doesn't support build right yet.
// Inner join: build right
// sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1")
//
// Left join: build right
// sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1")
//
// Right join: build right
// sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1")
//
// Full join: build right
// sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1")
//
// val left = sql("SELECT * FROM tbl_a")
// val right = sql("SELECT * FROM tbl_b")
//
// Left semi and anti joins are only supported with build right in Spark.
// left.join(right, left("_2") === right("_1"), "leftsemi")
// left.join(right, left("_2") === right("_1"), "leftanti")
}
}
}
}

test("HashJoin with join filter") {
withSQLConf(
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
// Inner join: build left
val df1 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " +
"ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df1)

// Right join: build left
val df2 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " +
"ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df2)

// Full join: build left
val df3 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " +
"ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df3)
}
}
}
}

// TODO: Add a test for SortMergeJoin with join filter after new DataFusion release
test("SortMergeJoin without join filter") {
withSQLConf(
Expand Down

0 comments on commit e90af01

Please sign in to comment.