Skip to content

Commit

Permalink
feat: Support sort merge join (apache#178)
Browse files Browse the repository at this point in the history
* feat: Support sort merge join

* Update PlanStability

* Update Spark diff

* For review

* Fix format

* For review

* Add CometJoinSuite

* Remove uuid

* Fix diff
  • Loading branch information
viirya authored Mar 18, 2024
1 parent 5f58010 commit 8aab44c
Show file tree
Hide file tree
Showing 32 changed files with 3,283 additions and 3,188 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,13 @@ public UTF8String getUTF8String(int rowId) {
int length = Platform.getInt(null, offsetBufferAddress + (rowId + 1L) * 4L) - offset;
return UTF8String.fromAddress(null, valueBufferAddress + offset, length);
} else {
// Iceberg maps UUID to StringType.
// The data type here must be UUID because the only FLBA -> String mapping we have is UUID.
BaseFixedWidthVector fixedWidthVector = (BaseFixedWidthVector) valueVector;
int length = fixedWidthVector.getTypeWidth();
int offset = rowId * length;
byte[] result = new byte[length];
Platform.copyMemory(
null, valueBufferAddress + offset, result, Platform.BYTE_ARRAY_OFFSET, length);
return UTF8String.fromString(convertToUuid(result).toString());
return UTF8String.fromBytes(result);
}
}

Expand Down
96 changes: 94 additions & 2 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ use datafusion::{
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
filter::FilterExec,
joins::SortMergeJoinExec,
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
ExecutionPlan, Partitioning,
},
};
use datafusion_common::ScalarValue;
use datafusion_common::{JoinType as DFJoinType, ScalarValue};
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
Expand Down Expand Up @@ -77,7 +78,7 @@ use crate::{
agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr,
ScalarFunc,
},
spark_operator::{operator::OpStruct, Operator},
spark_operator::{operator::OpStruct, JoinType, Operator},
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
},
};
Expand Down Expand Up @@ -868,6 +869,87 @@ impl PhysicalPlanner {
Arc::new(CometExpandExec::new(projections, child, schema)),
))
}
OpStruct::SortMergeJoin(join) => {
assert!(children.len() == 2);
let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;

left_scans.append(&mut right_scans);

let left_join_exprs = join
.left_join_keys
.iter()
.map(|expr| self.create_expr(expr, left.schema()))
.collect::<Result<Vec<_>, _>>()?;
let right_join_exprs = join
.right_join_keys
.iter()
.map(|expr| self.create_expr(expr, right.schema()))
.collect::<Result<Vec<_>, _>>()?;

let join_on = left_join_exprs
.into_iter()
.zip(right_join_exprs)
.collect::<Vec<_>>();

let join_type = match join.join_type.try_into() {
Ok(JoinType::Inner) => DFJoinType::Inner,
Ok(JoinType::LeftOuter) => DFJoinType::Left,
Ok(JoinType::RightOuter) => DFJoinType::Right,
Ok(JoinType::FullOuter) => DFJoinType::Full,
Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
Err(_) => {
return Err(ExecutionError::GeneralError(format!(
"Unsupported join type: {:?}",
join.join_type
)));
}
};

let sort_options = join
.sort_options
.iter()
.map(|sort_option| {
let sort_expr = self.create_sort_expr(sort_option, left.schema()).unwrap();
SortOptions {
descending: sort_expr.options.descending,
nulls_first: sort_expr.options.nulls_first,
}
})
.collect();

// DataFusion `SortMergeJoinExec` operator keeps the input batch internally. We need
// to copy the input batch to avoid the data corruption from reusing the input
// batch.
let left = if can_reuse_input_batch(&left) {
Arc::new(CopyExec::new(left))
} else {
left
};

let right = if can_reuse_input_batch(&right) {
Arc::new(CopyExec::new(right))
} else {
right
};

let join = Arc::new(SortMergeJoinExec::try_new(
left,
right,
join_on,
None,
join_type,
sort_options,
// null doesn't equal to null in Spark join key. If the join key is
// `EqualNullSafe`, Spark will rewrite it during planning.
false,
)?);

Ok((left_scans, join))
}
}
}

Expand Down Expand Up @@ -1051,6 +1133,16 @@ impl From<ExpressionError> for DataFusionError {
}
}

