diff --git a/pom.xml b/pom.xml
index 4fb377954..b9001c9cd 100644
--- a/pom.xml
+++ b/pom.xml
@@ -91,6 +91,7 @@ under the License.
spark-3.4-plus
spark-3.x
spark-3.4
+ spark-pre-3.5
@@ -527,7 +528,9 @@ under the License.
not-needed-yet
not-needed-yet
+ spark-3.x
spark-3.2
+ spark-pre-3.5
@@ -539,7 +542,9 @@ under the License.
3.3
1.12.0
not-needed-yet
+ spark-3.x
spark-3.3
+ spark-pre-3.5
@@ -549,6 +554,25 @@ under the License.
2.12.17
3.4
1.13.1
+ spark-3.x
+ spark-3.4
+ spark-pre-3.5
+
+
+
+
+
+ spark-3.5
+
+ 2.12.17
+ 3.5.1
+ 3.5
+ 1.13.1
+
+
+ spark-3.x
+ spark-3.5
+ not-needed-yet
@@ -564,6 +588,7 @@ under the License.
1.13.1
spark-4.0
not-needed-yet
+ not-needed-yet
17
${java.version}
@@ -693,7 +718,6 @@ under the License.
- -->
${scala.version}
true
true
diff --git a/spark/pom.xml b/spark/pom.xml
index 84e2e501f..8585b4b8c 100644
--- a/spark/pom.xml
+++ b/spark/pom.xml
@@ -254,6 +254,7 @@ under the License.
+
@@ -267,6 +268,7 @@ under the License.
+
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 3f4d7bfd3..a1b9a3677 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -41,10 +41,11 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan}
import org.apache.spark.sql.comet.shims.ShimCometShuffleWriteProcessor
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin}
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructField
+import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
@@ -390,7 +391,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
- fromAttributes(outputAttributes),
+ StructType(
+ outputAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))),
recordComparatorSupplier,
prefixComparator,
prefixComputer,
@@ -434,8 +436,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
serializer,
shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
shuffleType = CometColumnarShuffle,
- schema = Some(fromAttributes(outputAttributes)))
-
+ schema = Some(StructType(outputAttributes.map(a =>
+ StructField(a.name, a.dataType, a.nullable, a.metadata)))))
dependency
}
}
diff --git a/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala
new file mode 100644
index 000000000..91e86eea3
--- /dev/null
+++ b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.parquet
+
+import scala.collection.JavaConverters
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.parquet.filter2.predicate.FilterApi
+import org.apache.parquet.hadoop.ParquetInputFormat
+import org.apache.parquet.hadoop.metadata.FileMetaData
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.execution.datasources.PartitionedFile
+import org.apache.spark.sql.execution.datasources.RecordReaderIterator
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
+import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.LegacyBehaviorPolicy
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.{DateType, StructType, TimestampType}
+import org.apache.spark.util.SerializableConfiguration
+
+import org.apache.comet.CometConf
+import org.apache.comet.MetricsSupport
+import org.apache.comet.shims.ShimSQLConf
+import org.apache.comet.vector.CometVector
+
+/**
+ * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's
+ * [[ParquetFileFormat]], but overrides:
+ *
+ * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap
+ * column vector in the whole-stage codegen path.
+ * - `supportBatch`, which simply returns true since data types should have already been checked
+ * in [[org.apache.comet.CometSparkSessionExtensions]]
+ * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values.
+ */
+class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf {
+ override def shortName(): String = "parquet"
+ override def toString: String = "CometParquet"
+ override def hashCode(): Int = getClass.hashCode()
+ override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat]
+
+ override def vectorTypes(
+ requiredSchema: StructType,
+ partitionSchema: StructType,
+ sqlConf: SQLConf): Option[Seq[String]] = {
+ val length = requiredSchema.fields.length + partitionSchema.fields.length
+ Option(Seq.fill(length)(classOf[CometVector].getName))
+ }
+
+ override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true
+
+ override def buildReaderWithPartitionValues(
+ sparkSession: SparkSession,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String],
+ hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
+ val sqlConf = sparkSession.sessionState.conf
+ CometParquetFileFormat.populateConf(sqlConf, hadoopConf)
+ val broadcastedHadoopConf =
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+
+ val isCaseSensitive = sqlConf.caseSensitiveAnalysis
+ val useFieldId = CometParquetUtils.readFieldId(sqlConf)
+ val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf)
+ val pushDownDate = sqlConf.parquetFilterPushDownDate
+ val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp
+ val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal
+ val pushDownStringPredicate = getPushDownStringPredicate(sqlConf)
+ val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
+ val optionsMap = CaseInsensitiveMap[String](options)
+ val parquetOptions = new ParquetOptions(optionsMap, sqlConf)
+ val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead
+ val parquetFilterPushDown = sqlConf.parquetFilterPushDown
+
+ // Comet specific configurations
+ val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf)
+
+ (file: PartitionedFile) => {
+ val sharedConf = broadcastedHadoopConf.value.value
+ val footer = FooterReader.readFooter(sharedConf, file)
+ val footerFileMetaData = footer.getFileMetaData
+ val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec(
+ file,
+ requiredSchema,
+ sharedConf,
+ footerFileMetaData,
+ datetimeRebaseModeInRead)
+
+ val pushed = if (parquetFilterPushDown) {
+ val parquetSchema = footerFileMetaData.getSchema
+ val parquetFilters = new ParquetFilters(
+ parquetSchema,
+ pushDownDate,
+ pushDownTimestamp,
+ pushDownDecimal,
+ pushDownStringPredicate,
+ pushDownInFilterThreshold,
+ isCaseSensitive,
+ datetimeRebaseSpec)
+ filters
+ // Collects all converted Parquet filter predicates. Notice that not all predicates can
+ // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a
+ // `flatMap` is used here.
+ .flatMap(parquetFilters.createFilter)
+ .reduceOption(FilterApi.and)
+ } else {
+ None
+ }
+ pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p))
+
+ val batchReader = new BatchReader(
+ sharedConf,
+ file,
+ footer,
+ capacity,
+ requiredSchema,
+ isCaseSensitive,
+ useFieldId,
+ ignoreMissingIds,
+ datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED,
+ partitionSchema,
+ file.partitionValues,
+ JavaConverters.mapAsJavaMap(metrics))
+ val iter = new RecordReaderIterator(batchReader)
+ try {
+ batchReader.init()
+ iter.asInstanceOf[Iterator[InternalRow]]
+ } catch {
+ case e: Throwable =>
+ iter.close()
+ throw e
+ }
+ }
+ }
+}
+
+object CometParquetFileFormat extends Logging {
+
+ /**
+ * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf`
+ */
+ def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = {
+ hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName)
+ hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone)
+ hadoopConf.setBoolean(
+ SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key,
+ sqlConf.nestedSchemaPruningEnabled)
+ hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis)
+
+ // Sets flags for `ParquetToSparkSchemaConverter`
+ hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString)
+ hadoopConf.setBoolean(
+ SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
+ sqlConf.isParquetINT96AsTimestamp)
+
+ // Comet specific configs
+ hadoopConf.setBoolean(
+ CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key,
+ CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get())
+ hadoopConf.setBoolean(
+ CometConf.COMET_USE_DECIMAL_128.key,
+ CometConf.COMET_USE_DECIMAL_128.get())
+ hadoopConf.setBoolean(
+ CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key,
+ CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get())
+ }
+
+ def getDatetimeRebaseSpec(
+ file: PartitionedFile,
+ sparkSchema: StructType,
+ sharedConf: Configuration,
+ footerFileMetaData: FileMetaData,
+ datetimeRebaseModeInRead: String): RebaseSpec = {
+ val exceptionOnRebase = sharedConf.getBoolean(
+ CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key,
+ CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get)
+ var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec(
+ footerFileMetaData.getKeyValueMetaData.get,
+ datetimeRebaseModeInRead)
+ val hasDateOrTimestamp = sparkSchema.exists(f =>
+ f.dataType match {
+ case DateType | TimestampType => true
+ case _ => false
+ })
+
+ if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) {
+ if (exceptionOnRebase) {
+ logWarning(
+ s"""Found Parquet file $file that could potentially contain dates/timestamps that were
+ written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase
+ and return these according to the new Proleptic Gregorian calendar, Comet will throw
+ exception when reading them. If you want to read them as it is according to the hybrid
+ Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to
+ false. Otherwise, if you want to read them according to the new Proleptic Gregorian
+ calendar, please disable Comet for this query.""")
+ } else {
+ // do not throw exception on rebase - read as it is
+ datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED)
+ }
+ }
+
+ datetimeRebaseSpec
+ }
+}
diff --git a/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
new file mode 100644
index 000000000..787357531
--- /dev/null
+++ b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.parquet
+
+import scala.collection.JavaConverters
+import scala.collection.mutable
+
+import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate}
+import org.apache.parquet.hadoop.ParquetInputFormat
+import org.apache.parquet.hadoop.metadata.ParquetMetadata
+import org.apache.spark.TaskContext
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
+import org.apache.spark.sql.connector.read.InputPartition
+import org.apache.spark.sql.connector.read.PartitionReader
+import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile}
+import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
+import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.LegacyBehaviorPolicy
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.SerializableConfiguration
+
+import org.apache.comet.{CometConf, CometRuntimeException}
+import org.apache.comet.shims.ShimSQLConf
+
+case class CometParquetPartitionReaderFactory(
+ @transient sqlConf: SQLConf,
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ readDataSchema: StructType,
+ partitionSchema: StructType,
+ filters: Array[Filter],
+ options: ParquetOptions,
+ metrics: Map[String, SQLMetric])
+ extends FilePartitionReaderFactory
+ with ShimSQLConf
+ with Logging {
+
+ private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
+ private val useFieldId = CometParquetUtils.readFieldId(sqlConf)
+ private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf)
+ private val pushDownDate = sqlConf.parquetFilterPushDownDate
+ private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp
+ private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal
+ private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf)
+ private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
+ private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead
+ private val parquetFilterPushDown = sqlConf.parquetFilterPushDown
+
+ // Comet specific configurations
+ private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf)
+
+ // This is only called at executor on a Broadcast variable, so we don't want it to be
+ // materialized at driver.
+ @transient private lazy val preFetchEnabled = {
+ val conf = broadcastedConf.value.value
+
+ conf.getBoolean(
+ CometConf.COMET_SCAN_PREFETCH_ENABLED.key,
+ CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get)
+ }
+
+ private var cometReaders: Iterator[BatchReader] = _
+ private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]()
+
+ // TODO: we may want to revisit this as we're going to only support flat types at the beginning
+ override def supportColumnarReads(partition: InputPartition): Boolean = true
+
+ override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = {
+ if (preFetchEnabled) {
+ val filePartition = partition.asInstanceOf[FilePartition]
+ val conf = broadcastedConf.value.value
+
+ val threadNum = conf.getInt(
+ CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key,
+ CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get)
+ val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum)
+
+ this.cometReaders = filePartition.files
+ .map { file =>
+ // `init()` call is deferred to when the prefetch task begins.
+ // Otherwise we will hold too many resources for readers which are not ready
+ // to prefetch.
+ val cometReader = buildCometReader(file)
+ if (cometReader != null) {
+ cometReader.submitPrefetchTask(prefetchThreadPool)
+ }
+
+ cometReader
+ }
+ .toSeq
+ .toIterator
+ }
+
+ super.createColumnarReader(partition)
+ }
+
+ override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] =
+ throw new UnsupportedOperationException("Comet doesn't support 'buildReader'")
+
+ private def buildCometReader(file: PartitionedFile): BatchReader = {
+ val conf = broadcastedConf.value.value
+
+ try {
+ val (datetimeRebaseSpec, footer, filters) = getFilter(file)
+ filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed))
+ val cometReader = new BatchReader(
+ conf,
+ file,
+ footer,
+ batchSize,
+ readDataSchema,
+ isCaseSensitive,
+ useFieldId,
+ ignoreMissingIds,
+ datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED,
+ partitionSchema,
+ file.partitionValues,
+ JavaConverters.mapAsJavaMap(metrics))
+ val taskContext = Option(TaskContext.get)
+ taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close()))
+ return cometReader
+ } catch {
+ case e: Throwable if preFetchEnabled =>
+ // Keep original exception
+ cometReaderExceptionMap.put(file, e)
+ }
+ null
+ }
+
+ override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
+ val cometReader = if (!preFetchEnabled) {
+ // Prefetch is not enabled, create comet reader and initiate it.
+ val cometReader = buildCometReader(file)
+ cometReader.init()
+
+ cometReader
+ } else {
+ // If prefetch is enabled, we already tried to access the file when in `buildCometReader`.
+ // It is possibly we got an exception like `FileNotFoundException` and we need to throw it
+ // now to let Spark handle it.
+ val reader = cometReaders.next()
+ val exception = cometReaderExceptionMap.get(file)
+ exception.foreach(e => throw e)
+
+ if (reader == null) {
+ throw new CometRuntimeException(s"Cannot find comet file reader for $file")
+ }
+ reader
+ }
+ CometPartitionReader(cometReader)
+ }
+
+ def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = {
+ val sharedConf = broadcastedConf.value.value
+ val footer = FooterReader.readFooter(sharedConf, file)
+ val footerFileMetaData = footer.getFileMetaData
+ val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec(
+ file,
+ readDataSchema,
+ sharedConf,
+ footerFileMetaData,
+ datetimeRebaseModeInRead)
+
+ val pushed = if (parquetFilterPushDown) {
+ val parquetSchema = footerFileMetaData.getSchema
+ val parquetFilters = new ParquetFilters(
+ parquetSchema,
+ pushDownDate,
+ pushDownTimestamp,
+ pushDownDecimal,
+ pushDownStringPredicate,
+ pushDownInFilterThreshold,
+ isCaseSensitive,
+ datetimeRebaseSpec)
+ filters
+ // Collects all converted Parquet filter predicates. Notice that not all predicates can be
+ // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
+ // is used here.
+ .flatMap(parquetFilters.createFilter)
+ .reduceOption(FilterApi.and)
+ } else {
+ None
+ }
+ (datetimeRebaseSpec, footer, pushed)
+ }
+
+ override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] =
+ throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.")
+
+ /**
+ * A simple adapter on Comet's [[BatchReader]].
+ */
+ protected case class CometPartitionReader(reader: BatchReader)
+ extends PartitionReader[ColumnarBatch] {
+
+ override def next(): Boolean = {
+ reader.nextBatch()
+ }
+
+ override def get(): ColumnarBatch = {
+ reader.currentBatch()
+ }
+
+ override def close(): Unit = {
+ reader.close()
+ }
+ }
+}
diff --git a/spark/src/main/spark-3.5/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-3.5/org/apache/comet/parquet/ParquetFilters.scala
new file mode 100644
index 000000000..3cd434418
--- /dev/null
+++ b/spark/src/main/spark-3.5/org/apache/comet/parquet/ParquetFilters.scala
@@ -0,0 +1,882 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.parquet
+
+import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort}
+import java.math.{BigDecimal => JBigDecimal}
+import java.sql.{Date, Timestamp}
+import java.time.{Duration, Instant, LocalDate, Period}
+import java.util.Locale
+
+import scala.collection.JavaConverters.asScalaBufferConverter
+
+import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics}
+import org.apache.parquet.filter2.predicate._
+import org.apache.parquet.filter2.predicate.SparkFilterApi._
+import org.apache.parquet.io.api.Binary
+import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type}
+import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit}
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
+import org.apache.parquet.schema.Type.Repetition
+import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils}
+import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec}
+import org.apache.spark.sql.internal.LegacyBehaviorPolicy
+import org.apache.spark.sql.sources
+import org.apache.spark.unsafe.types.UTF8String
+
+import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
+
+/**
+ * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove
+ * this duplication
+ *
+ * Some utility function to convert Spark data source filters to Parquet filters.
+ */
+class ParquetFilters(
+ schema: MessageType,
+ pushDownDate: Boolean,
+ pushDownTimestamp: Boolean,
+ pushDownDecimal: Boolean,
+ pushDownStringPredicate: Boolean,
+ pushDownInFilterThreshold: Int,
+ caseSensitive: Boolean,
+ datetimeRebaseSpec: RebaseSpec) {
+ // A map which contains parquet field name and data type, if predicate push down applies.
+ //
+ // Each key in `nameToParquetField` represents a column; `dots` are used as separators for
+ // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion.
+ // See `org.apache.spark.sql.connector.catalog.quote` for implementation details.
+ private val nameToParquetField: Map[String, ParquetPrimitiveField] = {
+ // Recursively traverse the parquet schema to get primitive fields that can be pushed-down.
+ // `parentFieldNames` is used to keep track of the current nested level when traversing.
+ def getPrimitiveFields(
+ fields: Seq[Type],
+ parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = {
+ fields.flatMap {
+ // Parquet only supports predicate push-down for non-repeated primitive types.
+ // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for
+ // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34)
+ case p: PrimitiveType if p.getRepetition != Repetition.REPEATED =>
+ Some(
+ ParquetPrimitiveField(
+ fieldNames = parentFieldNames :+ p.getName,
+ fieldType = ParquetSchemaType(
+ p.getLogicalTypeAnnotation,
+ p.getPrimitiveTypeName,
+ p.getTypeLength)))
+ // Note that when g is a `Struct`, `g.getOriginalType` is `null`.
+ // When g is a `Map`, `g.getOriginalType` is `MAP`.
+ // When g is a `List`, `g.getOriginalType` is `LIST`.
+ case g: GroupType if g.getOriginalType == null =>
+ getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName)
+ // Parquet only supports push-down for primitive types; as a result, Map and List types
+ // are removed.
+ case _ => None
+ }
+ }
+
+ val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field =>
+ (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field)
+ }
+ if (caseSensitive) {
+ primitiveFields.toMap
+ } else {
+ // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive
+ // mode, just skip pushdown for these fields, they will trigger Exception when reading,
+ // See: SPARK-25132.
+ val dedupPrimitiveFields =
+ primitiveFields
+ .groupBy(_._1.toLowerCase(Locale.ROOT))
+ .filter(_._2.size == 1)
+ .mapValues(_.head._2)
+ CaseInsensitiveMap(dedupPrimitiveFields.toMap)
+ }
+ }
+
+ /**
+ * Holds a single primitive field information stored in the underlying parquet file.
+ *
+ * @param fieldNames
+ * a field name as an array of string multi-identifier in parquet file
+ * @param fieldType
+ * field type related info in parquet file
+ */
+ private case class ParquetPrimitiveField(
+ fieldNames: Array[String],
+ fieldType: ParquetSchemaType)
+
+ private case class ParquetSchemaType(
+ logicalTypeAnnotation: LogicalTypeAnnotation,
+ primitiveTypeName: PrimitiveTypeName,
+ length: Int)
+
+ private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0)
+ private val ParquetByteType =
+ ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0)
+ private val ParquetShortType =
+ ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0)
+ private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0)
+ private val ParquetLongType = ParquetSchemaType(null, INT64, 0)
+ private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0)
+ private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0)
+ private val ParquetStringType =
+ ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0)
+ private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0)
+ private val ParquetDateType =
+ ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0)
+ private val ParquetTimestampMicrosType =
+ ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0)
+ private val ParquetTimestampMillisType =
+ ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0)
+
+ private def dateToDays(date: Any): Int = {
+ val gregorianDays = date match {
+ case d: Date => DateTimeUtils.fromJavaDate(d)
+ case ld: LocalDate => DateTimeUtils.localDateToDays(ld)
+ }
+ datetimeRebaseSpec.mode match {
+ case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays)
+ case _ => gregorianDays
+ }
+ }
+
+ private def timestampToMicros(v: Any): JLong = {
+ val gregorianMicros = v match {
+ case i: Instant => DateTimeUtils.instantToMicros(i)
+ case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t)
+ }
+ datetimeRebaseSpec.mode match {
+ case LegacyBehaviorPolicy.LEGACY =>
+ rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros)
+ case _ => gregorianMicros
+ }
+ }
+
+ private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue()
+
+ private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue()
+
+ private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = {
+ val decimalBuffer = new Array[Byte](numBytes)
+ val bytes = decimal.unscaledValue().toByteArray
+
+ val fixedLengthBytes = if (bytes.length == numBytes) {
+ bytes
+ } else {
+ val signByte = if (bytes.head < 0) -1: Byte else 0: Byte
+ java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte)
+ System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length)
+ decimalBuffer
+ }
+ Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes)
+ }
+
+ private def timestampToMillis(v: Any): JLong = {
+ val micros = timestampToMicros(v)
+ val millis = DateTimeUtils.microsToMillis(micros)
+ millis.asInstanceOf[JLong]
+ }
+
+ private def toIntValue(v: Any): Integer = {
+ Option(v)
+ .map {
+ case p: Period => IntervalUtils.periodToMonths(p)
+ case n => n.asInstanceOf[Number].intValue
+ }
+ .map(_.asInstanceOf[Integer])
+ .orNull
+ }
+
+ private def toLongValue(v: Any): JLong = v match {
+ case d: Duration => IntervalUtils.durationToMicros(d)
+ case l => l.asInstanceOf[JLong]
+ }
+
+ private val makeEq
+ : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetBooleanType =>
+ (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean])
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.eq(
+ intColumn(n),
+ Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull)
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ // Binary.fromString and Binary.fromByteArray don't accept null values
+ case ParquetStringType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.eq(
+ binaryColumn(n),
+ Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.eq(
+ binaryColumn(n),
+ Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull)
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) =>
+ FilterApi.eq(
+ intColumn(n),
+ Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull)
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) =>
+ FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull)
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) =>
+ FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull)
+
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.eq(
+ intColumn(n),
+ Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull)
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.eq(
+ longColumn(n),
+ Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull)
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length)
+ if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.eq(
+ binaryColumn(n),
+ Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull)
+ }
+
+ private val makeNotEq
+ : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetBooleanType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean])
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.notEq(
+ intColumn(n),
+ Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull)
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ case ParquetStringType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.notEq(
+ binaryColumn(n),
+ Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.notEq(
+ binaryColumn(n),
+ Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull)
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) =>
+ FilterApi.notEq(
+ intColumn(n),
+ Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull)
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) =>
+ FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull)
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) =>
+ FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull)
+
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.notEq(
+ intColumn(n),
+ Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull)
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.notEq(
+ longColumn(n),
+ Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull)
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length)
+ if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.notEq(
+ binaryColumn(n),
+ Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull)
+ }
+
+ private val makeLt
+ : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ case ParquetStringType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer])
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v))
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v))
+
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length)
+ if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length))
+ }
+
+ private val makeLtEq
+ : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ case ParquetStringType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer])
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v))
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v))
+
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length)
+ if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length))
+ }
+
+ private val makeGt
+ : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ case ParquetStringType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer])
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v))
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v))
+
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length)
+ if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length))
+ }
+
+ private val makeGtEq
+ : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ case ParquetStringType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer])
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v))
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v))
+
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length)
+ if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length))
+ }
+
+ private val makeInPredicate: PartialFunction[
+ ParquetSchemaType,
+ (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(toIntValue(_).toInt).foreach(statistics.updateStats)
+ FilterApi.and(
+ FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]),
+ FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer]))
+
+ case ParquetLongType =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(toLongValue).foreach(statistics.updateStats(_))
+ FilterApi.and(
+ FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]),
+ FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong]))
+
+ case ParquetFloatType =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_))
+ FilterApi.and(
+ FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]),
+ FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat]))
+
+ case ParquetDoubleType =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_))
+ FilterApi.and(
+ FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]),
+ FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble]))
+
+ case ParquetStringType =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats)
+ FilterApi.and(
+ FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]),
+ FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary]))
+
+ case ParquetBinaryType =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]]))
+ .foreach(statistics.updateStats)
+ FilterApi.and(
+ FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]),
+ FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary]))
+
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_))
+ FilterApi.and(
+ FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]),
+ FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer]))
+
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(timestampToMicros).foreach(statistics.updateStats(_))
+ FilterApi.and(
+ FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]),
+ FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong]))
+
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(timestampToMillis).foreach(statistics.updateStats(_))
+ FilterApi.and(
+ FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]),
+ FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong]))
+
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_))
+ FilterApi.and(
+ FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]),
+ FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer]))
+
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal =>
+ (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_))
+ FilterApi.and(
+ FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]),
+ FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong]))
+
+ case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length)
+ if pushDownDecimal =>
+ (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) =>
+ v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length))
+ .foreach(statistics.updateStats)
+ FilterApi.and(
+ FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]),
+ FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary]))
+ }
+
+ // Returns filters that can be pushed down when reading Parquet files.
+ def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = {
+ filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true))
+ }
+
+ private def convertibleFiltersHelper(
+ predicate: sources.Filter,
+ canPartialPushDown: Boolean): Option[sources.Filter] = {
+ predicate match {
+ case sources.And(left, right) =>
+ val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown)
+ val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown)
+ (leftResultOptional, rightResultOptional) match {
+ case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult))
+ case (Some(leftResult), None) if canPartialPushDown => Some(leftResult)
+ case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult)
+ case _ => None
+ }
+
+ case sources.Or(left, right) =>
+ val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown)
+ val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown)
+ if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) {
+ None
+ } else {
+ Some(sources.Or(leftResultOptional.get, rightResultOptional.get))
+ }
+ case sources.Not(pred) =>
+ val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false)
+ resultOptional.map(sources.Not)
+
+ case other =>
+ if (createFilter(other).isDefined) {
+ Some(other)
+ } else {
+ None
+ }
+ }
+ }
+
+ /**
+ * Converts data sources filters to Parquet filter predicates.
+ */
+ def createFilter(predicate: sources.Filter): Option[FilterPredicate] = {
+ createFilterHelper(predicate, canPartialPushDownConjuncts = true)
+ }
+
+ // Parquet's type in the given file should be matched to the value's type
+ // in the pushed filter in order to push down the filter to Parquet.
+ private def valueCanMakeFilterOn(name: String, value: Any): Boolean = {
+ value == null || (nameToParquetField(name).fieldType match {
+ case ParquetBooleanType => value.isInstanceOf[JBoolean]
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ if (isSpark34Plus) {
+ value match {
+ // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type
+ // Int. We don't create a filter if the value would overflow.
+ case _: JByte | _: JShort | _: Integer => true
+ case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue
+ case _ => false
+ }
+ } else {
+ // If not Spark 3.4+, we still following the old behavior as Spark does.
+ value.isInstanceOf[Number]
+ }
+ case ParquetLongType => value.isInstanceOf[JLong]
+ case ParquetFloatType => value.isInstanceOf[JFloat]
+ case ParquetDoubleType => value.isInstanceOf[JDouble]
+ case ParquetStringType => value.isInstanceOf[String]
+ case ParquetBinaryType => value.isInstanceOf[Array[Byte]]
+ case ParquetDateType =>
+ value.isInstanceOf[Date] || value.isInstanceOf[LocalDate]
+ case ParquetTimestampMicrosType | ParquetTimestampMillisType =>
+ value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant]
+ case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) =>
+ isDecimalMatched(value, decimalType)
+ case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) =>
+ isDecimalMatched(value, decimalType)
+ case ParquetSchemaType(
+ decimalType: DecimalLogicalTypeAnnotation,
+ FIXED_LEN_BYTE_ARRAY,
+ _) =>
+ isDecimalMatched(value, decimalType)
+ case _ => false
+ })
+ }
+
+ // Decimal type must make sure that filter value's scale matched the file.
+ // If doesn't matched, which would cause data corruption.
+ private def isDecimalMatched(
+ value: Any,
+ decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match {
+ case decimal: JBigDecimal =>
+ decimal.scale == decimalLogicalType.getScale
+ case _ => false
+ }
+
+ private def canMakeFilterOn(name: String, value: Any): Boolean = {
+ nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value)
+ }
+
+ /**
+ * @param predicate
+ * the input filter predicates. Not all the predicates can be pushed down.
+ * @param canPartialPushDownConjuncts
+ * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one
+ * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR.
+ * @return
+ * the Parquet-native filter predicates that are eligible for pushdown.
+ */
+ private def createFilterHelper(
+ predicate: sources.Filter,
+ canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = {
+ // NOTE:
+ //
+ // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`,
+ // which can be casted to `false` implicitly. Please refer to the `eval` method of these
+ // operators and the `PruneFilters` rule for details.
+
+ // Hyukjin:
+ // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]].
+ // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]].
+ // The reason why I did this is, that the actual Parquet filter checks null-safe equality
+ // comparison.
+ // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because
+ // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc.
+ // Probably I missed something and obviously this should be changed.
+
+ predicate match {
+ case sources.IsNull(name) if canMakeFilterOn(name, null) =>
+ makeEq
+ .lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, null))
+ case sources.IsNotNull(name) if canMakeFilterOn(name, null) =>
+ makeNotEq
+ .lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, null))
+
+ case sources.EqualTo(name, value) if canMakeFilterOn(name, value) =>
+ makeEq
+ .lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+ case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) =>
+ makeNotEq
+ .lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+
+ case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) =>
+ makeEq
+ .lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+ case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) =>
+ makeNotEq
+ .lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+
+ case sources.LessThan(name, value) if canMakeFilterOn(name, value) =>
+ makeLt
+ .lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+ case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
+ makeLtEq
+ .lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+
+ case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) =>
+ makeGt
+ .lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+ case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
+ makeGtEq
+ .lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+
+ case sources.And(lhs, rhs) =>
+ // At here, it is not safe to just convert one side and remove the other side
+ // if we do not understand what the parent filters are.
+ //
+ // Here is an example used to explain the reason.
+ // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to
+ // convert b in ('1'). If we only convert a = 2, we will end up with a filter
+ // NOT(a = 2), which will generate wrong results.
+ //
+ // Pushing one side of AND down is only safe to do at the top level or in the child
+ // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate
+ // can be safely removed.
+ val lhsFilterOption =
+ createFilterHelper(lhs, canPartialPushDownConjuncts)
+ val rhsFilterOption =
+ createFilterHelper(rhs, canPartialPushDownConjuncts)
+
+ (lhsFilterOption, rhsFilterOption) match {
+ case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter))
+ case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter)
+ case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter)
+ case _ => None
+ }
+
+ case sources.Or(lhs, rhs) =>
+ // The Or predicate is convertible when both of its children can be pushed down.
+ // That is to say, if one/both of the children can be partially pushed down, the Or
+ // predicate can be partially pushed down as well.
+ //
+ // Here is an example used to explain the reason.
+ // Let's say we have
+ // (a1 AND a2) OR (b1 AND b2),
+ // a1 and b1 is convertible, while a2 and b2 is not.
+ // The predicate can be converted as
+ // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2)
+ // As per the logical in And predicate, we can push down (a1 OR b1).
+ for {
+ lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts)
+ rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts)
+ } yield FilterApi.or(lhsFilter, rhsFilter)
+
+ case sources.Not(pred) =>
+ createFilterHelper(pred, canPartialPushDownConjuncts = false)
+ .map(FilterApi.not)
+
+ case sources.In(name, values)
+ if pushDownInFilterThreshold > 0 && values.nonEmpty &&
+ canMakeFilterOn(name, values.head) =>
+ val fieldType = nameToParquetField(name).fieldType
+ val fieldNames = nameToParquetField(name).fieldNames
+ if (values.length <= pushDownInFilterThreshold) {
+ values.distinct
+ .flatMap { v =>
+ makeEq.lift(fieldType).map(_(fieldNames, v))
+ }
+ .reduceLeftOption(FilterApi.or)
+ } else if (canPartialPushDownConjuncts) {
+ val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType
+ val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType)
+ if (values.contains(null)) {
+ Seq(
+ makeEq.lift(fieldType).map(_(fieldNames, null)),
+ makeInPredicate
+ .lift(fieldType)
+ .map(_(fieldNames, values.filter(_ != null), statistics))).flatten
+ .reduceLeftOption(FilterApi.or)
+ } else {
+ makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics))
+ }
+ } else {
+ None
+ }
+
+ case sources.StringStartsWith(name, prefix)
+ if pushDownStringPredicate && canMakeFilterOn(name, prefix) =>
+ Option(prefix).map { v =>
+ FilterApi.userDefined(
+ binaryColumn(nameToParquetField(name).fieldNames),
+ new UserDefinedPredicate[Binary] with Serializable {
+ private val strToBinary = Binary.fromReusedByteArray(v.getBytes)
+ private val size = strToBinary.length
+
+ override def canDrop(statistics: Statistics[Binary]): Boolean = {
+ val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR
+ val max = statistics.getMax
+ val min = statistics.getMin
+ comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 ||
+ comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0
+ }
+
+ override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = {
+ val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR
+ val max = statistics.getMax
+ val min = statistics.getMin
+ comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 &&
+ comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0
+ }
+
+ override def keep(value: Binary): Boolean = {
+ value != null && UTF8String
+ .fromBytes(value.getBytes)
+ .startsWith(UTF8String.fromBytes(strToBinary.getBytes))
+ }
+ })
+ }
+
+ case sources.StringEndsWith(name, suffix)
+ if pushDownStringPredicate && canMakeFilterOn(name, suffix) =>
+ Option(suffix).map { v =>
+ FilterApi.userDefined(
+ binaryColumn(nameToParquetField(name).fieldNames),
+ new UserDefinedPredicate[Binary] with Serializable {
+ private val suffixStr = UTF8String.fromString(v)
+ override def canDrop(statistics: Statistics[Binary]): Boolean = false
+ override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false
+ override def keep(value: Binary): Boolean = {
+ value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr)
+ }
+ })
+ }
+
+ case sources.StringContains(name, value)
+ if pushDownStringPredicate && canMakeFilterOn(name, value) =>
+ Option(value).map { v =>
+ FilterApi.userDefined(
+ binaryColumn(nameToParquetField(name).fieldNames),
+ new UserDefinedPredicate[Binary] with Serializable {
+ private val subStr = UTF8String.fromString(v)
+ override def canDrop(statistics: Statistics[Binary]): Boolean = false
+ override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false
+ override def keep(value: Binary): Boolean = {
+ value != null && UTF8String.fromBytes(value.getBytes).contains(subStr)
+ }
+ })
+ }
+
+ case _ => None
+ }
+ }
+}
diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala
new file mode 100644
index 000000000..409e1c94b
--- /dev/null
+++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions.
+ */
+trait CometExprShim {
+ /**
+ * Returns a tuple of expressions for the `unhex` function.
+ */
+ def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+ (unhex.child, Literal(unhex.failOnError))
+ }
+}
diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala
new file mode 100644
index 000000000..777248b41
--- /dev/null
+++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala
@@ -0,0 +1,483 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet
+
+import scala.collection.mutable.HashMap
+import scala.concurrent.duration.NANOSECONDS
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.fs.Path
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
+import org.apache.spark.sql.execution.metric._
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.util.collection._
+
+import org.apache.comet.{CometConf, MetricsSupport}
+import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory}
+import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat}
+
+/**
+ * Comet physical scan node for DataSource V1. Most of the code here follow Spark's
+ * [[FileSourceScanExec]],
+ */
+case class CometScanExec(
+ @transient relation: HadoopFsRelation,
+ output: Seq[Attribute],
+ requiredSchema: StructType,
+ partitionFilters: Seq[Expression],
+ optionalBucketSet: Option[BitSet],
+ optionalNumCoalescedBuckets: Option[Int],
+ dataFilters: Seq[Expression],
+ tableIdentifier: Option[TableIdentifier],
+ disableBucketedScan: Boolean = false,
+ wrapped: FileSourceScanExec)
+ extends DataSourceScanExec
+ with ShimCometScanExec
+ with CometPlan {
+
+ // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests
+ override lazy val supportsColumnar: Boolean =
+ relation.fileFormat.supportBatch(relation.sparkSession, schema)
+
+ override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes
+
+ private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty
+
+ /**
+ * Send the driver-side metrics. Before calling this function, selectedPartitions has been
+ * initialized. See SPARK-26327 for more details.
+ */
+ private def sendDriverMetrics(): Unit = {
+ driverMetrics.foreach(e => metrics(e._1).add(e._2))
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(
+ sparkContext,
+ executionId,
+ metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq)
+ }
+
+ private def isDynamicPruningFilter(e: Expression): Boolean =
+ e.find(_.isInstanceOf[PlanExpression[_]]).isDefined
+
+ @transient lazy val selectedPartitions: Array[PartitionDirectory] = {
+ val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
+ val startTime = System.nanoTime()
+ val ret =
+ relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters)
+ setFilesNumAndSizeMetric(ret, true)
+ val timeTakenMs =
+ NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs)
+ driverMetrics("metadataTime") = timeTakenMs
+ ret
+ }.toArray
+
+ // We can only determine the actual partitions at runtime when a dynamic partition filter is
+ // present. This is because such a filter relies on information that is only available at run
+ // time (for instance the keys used in the other side of a join).
+ @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = {
+ val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter)
+
+ if (dynamicPartitionFilters.nonEmpty) {
+ val startTime = System.nanoTime()
+ // call the file index for the files matching all filters except dynamic partition filters
+ val predicate = dynamicPartitionFilters.reduce(And)
+ val partitionColumns = relation.partitionSchema
+ val boundPredicate = Predicate.create(
+ predicate.transform { case a: AttributeReference =>
+ val index = partitionColumns.indexWhere(a.name == _.name)
+ BoundReference(index, partitionColumns(index).dataType, nullable = true)
+ },
+ Nil)
+ val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values))
+ setFilesNumAndSizeMetric(ret, false)
+ val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000
+ driverMetrics("pruningTime") = timeTakenMs
+ ret
+ } else {
+ selectedPartitions
+ }
+ }
+
+ // exposed for testing
+ lazy val bucketedScan: Boolean = wrapped.bucketedScan
+
+ override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) =
+ (wrapped.outputPartitioning, wrapped.outputOrdering)
+
+ @transient
+ private lazy val pushedDownFilters = {
+ val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
+ dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
+ }
+
+ override lazy val metadata: Map[String, String] =
+ if (wrapped == null) Map.empty else wrapped.metadata
+
+ override def verboseStringWithOperatorId(): String = {
+ getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id))
+ wrapped.verboseStringWithOperatorId()
+ }
+
+ lazy val inputRDD: RDD[InternalRow] = {
+ val options = relation.options +
+ (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString)
+ val readFile: (PartitionedFile) => Iterator[InternalRow] =
+ relation.fileFormat.buildReaderWithPartitionValues(
+ sparkSession = relation.sparkSession,
+ dataSchema = relation.dataSchema,
+ partitionSchema = relation.partitionSchema,
+ requiredSchema = requiredSchema,
+ filters = pushedDownFilters,
+ options = options,
+ hadoopConf =
+ relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))
+
+ val readRDD = if (bucketedScan) {
+ createBucketedReadRDD(
+ relation.bucketSpec.get,
+ readFile,
+ dynamicallySelectedPartitions,
+ relation)
+ } else {
+ createReadRDD(readFile, dynamicallySelectedPartitions, relation)
+ }
+ sendDriverMetrics()
+ readRDD
+ }
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ inputRDD :: Nil
+ }
+
+ /** Helper for computing total number and size of files in selected partitions. */
+ private def setFilesNumAndSizeMetric(
+ partitions: Seq[PartitionDirectory],
+ static: Boolean): Unit = {
+ val filesNum = partitions.map(_.files.size.toLong).sum
+ val filesSize = partitions.map(_.files.map(_.getLen).sum).sum
+ if (!static || !partitionFilters.exists(isDynamicPruningFilter)) {
+ driverMetrics("numFiles") = filesNum
+ driverMetrics("filesSize") = filesSize
+ } else {
+ driverMetrics("staticFilesNum") = filesNum
+ driverMetrics("staticFilesSize") = filesSize
+ }
+ if (relation.partitionSchema.nonEmpty) {
+ driverMetrics("numPartitions") = partitions.length
+ }
+ }
+
+ override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ {
+ // Tracking scan time has overhead, we can't afford to do it for each row, and can only do
+ // it for each batch.
+ if (supportsColumnar) {
+ Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
+ } else {
+ None
+ }
+ } ++ {
+ relation.fileFormat match {
+ case f: MetricsSupport => f.initMetrics(sparkContext)
+ case _ => Map.empty
+ }
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ ColumnarToRowExec(this).doExecute()
+ }
+
+ protected override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ val numOutputRows = longMetric("numOutputRows")
+ val scanTime = longMetric("scanTime")
+ inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches =>
+ new Iterator[ColumnarBatch] {
+
+ override def hasNext: Boolean = {
+ // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call.
+ val startNs = System.nanoTime()
+ val res = batches.hasNext
+ scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs)
+ res
+ }
+
+ override def next(): ColumnarBatch = {
+ val batch = batches.next()
+ numOutputRows += batch.numRows()
+ batch
+ }
+ }
+ }
+ }
+
+ override def executeCollect(): Array[InternalRow] = {
+ ColumnarToRowExec(this).executeCollect()
+ }
+
+ override val nodeName: String =
+ s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}"
+
+ /**
+ * Create an RDD for bucketed reads. The non-bucketed variant of this function is
+ * [[createReadRDD]].
+ *
+ * The algorithm is pretty simple: each RDD partition being returned should include all the
+ * files with the same bucket id from all the given Hive partitions.
+ *
+ * @param bucketSpec
+ * the bucketing spec.
+ * @param readFile
+ * a function to read each (part of a) file.
+ * @param selectedPartitions
+ * Hive-style partition that are part of the read.
+ * @param fsRelation
+ * [[HadoopFsRelation]] associated with the read.
+ */
+ private def createBucketedReadRDD(
+ bucketSpec: BucketSpec,
+ readFile: (PartitionedFile) => Iterator[InternalRow],
+ selectedPartitions: Array[PartitionDirectory],
+ fsRelation: HadoopFsRelation): RDD[InternalRow] = {
+ logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
+ val filesGroupedToBuckets =
+ selectedPartitions
+ .flatMap { p =>
+ p.files.map { f =>
+ PartitionedFileUtil.getPartitionedFile(f, /*f.getPath,*/ p.values)
+ }
+ }
+ .groupBy { f =>
+ BucketingUtils
+ .getBucketId(new Path(f.filePath.toString()).getName)
+ .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version))
+ }
+
+ val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
+ val bucketSet = optionalBucketSet.get
+ filesGroupedToBuckets.filter { f =>
+ bucketSet.get(f._1)
+ }
+ } else {
+ filesGroupedToBuckets
+ }
+
+ val filePartitions = optionalNumCoalescedBuckets
+ .map { numCoalescedBuckets =>
+ logInfo(s"Coalescing to ${numCoalescedBuckets} buckets")
+ val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets)
+ Seq.tabulate(numCoalescedBuckets) { bucketId =>
+ val partitionedFiles = coalescedBuckets
+ .get(bucketId)
+ .map {
+ _.values.flatten.toArray
+ }
+ .getOrElse(Array.empty)
+ FilePartition(bucketId, partitionedFiles)
+ }
+ }
+ .getOrElse {
+ Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
+ FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty))
+ }
+ }
+
+ prepareRDD(fsRelation, readFile, filePartitions)
+ }
+
+ /**
+ * Create an RDD for non-bucketed reads. The bucketed variant of this function is
+ * [[createBucketedReadRDD]].
+ *
+ * @param readFile
+ * a function to read each (part of a) file.
+ * @param selectedPartitions
+ * Hive-style partition that are part of the read.
+ * @param fsRelation
+ * [[HadoopFsRelation]] associated with the read.
+ */
+ private def createReadRDD(
+ readFile: (PartitionedFile) => Iterator[InternalRow],
+ selectedPartitions: Array[PartitionDirectory],
+ fsRelation: HadoopFsRelation): RDD[InternalRow] = {
+ val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes
+ val maxSplitBytes =
+ FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions)
+ logInfo(
+ s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
+ s"open cost is considered as scanning $openCostInBytes bytes.")
+
+ // Filter files with bucket pruning if possible
+ val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled
+ val shouldProcess: Path => Boolean = optionalBucketSet match {
+ case Some(bucketSet) if bucketingEnabled =>
+ // Do not prune the file if bucket file name is invalid
+ filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get)
+ case _ =>
+ _ => true
+ }
+
+ val splitFiles = selectedPartitions
+ .flatMap { partition =>
+ partition.files.flatMap { file =>
+ // getPath() is very expensive so we only want to call it once in this block:
+ val filePath = file.getPath
+
+ if (shouldProcess(filePath)) {
+ val isSplitable = relation.fileFormat.isSplitable(
+ relation.sparkSession,
+ relation.options,
+ filePath) &&
+ // SPARK-39634: Allow file splitting in combination with row index generation once
+ // the fix for PARQUET-2161 is available.
+ !isNeededForSchema(requiredSchema)
+ PartitionedFileUtil.splitFiles(
+ sparkSession = relation.sparkSession,
+ file = file,
+ //filePath = filePath,
+ isSplitable = isSplitable,
+ maxSplitBytes = maxSplitBytes,
+ partitionValues = partition.values)
+ } else {
+ Seq.empty
+ }
+ }
+ }
+ .sortBy(_.length)(implicitly[Ordering[Long]].reverse)
+
+ prepareRDD(
+ fsRelation,
+ readFile,
+ FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes))
+ }
+
+ private def prepareRDD(
+ fsRelation: HadoopFsRelation,
+ readFile: (PartitionedFile) => Iterator[InternalRow],
+ partitions: Seq[FilePartition]): RDD[InternalRow] = {
+ val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)
+ val prefetchEnabled = hadoopConf.getBoolean(
+ CometConf.COMET_SCAN_PREFETCH_ENABLED.key,
+ CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get)
+
+ val sqlConf = fsRelation.sparkSession.sessionState.conf
+ if (prefetchEnabled) {
+ CometParquetFileFormat.populateConf(sqlConf, hadoopConf)
+ val broadcastedConf =
+ fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+ val partitionReaderFactory = CometParquetPartitionReaderFactory(
+ sqlConf,
+ broadcastedConf,
+ requiredSchema,
+ relation.partitionSchema,
+ pushedDownFilters.toArray,
+ new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf),
+ metrics)
+
+ newDataSourceRDD(
+ fsRelation.sparkSession.sparkContext,
+ partitions.map(Seq(_)),
+ partitionReaderFactory,
+ true,
+ Map.empty)
+ } else {
+ newFileScanRDD(
+ fsRelation.sparkSession,
+ readFile,
+ partitions,
+ new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields),
+ new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf))
+ }
+ }
+
+ // Filters unused DynamicPruningExpression expressions - one which has been replaced
+ // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning
+ private def filterUnusedDynamicPruningExpressions(
+ predicates: Seq[Expression]): Seq[Expression] = {
+ predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral))
+ }
+
+ override def doCanonicalize(): CometScanExec = {
+ CometScanExec(
+ relation,
+ output.map(QueryPlan.normalizeExpressions(_, output)),
+ requiredSchema,
+ QueryPlan.normalizePredicates(
+ filterUnusedDynamicPruningExpressions(partitionFilters),
+ output),
+ optionalBucketSet,
+ optionalNumCoalescedBuckets,
+ QueryPlan.normalizePredicates(dataFilters, output),
+ None,
+ disableBucketedScan,
+ null)
+ }
+}
+
+object CometScanExec {
+ def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = {
+ // TreeNode.mapProductIterator is protected method.
+ def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = {
+ val arr = Array.ofDim[B](product.productArity)
+ var i = 0
+ while (i < arr.length) {
+ arr(i) = f(product.productElement(i))
+ i += 1
+ }
+ arr
+ }
+
+ // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues
+ // on other Spark distributions if FileSourceScanExec constructor is changed.
+ // Using `makeCopy` to avoid the issue.
+ // https://github.com/apache/arrow-datafusion-comet/issues/190
+ def transform(arg: Any): AnyRef = arg match {
+ case _: HadoopFsRelation =>
+ scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session)
+ case other: AnyRef => other
+ case null => null
+ }
+ val newArgs = mapProductIterator(scanExec, transform(_))
+ val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec]
+ val batchScanExec = CometScanExec(
+ wrapped.relation,
+ wrapped.output,
+ wrapped.requiredSchema,
+ wrapped.partitionFilters,
+ wrapped.optionalBucketSet,
+ wrapped.optionalNumCoalescedBuckets,
+ wrapped.dataFilters,
+ wrapped.tableIdentifier,
+ wrapped.disableBucketedScan,
+ wrapped)
+ scanExec.logicalLink.foreach(batchScanExec.setLogicalLink)
+ batchScanExec
+ }
+}
\ No newline at end of file
diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
new file mode 100644
index 000000000..543116c10
--- /dev/null
+++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
+import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.SparkContext
+
+trait ShimCometScanExec {
+ def wrapped: FileSourceScanExec
+
+ lazy val fileConstantMetadataColumns: Seq[AttributeReference] =
+ wrapped.fileConstantMetadataColumns
+
+ protected def newDataSourceRDD(
+ sc: SparkContext,
+ inputPartitions: Seq[Seq[InputPartition]],
+ partitionReaderFactory: PartitionReaderFactory,
+ columnarReads: Boolean,
+ customMetrics: Map[String, SQLMetric]): DataSourceRDD =
+ new DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads, customMetrics)
+
+ protected def newFileScanRDD(
+ fsRelation: HadoopFsRelation,
+ readFunction: PartitionedFile => Iterator[InternalRow],
+ filePartitions: Seq[FilePartition],
+ readSchema: StructType,
+ options: ParquetOptions): FileScanRDD = {
+ new FileScanRDD(
+ fsRelation.sparkSession,
+ readFunction,
+ filePartitions,
+ readSchema,
+ fileConstantMetadataColumns,
+ fsRelation.fileFormat.fileConstantMetadataExtractors,
+ options)
+ }
+
+ protected def invalidBucketFile(path: String, sparkVersion: String): Throwable =
+ QueryExecutionErrors.invalidBucketFile(path)
+
+ // see SPARK-39634
+ protected def isNeededForSchema(sparkSchema: StructType): Boolean = false
+
+ protected def getPartitionedFile(f: FileStatusWithMetadata, p: PartitionDirectory): PartitionedFile =
+ PartitionedFileUtil.getPartitionedFile(f, p.values, 0, f.getLen)
+
+ protected def splitFiles(sparkSession: SparkSession,
+ file: FileStatusWithMetadata,
+ filePath: Path,
+ isSplitable: Boolean,
+ maxSplitBytes: Long,
+ partitionValues: InternalRow): Seq[PartitionedFile] =
+ PartitionedFileUtil.splitFiles(file, isSplitable, maxSplitBytes, partitionValues)
+}
diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala
similarity index 100%
rename from spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala
diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
similarity index 100%
rename from spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/ParquetFilters.scala
similarity index 100%
rename from spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/ParquetFilters.scala
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala
similarity index 100%
rename from spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
rename to spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala
diff --git a/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
similarity index 100%
rename from spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
rename to spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala