Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support sort merge join #178

Merged
merged 9 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,13 @@ public UTF8String getUTF8String(int rowId) {
int length = Platform.getInt(null, offsetBufferAddress + (rowId + 1L) * 4L) - offset;
return UTF8String.fromAddress(null, valueBufferAddress + offset, length);
} else {
// Iceberg maps UUID to StringType.
// The data type here must be UUID because the only FLBA -> String mapping we have is UUID.
BaseFixedWidthVector fixedWidthVector = (BaseFixedWidthVector) valueVector;
int length = fixedWidthVector.getTypeWidth();
int offset = rowId * length;
byte[] result = new byte[length];
Platform.copyMemory(
null, valueBufferAddress + offset, result, Platform.BYTE_ARRAY_OFFSET, length);
return UTF8String.fromString(convertToUuid(result).toString());
return UTF8String.fromBytes(result);
}
}

Expand Down
96 changes: 94 additions & 2 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ use datafusion::{
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
filter::FilterExec,
joins::SortMergeJoinExec,
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
ExecutionPlan, Partitioning,
},
};
use datafusion_common::ScalarValue;
use datafusion_common::{JoinType as DFJoinType, ScalarValue};
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
Expand Down Expand Up @@ -77,7 +78,7 @@ use crate::{
agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr,
ScalarFunc,
},
spark_operator::{operator::OpStruct, Operator},
spark_operator::{operator::OpStruct, JoinType, Operator},
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
},
};
Expand Down Expand Up @@ -868,6 +869,87 @@ impl PhysicalPlanner {
Arc::new(CometExpandExec::new(projections, child, schema)),
))
}
OpStruct::SortMergeJoin(join) => {
assert!(children.len() == 2);
let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;

left_scans.append(&mut right_scans);

let left_join_exprs = join
.left_join_keys
.iter()
.map(|expr| self.create_expr(expr, left.schema()))
.collect::<Result<Vec<_>, _>>()?;
let right_join_exprs = join
.right_join_keys
.iter()
.map(|expr| self.create_expr(expr, right.schema()))
.collect::<Result<Vec<_>, _>>()?;

let join_on = left_join_exprs
.into_iter()
.zip(right_join_exprs)
.collect::<Vec<_>>();

let join_type = match join.join_type.try_into() {
Ok(JoinType::Inner) => DFJoinType::Inner,
Ok(JoinType::LeftOuter) => DFJoinType::Left,
Ok(JoinType::RightOuter) => DFJoinType::Right,
Ok(JoinType::FullOuter) => DFJoinType::Full,
Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
Err(_) => {
return Err(ExecutionError::GeneralError(format!(
"Unsupported join type: {:?}",
join.join_type
)));
}
};

let sort_options = join
.sort_options
.iter()
.map(|sort_option| {
let sort_expr = self.create_sort_expr(sort_option, left.schema()).unwrap();
SortOptions {
descending: sort_expr.options.descending,
nulls_first: sort_expr.options.nulls_first,
}
})
.collect();

// DataFusion `SortMergeJoinExec` operator keeps the input batch internally. We need
// to copy the input batch to avoid the data corruption from reusing the input
// batch.
let left = if can_reuse_input_batch(&left) {
Arc::new(CopyExec::new(left))
} else {
left
};

let right = if can_reuse_input_batch(&right) {
Arc::new(CopyExec::new(right))
} else {
right
};

let join = Arc::new(SortMergeJoinExec::try_new(
left,
right,
join_on,
None,
join_type,
sort_options,
// null doesn't equal to null in Spark join key. If the join key is
// `EqualNullSafe`, Spark will rewrite it during planning.
false,
viirya marked this conversation as resolved.
Show resolved Hide resolved
)?);

Ok((left_scans, join))
}
}
}

Expand Down Expand Up @@ -1051,6 +1133,16 @@ impl From<ExpressionError> for DataFusionError {
}
}

