Skip to content

Commit

Permalink
fix: ReusedExchangeExec can be child operator of CometBroadcastExchan…
Browse files Browse the repository at this point in the history
…geExec (apache#713)

(cherry picked from commit e52cfb4)
  • Loading branch information
viirya authored and huaxingao committed Aug 7, 2024
1 parent 4bfc18b commit bb18e0a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -129,6 +129,11 @@ case class CometBroadcastExchangeExec(
case AQEShuffleReadExec(s: ShuffleQueryStageExec, _)
if s.plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect()
case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect()
case AQEShuffleReadExec(ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), _)
if plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect()
case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) =>
throw new CometRuntimeException(
"Child of CometBroadcastExchangeExec should be CometExec, " +
Expand Down
12 changes: 12 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus

class CometJoinSuite extends CometTestBase {
import testImplicits._

override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
pos: Position): Unit = {
Expand All @@ -40,6 +41,17 @@ class CometJoinSuite extends CometTestBase {
}
}

test("join - self join") {
val df1 = testData.select(testData("key")).as("df1")
val df2 = testData.select(testData("key")).as("df2")

checkAnswer(
df1.join(df2, $"df1.key" === $"df2.key"),
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
.collect()
.toSeq)
}

test("SortMergeJoin with unsupported key type should fall back to Spark") {
withSQLConf(
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
Expand Down

0 comments on commit bb18e0a

Please sign in to comment.