From 0b7f600985e91c0b50a9420f879c5e57ca3cab8a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 8 Mar 2024 23:23:54 -0800 Subject: [PATCH] feat: Support sort merge join --- .../apache/comet/vector/CometPlainVector.java | 6 +- core/src/execution/datafusion/planner.rs | 93 ++++++++++++++++++- core/src/execution/operators/copy.rs | 2 +- core/src/execution/proto/operator.proto | 19 ++++ .../comet/CometSparkSessionExtensions.scala | 42 +++++++-- .../apache/comet/serde/QueryPlanSerde.scala | 60 +++++++++++- .../apache/spark/sql/comet/operators.scala | 40 +++++++- .../apache/comet/exec/CometExecSuite.scala | 48 ++++++++++ 8 files changed, 296 insertions(+), 14 deletions(-) diff --git a/common/src/main/java/org/apache/comet/vector/CometPlainVector.java b/common/src/main/java/org/apache/comet/vector/CometPlainVector.java index a7373590ab..ce40eb6206 100644 --- a/common/src/main/java/org/apache/comet/vector/CometPlainVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometPlainVector.java @@ -113,7 +113,11 @@ public UTF8String getUTF8String(int rowId) { byte[] result = new byte[length]; Platform.copyMemory( null, valueBufferAddress + offset, result, Platform.BYTE_ARRAY_OFFSET, length); - return UTF8String.fromString(convertToUuid(result).toString()); + if (length == 16) { + return UTF8String.fromString(convertToUuid(result).toString()); + } else { + return UTF8String.fromBytes(result); + } } } diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 33cf636a2e..96b423c899 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -32,13 +32,14 @@ use datafusion::{ physical_plan::{ aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy}, filter::FilterExec, + joins::SortMergeJoinExec, limit::LocalLimitExec, projection::ProjectionExec, sorts::sort::SortExec, ExecutionPlan, Partitioning, }, }; -use datafusion_common::ScalarValue; +use datafusion_common::{JoinType as DFJoinType, ScalarValue}; use datafusion_physical_expr::{ execution_props::ExecutionProps, expressions::{ @@ -79,7 +80,7 @@ use crate::{ agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr, ScalarFunc, }, - spark_operator::{operator::OpStruct, Operator}, + spark_operator::{operator::OpStruct, JoinType, Operator}, spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }, }; @@ -849,6 +850,85 @@ impl PhysicalPlanner { Arc::new(CometExpandExec::new(projections, child, schema)), )) } + OpStruct::SortMergeJoin(join) => { + assert!(children.len() == 2); + let (mut left_scans, left) = self.create_plan(&children[0], inputs)?; + let (mut right_scans, right) = self.create_plan(&children[1], inputs)?; + + left_scans.append(&mut right_scans); + + let left_join_exprs = join + .left_join_keys + .iter() + .map(|expr| self.create_expr(expr, left.schema())) + .collect::, _>>()?; + let right_join_exprs = join + .right_join_keys + .iter() + .map(|expr| self.create_expr(expr, right.schema())) + .collect::, _>>()?; + + let join_on = left_join_exprs + .into_iter() + .zip(right_join_exprs) + .collect::>(); + + let join_type = match join.join_type.try_into() { + Ok(JoinType::Inner) => DFJoinType::Inner, + Ok(JoinType::LeftOuter) => DFJoinType::Left, + Ok(JoinType::RightOuter) => DFJoinType::Right, + Ok(JoinType::FullOuter) => DFJoinType::Full, + Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi, + Ok(JoinType::RightSemi) => DFJoinType::RightSemi, + Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti, + Ok(JoinType::RightAnti) => DFJoinType::RightAnti, + Err(_) => { + return Err(ExecutionError::GeneralError(format!( + "Unsupported join type: {:?}", + join.join_type + ))); + } + }; + + let sort_options = join + .sort_options + .iter() + .map(|sort_option| { + let sort_expr = self.create_sort_expr(sort_option, left.schema()).unwrap(); + SortOptions { + descending: sort_expr.options.descending, + nulls_first: sort_expr.options.nulls_first, + } + }) + .collect(); + + // DataFusion `SortMergeJoinExec` operator keeps the input batch internally. We need + // to copy the input batch to avoid the data corruption from reusing the input + // batch. + let left = if op_reuse_array(&left) { + Arc::new(CopyExec::new(left)) + } else { + left + }; + + let right = if op_reuse_array(&right) { + Arc::new(CopyExec::new(right)) + } else { + right + }; + + let join = Arc::new(SortMergeJoinExec::try_new( + left, + right, + join_on, + None, + join_type, + sort_options, + false, + )?); + + Ok((left_scans, join)) + } } } @@ -1017,6 +1097,15 @@ impl From for DataFusionError { } } +/// Returns true if given operator probably returns input array as output array without +/// modification. +fn op_reuse_array(op: &Arc) -> bool { + op.as_any().downcast_ref::().is_some() + || op.as_any().downcast_ref::().is_some() + || op.as_any().downcast_ref::().is_some() + || op.as_any().downcast_ref::().is_some() +} + #[cfg(test)] mod tests { use std::{sync::Arc, task::Poll}; diff --git a/core/src/execution/operators/copy.rs b/core/src/execution/operators/copy.rs index 996db2b470..699ccf7ae7 100644 --- a/core/src/execution/operators/copy.rs +++ b/core/src/execution/operators/copy.rs @@ -91,7 +91,7 @@ impl ExecutionPlan for CopyExec { } fn children(&self) -> Vec> { - self.input.children() + vec![self.input.clone()] } fn with_new_children( diff --git a/core/src/execution/proto/operator.proto b/core/src/execution/proto/operator.proto index 5b07cb30b1..0b7888d404 100644 --- a/core/src/execution/proto/operator.proto +++ b/core/src/execution/proto/operator.proto @@ -40,6 +40,7 @@ message Operator { Limit limit = 105; ShuffleWriter shuffle_writer = 106; Expand expand = 107; + SortMergeJoin sort_merge_join = 108; } } @@ -87,3 +88,21 @@ message Expand { repeated spark.spark_expression.Expr project_list = 1; int32 num_expr_per_project = 3; } + +message SortMergeJoin { + repeated spark.spark_expression.Expr left_join_keys = 1; + repeated spark.spark_expression.Expr right_join_keys = 2; + JoinType join_type = 3; + repeated spark.spark_expression.Expr sort_options = 4; +} + +enum JoinType { + Inner = 0; + LeftOuter = 1; + RightOuter = 2; + FullOuter = 3; + LeftSemi = 4; + RightSemi = 5; + LeftAnti = 6; + RightAnti = 7; +} diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 87c2265fcb..a51acc988a 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} @@ -38,6 +37,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -222,12 +222,16 @@ class CometSparkSessionExtensions */ // spotless:on private def transform(plan: SparkPlan): SparkPlan = { - def transform1(op: UnaryExecNode): Option[Operator] = { - op.child match { - case childNativeOp: CometNativeExec => - QueryPlanSerde.operator2Proto(op, childNativeOp.nativeOp) - case _ => - None + def transform1(op: SparkPlan): Option[Operator] = { + val allNativeExec = op.children.map { + case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp) + case _ => None + } + + if (allNativeExec.forall(_.isDefined)) { + QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*) + } else { + None } } @@ -333,6 +337,26 @@ class CometSparkSessionExtensions op } + case op: SortMergeJoinExec + if isCometOperatorEnabled(conf, "sort_merge_join") && + op.children.forall(isCometNative(_)) => + val newOp = transform1(op) + newOp match { + case Some(nativeOp) => + CometSortMergeJoinExec( + nativeOp, + op, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.left, + op.right, + SerializedPlan(None)) + case None => + op + } + case c @ CoalesceExec(numPartitions, child) if isCometOperatorEnabled(conf, "coalesce") && isCometNative(child) => @@ -547,7 +571,9 @@ object CometSparkSessionExtensions extends Logging { private[comet] def isCometOperatorEnabled(conf: SQLConf, operator: String): Boolean = { val operatorFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.enabled" - conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) + val operatorDisabledFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.disabled" + conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) && + !conf.getConfString(operatorDisabledFlag, "false").toBoolean } private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b27fa3a754..336640fc24 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Final, First, Last, Max, Min, Partial, Sum} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkPlaceHolder, DecimalPrecision} @@ -33,6 +34,7 @@ import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -40,7 +42,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} -import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} +import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} import org.apache.comet.shims.ShimQueryPlanSerde /** @@ -1836,6 +1838,62 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } + case join: SortMergeJoinExec if isCometOperatorEnabled(op.conf, "sort_merge_join") => + // `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec. + def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + keys.map(SortOrder(_, Ascending)) + } + + def getKeyOrdering( + keys: Seq[Expression], + childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = { + val requiredOrdering = requiredOrders(keys) + if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) { + keys.zip(childOutputOrdering).map { case (key, childOrder) => + val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key + SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq) + } + } else { + requiredOrdering + } + } + + // TODO: Support SortMergeJoin with join condition after new DataFusion release + if (join.condition.isDefined) { + return None + } + + val joinType = join.joinType match { + case Inner => JoinType.Inner + case LeftOuter => JoinType.LeftOuter + case RightOuter => JoinType.RightOuter + case FullOuter => JoinType.FullOuter + case LeftSemi => JoinType.LeftSemi + case LeftAnti => JoinType.LeftAnti + case _ => return None // Spark doesn't support other join types + } + + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) + val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) + + val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering) + .map(exprToProto(_, join.left.output)) + + if (sortOptions.forall(_.isDefined) && + leftKeys.forall(_.isDefined) && + rightKeys.forall(_.isDefined) && + childOp.nonEmpty) { + val joinBuilder = OperatorOuterClass.SortMergeJoin + .newBuilder() + .setJoinType(joinType) + .addAllSortOptions(sortOptions.map(_.get).asJava) + .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) + .addAllRightJoinKeys(rightKeys.map(_.get).asJava) + Some(result.setSortMergeJoin(joinBuilder).build()) + } else { + None + } + case op if isCometSink(op) => // These operators are source of Comet native execution chain val scanBuilder = OperatorOuterClass.Scan.newBuilder() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index e75f9a4a52..553791de87 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -31,9 +31,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode} +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -323,6 +324,8 @@ abstract class CometNativeExec extends CometExec { abstract class CometUnaryExec extends CometNativeExec with UnaryExecNode +abstract class CometBinaryExec extends CometNativeExec with BinaryExecNode + /** * Represents the serialized plan of Comet native operators. Only the first operator in a block of * continuous Comet native operators has defined plan bytes which contains the serialization of @@ -583,6 +586,41 @@ case class CometHashAggregateExec( Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child) } +case class CometSortMergeJoinExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + override val left: SparkPlan, + override val right: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometBinaryExec { + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = + this.copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = + Iterator(leftKeys, rightKeys, joinType, condition, left, right) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometSortMergeJoinExec => + this.leftKeys == other.leftKeys && + this.rightKeys == other.rightKeys && + this.condition == other.condition && + this.left == other.left && + this.right == other.right && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(leftKeys, rightKeys, condition, left, right) +} + case class CometScanWrapper(override val nativeOp: Operator, override val originalPlan: SparkPlan) extends CometNativeExec with LeafExecNode { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 6a34d4fe4a..e3a99ceca8 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -58,6 +58,53 @@ class CometExecSuite extends CometTestBase { } } + // TODO: Add a test for SortMergeJoin with join filter after new DataFusion release + test("SortMergeJoin without join filter") { + withSQLConf( + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + val df1 = sql("SELECT * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df1) + + val df2 = sql("SELECT * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df2) + + val df3 = sql("SELECT * FROM tbl_b LEFT JOIN tbl_a ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df3) + + val df4 = sql("SELECT * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df4) + + val df5 = sql("SELECT * FROM tbl_b RIGHT JOIN tbl_a ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df5) + + val df6 = sql("SELECT * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df6) + + val df7 = sql("SELECT * FROM tbl_b FULL JOIN tbl_a ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df7) + + val left = sql("SELECT * FROM tbl_a") + val right = sql("SELECT * FROM tbl_b") + + val df8 = left.join(right, left("_2") === right("_1"), "leftsemi") + checkSparkAnswerAndOperator(df8) + + val df9 = right.join(left, left("_2") === right("_1"), "leftsemi") + checkSparkAnswerAndOperator(df9) + + val df10 = left.join(right, left("_2") === right("_1"), "leftanti") + checkSparkAnswerAndOperator(df10) + + val df11 = right.join(left, left("_2") === right("_1"), "leftanti") + checkSparkAnswerAndOperator(df11) + } + } + } + } + test("Fix corrupted AggregateMode when transforming plan parameters") { withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2")) @@ -859,6 +906,7 @@ class CometExecSuite extends CometTestBase { .saveAsTable("bucketed_table2") withSQLConf( + "spark.comet.exec.sort_merge_join.disabled" -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { val t1 = spark.table("bucketed_table1")