Skip to content

Commit

Permalink
Init
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 1, 2024
1 parent 1f81c38 commit 40bed97
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 6 deletions.
2 changes: 1 addition & 1 deletion core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ impl PhysicalPlanner {
&join.left_join_keys,
&join.right_join_keys,
join.join_type,
&None,
&join.condition,
)?;

let sort_options = join
Expand Down
1 change: 1 addition & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ message SortMergeJoin {
repeated spark.spark_expression.Expr right_join_keys = 2;
JoinType join_type = 3;
repeated spark.spark_expression.Expr sort_options = 4;
optional spark.spark_expression.Expr condition = 5;
}

enum JoinType {
Expand Down
12 changes: 8 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2510,10 +2510,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}
}

// TODO: Support SortMergeJoin with join condition after new DataFusion release
if (join.condition.isDefined) {
withInfo(op, "Sort merge join with a join condition is not supported")
return None
val condition = join.condition.map { cond =>
val condProto = exprToProto(cond, join.left.output ++ join.right.output)
if (condProto.isEmpty) {
withInfo(join, cond)
return None
}
condProto.get
}

val joinType = join.joinType match {
Expand Down Expand Up @@ -2559,6 +2562,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.addAllSortOptions(sortOptions.map(_.get).asJava)
.addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
condition.map(joinBuilder.setCondition)
Some(result.setSortMergeJoin(joinBuilder).build())
} else {
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
Expand Down
69 changes: 68 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ class CometJoinSuite extends CometTestBase {
}
}

// TODO: Add a test for SortMergeJoin with join filter after new DataFusion release
test("SortMergeJoin without join filter") {
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
Expand Down Expand Up @@ -251,4 +250,72 @@ class CometJoinSuite extends CometTestBase {
}
}
}

test("SortMergeJoin with join filter") {
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
val df1 = sql(
"SELECT * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1 AND " +
"tbl_a._1 > tbl_b._2")
df1.explain()
checkSparkAnswerAndOperator(df1)

val df2 = sql(
"SELECT * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1 " +
"AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df2)

val df3 = sql(
"SELECT * FROM tbl_b LEFT JOIN tbl_a ON tbl_a._2 = tbl_b._1 " +
"AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df3)

val df4 = sql(
"SELECT * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1 " +
"AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df4)

val df5 = sql(
"SELECT * FROM tbl_b RIGHT JOIN tbl_a ON tbl_a._2 = tbl_b._1 " +
"AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df5)

val df6 = sql(
"SELECT * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1 " +
"AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df6)

val df7 = sql(
"SELECT * FROM tbl_b FULL JOIN tbl_a ON tbl_a._2 = tbl_b._1 " +
"AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df7)

/*
val left = sql("SELECT * FROM tbl_a")
val right = sql("SELECT * FROM tbl_b")
val df8 =
left.join(right, left("_2") === right("_1") && left("_2") >= right("_1"), "leftsemi")
checkSparkAnswerAndOperator(df8)
val df9 =
right.join(left, left("_2") === right("_1") && left("_2") >= right("_1"), "leftsemi")
checkSparkAnswerAndOperator(df9)
val df10 =
left.join(right, left("_2") === right("_1") && left("_2") >= right("_1"), "leftanti")
checkSparkAnswerAndOperator(df10)
val df11 =
right.join(left, left("_2") === right("_1") && left("_2") >= right("_1"), "leftanti")
checkSparkAnswerAndOperator(df11)
*/
}
}
}
}
}

0 comments on commit 40bed97

Please sign in to comment.