Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add HashJoin support for BuildRight #437

Merged
merged 20 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions core/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ itertools = "0.11.0"
chrono = { version = "0.4", default-features = false, features = ["clock"] }
chrono-tz = { version = "0.8" }
paste = "1.0.14"
datafusion-common = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4" }
datafusion = { default-features = false, git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features = ["unicode_expressions", "crypto_expressions"] }
datafusion-functions = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features = ["crypto_expressions"]}
datafusion-physical-expr = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", default-features = false, features = ["unicode_expressions"] }
datafusion-common = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "debb2f2" }
datafusion = { default-features = false, git = "https://github.com/viirya/arrow-datafusion.git", rev = "debb2f2", features = ["unicode_expressions", "crypto_expressions"] }
datafusion-functions = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "debb2f2", features = ["crypto_expressions"] }
datafusion-physical-expr = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "debb2f2", default-features = false, features = ["unicode_expressions"] }
unicode-segmentation = "^1.10.1"
once_cell = "1.18.0"
regex = "1.9.6"
Expand Down
13 changes: 11 additions & 2 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use datafusion::{
},
AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr,
},
physical_optimizer::join_selection::swap_hash_join,
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
filter::FilterExec,
Expand Down Expand Up @@ -966,7 +967,7 @@ impl PhysicalPlanner {
join.join_type,
&join.condition,
)?;
let join = Arc::new(HashJoinExec::try_new(
let hash_join = Arc::new(HashJoinExec::try_new(
join_params.left,
join_params.right,
join_params.join_on,
Expand All @@ -978,7 +979,15 @@ impl PhysicalPlanner {
// `EqualNullSafe`, Spark will rewrite it during planning.
false,
)?);
Ok((scans, join))

// If the hash join is build right, we need to swap the left and right
let hash_join = if join.build_side == 0 {
andygrove marked this conversation as resolved.
Show resolved Hide resolved
hash_join
} else {
swap_hash_join(hash_join.as_ref(), PartitionMode::Partitioned)?
};

Ok((scans, hash_join))
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ message HashJoin {
repeated spark.spark_expression.Expr right_join_keys = 2;
JoinType join_type = 3;
optional spark.spark_expression.Expr condition = 4;
BuildSide build_side = 5;
}

message SortMergeJoin {
Expand All @@ -114,3 +115,8 @@ enum JoinType {
LeftAnti = 6;
RightAnti = 7;
}

enum BuildSide {
BuildLeft = 0;
BuildRight = 1;
}
108 changes: 100 additions & 8 deletions dev/diffs/3.4.2.diff
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ index 7dec558f8df..840dda15033 100644
assert(exchanges.size == 2)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index f33432ddb6f..060f874ea72 100644
index f33432ddb6f..9cf7a9dd4e3 100644
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need the same change for 3.4.3.diff ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen
Expand All @@ -373,7 +373,37 @@ index f33432ddb6f..060f874ea72 100644
case _ => Nil
}
}
@@ -1187,7 +1191,8 @@ abstract class DynamicPartitionPruningSuiteBase
@@ -665,7 +669,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

- test("partition pruning in broadcast hash joins with aliases") {
+ test("partition pruning in broadcast hash joins with aliases",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
Given("alias with simple join condition, using attribute names only")
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
@@ -755,7 +760,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

- test("partition pruning in broadcast hash joins") {
+ test("partition pruning in broadcast hash joins",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
Given("disable broadcast pruning and disable subquery duplication")
withSQLConf(
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true",
@@ -990,7 +996,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

- test("different broadcast subqueries with identical children") {
+ test("different broadcast subqueries with identical children",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
withTable("fact", "dim") {
spark.range(100).select(
@@ -1187,7 +1194,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

Expand All @@ -383,7 +413,7 @@ index f33432ddb6f..060f874ea72 100644
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
"""
@@ -1238,7 +1243,8 @@ abstract class DynamicPartitionPruningSuiteBase
@@ -1238,7 +1246,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

Expand All @@ -393,7 +423,27 @@ index f33432ddb6f..060f874ea72 100644
Given("dynamic pruning filter on the build side")
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
@@ -1485,7 +1491,7 @@ abstract class DynamicPartitionPruningSuiteBase
@@ -1311,7 +1320,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

- test("SPARK-32817: DPP throws error when the broadcast side is empty") {
+ test("SPARK-32817: DPP throws error when the broadcast side is empty",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(
SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true",
@@ -1470,7 +1480,8 @@ abstract class DynamicPartitionPruningSuiteBase
checkAnswer(df, Row(3, 2) :: Row(3, 2) :: Row(3, 2) :: Row(3, 2) :: Nil)
}

- test("SPARK-36444: Remove OptimizeSubqueries from batch of PartitionPruning") {
+ test("SPARK-36444: Remove OptimizeSubqueries from batch of PartitionPruning",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
val df = sql(
"""
@@ -1485,7 +1496,7 @@ abstract class DynamicPartitionPruningSuiteBase
}

test("SPARK-38148: Do not add dynamic partition pruning if there exists static partition " +
Expand All @@ -402,7 +452,37 @@ index f33432ddb6f..060f874ea72 100644
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
Seq(
"f.store_id = 1" -> false,
@@ -1729,6 +1735,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
@@ -1557,7 +1568,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

- test("SPARK-38674: Remove useless deduplicate in SubqueryBroadcastExec") {
+ test("SPARK-38674: Remove useless deduplicate in SubqueryBroadcastExec",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withTable("duplicate_keys") {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
Seq[(Int, String)]((1, "NL"), (1, "NL"), (3, "US"), (3, "US"), (3, "US"))
@@ -1588,7 +1600,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

- test("SPARK-39338: Remove dynamic pruning subquery if pruningKey's references is empty") {
+ test("SPARK-39338: Remove dynamic pruning subquery if pruningKey's references is empty",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
val df = sql(
"""
@@ -1617,7 +1630,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

- test("SPARK-39217: Makes DPP support the pruning side has Union") {
+ test("SPARK-39217: Makes DPP support the pruning side has Union",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
val df = sql(
"""
@@ -1729,6 +1743,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
case s: BatchScanExec =>
// we use f1 col for v2 tables due to schema pruning
s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1")))
Expand Down Expand Up @@ -966,13 +1046,15 @@ index 4b3d3a4b805..56e1e0e6f16 100644

setupTestData()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
index 9e9d717db3b..91a4f9a38d5 100644
index 9e9d717db3b..c1a7caf56e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
@@ -18,6 +18,7 @@
@@ -17,7 +17,8 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest, Row}
-import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.{DataFrame, IgnoreComet, QueryTest, Row}
+import org.apache.spark.sql.comet.CometProjectExec
import org.apache.spark.sql.connector.SimpleWritableDataSource
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
Expand All @@ -989,6 +1071,16 @@ index 9e9d717db3b..91a4f9a38d5 100644
assert(actual == expected)
}
}
@@ -112,7 +116,8 @@ abstract class RemoveRedundantProjectsSuiteBase
assertProjectExec(query, 1, 3)
}

- test("join with ordering requirement") {
+ test("join with ordering requirement",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
val query = "select * from (select key, a, c, b from testView) as t1 join " +
"(select key, a, b, c from testView) as t2 on t1.key = t2.key where t2.a > 50"
assertProjectExec(query, 2, 2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
index 30ce940b032..0d3f6c6c934 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
Expand Down
12 changes: 6 additions & 6 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
Expand All @@ -46,7 +46,7 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isC
import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported}
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator}
import org.apache.comet.shims.CometExprShim
import org.apache.comet.shims.ShimQueryPlanSerde

Expand Down Expand Up @@ -2438,10 +2438,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
return None
}

if (join.buildSide == BuildRight) {
// DataFusion HashJoin assumes build side is always left.
// TODO: support BuildRight
withInfo(join, "BuildRight is not supported")
if (join.buildSide == BuildRight && join.joinType == LeftAnti) {
withInfo(join, "BuildRight with LeftAnti is not supported")
andygrove marked this conversation as resolved.
Show resolved Hide resolved
return None
}

Expand Down Expand Up @@ -2478,6 +2476,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.setJoinType(joinType)
.addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
.setBuildSide(
if (join.buildSide == BuildLeft) BuildSide.BuildLeft else BuildSide.BuildRight)
condition.foreach(joinBuilder.setCondition)
Some(result.setHashJoin(joinBuilder).build())
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ case class CometHashJoinExec(
this.copy(left = newLeft, right = newRight)

override def stringArgs: Iterator[Any] =
Iterator(leftKeys, rightKeys, joinType, condition, left, right)
Iterator(leftKeys, rightKeys, joinType, buildSide, condition, left, right)

override def equals(obj: Any): Boolean = {
obj match {
Expand Down Expand Up @@ -836,7 +836,7 @@ case class CometBroadcastHashJoinExec(
this.copy(left = newLeft, right = newRight)

override def stringArgs: Iterator[Any] =
Iterator(leftKeys, rightKeys, joinType, condition, left, right)
Iterator(leftKeys, rightKeys, joinType, condition, buildSide, left, right)

override def equals(obj: Any): Boolean = {
obj match {
Expand Down
Loading
Loading