/// Returns true if given operator can return input array as output array without
/// modification. This is used to determine if we need to copy the input batch to avoid
/// data corruption from reusing the input batch.
fn can_reuse_input_batch(op: &Arc<dyn ExecutionPlan>) -> bool {
op.as_any().downcast_ref::<ScanExec>().is_some()
|| op.as_any().downcast_ref::<LocalLimitExec>().is_some()
|| op.as_any().downcast_ref::<ProjectionExec>().is_some()
|| op.as_any().downcast_ref::<FilterExec>().is_some()
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/operators/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl ExecutionPlan for CopyExec {
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
self.input.children()
vec![self.input.clone()]
}

fn with_new_children(
Expand Down
19 changes: 19 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message Operator {
Limit limit = 105;
ShuffleWriter shuffle_writer = 106;
Expand expand = 107;
SortMergeJoin sort_merge_join = 108;
}
}

Expand Down Expand Up @@ -87,3 +88,21 @@ message Expand {
repeated spark.spark_expression.Expr project_list = 1;
int32 num_expr_per_project = 3;
}

message SortMergeJoin {
repeated spark.spark_expression.Expr left_join_keys = 1;
repeated spark.spark_expression.Expr right_join_keys = 2;
JoinType join_type = 3;
repeated spark.spark_expression.Expr sort_options = 4;
}

enum JoinType {
Inner = 0;
LeftOuter = 1;
RightOuter = 2;
FullOuter = 3;
LeftSemi = 4;
RightSemi = 5;
LeftAnti = 6;
RightAnti = 7;
}
42 changes: 36 additions & 6 deletions dev/diffs/3.4.2.diff
Original file line number Diff line number Diff line change
Expand Up @@ -935,18 +935,20 @@ index d083cac48ff..3c11bcde807 100644
import testImplicits._

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 266bb343526..cb90d15fed7 100644
index 266bb343526..f393606997c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
@@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.comet.execution.shuffle._
import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan}
+import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SortExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.datasources.BucketingUtils
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -101,12 +103,20 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
}
}
Expand Down Expand Up @@ -980,7 +982,35 @@ index 266bb343526..cb90d15fed7 100644

val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType
val rowsWithInvalidBuckets = fileScan.execute().filter(row => {
@@ -461,18 +472,22 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -451,28 +462,46 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) {
val executedPlan =
joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- assert(executedPlan.isInstanceOf[SortMergeJoinExec])
- executedPlan.asInstanceOf[SortMergeJoinExec]
+ executedPlan match {
+ case s: SortMergeJoinExec => s
+ case b: CometSortMergeJoinExec =>
+ b.originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
} else {
val executedPlan = joined.queryExecution.executedPlan
- assert(executedPlan.isInstanceOf[SortMergeJoinExec])
- executedPlan.asInstanceOf[SortMergeJoinExec]
+ executedPlan match {
+ case s: SortMergeJoinExec => s
+ case ColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
}

// check existence of shuffle
assert(
Expand All @@ -1007,7 +1037,7 @@ index 266bb343526..cb90d15fed7 100644
s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}")

// check the output partitioning
@@ -835,11 +850,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -835,11 +864,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")

val scanDF = spark.table("bucketed_table").select("j")
Expand All @@ -1021,7 +1051,7 @@ index 266bb343526..cb90d15fed7 100644
checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
}
}
@@ -1031,10 +1046,16 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -1031,10 +1060,16 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti

val scans = plan.collect {
case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -223,13 +224,10 @@ class CometSparkSessionExtensions
// spotless:on
private def transform(plan: SparkPlan): SparkPlan = {
def transform1(op: SparkPlan): Option[Operator] = {
val allNativeExec = op.children.map {
case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp)
case _ => None
}

if (allNativeExec.forall(_.isDefined)) {
QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*)
if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
QueryPlanSerde.operator2Proto(
op,
op.children.map(_.asInstanceOf[CometNativeExec].nativeOp): _*)
} else {
None
}
Expand Down Expand Up @@ -337,6 +335,26 @@ class CometSparkSessionExtensions
op
}

case op: SortMergeJoinExec
if isCometOperatorEnabled(conf, "sort_merge_join") &&
op.children.forall(isCometNative(_)) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometSortMergeJoinExec(
nativeOp,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
op.condition,
op.left,
op.right,
SerializedPlan(None))
case None =>
op
}

case c @ CoalesceExec(numPartitions, child)
if isCometOperatorEnabled(conf, "coalesce")
&& isCometNative(child) =>
Expand Down Expand Up @@ -576,7 +594,9 @@ object CometSparkSessionExtensions extends Logging {

private[comet] def isCometOperatorEnabled(conf: SQLConf, operator: String): Boolean = {
val operatorFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.enabled"
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf)
val operatorDisabledFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.disabled"
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) &&
!conf.getConfString(operatorDisabledFlag, "false").toBoolean
}

private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = {
Expand Down
Loading

0 comments on commit 8aab44c

Please sign in to comment.