diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 59818857e7..6be3291cca 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -40,7 +40,7 @@ use datafusion::{ physical_plan::{ aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy}, filter::FilterExec, - joins::{utils::JoinFilter, HashJoinExec, PartitionMode, SortMergeJoinExec}, + joins::{utils::JoinFilter, CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec}, limit::LocalLimitExec, projection::ProjectionExec, sorts::sort::SortExec, @@ -978,6 +978,15 @@ impl PhysicalPlanner { )?); Ok((scans, join)) } + OpStruct::CrossJoin(_) => { + assert!(children.len() == 2); + let (mut left_scans, left) = self.create_plan(&children[0], inputs)?; + let (mut right_scans, right) = self.create_plan(&children[1], inputs)?; + + left_scans.append(&mut right_scans); + let join = Arc::new(CrossJoinExec::new(left, right)); + Ok((left_scans, join)) + } } } diff --git a/core/src/execution/proto/operator.proto b/core/src/execution/proto/operator.proto index 6080c56682..f99ee80f6e 100644 --- a/core/src/execution/proto/operator.proto +++ b/core/src/execution/proto/operator.proto @@ -42,6 +42,7 @@ message Operator { Expand expand = 107; SortMergeJoin sort_merge_join = 108; HashJoin hash_join = 109; + CrossJoin cross_join = 110; } } @@ -104,6 +105,9 @@ message SortMergeJoin { repeated spark.spark_expression.Expr sort_options = 4; } +message CrossJoin { +} + enum JoinType { Inner = 0; LeftOuter = 1; diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 7ddc950eaa..f352c036f7 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -698,6 +698,31 @@ class CometSparkSessionExtensions withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(",")) s + case op: CartesianProductExec + if isCometOperatorEnabled(conf, "cross_join") && + op.children.forall(isCometNative) => + val newOp = transform1(op) + newOp match { + case Some(nativeOp) => + CometCartesianProductExec( + nativeOp, + op, + op.condition, + op.left, + op.right, + SerializedPlan(None)) + case None => + op + } + + case op: CartesianProductExec if !isCometOperatorEnabled(conf, "cross_join") => + withInfo(op, "Union is not enabled") + op + + case op: CartesianProductExec if !op.children.forall(isCometNative(_)) => + withInfo(op, "Union disabled because not all child plans are native") + op + case op => // An operator that is not supported by Comet op match { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index cf7c86a9fd..ab0a91617b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -2545,6 +2545,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(join, "SortMergeJoin is not enabled") None + case join: CartesianProductExec if isCometOperatorEnabled(op.conf, "cross_join") => + // TODO: Support CartesianProductExec with join condition after new DataFusion release + if (join.condition.isDefined) { + withInfo(op, "cross_join with a join condition is not supported") + return None + } + None + + case join: CartesianProductExec if !isCometOperatorEnabled(op.conf, "cross_join") => + withInfo(join, "cross_join is not enabled") + None + case op if isCometSink(op) && op.output.forall(a => supportedDataType(a.dataType)) => // These operators are source of Comet native execution chain val scanBuilder = OperatorOuterClass.Scan.newBuilder() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index ad07ff0e25..db9bef430c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -899,6 +899,39 @@ case class CometSortMergeJoinExec( "join_time" -> SQLMetrics.createNanoTimingMetric(sparkContext, "Total time for joining")) } +case class CometCartesianProductExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + condition: Option[Expression], + override val left: SparkPlan, + override val right: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometBinaryExec { + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SparkPlan = + this.copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = Iterator(condition, left, right) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometCartesianProductExec => + this.condition == other.condition && + this.left == other.left && + this.right == other.right && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false + } + } + + override def hashCode(): Int = Objects.hashCode(condition, left, right) + + override lazy val metrics: Map[String, SQLMetric] = + CometMetricNode.baselineMetrics(sparkContext) +} case class CometScanWrapper(override val nativeOp: Operator, override val originalPlan: SparkPlan) extends CometNativeExec with LeafExecNode { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 91d88c76e8..42f1de7320 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -251,4 +251,14 @@ class CometJoinSuite extends CometTestBase { } } } + + // TODO: Add a test for CartesianProductExec with join filter after new DataFusion release + test("CartesianProductExec without join filter") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + val df1 = sql("SELECT * FROM tbl_a CROSS JOIN tbl_b") + checkSparkAnswerAndOperator(df1) + } + } + } }