diff --git a/.github/workflows/pr_build.yml b/.github/workflows/pr_build.yml index 1e347250f..410f1e1fe 100644 --- a/.github/workflows/pr_build.yml +++ b/.github/workflows/pr_build.yml @@ -76,6 +76,44 @@ jobs: # upload test reports only for java 17 upload-test-reports: ${{ matrix.java_version == '17' }} + linux-test-with-spark4_0: + strategy: + matrix: + os: [ubuntu-latest] + java_version: [17] + test-target: [java] + spark-version: ['4.0'] + is_push_event: + - ${{ github.event_name == 'push' }} + fail-fast: false + name: ${{ matrix.os }}/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }} + runs-on: ${{ matrix.os }} + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust & Java toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{env.RUST_VERSION}} + jdk-version: ${{ matrix.java_version }} + - name: Clone Spark + uses: actions/checkout@v4 + with: + repository: "apache/spark" + path: "apache-spark" + - name: Install Spark + shell: bash + working-directory: ./apache-spark + run: build/mvn install -Phive -Phadoop-cloud -DskipTests + - name: Java test steps + uses: ./.github/actions/java-test + with: + # TODO: remove -DskipTests after fixing tests + maven_opts: "-Pspark-${{ matrix.spark-version }} -DskipTests" + # TODO: upload test reports after enabling tests + upload-test-reports: false + linux-test-with-old-spark: strategy: matrix: @@ -169,6 +207,81 @@ jobs: with: maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} + macos-test-with-spark4_0: + strategy: + matrix: + os: [macos-13] + java_version: [17] + test-target: [java] + spark-version: ['4.0'] + fail-fast: false + if: github.event_name == 'push' + name: ${{ matrix.os }}/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }} + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Setup Rust & Java toolchain + uses: ./.github/actions/setup-macos-builder + with: + rust-version: ${{env.RUST_VERSION}} + jdk-version: ${{ matrix.java_version }} + - name: Clone Spark + uses: actions/checkout@v4 + with: + repository: "apache/spark" + path: "apache-spark" + - name: Install Spark + shell: bash + working-directory: ./apache-spark + run: build/mvn install -Phive -Phadoop-cloud -DskipTests + - name: Java test steps + uses: ./.github/actions/java-test + with: + # TODO: remove -DskipTests after fixing tests + maven_opts: "-Pspark-${{ matrix.spark-version }} -DskipTests" + # TODO: upload test reports after enabling tests + upload-test-reports: false + + macos-aarch64-test-with-spark4_0: + strategy: + matrix: + java_version: [17] + test-target: [java] + spark-version: ['4.0'] + is_push_event: + - ${{ github.event_name == 'push' }} + exclude: # exclude java 11 for pull_request event + - java_version: 11 + is_push_event: false + fail-fast: false + name: macos-14(Silicon)/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }} + runs-on: macos-14 + steps: + - uses: actions/checkout@v4 + - name: Setup Rust & Java toolchain + uses: ./.github/actions/setup-macos-builder + with: + rust-version: ${{env.RUST_VERSION}} + jdk-version: ${{ matrix.java_version }} + jdk-architecture: aarch64 + protoc-architecture: aarch_64 + - name: Clone Spark + uses: actions/checkout@v4 + with: + repository: "apache/spark" + path: "apache-spark" + - name: Install Spark + shell: bash + working-directory: ./apache-spark + run: build/mvn install -Phive -Phadoop-cloud -DskipTests + - name: Java test steps + uses: ./.github/actions/java-test + with: + # TODO: remove -DskipTests after fixing tests + maven_opts: "-Pspark-${{ matrix.spark-version }} -DskipTests" + # TODO: upload test reports after enabling tests + upload-test-reports: false + macos-aarch64-test-with-old-spark: strategy: matrix: diff --git a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala index d851067b5..d03252d06 100644 --- a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala +++ b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala @@ -20,10 +20,10 @@ package org.apache.comet.parquet import org.apache.hadoop.conf.Configuration +import org.apache.spark.sql.comet.shims.ShimCometParquetUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ -object CometParquetUtils { +object CometParquetUtils extends ShimCometParquetUtils { private val PARQUET_FIELD_ID_WRITE_ENABLED = "spark.sql.parquet.fieldId.write.enabled" private val PARQUET_FIELD_ID_READ_ENABLED = "spark.sql.parquet.fieldId.read.enabled" private val IGNORE_MISSING_PARQUET_FIELD_ID = "spark.sql.parquet.fieldId.read.ignoreMissing" @@ -39,61 +39,4 @@ object CometParquetUtils { def ignoreMissingIds(conf: SQLConf): Boolean = conf.getConfString(IGNORE_MISSING_PARQUET_FIELD_ID, "false").toBoolean - - // The following is copied from QueryExecutionErrors - // TODO: remove after dropping Spark 3.2.0 support and directly use - // QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError - def foundDuplicateFieldInFieldIdLookupModeError( - requiredId: Int, - matchedFields: String): Throwable = { - new RuntimeException(s""" - |Found duplicate field(s) "$requiredId": $matchedFields - |in id mapping mode - """.stripMargin.replaceAll("\n", " ")) - } - - // The followings are copied from org.apache.spark.sql.execution.datasources.parquet.ParquetUtils - // TODO: remove after dropping Spark 3.2.0 support and directly use ParquetUtils - /** - * A StructField metadata key used to set the field id of a column in the Parquet schema. - */ - val FIELD_ID_METADATA_KEY = "parquet.field.id" - - /** - * Whether there exists a field in the schema, whether inner or leaf, has the parquet field ID - * metadata. - */ - def hasFieldIds(schema: StructType): Boolean = { - def recursiveCheck(schema: DataType): Boolean = { - schema match { - case st: StructType => - st.exists(field => hasFieldId(field) || recursiveCheck(field.dataType)) - - case at: ArrayType => recursiveCheck(at.elementType) - - case mt: MapType => recursiveCheck(mt.keyType) || recursiveCheck(mt.valueType) - - case _ => - // No need to really check primitive types, just to terminate the recursion - false - } - } - if (schema.isEmpty) false else recursiveCheck(schema) - } - - def hasFieldId(field: StructField): Boolean = - field.metadata.contains(FIELD_ID_METADATA_KEY) - - def getFieldId(field: StructField): Int = { - require( - hasFieldId(field), - s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + field) - try { - Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY)) - } catch { - case _: ArithmeticException | _: ClassCastException => - throw new IllegalArgumentException( - s"The key `$FIELD_ID_METADATA_KEY` must be a 32-bit integer") - } - } } diff --git a/common/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala b/common/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala new file mode 100644 index 000000000..f22ac4060 --- /dev/null +++ b/common/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.sql.types._ + +trait ShimCometParquetUtils { + // The following is copied from QueryExecutionErrors + // TODO: remove after dropping Spark 3.2.0 support and directly use + // QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError + def foundDuplicateFieldInFieldIdLookupModeError( + requiredId: Int, + matchedFields: String): Throwable = { + new RuntimeException(s""" + |Found duplicate field(s) "$requiredId": $matchedFields + |in id mapping mode + """.stripMargin.replaceAll("\n", " ")) + } + + // The followings are copied from org.apache.spark.sql.execution.datasources.parquet.ParquetUtils + // TODO: remove after dropping Spark 3.2.0 support and directly use ParquetUtils + /** + * A StructField metadata key used to set the field id of a column in the Parquet schema. + */ + val FIELD_ID_METADATA_KEY = "parquet.field.id" + + /** + * Whether there exists a field in the schema, whether inner or leaf, has the parquet field ID + * metadata. + */ + def hasFieldIds(schema: StructType): Boolean = { + def recursiveCheck(schema: DataType): Boolean = { + schema match { + case st: StructType => + st.exists(field => hasFieldId(field) || recursiveCheck(field.dataType)) + + case at: ArrayType => recursiveCheck(at.elementType) + + case mt: MapType => recursiveCheck(mt.keyType) || recursiveCheck(mt.valueType) + + case _ => + // No need to really check primitive types, just to terminate the recursion + false + } + } + if (schema.isEmpty) false else recursiveCheck(schema) + } + + def hasFieldId(field: StructField): Boolean = + field.metadata.contains(FIELD_ID_METADATA_KEY) + + def getFieldId(field: StructField): Int = { + require( + hasFieldId(field), + s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + field) + try { + Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY)) + } catch { + case _: ArithmeticException | _: ClassCastException => + throw new IllegalArgumentException( + s"The key `$FIELD_ID_METADATA_KEY` must be a 32-bit integer") + } + } +} diff --git a/common/src/main/spark-4.0/org/apache/comet/shims/ShimBatchReader.scala b/common/src/main/spark-4.0/org/apache/comet/shims/ShimBatchReader.scala new file mode 100644 index 000000000..448d0886c --- /dev/null +++ b/common/src/main/spark-4.0/org/apache/comet/shims/ShimBatchReader.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.paths.SparkPath +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.PartitionedFile + +object ShimBatchReader { + def newPartitionedFile(partitionValues: InternalRow, file: String): PartitionedFile = + PartitionedFile( + partitionValues, + SparkPath.fromUrlString(file), + -1, // -1 means we read the entire file + -1, + Array.empty[String], + 0, + 0 + ) +} diff --git a/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala b/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala new file mode 100644 index 000000000..2f386869a --- /dev/null +++ b/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat + +object ShimFileFormat { + // A name for a temporary column that holds row indexes computed by the file format reader + // until they can be placed in the _metadata struct. + val ROW_INDEX_TEMPORARY_COLUMN_NAME = ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME + + val OPTION_RETURNING_BATCH = FileFormat.OPTION_RETURNING_BATCH +} diff --git a/common/src/main/spark-4.0/org/apache/comet/shims/ShimResolveDefaultColumns.scala b/common/src/main/spark-4.0/org/apache/comet/shims/ShimResolveDefaultColumns.scala new file mode 100644 index 000000000..60e21765b --- /dev/null +++ b/common/src/main/spark-4.0/org/apache/comet/shims/ShimResolveDefaultColumns.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + + +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns +import org.apache.spark.sql.types.{StructField, StructType} + +object ShimResolveDefaultColumns { + def getExistenceDefaultValue(field: StructField): Any = + ResolveDefaultColumns.getExistenceDefaultValues(StructType(Seq(field))).head +} diff --git a/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala b/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala new file mode 100644 index 000000000..d402cd786 --- /dev/null +++ b/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils +import org.apache.spark.sql.types._ + +trait ShimCometParquetUtils { + def foundDuplicateFieldInFieldIdLookupModeError( + requiredId: Int, + matchedFields: String): Throwable = { + QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError(requiredId, matchedFields) + } + + def hasFieldIds(schema: StructType): Boolean = ParquetUtils.hasFieldIds(schema) + + def hasFieldId(field: StructField): Boolean = ParquetUtils.hasFieldId(field) + + def getFieldId(field: StructField): Int = ParquetUtils.getFieldId (field) +} diff --git a/dev/ensure-jars-have-correct-contents.sh b/dev/ensure-jars-have-correct-contents.sh index 1f97d2d4a..12f555b8e 100755 --- a/dev/ensure-jars-have-correct-contents.sh +++ b/dev/ensure-jars-have-correct-contents.sh @@ -40,9 +40,11 @@ allowed_expr="(^org/$|^org/apache/$" # we have to allow the directories that lead to the org/apache/comet dir # We allow all the classes under the following packages: # * org.apache.comet +# * org.apache.spark.comet # * org.apache.spark.sql.comet # * org.apache.arrow.c allowed_expr+="|^org/apache/comet/" +allowed_expr+="|^org/apache/spark/comet/" allowed_expr+="|^org/apache/spark/sql/comet/" allowed_expr+="|^org/apache/arrow/c/" # * whatever in the "META-INF" directory diff --git a/pom.xml b/pom.xml index 59e0569ff..57b4206cf 100644 --- a/pom.xml +++ b/pom.xml @@ -87,7 +87,7 @@ under the License. -ea -Xmx4g -Xss4m ${extraJavaTestArgs} spark-3.3-plus - spark-3.4 + spark-3.4-plus spark-3.x spark-3.4 @@ -512,7 +512,6 @@ under the License. 3.3.2 3.3 1.12.0 - spark-3.3-plus not-needed-yet spark-3.3 @@ -524,9 +523,25 @@ under the License. 2.12.17 3.4 1.13.1 - spark-3.3-plus - spark-3.4 - spark-3.4 + + + + + + spark-4.0 + + + 2.13.13 + 2.13 + 4.0.0-SNAPSHOT + 4.0 + 1.13.1 + spark-4.0 + not-needed-yet + + 17 + ${java.version} + ${java.version} @@ -605,7 +620,7 @@ under the License. org.scalameta semanticdb-scalac_${scala.version} - 4.7.5 + 4.8.8 diff --git a/spark/pom.xml b/spark/pom.xml index 21fa09fc2..84e2e501f 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -252,6 +252,8 @@ under the License. src/test/${additional.3_3.test.source} src/test/${additional.3_4.test.source} + src/test/${shims.majorVerSrc} + src/test/${shims.minorVerSrc} diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala index ac871cf60..52d8d09a0 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.{DateType, StructType, TimestampType} import org.apache.spark.util.SerializableConfiguration @@ -144,7 +143,7 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with isCaseSensitive, useFieldId, ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + datetimeRebaseSpec.mode == CORRECTED, partitionSchema, file.partitionValues, JavaConverters.mapAsJavaMap(metrics)) @@ -161,7 +160,7 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with } } -object CometParquetFileFormat extends Logging { +object CometParquetFileFormat extends Logging with ShimSQLConf { /** * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` @@ -210,7 +209,7 @@ object CometParquetFileFormat extends Logging { case _ => false }) - if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { + if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LEGACY) { if (exceptionOnRebase) { logWarning( s"""Found Parquet file $file that could potentially contain dates/timestamps that were @@ -222,7 +221,7 @@ object CometParquetFileFormat extends Logging { calendar, please disable Comet for this query.""") } else { // do not throw exception on rebase - read as it is - datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) + datetimeRebaseSpec = datetimeRebaseSpec.copy(CORRECTED) } } diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala index 693af125b..e48d76384 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -135,7 +134,7 @@ case class CometParquetPartitionReaderFactory( isCaseSensitive, useFieldId, ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + datetimeRebaseSpec.mode == CORRECTED, partitionSchema, file.partitionValues, JavaConverters.mapAsJavaMap(metrics)) diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala index 5994dfb41..58c2aeb41 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala @@ -38,11 +38,11 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus +import org.apache.comet.shims.ShimSQLConf /** * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove @@ -58,7 +58,8 @@ class ParquetFilters( pushDownStringPredicate: Boolean, pushDownInFilterThreshold: Int, caseSensitive: Boolean, - datetimeRebaseSpec: RebaseSpec) { + datetimeRebaseSpec: RebaseSpec) + extends ShimSQLConf { // A map which contains parquet field name and data type, if predicate push down applies. // // Each key in `nameToParquetField` represents a column; `dots` are used as separators for @@ -153,7 +154,7 @@ class ParquetFilters( case ld: LocalDate => DateTimeUtils.localDateToDays(ld) } datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) + case LEGACY => rebaseGregorianToJulianDays(gregorianDays) case _ => gregorianDays } } @@ -164,7 +165,7 @@ class ParquetFilters( case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) } datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => + case LEGACY => rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) case _ => gregorianMicros } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 6333650dd..a717e066e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1987,18 +1987,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for // char types. Use rpad to achieve the behavior. // See https://github.com/apache/spark/pull/38151 - case StaticInvoke( - _: Class[CharVarcharCodegenUtils], - _: StringType, - "readSidePadding", - arguments, - _, - true, - false, - true) if arguments.size == 2 => + case s: StaticInvoke + if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] && + s.dataType.isInstanceOf[StringType] && + s.functionName == "readSidePadding" && + s.arguments.size == 2 && + s.propagateNull && + !s.returnNullable && + s.isDeterministic => val argsExpr = Seq( - exprToProtoInternal(Cast(arguments(0), StringType), inputs), - exprToProtoInternal(arguments(1), inputs)) + exprToProtoInternal(Cast(s.arguments(0), StringType), inputs), + exprToProtoInternal(s.arguments(1), inputs)) if (argsExpr.forall(_.isDefined)) { val builder = ExprOuterClass.ScalarFunc.newBuilder() @@ -2007,7 +2006,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) } else { - withInfo(expr, arguments: _*) + withInfo(expr, s.arguments: _*) None } diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 97838448a..dcc00f66c 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -23,9 +23,9 @@ import java.{util => ju} import java.util.Collections import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin} +import org.apache.spark.comet.shims.ShimCometDriverPlugin import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} -import org.apache.spark.resource.ResourceProfile import org.apache.comet.{CometConf, CometSparkSessionExtensions} @@ -40,9 +40,7 @@ import org.apache.comet.{CometConf, CometSparkSessionExtensions} * * To enable this plugin, set the config "spark.plugins" to `org.apache.spark.CometPlugin`. */ -class CometDriverPlugin extends DriverPlugin with Logging { - import CometDriverPlugin._ - +class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPlugin { override def init(sc: SparkContext, pluginContext: PluginContext): ju.Map[String, String] = { logInfo("CometDriverPlugin init") @@ -52,14 +50,10 @@ class CometDriverPlugin extends DriverPlugin with Logging { } else { // By default, executorMemory * spark.executor.memoryOverheadFactor, with minimum of 384MB val executorMemory = sc.getConf.getSizeAsMb(EXECUTOR_MEMORY.key) - val memoryOverheadFactor = - sc.getConf.getDouble( - EXECUTOR_MEMORY_OVERHEAD_FACTOR, - EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT) - - Math.max( - (executorMemory * memoryOverheadFactor).toInt, - ResourceProfile.MEMORY_OVERHEAD_MIN_MIB) + val memoryOverheadFactor = getMemoryOverheadFactor(sc.getConf) + val memoryOverheadMinMib = getMemoryOverheadMinMib(sc.getConf) + + Math.max((executorMemory * memoryOverheadFactor).toLong, memoryOverheadMinMib) } val cometMemOverhead = CometSparkSessionExtensions.getCometMemoryOverheadInMiB(sc.getConf) @@ -100,12 +94,6 @@ class CometDriverPlugin extends DriverPlugin with Logging { } } -object CometDriverPlugin { - // `org.apache.spark.internal.config.EXECUTOR_MEMORY_OVERHEAD_FACTOR` was added since Spark 3.3.0 - val EXECUTOR_MEMORY_OVERHEAD_FACTOR = "spark.executor.memoryOverheadFactor" - val EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT = 0.1 -} - /** * The Comet plugin for Spark. To enable this plugin, set the config "spark.plugins" to * `org.apache.spark.CometPlugin` diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 7bd34debb..38247b2c4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -27,6 +27,7 @@ import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext} +import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -44,7 +45,6 @@ import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects import org.apache.comet.CometRuntimeException -import org.apache.comet.shims.ShimCometBroadcastExchangeExec /** * A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala index 14a664108..9a5b55d65 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.comet.shims.ShimCometScanExec import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions @@ -43,7 +44,7 @@ import org.apache.spark.util.collection._ import org.apache.comet.{CometConf, MetricsSupport} import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} -import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} +import org.apache.comet.shims.ShimFileFormat /** * Comet physical scan node for DataSource V1. Most of the code here follow Spark's @@ -271,7 +272,7 @@ case class CometScanExec( selectedPartitions .flatMap { p => p.files.map { f => - PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) + getPartitionedFile(f, p) } } .groupBy { f => @@ -358,7 +359,7 @@ case class CometScanExec( // SPARK-39634: Allow file splitting in combination with row index generation once // the fix for PARQUET-2161 is available. !isNeededForSchema(requiredSchema) - PartitionedFileUtil.splitFiles( + super.splitFiles( sparkSession = relation.sparkSession, file = file, filePath = filePath, @@ -409,7 +410,7 @@ case class CometScanExec( Map.empty) } else { newFileScanRDD( - fsRelation.sparkSession, + fsRelation, readFile, partitions, new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala b/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala index 13f26ce58..a2cdf421c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala @@ -59,7 +59,7 @@ object DecimalPrecision { } CheckOverflow(add, resultType, nullOnOverflow) - case sub @ Subtract(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) => + case sub @ Subtract(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => val resultScale = max(s1, s2) val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) @@ -68,7 +68,7 @@ object DecimalPrecision { } CheckOverflow(sub, resultType, nullOnOverflow) - case mul @ Multiply(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) => + case mul @ Multiply(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) } else { @@ -76,7 +76,7 @@ object DecimalPrecision { } CheckOverflow(mul, resultType, nullOnOverflow) - case div @ Divide(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) => + case div @ Divide(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => val resultType = if (allowPrecisionLoss) { // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) // Scale: max(6, s1 + p2 + 1) @@ -96,7 +96,7 @@ object DecimalPrecision { } CheckOverflow(div, resultType, nullOnOverflow) - case rem @ Remainder(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) => + case rem @ Remainder(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } else { @@ -108,6 +108,7 @@ object DecimalPrecision { } } + // TODO: consider to use `org.apache.spark.sql.types.DecimalExpression` for Spark 3.5+ object DecimalExpression { def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { case t: DecimalType => Some((t.precision, t.scale)) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 49c263f3f..3f4d7bfd3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -28,10 +28,10 @@ import scala.concurrent.Future import org.apache.spark._ import org.apache.spark.internal.config -import org.apache.spark.rdd.{MapPartitionsRDD, RDD} +import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter, ShuffleWriteProcessor} +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} @@ -39,12 +39,12 @@ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan} +import org.apache.spark.sql.comet.shims.ShimCometShuffleWriteProcessor import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} @@ -68,7 +68,8 @@ case class CometShuffleExchangeExec( shuffleType: ShuffleType = CometNativeShuffle, advisoryPartitionSize: Option[Long] = None) extends ShuffleExchangeLike - with CometPlan { + with CometPlan + with ShimCometShuffleExchangeExec { private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) @@ -127,6 +128,9 @@ case class CometShuffleExchangeExec( Statistics(dataSize, Some(rowCount)) } + // TODO: add `override` keyword after dropping Spark-3.x supports + def shuffleId: Int = getShuffleId(shuffleDependency) + /** * A [[ShuffleDependency]] that will partition rows of its child based on the partitioning * scheme defined in `newPartitioning`. Those partitions of the returned ShuffleDependency will @@ -386,7 +390,7 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { val pageSize = SparkEnv.get.memoryManager.pageSizeBytes val sorter = UnsafeExternalRowSorter.createWithRecordComparator( - StructType.fromAttributes(outputAttributes), + fromAttributes(outputAttributes), recordComparatorSupplier, prefixComparator, prefixComputer, @@ -430,7 +434,7 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { serializer, shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics), shuffleType = CometColumnarShuffle, - schema = Some(StructType.fromAttributes(outputAttributes))) + schema = Some(fromAttributes(outputAttributes))) dependency } @@ -445,7 +449,7 @@ class CometShuffleWriteProcessor( outputPartitioning: Partitioning, outputAttributes: Seq[Attribute], metrics: Map[String, SQLMetric]) - extends ShuffleWriteProcessor { + extends ShimCometShuffleWriteProcessor { private val OFFSET_LENGTH = 8 @@ -455,11 +459,11 @@ class CometShuffleWriteProcessor( } override def write( - rdd: RDD[_], + inputs: Iterator[_], dep: ShuffleDependency[_, _, _], mapId: Long, - context: TaskContext, - partition: Partition): MapStatus = { + mapIndex: Int, + context: TaskContext): MapStatus = { val shuffleBlockResolver = SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] val dataFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) @@ -469,10 +473,6 @@ class CometShuffleWriteProcessor( val tempDataFilePath = Paths.get(tempDataFilename) val tempIndexFilePath = Paths.get(tempIndexFilename) - // Getting rid of the fake partitionId - val cometRDD = - rdd.asInstanceOf[MapPartitionsRDD[_, _]].prev.asInstanceOf[RDD[ColumnarBatch]] - // Call native shuffle write val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename) @@ -482,8 +482,12 @@ class CometShuffleWriteProcessor( "elapsed_compute" -> metrics("shuffleReadElapsedCompute")) val nativeMetrics = CometMetricNode(nativeSQLMetrics) - val rawIter = cometRDD.iterator(partition, context) - val cometIter = CometExec.getCometIterator(Seq(rawIter), nativePlan, nativeMetrics) + // Getting rid of the fake partitionId + val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) + val cometIter = CometExec.getCometIterator( + Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), + nativePlan, + nativeMetrics) while (cometIter.hasNext) { cometIter.next() diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala index 0c45a9c2c..f5a578f82 100644 --- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._ */ trait CometExprShim { /** - * Returns a tuple of expressions for the `unhex` function. - */ + * Returns a tuple of expressions for the `unhex` function. + */ def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala index 0c45a9c2c..f5a578f82 100644 --- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._ */ trait CometExprShim { /** - * Returns a tuple of expressions for the `unhex` function. - */ + * Returns a tuple of expressions for the `unhex` function. + */ def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 409e1c94b..3f2301f0a 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._ */ trait CometExprShim { /** - * Returns a tuple of expressions for the `unhex` function. - */ + * Returns a tuple of expressions for the `unhex` function. + */ def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) } diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala index 6b4fad974..350aeb9f0 100644 --- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala +++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala @@ -19,8 +19,11 @@ package org.apache.comet.shims +import org.apache.spark.ShuffleDependency +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, ShuffleType} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.types.{StructField, StructType} trait ShimCometShuffleExchangeExec { // TODO: remove after dropping Spark 3.2 and 3.3 support @@ -37,4 +40,11 @@ trait ShimCometShuffleExchangeExec { shuffleType, advisoryPartitionSize) } + + // TODO: remove after dropping Spark 3.x support + protected def fromAttributes(attributes: Seq[Attribute]): StructType = + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + + // TODO: remove after dropping Spark 3.x support + protected def getShuffleId(shuffleDependency: ShuffleDependency[Int, _, _]): Int = 0 } diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index eb04c68ab..377485335 100644 --- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -24,8 +24,6 @@ import org.apache.spark.sql.execution.{LimitExec, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan trait ShimCometSparkSessionExtensions { - import org.apache.comet.shims.ShimCometSparkSessionExtensions._ - /** * TODO: delete after dropping Spark 3.2.0 support and directly call scan.pushedAggregate */ @@ -45,9 +43,7 @@ trait ShimCometSparkSessionExtensions { * SQLConf.EXTENDED_EXPLAIN_PROVIDERS.key */ protected val EXTENDED_EXPLAIN_PROVIDERS_KEY = "spark.sql.extendedExplainProviders" -} -object ShimCometSparkSessionExtensions { private def getOffsetOpt(plan: SparkPlan): Option[Int] = plan.getClass.getDeclaredFields .filter(_.getName == "offset") .map { a => a.setAccessible(true); a.get(plan) } diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala index ff60ef964..c3d0c56e5 100644 --- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala +++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy trait ShimSQLConf { @@ -39,4 +40,7 @@ trait ShimSQLConf { case _ => None }) .head + + protected val LEGACY = LegacyBehaviorPolicy.LEGACY + protected val CORRECTED = LegacyBehaviorPolicy.CORRECTED } diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala similarity index 98% rename from spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala rename to spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala index 63ff2a2c1..aede47951 100644 --- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala +++ b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.comet.shims +package org.apache.spark.comet.shims import scala.reflect.ClassTag diff --git a/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala new file mode 100644 index 000000000..cfb6a0088 --- /dev/null +++ b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.comet.shims + +import org.apache.spark.SparkConf + +trait ShimCometDriverPlugin { + // `org.apache.spark.internal.config.EXECUTOR_MEMORY_OVERHEAD_FACTOR` was added since Spark 3.3.0 + private val EXECUTOR_MEMORY_OVERHEAD_FACTOR = "spark.executor.memoryOverheadFactor" + private val EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT = 0.1 + // `org.apache.spark.internal.config.EXECUTOR_MIN_MEMORY_OVERHEAD` was added since Spark 4.0.0 + private val EXECUTOR_MIN_MEMORY_OVERHEAD = "spark.executor.minMemoryOverhead" + private val EXECUTOR_MIN_MEMORY_OVERHEAD_DEFAULT = 384L + + def getMemoryOverheadFactor(sc: SparkConf): Double = + sc.getDouble( + EXECUTOR_MEMORY_OVERHEAD_FACTOR, + EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT) + def getMemoryOverheadMinMib(sc: SparkConf): Long = + sc.getLong(EXECUTOR_MIN_MEMORY_OVERHEAD, EXECUTOR_MIN_MEMORY_OVERHEAD_DEFAULT) +} diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala similarity index 81% rename from spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometScanExec.scala rename to spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala index 544a67385..02b97f9fb 100644 --- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometScanExec.scala +++ b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala @@ -17,17 +17,21 @@ * under the License. */ -package org.apache.comet.shims +package org.apache.spark.sql.comet.shims + +import org.apache.comet.shims.ShimFileFormat import scala.language.implicitConversions +import org.apache.hadoop.fs.{FileStatus, Path} + import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} -import org.apache.spark.sql.execution.FileSourceScanExec -import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} +import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} +import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile} import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD import org.apache.spark.sql.execution.metric.SQLMetric @@ -63,7 +67,7 @@ trait ShimCometScanExec { // TODO: remove after dropping Spark 3.2 support and directly call new FileScanRDD protected def newFileScanRDD( - sparkSession: SparkSession, + fsRelation: HadoopFsRelation, readFunction: PartitionedFile => Iterator[InternalRow], filePartitions: Seq[FilePartition], readSchema: StructType, @@ -73,12 +77,12 @@ trait ShimCometScanExec { .filter(c => List(3, 5, 6).contains(c.getParameterCount()) ) .map { c => c.getParameterCount match { - case 3 => c.newInstance(sparkSession, readFunction, filePartitions) + case 3 => c.newInstance(fsRelation.sparkSession, readFunction, filePartitions) case 5 => - c.newInstance(sparkSession, readFunction, filePartitions, readSchema, metadataColumns) + c.newInstance(fsRelation.sparkSession, readFunction, filePartitions, readSchema, metadataColumns) case 6 => c.newInstance( - sparkSession, + fsRelation.sparkSession, readFunction, filePartitions, readSchema, @@ -123,4 +127,15 @@ trait ShimCometScanExec { protected def isNeededForSchema(sparkSchema: StructType): Boolean = { findRowIndexColumnIndexInSchema(sparkSchema) >= 0 } + + protected def getPartitionedFile(f: FileStatus, p: PartitionDirectory): PartitionedFile = + PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) + + protected def splitFiles(sparkSession: SparkSession, + file: FileStatus, + filePath: Path, + isSplitable: Boolean, + maxSplitBytes: Long, + partitionValues: InternalRow): Seq[PartitionedFile] = + PartitionedFileUtil.splitFiles(sparkSession, file, filePath, isSplitable, maxSplitBytes, partitionValues) } diff --git a/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala new file mode 100644 index 000000000..9100b90c2 --- /dev/null +++ b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.{Partition, ShuffleDependency, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.ShuffleWriteProcessor + +trait ShimCometShuffleWriteProcessor extends ShuffleWriteProcessor { + override def write( + rdd: RDD[_], + dep: ShuffleDependency[_, _, _], + mapId: Long, + context: TaskContext, + partition: Partition): MapStatus = { + val rawIter = rdd.iterator(partition, context) + write(rawIter, dep, mapId, partition.index, context) + } + + def write( + inputs: Iterator[_], + dep: ShuffleDependency[_, _, _], + mapId: Long, + mapIndex: Int, + context: TaskContext): MapStatus +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala new file mode 100644 index 000000000..01f923206 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. + */ +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ + protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(unhex.failOnError)) + } +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBatchScanExec.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBatchScanExec.scala new file mode 100644 index 000000000..167b539f8 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBatchScanExec.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec + +trait ShimCometBatchScanExec { + def wrapped: BatchScanExec + + def keyGroupedPartitioning: Option[Seq[Expression]] = wrapped.keyGroupedPartitioning + + def inputPartitions: Seq[InputPartition] = wrapped.inputPartitions + + def ordering: Option[Seq[SortOrder]] = wrapped.ordering +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala new file mode 100644 index 000000000..1f689b400 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning} + +trait ShimCometBroadcastHashJoinExec { + protected def getHashPartitioningLikeExpressions(partitioning: Partitioning): Seq[Expression] = + partitioning match { + case p: HashPartitioningLike => p.expressions + case _ => Seq() + } +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala new file mode 100644 index 000000000..559e327b4 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.ShuffleDependency +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, ShuffleType} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.types.StructType + +trait ShimCometShuffleExchangeExec { + def apply(s: ShuffleExchangeExec, shuffleType: ShuffleType): CometShuffleExchangeExec = { + CometShuffleExchangeExec( + s.outputPartitioning, + s.child, + s, + s.shuffleOrigin, + shuffleType, + s.advisoryPartitionSize) + } + + protected def fromAttributes(attributes: Seq[Attribute]): StructType = DataTypeUtils.fromAttributes(attributes) + + protected def getShuffleId(shuffleDependency: ShuffleDependency[Int, _, _]): Int = shuffleDependency.shuffleId +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala new file mode 100644 index 000000000..9fb7355ee --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.execution.{CollectLimitExec, GlobalLimitExec, LocalLimitExec, QueryExecution} +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.internal.SQLConf + +trait ShimCometSparkSessionExtensions { + protected def getPushedAggregate(scan: ParquetScan): Option[Aggregation] = scan.pushedAggregate + + protected def getOffset(limit: LocalLimitExec): Int = 0 + protected def getOffset(limit: GlobalLimitExec): Int = limit.offset + protected def getOffset(limit: CollectLimitExec): Int = limit.offset + + protected def supportsExtendedExplainInfo(qe: QueryExecution): Boolean = true + + protected val EXTENDED_EXPLAIN_PROVIDERS_KEY = SQLConf.EXTENDED_EXPLAIN_PROVIDERS.key +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometTakeOrderedAndProjectExec.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometTakeOrderedAndProjectExec.scala new file mode 100644 index 000000000..5a8ac97b3 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometTakeOrderedAndProjectExec.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.execution.TakeOrderedAndProjectExec + +trait ShimCometTakeOrderedAndProjectExec { + protected def getOffset(plan: TakeOrderedAndProjectExec): Option[Int] = Some(plan.offset) +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala new file mode 100644 index 000000000..4d261f3c2 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, BinaryExpression, BloomFilterMightContain, EvalMode} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} + +trait ShimQueryPlanSerde { + protected def getFailOnError(b: BinaryArithmetic): Boolean = + b.getClass.getMethod("failOnError").invoke(b).asInstanceOf[Boolean] + + protected def getFailOnError(aggregate: Sum): Boolean = aggregate.initQueryContext().isDefined + protected def getFailOnError(aggregate: Average): Boolean = aggregate.initQueryContext().isDefined + + protected def isLegacyMode(aggregate: Sum): Boolean = aggregate.evalMode.equals(EvalMode.LEGACY) + protected def isLegacyMode(aggregate: Average): Boolean = aggregate.evalMode.equals(EvalMode.LEGACY) + + protected def isBloomFilterMightContain(binary: BinaryExpression): Boolean = + binary.isInstanceOf[BloomFilterMightContain] +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala new file mode 100644 index 000000000..574967767 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.LegacyBehaviorPolicy + +trait ShimSQLConf { + protected def getPushDownStringPredicate(sqlConf: SQLConf): Boolean = + sqlConf.parquetFilterPushDownStringPredicate + + protected val LEGACY = LegacyBehaviorPolicy.LEGACY + protected val CORRECTED = LegacyBehaviorPolicy.CORRECTED +} diff --git a/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala new file mode 100644 index 000000000..ba87a2515 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.comet.shims + +import scala.reflect.ClassTag + +import org.apache.spark.SparkContext +import org.apache.spark.broadcast.Broadcast + +trait ShimCometBroadcastExchangeExec { + protected def doBroadcast[T: ClassTag](sparkContext: SparkContext, value: T): Broadcast[Any] = + sparkContext.broadcastInternal(value, true) +} diff --git a/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala new file mode 100644 index 000000000..f7a57a642 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.comet.shims + +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.EXECUTOR_MEMORY_OVERHEAD_FACTOR +import org.apache.spark.internal.config.EXECUTOR_MIN_MEMORY_OVERHEAD + +trait ShimCometDriverPlugin { + protected def getMemoryOverheadFactor(sparkConf: SparkConf): Double = sparkConf.get( + EXECUTOR_MEMORY_OVERHEAD_FACTOR) + + protected def getMemoryOverheadMinMib(sparkConf: SparkConf): Long = sparkConf.get( + EXECUTOR_MIN_MEMORY_OVERHEAD) +} diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala new file mode 100644 index 000000000..543116c10 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} +import org.apache.spark.sql.types.StructType +import org.apache.spark.SparkContext + +trait ShimCometScanExec { + def wrapped: FileSourceScanExec + + lazy val fileConstantMetadataColumns: Seq[AttributeReference] = + wrapped.fileConstantMetadataColumns + + protected def newDataSourceRDD( + sc: SparkContext, + inputPartitions: Seq[Seq[InputPartition]], + partitionReaderFactory: PartitionReaderFactory, + columnarReads: Boolean, + customMetrics: Map[String, SQLMetric]): DataSourceRDD = + new DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads, customMetrics) + + protected def newFileScanRDD( + fsRelation: HadoopFsRelation, + readFunction: PartitionedFile => Iterator[InternalRow], + filePartitions: Seq[FilePartition], + readSchema: StructType, + options: ParquetOptions): FileScanRDD = { + new FileScanRDD( + fsRelation.sparkSession, + readFunction, + filePartitions, + readSchema, + fileConstantMetadataColumns, + fsRelation.fileFormat.fileConstantMetadataExtractors, + options) + } + + protected def invalidBucketFile(path: String, sparkVersion: String): Throwable = + QueryExecutionErrors.invalidBucketFile(path) + + // see SPARK-39634 + protected def isNeededForSchema(sparkSchema: StructType): Boolean = false + + protected def getPartitionedFile(f: FileStatusWithMetadata, p: PartitionDirectory): PartitionedFile = + PartitionedFileUtil.getPartitionedFile(f, p.values, 0, f.getLen) + + protected def splitFiles(sparkSession: SparkSession, + file: FileStatusWithMetadata, + filePath: Path, + isSplitable: Boolean, + maxSplitBytes: Long, + partitionValues: InternalRow): Seq[PartitionedFile] = + PartitionedFileUtil.splitFiles(file, isSplitable, maxSplitBytes, partitionValues) +} diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala new file mode 100644 index 000000000..f875e3f38 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.shuffle.ShuffleWriteProcessor + +trait ShimCometShuffleWriteProcessor extends ShuffleWriteProcessor { + +} diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala index 1357d6548..53186b131 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala @@ -21,13 +21,13 @@ package org.apache.spark.sql import org.apache.spark.SparkConf import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} +import org.apache.spark.sql.comet.shims.ShimCometTPCDSQuerySuite import org.apache.comet.CometConf class CometTPCDSQuerySuite extends { - // This is private in `TPCDSBase`. - val excludedTpcdsQueries: Seq[String] = Seq() + override val excludedTpcdsQueries: Set[String] = Set() // This is private in `TPCDSBase` and `excludedTpcdsQueries` is private too. // So we cannot override `excludedTpcdsQueries` to exclude the queries. @@ -145,7 +145,8 @@ class CometTPCDSQuerySuite override val tpcdsQueries: Seq[String] = tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains) } - with TPCDSQueryTestSuite { + with TPCDSQueryTestSuite + with ShimCometTPCDSQuerySuite { override def sparkConf: SparkConf = { val conf = super.sparkConf conf.set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala index 1abe5faeb..ec87f19e9 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession} import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus +import org.apache.comet.shims.ShimCometTPCHQuerySuite /** * End-to-end tests to check TPCH query results. @@ -49,7 +50,7 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus * ./mvnw -Dsuites=org.apache.spark.sql.CometTPCHQuerySuite test * }}} */ -class CometTPCHQuerySuite extends QueryTest with CometTPCBase with SQLQueryTestHelper { +class CometTPCHQuerySuite extends QueryTest with CometTPCBase with ShimCometTPCHQuerySuite { private val tpchDataPath = sys.env.get("SPARK_TPCH_DATA") @@ -142,7 +143,7 @@ class CometTPCHQuerySuite extends QueryTest with CometTPCBase with SQLQueryTestH val shouldSortResults = sortMergeJoinConf != conf // Sort for other joins withSQLConf(conf.toSeq: _*) { try { - val (schema, output) = handleExceptions(getNormalizedResult(spark, query)) + val (schema, output) = handleExceptions(getNormalizedQueryExecutionResult(spark, query)) val queryString = query.trim val outputString = output.mkString("\n").replaceAll("\\s+$", "") if (shouldRegenerateGoldenFiles) { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 0530d764c..d8c82f12b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -48,7 +48,6 @@ import org.apache.spark.sql.types.StructType import org.apache.comet._ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus import org.apache.comet.shims.ShimCometSparkSessionExtensions -import org.apache.comet.shims.ShimCometSparkSessionExtensions.supportsExtendedExplainInfo /** * Base class for testing. This exists in `org.apache.spark.sql` since [[SQLTestUtils]] is diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala index 4c2f832af..fc4549445 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala @@ -25,11 +25,12 @@ import scala.collection.JavaConverters._ import scala.util.Random import org.apache.spark.benchmark.Benchmark +import org.apache.spark.comet.shims.ShimTestUtils import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnVector -import org.apache.comet.{CometConf, TestUtils} +import org.apache.comet.CometConf import org.apache.comet.parquet.BatchReader /** @@ -123,7 +124,7 @@ object CometReadBenchmark extends CometBenchmarkBase { (col: ColumnVector, i: Int) => longSum += col.getUTF8String(i).toLongExact } - val files = TestUtils.listDirectory(new File(dir, "parquetV1")) + val files = ShimTestUtils.listDirectory(new File(dir, "parquetV1")) sqlBenchmark.addCase("ParquetReader Spark") { _ => files.map(_.asInstanceOf[String]).foreach { p => diff --git a/spark/src/test/spark-3.4/org/apache/comet/exec/CometExec3_4Suite.scala b/spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala similarity index 98% rename from spark/src/test/spark-3.4/org/apache/comet/exec/CometExec3_4Suite.scala rename to spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala index 019b4f030..31d1ffbf7 100644 --- a/spark/src/test/spark-3.4/org/apache/comet/exec/CometExec3_4Suite.scala +++ b/spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.CometTestBase import org.apache.comet.CometConf /** - * This test suite contains tests for only Spark 3.4. + * This test suite contains tests for only Spark 3.4+. */ -class CometExec3_4Suite extends CometTestBase { +class CometExec3_4PlusSuite extends CometTestBase { import testImplicits._ override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit diff --git a/spark/src/test/spark-3.x/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala b/spark/src/test/spark-3.x/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala new file mode 100644 index 000000000..caa943c26 --- /dev/null +++ b/spark/src/test/spark-3.x/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.{SQLQueryTestHelper, SparkSession} + +trait ShimCometTPCHQuerySuite extends SQLQueryTestHelper { + protected def getNormalizedQueryExecutionResult(session: SparkSession, sql: String): (String, Seq[String]) = { + getNormalizedResult(session, sql) + } +} diff --git a/spark/src/test/scala/org/apache/comet/TestUtils.scala b/spark/src/test/spark-3.x/org/apache/spark/comet/shims/ShimTestUtils.scala similarity index 96% rename from spark/src/test/scala/org/apache/comet/TestUtils.scala rename to spark/src/test/spark-3.x/org/apache/spark/comet/shims/ShimTestUtils.scala index d4e771568..fcb543f9b 100644 --- a/spark/src/test/scala/org/apache/comet/TestUtils.scala +++ b/spark/src/test/spark-3.x/org/apache/spark/comet/shims/ShimTestUtils.scala @@ -17,13 +17,12 @@ * under the License. */ -package org.apache.comet +package org.apache.spark.comet.shims import java.io.File - import scala.collection.mutable.ArrayBuffer -object TestUtils { +object ShimTestUtils { /** * Spark 3.3.0 moved {{{SpecificParquetRecordReaderBase.listDirectory}}} to diff --git a/spark/src/test/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala b/spark/src/test/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala new file mode 100644 index 000000000..f8d621c7e --- /dev/null +++ b/spark/src/test/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +trait ShimCometTPCDSQuerySuite { + // This is private in `TPCDSBase`. + val excludedTpcdsQueries: Set[String] = Set() +} diff --git a/spark/src/test/spark-4.0/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala b/spark/src/test/spark-4.0/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala new file mode 100644 index 000000000..ec9823e52 --- /dev/null +++ b/spark/src/test/spark-4.0/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.SQLQueryTestHelper + +trait ShimCometTPCHQuerySuite extends SQLQueryTestHelper { +} diff --git a/spark/src/test/spark-4.0/org/apache/spark/comet/shims/ShimTestUtils.scala b/spark/src/test/spark-4.0/org/apache/spark/comet/shims/ShimTestUtils.scala new file mode 100644 index 000000000..923ae68f2 --- /dev/null +++ b/spark/src/test/spark-4.0/org/apache/spark/comet/shims/ShimTestUtils.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.comet.shims + +import java.io.File + +object ShimTestUtils { + def listDirectory(path: File): Array[String] = + org.apache.spark.TestUtils.listDirectory(path) +} diff --git a/spark/src/test/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala b/spark/src/test/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala new file mode 100644 index 000000000..43917df63 --- /dev/null +++ b/spark/src/test/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +trait ShimCometTPCDSQuerySuite { + +}