diff --git a/pom.xml b/pom.xml index 4fb377954..b9001c9cd 100644 --- a/pom.xml +++ b/pom.xml @@ -91,6 +91,7 @@ under the License. spark-3.4-plus spark-3.x spark-3.4 + spark-pre-3.5 @@ -527,7 +528,9 @@ under the License. not-needed-yet not-needed-yet + spark-3.x spark-3.2 + spark-pre-3.5 @@ -539,7 +542,9 @@ under the License. 3.3 1.12.0 not-needed-yet + spark-3.x spark-3.3 + spark-pre-3.5 @@ -549,6 +554,25 @@ under the License. 2.12.17 3.4 1.13.1 + spark-3.x + spark-3.4 + spark-pre-3.5 + + + + + + spark-3.5 + + 2.12.17 + 3.5.1 + 3.5 + 1.13.1 + + + spark-3.x + spark-3.5 + not-needed-yet @@ -564,6 +588,7 @@ under the License. 1.13.1 spark-4.0 not-needed-yet + not-needed-yet 17 ${java.version} @@ -693,7 +718,6 @@ under the License. - --> ${scala.version} true true diff --git a/spark/pom.xml b/spark/pom.xml index 84e2e501f..8585b4b8c 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -254,6 +254,7 @@ under the License. src/test/${additional.3_4.test.source} src/test/${shims.majorVerSrc} src/test/${shims.minorVerSrc} + src/test/${shims.extraVerSrc} @@ -267,6 +268,7 @@ under the License. src/main/${shims.majorVerSrc} src/main/${shims.minorVerSrc} + src/main/${shims.extraVerSrc} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 3f4d7bfd3..a1b9a3677 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -41,10 +41,11 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan} import org.apache.spark.sql.comet.shims.ShimCometShuffleWriteProcessor import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin} -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} @@ -390,7 +391,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { val pageSize = SparkEnv.get.memoryManager.pageSizeBytes val sorter = UnsafeExternalRowSorter.createWithRecordComparator( - fromAttributes(outputAttributes), + StructType( + outputAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))), recordComparatorSupplier, prefixComparator, prefixComputer, @@ -434,8 +436,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { serializer, shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics), shuffleType = CometColumnarShuffle, - schema = Some(fromAttributes(outputAttributes))) - + schema = Some(StructType(outputAttributes.map(a => + StructField(a.name, a.dataType, a.nullable, a.metadata))))) dependency } } diff --git a/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala new file mode 100644 index 000000000..91e86eea3 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import scala.collection.JavaConverters + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop.ParquetInputFormat +import org.apache.parquet.hadoop.metadata.FileMetaData +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.RecordReaderIterator +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.{DateType, StructType, TimestampType} +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.CometConf +import org.apache.comet.MetricsSupport +import org.apache.comet.shims.ShimSQLConf +import org.apache.comet.vector.CometVector + +/** + * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's + * [[ParquetFileFormat]], but overrides: + * + * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap + * column vector in the whole-stage codegen path. + * - `supportBatch`, which simply returns true since data types should have already been checked + * in [[org.apache.comet.CometSparkSessionExtensions]] + * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values. + */ +class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf { + override def shortName(): String = "parquet" + override def toString: String = "CometParquet" + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat] + + override def vectorTypes( + requiredSchema: StructType, + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { + val length = requiredSchema.fields.length + partitionSchema.fields.length + Option(Seq.fill(length)(classOf[CometVector].getName)) + } + + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val sqlConf = sparkSession.sessionState.conf + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + val useFieldId = CometParquetUtils.readFieldId(sqlConf) + val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val optionsMap = CaseInsensitiveMap[String](options) + val parquetOptions = new ParquetOptions(optionsMap, sqlConf) + val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead + val parquetFilterPushDown = sqlConf.parquetFilterPushDown + + // Comet specific configurations + val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) + + (file: PartitionedFile) => { + val sharedConf = broadcastedHadoopConf.value.value + val footer = FooterReader.readFooter(sharedConf, file) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( + file, + requiredSchema, + sharedConf, + footerFileMetaData, + datetimeRebaseModeInRead) + + val pushed = if (parquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can + // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a + // `flatMap` is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) + + val batchReader = new BatchReader( + sharedConf, + file, + footer, + capacity, + requiredSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + val iter = new RecordReaderIterator(batchReader) + try { + batchReader.init() + iter.asInstanceOf[Iterator[InternalRow]] + } catch { + case e: Throwable => + iter.close() + throw e + } + } + } +} + +object CometParquetFileFormat extends Logging { + + /** + * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` + */ + def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = { + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sqlConf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlConf.isParquetINT96AsTimestamp) + + // Comet specific configs + hadoopConf.setBoolean( + CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key, + CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get()) + hadoopConf.setBoolean( + CometConf.COMET_USE_DECIMAL_128.key, + CometConf.COMET_USE_DECIMAL_128.get()) + hadoopConf.setBoolean( + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get()) + } + + def getDatetimeRebaseSpec( + file: PartitionedFile, + sparkSchema: StructType, + sharedConf: Configuration, + footerFileMetaData: FileMetaData, + datetimeRebaseModeInRead: String): RebaseSpec = { + val exceptionOnRebase = sharedConf.getBoolean( + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get) + var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val hasDateOrTimestamp = sparkSchema.exists(f => + f.dataType match { + case DateType | TimestampType => true + case _ => false + }) + + if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { + if (exceptionOnRebase) { + logWarning( + s"""Found Parquet file $file that could potentially contain dates/timestamps that were + written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase + and return these according to the new Proleptic Gregorian calendar, Comet will throw + exception when reading them. If you want to read them as it is according to the hybrid + Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to + false. Otherwise, if you want to read them according to the new Proleptic Gregorian + calendar, please disable Comet for this query.""") + } else { + // do not throw exception on rebase - read as it is + datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) + } + } + + datetimeRebaseSpec + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala new file mode 100644 index 000000000..787357531 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import scala.collection.JavaConverters +import scala.collection.mutable + +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.parquet.hadoop.ParquetInputFormat +import org.apache.parquet.hadoop.metadata.ParquetMetadata +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.{CometConf, CometRuntimeException} +import org.apache.comet.shims.ShimSQLConf + +case class CometParquetPartitionReaderFactory( + @transient sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + readDataSchema: StructType, + partitionSchema: StructType, + filters: Array[Filter], + options: ParquetOptions, + metrics: Map[String, SQLMetric]) + extends FilePartitionReaderFactory + with ShimSQLConf + with Logging { + + private val isCaseSensitive = sqlConf.caseSensitiveAnalysis + private val useFieldId = CometParquetUtils.readFieldId(sqlConf) + private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) + private val pushDownDate = sqlConf.parquetFilterPushDownDate + private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) + private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead + private val parquetFilterPushDown = sqlConf.parquetFilterPushDown + + // Comet specific configurations + private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf) + + // This is only called at executor on a Broadcast variable, so we don't want it to be + // materialized at driver. + @transient private lazy val preFetchEnabled = { + val conf = broadcastedConf.value.value + + conf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + } + + private var cometReaders: Iterator[BatchReader] = _ + private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]() + + // TODO: we may want to revisit this as we're going to only support flat types at the beginning + override def supportColumnarReads(partition: InputPartition): Boolean = true + + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + if (preFetchEnabled) { + val filePartition = partition.asInstanceOf[FilePartition] + val conf = broadcastedConf.value.value + + val threadNum = conf.getInt( + CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key, + CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get) + val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum) + + this.cometReaders = filePartition.files + .map { file => + // `init()` call is deferred to when the prefetch task begins. + // Otherwise we will hold too many resources for readers which are not ready + // to prefetch. + val cometReader = buildCometReader(file) + if (cometReader != null) { + cometReader.submitPrefetchTask(prefetchThreadPool) + } + + cometReader + } + .toSeq + .toIterator + } + + super.createColumnarReader(partition) + } + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = + throw new UnsupportedOperationException("Comet doesn't support 'buildReader'") + + private def buildCometReader(file: PartitionedFile): BatchReader = { + val conf = broadcastedConf.value.value + + try { + val (datetimeRebaseSpec, footer, filters) = getFilter(file) + filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) + val cometReader = new BatchReader( + conf, + file, + footer, + batchSize, + readDataSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + val taskContext = Option(TaskContext.get) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) + return cometReader + } catch { + case e: Throwable if preFetchEnabled => + // Keep original exception + cometReaderExceptionMap.put(file, e) + } + null + } + + override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { + val cometReader = if (!preFetchEnabled) { + // Prefetch is not enabled, create comet reader and initiate it. + val cometReader = buildCometReader(file) + cometReader.init() + + cometReader + } else { + // If prefetch is enabled, we already tried to access the file when in `buildCometReader`. + // It is possibly we got an exception like `FileNotFoundException` and we need to throw it + // now to let Spark handle it. + val reader = cometReaders.next() + val exception = cometReaderExceptionMap.get(file) + exception.foreach(e => throw e) + + if (reader == null) { + throw new CometRuntimeException(s"Cannot find comet file reader for $file") + } + reader + } + CometPartitionReader(cometReader) + } + + def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = { + val sharedConf = broadcastedConf.value.value + val footer = FooterReader.readFooter(sharedConf, file) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( + file, + readDataSchema, + sharedConf, + footerFileMetaData, + datetimeRebaseModeInRead) + + val pushed = if (parquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + (datetimeRebaseSpec, footer, pushed) + } + + override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = + throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.") + + /** + * A simple adapter on Comet's [[BatchReader]]. + */ + protected case class CometPartitionReader(reader: BatchReader) + extends PartitionReader[ColumnarBatch] { + + override def next(): Boolean = { + reader.nextBatch() + } + + override def get(): ColumnarBatch = { + reader.currentBatch() + } + + override def close(): Unit = { + reader.close() + } + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-3.5/org/apache/comet/parquet/ParquetFilters.scala new file mode 100644 index 000000000..3cd434418 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/parquet/ParquetFilters.scala @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} +import java.time.{Duration, Instant, LocalDate, Period} +import java.util.Locale + +import scala.collection.JavaConverters.asScalaBufferConverter + +import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} +import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.SparkFilterApi._ +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type} +import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} +import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.sources +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus + +/** + * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove + * this duplication + * + * Some utility function to convert Spark data source filters to Parquet filters. + */ +class ParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStringPredicate: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean, + datetimeRebaseSpec: RebaseSpec) { + // A map which contains parquet field name and data type, if predicate push down applies. + // + // Each key in `nameToParquetField` represents a column; `dots` are used as separators for + // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + private val nameToParquetField: Map[String, ParquetPrimitiveField] = { + // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. + // `parentFieldNames` is used to keep track of the current nested level when traversing. + def getPrimitiveFields( + fields: Seq[Type], + parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { + fields.flatMap { + // Parquet only supports predicate push-down for non-repeated primitive types. + // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for + // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34) + case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => + Some( + ParquetPrimitiveField( + fieldNames = parentFieldNames :+ p.getName, + fieldType = ParquetSchemaType( + p.getLogicalTypeAnnotation, + p.getPrimitiveTypeName, + p.getTypeLength))) + // Note that when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => + getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are removed. + case _ => None + } + } + + val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => + (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields.toMap) + } + } + + /** + * Holds a single primitive field information stored in the underlying parquet file. + * + * @param fieldNames + * a field name as an array of string multi-identifier in parquet file + * @param fieldType + * field type related info in parquet file + */ + private case class ParquetPrimitiveField( + fieldNames: Array[String], + fieldType: ParquetSchemaType) + + private case class ParquetSchemaType( + logicalTypeAnnotation: LogicalTypeAnnotation, + primitiveTypeName: PrimitiveTypeName, + length: Int) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0) + private val ParquetByteType = + ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0) + private val ParquetShortType = + ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0) + private val ParquetStringType = + ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0) + private val ParquetDateType = + ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0) + private val ParquetTimestampMicrosType = + ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0) + private val ParquetTimestampMillisType = + ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0) + + private def dateToDays(date: Any): Int = { + val gregorianDays = date match { + case d: Date => DateTimeUtils.fromJavaDate(d) + case ld: LocalDate => DateTimeUtils.localDateToDays(ld) + } + datetimeRebaseSpec.mode match { + case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) + case _ => gregorianDays + } + } + + private def timestampToMicros(v: Any): JLong = { + val gregorianMicros = v match { + case i: Instant => DateTimeUtils.instantToMicros(i) + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + } + datetimeRebaseSpec.mode match { + case LegacyBehaviorPolicy.LEGACY => + rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) + case _ => gregorianMicros + } + } + + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private def timestampToMillis(v: Any): JLong = { + val micros = timestampToMicros(v) + val millis = DateTimeUtils.microsToMillis(micros) + millis.asInstanceOf[JLong] + } + + private def toIntValue(v: Any): Integer = { + Option(v) + .map { + case p: Period => IntervalUtils.periodToMonths(p) + case n => n.asInstanceOf[Number].intValue + } + .map(_.asInstanceOf[Integer]) + .orNull + } + + private def toLongValue(v: Any): JLong = v match { + case d: Duration => IntervalUtils.durationToMicros(d) + case l => l.asInstanceOf[JLong] + } + + private val makeEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) + + // Binary.fromString and Binary.fromByteArray don't accept null values + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeNotEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeLt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeLtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeInPredicate: PartialFunction[ + ParquetSchemaType, + (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(toIntValue(_).toInt).foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetLongType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(toLongValue).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetFloatType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]), + FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat])) + + case ParquetDoubleType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]), + FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble])) + + case ParquetStringType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) + + case ParquetBinaryType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]])) + .foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) + + case ParquetDateType if pushDownDate => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(timestampToMicros).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(timestampToMillis).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)) + .foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary])) + } + + // Returns filters that can be pushed down when reading Parquet files. + def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { + filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) + } + + private def convertibleFiltersHelper( + predicate: sources.Filter, + canPartialPushDown: Boolean): Option[sources.Filter] = { + predicate match { + case sources.And(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + (leftResultOptional, rightResultOptional) match { + case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) + case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) + case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) + case _ => None + } + + case sources.Or(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { + None + } else { + Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) + } + case sources.Not(pred) => + val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) + resultOptional.map(sources.Not) + + case other => + if (createFilter(other).isDefined) { + Some(other) + } else { + None + } + } + } + + /** + * Converts data sources filters to Parquet filter predicates. + */ + def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { + createFilterHelper(predicate, canPartialPushDownConjuncts = true) + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToParquetField(name).fieldType match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => + if (isSpark34Plus) { + value match { + // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type + // Int. We don't create a filter if the value would overflow. + case _: JByte | _: JShort | _: Integer => true + case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue + case _ => false + } + } else { + // If not Spark 3.4+, we still following the old behavior as Spark does. + value.isInstanceOf[Number] + } + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => + value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] + case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) => + isDecimalMatched(value, decimalType) + case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) => + isDecimalMatched(value, decimalType) + case ParquetSchemaType( + decimalType: DecimalLogicalTypeAnnotation, + FIXED_LEN_BYTE_ARRAY, + _) => + isDecimalMatched(value, decimalType) + case _ => false + }) + } + + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + private def isDecimalMatched( + value: Any, + decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalLogicalType.getScale + case _ => false + } + + private def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) + } + + /** + * @param predicate + * the input filter predicates. Not all the predicates can be pushed down. + * @param canPartialPushDownConjuncts + * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one + * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. + * @return + * the Parquet-native filter predicates that are eligible for pushdown. + */ + private def createFilterHelper( + predicate: sources.Filter, + canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `PruneFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + + predicate match { + case sources.IsNull(name) if canMakeFilterOn(name, null) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => + makeLt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeLtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => + makeGt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeGtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.And(lhs, rhs) => + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val lhsFilterOption = + createFilterHelper(lhs, canPartialPushDownConjuncts) + val rhsFilterOption = + createFilterHelper(rhs, canPartialPushDownConjuncts) + + (lhsFilterOption, rhsFilterOption) match { + case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) + case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) + case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) + case _ => None + } + + case sources.Or(lhs, rhs) => + // The Or predicate is convertible when both of its children can be pushed down. + // That is to say, if one/both of the children can be partially pushed down, the Or + // predicate can be partially pushed down as well. + // + // Here is an example used to explain the reason. + // Let's say we have + // (a1 AND a2) OR (b1 AND b2), + // a1 and b1 is convertible, while a2 and b2 is not. + // The predicate can be converted as + // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) + // As per the logical in And predicate, we can push down (a1 OR b1). + for { + lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) + rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case sources.Not(pred) => + createFilterHelper(pred, canPartialPushDownConjuncts = false) + .map(FilterApi.not) + + case sources.In(name, values) + if pushDownInFilterThreshold > 0 && values.nonEmpty && + canMakeFilterOn(name, values.head) => + val fieldType = nameToParquetField(name).fieldType + val fieldNames = nameToParquetField(name).fieldNames + if (values.length <= pushDownInFilterThreshold) { + values.distinct + .flatMap { v => + makeEq.lift(fieldType).map(_(fieldNames, v)) + } + .reduceLeftOption(FilterApi.or) + } else if (canPartialPushDownConjuncts) { + val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType + val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType) + if (values.contains(null)) { + Seq( + makeEq.lift(fieldType).map(_(fieldNames, null)), + makeInPredicate + .lift(fieldType) + .map(_(fieldNames, values.filter(_ != null), statistics))).flatten + .reduceLeftOption(FilterApi.or) + } else { + makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics)) + } + } else { + None + } + + case sources.StringStartsWith(name, prefix) + if pushDownStringPredicate && canMakeFilterOn(name, prefix) => + Option(prefix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + value != null && UTF8String + .fromBytes(value.getBytes) + .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) + } + }) + } + + case sources.StringEndsWith(name, suffix) + if pushDownStringPredicate && canMakeFilterOn(name, suffix) => + Option(suffix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val suffixStr = UTF8String.fromString(v) + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + override def keep(value: Binary): Boolean = { + value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr) + } + }) + } + + case sources.StringContains(name, value) + if pushDownStringPredicate && canMakeFilterOn(name, value) => + Option(value).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val subStr = UTF8String.fromString(v) + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + override def keep(value: Binary): Boolean = { + value != null && UTF8String.fromBytes(value.getBytes).contains(subStr) + } + }) + } + + case _ => None + } + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala new file mode 100644 index 000000000..409e1c94b --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. + */ +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(unhex.failOnError)) + } +} diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala new file mode 100644 index 000000000..777248b41 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.mutable.HashMap +import scala.concurrent.duration.NANOSECONDS +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection._ + +import org.apache.comet.{CometConf, MetricsSupport} +import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} +import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} + +/** + * Comet physical scan node for DataSource V1. Most of the code here follow Spark's + * [[FileSourceScanExec]], + */ +case class CometScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false, + wrapped: FileSourceScanExec) + extends DataSourceScanExec + with ShimCometScanExec + with CometPlan { + + // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests + override lazy val supportsColumnar: Boolean = + relation.fileFormat.supportBatch(relation.sparkSession, schema) + + override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes + + private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has been + * initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) + val startTime = System.nanoTime() + val ret = + relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + setFilesNumAndSizeMetric(ret, true) + val timeTakenMs = + NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) + driverMetrics("metadataTime") = timeTakenMs + ret + }.toArray + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create( + predicate.transform { case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, + Nil) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } + + // exposed for testing + lazy val bucketedScan: Boolean = wrapped.bucketedScan + + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = + (wrapped.outputPartitioning, wrapped.outputOrdering) + + @transient + private lazy val pushedDownFilters = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } + + override lazy val metadata: Map[String, String] = + if (wrapped == null) Map.empty else wrapped.metadata + + override def verboseStringWithOperatorId(): String = { + getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) + wrapped.verboseStringWithOperatorId() + } + + lazy val inputRDD: RDD[InternalRow] = { + val options = relation.options + + (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) + val readFile: (PartitionedFile) => Iterator[InternalRow] = + relation.fileFormat.buildReaderWithPartitionValues( + sparkSession = relation.sparkSession, + dataSchema = relation.dataSchema, + partitionSchema = relation.partitionSchema, + requiredSchema = requiredSchema, + filters = pushedDownFilters, + options = options, + hadoopConf = + relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) + + val readRDD = if (bucketedScan) { + createBucketedReadRDD( + relation.bucketSpec.get, + readFile, + dynamicallySelectedPartitions, + relation) + } else { + createReadRDD(readFile, dynamicallySelectedPartitions, relation) + } + sendDriverMetrics() + readRDD + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + /** Helper for computing total number and size of files in selected partitions. */ + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchema.nonEmpty) { + driverMetrics("numPartitions") = partitions.length + } + } + + override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + relation.fileFormat match { + case f: MetricsSupport => f.initMetrics(sparkContext) + case _ => Map.empty + } + } + + protected override def doExecute(): RDD[InternalRow] = { + ColumnarToRowExec(this).doExecute() + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = batches.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + + override def next(): ColumnarBatch = { + val batch = batches.next() + numOutputRows += batch.numRows() + batch + } + } + } + } + + override def executeCollect(): Array[InternalRow] = { + ColumnarToRowExec(this).executeCollect() + } + + override val nodeName: String = + s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" + + /** + * Create an RDD for bucketed reads. The non-bucketed variant of this function is + * [[createReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the + * files with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec + * the bucketing spec. + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createBucketedReadRDD( + bucketSpec: BucketSpec, + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") + val filesGroupedToBuckets = + selectedPartitions + .flatMap { p => + p.files.map { f => + PartitionedFileUtil.getPartitionedFile(f, /*f.getPath,*/ p.values) + } + } + .groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath.toString()).getName) + .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) + } + + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { f => + bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets + .map { numCoalescedBuckets => + logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { bucketId => + val partitionedFiles = coalescedBuckets + .get(bucketId) + .map { + _.values.flatten.toArray + } + .getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + } + .getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + + prepareRDD(fsRelation, readFile, filePartitions) + } + + /** + * Create an RDD for non-bucketed reads. The bucketed variant of this function is + * [[createBucketedReadRDD]]. + * + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createReadRDD( + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) + logInfo( + s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + // Filter files with bucket pruning if possible + val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + + val splitFiles = selectedPartitions + .flatMap { partition => + partition.files.flatMap { file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + + if (shouldProcess(filePath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, + relation.options, + filePath) && + // SPARK-39634: Allow file splitting in combination with row index generation once + // the fix for PARQUET-2161 is available. + !isNeededForSchema(requiredSchema) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + //filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values) + } else { + Seq.empty + } + } + } + .sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + prepareRDD( + fsRelation, + readFile, + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) + } + + private def prepareRDD( + fsRelation: HadoopFsRelation, + readFile: (PartitionedFile) => Iterator[InternalRow], + partitions: Seq[FilePartition]): RDD[InternalRow] = { + val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) + val prefetchEnabled = hadoopConf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + + val sqlConf = fsRelation.sparkSession.sessionState.conf + if (prefetchEnabled) { + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedConf = + fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val partitionReaderFactory = CometParquetPartitionReaderFactory( + sqlConf, + broadcastedConf, + requiredSchema, + relation.partitionSchema, + pushedDownFilters.toArray, + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), + metrics) + + newDataSourceRDD( + fsRelation.sparkSession.sparkContext, + partitions.map(Seq(_)), + partitionReaderFactory, + true, + Map.empty) + } else { + newFileScanRDD( + fsRelation.sparkSession, + readFile, + partitions, + new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) + } + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + private def filterUnusedDynamicPruningExpressions( + predicates: Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } + + override def doCanonicalize(): CometScanExec = { + CometScanExec( + relation, + output.map(QueryPlan.normalizeExpressions(_, output)), + requiredSchema, + QueryPlan.normalizePredicates( + filterUnusedDynamicPruningExpressions(partitionFilters), + output), + optionalBucketSet, + optionalNumCoalescedBuckets, + QueryPlan.normalizePredicates(dataFilters, output), + None, + disableBucketedScan, + null) + } +} + +object CometScanExec { + def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { + // TreeNode.mapProductIterator is protected method. + def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { + val arr = Array.ofDim[B](product.productArity) + var i = 0 + while (i < arr.length) { + arr(i) = f(product.productElement(i)) + i += 1 + } + arr + } + + // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues + // on other Spark distributions if FileSourceScanExec constructor is changed. + // Using `makeCopy` to avoid the issue. + // https://github.com/apache/arrow-datafusion-comet/issues/190 + def transform(arg: Any): AnyRef = arg match { + case _: HadoopFsRelation => + scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) + case other: AnyRef => other + case null => null + } + val newArgs = mapProductIterator(scanExec, transform(_)) + val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] + val batchScanExec = CometScanExec( + wrapped.relation, + wrapped.output, + wrapped.requiredSchema, + wrapped.partitionFilters, + wrapped.optionalBucketSet, + wrapped.optionalNumCoalescedBuckets, + wrapped.dataFilters, + wrapped.tableIdentifier, + wrapped.disableBucketedScan, + wrapped) + scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) + batchScanExec + } +} \ No newline at end of file diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala new file mode 100644 index 000000000..543116c10 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} +import org.apache.spark.sql.types.StructType +import org.apache.spark.SparkContext + +trait ShimCometScanExec { + def wrapped: FileSourceScanExec + + lazy val fileConstantMetadataColumns: Seq[AttributeReference] = + wrapped.fileConstantMetadataColumns + + protected def newDataSourceRDD( + sc: SparkContext, + inputPartitions: Seq[Seq[InputPartition]], + partitionReaderFactory: PartitionReaderFactory, + columnarReads: Boolean, + customMetrics: Map[String, SQLMetric]): DataSourceRDD = + new DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads, customMetrics) + + protected def newFileScanRDD( + fsRelation: HadoopFsRelation, + readFunction: PartitionedFile => Iterator[InternalRow], + filePartitions: Seq[FilePartition], + readSchema: StructType, + options: ParquetOptions): FileScanRDD = { + new FileScanRDD( + fsRelation.sparkSession, + readFunction, + filePartitions, + readSchema, + fileConstantMetadataColumns, + fsRelation.fileFormat.fileConstantMetadataExtractors, + options) + } + + protected def invalidBucketFile(path: String, sparkVersion: String): Throwable = + QueryExecutionErrors.invalidBucketFile(path) + + // see SPARK-39634 + protected def isNeededForSchema(sparkSchema: StructType): Boolean = false + + protected def getPartitionedFile(f: FileStatusWithMetadata, p: PartitionDirectory): PartitionedFile = + PartitionedFileUtil.getPartitionedFile(f, p.values, 0, f.getLen) + + protected def splitFiles(sparkSession: SparkSession, + file: FileStatusWithMetadata, + filePath: Path, + isSplitable: Boolean, + maxSplitBytes: Long, + partitionValues: InternalRow): Seq[PartitionedFile] = + PartitionedFileUtil.splitFiles(file, isSplitable, maxSplitBytes, partitionValues) +} diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala similarity index 100% rename from spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala similarity index 100% rename from spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/ParquetFilters.scala similarity index 100% rename from spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/ParquetFilters.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala similarity index 100% rename from spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala rename to spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala diff --git a/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala similarity index 100% rename from spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala rename to spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala