Skip to content

Commit

Permalink
Support CartesianProductExec in comet
Browse files Browse the repository at this point in the history
  • Loading branch information
leoluan2009 committed May 17, 2024
1 parent f8fec7f commit bac46b2
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 3 deletions.
11 changes: 10 additions & 1 deletion core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use datafusion::{
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
filter::FilterExec,
joins::{utils::JoinFilter, HashJoinExec, PartitionMode, SortMergeJoinExec},
joins::{utils::JoinFilter, CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec},
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
Expand Down Expand Up @@ -978,6 +978,15 @@ impl PhysicalPlanner {
)?);
Ok((scans, join))
}
OpStruct::CrossJoin(_) => {
assert!(children.len() == 2);
let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;

left_scans.append(&mut right_scans);
let join = Arc::new(CrossJoinExec::new(left, right));
Ok((left_scans, join))
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ message Operator {
Expand expand = 107;
SortMergeJoin sort_merge_join = 108;
HashJoin hash_join = 109;
CrossJoin cross_join = 110;
}
}

Expand Down Expand Up @@ -104,6 +105,9 @@ message SortMergeJoin {
repeated spark.spark_expression.Expr sort_options = 4;
}

message CrossJoin {
}

enum JoinType {
Inner = 0;
LeftOuter = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -698,6 +698,31 @@ class CometSparkSessionExtensions
withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(","))
s

case op: CartesianProductExec
if isCometOperatorEnabled(conf, "cross_join") &&
op.children.forall(isCometNative) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometCartesianProductExec(
nativeOp,
op,
op.condition,
op.left,
op.right,
SerializedPlan(None))
case None =>
op
}

case op: CartesianProductExec if !isCometOperatorEnabled(conf, "cross_join") =>
withInfo(op, "Union is not enabled")
op

case op: CartesianProductExec if !op.children.forall(isCometNative(_)) =>
withInfo(op, "Union disabled because not all child plans are native")
op

case op =>
// An operator that is not supported by Comet
op match {
Expand Down
14 changes: 13 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -2545,6 +2545,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
withInfo(join, "SortMergeJoin is not enabled")
None

case join: CartesianProductExec if isCometOperatorEnabled(op.conf, "cross_join") =>
// TODO: Support CartesianProductExec with join condition after new DataFusion release
if (join.condition.isDefined) {
withInfo(op, "cross_join with a join condition is not supported")
return None
}
None

case join: CartesianProductExec if !isCometOperatorEnabled(op.conf, "cross_join") =>
withInfo(join, "cross_join is not enabled")
None

case op if isCometSink(op) && op.output.forall(a => supportedDataType(a.dataType)) =>
// These operators are source of Comet native execution chain
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
Expand Down
33 changes: 33 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,39 @@ case class CometSortMergeJoinExec(
"join_time" -> SQLMetrics.createNanoTimingMetric(sparkContext, "Total time for joining"))
}

case class CometCartesianProductExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
condition: Option[Expression],
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {

override protected def withNewChildrenInternal(
newLeft: SparkPlan,
newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

override def stringArgs: Iterator[Any] = Iterator(condition, left, right)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometCartesianProductExec =>
this.condition == other.condition &&
this.left == other.left &&
this.right == other.right &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(condition, left, right)

override lazy val metrics: Map[String, SQLMetric] =
CometMetricNode.baselineMetrics(sparkContext)
}
case class CometScanWrapper(override val nativeOp: Operator, override val originalPlan: SparkPlan)
extends CometNativeExec
with LeafExecNode {
Expand Down
10 changes: 10 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,14 @@ class CometJoinSuite extends CometTestBase {
}
}
}

// TODO: Add a test for CartesianProductExec with join filter after new DataFusion release
test("CartesianProductExec without join filter") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
val df1 = sql("SELECT * FROM tbl_a CROSS JOIN tbl_b")
checkSparkAnswerAndOperator(df1)
}
}
}
}

0 comments on commit bac46b2

Please sign in to comment.