diff --git a/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala deleted file mode 100644 index ac871cf60..000000000 --- a/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import scala.collection.JavaConverters - -import org.apache.hadoop.conf.Configuration -import org.apache.parquet.filter2.predicate.FilterApi -import org.apache.parquet.hadoop.ParquetInputFormat -import org.apache.parquet.hadoop.metadata.FileMetaData -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.execution.datasources.DataSourceUtils -import org.apache.spark.sql.execution.datasources.PartitionedFile -import org.apache.spark.sql.execution.datasources.RecordReaderIterator -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{DateType, StructType, TimestampType} -import org.apache.spark.util.SerializableConfiguration - -import org.apache.comet.CometConf -import org.apache.comet.MetricsSupport -import org.apache.comet.shims.ShimSQLConf -import org.apache.comet.vector.CometVector - -/** - * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's - * [[ParquetFileFormat]], but overrides: - * - * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap - * column vector in the whole-stage codegen path. - * - `supportBatch`, which simply returns true since data types should have already been checked - * in [[org.apache.comet.CometSparkSessionExtensions]] - * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values. - */ -class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf { - override def shortName(): String = "parquet" - override def toString: String = "CometParquet" - override def hashCode(): Int = getClass.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat] - - override def vectorTypes( - requiredSchema: StructType, - partitionSchema: StructType, - sqlConf: SQLConf): Option[Seq[String]] = { - val length = requiredSchema.fields.length + partitionSchema.fields.length - Option(Seq.fill(length)(classOf[CometVector].getName)) - } - - override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true - - override def buildReaderWithPartitionValues( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val sqlConf = sparkSession.sessionState.conf - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedHadoopConf = - sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - - val isCaseSensitive = sqlConf.caseSensitiveAnalysis - val useFieldId = CometParquetUtils.readFieldId(sqlConf) - val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) - val pushDownDate = sqlConf.parquetFilterPushDownDate - val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) - val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - val optionsMap = CaseInsensitiveMap[String](options) - val parquetOptions = new ParquetOptions(optionsMap, sqlConf) - val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead - val parquetFilterPushDown = sqlConf.parquetFilterPushDown - - // Comet specific configurations - val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) - - (file: PartitionedFile) => { - val sharedConf = broadcastedHadoopConf.value.value - val footer = FooterReader.readFooter(sharedConf, file) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( - file, - requiredSchema, - sharedConf, - footerFileMetaData, - datetimeRebaseModeInRead) - - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can - // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a - // `flatMap` is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) - - val batchReader = new BatchReader( - sharedConf, - file, - footer, - capacity, - requiredSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val iter = new RecordReaderIterator(batchReader) - try { - batchReader.init() - iter.asInstanceOf[Iterator[InternalRow]] - } catch { - case e: Throwable => - iter.close() - throw e - } - } - } -} - -object CometParquetFileFormat extends Logging { - - /** - * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` - */ - def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = { - hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) - hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) - hadoopConf.setBoolean( - SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, - sqlConf.nestedSchemaPruningEnabled) - hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) - - // Sets flags for `ParquetToSparkSchemaConverter` - hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) - hadoopConf.setBoolean( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlConf.isParquetINT96AsTimestamp) - - // Comet specific configs - hadoopConf.setBoolean( - CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key, - CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get()) - hadoopConf.setBoolean( - CometConf.COMET_USE_DECIMAL_128.key, - CometConf.COMET_USE_DECIMAL_128.get()) - hadoopConf.setBoolean( - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get()) - } - - def getDatetimeRebaseSpec( - file: PartitionedFile, - sparkSchema: StructType, - sharedConf: Configuration, - footerFileMetaData: FileMetaData, - datetimeRebaseModeInRead: String): RebaseSpec = { - val exceptionOnRebase = sharedConf.getBoolean( - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get) - var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( - footerFileMetaData.getKeyValueMetaData.get, - datetimeRebaseModeInRead) - val hasDateOrTimestamp = sparkSchema.exists(f => - f.dataType match { - case DateType | TimestampType => true - case _ => false - }) - - if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { - if (exceptionOnRebase) { - logWarning( - s"""Found Parquet file $file that could potentially contain dates/timestamps that were - written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase - and return these according to the new Proleptic Gregorian calendar, Comet will throw - exception when reading them. If you want to read them as it is according to the hybrid - Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to - false. Otherwise, if you want to read them according to the new Proleptic Gregorian - calendar, please disable Comet for this query.""") - } else { - // do not throw exception on rebase - read as it is - datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) - } - } - - datetimeRebaseSpec - } -} diff --git a/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala deleted file mode 100644 index 693af125b..000000000 --- a/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import scala.collection.JavaConverters -import scala.collection.mutable - -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.hadoop.ParquetInputFormat -import org.apache.parquet.hadoop.metadata.ParquetMetadata -import org.apache.spark.TaskContext -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.connector.read.InputPartition -import org.apache.spark.sql.connector.read.PartitionReader -import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration - -import org.apache.comet.{CometConf, CometRuntimeException} -import org.apache.comet.shims.ShimSQLConf - -case class CometParquetPartitionReaderFactory( - @transient sqlConf: SQLConf, - broadcastedConf: Broadcast[SerializableConfiguration], - readDataSchema: StructType, - partitionSchema: StructType, - filters: Array[Filter], - options: ParquetOptions, - metrics: Map[String, SQLMetric]) - extends FilePartitionReaderFactory - with ShimSQLConf - with Logging { - - private val isCaseSensitive = sqlConf.caseSensitiveAnalysis - private val useFieldId = CometParquetUtils.readFieldId(sqlConf) - private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) - private val pushDownDate = sqlConf.parquetFilterPushDownDate - private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) - private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead - private val parquetFilterPushDown = sqlConf.parquetFilterPushDown - - // Comet specific configurations - private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf) - - // This is only called at executor on a Broadcast variable, so we don't want it to be - // materialized at driver. - @transient private lazy val preFetchEnabled = { - val conf = broadcastedConf.value.value - - conf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - } - - private var cometReaders: Iterator[BatchReader] = _ - private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]() - - // TODO: we may want to revisit this as we're going to only support flat types at the beginning - override def supportColumnarReads(partition: InputPartition): Boolean = true - - override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { - if (preFetchEnabled) { - val filePartition = partition.asInstanceOf[FilePartition] - val conf = broadcastedConf.value.value - - val threadNum = conf.getInt( - CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key, - CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get) - val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum) - - this.cometReaders = filePartition.files - .map { file => - // `init()` call is deferred to when the prefetch task begins. - // Otherwise we will hold too many resources for readers which are not ready - // to prefetch. - val cometReader = buildCometReader(file) - if (cometReader != null) { - cometReader.submitPrefetchTask(prefetchThreadPool) - } - - cometReader - } - .toSeq - .toIterator - } - - super.createColumnarReader(partition) - } - - override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = - throw new UnsupportedOperationException("Comet doesn't support 'buildReader'") - - private def buildCometReader(file: PartitionedFile): BatchReader = { - val conf = broadcastedConf.value.value - - try { - val (datetimeRebaseSpec, footer, filters) = getFilter(file) - filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) - val cometReader = new BatchReader( - conf, - file, - footer, - batchSize, - readDataSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val taskContext = Option(TaskContext.get) - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) - return cometReader - } catch { - case e: Throwable if preFetchEnabled => - // Keep original exception - cometReaderExceptionMap.put(file, e) - } - null - } - - override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val cometReader = if (!preFetchEnabled) { - // Prefetch is not enabled, create comet reader and initiate it. - val cometReader = buildCometReader(file) - cometReader.init() - - cometReader - } else { - // If prefetch is enabled, we already tried to access the file when in `buildCometReader`. - // It is possibly we got an exception like `FileNotFoundException` and we need to throw it - // now to let Spark handle it. - val reader = cometReaders.next() - val exception = cometReaderExceptionMap.get(file) - exception.foreach(e => throw e) - - if (reader == null) { - throw new CometRuntimeException(s"Cannot find comet file reader for $file") - } - reader - } - CometPartitionReader(cometReader) - } - - def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = { - val sharedConf = broadcastedConf.value.value - val footer = FooterReader.readFooter(sharedConf, file) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( - file, - readDataSchema, - sharedConf, - footerFileMetaData, - datetimeRebaseModeInRead) - - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - (datetimeRebaseSpec, footer, pushed) - } - - override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = - throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.") - - /** - * A simple adapter on Comet's [[BatchReader]]. - */ - protected case class CometPartitionReader(reader: BatchReader) - extends PartitionReader[ColumnarBatch] { - - override def next(): Boolean = { - reader.nextBatch() - } - - override def get(): ColumnarBatch = { - reader.currentBatch() - } - - override def close(): Unit = { - reader.close() - } - } -} diff --git a/spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala deleted file mode 100644 index 5994dfb41..000000000 --- a/spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala +++ /dev/null @@ -1,882 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} -import java.math.{BigDecimal => JBigDecimal} -import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate, Period} -import java.util.Locale - -import scala.collection.JavaConverters.asScalaBufferConverter - -import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} -import org.apache.parquet.filter2.predicate._ -import org.apache.parquet.filter2.predicate.SparkFilterApi._ -import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type} -import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ -import org.apache.parquet.schema.Type.Repetition -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} -import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources -import org.apache.spark.unsafe.types.UTF8String - -import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus - -/** - * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove - * this duplication - * - * Some utility function to convert Spark data source filters to Parquet filters. - */ -class ParquetFilters( - schema: MessageType, - pushDownDate: Boolean, - pushDownTimestamp: Boolean, - pushDownDecimal: Boolean, - pushDownStringPredicate: Boolean, - pushDownInFilterThreshold: Int, - caseSensitive: Boolean, - datetimeRebaseSpec: RebaseSpec) { - // A map which contains parquet field name and data type, if predicate push down applies. - // - // Each key in `nameToParquetField` represents a column; `dots` are used as separators for - // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. - // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. - private val nameToParquetField: Map[String, ParquetPrimitiveField] = { - // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. - // `parentFieldNames` is used to keep track of the current nested level when traversing. - def getPrimitiveFields( - fields: Seq[Type], - parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { - fields.flatMap { - // Parquet only supports predicate push-down for non-repeated primitive types. - // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for - // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34) - case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => - Some( - ParquetPrimitiveField( - fieldNames = parentFieldNames :+ p.getName, - fieldType = ParquetSchemaType( - p.getLogicalTypeAnnotation, - p.getPrimitiveTypeName, - p.getTypeLength))) - // Note that when g is a `Struct`, `g.getOriginalType` is `null`. - // When g is a `Map`, `g.getOriginalType` is `MAP`. - // When g is a `List`, `g.getOriginalType` is `LIST`. - case g: GroupType if g.getOriginalType == null => - getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) - // Parquet only supports push-down for primitive types; as a result, Map and List types - // are removed. - case _ => None - } - } - - val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => - (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field) - } - if (caseSensitive) { - primitiveFields.toMap - } else { - // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive - // mode, just skip pushdown for these fields, they will trigger Exception when reading, - // See: SPARK-25132. - val dedupPrimitiveFields = - primitiveFields - .groupBy(_._1.toLowerCase(Locale.ROOT)) - .filter(_._2.size == 1) - .mapValues(_.head._2) - CaseInsensitiveMap(dedupPrimitiveFields.toMap) - } - } - - /** - * Holds a single primitive field information stored in the underlying parquet file. - * - * @param fieldNames - * a field name as an array of string multi-identifier in parquet file - * @param fieldType - * field type related info in parquet file - */ - private case class ParquetPrimitiveField( - fieldNames: Array[String], - fieldType: ParquetSchemaType) - - private case class ParquetSchemaType( - logicalTypeAnnotation: LogicalTypeAnnotation, - primitiveTypeName: PrimitiveTypeName, - length: Int) - - private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0) - private val ParquetByteType = - ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0) - private val ParquetShortType = - ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0) - private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0) - private val ParquetLongType = ParquetSchemaType(null, INT64, 0) - private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0) - private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0) - private val ParquetStringType = - ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0) - private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0) - private val ParquetDateType = - ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0) - private val ParquetTimestampMicrosType = - ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0) - private val ParquetTimestampMillisType = - ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0) - - private def dateToDays(date: Any): Int = { - val gregorianDays = date match { - case d: Date => DateTimeUtils.fromJavaDate(d) - case ld: LocalDate => DateTimeUtils.localDateToDays(ld) - } - datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) - case _ => gregorianDays - } - } - - private def timestampToMicros(v: Any): JLong = { - val gregorianMicros = v match { - case i: Instant => DateTimeUtils.instantToMicros(i) - case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) - } - datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => - rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) - case _ => gregorianMicros - } - } - - private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() - - private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() - - private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { - val decimalBuffer = new Array[Byte](numBytes) - val bytes = decimal.unscaledValue().toByteArray - - val fixedLengthBytes = if (bytes.length == numBytes) { - bytes - } else { - val signByte = if (bytes.head < 0) -1: Byte else 0: Byte - java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) - System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) - decimalBuffer - } - Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) - } - - private def timestampToMillis(v: Any): JLong = { - val micros = timestampToMicros(v) - val millis = DateTimeUtils.microsToMillis(micros) - millis.asInstanceOf[JLong] - } - - private def toIntValue(v: Any): Integer = { - Option(v) - .map { - case p: Period => IntervalUtils.periodToMonths(p) - case n => n.asInstanceOf[Number].intValue - } - .map(_.asInstanceOf[Integer]) - .orNull - } - - private def toLongValue(v: Any): JLong = v match { - case d: Duration => IntervalUtils.durationToMicros(d) - case l => l.asInstanceOf[JLong] - } - - private val makeEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetBooleanType => - (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) - - // Binary.fromString and Binary.fromByteArray don't accept null values - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) - } - - private val makeNotEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetBooleanType => - (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) - } - - private val makeLt - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeLtEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeGt - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeGtEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeInPredicate: PartialFunction[ - ParquetSchemaType, - (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(toIntValue(_).toInt).foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetLongType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(toLongValue).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetFloatType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]), - FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat])) - - case ParquetDoubleType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]), - FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble])) - - case ParquetStringType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) - - case ParquetBinaryType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]])) - .foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) - - case ParquetDateType if pushDownDate => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(timestampToMicros).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(timestampToMillis).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)) - .foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary])) - } - - // Returns filters that can be pushed down when reading Parquet files. - def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { - filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) - } - - private def convertibleFiltersHelper( - predicate: sources.Filter, - canPartialPushDown: Boolean): Option[sources.Filter] = { - predicate match { - case sources.And(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - (leftResultOptional, rightResultOptional) match { - case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) - case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) - case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) - case _ => None - } - - case sources.Or(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { - None - } else { - Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) - } - case sources.Not(pred) => - val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - resultOptional.map(sources.Not) - - case other => - if (createFilter(other).isDefined) { - Some(other) - } else { - None - } - } - } - - /** - * Converts data sources filters to Parquet filter predicates. - */ - def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { - createFilterHelper(predicate, canPartialPushDownConjuncts = true) - } - - // Parquet's type in the given file should be matched to the value's type - // in the pushed filter in order to push down the filter to Parquet. - private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { - value == null || (nameToParquetField(name).fieldType match { - case ParquetBooleanType => value.isInstanceOf[JBoolean] - case ParquetByteType | ParquetShortType | ParquetIntegerType => - if (isSpark34Plus) { - value match { - // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type - // Int. We don't create a filter if the value would overflow. - case _: JByte | _: JShort | _: Integer => true - case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue - case _ => false - } - } else { - // If not Spark 3.4+, we still following the old behavior as Spark does. - value.isInstanceOf[Number] - } - case ParquetLongType => value.isInstanceOf[JLong] - case ParquetFloatType => value.isInstanceOf[JFloat] - case ParquetDoubleType => value.isInstanceOf[JDouble] - case ParquetStringType => value.isInstanceOf[String] - case ParquetBinaryType => value.isInstanceOf[Array[Byte]] - case ParquetDateType => - value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] - case ParquetTimestampMicrosType | ParquetTimestampMillisType => - value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] - case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) => - isDecimalMatched(value, decimalType) - case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) => - isDecimalMatched(value, decimalType) - case ParquetSchemaType( - decimalType: DecimalLogicalTypeAnnotation, - FIXED_LEN_BYTE_ARRAY, - _) => - isDecimalMatched(value, decimalType) - case _ => false - }) - } - - // Decimal type must make sure that filter value's scale matched the file. - // If doesn't matched, which would cause data corruption. - private def isDecimalMatched( - value: Any, - decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match { - case decimal: JBigDecimal => - decimal.scale == decimalLogicalType.getScale - case _ => false - } - - private def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) - } - - /** - * @param predicate - * the input filter predicates. Not all the predicates can be pushed down. - * @param canPartialPushDownConjuncts - * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one - * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. - * @return - * the Parquet-native filter predicates that are eligible for pushdown. - */ - private def createFilterHelper( - predicate: sources.Filter, - canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { - // NOTE: - // - // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, - // which can be casted to `false` implicitly. Please refer to the `eval` method of these - // operators and the `PruneFilters` rule for details. - - // Hyukjin: - // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. - // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. - // The reason why I did this is, that the actual Parquet filter checks null-safe equality - // comparison. - // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because - // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. - // Probably I missed something and obviously this should be changed. - - predicate match { - case sources.IsNull(name) if canMakeFilterOn(name, null) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name, null) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, null)) - - case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.LessThan(name, value) if canMakeFilterOn(name, value) => - makeLt - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeLtEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => - makeGt - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeGtEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.And(lhs, rhs) => - // At here, it is not safe to just convert one side and remove the other side - // if we do not understand what the parent filters are. - // - // Here is an example used to explain the reason. - // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to - // convert b in ('1'). If we only convert a = 2, we will end up with a filter - // NOT(a = 2), which will generate wrong results. - // - // Pushing one side of AND down is only safe to do at the top level or in the child - // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate - // can be safely removed. - val lhsFilterOption = - createFilterHelper(lhs, canPartialPushDownConjuncts) - val rhsFilterOption = - createFilterHelper(rhs, canPartialPushDownConjuncts) - - (lhsFilterOption, rhsFilterOption) match { - case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) - case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) - case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) - case _ => None - } - - case sources.Or(lhs, rhs) => - // The Or predicate is convertible when both of its children can be pushed down. - // That is to say, if one/both of the children can be partially pushed down, the Or - // predicate can be partially pushed down as well. - // - // Here is an example used to explain the reason. - // Let's say we have - // (a1 AND a2) OR (b1 AND b2), - // a1 and b1 is convertible, while a2 and b2 is not. - // The predicate can be converted as - // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) - // As per the logical in And predicate, we can push down (a1 OR b1). - for { - lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) - rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) - } yield FilterApi.or(lhsFilter, rhsFilter) - - case sources.Not(pred) => - createFilterHelper(pred, canPartialPushDownConjuncts = false) - .map(FilterApi.not) - - case sources.In(name, values) - if pushDownInFilterThreshold > 0 && values.nonEmpty && - canMakeFilterOn(name, values.head) => - val fieldType = nameToParquetField(name).fieldType - val fieldNames = nameToParquetField(name).fieldNames - if (values.length <= pushDownInFilterThreshold) { - values.distinct - .flatMap { v => - makeEq.lift(fieldType).map(_(fieldNames, v)) - } - .reduceLeftOption(FilterApi.or) - } else if (canPartialPushDownConjuncts) { - val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType - val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType) - if (values.contains(null)) { - Seq( - makeEq.lift(fieldType).map(_(fieldNames, null)), - makeInPredicate - .lift(fieldType) - .map(_(fieldNames, values.filter(_ != null), statistics))).flatten - .reduceLeftOption(FilterApi.or) - } else { - makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics)) - } - } else { - None - } - - case sources.StringStartsWith(name, prefix) - if pushDownStringPredicate && canMakeFilterOn(name, prefix) => - Option(prefix).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val strToBinary = Binary.fromReusedByteArray(v.getBytes) - private val size = strToBinary.length - - override def canDrop(statistics: Statistics[Binary]): Boolean = { - val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR - val max = statistics.getMax - val min = statistics.getMin - comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || - comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 - } - - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { - val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR - val max = statistics.getMax - val min = statistics.getMin - comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && - comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 - } - - override def keep(value: Binary): Boolean = { - value != null && UTF8String - .fromBytes(value.getBytes) - .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) - } - }) - } - - case sources.StringEndsWith(name, suffix) - if pushDownStringPredicate && canMakeFilterOn(name, suffix) => - Option(suffix).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val suffixStr = UTF8String.fromString(v) - override def canDrop(statistics: Statistics[Binary]): Boolean = false - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false - override def keep(value: Binary): Boolean = { - value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr) - } - }) - } - - case sources.StringContains(name, value) - if pushDownStringPredicate && canMakeFilterOn(name, value) => - Option(value).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val subStr = UTF8String.fromString(v) - override def canDrop(statistics: Statistics[Binary]): Boolean = false - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false - override def keep(value: Binary): Boolean = { - value != null && UTF8String.fromBytes(value.getBytes).contains(subStr) - } - }) - } - - case _ => None - } - } -} diff --git a/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala deleted file mode 100644 index ac871cf60..000000000 --- a/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import scala.collection.JavaConverters - -import org.apache.hadoop.conf.Configuration -import org.apache.parquet.filter2.predicate.FilterApi -import org.apache.parquet.hadoop.ParquetInputFormat -import org.apache.parquet.hadoop.metadata.FileMetaData -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.execution.datasources.DataSourceUtils -import org.apache.spark.sql.execution.datasources.PartitionedFile -import org.apache.spark.sql.execution.datasources.RecordReaderIterator -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{DateType, StructType, TimestampType} -import org.apache.spark.util.SerializableConfiguration - -import org.apache.comet.CometConf -import org.apache.comet.MetricsSupport -import org.apache.comet.shims.ShimSQLConf -import org.apache.comet.vector.CometVector - -/** - * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's - * [[ParquetFileFormat]], but overrides: - * - * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap - * column vector in the whole-stage codegen path. - * - `supportBatch`, which simply returns true since data types should have already been checked - * in [[org.apache.comet.CometSparkSessionExtensions]] - * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values. - */ -class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf { - override def shortName(): String = "parquet" - override def toString: String = "CometParquet" - override def hashCode(): Int = getClass.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat] - - override def vectorTypes( - requiredSchema: StructType, - partitionSchema: StructType, - sqlConf: SQLConf): Option[Seq[String]] = { - val length = requiredSchema.fields.length + partitionSchema.fields.length - Option(Seq.fill(length)(classOf[CometVector].getName)) - } - - override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true - - override def buildReaderWithPartitionValues( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val sqlConf = sparkSession.sessionState.conf - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedHadoopConf = - sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - - val isCaseSensitive = sqlConf.caseSensitiveAnalysis - val useFieldId = CometParquetUtils.readFieldId(sqlConf) - val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) - val pushDownDate = sqlConf.parquetFilterPushDownDate - val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) - val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - val optionsMap = CaseInsensitiveMap[String](options) - val parquetOptions = new ParquetOptions(optionsMap, sqlConf) - val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead - val parquetFilterPushDown = sqlConf.parquetFilterPushDown - - // Comet specific configurations - val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) - - (file: PartitionedFile) => { - val sharedConf = broadcastedHadoopConf.value.value - val footer = FooterReader.readFooter(sharedConf, file) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( - file, - requiredSchema, - sharedConf, - footerFileMetaData, - datetimeRebaseModeInRead) - - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can - // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a - // `flatMap` is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) - - val batchReader = new BatchReader( - sharedConf, - file, - footer, - capacity, - requiredSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val iter = new RecordReaderIterator(batchReader) - try { - batchReader.init() - iter.asInstanceOf[Iterator[InternalRow]] - } catch { - case e: Throwable => - iter.close() - throw e - } - } - } -} - -object CometParquetFileFormat extends Logging { - - /** - * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` - */ - def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = { - hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) - hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) - hadoopConf.setBoolean( - SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, - sqlConf.nestedSchemaPruningEnabled) - hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) - - // Sets flags for `ParquetToSparkSchemaConverter` - hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) - hadoopConf.setBoolean( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlConf.isParquetINT96AsTimestamp) - - // Comet specific configs - hadoopConf.setBoolean( - CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key, - CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get()) - hadoopConf.setBoolean( - CometConf.COMET_USE_DECIMAL_128.key, - CometConf.COMET_USE_DECIMAL_128.get()) - hadoopConf.setBoolean( - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get()) - } - - def getDatetimeRebaseSpec( - file: PartitionedFile, - sparkSchema: StructType, - sharedConf: Configuration, - footerFileMetaData: FileMetaData, - datetimeRebaseModeInRead: String): RebaseSpec = { - val exceptionOnRebase = sharedConf.getBoolean( - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get) - var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( - footerFileMetaData.getKeyValueMetaData.get, - datetimeRebaseModeInRead) - val hasDateOrTimestamp = sparkSchema.exists(f => - f.dataType match { - case DateType | TimestampType => true - case _ => false - }) - - if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { - if (exceptionOnRebase) { - logWarning( - s"""Found Parquet file $file that could potentially contain dates/timestamps that were - written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase - and return these according to the new Proleptic Gregorian calendar, Comet will throw - exception when reading them. If you want to read them as it is according to the hybrid - Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to - false. Otherwise, if you want to read them according to the new Proleptic Gregorian - calendar, please disable Comet for this query.""") - } else { - // do not throw exception on rebase - read as it is - datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) - } - } - - datetimeRebaseSpec - } -} diff --git a/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala deleted file mode 100644 index 693af125b..000000000 --- a/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import scala.collection.JavaConverters -import scala.collection.mutable - -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.hadoop.ParquetInputFormat -import org.apache.parquet.hadoop.metadata.ParquetMetadata -import org.apache.spark.TaskContext -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.connector.read.InputPartition -import org.apache.spark.sql.connector.read.PartitionReader -import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration - -import org.apache.comet.{CometConf, CometRuntimeException} -import org.apache.comet.shims.ShimSQLConf - -case class CometParquetPartitionReaderFactory( - @transient sqlConf: SQLConf, - broadcastedConf: Broadcast[SerializableConfiguration], - readDataSchema: StructType, - partitionSchema: StructType, - filters: Array[Filter], - options: ParquetOptions, - metrics: Map[String, SQLMetric]) - extends FilePartitionReaderFactory - with ShimSQLConf - with Logging { - - private val isCaseSensitive = sqlConf.caseSensitiveAnalysis - private val useFieldId = CometParquetUtils.readFieldId(sqlConf) - private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) - private val pushDownDate = sqlConf.parquetFilterPushDownDate - private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) - private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead - private val parquetFilterPushDown = sqlConf.parquetFilterPushDown - - // Comet specific configurations - private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf) - - // This is only called at executor on a Broadcast variable, so we don't want it to be - // materialized at driver. - @transient private lazy val preFetchEnabled = { - val conf = broadcastedConf.value.value - - conf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - } - - private var cometReaders: Iterator[BatchReader] = _ - private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]() - - // TODO: we may want to revisit this as we're going to only support flat types at the beginning - override def supportColumnarReads(partition: InputPartition): Boolean = true - - override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { - if (preFetchEnabled) { - val filePartition = partition.asInstanceOf[FilePartition] - val conf = broadcastedConf.value.value - - val threadNum = conf.getInt( - CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key, - CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get) - val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum) - - this.cometReaders = filePartition.files - .map { file => - // `init()` call is deferred to when the prefetch task begins. - // Otherwise we will hold too many resources for readers which are not ready - // to prefetch. - val cometReader = buildCometReader(file) - if (cometReader != null) { - cometReader.submitPrefetchTask(prefetchThreadPool) - } - - cometReader - } - .toSeq - .toIterator - } - - super.createColumnarReader(partition) - } - - override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = - throw new UnsupportedOperationException("Comet doesn't support 'buildReader'") - - private def buildCometReader(file: PartitionedFile): BatchReader = { - val conf = broadcastedConf.value.value - - try { - val (datetimeRebaseSpec, footer, filters) = getFilter(file) - filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) - val cometReader = new BatchReader( - conf, - file, - footer, - batchSize, - readDataSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val taskContext = Option(TaskContext.get) - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) - return cometReader - } catch { - case e: Throwable if preFetchEnabled => - // Keep original exception - cometReaderExceptionMap.put(file, e) - } - null - } - - override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val cometReader = if (!preFetchEnabled) { - // Prefetch is not enabled, create comet reader and initiate it. - val cometReader = buildCometReader(file) - cometReader.init() - - cometReader - } else { - // If prefetch is enabled, we already tried to access the file when in `buildCometReader`. - // It is possibly we got an exception like `FileNotFoundException` and we need to throw it - // now to let Spark handle it. - val reader = cometReaders.next() - val exception = cometReaderExceptionMap.get(file) - exception.foreach(e => throw e) - - if (reader == null) { - throw new CometRuntimeException(s"Cannot find comet file reader for $file") - } - reader - } - CometPartitionReader(cometReader) - } - - def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = { - val sharedConf = broadcastedConf.value.value - val footer = FooterReader.readFooter(sharedConf, file) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( - file, - readDataSchema, - sharedConf, - footerFileMetaData, - datetimeRebaseModeInRead) - - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - (datetimeRebaseSpec, footer, pushed) - } - - override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = - throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.") - - /** - * A simple adapter on Comet's [[BatchReader]]. - */ - protected case class CometPartitionReader(reader: BatchReader) - extends PartitionReader[ColumnarBatch] { - - override def next(): Boolean = { - reader.nextBatch() - } - - override def get(): ColumnarBatch = { - reader.currentBatch() - } - - override def close(): Unit = { - reader.close() - } - } -} diff --git a/spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala deleted file mode 100644 index 5994dfb41..000000000 --- a/spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala +++ /dev/null @@ -1,882 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} -import java.math.{BigDecimal => JBigDecimal} -import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate, Period} -import java.util.Locale - -import scala.collection.JavaConverters.asScalaBufferConverter - -import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} -import org.apache.parquet.filter2.predicate._ -import org.apache.parquet.filter2.predicate.SparkFilterApi._ -import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type} -import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ -import org.apache.parquet.schema.Type.Repetition -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} -import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources -import org.apache.spark.unsafe.types.UTF8String - -import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus - -/** - * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove - * this duplication - * - * Some utility function to convert Spark data source filters to Parquet filters. - */ -class ParquetFilters( - schema: MessageType, - pushDownDate: Boolean, - pushDownTimestamp: Boolean, - pushDownDecimal: Boolean, - pushDownStringPredicate: Boolean, - pushDownInFilterThreshold: Int, - caseSensitive: Boolean, - datetimeRebaseSpec: RebaseSpec) { - // A map which contains parquet field name and data type, if predicate push down applies. - // - // Each key in `nameToParquetField` represents a column; `dots` are used as separators for - // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. - // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. - private val nameToParquetField: Map[String, ParquetPrimitiveField] = { - // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. - // `parentFieldNames` is used to keep track of the current nested level when traversing. - def getPrimitiveFields( - fields: Seq[Type], - parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { - fields.flatMap { - // Parquet only supports predicate push-down for non-repeated primitive types. - // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for - // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34) - case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => - Some( - ParquetPrimitiveField( - fieldNames = parentFieldNames :+ p.getName, - fieldType = ParquetSchemaType( - p.getLogicalTypeAnnotation, - p.getPrimitiveTypeName, - p.getTypeLength))) - // Note that when g is a `Struct`, `g.getOriginalType` is `null`. - // When g is a `Map`, `g.getOriginalType` is `MAP`. - // When g is a `List`, `g.getOriginalType` is `LIST`. - case g: GroupType if g.getOriginalType == null => - getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) - // Parquet only supports push-down for primitive types; as a result, Map and List types - // are removed. - case _ => None - } - } - - val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => - (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field) - } - if (caseSensitive) { - primitiveFields.toMap - } else { - // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive - // mode, just skip pushdown for these fields, they will trigger Exception when reading, - // See: SPARK-25132. - val dedupPrimitiveFields = - primitiveFields - .groupBy(_._1.toLowerCase(Locale.ROOT)) - .filter(_._2.size == 1) - .mapValues(_.head._2) - CaseInsensitiveMap(dedupPrimitiveFields.toMap) - } - } - - /** - * Holds a single primitive field information stored in the underlying parquet file. - * - * @param fieldNames - * a field name as an array of string multi-identifier in parquet file - * @param fieldType - * field type related info in parquet file - */ - private case class ParquetPrimitiveField( - fieldNames: Array[String], - fieldType: ParquetSchemaType) - - private case class ParquetSchemaType( - logicalTypeAnnotation: LogicalTypeAnnotation, - primitiveTypeName: PrimitiveTypeName, - length: Int) - - private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0) - private val ParquetByteType = - ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0) - private val ParquetShortType = - ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0) - private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0) - private val ParquetLongType = ParquetSchemaType(null, INT64, 0) - private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0) - private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0) - private val ParquetStringType = - ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0) - private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0) - private val ParquetDateType = - ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0) - private val ParquetTimestampMicrosType = - ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0) - private val ParquetTimestampMillisType = - ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0) - - private def dateToDays(date: Any): Int = { - val gregorianDays = date match { - case d: Date => DateTimeUtils.fromJavaDate(d) - case ld: LocalDate => DateTimeUtils.localDateToDays(ld) - } - datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) - case _ => gregorianDays - } - } - - private def timestampToMicros(v: Any): JLong = { - val gregorianMicros = v match { - case i: Instant => DateTimeUtils.instantToMicros(i) - case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) - } - datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => - rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) - case _ => gregorianMicros - } - } - - private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() - - private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() - - private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { - val decimalBuffer = new Array[Byte](numBytes) - val bytes = decimal.unscaledValue().toByteArray - - val fixedLengthBytes = if (bytes.length == numBytes) { - bytes - } else { - val signByte = if (bytes.head < 0) -1: Byte else 0: Byte - java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) - System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) - decimalBuffer - } - Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) - } - - private def timestampToMillis(v: Any): JLong = { - val micros = timestampToMicros(v) - val millis = DateTimeUtils.microsToMillis(micros) - millis.asInstanceOf[JLong] - } - - private def toIntValue(v: Any): Integer = { - Option(v) - .map { - case p: Period => IntervalUtils.periodToMonths(p) - case n => n.asInstanceOf[Number].intValue - } - .map(_.asInstanceOf[Integer]) - .orNull - } - - private def toLongValue(v: Any): JLong = v match { - case d: Duration => IntervalUtils.durationToMicros(d) - case l => l.asInstanceOf[JLong] - } - - private val makeEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetBooleanType => - (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) - - // Binary.fromString and Binary.fromByteArray don't accept null values - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) - } - - private val makeNotEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetBooleanType => - (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) - } - - private val makeLt - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeLtEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeGt - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeGtEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeInPredicate: PartialFunction[ - ParquetSchemaType, - (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(toIntValue(_).toInt).foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetLongType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(toLongValue).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetFloatType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]), - FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat])) - - case ParquetDoubleType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]), - FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble])) - - case ParquetStringType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) - - case ParquetBinaryType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]])) - .foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) - - case ParquetDateType if pushDownDate => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(timestampToMicros).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(timestampToMillis).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)) - .foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary])) - } - - // Returns filters that can be pushed down when reading Parquet files. - def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { - filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) - } - - private def convertibleFiltersHelper( - predicate: sources.Filter, - canPartialPushDown: Boolean): Option[sources.Filter] = { - predicate match { - case sources.And(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - (leftResultOptional, rightResultOptional) match { - case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) - case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) - case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) - case _ => None - } - - case sources.Or(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { - None - } else { - Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) - } - case sources.Not(pred) => - val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - resultOptional.map(sources.Not) - - case other => - if (createFilter(other).isDefined) { - Some(other) - } else { - None - } - } - } - - /** - * Converts data sources filters to Parquet filter predicates. - */ - def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { - createFilterHelper(predicate, canPartialPushDownConjuncts = true) - } - - // Parquet's type in the given file should be matched to the value's type - // in the pushed filter in order to push down the filter to Parquet. - private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { - value == null || (nameToParquetField(name).fieldType match { - case ParquetBooleanType => value.isInstanceOf[JBoolean] - case ParquetByteType | ParquetShortType | ParquetIntegerType => - if (isSpark34Plus) { - value match { - // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type - // Int. We don't create a filter if the value would overflow. - case _: JByte | _: JShort | _: Integer => true - case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue - case _ => false - } - } else { - // If not Spark 3.4+, we still following the old behavior as Spark does. - value.isInstanceOf[Number] - } - case ParquetLongType => value.isInstanceOf[JLong] - case ParquetFloatType => value.isInstanceOf[JFloat] - case ParquetDoubleType => value.isInstanceOf[JDouble] - case ParquetStringType => value.isInstanceOf[String] - case ParquetBinaryType => value.isInstanceOf[Array[Byte]] - case ParquetDateType => - value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] - case ParquetTimestampMicrosType | ParquetTimestampMillisType => - value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] - case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) => - isDecimalMatched(value, decimalType) - case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) => - isDecimalMatched(value, decimalType) - case ParquetSchemaType( - decimalType: DecimalLogicalTypeAnnotation, - FIXED_LEN_BYTE_ARRAY, - _) => - isDecimalMatched(value, decimalType) - case _ => false - }) - } - - // Decimal type must make sure that filter value's scale matched the file. - // If doesn't matched, which would cause data corruption. - private def isDecimalMatched( - value: Any, - decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match { - case decimal: JBigDecimal => - decimal.scale == decimalLogicalType.getScale - case _ => false - } - - private def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) - } - - /** - * @param predicate - * the input filter predicates. Not all the predicates can be pushed down. - * @param canPartialPushDownConjuncts - * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one - * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. - * @return - * the Parquet-native filter predicates that are eligible for pushdown. - */ - private def createFilterHelper( - predicate: sources.Filter, - canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { - // NOTE: - // - // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, - // which can be casted to `false` implicitly. Please refer to the `eval` method of these - // operators and the `PruneFilters` rule for details. - - // Hyukjin: - // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. - // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. - // The reason why I did this is, that the actual Parquet filter checks null-safe equality - // comparison. - // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because - // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. - // Probably I missed something and obviously this should be changed. - - predicate match { - case sources.IsNull(name) if canMakeFilterOn(name, null) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name, null) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, null)) - - case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.LessThan(name, value) if canMakeFilterOn(name, value) => - makeLt - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeLtEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => - makeGt - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeGtEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.And(lhs, rhs) => - // At here, it is not safe to just convert one side and remove the other side - // if we do not understand what the parent filters are. - // - // Here is an example used to explain the reason. - // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to - // convert b in ('1'). If we only convert a = 2, we will end up with a filter - // NOT(a = 2), which will generate wrong results. - // - // Pushing one side of AND down is only safe to do at the top level or in the child - // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate - // can be safely removed. - val lhsFilterOption = - createFilterHelper(lhs, canPartialPushDownConjuncts) - val rhsFilterOption = - createFilterHelper(rhs, canPartialPushDownConjuncts) - - (lhsFilterOption, rhsFilterOption) match { - case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) - case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) - case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) - case _ => None - } - - case sources.Or(lhs, rhs) => - // The Or predicate is convertible when both of its children can be pushed down. - // That is to say, if one/both of the children can be partially pushed down, the Or - // predicate can be partially pushed down as well. - // - // Here is an example used to explain the reason. - // Let's say we have - // (a1 AND a2) OR (b1 AND b2), - // a1 and b1 is convertible, while a2 and b2 is not. - // The predicate can be converted as - // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) - // As per the logical in And predicate, we can push down (a1 OR b1). - for { - lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) - rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) - } yield FilterApi.or(lhsFilter, rhsFilter) - - case sources.Not(pred) => - createFilterHelper(pred, canPartialPushDownConjuncts = false) - .map(FilterApi.not) - - case sources.In(name, values) - if pushDownInFilterThreshold > 0 && values.nonEmpty && - canMakeFilterOn(name, values.head) => - val fieldType = nameToParquetField(name).fieldType - val fieldNames = nameToParquetField(name).fieldNames - if (values.length <= pushDownInFilterThreshold) { - values.distinct - .flatMap { v => - makeEq.lift(fieldType).map(_(fieldNames, v)) - } - .reduceLeftOption(FilterApi.or) - } else if (canPartialPushDownConjuncts) { - val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType - val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType) - if (values.contains(null)) { - Seq( - makeEq.lift(fieldType).map(_(fieldNames, null)), - makeInPredicate - .lift(fieldType) - .map(_(fieldNames, values.filter(_ != null), statistics))).flatten - .reduceLeftOption(FilterApi.or) - } else { - makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics)) - } - } else { - None - } - - case sources.StringStartsWith(name, prefix) - if pushDownStringPredicate && canMakeFilterOn(name, prefix) => - Option(prefix).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val strToBinary = Binary.fromReusedByteArray(v.getBytes) - private val size = strToBinary.length - - override def canDrop(statistics: Statistics[Binary]): Boolean = { - val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR - val max = statistics.getMax - val min = statistics.getMin - comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || - comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 - } - - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { - val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR - val max = statistics.getMax - val min = statistics.getMin - comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && - comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 - } - - override def keep(value: Binary): Boolean = { - value != null && UTF8String - .fromBytes(value.getBytes) - .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) - } - }) - } - - case sources.StringEndsWith(name, suffix) - if pushDownStringPredicate && canMakeFilterOn(name, suffix) => - Option(suffix).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val suffixStr = UTF8String.fromString(v) - override def canDrop(statistics: Statistics[Binary]): Boolean = false - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false - override def keep(value: Binary): Boolean = { - value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr) - } - }) - } - - case sources.StringContains(name, value) - if pushDownStringPredicate && canMakeFilterOn(name, value) => - Option(value).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val subStr = UTF8String.fromString(v) - override def canDrop(statistics: Statistics[Binary]): Boolean = false - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false - override def keep(value: Binary): Boolean = { - value != null && UTF8String.fromBytes(value.getBytes).contains(subStr) - } - }) - } - - case _ => None - } - } -} diff --git a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala deleted file mode 100644 index 14a664108..000000000 --- a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet - -import scala.collection.mutable.HashMap -import scala.concurrent.duration.NANOSECONDS -import scala.reflect.ClassTag - -import org.apache.hadoop.fs.Path -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.metric._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection._ - -import org.apache.comet.{CometConf, MetricsSupport} -import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} -import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} - -/** - * Comet physical scan node for DataSource V1. Most of the code here follow Spark's - * [[FileSourceScanExec]], - */ -case class CometScanExec( - @transient relation: HadoopFsRelation, - output: Seq[Attribute], - requiredSchema: StructType, - partitionFilters: Seq[Expression], - optionalBucketSet: Option[BitSet], - optionalNumCoalescedBuckets: Option[Int], - dataFilters: Seq[Expression], - tableIdentifier: Option[TableIdentifier], - disableBucketedScan: Boolean = false, - wrapped: FileSourceScanExec) - extends DataSourceScanExec - with ShimCometScanExec - with CometPlan { - - // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests - override lazy val supportsColumnar: Boolean = - relation.fileFormat.supportBatch(relation.sparkSession, schema) - - override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes - - private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty - - /** - * Send the driver-side metrics. Before calling this function, selectedPartitions has been - * initialized. See SPARK-26327 for more details. - */ - private def sendDriverMetrics(): Unit = { - driverMetrics.foreach(e => metrics(e._1).add(e._2)) - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates( - sparkContext, - executionId, - metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) - } - - private def isDynamicPruningFilter(e: Expression): Boolean = - e.find(_.isInstanceOf[PlanExpression[_]]).isDefined - - @transient lazy val selectedPartitions: Array[PartitionDirectory] = { - val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) - val startTime = System.nanoTime() - val ret = - relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) - setFilesNumAndSizeMetric(ret, true) - val timeTakenMs = - NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) - driverMetrics("metadataTime") = timeTakenMs - ret - }.toArray - - // We can only determine the actual partitions at runtime when a dynamic partition filter is - // present. This is because such a filter relies on information that is only available at run - // time (for instance the keys used in the other side of a join). - @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { - val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) - - if (dynamicPartitionFilters.nonEmpty) { - val startTime = System.nanoTime() - // call the file index for the files matching all filters except dynamic partition filters - val predicate = dynamicPartitionFilters.reduce(And) - val partitionColumns = relation.partitionSchema - val boundPredicate = Predicate.create( - predicate.transform { case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }, - Nil) - val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) - setFilesNumAndSizeMetric(ret, false) - val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 - driverMetrics("pruningTime") = timeTakenMs - ret - } else { - selectedPartitions - } - } - - // exposed for testing - lazy val bucketedScan: Boolean = wrapped.bucketedScan - - override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = - (wrapped.outputPartitioning, wrapped.outputOrdering) - - @transient - private lazy val pushedDownFilters = { - val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) - } - - override lazy val metadata: Map[String, String] = - if (wrapped == null) Map.empty else wrapped.metadata - - override def verboseStringWithOperatorId(): String = { - getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) - wrapped.verboseStringWithOperatorId() - } - - lazy val inputRDD: RDD[InternalRow] = { - val options = relation.options + - (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) - val readFile: (PartitionedFile) => Iterator[InternalRow] = - relation.fileFormat.buildReaderWithPartitionValues( - sparkSession = relation.sparkSession, - dataSchema = relation.dataSchema, - partitionSchema = relation.partitionSchema, - requiredSchema = requiredSchema, - filters = pushedDownFilters, - options = options, - hadoopConf = - relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) - - val readRDD = if (bucketedScan) { - createBucketedReadRDD( - relation.bucketSpec.get, - readFile, - dynamicallySelectedPartitions, - relation) - } else { - createReadRDD(readFile, dynamicallySelectedPartitions, relation) - } - sendDriverMetrics() - readRDD - } - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - inputRDD :: Nil - } - - /** Helper for computing total number and size of files in selected partitions. */ - private def setFilesNumAndSizeMetric( - partitions: Seq[PartitionDirectory], - static: Boolean): Unit = { - val filesNum = partitions.map(_.files.size.toLong).sum - val filesSize = partitions.map(_.files.map(_.getLen).sum).sum - if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { - driverMetrics("numFiles") = filesNum - driverMetrics("filesSize") = filesSize - } else { - driverMetrics("staticFilesNum") = filesNum - driverMetrics("staticFilesSize") = filesSize - } - if (relation.partitionSchema.nonEmpty) { - driverMetrics("numPartitions") = partitions.length - } - } - - override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { - // Tracking scan time has overhead, we can't afford to do it for each row, and can only do - // it for each batch. - if (supportsColumnar) { - Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) - } else { - None - } - } ++ { - relation.fileFormat match { - case f: MetricsSupport => f.initMetrics(sparkContext) - case _ => Map.empty - } - } - - protected override def doExecute(): RDD[InternalRow] = { - ColumnarToRowExec(this).doExecute() - } - - protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val numOutputRows = longMetric("numOutputRows") - val scanTime = longMetric("scanTime") - inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => - new Iterator[ColumnarBatch] { - - override def hasNext: Boolean = { - // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. - val startNs = System.nanoTime() - val res = batches.hasNext - scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) - res - } - - override def next(): ColumnarBatch = { - val batch = batches.next() - numOutputRows += batch.numRows() - batch - } - } - } - } - - override def executeCollect(): Array[InternalRow] = { - ColumnarToRowExec(this).executeCollect() - } - - override val nodeName: String = - s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" - - /** - * Create an RDD for bucketed reads. The non-bucketed variant of this function is - * [[createReadRDD]]. - * - * The algorithm is pretty simple: each RDD partition being returned should include all the - * files with the same bucket id from all the given Hive partitions. - * - * @param bucketSpec - * the bucketing spec. - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createBucketedReadRDD( - bucketSpec: BucketSpec, - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val filesGroupedToBuckets = - selectedPartitions - .flatMap { p => - p.files.map { f => - PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) - } - } - .groupBy { f => - BucketingUtils - .getBucketId(new Path(f.filePath.toString()).getName) - .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) - } - - val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { - val bucketSet = optionalBucketSet.get - filesGroupedToBuckets.filter { f => - bucketSet.get(f._1) - } - } else { - filesGroupedToBuckets - } - - val filePartitions = optionalNumCoalescedBuckets - .map { numCoalescedBuckets => - logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") - val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) - Seq.tabulate(numCoalescedBuckets) { bucketId => - val partitionedFiles = coalescedBuckets - .get(bucketId) - .map { - _.values.flatten.toArray - } - .getOrElse(Array.empty) - FilePartition(bucketId, partitionedFiles) - } - } - .getOrElse { - Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) - } - } - - prepareRDD(fsRelation, readFile, filePartitions) - } - - /** - * Create an RDD for non-bucketed reads. The bucketed variant of this function is - * [[createBucketedReadRDD]]. - * - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createReadRDD( - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes - val maxSplitBytes = - FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) - logInfo( - s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + - s"open cost is considered as scanning $openCostInBytes bytes.") - - // Filter files with bucket pruning if possible - val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled - val shouldProcess: Path => Boolean = optionalBucketSet match { - case Some(bucketSet) if bucketingEnabled => - // Do not prune the file if bucket file name is invalid - filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) - case _ => - _ => true - } - - val splitFiles = selectedPartitions - .flatMap { partition => - partition.files.flatMap { file => - // getPath() is very expensive so we only want to call it once in this block: - val filePath = file.getPath - - if (shouldProcess(filePath)) { - val isSplitable = relation.fileFormat.isSplitable( - relation.sparkSession, - relation.options, - filePath) && - // SPARK-39634: Allow file splitting in combination with row index generation once - // the fix for PARQUET-2161 is available. - !isNeededForSchema(requiredSchema) - PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, - file = file, - filePath = filePath, - isSplitable = isSplitable, - maxSplitBytes = maxSplitBytes, - partitionValues = partition.values) - } else { - Seq.empty - } - } - } - .sortBy(_.length)(implicitly[Ordering[Long]].reverse) - - prepareRDD( - fsRelation, - readFile, - FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) - } - - private def prepareRDD( - fsRelation: HadoopFsRelation, - readFile: (PartitionedFile) => Iterator[InternalRow], - partitions: Seq[FilePartition]): RDD[InternalRow] = { - val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) - val prefetchEnabled = hadoopConf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - - val sqlConf = fsRelation.sparkSession.sessionState.conf - if (prefetchEnabled) { - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedConf = - fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val partitionReaderFactory = CometParquetPartitionReaderFactory( - sqlConf, - broadcastedConf, - requiredSchema, - relation.partitionSchema, - pushedDownFilters.toArray, - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), - metrics) - - newDataSourceRDD( - fsRelation.sparkSession.sparkContext, - partitions.map(Seq(_)), - partitionReaderFactory, - true, - Map.empty) - } else { - newFileScanRDD( - fsRelation.sparkSession, - readFile, - partitions, - new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) - } - } - - // Filters unused DynamicPruningExpression expressions - one which has been replaced - // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning - private def filterUnusedDynamicPruningExpressions( - predicates: Seq[Expression]): Seq[Expression] = { - predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) - } - - override def doCanonicalize(): CometScanExec = { - CometScanExec( - relation, - output.map(QueryPlan.normalizeExpressions(_, output)), - requiredSchema, - QueryPlan.normalizePredicates( - filterUnusedDynamicPruningExpressions(partitionFilters), - output), - optionalBucketSet, - optionalNumCoalescedBuckets, - QueryPlan.normalizePredicates(dataFilters, output), - None, - disableBucketedScan, - null) - } -} - -object CometScanExec { - def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { - // TreeNode.mapProductIterator is protected method. - def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { - val arr = Array.ofDim[B](product.productArity) - var i = 0 - while (i < arr.length) { - arr(i) = f(product.productElement(i)) - i += 1 - } - arr - } - - // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues - // on other Spark distributions if FileSourceScanExec constructor is changed. - // Using `makeCopy` to avoid the issue. - // https://github.com/apache/arrow-datafusion-comet/issues/190 - def transform(arg: Any): AnyRef = arg match { - case _: HadoopFsRelation => - scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) - case other: AnyRef => other - case null => null - } - val newArgs = mapProductIterator(scanExec, transform(_)) - val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] - val batchScanExec = CometScanExec( - wrapped.relation, - wrapped.output, - wrapped.requiredSchema, - wrapped.partitionFilters, - wrapped.optionalBucketSet, - wrapped.optionalNumCoalescedBuckets, - wrapped.dataFilters, - wrapped.tableIdentifier, - wrapped.disableBucketedScan, - wrapped) - scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) - batchScanExec - } -} diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala deleted file mode 100644 index 8a017fabe..000000000 --- a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet - -import scala.collection.mutable.HashMap -import scala.concurrent.duration.NANOSECONDS -import scala.reflect.ClassTag - -import org.apache.hadoop.fs.Path -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.metric._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection._ - -import org.apache.comet.{CometConf, MetricsSupport} -import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} -import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} - -/** - * Comet physical scan node for DataSource V1. Most of the code here follow Spark's - * [[FileSourceScanExec]], - */ -case class CometScanExec( - @transient relation: HadoopFsRelation, - output: Seq[Attribute], - requiredSchema: StructType, - partitionFilters: Seq[Expression], - optionalBucketSet: Option[BitSet], - optionalNumCoalescedBuckets: Option[Int], - dataFilters: Seq[Expression], - tableIdentifier: Option[TableIdentifier], - disableBucketedScan: Boolean = false, - wrapped: FileSourceScanExec) - extends DataSourceScanExec - with ShimCometScanExec - with CometPlan { - - // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests - override lazy val supportsColumnar: Boolean = - relation.fileFormat.supportBatch(relation.sparkSession, schema) - - override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes - - private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty - - /** - * Send the driver-side metrics. Before calling this function, selectedPartitions has been - * initialized. See SPARK-26327 for more details. - */ - private def sendDriverMetrics(): Unit = { - driverMetrics.foreach(e => metrics(e._1).add(e._2)) - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates( - sparkContext, - executionId, - metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) - } - - private def isDynamicPruningFilter(e: Expression): Boolean = - e.find(_.isInstanceOf[PlanExpression[_]]).isDefined - - @transient lazy val selectedPartitions: Array[PartitionDirectory] = { - val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) - val startTime = System.nanoTime() - val ret = - relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) - setFilesNumAndSizeMetric(ret, true) - val timeTakenMs = - NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) - driverMetrics("metadataTime") = timeTakenMs - ret - }.toArray - - // We can only determine the actual partitions at runtime when a dynamic partition filter is - // present. This is because such a filter relies on information that is only available at run - // time (for instance the keys used in the other side of a join). - @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { - val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) - - if (dynamicPartitionFilters.nonEmpty) { - val startTime = System.nanoTime() - // call the file index for the files matching all filters except dynamic partition filters - val predicate = dynamicPartitionFilters.reduce(And) - val partitionColumns = relation.partitionSchema - val boundPredicate = Predicate.create( - predicate.transform { case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }, - Nil) - val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) - setFilesNumAndSizeMetric(ret, false) - val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 - driverMetrics("pruningTime") = timeTakenMs - ret - } else { - selectedPartitions - } - } - - // exposed for testing - lazy val bucketedScan: Boolean = wrapped.bucketedScan - - override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = - (wrapped.outputPartitioning, wrapped.outputOrdering) - - @transient - private lazy val pushedDownFilters = { - val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) - } - - override lazy val metadata: Map[String, String] = - if (wrapped == null) Map.empty else wrapped.metadata - - override def verboseStringWithOperatorId(): String = { - getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) - wrapped.verboseStringWithOperatorId() - } - - lazy val inputRDD: RDD[InternalRow] = { - val options = relation.options + - (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) - val readFile: (PartitionedFile) => Iterator[InternalRow] = - relation.fileFormat.buildReaderWithPartitionValues( - sparkSession = relation.sparkSession, - dataSchema = relation.dataSchema, - partitionSchema = relation.partitionSchema, - requiredSchema = requiredSchema, - filters = pushedDownFilters, - options = options, - hadoopConf = - relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) - - val readRDD = if (bucketedScan) { - createBucketedReadRDD( - relation.bucketSpec.get, - readFile, - dynamicallySelectedPartitions, - relation) - } else { - createReadRDD(readFile, dynamicallySelectedPartitions, relation) - } - sendDriverMetrics() - readRDD - } - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - inputRDD :: Nil - } - - /** Helper for computing total number and size of files in selected partitions. */ - private def setFilesNumAndSizeMetric( - partitions: Seq[PartitionDirectory], - static: Boolean): Unit = { - val filesNum = partitions.map(_.files.size.toLong).sum - val filesSize = partitions.map(_.files.map(_.getLen).sum).sum - if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { - driverMetrics("numFiles") = filesNum - driverMetrics("filesSize") = filesSize - } else { - driverMetrics("staticFilesNum") = filesNum - driverMetrics("staticFilesSize") = filesSize - } - if (relation.partitionSchema.nonEmpty) { - driverMetrics("numPartitions") = partitions.length - } - } - - override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { - // Tracking scan time has overhead, we can't afford to do it for each row, and can only do - // it for each batch. - if (supportsColumnar) { - Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) - } else { - None - } - } ++ { - relation.fileFormat match { - case f: MetricsSupport => f.initMetrics(sparkContext) - case _ => Map.empty - } - } - - protected override def doExecute(): RDD[InternalRow] = { - ColumnarToRowExec(this).doExecute() - } - - protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val numOutputRows = longMetric("numOutputRows") - val scanTime = longMetric("scanTime") - inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => - new Iterator[ColumnarBatch] { - - override def hasNext: Boolean = { - // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. - val startNs = System.nanoTime() - val res = batches.hasNext - scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) - res - } - - override def next(): ColumnarBatch = { - val batch = batches.next() - numOutputRows += batch.numRows() - batch - } - } - } - } - - override def executeCollect(): Array[InternalRow] = { - ColumnarToRowExec(this).executeCollect() - } - - override val nodeName: String = - s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" - - /** - * Create an RDD for bucketed reads. The non-bucketed variant of this function is - * [[createReadRDD]]. - * - * The algorithm is pretty simple: each RDD partition being returned should include all the - * files with the same bucket id from all the given Hive partitions. - * - * @param bucketSpec - * the bucketing spec. - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createBucketedReadRDD( - bucketSpec: BucketSpec, - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val filesGroupedToBuckets = - selectedPartitions - .flatMap { p => - p.files.map { f => - PartitionedFileUtil.getPartitionedFile(f, /*f.getPath,*/ p.values) - } - } - .groupBy { f => - BucketingUtils - .getBucketId(new Path(f.filePath.toString()).getName) - .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) - } - - val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { - val bucketSet = optionalBucketSet.get - filesGroupedToBuckets.filter { f => - bucketSet.get(f._1) - } - } else { - filesGroupedToBuckets - } - - val filePartitions = optionalNumCoalescedBuckets - .map { numCoalescedBuckets => - logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") - val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) - Seq.tabulate(numCoalescedBuckets) { bucketId => - val partitionedFiles = coalescedBuckets - .get(bucketId) - .map { - _.values.flatten.toArray - } - .getOrElse(Array.empty) - FilePartition(bucketId, partitionedFiles) - } - } - .getOrElse { - Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) - } - } - - prepareRDD(fsRelation, readFile, filePartitions) - } - - /** - * Create an RDD for non-bucketed reads. The bucketed variant of this function is - * [[createBucketedReadRDD]]. - * - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createReadRDD( - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes - val maxSplitBytes = - FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) - logInfo( - s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + - s"open cost is considered as scanning $openCostInBytes bytes.") - - // Filter files with bucket pruning if possible - val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled - val shouldProcess: Path => Boolean = optionalBucketSet match { - case Some(bucketSet) if bucketingEnabled => - // Do not prune the file if bucket file name is invalid - filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) - case _ => - _ => true - } - - val splitFiles = selectedPartitions - .flatMap { partition => - partition.files.flatMap { file => - // getPath() is very expensive so we only want to call it once in this block: - val filePath = file.getPath - - if (shouldProcess(filePath)) { - val isSplitable = relation.fileFormat.isSplitable( - relation.sparkSession, - relation.options, - filePath) && - // SPARK-39634: Allow file splitting in combination with row index generation once - // the fix for PARQUET-2161 is available. - !isNeededForSchema(requiredSchema) - PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, - file = file, - //filePath = filePath, - isSplitable = isSplitable, - maxSplitBytes = maxSplitBytes, - partitionValues = partition.values) - } else { - Seq.empty - } - } - } - .sortBy(_.length)(implicitly[Ordering[Long]].reverse) - - prepareRDD( - fsRelation, - readFile, - FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) - } - - private def prepareRDD( - fsRelation: HadoopFsRelation, - readFile: (PartitionedFile) => Iterator[InternalRow], - partitions: Seq[FilePartition]): RDD[InternalRow] = { - val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) - val prefetchEnabled = hadoopConf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - - val sqlConf = fsRelation.sparkSession.sessionState.conf - if (prefetchEnabled) { - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedConf = - fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val partitionReaderFactory = CometParquetPartitionReaderFactory( - sqlConf, - broadcastedConf, - requiredSchema, - relation.partitionSchema, - pushedDownFilters.toArray, - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), - metrics) - - newDataSourceRDD( - fsRelation.sparkSession.sparkContext, - partitions.map(Seq(_)), - partitionReaderFactory, - true, - Map.empty) - } else { - newFileScanRDD( - fsRelation.sparkSession, - readFile, - partitions, - new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) - } - } - - // Filters unused DynamicPruningExpression expressions - one which has been replaced - // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning - private def filterUnusedDynamicPruningExpressions( - predicates: Seq[Expression]): Seq[Expression] = { - predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) - } - - override def doCanonicalize(): CometScanExec = { - CometScanExec( - relation, - output.map(QueryPlan.normalizeExpressions(_, output)), - requiredSchema, - QueryPlan.normalizePredicates( - filterUnusedDynamicPruningExpressions(partitionFilters), - output), - optionalBucketSet, - optionalNumCoalescedBuckets, - QueryPlan.normalizePredicates(dataFilters, output), - None, - disableBucketedScan, - null) - } -} - -object CometScanExec { - def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { - // TreeNode.mapProductIterator is protected method. - def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { - val arr = Array.ofDim[B](product.productArity) - var i = 0 - while (i < arr.length) { - arr(i) = f(product.productElement(i)) - i += 1 - } - arr - } - - // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues - // on other Spark distributions if FileSourceScanExec constructor is changed. - // Using `makeCopy` to avoid the issue. - // https://github.com/apache/arrow-datafusion-comet/issues/190 - def transform(arg: Any): AnyRef = arg match { - case _: HadoopFsRelation => - scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) - case other: AnyRef => other - case null => null - } - val newArgs = mapProductIterator(scanExec, transform(_)) - val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] - val batchScanExec = CometScanExec( - wrapped.relation, - wrapped.output, - wrapped.requiredSchema, - wrapped.partitionFilters, - wrapped.optionalBucketSet, - wrapped.optionalNumCoalescedBuckets, - wrapped.dataFilters, - wrapped.tableIdentifier, - wrapped.disableBucketedScan, - wrapped) - scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) - batchScanExec - } -} diff --git a/spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala similarity index 100% rename from spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetFileFormat.scala rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala diff --git a/spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala similarity index 100% rename from spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala diff --git a/spark/src/main/spark-3.2/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/ParquetFilters.scala similarity index 100% rename from spark/src/main/spark-3.2/org/apache/comet/parquet/ParquetFilters.scala rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/ParquetFilters.scala diff --git a/spark/src/main/spark-3.2/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala similarity index 100% rename from spark/src/main/spark-3.2/org/apache/spark/sql/comet/CometScanExec.scala rename to spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala