Skip to content

Commit

Permalink
Fix more
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 30, 2024
1 parent edfce1f commit 67a309b
Showing 1 changed file with 86 additions and 7 deletions.
93 changes: 86 additions & 7 deletions dev/diffs/3.4.2.diff
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,26 @@ index 9ddb4abe98b..2bd28d4041d 100644
withTable("tbl") {
sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 7dec558f8df..064cf6d4d97 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
@@ -2255,6 +2256,7 @@ class DatasetSuite extends QueryTest
// Assert that no extra shuffle introduced by cogroup.
val exchanges = collect(df3.queryExecution.executedPlan) {
case h: ShuffleExchangeExec => h
+ case c: CometShuffleExchangeExec => c
}
assert(exchanges.size == 2)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index f33432ddb6f..060f874ea72 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
Expand Down Expand Up @@ -390,34 +410,42 @@ index a6b295578d6..a5cb616945a 100644
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
val basePath = dir.getCanonicalPath + "/" + fmt
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 2796b1cf154..94591f83c84 100644
index 2796b1cf154..be7078b38f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, Literal}
import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt}
import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec}
+import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FilePartition
@@ -875,6 +876,7 @@ class FileBasedDataSourceSuite extends QueryTest
@@ -815,6 +816,7 @@ class FileBasedDataSourceSuite extends QueryTest
assert(bJoinExec.isEmpty)
val smJoinExec = collect(joinedDF.queryExecution.executedPlan) {
case smJoin: SortMergeJoinExec => smJoin
+ case smJoin: CometSortMergeJoinExec => smJoin
}
assert(smJoinExec.nonEmpty)
}
@@ -875,6 +877,7 @@ class FileBasedDataSourceSuite extends QueryTest

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
+ case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _), _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
@@ -916,6 +918,7 @@ class FileBasedDataSourceSuite extends QueryTest
@@ -916,6 +919,7 @@ class FileBasedDataSourceSuite extends QueryTest

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
+ case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _), _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
@@ -1100,6 +1103,8 @@ class FileBasedDataSourceSuite extends QueryTest
@@ -1100,6 +1104,8 @@ class FileBasedDataSourceSuite extends QueryTest
val filters = df.queryExecution.executedPlan.collect {
case f: FileSourceScanLike => f.dataFilters
case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
Expand Down Expand Up @@ -738,6 +766,45 @@ index cfc8b2cc845..c6fcfd7bd08 100644
}
} finally {
spark.listenerManager.unregister(listener)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index cf76f6ca32c..8a7c2b894ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression}
import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.comet.CometSortMergeJoinExec
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog
import org.apache.spark.sql.connector.catalog.functions._
@@ -31,7 +33,7 @@ import org.apache.spark.sql.connector.expressions.Expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._
@@ -279,13 +281,15 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50)))
}

- private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = {
+ private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = {
// here we skip collecting shuffle operators that are not associated with SMJ
collect(plan) {
case s: SortMergeJoinExec => s
+ case c: CometSortMergeJoinExec => c.originalPlan
}.flatMap(smj =>
collect(smj) {
case s: ShuffleExchangeExec => s
+ case c: CometShuffleExchangeExec => c
})
}

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index c0ec8a58bd5..4e8bc6ed3c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
Expand Down Expand Up @@ -1369,10 +1436,22 @@ index 3a0bd35cb70..b28f06a757f 100644
val workDirPath = workDir.getAbsolutePath
val input = spark.range(5).toDF("id")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 26e61c6b58d..cde10983c68 100644
index 26e61c6b58d..cb09d7e116a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -737,7 +737,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
@@ -45,8 +45,10 @@ import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.util.{AccumulatorContext, JsonProtocol}

// Disable AQE because metric info is different with AQE on/off
+// This test suite runs tests against the metrics of physical operators.
+// Disabling it for Comet because the metrics are different with Comet enabled.
class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
- with DisableAdaptiveExecutionSuite {
+ with DisableAdaptiveExecutionSuite with IgnoreCometSuite {
import testImplicits._

/**
@@ -737,7 +739,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
}
}

Expand Down

0 comments on commit 67a309b

Please sign in to comment.