diff --git a/.github/workflows/pr_build.yml b/.github/workflows/pr_build.yml
index 1e347250f..410f1e1fe 100644
--- a/.github/workflows/pr_build.yml
+++ b/.github/workflows/pr_build.yml
@@ -76,6 +76,44 @@ jobs:
# upload test reports only for java 17
upload-test-reports: ${{ matrix.java_version == '17' }}
+ linux-test-with-spark4_0:
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ java_version: [17]
+ test-target: [java]
+ spark-version: ['4.0']
+ is_push_event:
+ - ${{ github.event_name == 'push' }}
+ fail-fast: false
+ name: ${{ matrix.os }}/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }}
+ runs-on: ${{ matrix.os }}
+ container:
+ image: amd64/rust
+ steps:
+ - uses: actions/checkout@v4
+ - name: Setup Rust & Java toolchain
+ uses: ./.github/actions/setup-builder
+ with:
+ rust-version: ${{env.RUST_VERSION}}
+ jdk-version: ${{ matrix.java_version }}
+ - name: Clone Spark
+ uses: actions/checkout@v4
+ with:
+ repository: "apache/spark"
+ path: "apache-spark"
+ - name: Install Spark
+ shell: bash
+ working-directory: ./apache-spark
+ run: build/mvn install -Phive -Phadoop-cloud -DskipTests
+ - name: Java test steps
+ uses: ./.github/actions/java-test
+ with:
+ # TODO: remove -DskipTests after fixing tests
+ maven_opts: "-Pspark-${{ matrix.spark-version }} -DskipTests"
+ # TODO: upload test reports after enabling tests
+ upload-test-reports: false
+
linux-test-with-old-spark:
strategy:
matrix:
@@ -169,6 +207,81 @@ jobs:
with:
maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }}
+ macos-test-with-spark4_0:
+ strategy:
+ matrix:
+ os: [macos-13]
+ java_version: [17]
+ test-target: [java]
+ spark-version: ['4.0']
+ fail-fast: false
+ if: github.event_name == 'push'
+ name: ${{ matrix.os }}/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }}
+ runs-on: ${{ matrix.os }}
+ steps:
+ - uses: actions/checkout@v4
+ - name: Setup Rust & Java toolchain
+ uses: ./.github/actions/setup-macos-builder
+ with:
+ rust-version: ${{env.RUST_VERSION}}
+ jdk-version: ${{ matrix.java_version }}
+ - name: Clone Spark
+ uses: actions/checkout@v4
+ with:
+ repository: "apache/spark"
+ path: "apache-spark"
+ - name: Install Spark
+ shell: bash
+ working-directory: ./apache-spark
+ run: build/mvn install -Phive -Phadoop-cloud -DskipTests
+ - name: Java test steps
+ uses: ./.github/actions/java-test
+ with:
+ # TODO: remove -DskipTests after fixing tests
+ maven_opts: "-Pspark-${{ matrix.spark-version }} -DskipTests"
+ # TODO: upload test reports after enabling tests
+ upload-test-reports: false
+
+ macos-aarch64-test-with-spark4_0:
+ strategy:
+ matrix:
+ java_version: [17]
+ test-target: [java]
+ spark-version: ['4.0']
+ is_push_event:
+ - ${{ github.event_name == 'push' }}
+ exclude: # exclude java 11 for pull_request event
+ - java_version: 11
+ is_push_event: false
+ fail-fast: false
+ name: macos-14(Silicon)/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }}
+ runs-on: macos-14
+ steps:
+ - uses: actions/checkout@v4
+ - name: Setup Rust & Java toolchain
+ uses: ./.github/actions/setup-macos-builder
+ with:
+ rust-version: ${{env.RUST_VERSION}}
+ jdk-version: ${{ matrix.java_version }}
+ jdk-architecture: aarch64
+ protoc-architecture: aarch_64
+ - name: Clone Spark
+ uses: actions/checkout@v4
+ with:
+ repository: "apache/spark"
+ path: "apache-spark"
+ - name: Install Spark
+ shell: bash
+ working-directory: ./apache-spark
+ run: build/mvn install -Phive -Phadoop-cloud -DskipTests
+ - name: Java test steps
+ uses: ./.github/actions/java-test
+ with:
+ # TODO: remove -DskipTests after fixing tests
+ maven_opts: "-Pspark-${{ matrix.spark-version }} -DskipTests"
+ # TODO: upload test reports after enabling tests
+ upload-test-reports: false
+
macos-aarch64-test-with-old-spark:
strategy:
matrix:
diff --git a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
index d851067b5..d03252d06 100644
--- a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
+++ b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
@@ -20,10 +20,10 @@
package org.apache.comet.parquet
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.sql.comet.shims.ShimCometParquetUtils
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types._
-object CometParquetUtils {
+object CometParquetUtils extends ShimCometParquetUtils {
private val PARQUET_FIELD_ID_WRITE_ENABLED = "spark.sql.parquet.fieldId.write.enabled"
private val PARQUET_FIELD_ID_READ_ENABLED = "spark.sql.parquet.fieldId.read.enabled"
private val IGNORE_MISSING_PARQUET_FIELD_ID = "spark.sql.parquet.fieldId.read.ignoreMissing"
@@ -39,61 +39,4 @@ object CometParquetUtils {
def ignoreMissingIds(conf: SQLConf): Boolean =
conf.getConfString(IGNORE_MISSING_PARQUET_FIELD_ID, "false").toBoolean
-
- // The following is copied from QueryExecutionErrors
- // TODO: remove after dropping Spark 3.2.0 support and directly use
- // QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError
- def foundDuplicateFieldInFieldIdLookupModeError(
- requiredId: Int,
- matchedFields: String): Throwable = {
- new RuntimeException(s"""
- |Found duplicate field(s) "$requiredId": $matchedFields
- |in id mapping mode
- """.stripMargin.replaceAll("\n", " "))
- }
-
- // The followings are copied from org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
- // TODO: remove after dropping Spark 3.2.0 support and directly use ParquetUtils
- /**
- * A StructField metadata key used to set the field id of a column in the Parquet schema.
- */
- val FIELD_ID_METADATA_KEY = "parquet.field.id"
-
- /**
- * Whether there exists a field in the schema, whether inner or leaf, has the parquet field ID
- * metadata.
- */
- def hasFieldIds(schema: StructType): Boolean = {
- def recursiveCheck(schema: DataType): Boolean = {
- schema match {
- case st: StructType =>
- st.exists(field => hasFieldId(field) || recursiveCheck(field.dataType))
-
- case at: ArrayType => recursiveCheck(at.elementType)
-
- case mt: MapType => recursiveCheck(mt.keyType) || recursiveCheck(mt.valueType)
-
- case _ =>
- // No need to really check primitive types, just to terminate the recursion
- false
- }
- }
- if (schema.isEmpty) false else recursiveCheck(schema)
- }
-
- def hasFieldId(field: StructField): Boolean =
- field.metadata.contains(FIELD_ID_METADATA_KEY)
-
- def getFieldId(field: StructField): Int = {
- require(
- hasFieldId(field),
- s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + field)
- try {
- Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY))
- } catch {
- case _: ArithmeticException | _: ClassCastException =>
- throw new IllegalArgumentException(
- s"The key `$FIELD_ID_METADATA_KEY` must be a 32-bit integer")
- }
- }
}
diff --git a/common/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala b/common/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala
new file mode 100644
index 000000000..f22ac4060
--- /dev/null
+++ b/common/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.sql.types._
+
+trait ShimCometParquetUtils {
+ // The following is copied from QueryExecutionErrors
+ // TODO: remove after dropping Spark 3.2.0 support and directly use
+ // QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError
+ def foundDuplicateFieldInFieldIdLookupModeError(
+ requiredId: Int,
+ matchedFields: String): Throwable = {
+ new RuntimeException(s"""
+ |Found duplicate field(s) "$requiredId": $matchedFields
+ |in id mapping mode
+ """.stripMargin.replaceAll("\n", " "))
+ }
+
+ // The followings are copied from org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
+ // TODO: remove after dropping Spark 3.2.0 support and directly use ParquetUtils
+ /**
+ * A StructField metadata key used to set the field id of a column in the Parquet schema.
+ */
+ val FIELD_ID_METADATA_KEY = "parquet.field.id"
+
+ /**
+ * Whether there exists a field in the schema, whether inner or leaf, has the parquet field ID
+ * metadata.
+ */
+ def hasFieldIds(schema: StructType): Boolean = {
+ def recursiveCheck(schema: DataType): Boolean = {
+ schema match {
+ case st: StructType =>
+ st.exists(field => hasFieldId(field) || recursiveCheck(field.dataType))
+
+ case at: ArrayType => recursiveCheck(at.elementType)
+
+ case mt: MapType => recursiveCheck(mt.keyType) || recursiveCheck(mt.valueType)
+
+ case _ =>
+ // No need to really check primitive types, just to terminate the recursion
+ false
+ }
+ }
+ if (schema.isEmpty) false else recursiveCheck(schema)
+ }
+
+ def hasFieldId(field: StructField): Boolean =
+ field.metadata.contains(FIELD_ID_METADATA_KEY)
+
+ def getFieldId(field: StructField): Int = {
+ require(
+ hasFieldId(field),
+ s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + field)
+ try {
+ Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY))
+ } catch {
+ case _: ArithmeticException | _: ClassCastException =>
+ throw new IllegalArgumentException(
+ s"The key `$FIELD_ID_METADATA_KEY` must be a 32-bit integer")
+ }
+ }
+}
diff --git a/common/src/main/spark-4.0/org/apache/comet/shims/ShimBatchReader.scala b/common/src/main/spark-4.0/org/apache/comet/shims/ShimBatchReader.scala
new file mode 100644
index 000000000..448d0886c
--- /dev/null
+++ b/common/src/main/spark-4.0/org/apache/comet/shims/ShimBatchReader.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.paths.SparkPath
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.PartitionedFile
+
+object ShimBatchReader {
+ def newPartitionedFile(partitionValues: InternalRow, file: String): PartitionedFile =
+ PartitionedFile(
+ partitionValues,
+ SparkPath.fromUrlString(file),
+ -1, // -1 means we read the entire file
+ -1,
+ Array.empty[String],
+ 0,
+ 0
+ )
+}
diff --git a/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala b/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala
new file mode 100644
index 000000000..2f386869a
--- /dev/null
+++ b/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.execution.datasources.FileFormat
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+
+object ShimFileFormat {
+ // A name for a temporary column that holds row indexes computed by the file format reader
+ // until they can be placed in the _metadata struct.
+ val ROW_INDEX_TEMPORARY_COLUMN_NAME = ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME
+
+ val OPTION_RETURNING_BATCH = FileFormat.OPTION_RETURNING_BATCH
+}
diff --git a/common/src/main/spark-4.0/org/apache/comet/shims/ShimResolveDefaultColumns.scala b/common/src/main/spark-4.0/org/apache/comet/shims/ShimResolveDefaultColumns.scala
new file mode 100644
index 000000000..60e21765b
--- /dev/null
+++ b/common/src/main/spark-4.0/org/apache/comet/shims/ShimResolveDefaultColumns.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+
+import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
+import org.apache.spark.sql.types.{StructField, StructType}
+
+object ShimResolveDefaultColumns {
+ def getExistenceDefaultValue(field: StructField): Any =
+ ResolveDefaultColumns.getExistenceDefaultValues(StructType(Seq(field))).head
+}
diff --git a/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala b/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala
new file mode 100644
index 000000000..d402cd786
--- /dev/null
+++ b/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
+import org.apache.spark.sql.types._
+
+trait ShimCometParquetUtils {
+ def foundDuplicateFieldInFieldIdLookupModeError(
+ requiredId: Int,
+ matchedFields: String): Throwable = {
+ QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError(requiredId, matchedFields)
+ }
+
+ def hasFieldIds(schema: StructType): Boolean = ParquetUtils.hasFieldIds(schema)
+
+ def hasFieldId(field: StructField): Boolean = ParquetUtils.hasFieldId(field)
+
+ def getFieldId(field: StructField): Int = ParquetUtils.getFieldId (field)
+}
diff --git a/dev/ensure-jars-have-correct-contents.sh b/dev/ensure-jars-have-correct-contents.sh
index 1f97d2d4a..12f555b8e 100755
--- a/dev/ensure-jars-have-correct-contents.sh
+++ b/dev/ensure-jars-have-correct-contents.sh
@@ -40,9 +40,11 @@ allowed_expr="(^org/$|^org/apache/$"
# we have to allow the directories that lead to the org/apache/comet dir
# We allow all the classes under the following packages:
# * org.apache.comet
+# * org.apache.spark.comet
# * org.apache.spark.sql.comet
# * org.apache.arrow.c
allowed_expr+="|^org/apache/comet/"
+allowed_expr+="|^org/apache/spark/comet/"
allowed_expr+="|^org/apache/spark/sql/comet/"
allowed_expr+="|^org/apache/arrow/c/"
# * whatever in the "META-INF" directory
diff --git a/pom.xml b/pom.xml
index 59e0569ff..57b4206cf 100644
--- a/pom.xml
+++ b/pom.xml
@@ -87,7 +87,7 @@ under the License.
-ea -Xmx4g -Xss4m ${extraJavaTestArgs}spark-3.3-plus
- spark-3.4
+ spark-3.4-plusspark-3.xspark-3.4
@@ -512,7 +512,6 @@ under the License.
3.3.23.31.12.0
- spark-3.3-plusnot-needed-yetspark-3.3
@@ -524,9 +523,25 @@ under the License.
2.12.173.41.13.1
- spark-3.3-plus
- spark-3.4
- spark-3.4
+
+
+
+
+
+ spark-4.0
+
+
+ 2.13.13
+ 2.13
+ 4.0.0-SNAPSHOT
+ 4.0
+ 1.13.1
+ spark-4.0
+ not-needed-yet
+
+ 17
+ ${java.version}
+ ${java.version}
@@ -605,7 +620,7 @@ under the License.
org.scalametasemanticdb-scalac_${scala.version}
- 4.7.5
+ 4.8.8
diff --git a/spark/pom.xml b/spark/pom.xml
index 21fa09fc2..84e2e501f 100644
--- a/spark/pom.xml
+++ b/spark/pom.xml
@@ -252,6 +252,8 @@ under the License.
+
+
diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
index ac871cf60..52d8d09a0 100644
--- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
+++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
@@ -37,7 +37,6 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{DateType, StructType, TimestampType}
import org.apache.spark.util.SerializableConfiguration
@@ -144,7 +143,7 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with
isCaseSensitive,
useFieldId,
ignoreMissingIds,
- datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED,
+ datetimeRebaseSpec.mode == CORRECTED,
partitionSchema,
file.partitionValues,
JavaConverters.mapAsJavaMap(metrics))
@@ -161,7 +160,7 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with
}
}
-object CometParquetFileFormat extends Logging {
+object CometParquetFileFormat extends Logging with ShimSQLConf {
/**
* Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf`
@@ -210,7 +209,7 @@ object CometParquetFileFormat extends Logging {
case _ => false
})
- if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) {
+ if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LEGACY) {
if (exceptionOnRebase) {
logWarning(
s"""Found Parquet file $file that could potentially contain dates/timestamps that were
@@ -222,7 +221,7 @@ object CometParquetFileFormat extends Logging {
calendar, please disable Comet for this query.""")
} else {
// do not throw exception on rebase - read as it is
- datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED)
+ datetimeRebaseSpec = datetimeRebaseSpec.copy(CORRECTED)
}
}
diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
index 693af125b..e48d76384 100644
--- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
+++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
@@ -37,7 +37,6 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -135,7 +134,7 @@ case class CometParquetPartitionReaderFactory(
isCaseSensitive,
useFieldId,
ignoreMissingIds,
- datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED,
+ datetimeRebaseSpec.mode == CORRECTED,
partitionSchema,
file.partitionValues,
JavaConverters.mapAsJavaMap(metrics))
diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
index 5994dfb41..58c2aeb41 100644
--- a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
+++ b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
@@ -38,11 +38,11 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.parquet.schema.Type.Repetition
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec}
-import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.sources
import org.apache.spark.unsafe.types.UTF8String
import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
+import org.apache.comet.shims.ShimSQLConf
/**
* Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove
@@ -58,7 +58,8 @@ class ParquetFilters(
pushDownStringPredicate: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
- datetimeRebaseSpec: RebaseSpec) {
+ datetimeRebaseSpec: RebaseSpec)
+ extends ShimSQLConf {
// A map which contains parquet field name and data type, if predicate push down applies.
//
// Each key in `nameToParquetField` represents a column; `dots` are used as separators for
@@ -153,7 +154,7 @@ class ParquetFilters(
case ld: LocalDate => DateTimeUtils.localDateToDays(ld)
}
datetimeRebaseSpec.mode match {
- case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays)
+ case LEGACY => rebaseGregorianToJulianDays(gregorianDays)
case _ => gregorianDays
}
}
@@ -164,7 +165,7 @@ class ParquetFilters(
case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t)
}
datetimeRebaseSpec.mode match {
- case LegacyBehaviorPolicy.LEGACY =>
+ case LEGACY =>
rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros)
case _ => gregorianMicros
}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 6333650dd..a717e066e 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1987,18 +1987,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
// char types. Use rpad to achieve the behavior.
// See https://github.com/apache/spark/pull/38151
- case StaticInvoke(
- _: Class[CharVarcharCodegenUtils],
- _: StringType,
- "readSidePadding",
- arguments,
- _,
- true,
- false,
- true) if arguments.size == 2 =>
+ case s: StaticInvoke
+ if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
+ s.dataType.isInstanceOf[StringType] &&
+ s.functionName == "readSidePadding" &&
+ s.arguments.size == 2 &&
+ s.propagateNull &&
+ !s.returnNullable &&
+ s.isDeterministic =>
val argsExpr = Seq(
- exprToProtoInternal(Cast(arguments(0), StringType), inputs),
- exprToProtoInternal(arguments(1), inputs))
+ exprToProtoInternal(Cast(s.arguments(0), StringType), inputs),
+ exprToProtoInternal(s.arguments(1), inputs))
if (argsExpr.forall(_.isDefined)) {
val builder = ExprOuterClass.ScalarFunc.newBuilder()
@@ -2007,7 +2006,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
} else {
- withInfo(expr, arguments: _*)
+ withInfo(expr, s.arguments: _*)
None
}
diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala
index 97838448a..dcc00f66c 100644
--- a/spark/src/main/scala/org/apache/spark/Plugins.scala
+++ b/spark/src/main/scala/org/apache/spark/Plugins.scala
@@ -23,9 +23,9 @@ import java.{util => ju}
import java.util.Collections
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin}
+import org.apache.spark.comet.shims.ShimCometDriverPlugin
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD}
-import org.apache.spark.resource.ResourceProfile
import org.apache.comet.{CometConf, CometSparkSessionExtensions}
@@ -40,9 +40,7 @@ import org.apache.comet.{CometConf, CometSparkSessionExtensions}
*
* To enable this plugin, set the config "spark.plugins" to `org.apache.spark.CometPlugin`.
*/
-class CometDriverPlugin extends DriverPlugin with Logging {
- import CometDriverPlugin._
-
+class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPlugin {
override def init(sc: SparkContext, pluginContext: PluginContext): ju.Map[String, String] = {
logInfo("CometDriverPlugin init")
@@ -52,14 +50,10 @@ class CometDriverPlugin extends DriverPlugin with Logging {
} else {
// By default, executorMemory * spark.executor.memoryOverheadFactor, with minimum of 384MB
val executorMemory = sc.getConf.getSizeAsMb(EXECUTOR_MEMORY.key)
- val memoryOverheadFactor =
- sc.getConf.getDouble(
- EXECUTOR_MEMORY_OVERHEAD_FACTOR,
- EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT)
-
- Math.max(
- (executorMemory * memoryOverheadFactor).toInt,
- ResourceProfile.MEMORY_OVERHEAD_MIN_MIB)
+ val memoryOverheadFactor = getMemoryOverheadFactor(sc.getConf)
+ val memoryOverheadMinMib = getMemoryOverheadMinMib(sc.getConf)
+
+ Math.max((executorMemory * memoryOverheadFactor).toLong, memoryOverheadMinMib)
}
val cometMemOverhead = CometSparkSessionExtensions.getCometMemoryOverheadInMiB(sc.getConf)
@@ -100,12 +94,6 @@ class CometDriverPlugin extends DriverPlugin with Logging {
}
}
-object CometDriverPlugin {
- // `org.apache.spark.internal.config.EXECUTOR_MEMORY_OVERHEAD_FACTOR` was added since Spark 3.3.0
- val EXECUTOR_MEMORY_OVERHEAD_FACTOR = "spark.executor.memoryOverheadFactor"
- val EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT = 0.1
-}
-
/**
* The Comet plugin for Spark. To enable this plugin, set the config "spark.plugins" to
* `org.apache.spark.CometPlugin`
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index 7bd34debb..38247b2c4 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -27,6 +27,7 @@ import scala.concurrent.duration.NANOSECONDS
import scala.util.control.NonFatal
import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
+import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -44,7 +45,6 @@ import org.apache.spark.util.io.ChunkedByteBuffer
import com.google.common.base.Objects
import org.apache.comet.CometRuntimeException
-import org.apache.comet.shims.ShimCometBroadcastExchangeExec
/**
* A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
index 14a664108..9a5b55d65 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.comet.shims.ShimCometScanExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
@@ -43,7 +44,7 @@ import org.apache.spark.util.collection._
import org.apache.comet.{CometConf, MetricsSupport}
import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory}
-import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat}
+import org.apache.comet.shims.ShimFileFormat
/**
* Comet physical scan node for DataSource V1. Most of the code here follow Spark's
@@ -271,7 +272,7 @@ case class CometScanExec(
selectedPartitions
.flatMap { p =>
p.files.map { f =>
- PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values)
+ getPartitionedFile(f, p)
}
}
.groupBy { f =>
@@ -358,7 +359,7 @@ case class CometScanExec(
// SPARK-39634: Allow file splitting in combination with row index generation once
// the fix for PARQUET-2161 is available.
!isNeededForSchema(requiredSchema)
- PartitionedFileUtil.splitFiles(
+ super.splitFiles(
sparkSession = relation.sparkSession,
file = file,
filePath = filePath,
@@ -409,7 +410,7 @@ case class CometScanExec(
Map.empty)
} else {
newFileScanRDD(
- fsRelation.sparkSession,
+ fsRelation,
readFile,
partitions,
new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields),
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala b/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala
index 13f26ce58..a2cdf421c 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala
@@ -59,7 +59,7 @@ object DecimalPrecision {
}
CheckOverflow(add, resultType, nullOnOverflow)
- case sub @ Subtract(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) =>
+ case sub @ Subtract(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
@@ -68,7 +68,7 @@ object DecimalPrecision {
}
CheckOverflow(sub, resultType, nullOnOverflow)
- case mul @ Multiply(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) =>
+ case mul @ Multiply(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
@@ -76,7 +76,7 @@ object DecimalPrecision {
}
CheckOverflow(mul, resultType, nullOnOverflow)
- case div @ Divide(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) =>
+ case div @ Divide(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
@@ -96,7 +96,7 @@ object DecimalPrecision {
}
CheckOverflow(div, resultType, nullOnOverflow)
- case rem @ Remainder(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) =>
+ case rem @ Remainder(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
@@ -108,6 +108,7 @@ object DecimalPrecision {
}
}
+ // TODO: consider to use `org.apache.spark.sql.types.DecimalExpression` for Spark 3.5+
object DecimalExpression {
def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
case t: DecimalType => Some((t.precision, t.scale))
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 49c263f3f..3f4d7bfd3 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -28,10 +28,10 @@ import scala.concurrent.Future
import org.apache.spark._
import org.apache.spark.internal.config
-import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
+import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
+import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
@@ -39,12 +39,12 @@ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan}
+import org.apache.spark.sql.comet.shims.ShimCometShuffleWriteProcessor
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
@@ -68,7 +68,8 @@ case class CometShuffleExchangeExec(
shuffleType: ShuffleType = CometNativeShuffle,
advisoryPartitionSize: Option[Long] = None)
extends ShuffleExchangeLike
- with CometPlan {
+ with CometPlan
+ with ShimCometShuffleExchangeExec {
private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
@@ -127,6 +128,9 @@ case class CometShuffleExchangeExec(
Statistics(dataSize, Some(rowCount))
}
+ // TODO: add `override` keyword after dropping Spark-3.x supports
+ def shuffleId: Int = getShuffleId(shuffleDependency)
+
/**
* A [[ShuffleDependency]] that will partition rows of its child based on the partitioning
* scheme defined in `newPartitioning`. Those partitions of the returned ShuffleDependency will
@@ -386,7 +390,7 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
- StructType.fromAttributes(outputAttributes),
+ fromAttributes(outputAttributes),
recordComparatorSupplier,
prefixComparator,
prefixComputer,
@@ -430,7 +434,7 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
serializer,
shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
shuffleType = CometColumnarShuffle,
- schema = Some(StructType.fromAttributes(outputAttributes)))
+ schema = Some(fromAttributes(outputAttributes)))
dependency
}
@@ -445,7 +449,7 @@ class CometShuffleWriteProcessor(
outputPartitioning: Partitioning,
outputAttributes: Seq[Attribute],
metrics: Map[String, SQLMetric])
- extends ShuffleWriteProcessor {
+ extends ShimCometShuffleWriteProcessor {
private val OFFSET_LENGTH = 8
@@ -455,11 +459,11 @@ class CometShuffleWriteProcessor(
}
override def write(
- rdd: RDD[_],
+ inputs: Iterator[_],
dep: ShuffleDependency[_, _, _],
mapId: Long,
- context: TaskContext,
- partition: Partition): MapStatus = {
+ mapIndex: Int,
+ context: TaskContext): MapStatus = {
val shuffleBlockResolver =
SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver]
val dataFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
@@ -469,10 +473,6 @@ class CometShuffleWriteProcessor(
val tempDataFilePath = Paths.get(tempDataFilename)
val tempIndexFilePath = Paths.get(tempIndexFilename)
- // Getting rid of the fake partitionId
- val cometRDD =
- rdd.asInstanceOf[MapPartitionsRDD[_, _]].prev.asInstanceOf[RDD[ColumnarBatch]]
-
// Call native shuffle write
val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename)
@@ -482,8 +482,12 @@ class CometShuffleWriteProcessor(
"elapsed_compute" -> metrics("shuffleReadElapsedCompute"))
val nativeMetrics = CometMetricNode(nativeSQLMetrics)
- val rawIter = cometRDD.iterator(partition, context)
- val cometIter = CometExec.getCometIterator(Seq(rawIter), nativePlan, nativeMetrics)
+ // Getting rid of the fake partitionId
+ val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2)
+ val cometIter = CometExec.getCometIterator(
+ Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
+ nativePlan,
+ nativeMetrics)
while (cometIter.hasNext) {
cometIter.next()
diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
index 0c45a9c2c..f5a578f82 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
*/
trait CometExprShim {
/**
- * Returns a tuple of expressions for the `unhex` function.
- */
+ * Returns a tuple of expressions for the `unhex` function.
+ */
def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(false))
}
diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
index 0c45a9c2c..f5a578f82 100644
--- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
*/
trait CometExprShim {
/**
- * Returns a tuple of expressions for the `unhex` function.
- */
+ * Returns a tuple of expressions for the `unhex` function.
+ */
def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(false))
}
diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
index 409e1c94b..3f2301f0a 100644
--- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
*/
trait CometExprShim {
/**
- * Returns a tuple of expressions for the `unhex` function.
- */
+ * Returns a tuple of expressions for the `unhex` function.
+ */
def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(unhex.failOnError))
}
diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
index 6b4fad974..350aeb9f0 100644
--- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
+++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
@@ -19,8 +19,11 @@
package org.apache.comet.shims
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, ShuffleType}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.types.{StructField, StructType}
trait ShimCometShuffleExchangeExec {
// TODO: remove after dropping Spark 3.2 and 3.3 support
@@ -37,4 +40,11 @@ trait ShimCometShuffleExchangeExec {
shuffleType,
advisoryPartitionSize)
}
+
+ // TODO: remove after dropping Spark 3.x support
+ protected def fromAttributes(attributes: Seq[Attribute]): StructType =
+ StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
+
+ // TODO: remove after dropping Spark 3.x support
+ protected def getShuffleId(shuffleDependency: ShuffleDependency[Int, _, _]): Int = 0
}
diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
index eb04c68ab..377485335 100644
--- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
+++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
@@ -24,8 +24,6 @@ import org.apache.spark.sql.execution.{LimitExec, QueryExecution, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
trait ShimCometSparkSessionExtensions {
- import org.apache.comet.shims.ShimCometSparkSessionExtensions._
-
/**
* TODO: delete after dropping Spark 3.2.0 support and directly call scan.pushedAggregate
*/
@@ -45,9 +43,7 @@ trait ShimCometSparkSessionExtensions {
* SQLConf.EXTENDED_EXPLAIN_PROVIDERS.key
*/
protected val EXTENDED_EXPLAIN_PROVIDERS_KEY = "spark.sql.extendedExplainProviders"
-}
-object ShimCometSparkSessionExtensions {
private def getOffsetOpt(plan: SparkPlan): Option[Int] = plan.getClass.getDeclaredFields
.filter(_.getName == "offset")
.map { a => a.setAccessible(true); a.get(plan) }
diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala
index ff60ef964..c3d0c56e5 100644
--- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala
+++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala
@@ -20,6 +20,7 @@
package org.apache.comet.shims
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
trait ShimSQLConf {
@@ -39,4 +40,7 @@ trait ShimSQLConf {
case _ => None
})
.head
+
+ protected val LEGACY = LegacyBehaviorPolicy.LEGACY
+ protected val CORRECTED = LegacyBehaviorPolicy.CORRECTED
}
diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
similarity index 98%
rename from spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala
rename to spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
index 63ff2a2c1..aede47951 100644
--- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala
+++ b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.comet.shims
+package org.apache.spark.comet.shims
import scala.reflect.ClassTag
diff --git a/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
new file mode 100644
index 000000000..cfb6a0088
--- /dev/null
+++ b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.comet.shims
+
+import org.apache.spark.SparkConf
+
+trait ShimCometDriverPlugin {
+ // `org.apache.spark.internal.config.EXECUTOR_MEMORY_OVERHEAD_FACTOR` was added since Spark 3.3.0
+ private val EXECUTOR_MEMORY_OVERHEAD_FACTOR = "spark.executor.memoryOverheadFactor"
+ private val EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT = 0.1
+ // `org.apache.spark.internal.config.EXECUTOR_MIN_MEMORY_OVERHEAD` was added since Spark 4.0.0
+ private val EXECUTOR_MIN_MEMORY_OVERHEAD = "spark.executor.minMemoryOverhead"
+ private val EXECUTOR_MIN_MEMORY_OVERHEAD_DEFAULT = 384L
+
+ def getMemoryOverheadFactor(sc: SparkConf): Double =
+ sc.getDouble(
+ EXECUTOR_MEMORY_OVERHEAD_FACTOR,
+ EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT)
+ def getMemoryOverheadMinMib(sc: SparkConf): Long =
+ sc.getLong(EXECUTOR_MIN_MEMORY_OVERHEAD, EXECUTOR_MIN_MEMORY_OVERHEAD_DEFAULT)
+}
diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
similarity index 81%
rename from spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometScanExec.scala
rename to spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
index 544a67385..02b97f9fb 100644
--- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometScanExec.scala
+++ b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
@@ -17,17 +17,21 @@
* under the License.
*/
-package org.apache.comet.shims
+package org.apache.spark.sql.comet.shims
+
+import org.apache.comet.shims.ShimFileFormat
import scala.language.implicitConversions
+import org.apache.hadoop.fs.{FileStatus, Path}
+
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
-import org.apache.spark.sql.execution.FileSourceScanExec
-import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
+import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
+import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -63,7 +67,7 @@ trait ShimCometScanExec {
// TODO: remove after dropping Spark 3.2 support and directly call new FileScanRDD
protected def newFileScanRDD(
- sparkSession: SparkSession,
+ fsRelation: HadoopFsRelation,
readFunction: PartitionedFile => Iterator[InternalRow],
filePartitions: Seq[FilePartition],
readSchema: StructType,
@@ -73,12 +77,12 @@ trait ShimCometScanExec {
.filter(c => List(3, 5, 6).contains(c.getParameterCount()) )
.map { c =>
c.getParameterCount match {
- case 3 => c.newInstance(sparkSession, readFunction, filePartitions)
+ case 3 => c.newInstance(fsRelation.sparkSession, readFunction, filePartitions)
case 5 =>
- c.newInstance(sparkSession, readFunction, filePartitions, readSchema, metadataColumns)
+ c.newInstance(fsRelation.sparkSession, readFunction, filePartitions, readSchema, metadataColumns)
case 6 =>
c.newInstance(
- sparkSession,
+ fsRelation.sparkSession,
readFunction,
filePartitions,
readSchema,
@@ -123,4 +127,15 @@ trait ShimCometScanExec {
protected def isNeededForSchema(sparkSchema: StructType): Boolean = {
findRowIndexColumnIndexInSchema(sparkSchema) >= 0
}
+
+ protected def getPartitionedFile(f: FileStatus, p: PartitionDirectory): PartitionedFile =
+ PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values)
+
+ protected def splitFiles(sparkSession: SparkSession,
+ file: FileStatus,
+ filePath: Path,
+ isSplitable: Boolean,
+ maxSplitBytes: Long,
+ partitionValues: InternalRow): Seq[PartitionedFile] =
+ PartitionedFileUtil.splitFiles(sparkSession, file, filePath, isSplitable, maxSplitBytes, partitionValues)
}
diff --git a/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
new file mode 100644
index 000000000..9100b90c2
--- /dev/null
+++ b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.{Partition, ShuffleDependency, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.shuffle.ShuffleWriteProcessor
+
+trait ShimCometShuffleWriteProcessor extends ShuffleWriteProcessor {
+ override def write(
+ rdd: RDD[_],
+ dep: ShuffleDependency[_, _, _],
+ mapId: Long,
+ context: TaskContext,
+ partition: Partition): MapStatus = {
+ val rawIter = rdd.iterator(partition, context)
+ write(rawIter, dep, mapId, partition.index, context)
+ }
+
+ def write(
+ inputs: Iterator[_],
+ dep: ShuffleDependency[_, _, _],
+ mapId: Long,
+ mapIndex: Int,
+ context: TaskContext): MapStatus
+}
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
new file mode 100644
index 000000000..01f923206
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions.
+ */
+trait CometExprShim {
+ /**
+ * Returns a tuple of expressions for the `unhex` function.
+ */
+ protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+ (unhex.child, Literal(unhex.failOnError))
+ }
+}
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBatchScanExec.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBatchScanExec.scala
new file mode 100644
index 000000000..167b539f8
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBatchScanExec.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
+import org.apache.spark.sql.connector.read.InputPartition
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+
+trait ShimCometBatchScanExec {
+ def wrapped: BatchScanExec
+
+ def keyGroupedPartitioning: Option[Seq[Expression]] = wrapped.keyGroupedPartitioning
+
+ def inputPartitions: Seq[InputPartition] = wrapped.inputPartitions
+
+ def ordering: Option[Seq[SortOrder]] = wrapped.ordering
+}
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
new file mode 100644
index 000000000..1f689b400
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning}
+
+trait ShimCometBroadcastHashJoinExec {
+ protected def getHashPartitioningLikeExpressions(partitioning: Partitioning): Seq[Expression] =
+ partitioning match {
+ case p: HashPartitioningLike => p.expressions
+ case _ => Seq()
+ }
+}
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
new file mode 100644
index 000000000..559e327b4
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, ShuffleType}
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.types.StructType
+
+trait ShimCometShuffleExchangeExec {
+ def apply(s: ShuffleExchangeExec, shuffleType: ShuffleType): CometShuffleExchangeExec = {
+ CometShuffleExchangeExec(
+ s.outputPartitioning,
+ s.child,
+ s,
+ s.shuffleOrigin,
+ shuffleType,
+ s.advisoryPartitionSize)
+ }
+
+ protected def fromAttributes(attributes: Seq[Attribute]): StructType = DataTypeUtils.fromAttributes(attributes)
+
+ protected def getShuffleId(shuffleDependency: ShuffleDependency[Int, _, _]): Int = shuffleDependency.shuffleId
+}
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
new file mode 100644
index 000000000..9fb7355ee
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.execution.{CollectLimitExec, GlobalLimitExec, LocalLimitExec, QueryExecution}
+import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
+import org.apache.spark.sql.internal.SQLConf
+
+trait ShimCometSparkSessionExtensions {
+ protected def getPushedAggregate(scan: ParquetScan): Option[Aggregation] = scan.pushedAggregate
+
+ protected def getOffset(limit: LocalLimitExec): Int = 0
+ protected def getOffset(limit: GlobalLimitExec): Int = limit.offset
+ protected def getOffset(limit: CollectLimitExec): Int = limit.offset
+
+ protected def supportsExtendedExplainInfo(qe: QueryExecution): Boolean = true
+
+ protected val EXTENDED_EXPLAIN_PROVIDERS_KEY = SQLConf.EXTENDED_EXPLAIN_PROVIDERS.key
+}
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometTakeOrderedAndProjectExec.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometTakeOrderedAndProjectExec.scala
new file mode 100644
index 000000000..5a8ac97b3
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometTakeOrderedAndProjectExec.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.execution.TakeOrderedAndProjectExec
+
+trait ShimCometTakeOrderedAndProjectExec {
+ protected def getOffset(plan: TakeOrderedAndProjectExec): Option[Int] = Some(plan.offset)
+}
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala
new file mode 100644
index 000000000..4d261f3c2
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, BinaryExpression, BloomFilterMightContain, EvalMode}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
+
+trait ShimQueryPlanSerde {
+ protected def getFailOnError(b: BinaryArithmetic): Boolean =
+ b.getClass.getMethod("failOnError").invoke(b).asInstanceOf[Boolean]
+
+ protected def getFailOnError(aggregate: Sum): Boolean = aggregate.initQueryContext().isDefined
+ protected def getFailOnError(aggregate: Average): Boolean = aggregate.initQueryContext().isDefined
+
+ protected def isLegacyMode(aggregate: Sum): Boolean = aggregate.evalMode.equals(EvalMode.LEGACY)
+ protected def isLegacyMode(aggregate: Average): Boolean = aggregate.evalMode.equals(EvalMode.LEGACY)
+
+ protected def isBloomFilterMightContain(binary: BinaryExpression): Boolean =
+ binary.isInstanceOf[BloomFilterMightContain]
+}
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala
new file mode 100644
index 000000000..574967767
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.LegacyBehaviorPolicy
+
+trait ShimSQLConf {
+ protected def getPushDownStringPredicate(sqlConf: SQLConf): Boolean =
+ sqlConf.parquetFilterPushDownStringPredicate
+
+ protected val LEGACY = LegacyBehaviorPolicy.LEGACY
+ protected val CORRECTED = LegacyBehaviorPolicy.CORRECTED
+}
diff --git a/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
new file mode 100644
index 000000000..ba87a2515
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.comet.shims
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkContext
+import org.apache.spark.broadcast.Broadcast
+
+trait ShimCometBroadcastExchangeExec {
+ protected def doBroadcast[T: ClassTag](sparkContext: SparkContext, value: T): Broadcast[Any] =
+ sparkContext.broadcastInternal(value, true)
+}
diff --git a/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
new file mode 100644
index 000000000..f7a57a642
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.comet.shims
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.EXECUTOR_MEMORY_OVERHEAD_FACTOR
+import org.apache.spark.internal.config.EXECUTOR_MIN_MEMORY_OVERHEAD
+
+trait ShimCometDriverPlugin {
+ protected def getMemoryOverheadFactor(sparkConf: SparkConf): Double = sparkConf.get(
+ EXECUTOR_MEMORY_OVERHEAD_FACTOR)
+
+ protected def getMemoryOverheadMinMib(sparkConf: SparkConf): Long = sparkConf.get(
+ EXECUTOR_MIN_MEMORY_OVERHEAD)
+}
diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
new file mode 100644
index 000000000..543116c10
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
+import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.SparkContext
+
+trait ShimCometScanExec {
+ def wrapped: FileSourceScanExec
+
+ lazy val fileConstantMetadataColumns: Seq[AttributeReference] =
+ wrapped.fileConstantMetadataColumns
+
+ protected def newDataSourceRDD(
+ sc: SparkContext,
+ inputPartitions: Seq[Seq[InputPartition]],
+ partitionReaderFactory: PartitionReaderFactory,
+ columnarReads: Boolean,
+ customMetrics: Map[String, SQLMetric]): DataSourceRDD =
+ new DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads, customMetrics)
+
+ protected def newFileScanRDD(
+ fsRelation: HadoopFsRelation,
+ readFunction: PartitionedFile => Iterator[InternalRow],
+ filePartitions: Seq[FilePartition],
+ readSchema: StructType,
+ options: ParquetOptions): FileScanRDD = {
+ new FileScanRDD(
+ fsRelation.sparkSession,
+ readFunction,
+ filePartitions,
+ readSchema,
+ fileConstantMetadataColumns,
+ fsRelation.fileFormat.fileConstantMetadataExtractors,
+ options)
+ }
+
+ protected def invalidBucketFile(path: String, sparkVersion: String): Throwable =
+ QueryExecutionErrors.invalidBucketFile(path)
+
+ // see SPARK-39634
+ protected def isNeededForSchema(sparkSchema: StructType): Boolean = false
+
+ protected def getPartitionedFile(f: FileStatusWithMetadata, p: PartitionDirectory): PartitionedFile =
+ PartitionedFileUtil.getPartitionedFile(f, p.values, 0, f.getLen)
+
+ protected def splitFiles(sparkSession: SparkSession,
+ file: FileStatusWithMetadata,
+ filePath: Path,
+ isSplitable: Boolean,
+ maxSplitBytes: Long,
+ partitionValues: InternalRow): Seq[PartitionedFile] =
+ PartitionedFileUtil.splitFiles(file, isSplitable, maxSplitBytes, partitionValues)
+}
diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
new file mode 100644
index 000000000..f875e3f38
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.shuffle.ShuffleWriteProcessor
+
+trait ShimCometShuffleWriteProcessor extends ShuffleWriteProcessor {
+
+}
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
index 1357d6548..53186b131 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
@@ -21,13 +21,13 @@ package org.apache.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE}
+import org.apache.spark.sql.comet.shims.ShimCometTPCDSQuerySuite
import org.apache.comet.CometConf
class CometTPCDSQuerySuite
extends {
- // This is private in `TPCDSBase`.
- val excludedTpcdsQueries: Seq[String] = Seq()
+ override val excludedTpcdsQueries: Set[String] = Set()
// This is private in `TPCDSBase` and `excludedTpcdsQueries` is private too.
// So we cannot override `excludedTpcdsQueries` to exclude the queries.
@@ -145,7 +145,8 @@ class CometTPCDSQuerySuite
override val tpcdsQueries: Seq[String] =
tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains)
}
- with TPCDSQueryTestSuite {
+ with TPCDSQueryTestSuite
+ with ShimCometTPCDSQuerySuite {
override def sparkConf: SparkConf = {
val conf = super.sparkConf
conf.set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions")
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
index 1abe5faeb..ec87f19e9 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession}
import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
+import org.apache.comet.shims.ShimCometTPCHQuerySuite
/**
* End-to-end tests to check TPCH query results.
@@ -49,7 +50,7 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
* ./mvnw -Dsuites=org.apache.spark.sql.CometTPCHQuerySuite test
* }}}
*/
-class CometTPCHQuerySuite extends QueryTest with CometTPCBase with SQLQueryTestHelper {
+class CometTPCHQuerySuite extends QueryTest with CometTPCBase with ShimCometTPCHQuerySuite {
private val tpchDataPath = sys.env.get("SPARK_TPCH_DATA")
@@ -142,7 +143,7 @@ class CometTPCHQuerySuite extends QueryTest with CometTPCBase with SQLQueryTestH
val shouldSortResults = sortMergeJoinConf != conf // Sort for other joins
withSQLConf(conf.toSeq: _*) {
try {
- val (schema, output) = handleExceptions(getNormalizedResult(spark, query))
+ val (schema, output) = handleExceptions(getNormalizedQueryExecutionResult(spark, query))
val queryString = query.trim
val outputString = output.mkString("\n").replaceAll("\\s+$", "")
if (shouldRegenerateGoldenFiles) {
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 0530d764c..d8c82f12b 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -48,7 +48,6 @@ import org.apache.spark.sql.types.StructType
import org.apache.comet._
import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
import org.apache.comet.shims.ShimCometSparkSessionExtensions
-import org.apache.comet.shims.ShimCometSparkSessionExtensions.supportsExtendedExplainInfo
/**
* Base class for testing. This exists in `org.apache.spark.sql` since [[SQLTestUtils]] is
diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala
index 4c2f832af..fc4549445 100644
--- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala
@@ -25,11 +25,12 @@ import scala.collection.JavaConverters._
import scala.util.Random
import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.comet.shims.ShimTestUtils
import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnVector
-import org.apache.comet.{CometConf, TestUtils}
+import org.apache.comet.CometConf
import org.apache.comet.parquet.BatchReader
/**
@@ -123,7 +124,7 @@ object CometReadBenchmark extends CometBenchmarkBase {
(col: ColumnVector, i: Int) => longSum += col.getUTF8String(i).toLongExact
}
- val files = TestUtils.listDirectory(new File(dir, "parquetV1"))
+ val files = ShimTestUtils.listDirectory(new File(dir, "parquetV1"))
sqlBenchmark.addCase("ParquetReader Spark") { _ =>
files.map(_.asInstanceOf[String]).foreach { p =>
diff --git a/spark/src/test/spark-3.4/org/apache/comet/exec/CometExec3_4Suite.scala b/spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala
similarity index 98%
rename from spark/src/test/spark-3.4/org/apache/comet/exec/CometExec3_4Suite.scala
rename to spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala
index 019b4f030..31d1ffbf7 100644
--- a/spark/src/test/spark-3.4/org/apache/comet/exec/CometExec3_4Suite.scala
+++ b/spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala
@@ -27,9 +27,9 @@ import org.apache.spark.sql.CometTestBase
import org.apache.comet.CometConf
/**
- * This test suite contains tests for only Spark 3.4.
+ * This test suite contains tests for only Spark 3.4+.
*/
-class CometExec3_4Suite extends CometTestBase {
+class CometExec3_4PlusSuite extends CometTestBase {
import testImplicits._
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
diff --git a/spark/src/test/spark-3.x/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala b/spark/src/test/spark-3.x/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
new file mode 100644
index 000000000..caa943c26
--- /dev/null
+++ b/spark/src/test/spark-3.x/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.{SQLQueryTestHelper, SparkSession}
+
+trait ShimCometTPCHQuerySuite extends SQLQueryTestHelper {
+ protected def getNormalizedQueryExecutionResult(session: SparkSession, sql: String): (String, Seq[String]) = {
+ getNormalizedResult(session, sql)
+ }
+}
diff --git a/spark/src/test/scala/org/apache/comet/TestUtils.scala b/spark/src/test/spark-3.x/org/apache/spark/comet/shims/ShimTestUtils.scala
similarity index 96%
rename from spark/src/test/scala/org/apache/comet/TestUtils.scala
rename to spark/src/test/spark-3.x/org/apache/spark/comet/shims/ShimTestUtils.scala
index d4e771568..fcb543f9b 100644
--- a/spark/src/test/scala/org/apache/comet/TestUtils.scala
+++ b/spark/src/test/spark-3.x/org/apache/spark/comet/shims/ShimTestUtils.scala
@@ -17,13 +17,12 @@
* under the License.
*/
-package org.apache.comet
+package org.apache.spark.comet.shims
import java.io.File
-
import scala.collection.mutable.ArrayBuffer
-object TestUtils {
+object ShimTestUtils {
/**
* Spark 3.3.0 moved {{{SpecificParquetRecordReaderBase.listDirectory}}} to
diff --git a/spark/src/test/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala b/spark/src/test/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
new file mode 100644
index 000000000..f8d621c7e
--- /dev/null
+++ b/spark/src/test/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+trait ShimCometTPCDSQuerySuite {
+ // This is private in `TPCDSBase`.
+ val excludedTpcdsQueries: Set[String] = Set()
+}
diff --git a/spark/src/test/spark-4.0/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala b/spark/src/test/spark-4.0/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
new file mode 100644
index 000000000..ec9823e52
--- /dev/null
+++ b/spark/src/test/spark-4.0/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.SQLQueryTestHelper
+
+trait ShimCometTPCHQuerySuite extends SQLQueryTestHelper {
+}
diff --git a/spark/src/test/spark-4.0/org/apache/spark/comet/shims/ShimTestUtils.scala b/spark/src/test/spark-4.0/org/apache/spark/comet/shims/ShimTestUtils.scala
new file mode 100644
index 000000000..923ae68f2
--- /dev/null
+++ b/spark/src/test/spark-4.0/org/apache/spark/comet/shims/ShimTestUtils.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.comet.shims
+
+import java.io.File
+
+object ShimTestUtils {
+ def listDirectory(path: File): Array[String] =
+ org.apache.spark.TestUtils.listDirectory(path)
+}
diff --git a/spark/src/test/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala b/spark/src/test/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
new file mode 100644
index 000000000..43917df63
--- /dev/null
+++ b/spark/src/test/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+trait ShimCometTPCDSQuerySuite {
+
+}