/// Returns true if given operator can return input array as output array without
/// modification. This is used to determine if we need to copy the input batch to avoid
/// data corruption from reusing the input batch.
fn can_reuse_input_batch(op: &Arc<dyn ExecutionPlan>) -> bool {
op.as_any().downcast_ref::<ScanExec>().is_some()
|| op.as_any().downcast_ref::<LocalLimitExec>().is_some()
|| op.as_any().downcast_ref::<ProjectionExec>().is_some()
|| op.as_any().downcast_ref::<FilterExec>().is_some()
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/operators/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl ExecutionPlan for CopyExec {
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
self.input.children()
vec![self.input.clone()]
}

fn with_new_children(
Expand Down
19 changes: 19 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message Operator {
Limit limit = 105;
ShuffleWriter shuffle_writer = 106;
Expand expand = 107;
SortMergeJoin sort_merge_join = 108;
}
}

Expand Down Expand Up @@ -87,3 +88,21 @@ message Expand {
repeated spark.spark_expression.Expr project_list = 1;
int32 num_expr_per_project = 3;
}

message SortMergeJoin {
repeated spark.spark_expression.Expr left_join_keys = 1;
repeated spark.spark_expression.Expr right_join_keys = 2;
JoinType join_type = 3;
repeated spark.spark_expression.Expr sort_options = 4;
}

enum JoinType {
Inner = 0;
LeftOuter = 1;
RightOuter = 2;
FullOuter = 3;
LeftSemi = 4;
RightSemi = 5;
LeftAnti = 6;
RightAnti = 7;
}
42 changes: 36 additions & 6 deletions dev/diffs/3.4.2.diff
Original file line number Diff line number Diff line change
Expand Up @@ -935,18 +935,20 @@ index d083cac48ff..3c11bcde807 100644
import testImplicits._

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 266bb343526..cb90d15fed7 100644
index 266bb343526..f393606997c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
@@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.comet.execution.shuffle._
import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan}
+import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SortExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.datasources.BucketingUtils
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -101,12 +103,20 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
}
}
Expand Down Expand Up @@ -980,7 +982,35 @@ index 266bb343526..cb90d15fed7 100644

val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType
val rowsWithInvalidBuckets = fileScan.execute().filter(row => {
@@ -461,18 +472,22 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -451,28 +462,46 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) {
val executedPlan =
joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- assert(executedPlan.isInstanceOf[SortMergeJoinExec])
- executedPlan.asInstanceOf[SortMergeJoinExec]
+ executedPlan match {
+ case s: SortMergeJoinExec => s
+ case b: CometSortMergeJoinExec =>
+ b.originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
} else {
val executedPlan = joined.queryExecution.executedPlan
- assert(executedPlan.isInstanceOf[SortMergeJoinExec])
- executedPlan.asInstanceOf[SortMergeJoinExec]
+ executedPlan match {
+ case s: SortMergeJoinExec => s
+ case ColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
}

// check existence of shuffle
assert(
Expand All @@ -1007,7 +1037,7 @@ index 266bb343526..cb90d15fed7 100644
s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}")

// check the output partitioning
@@ -835,11 +850,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -835,11 +864,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")

val scanDF = spark.table("bucketed_table").select("j")
Expand All @@ -1021,7 +1051,7 @@ index 266bb343526..cb90d15fed7 100644
checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
}
}
@@ -1031,10 +1046,16 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -1031,10 +1060,16 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti

val scans = plan.collect {
case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -223,13 +224,10 @@ class CometSparkSessionExtensions
// spotless:on
private def transform(plan: SparkPlan): SparkPlan = {
def transform1(op: SparkPlan): Option[Operator] = {
val allNativeExec = op.children.map {
case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp)
case _ => None
}

if (allNativeExec.forall(_.isDefined)) {
QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*)
if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
QueryPlanSerde.operator2Proto(
op,
op.children.map(_.asInstanceOf[CometNativeExec].nativeOp): _*)
} else {
None
}
Expand Down Expand Up @@ -337,6 +335,26 @@ class CometSparkSessionExtensions
op
}

case op: SortMergeJoinExec
if isCometOperatorEnabled(conf, "sort_merge_join") &&
op.children.forall(isCometNative(_)) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometSortMergeJoinExec(
nativeOp,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
op.condition,
op.left,
op.right,
SerializedPlan(None))
case None =>
op
}

case c @ CoalesceExec(numPartitions, child)
if isCometOperatorEnabled(conf, "coalesce")
&& isCometNative(child) =>
Expand Down Expand Up @@ -576,7 +594,9 @@ object CometSparkSessionExtensions extends Logging {

private[comet] def isCometOperatorEnabled(conf: SQLConf, operator: String): Boolean = {
val operatorFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.enabled"
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf)
val operatorDisabledFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.disabled"
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) &&
!conf.getConfString(operatorDisabledFlag, "false").toBoolean
Comment on lines +597 to +599
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "disable" flag is useful to disable a particular operator in unit test. For example, I disable sort merge join in one existing test below.

}

private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = {
Expand Down
Loading
Loading