From d3a236a27303df7094436fbc3e3390d382de4b15 Mon Sep 17 00:00:00 2001 From: awol2005ex Date: Wed, 29 May 2024 10:04:01 +0800 Subject: [PATCH 01/12] Profile spark3.5.1 and centos7 for compatible on spark 3.5.1 and centos7 old glic 2.7 --- build_for_centos7.sh | 5 + core/Cross.toml | 2 + core/comet_build_env_centos7.dockerfile | 36 + pom.xml | 21 +- ...scala => CometParquetFileFormat.scala.old} | 0 ...etParquetPartitionReaderFactory.scala.old} | 0 ...Filters.scala => ParquetFilters.scala.old} | 0 ...ScanExec.scala => CometScanExec.scala.old} | 0 .../spark/sql/comet/DecimalPrecision.scala | 8 +- .../shuffle/CometShuffleExchangeExec.scala | 8 +- .../parquet/CometParquetFileFormat.scala | 231 +++++ .../CometParquetPartitionReaderFactory.scala | 231 +++++ .../apache/comet/parquet/ParquetFilters.scala | 882 ++++++++++++++++++ .../spark/sql/comet/CometScanExec.scala | 483 ++++++++++ .../parquet/CometParquetFileFormat.scala | 231 +++++ .../CometParquetPartitionReaderFactory.scala | 231 +++++ .../apache/comet/parquet/ParquetFilters.scala | 882 ++++++++++++++++++ .../spark/sql/comet/CometScanExec.scala | 483 ++++++++++ .../parquet/CometParquetFileFormat.scala | 231 +++++ .../CometParquetPartitionReaderFactory.scala | 231 +++++ .../apache/comet/parquet/ParquetFilters.scala | 882 ++++++++++++++++++ .../spark/sql/comet/CometScanExec.scala | 483 ++++++++++ .../parquet/CometParquetFileFormat.scala | 231 +++++ .../CometParquetPartitionReaderFactory.scala | 231 +++++ .../apache/comet/parquet/ParquetFilters.scala | 882 ++++++++++++++++++ .../apache/comet/shims/CometExprShim.scala | 33 + .../spark/sql/comet/CometScanExec.scala | 483 ++++++++++ 27 files changed, 7413 insertions(+), 8 deletions(-) create mode 100644 build_for_centos7.sh create mode 100644 core/Cross.toml create mode 100644 core/comet_build_env_centos7.dockerfile rename spark/src/main/scala/org/apache/comet/parquet/{CometParquetFileFormat.scala => CometParquetFileFormat.scala.old} (100%) rename spark/src/main/scala/org/apache/comet/parquet/{CometParquetPartitionReaderFactory.scala => CometParquetPartitionReaderFactory.scala.old} (100%) rename spark/src/main/scala/org/apache/comet/parquet/{ParquetFilters.scala => ParquetFilters.scala.old} (100%) rename spark/src/main/scala/org/apache/spark/sql/comet/{CometScanExec.scala => CometScanExec.scala.old} (100%) create mode 100644 spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetFileFormat.scala create mode 100644 spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala create mode 100644 spark/src/main/spark-3.2/org/apache/comet/parquet/ParquetFilters.scala create mode 100644 spark/src/main/spark-3.2/org/apache/spark/sql/comet/CometScanExec.scala create mode 100644 spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala create mode 100644 spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala create mode 100644 spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala create mode 100644 spark/src/main/spark-3.3/org/apache/spark/sql/comet/CometScanExec.scala create mode 100644 spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala create mode 100644 spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala create mode 100644 spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala create mode 100644 spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala create mode 100644 spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala create mode 100644 spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala create mode 100644 spark/src/main/spark-3.5/org/apache/comet/parquet/ParquetFilters.scala create mode 100644 spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala create mode 100644 spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala diff --git a/build_for_centos7.sh b/build_for_centos7.sh new file mode 100644 index 000000000..ccd257162 --- /dev/null +++ b/build_for_centos7.sh @@ -0,0 +1,5 @@ +docker build -t comet_build_env_centos7:1.0 -f core/comet_build_env_centos7.dockerfile +export PROFILES="-Dmaven.test.skip=true -Pjdk1.8 -Pspark-3.5 -Dscalastyle.skip -Dspotless.check.skip=true -Denforcer.skip -Prelease-centos7 -Drat.skip=true" +cd core && RUSTFLAGS="-Ctarget-cpu=native" cross build --target x86_64-unknown-linux-gnu --features nightly --release +cd .. +mvn install -Prelease -DskipTests ${PROFILES} \ No newline at end of file diff --git a/core/Cross.toml b/core/Cross.toml new file mode 100644 index 000000000..cbf43363b --- /dev/null +++ b/core/Cross.toml @@ -0,0 +1,2 @@ +[target.x86_64-unknown-linux-gnu] +image = "comet_build_env_centos7:1.0" \ No newline at end of file diff --git a/core/comet_build_env_centos7.dockerfile b/core/comet_build_env_centos7.dockerfile new file mode 100644 index 000000000..026facd49 --- /dev/null +++ b/core/comet_build_env_centos7.dockerfile @@ -0,0 +1,36 @@ +FROM centos:7 +RUN ulimit -n 65536 + +# install common tools +#RUN yum update -y +RUN yum install -y centos-release-scl epel-release +RUN rm -f /var/lib/rpm/__db* && rpm --rebuilddb +RUN yum clean all && rm -rf /var/cache/yum +RUN yum makecache +RUN yum install -y -v git +RUN yum install -y -v libzip unzip wget cmake3 openssl-devel +# install protoc +RUN wget -O /protobuf-21.7-linux-x86_64.zip https://github.com/protocolbuffers/protobuf/releases/download/v21.7/protoc-21.7-linux-x86_64.zip +RUN mkdir /protobuf-bin && (cd /protobuf-bin && unzip /protobuf-21.7-linux-x86_64.zip) +RUN echo 'export PATH="$PATH:/protobuf-bin/bin"' >> ~/.bashrc + + +# install gcc-11 +RUN yum install -y devtoolset-11-gcc devtoolset-11-gcc-c++ +RUN echo '. /opt/rh/devtoolset-11/enable' >> ~/.bashrc + +# install rust nightly toolchain +RUN curl https://sh.rustup.rs > /rustup-init +RUN chmod +x /rustup-init +RUN /rustup-init -y --default-toolchain nightly-2023-08-01-x86_64-unknown-linux-gnu + +RUN echo 'source $HOME/.cargo/env' >> ~/.bashrc + +# install java +RUN yum install -y java-1.8.0-openjdk java-1.8.0-openjdk-devel +RUN echo 'export JAVA_HOME="/usr/lib/jvm/java-1.8.0-openjdk"' >> ~/.bashrc + +# install maven +RUN yum install -y rh-maven35 +RUN echo 'source /opt/rh/rh-maven35/enable' >> ~/.bashrc +RUN yum -y install gcc automake autoconf libtool make diff --git a/pom.xml b/pom.xml index 59e0569ff..1f1f00ecc 100644 --- a/pom.xml +++ b/pom.xml @@ -421,6 +421,13 @@ under the License. + + release-centos7 + + ${project.basedir}/../core/target/x86_64-unknown-linux-gnu/release + + + Win-x86 @@ -530,6 +537,19 @@ under the License. + + spark-3.5 + + 2.12.17 + 3.5.1 + 3.5 + 1.13.1 + spark-3.3-plus + spark-3.5 + spark-3.5 + + + scala-2.13 @@ -652,7 +672,6 @@ under the License. - --> ${scala.version} true true diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala.old similarity index 100% rename from spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala rename to spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala.old diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala.old similarity index 100% rename from spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala rename to spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala.old diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala.old similarity index 100% rename from spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala rename to spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala.old diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala.old similarity index 100% rename from spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala.old diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala b/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala index 13f26ce58..acd86c1ac 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala @@ -59,7 +59,7 @@ object DecimalPrecision { } CheckOverflow(add, resultType, nullOnOverflow) - case sub @ Subtract(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) => + case sub @ Subtract(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => val resultScale = max(s1, s2) val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) @@ -68,7 +68,7 @@ object DecimalPrecision { } CheckOverflow(sub, resultType, nullOnOverflow) - case mul @ Multiply(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) => + case mul @ Multiply(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) } else { @@ -76,7 +76,7 @@ object DecimalPrecision { } CheckOverflow(mul, resultType, nullOnOverflow) - case div @ Divide(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) => + case div @ Divide(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => val resultType = if (allowPrecisionLoss) { // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) // Scale: max(6, s1 + p2 + 1) @@ -96,7 +96,7 @@ object DecimalPrecision { } CheckOverflow(div, resultType, nullOnOverflow) - case rem @ Remainder(DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2), _) => + case rem @ Remainder(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } else { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 49c263f3f..d73b0d908 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.StructField import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} @@ -386,7 +387,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { val pageSize = SparkEnv.get.memoryManager.pageSizeBytes val sorter = UnsafeExternalRowSorter.createWithRecordComparator( - StructType.fromAttributes(outputAttributes), + //StructType.fromAttributes(outputAttributes), + StructType(outputAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))), recordComparatorSupplier, prefixComparator, prefixComputer, @@ -430,8 +432,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { serializer, shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics), shuffleType = CometColumnarShuffle, - schema = Some(StructType.fromAttributes(outputAttributes))) - + //schema = Some(StructType.fromAttributes(outputAttributes))) + schema = Some(StructType(outputAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))) dependency } } diff --git a/spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetFileFormat.scala new file mode 100644 index 000000000..ac871cf60 --- /dev/null +++ b/spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetFileFormat.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import scala.collection.JavaConverters + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop.ParquetInputFormat +import org.apache.parquet.hadoop.metadata.FileMetaData +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.RecordReaderIterator +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.{DateType, StructType, TimestampType} +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.CometConf +import org.apache.comet.MetricsSupport +import org.apache.comet.shims.ShimSQLConf +import org.apache.comet.vector.CometVector + +/** + * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's + * [[ParquetFileFormat]], but overrides: + * + * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap + * column vector in the whole-stage codegen path. + * - `supportBatch`, which simply returns true since data types should have already been checked + * in [[org.apache.comet.CometSparkSessionExtensions]] + * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values. + */ +class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf { + override def shortName(): String = "parquet" + override def toString: String = "CometParquet" + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat] + + override def vectorTypes( + requiredSchema: StructType, + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { + val length = requiredSchema.fields.length + partitionSchema.fields.length + Option(Seq.fill(length)(classOf[CometVector].getName)) + } + + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val sqlConf = sparkSession.sessionState.conf + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + val useFieldId = CometParquetUtils.readFieldId(sqlConf) + val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val optionsMap = CaseInsensitiveMap[String](options) + val parquetOptions = new ParquetOptions(optionsMap, sqlConf) + val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead + val parquetFilterPushDown = sqlConf.parquetFilterPushDown + + // Comet specific configurations + val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) + + (file: PartitionedFile) => { + val sharedConf = broadcastedHadoopConf.value.value + val footer = FooterReader.readFooter(sharedConf, file) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( + file, + requiredSchema, + sharedConf, + footerFileMetaData, + datetimeRebaseModeInRead) + + val pushed = if (parquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can + // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a + // `flatMap` is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) + + val batchReader = new BatchReader( + sharedConf, + file, + footer, + capacity, + requiredSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + val iter = new RecordReaderIterator(batchReader) + try { + batchReader.init() + iter.asInstanceOf[Iterator[InternalRow]] + } catch { + case e: Throwable => + iter.close() + throw e + } + } + } +} + +object CometParquetFileFormat extends Logging { + + /** + * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` + */ + def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = { + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sqlConf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlConf.isParquetINT96AsTimestamp) + + // Comet specific configs + hadoopConf.setBoolean( + CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key, + CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get()) + hadoopConf.setBoolean( + CometConf.COMET_USE_DECIMAL_128.key, + CometConf.COMET_USE_DECIMAL_128.get()) + hadoopConf.setBoolean( + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get()) + } + + def getDatetimeRebaseSpec( + file: PartitionedFile, + sparkSchema: StructType, + sharedConf: Configuration, + footerFileMetaData: FileMetaData, + datetimeRebaseModeInRead: String): RebaseSpec = { + val exceptionOnRebase = sharedConf.getBoolean( + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get) + var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val hasDateOrTimestamp = sparkSchema.exists(f => + f.dataType match { + case DateType | TimestampType => true + case _ => false + }) + + if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { + if (exceptionOnRebase) { + logWarning( + s"""Found Parquet file $file that could potentially contain dates/timestamps that were + written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase + and return these according to the new Proleptic Gregorian calendar, Comet will throw + exception when reading them. If you want to read them as it is according to the hybrid + Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to + false. Otherwise, if you want to read them according to the new Proleptic Gregorian + calendar, please disable Comet for this query.""") + } else { + // do not throw exception on rebase - read as it is + datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) + } + } + + datetimeRebaseSpec + } +} diff --git a/spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala new file mode 100644 index 000000000..693af125b --- /dev/null +++ b/spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import scala.collection.JavaConverters +import scala.collection.mutable + +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.parquet.hadoop.ParquetInputFormat +import org.apache.parquet.hadoop.metadata.ParquetMetadata +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.{CometConf, CometRuntimeException} +import org.apache.comet.shims.ShimSQLConf + +case class CometParquetPartitionReaderFactory( + @transient sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + readDataSchema: StructType, + partitionSchema: StructType, + filters: Array[Filter], + options: ParquetOptions, + metrics: Map[String, SQLMetric]) + extends FilePartitionReaderFactory + with ShimSQLConf + with Logging { + + private val isCaseSensitive = sqlConf.caseSensitiveAnalysis + private val useFieldId = CometParquetUtils.readFieldId(sqlConf) + private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) + private val pushDownDate = sqlConf.parquetFilterPushDownDate + private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) + private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead + private val parquetFilterPushDown = sqlConf.parquetFilterPushDown + + // Comet specific configurations + private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf) + + // This is only called at executor on a Broadcast variable, so we don't want it to be + // materialized at driver. + @transient private lazy val preFetchEnabled = { + val conf = broadcastedConf.value.value + + conf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + } + + private var cometReaders: Iterator[BatchReader] = _ + private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]() + + // TODO: we may want to revisit this as we're going to only support flat types at the beginning + override def supportColumnarReads(partition: InputPartition): Boolean = true + + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + if (preFetchEnabled) { + val filePartition = partition.asInstanceOf[FilePartition] + val conf = broadcastedConf.value.value + + val threadNum = conf.getInt( + CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key, + CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get) + val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum) + + this.cometReaders = filePartition.files + .map { file => + // `init()` call is deferred to when the prefetch task begins. + // Otherwise we will hold too many resources for readers which are not ready + // to prefetch. + val cometReader = buildCometReader(file) + if (cometReader != null) { + cometReader.submitPrefetchTask(prefetchThreadPool) + } + + cometReader + } + .toSeq + .toIterator + } + + super.createColumnarReader(partition) + } + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = + throw new UnsupportedOperationException("Comet doesn't support 'buildReader'") + + private def buildCometReader(file: PartitionedFile): BatchReader = { + val conf = broadcastedConf.value.value + + try { + val (datetimeRebaseSpec, footer, filters) = getFilter(file) + filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) + val cometReader = new BatchReader( + conf, + file, + footer, + batchSize, + readDataSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + val taskContext = Option(TaskContext.get) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) + return cometReader + } catch { + case e: Throwable if preFetchEnabled => + // Keep original exception + cometReaderExceptionMap.put(file, e) + } + null + } + + override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { + val cometReader = if (!preFetchEnabled) { + // Prefetch is not enabled, create comet reader and initiate it. + val cometReader = buildCometReader(file) + cometReader.init() + + cometReader + } else { + // If prefetch is enabled, we already tried to access the file when in `buildCometReader`. + // It is possibly we got an exception like `FileNotFoundException` and we need to throw it + // now to let Spark handle it. + val reader = cometReaders.next() + val exception = cometReaderExceptionMap.get(file) + exception.foreach(e => throw e) + + if (reader == null) { + throw new CometRuntimeException(s"Cannot find comet file reader for $file") + } + reader + } + CometPartitionReader(cometReader) + } + + def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = { + val sharedConf = broadcastedConf.value.value + val footer = FooterReader.readFooter(sharedConf, file) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( + file, + readDataSchema, + sharedConf, + footerFileMetaData, + datetimeRebaseModeInRead) + + val pushed = if (parquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + (datetimeRebaseSpec, footer, pushed) + } + + override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = + throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.") + + /** + * A simple adapter on Comet's [[BatchReader]]. + */ + protected case class CometPartitionReader(reader: BatchReader) + extends PartitionReader[ColumnarBatch] { + + override def next(): Boolean = { + reader.nextBatch() + } + + override def get(): ColumnarBatch = { + reader.currentBatch() + } + + override def close(): Unit = { + reader.close() + } + } +} diff --git a/spark/src/main/spark-3.2/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-3.2/org/apache/comet/parquet/ParquetFilters.scala new file mode 100644 index 000000000..5994dfb41 --- /dev/null +++ b/spark/src/main/spark-3.2/org/apache/comet/parquet/ParquetFilters.scala @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} +import java.time.{Duration, Instant, LocalDate, Period} +import java.util.Locale + +import scala.collection.JavaConverters.asScalaBufferConverter + +import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} +import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.SparkFilterApi._ +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type} +import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sources +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus + +/** + * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove + * this duplication + * + * Some utility function to convert Spark data source filters to Parquet filters. + */ +class ParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStringPredicate: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean, + datetimeRebaseSpec: RebaseSpec) { + // A map which contains parquet field name and data type, if predicate push down applies. + // + // Each key in `nameToParquetField` represents a column; `dots` are used as separators for + // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + private val nameToParquetField: Map[String, ParquetPrimitiveField] = { + // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. + // `parentFieldNames` is used to keep track of the current nested level when traversing. + def getPrimitiveFields( + fields: Seq[Type], + parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { + fields.flatMap { + // Parquet only supports predicate push-down for non-repeated primitive types. + // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for + // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34) + case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => + Some( + ParquetPrimitiveField( + fieldNames = parentFieldNames :+ p.getName, + fieldType = ParquetSchemaType( + p.getLogicalTypeAnnotation, + p.getPrimitiveTypeName, + p.getTypeLength))) + // Note that when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => + getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are removed. + case _ => None + } + } + + val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => + (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields.toMap) + } + } + + /** + * Holds a single primitive field information stored in the underlying parquet file. + * + * @param fieldNames + * a field name as an array of string multi-identifier in parquet file + * @param fieldType + * field type related info in parquet file + */ + private case class ParquetPrimitiveField( + fieldNames: Array[String], + fieldType: ParquetSchemaType) + + private case class ParquetSchemaType( + logicalTypeAnnotation: LogicalTypeAnnotation, + primitiveTypeName: PrimitiveTypeName, + length: Int) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0) + private val ParquetByteType = + ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0) + private val ParquetShortType = + ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0) + private val ParquetStringType = + ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0) + private val ParquetDateType = + ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0) + private val ParquetTimestampMicrosType = + ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0) + private val ParquetTimestampMillisType = + ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0) + + private def dateToDays(date: Any): Int = { + val gregorianDays = date match { + case d: Date => DateTimeUtils.fromJavaDate(d) + case ld: LocalDate => DateTimeUtils.localDateToDays(ld) + } + datetimeRebaseSpec.mode match { + case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) + case _ => gregorianDays + } + } + + private def timestampToMicros(v: Any): JLong = { + val gregorianMicros = v match { + case i: Instant => DateTimeUtils.instantToMicros(i) + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + } + datetimeRebaseSpec.mode match { + case LegacyBehaviorPolicy.LEGACY => + rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) + case _ => gregorianMicros + } + } + + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private def timestampToMillis(v: Any): JLong = { + val micros = timestampToMicros(v) + val millis = DateTimeUtils.microsToMillis(micros) + millis.asInstanceOf[JLong] + } + + private def toIntValue(v: Any): Integer = { + Option(v) + .map { + case p: Period => IntervalUtils.periodToMonths(p) + case n => n.asInstanceOf[Number].intValue + } + .map(_.asInstanceOf[Integer]) + .orNull + } + + private def toLongValue(v: Any): JLong = v match { + case d: Duration => IntervalUtils.durationToMicros(d) + case l => l.asInstanceOf[JLong] + } + + private val makeEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) + + // Binary.fromString and Binary.fromByteArray don't accept null values + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeNotEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeLt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeLtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeInPredicate: PartialFunction[ + ParquetSchemaType, + (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(toIntValue(_).toInt).foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetLongType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(toLongValue).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetFloatType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]), + FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat])) + + case ParquetDoubleType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]), + FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble])) + + case ParquetStringType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) + + case ParquetBinaryType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]])) + .foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) + + case ParquetDateType if pushDownDate => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(timestampToMicros).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(timestampToMillis).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)) + .foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary])) + } + + // Returns filters that can be pushed down when reading Parquet files. + def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { + filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) + } + + private def convertibleFiltersHelper( + predicate: sources.Filter, + canPartialPushDown: Boolean): Option[sources.Filter] = { + predicate match { + case sources.And(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + (leftResultOptional, rightResultOptional) match { + case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) + case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) + case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) + case _ => None + } + + case sources.Or(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { + None + } else { + Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) + } + case sources.Not(pred) => + val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) + resultOptional.map(sources.Not) + + case other => + if (createFilter(other).isDefined) { + Some(other) + } else { + None + } + } + } + + /** + * Converts data sources filters to Parquet filter predicates. + */ + def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { + createFilterHelper(predicate, canPartialPushDownConjuncts = true) + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToParquetField(name).fieldType match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => + if (isSpark34Plus) { + value match { + // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type + // Int. We don't create a filter if the value would overflow. + case _: JByte | _: JShort | _: Integer => true + case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue + case _ => false + } + } else { + // If not Spark 3.4+, we still following the old behavior as Spark does. + value.isInstanceOf[Number] + } + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => + value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] + case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) => + isDecimalMatched(value, decimalType) + case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) => + isDecimalMatched(value, decimalType) + case ParquetSchemaType( + decimalType: DecimalLogicalTypeAnnotation, + FIXED_LEN_BYTE_ARRAY, + _) => + isDecimalMatched(value, decimalType) + case _ => false + }) + } + + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + private def isDecimalMatched( + value: Any, + decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalLogicalType.getScale + case _ => false + } + + private def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) + } + + /** + * @param predicate + * the input filter predicates. Not all the predicates can be pushed down. + * @param canPartialPushDownConjuncts + * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one + * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. + * @return + * the Parquet-native filter predicates that are eligible for pushdown. + */ + private def createFilterHelper( + predicate: sources.Filter, + canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `PruneFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + + predicate match { + case sources.IsNull(name) if canMakeFilterOn(name, null) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => + makeLt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeLtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => + makeGt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeGtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.And(lhs, rhs) => + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val lhsFilterOption = + createFilterHelper(lhs, canPartialPushDownConjuncts) + val rhsFilterOption = + createFilterHelper(rhs, canPartialPushDownConjuncts) + + (lhsFilterOption, rhsFilterOption) match { + case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) + case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) + case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) + case _ => None + } + + case sources.Or(lhs, rhs) => + // The Or predicate is convertible when both of its children can be pushed down. + // That is to say, if one/both of the children can be partially pushed down, the Or + // predicate can be partially pushed down as well. + // + // Here is an example used to explain the reason. + // Let's say we have + // (a1 AND a2) OR (b1 AND b2), + // a1 and b1 is convertible, while a2 and b2 is not. + // The predicate can be converted as + // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) + // As per the logical in And predicate, we can push down (a1 OR b1). + for { + lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) + rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case sources.Not(pred) => + createFilterHelper(pred, canPartialPushDownConjuncts = false) + .map(FilterApi.not) + + case sources.In(name, values) + if pushDownInFilterThreshold > 0 && values.nonEmpty && + canMakeFilterOn(name, values.head) => + val fieldType = nameToParquetField(name).fieldType + val fieldNames = nameToParquetField(name).fieldNames + if (values.length <= pushDownInFilterThreshold) { + values.distinct + .flatMap { v => + makeEq.lift(fieldType).map(_(fieldNames, v)) + } + .reduceLeftOption(FilterApi.or) + } else if (canPartialPushDownConjuncts) { + val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType + val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType) + if (values.contains(null)) { + Seq( + makeEq.lift(fieldType).map(_(fieldNames, null)), + makeInPredicate + .lift(fieldType) + .map(_(fieldNames, values.filter(_ != null), statistics))).flatten + .reduceLeftOption(FilterApi.or) + } else { + makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics)) + } + } else { + None + } + + case sources.StringStartsWith(name, prefix) + if pushDownStringPredicate && canMakeFilterOn(name, prefix) => + Option(prefix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + value != null && UTF8String + .fromBytes(value.getBytes) + .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) + } + }) + } + + case sources.StringEndsWith(name, suffix) + if pushDownStringPredicate && canMakeFilterOn(name, suffix) => + Option(suffix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val suffixStr = UTF8String.fromString(v) + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + override def keep(value: Binary): Boolean = { + value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr) + } + }) + } + + case sources.StringContains(name, value) + if pushDownStringPredicate && canMakeFilterOn(name, value) => + Option(value).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val subStr = UTF8String.fromString(v) + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + override def keep(value: Binary): Boolean = { + value != null && UTF8String.fromBytes(value.getBytes).contains(subStr) + } + }) + } + + case _ => None + } + } +} diff --git a/spark/src/main/spark-3.2/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.2/org/apache/spark/sql/comet/CometScanExec.scala new file mode 100644 index 000000000..14a664108 --- /dev/null +++ b/spark/src/main/spark-3.2/org/apache/spark/sql/comet/CometScanExec.scala @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.mutable.HashMap +import scala.concurrent.duration.NANOSECONDS +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection._ + +import org.apache.comet.{CometConf, MetricsSupport} +import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} +import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} + +/** + * Comet physical scan node for DataSource V1. Most of the code here follow Spark's + * [[FileSourceScanExec]], + */ +case class CometScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false, + wrapped: FileSourceScanExec) + extends DataSourceScanExec + with ShimCometScanExec + with CometPlan { + + // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests + override lazy val supportsColumnar: Boolean = + relation.fileFormat.supportBatch(relation.sparkSession, schema) + + override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes + + private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has been + * initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) + val startTime = System.nanoTime() + val ret = + relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + setFilesNumAndSizeMetric(ret, true) + val timeTakenMs = + NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) + driverMetrics("metadataTime") = timeTakenMs + ret + }.toArray + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create( + predicate.transform { case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, + Nil) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } + + // exposed for testing + lazy val bucketedScan: Boolean = wrapped.bucketedScan + + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = + (wrapped.outputPartitioning, wrapped.outputOrdering) + + @transient + private lazy val pushedDownFilters = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } + + override lazy val metadata: Map[String, String] = + if (wrapped == null) Map.empty else wrapped.metadata + + override def verboseStringWithOperatorId(): String = { + getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) + wrapped.verboseStringWithOperatorId() + } + + lazy val inputRDD: RDD[InternalRow] = { + val options = relation.options + + (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) + val readFile: (PartitionedFile) => Iterator[InternalRow] = + relation.fileFormat.buildReaderWithPartitionValues( + sparkSession = relation.sparkSession, + dataSchema = relation.dataSchema, + partitionSchema = relation.partitionSchema, + requiredSchema = requiredSchema, + filters = pushedDownFilters, + options = options, + hadoopConf = + relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) + + val readRDD = if (bucketedScan) { + createBucketedReadRDD( + relation.bucketSpec.get, + readFile, + dynamicallySelectedPartitions, + relation) + } else { + createReadRDD(readFile, dynamicallySelectedPartitions, relation) + } + sendDriverMetrics() + readRDD + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + /** Helper for computing total number and size of files in selected partitions. */ + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchema.nonEmpty) { + driverMetrics("numPartitions") = partitions.length + } + } + + override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + relation.fileFormat match { + case f: MetricsSupport => f.initMetrics(sparkContext) + case _ => Map.empty + } + } + + protected override def doExecute(): RDD[InternalRow] = { + ColumnarToRowExec(this).doExecute() + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = batches.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + + override def next(): ColumnarBatch = { + val batch = batches.next() + numOutputRows += batch.numRows() + batch + } + } + } + } + + override def executeCollect(): Array[InternalRow] = { + ColumnarToRowExec(this).executeCollect() + } + + override val nodeName: String = + s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" + + /** + * Create an RDD for bucketed reads. The non-bucketed variant of this function is + * [[createReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the + * files with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec + * the bucketing spec. + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createBucketedReadRDD( + bucketSpec: BucketSpec, + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") + val filesGroupedToBuckets = + selectedPartitions + .flatMap { p => + p.files.map { f => + PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) + } + } + .groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath.toString()).getName) + .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) + } + + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { f => + bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets + .map { numCoalescedBuckets => + logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { bucketId => + val partitionedFiles = coalescedBuckets + .get(bucketId) + .map { + _.values.flatten.toArray + } + .getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + } + .getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + + prepareRDD(fsRelation, readFile, filePartitions) + } + + /** + * Create an RDD for non-bucketed reads. The bucketed variant of this function is + * [[createBucketedReadRDD]]. + * + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createReadRDD( + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) + logInfo( + s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + // Filter files with bucket pruning if possible + val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + + val splitFiles = selectedPartitions + .flatMap { partition => + partition.files.flatMap { file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + + if (shouldProcess(filePath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, + relation.options, + filePath) && + // SPARK-39634: Allow file splitting in combination with row index generation once + // the fix for PARQUET-2161 is available. + !isNeededForSchema(requiredSchema) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values) + } else { + Seq.empty + } + } + } + .sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + prepareRDD( + fsRelation, + readFile, + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) + } + + private def prepareRDD( + fsRelation: HadoopFsRelation, + readFile: (PartitionedFile) => Iterator[InternalRow], + partitions: Seq[FilePartition]): RDD[InternalRow] = { + val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) + val prefetchEnabled = hadoopConf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + + val sqlConf = fsRelation.sparkSession.sessionState.conf + if (prefetchEnabled) { + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedConf = + fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val partitionReaderFactory = CometParquetPartitionReaderFactory( + sqlConf, + broadcastedConf, + requiredSchema, + relation.partitionSchema, + pushedDownFilters.toArray, + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), + metrics) + + newDataSourceRDD( + fsRelation.sparkSession.sparkContext, + partitions.map(Seq(_)), + partitionReaderFactory, + true, + Map.empty) + } else { + newFileScanRDD( + fsRelation.sparkSession, + readFile, + partitions, + new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) + } + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + private def filterUnusedDynamicPruningExpressions( + predicates: Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } + + override def doCanonicalize(): CometScanExec = { + CometScanExec( + relation, + output.map(QueryPlan.normalizeExpressions(_, output)), + requiredSchema, + QueryPlan.normalizePredicates( + filterUnusedDynamicPruningExpressions(partitionFilters), + output), + optionalBucketSet, + optionalNumCoalescedBuckets, + QueryPlan.normalizePredicates(dataFilters, output), + None, + disableBucketedScan, + null) + } +} + +object CometScanExec { + def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { + // TreeNode.mapProductIterator is protected method. + def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { + val arr = Array.ofDim[B](product.productArity) + var i = 0 + while (i < arr.length) { + arr(i) = f(product.productElement(i)) + i += 1 + } + arr + } + + // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues + // on other Spark distributions if FileSourceScanExec constructor is changed. + // Using `makeCopy` to avoid the issue. + // https://github.com/apache/arrow-datafusion-comet/issues/190 + def transform(arg: Any): AnyRef = arg match { + case _: HadoopFsRelation => + scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) + case other: AnyRef => other + case null => null + } + val newArgs = mapProductIterator(scanExec, transform(_)) + val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] + val batchScanExec = CometScanExec( + wrapped.relation, + wrapped.output, + wrapped.requiredSchema, + wrapped.partitionFilters, + wrapped.optionalBucketSet, + wrapped.optionalNumCoalescedBuckets, + wrapped.dataFilters, + wrapped.tableIdentifier, + wrapped.disableBucketedScan, + wrapped) + scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) + batchScanExec + } +} diff --git a/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala new file mode 100644 index 000000000..ac871cf60 --- /dev/null +++ b/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import scala.collection.JavaConverters + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop.ParquetInputFormat +import org.apache.parquet.hadoop.metadata.FileMetaData +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.RecordReaderIterator +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.{DateType, StructType, TimestampType} +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.CometConf +import org.apache.comet.MetricsSupport +import org.apache.comet.shims.ShimSQLConf +import org.apache.comet.vector.CometVector + +/** + * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's + * [[ParquetFileFormat]], but overrides: + * + * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap + * column vector in the whole-stage codegen path. + * - `supportBatch`, which simply returns true since data types should have already been checked + * in [[org.apache.comet.CometSparkSessionExtensions]] + * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values. + */ +class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf { + override def shortName(): String = "parquet" + override def toString: String = "CometParquet" + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat] + + override def vectorTypes( + requiredSchema: StructType, + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { + val length = requiredSchema.fields.length + partitionSchema.fields.length + Option(Seq.fill(length)(classOf[CometVector].getName)) + } + + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val sqlConf = sparkSession.sessionState.conf + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + val useFieldId = CometParquetUtils.readFieldId(sqlConf) + val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val optionsMap = CaseInsensitiveMap[String](options) + val parquetOptions = new ParquetOptions(optionsMap, sqlConf) + val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead + val parquetFilterPushDown = sqlConf.parquetFilterPushDown + + // Comet specific configurations + val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) + + (file: PartitionedFile) => { + val sharedConf = broadcastedHadoopConf.value.value + val footer = FooterReader.readFooter(sharedConf, file) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( + file, + requiredSchema, + sharedConf, + footerFileMetaData, + datetimeRebaseModeInRead) + + val pushed = if (parquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can + // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a + // `flatMap` is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) + + val batchReader = new BatchReader( + sharedConf, + file, + footer, + capacity, + requiredSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + val iter = new RecordReaderIterator(batchReader) + try { + batchReader.init() + iter.asInstanceOf[Iterator[InternalRow]] + } catch { + case e: Throwable => + iter.close() + throw e + } + } + } +} + +object CometParquetFileFormat extends Logging { + + /** + * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` + */ + def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = { + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sqlConf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlConf.isParquetINT96AsTimestamp) + + // Comet specific configs + hadoopConf.setBoolean( + CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key, + CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get()) + hadoopConf.setBoolean( + CometConf.COMET_USE_DECIMAL_128.key, + CometConf.COMET_USE_DECIMAL_128.get()) + hadoopConf.setBoolean( + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get()) + } + + def getDatetimeRebaseSpec( + file: PartitionedFile, + sparkSchema: StructType, + sharedConf: Configuration, + footerFileMetaData: FileMetaData, + datetimeRebaseModeInRead: String): RebaseSpec = { + val exceptionOnRebase = sharedConf.getBoolean( + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get) + var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val hasDateOrTimestamp = sparkSchema.exists(f => + f.dataType match { + case DateType | TimestampType => true + case _ => false + }) + + if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { + if (exceptionOnRebase) { + logWarning( + s"""Found Parquet file $file that could potentially contain dates/timestamps that were + written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase + and return these according to the new Proleptic Gregorian calendar, Comet will throw + exception when reading them. If you want to read them as it is according to the hybrid + Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to + false. Otherwise, if you want to read them according to the new Proleptic Gregorian + calendar, please disable Comet for this query.""") + } else { + // do not throw exception on rebase - read as it is + datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) + } + } + + datetimeRebaseSpec + } +} diff --git a/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala new file mode 100644 index 000000000..693af125b --- /dev/null +++ b/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import scala.collection.JavaConverters +import scala.collection.mutable + +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.parquet.hadoop.ParquetInputFormat +import org.apache.parquet.hadoop.metadata.ParquetMetadata +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.{CometConf, CometRuntimeException} +import org.apache.comet.shims.ShimSQLConf + +case class CometParquetPartitionReaderFactory( + @transient sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + readDataSchema: StructType, + partitionSchema: StructType, + filters: Array[Filter], + options: ParquetOptions, + metrics: Map[String, SQLMetric]) + extends FilePartitionReaderFactory + with ShimSQLConf + with Logging { + + private val isCaseSensitive = sqlConf.caseSensitiveAnalysis + private val useFieldId = CometParquetUtils.readFieldId(sqlConf) + private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) + private val pushDownDate = sqlConf.parquetFilterPushDownDate + private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) + private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead + private val parquetFilterPushDown = sqlConf.parquetFilterPushDown + + // Comet specific configurations + private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf) + + // This is only called at executor on a Broadcast variable, so we don't want it to be + // materialized at driver. + @transient private lazy val preFetchEnabled = { + val conf = broadcastedConf.value.value + + conf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + } + + private var cometReaders: Iterator[BatchReader] = _ + private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]() + + // TODO: we may want to revisit this as we're going to only support flat types at the beginning + override def supportColumnarReads(partition: InputPartition): Boolean = true + + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + if (preFetchEnabled) { + val filePartition = partition.asInstanceOf[FilePartition] + val conf = broadcastedConf.value.value + + val threadNum = conf.getInt( + CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key, + CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get) + val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum) + + this.cometReaders = filePartition.files + .map { file => + // `init()` call is deferred to when the prefetch task begins. + // Otherwise we will hold too many resources for readers which are not ready + // to prefetch. + val cometReader = buildCometReader(file) + if (cometReader != null) { + cometReader.submitPrefetchTask(prefetchThreadPool) + } + + cometReader + } + .toSeq + .toIterator + } + + super.createColumnarReader(partition) + } + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = + throw new UnsupportedOperationException("Comet doesn't support 'buildReader'") + + private def buildCometReader(file: PartitionedFile): BatchReader = { + val conf = broadcastedConf.value.value + + try { + val (datetimeRebaseSpec, footer, filters) = getFilter(file) + filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) + val cometReader = new BatchReader( + conf, + file, + footer, + batchSize, + readDataSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + val taskContext = Option(TaskContext.get) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) + return cometReader + } catch { + case e: Throwable if preFetchEnabled => + // Keep original exception + cometReaderExceptionMap.put(file, e) + } + null + } + + override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { + val cometReader = if (!preFetchEnabled) { + // Prefetch is not enabled, create comet reader and initiate it. + val cometReader = buildCometReader(file) + cometReader.init() + + cometReader + } else { + // If prefetch is enabled, we already tried to access the file when in `buildCometReader`. + // It is possibly we got an exception like `FileNotFoundException` and we need to throw it + // now to let Spark handle it. + val reader = cometReaders.next() + val exception = cometReaderExceptionMap.get(file) + exception.foreach(e => throw e) + + if (reader == null) { + throw new CometRuntimeException(s"Cannot find comet file reader for $file") + } + reader + } + CometPartitionReader(cometReader) + } + + def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = { + val sharedConf = broadcastedConf.value.value + val footer = FooterReader.readFooter(sharedConf, file) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( + file, + readDataSchema, + sharedConf, + footerFileMetaData, + datetimeRebaseModeInRead) + + val pushed = if (parquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + (datetimeRebaseSpec, footer, pushed) + } + + override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = + throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.") + + /** + * A simple adapter on Comet's [[BatchReader]]. + */ + protected case class CometPartitionReader(reader: BatchReader) + extends PartitionReader[ColumnarBatch] { + + override def next(): Boolean = { + reader.nextBatch() + } + + override def get(): ColumnarBatch = { + reader.currentBatch() + } + + override def close(): Unit = { + reader.close() + } + } +} diff --git a/spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala new file mode 100644 index 000000000..5994dfb41 --- /dev/null +++ b/spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} +import java.time.{Duration, Instant, LocalDate, Period} +import java.util.Locale + +import scala.collection.JavaConverters.asScalaBufferConverter + +import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} +import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.SparkFilterApi._ +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type} +import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sources +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus + +/** + * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove + * this duplication + * + * Some utility function to convert Spark data source filters to Parquet filters. + */ +class ParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStringPredicate: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean, + datetimeRebaseSpec: RebaseSpec) { + // A map which contains parquet field name and data type, if predicate push down applies. + // + // Each key in `nameToParquetField` represents a column; `dots` are used as separators for + // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + private val nameToParquetField: Map[String, ParquetPrimitiveField] = { + // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. + // `parentFieldNames` is used to keep track of the current nested level when traversing. + def getPrimitiveFields( + fields: Seq[Type], + parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { + fields.flatMap { + // Parquet only supports predicate push-down for non-repeated primitive types. + // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for + // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34) + case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => + Some( + ParquetPrimitiveField( + fieldNames = parentFieldNames :+ p.getName, + fieldType = ParquetSchemaType( + p.getLogicalTypeAnnotation, + p.getPrimitiveTypeName, + p.getTypeLength))) + // Note that when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => + getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are removed. + case _ => None + } + } + + val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => + (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields.toMap) + } + } + + /** + * Holds a single primitive field information stored in the underlying parquet file. + * + * @param fieldNames + * a field name as an array of string multi-identifier in parquet file + * @param fieldType + * field type related info in parquet file + */ + private case class ParquetPrimitiveField( + fieldNames: Array[String], + fieldType: ParquetSchemaType) + + private case class ParquetSchemaType( + logicalTypeAnnotation: LogicalTypeAnnotation, + primitiveTypeName: PrimitiveTypeName, + length: Int) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0) + private val ParquetByteType = + ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0) + private val ParquetShortType = + ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0) + private val ParquetStringType = + ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0) + private val ParquetDateType = + ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0) + private val ParquetTimestampMicrosType = + ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0) + private val ParquetTimestampMillisType = + ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0) + + private def dateToDays(date: Any): Int = { + val gregorianDays = date match { + case d: Date => DateTimeUtils.fromJavaDate(d) + case ld: LocalDate => DateTimeUtils.localDateToDays(ld) + } + datetimeRebaseSpec.mode match { + case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) + case _ => gregorianDays + } + } + + private def timestampToMicros(v: Any): JLong = { + val gregorianMicros = v match { + case i: Instant => DateTimeUtils.instantToMicros(i) + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + } + datetimeRebaseSpec.mode match { + case LegacyBehaviorPolicy.LEGACY => + rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) + case _ => gregorianMicros + } + } + + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private def timestampToMillis(v: Any): JLong = { + val micros = timestampToMicros(v) + val millis = DateTimeUtils.microsToMillis(micros) + millis.asInstanceOf[JLong] + } + + private def toIntValue(v: Any): Integer = { + Option(v) + .map { + case p: Period => IntervalUtils.periodToMonths(p) + case n => n.asInstanceOf[Number].intValue + } + .map(_.asInstanceOf[Integer]) + .orNull + } + + private def toLongValue(v: Any): JLong = v match { + case d: Duration => IntervalUtils.durationToMicros(d) + case l => l.asInstanceOf[JLong] + } + + private val makeEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) + + // Binary.fromString and Binary.fromByteArray don't accept null values + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeNotEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeLt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeLtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeInPredicate: PartialFunction[ + ParquetSchemaType, + (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(toIntValue(_).toInt).foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetLongType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(toLongValue).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetFloatType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]), + FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat])) + + case ParquetDoubleType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]), + FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble])) + + case ParquetStringType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) + + case ParquetBinaryType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]])) + .foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) + + case ParquetDateType if pushDownDate => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(timestampToMicros).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(timestampToMillis).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)) + .foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary])) + } + + // Returns filters that can be pushed down when reading Parquet files. + def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { + filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) + } + + private def convertibleFiltersHelper( + predicate: sources.Filter, + canPartialPushDown: Boolean): Option[sources.Filter] = { + predicate match { + case sources.And(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + (leftResultOptional, rightResultOptional) match { + case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) + case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) + case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) + case _ => None + } + + case sources.Or(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { + None + } else { + Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) + } + case sources.Not(pred) => + val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) + resultOptional.map(sources.Not) + + case other => + if (createFilter(other).isDefined) { + Some(other) + } else { + None + } + } + } + + /** + * Converts data sources filters to Parquet filter predicates. + */ + def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { + createFilterHelper(predicate, canPartialPushDownConjuncts = true) + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToParquetField(name).fieldType match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => + if (isSpark34Plus) { + value match { + // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type + // Int. We don't create a filter if the value would overflow. + case _: JByte | _: JShort | _: Integer => true + case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue + case _ => false + } + } else { + // If not Spark 3.4+, we still following the old behavior as Spark does. + value.isInstanceOf[Number] + } + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => + value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] + case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) => + isDecimalMatched(value, decimalType) + case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) => + isDecimalMatched(value, decimalType) + case ParquetSchemaType( + decimalType: DecimalLogicalTypeAnnotation, + FIXED_LEN_BYTE_ARRAY, + _) => + isDecimalMatched(value, decimalType) + case _ => false + }) + } + + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + private def isDecimalMatched( + value: Any, + decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalLogicalType.getScale + case _ => false + } + + private def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) + } + + /** + * @param predicate + * the input filter predicates. Not all the predicates can be pushed down. + * @param canPartialPushDownConjuncts + * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one + * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. + * @return + * the Parquet-native filter predicates that are eligible for pushdown. + */ + private def createFilterHelper( + predicate: sources.Filter, + canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `PruneFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + + predicate match { + case sources.IsNull(name) if canMakeFilterOn(name, null) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => + makeLt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeLtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => + makeGt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeGtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.And(lhs, rhs) => + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val lhsFilterOption = + createFilterHelper(lhs, canPartialPushDownConjuncts) + val rhsFilterOption = + createFilterHelper(rhs, canPartialPushDownConjuncts) + + (lhsFilterOption, rhsFilterOption) match { + case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) + case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) + case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) + case _ => None + } + + case sources.Or(lhs, rhs) => + // The Or predicate is convertible when both of its children can be pushed down. + // That is to say, if one/both of the children can be partially pushed down, the Or + // predicate can be partially pushed down as well. + // + // Here is an example used to explain the reason. + // Let's say we have + // (a1 AND a2) OR (b1 AND b2), + // a1 and b1 is convertible, while a2 and b2 is not. + // The predicate can be converted as + // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) + // As per the logical in And predicate, we can push down (a1 OR b1). + for { + lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) + rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case sources.Not(pred) => + createFilterHelper(pred, canPartialPushDownConjuncts = false) + .map(FilterApi.not) + + case sources.In(name, values) + if pushDownInFilterThreshold > 0 && values.nonEmpty && + canMakeFilterOn(name, values.head) => + val fieldType = nameToParquetField(name).fieldType + val fieldNames = nameToParquetField(name).fieldNames + if (values.length <= pushDownInFilterThreshold) { + values.distinct + .flatMap { v => + makeEq.lift(fieldType).map(_(fieldNames, v)) + } + .reduceLeftOption(FilterApi.or) + } else if (canPartialPushDownConjuncts) { + val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType + val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType) + if (values.contains(null)) { + Seq( + makeEq.lift(fieldType).map(_(fieldNames, null)), + makeInPredicate + .lift(fieldType) + .map(_(fieldNames, values.filter(_ != null), statistics))).flatten + .reduceLeftOption(FilterApi.or) + } else { + makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics)) + } + } else { + None + } + + case sources.StringStartsWith(name, prefix) + if pushDownStringPredicate && canMakeFilterOn(name, prefix) => + Option(prefix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + value != null && UTF8String + .fromBytes(value.getBytes) + .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) + } + }) + } + + case sources.StringEndsWith(name, suffix) + if pushDownStringPredicate && canMakeFilterOn(name, suffix) => + Option(suffix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val suffixStr = UTF8String.fromString(v) + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + override def keep(value: Binary): Boolean = { + value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr) + } + }) + } + + case sources.StringContains(name, value) + if pushDownStringPredicate && canMakeFilterOn(name, value) => + Option(value).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val subStr = UTF8String.fromString(v) + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + override def keep(value: Binary): Boolean = { + value != null && UTF8String.fromBytes(value.getBytes).contains(subStr) + } + }) + } + + case _ => None + } + } +} diff --git a/spark/src/main/spark-3.3/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.3/org/apache/spark/sql/comet/CometScanExec.scala new file mode 100644 index 000000000..14a664108 --- /dev/null +++ b/spark/src/main/spark-3.3/org/apache/spark/sql/comet/CometScanExec.scala @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.mutable.HashMap +import scala.concurrent.duration.NANOSECONDS +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection._ + +import org.apache.comet.{CometConf, MetricsSupport} +import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} +import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} + +/** + * Comet physical scan node for DataSource V1. Most of the code here follow Spark's + * [[FileSourceScanExec]], + */ +case class CometScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false, + wrapped: FileSourceScanExec) + extends DataSourceScanExec + with ShimCometScanExec + with CometPlan { + + // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests + override lazy val supportsColumnar: Boolean = + relation.fileFormat.supportBatch(relation.sparkSession, schema) + + override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes + + private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has been + * initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) + val startTime = System.nanoTime() + val ret = + relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + setFilesNumAndSizeMetric(ret, true) + val timeTakenMs = + NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) + driverMetrics("metadataTime") = timeTakenMs + ret + }.toArray + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create( + predicate.transform { case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, + Nil) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } + + // exposed for testing + lazy val bucketedScan: Boolean = wrapped.bucketedScan + + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = + (wrapped.outputPartitioning, wrapped.outputOrdering) + + @transient + private lazy val pushedDownFilters = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } + + override lazy val metadata: Map[String, String] = + if (wrapped == null) Map.empty else wrapped.metadata + + override def verboseStringWithOperatorId(): String = { + getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) + wrapped.verboseStringWithOperatorId() + } + + lazy val inputRDD: RDD[InternalRow] = { + val options = relation.options + + (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) + val readFile: (PartitionedFile) => Iterator[InternalRow] = + relation.fileFormat.buildReaderWithPartitionValues( + sparkSession = relation.sparkSession, + dataSchema = relation.dataSchema, + partitionSchema = relation.partitionSchema, + requiredSchema = requiredSchema, + filters = pushedDownFilters, + options = options, + hadoopConf = + relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) + + val readRDD = if (bucketedScan) { + createBucketedReadRDD( + relation.bucketSpec.get, + readFile, + dynamicallySelectedPartitions, + relation) + } else { + createReadRDD(readFile, dynamicallySelectedPartitions, relation) + } + sendDriverMetrics() + readRDD + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + /** Helper for computing total number and size of files in selected partitions. */ + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchema.nonEmpty) { + driverMetrics("numPartitions") = partitions.length + } + } + + override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + relation.fileFormat match { + case f: MetricsSupport => f.initMetrics(sparkContext) + case _ => Map.empty + } + } + + protected override def doExecute(): RDD[InternalRow] = { + ColumnarToRowExec(this).doExecute() + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = batches.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + + override def next(): ColumnarBatch = { + val batch = batches.next() + numOutputRows += batch.numRows() + batch + } + } + } + } + + override def executeCollect(): Array[InternalRow] = { + ColumnarToRowExec(this).executeCollect() + } + + override val nodeName: String = + s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" + + /** + * Create an RDD for bucketed reads. The non-bucketed variant of this function is + * [[createReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the + * files with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec + * the bucketing spec. + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createBucketedReadRDD( + bucketSpec: BucketSpec, + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") + val filesGroupedToBuckets = + selectedPartitions + .flatMap { p => + p.files.map { f => + PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) + } + } + .groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath.toString()).getName) + .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) + } + + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { f => + bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets + .map { numCoalescedBuckets => + logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { bucketId => + val partitionedFiles = coalescedBuckets + .get(bucketId) + .map { + _.values.flatten.toArray + } + .getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + } + .getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + + prepareRDD(fsRelation, readFile, filePartitions) + } + + /** + * Create an RDD for non-bucketed reads. The bucketed variant of this function is + * [[createBucketedReadRDD]]. + * + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createReadRDD( + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) + logInfo( + s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + // Filter files with bucket pruning if possible + val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + + val splitFiles = selectedPartitions + .flatMap { partition => + partition.files.flatMap { file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + + if (shouldProcess(filePath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, + relation.options, + filePath) && + // SPARK-39634: Allow file splitting in combination with row index generation once + // the fix for PARQUET-2161 is available. + !isNeededForSchema(requiredSchema) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values) + } else { + Seq.empty + } + } + } + .sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + prepareRDD( + fsRelation, + readFile, + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) + } + + private def prepareRDD( + fsRelation: HadoopFsRelation, + readFile: (PartitionedFile) => Iterator[InternalRow], + partitions: Seq[FilePartition]): RDD[InternalRow] = { + val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) + val prefetchEnabled = hadoopConf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + + val sqlConf = fsRelation.sparkSession.sessionState.conf + if (prefetchEnabled) { + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedConf = + fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val partitionReaderFactory = CometParquetPartitionReaderFactory( + sqlConf, + broadcastedConf, + requiredSchema, + relation.partitionSchema, + pushedDownFilters.toArray, + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), + metrics) + + newDataSourceRDD( + fsRelation.sparkSession.sparkContext, + partitions.map(Seq(_)), + partitionReaderFactory, + true, + Map.empty) + } else { + newFileScanRDD( + fsRelation.sparkSession, + readFile, + partitions, + new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) + } + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + private def filterUnusedDynamicPruningExpressions( + predicates: Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } + + override def doCanonicalize(): CometScanExec = { + CometScanExec( + relation, + output.map(QueryPlan.normalizeExpressions(_, output)), + requiredSchema, + QueryPlan.normalizePredicates( + filterUnusedDynamicPruningExpressions(partitionFilters), + output), + optionalBucketSet, + optionalNumCoalescedBuckets, + QueryPlan.normalizePredicates(dataFilters, output), + None, + disableBucketedScan, + null) + } +} + +object CometScanExec { + def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { + // TreeNode.mapProductIterator is protected method. + def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { + val arr = Array.ofDim[B](product.productArity) + var i = 0 + while (i < arr.length) { + arr(i) = f(product.productElement(i)) + i += 1 + } + arr + } + + // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues + // on other Spark distributions if FileSourceScanExec constructor is changed. + // Using `makeCopy` to avoid the issue. + // https://github.com/apache/arrow-datafusion-comet/issues/190 + def transform(arg: Any): AnyRef = arg match { + case _: HadoopFsRelation => + scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) + case other: AnyRef => other + case null => null + } + val newArgs = mapProductIterator(scanExec, transform(_)) + val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] + val batchScanExec = CometScanExec( + wrapped.relation, + wrapped.output, + wrapped.requiredSchema, + wrapped.partitionFilters, + wrapped.optionalBucketSet, + wrapped.optionalNumCoalescedBuckets, + wrapped.dataFilters, + wrapped.tableIdentifier, + wrapped.disableBucketedScan, + wrapped) + scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) + batchScanExec + } +} diff --git a/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala new file mode 100644 index 000000000..ac871cf60 --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import scala.collection.JavaConverters + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop.ParquetInputFormat +import org.apache.parquet.hadoop.metadata.FileMetaData +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.RecordReaderIterator +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.{DateType, StructType, TimestampType} +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.CometConf +import org.apache.comet.MetricsSupport +import org.apache.comet.shims.ShimSQLConf +import org.apache.comet.vector.CometVector + +/** + * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's + * [[ParquetFileFormat]], but overrides: + * + * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap + * column vector in the whole-stage codegen path. + * - `supportBatch`, which simply returns true since data types should have already been checked + * in [[org.apache.comet.CometSparkSessionExtensions]] + * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values. + */ +class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf { + override def shortName(): String = "parquet" + override def toString: String = "CometParquet" + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat] + + override def vectorTypes( + requiredSchema: StructType, + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { + val length = requiredSchema.fields.length + partitionSchema.fields.length + Option(Seq.fill(length)(classOf[CometVector].getName)) + } + + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val sqlConf = sparkSession.sessionState.conf + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + val useFieldId = CometParquetUtils.readFieldId(sqlConf) + val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val optionsMap = CaseInsensitiveMap[String](options) + val parquetOptions = new ParquetOptions(optionsMap, sqlConf) + val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead + val parquetFilterPushDown = sqlConf.parquetFilterPushDown + + // Comet specific configurations + val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) + + (file: PartitionedFile) => { + val sharedConf = broadcastedHadoopConf.value.value + val footer = FooterReader.readFooter(sharedConf, file) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( + file, + requiredSchema, + sharedConf, + footerFileMetaData, + datetimeRebaseModeInRead) + + val pushed = if (parquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can + // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a + // `flatMap` is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) + + val batchReader = new BatchReader( + sharedConf, + file, + footer, + capacity, + requiredSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + val iter = new RecordReaderIterator(batchReader) + try { + batchReader.init() + iter.asInstanceOf[Iterator[InternalRow]] + } catch { + case e: Throwable => + iter.close() + throw e + } + } + } +} + +object CometParquetFileFormat extends Logging { + + /** + * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` + */ + def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = { + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sqlConf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlConf.isParquetINT96AsTimestamp) + + // Comet specific configs + hadoopConf.setBoolean( + CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key, + CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get()) + hadoopConf.setBoolean( + CometConf.COMET_USE_DECIMAL_128.key, + CometConf.COMET_USE_DECIMAL_128.get()) + hadoopConf.setBoolean( + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get()) + } + + def getDatetimeRebaseSpec( + file: PartitionedFile, + sparkSchema: StructType, + sharedConf: Configuration, + footerFileMetaData: FileMetaData, + datetimeRebaseModeInRead: String): RebaseSpec = { + val exceptionOnRebase = sharedConf.getBoolean( + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get) + var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val hasDateOrTimestamp = sparkSchema.exists(f => + f.dataType match { + case DateType | TimestampType => true + case _ => false + }) + + if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { + if (exceptionOnRebase) { + logWarning( + s"""Found Parquet file $file that could potentially contain dates/timestamps that were + written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase + and return these according to the new Proleptic Gregorian calendar, Comet will throw + exception when reading them. If you want to read them as it is according to the hybrid + Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to + false. Otherwise, if you want to read them according to the new Proleptic Gregorian + calendar, please disable Comet for this query.""") + } else { + // do not throw exception on rebase - read as it is + datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) + } + } + + datetimeRebaseSpec + } +} diff --git a/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala new file mode 100644 index 000000000..693af125b --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import scala.collection.JavaConverters +import scala.collection.mutable + +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.parquet.hadoop.ParquetInputFormat +import org.apache.parquet.hadoop.metadata.ParquetMetadata +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.{CometConf, CometRuntimeException} +import org.apache.comet.shims.ShimSQLConf + +case class CometParquetPartitionReaderFactory( + @transient sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + readDataSchema: StructType, + partitionSchema: StructType, + filters: Array[Filter], + options: ParquetOptions, + metrics: Map[String, SQLMetric]) + extends FilePartitionReaderFactory + with ShimSQLConf + with Logging { + + private val isCaseSensitive = sqlConf.caseSensitiveAnalysis + private val useFieldId = CometParquetUtils.readFieldId(sqlConf) + private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) + private val pushDownDate = sqlConf.parquetFilterPushDownDate + private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) + private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead + private val parquetFilterPushDown = sqlConf.parquetFilterPushDown + + // Comet specific configurations + private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf) + + // This is only called at executor on a Broadcast variable, so we don't want it to be + // materialized at driver. + @transient private lazy val preFetchEnabled = { + val conf = broadcastedConf.value.value + + conf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + } + + private var cometReaders: Iterator[BatchReader] = _ + private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]() + + // TODO: we may want to revisit this as we're going to only support flat types at the beginning + override def supportColumnarReads(partition: InputPartition): Boolean = true + + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + if (preFetchEnabled) { + val filePartition = partition.asInstanceOf[FilePartition] + val conf = broadcastedConf.value.value + + val threadNum = conf.getInt( + CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key, + CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get) + val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum) + + this.cometReaders = filePartition.files + .map { file => + // `init()` call is deferred to when the prefetch task begins. + // Otherwise we will hold too many resources for readers which are not ready + // to prefetch. + val cometReader = buildCometReader(file) + if (cometReader != null) { + cometReader.submitPrefetchTask(prefetchThreadPool) + } + + cometReader + } + .toSeq + .toIterator + } + + super.createColumnarReader(partition) + } + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = + throw new UnsupportedOperationException("Comet doesn't support 'buildReader'") + + private def buildCometReader(file: PartitionedFile): BatchReader = { + val conf = broadcastedConf.value.value + + try { + val (datetimeRebaseSpec, footer, filters) = getFilter(file) + filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) + val cometReader = new BatchReader( + conf, + file, + footer, + batchSize, + readDataSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + val taskContext = Option(TaskContext.get) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) + return cometReader + } catch { + case e: Throwable if preFetchEnabled => + // Keep original exception + cometReaderExceptionMap.put(file, e) + } + null + } + + override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { + val cometReader = if (!preFetchEnabled) { + // Prefetch is not enabled, create comet reader and initiate it. + val cometReader = buildCometReader(file) + cometReader.init() + + cometReader + } else { + // If prefetch is enabled, we already tried to access the file when in `buildCometReader`. + // It is possibly we got an exception like `FileNotFoundException` and we need to throw it + // now to let Spark handle it. + val reader = cometReaders.next() + val exception = cometReaderExceptionMap.get(file) + exception.foreach(e => throw e) + + if (reader == null) { + throw new CometRuntimeException(s"Cannot find comet file reader for $file") + } + reader + } + CometPartitionReader(cometReader) + } + + def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = { + val sharedConf = broadcastedConf.value.value + val footer = FooterReader.readFooter(sharedConf, file) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( + file, + readDataSchema, + sharedConf, + footerFileMetaData, + datetimeRebaseModeInRead) + + val pushed = if (parquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + (datetimeRebaseSpec, footer, pushed) + } + + override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = + throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.") + + /** + * A simple adapter on Comet's [[BatchReader]]. + */ + protected case class CometPartitionReader(reader: BatchReader) + extends PartitionReader[ColumnarBatch] { + + override def next(): Boolean = { + reader.nextBatch() + } + + override def get(): ColumnarBatch = { + reader.currentBatch() + } + + override def close(): Unit = { + reader.close() + } + } +} diff --git a/spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala new file mode 100644 index 000000000..5994dfb41 --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} +import java.time.{Duration, Instant, LocalDate, Period} +import java.util.Locale + +import scala.collection.JavaConverters.asScalaBufferConverter + +import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} +import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.SparkFilterApi._ +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type} +import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sources +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus + +/** + * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove + * this duplication + * + * Some utility function to convert Spark data source filters to Parquet filters. + */ +class ParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStringPredicate: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean, + datetimeRebaseSpec: RebaseSpec) { + // A map which contains parquet field name and data type, if predicate push down applies. + // + // Each key in `nameToParquetField` represents a column; `dots` are used as separators for + // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + private val nameToParquetField: Map[String, ParquetPrimitiveField] = { + // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. + // `parentFieldNames` is used to keep track of the current nested level when traversing. + def getPrimitiveFields( + fields: Seq[Type], + parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { + fields.flatMap { + // Parquet only supports predicate push-down for non-repeated primitive types. + // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for + // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34) + case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => + Some( + ParquetPrimitiveField( + fieldNames = parentFieldNames :+ p.getName, + fieldType = ParquetSchemaType( + p.getLogicalTypeAnnotation, + p.getPrimitiveTypeName, + p.getTypeLength))) + // Note that when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => + getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are removed. + case _ => None + } + } + + val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => + (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields.toMap) + } + } + + /** + * Holds a single primitive field information stored in the underlying parquet file. + * + * @param fieldNames + * a field name as an array of string multi-identifier in parquet file + * @param fieldType + * field type related info in parquet file + */ + private case class ParquetPrimitiveField( + fieldNames: Array[String], + fieldType: ParquetSchemaType) + + private case class ParquetSchemaType( + logicalTypeAnnotation: LogicalTypeAnnotation, + primitiveTypeName: PrimitiveTypeName, + length: Int) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0) + private val ParquetByteType = + ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0) + private val ParquetShortType = + ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0) + private val ParquetStringType = + ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0) + private val ParquetDateType = + ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0) + private val ParquetTimestampMicrosType = + ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0) + private val ParquetTimestampMillisType = + ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0) + + private def dateToDays(date: Any): Int = { + val gregorianDays = date match { + case d: Date => DateTimeUtils.fromJavaDate(d) + case ld: LocalDate => DateTimeUtils.localDateToDays(ld) + } + datetimeRebaseSpec.mode match { + case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) + case _ => gregorianDays + } + } + + private def timestampToMicros(v: Any): JLong = { + val gregorianMicros = v match { + case i: Instant => DateTimeUtils.instantToMicros(i) + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + } + datetimeRebaseSpec.mode match { + case LegacyBehaviorPolicy.LEGACY => + rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) + case _ => gregorianMicros + } + } + + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private def timestampToMillis(v: Any): JLong = { + val micros = timestampToMicros(v) + val millis = DateTimeUtils.microsToMillis(micros) + millis.asInstanceOf[JLong] + } + + private def toIntValue(v: Any): Integer = { + Option(v) + .map { + case p: Period => IntervalUtils.periodToMonths(p) + case n => n.asInstanceOf[Number].intValue + } + .map(_.asInstanceOf[Integer]) + .orNull + } + + private def toLongValue(v: Any): JLong = v match { + case d: Duration => IntervalUtils.durationToMicros(d) + case l => l.asInstanceOf[JLong] + } + + private val makeEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) + + // Binary.fromString and Binary.fromByteArray don't accept null values + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeNotEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeLt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeLtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeInPredicate: PartialFunction[ + ParquetSchemaType, + (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(toIntValue(_).toInt).foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetLongType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(toLongValue).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetFloatType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]), + FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat])) + + case ParquetDoubleType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]), + FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble])) + + case ParquetStringType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) + + case ParquetBinaryType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]])) + .foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) + + case ParquetDateType if pushDownDate => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(timestampToMicros).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(timestampToMillis).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)) + .foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary])) + } + + // Returns filters that can be pushed down when reading Parquet files. + def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { + filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) + } + + private def convertibleFiltersHelper( + predicate: sources.Filter, + canPartialPushDown: Boolean): Option[sources.Filter] = { + predicate match { + case sources.And(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + (leftResultOptional, rightResultOptional) match { + case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) + case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) + case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) + case _ => None + } + + case sources.Or(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { + None + } else { + Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) + } + case sources.Not(pred) => + val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) + resultOptional.map(sources.Not) + + case other => + if (createFilter(other).isDefined) { + Some(other) + } else { + None + } + } + } + + /** + * Converts data sources filters to Parquet filter predicates. + */ + def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { + createFilterHelper(predicate, canPartialPushDownConjuncts = true) + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToParquetField(name).fieldType match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => + if (isSpark34Plus) { + value match { + // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type + // Int. We don't create a filter if the value would overflow. + case _: JByte | _: JShort | _: Integer => true + case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue + case _ => false + } + } else { + // If not Spark 3.4+, we still following the old behavior as Spark does. + value.isInstanceOf[Number] + } + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => + value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] + case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) => + isDecimalMatched(value, decimalType) + case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) => + isDecimalMatched(value, decimalType) + case ParquetSchemaType( + decimalType: DecimalLogicalTypeAnnotation, + FIXED_LEN_BYTE_ARRAY, + _) => + isDecimalMatched(value, decimalType) + case _ => false + }) + } + + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + private def isDecimalMatched( + value: Any, + decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalLogicalType.getScale + case _ => false + } + + private def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) + } + + /** + * @param predicate + * the input filter predicates. Not all the predicates can be pushed down. + * @param canPartialPushDownConjuncts + * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one + * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. + * @return + * the Parquet-native filter predicates that are eligible for pushdown. + */ + private def createFilterHelper( + predicate: sources.Filter, + canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `PruneFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + + predicate match { + case sources.IsNull(name) if canMakeFilterOn(name, null) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => + makeLt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeLtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => + makeGt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeGtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.And(lhs, rhs) => + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val lhsFilterOption = + createFilterHelper(lhs, canPartialPushDownConjuncts) + val rhsFilterOption = + createFilterHelper(rhs, canPartialPushDownConjuncts) + + (lhsFilterOption, rhsFilterOption) match { + case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) + case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) + case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) + case _ => None + } + + case sources.Or(lhs, rhs) => + // The Or predicate is convertible when both of its children can be pushed down. + // That is to say, if one/both of the children can be partially pushed down, the Or + // predicate can be partially pushed down as well. + // + // Here is an example used to explain the reason. + // Let's say we have + // (a1 AND a2) OR (b1 AND b2), + // a1 and b1 is convertible, while a2 and b2 is not. + // The predicate can be converted as + // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) + // As per the logical in And predicate, we can push down (a1 OR b1). + for { + lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) + rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case sources.Not(pred) => + createFilterHelper(pred, canPartialPushDownConjuncts = false) + .map(FilterApi.not) + + case sources.In(name, values) + if pushDownInFilterThreshold > 0 && values.nonEmpty && + canMakeFilterOn(name, values.head) => + val fieldType = nameToParquetField(name).fieldType + val fieldNames = nameToParquetField(name).fieldNames + if (values.length <= pushDownInFilterThreshold) { + values.distinct + .flatMap { v => + makeEq.lift(fieldType).map(_(fieldNames, v)) + } + .reduceLeftOption(FilterApi.or) + } else if (canPartialPushDownConjuncts) { + val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType + val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType) + if (values.contains(null)) { + Seq( + makeEq.lift(fieldType).map(_(fieldNames, null)), + makeInPredicate + .lift(fieldType) + .map(_(fieldNames, values.filter(_ != null), statistics))).flatten + .reduceLeftOption(FilterApi.or) + } else { + makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics)) + } + } else { + None + } + + case sources.StringStartsWith(name, prefix) + if pushDownStringPredicate && canMakeFilterOn(name, prefix) => + Option(prefix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + value != null && UTF8String + .fromBytes(value.getBytes) + .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) + } + }) + } + + case sources.StringEndsWith(name, suffix) + if pushDownStringPredicate && canMakeFilterOn(name, suffix) => + Option(suffix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val suffixStr = UTF8String.fromString(v) + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + override def keep(value: Binary): Boolean = { + value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr) + } + }) + } + + case sources.StringContains(name, value) + if pushDownStringPredicate && canMakeFilterOn(name, value) => + Option(value).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val subStr = UTF8String.fromString(v) + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + override def keep(value: Binary): Boolean = { + value != null && UTF8String.fromBytes(value.getBytes).contains(subStr) + } + }) + } + + case _ => None + } + } +} diff --git a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala new file mode 100644 index 000000000..14a664108 --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.mutable.HashMap +import scala.concurrent.duration.NANOSECONDS +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection._ + +import org.apache.comet.{CometConf, MetricsSupport} +import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} +import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} + +/** + * Comet physical scan node for DataSource V1. Most of the code here follow Spark's + * [[FileSourceScanExec]], + */ +case class CometScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false, + wrapped: FileSourceScanExec) + extends DataSourceScanExec + with ShimCometScanExec + with CometPlan { + + // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests + override lazy val supportsColumnar: Boolean = + relation.fileFormat.supportBatch(relation.sparkSession, schema) + + override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes + + private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has been + * initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) + val startTime = System.nanoTime() + val ret = + relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + setFilesNumAndSizeMetric(ret, true) + val timeTakenMs = + NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) + driverMetrics("metadataTime") = timeTakenMs + ret + }.toArray + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create( + predicate.transform { case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, + Nil) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } + + // exposed for testing + lazy val bucketedScan: Boolean = wrapped.bucketedScan + + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = + (wrapped.outputPartitioning, wrapped.outputOrdering) + + @transient + private lazy val pushedDownFilters = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } + + override lazy val metadata: Map[String, String] = + if (wrapped == null) Map.empty else wrapped.metadata + + override def verboseStringWithOperatorId(): String = { + getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) + wrapped.verboseStringWithOperatorId() + } + + lazy val inputRDD: RDD[InternalRow] = { + val options = relation.options + + (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) + val readFile: (PartitionedFile) => Iterator[InternalRow] = + relation.fileFormat.buildReaderWithPartitionValues( + sparkSession = relation.sparkSession, + dataSchema = relation.dataSchema, + partitionSchema = relation.partitionSchema, + requiredSchema = requiredSchema, + filters = pushedDownFilters, + options = options, + hadoopConf = + relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) + + val readRDD = if (bucketedScan) { + createBucketedReadRDD( + relation.bucketSpec.get, + readFile, + dynamicallySelectedPartitions, + relation) + } else { + createReadRDD(readFile, dynamicallySelectedPartitions, relation) + } + sendDriverMetrics() + readRDD + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + /** Helper for computing total number and size of files in selected partitions. */ + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchema.nonEmpty) { + driverMetrics("numPartitions") = partitions.length + } + } + + override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + relation.fileFormat match { + case f: MetricsSupport => f.initMetrics(sparkContext) + case _ => Map.empty + } + } + + protected override def doExecute(): RDD[InternalRow] = { + ColumnarToRowExec(this).doExecute() + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = batches.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + + override def next(): ColumnarBatch = { + val batch = batches.next() + numOutputRows += batch.numRows() + batch + } + } + } + } + + override def executeCollect(): Array[InternalRow] = { + ColumnarToRowExec(this).executeCollect() + } + + override val nodeName: String = + s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" + + /** + * Create an RDD for bucketed reads. The non-bucketed variant of this function is + * [[createReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the + * files with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec + * the bucketing spec. + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createBucketedReadRDD( + bucketSpec: BucketSpec, + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") + val filesGroupedToBuckets = + selectedPartitions + .flatMap { p => + p.files.map { f => + PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) + } + } + .groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath.toString()).getName) + .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) + } + + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { f => + bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets + .map { numCoalescedBuckets => + logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { bucketId => + val partitionedFiles = coalescedBuckets + .get(bucketId) + .map { + _.values.flatten.toArray + } + .getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + } + .getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + + prepareRDD(fsRelation, readFile, filePartitions) + } + + /** + * Create an RDD for non-bucketed reads. The bucketed variant of this function is + * [[createBucketedReadRDD]]. + * + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createReadRDD( + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) + logInfo( + s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + // Filter files with bucket pruning if possible + val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + + val splitFiles = selectedPartitions + .flatMap { partition => + partition.files.flatMap { file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + + if (shouldProcess(filePath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, + relation.options, + filePath) && + // SPARK-39634: Allow file splitting in combination with row index generation once + // the fix for PARQUET-2161 is available. + !isNeededForSchema(requiredSchema) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values) + } else { + Seq.empty + } + } + } + .sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + prepareRDD( + fsRelation, + readFile, + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) + } + + private def prepareRDD( + fsRelation: HadoopFsRelation, + readFile: (PartitionedFile) => Iterator[InternalRow], + partitions: Seq[FilePartition]): RDD[InternalRow] = { + val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) + val prefetchEnabled = hadoopConf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + + val sqlConf = fsRelation.sparkSession.sessionState.conf + if (prefetchEnabled) { + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedConf = + fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val partitionReaderFactory = CometParquetPartitionReaderFactory( + sqlConf, + broadcastedConf, + requiredSchema, + relation.partitionSchema, + pushedDownFilters.toArray, + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), + metrics) + + newDataSourceRDD( + fsRelation.sparkSession.sparkContext, + partitions.map(Seq(_)), + partitionReaderFactory, + true, + Map.empty) + } else { + newFileScanRDD( + fsRelation.sparkSession, + readFile, + partitions, + new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) + } + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + private def filterUnusedDynamicPruningExpressions( + predicates: Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } + + override def doCanonicalize(): CometScanExec = { + CometScanExec( + relation, + output.map(QueryPlan.normalizeExpressions(_, output)), + requiredSchema, + QueryPlan.normalizePredicates( + filterUnusedDynamicPruningExpressions(partitionFilters), + output), + optionalBucketSet, + optionalNumCoalescedBuckets, + QueryPlan.normalizePredicates(dataFilters, output), + None, + disableBucketedScan, + null) + } +} + +object CometScanExec { + def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { + // TreeNode.mapProductIterator is protected method. + def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { + val arr = Array.ofDim[B](product.productArity) + var i = 0 + while (i < arr.length) { + arr(i) = f(product.productElement(i)) + i += 1 + } + arr + } + + // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues + // on other Spark distributions if FileSourceScanExec constructor is changed. + // Using `makeCopy` to avoid the issue. + // https://github.com/apache/arrow-datafusion-comet/issues/190 + def transform(arg: Any): AnyRef = arg match { + case _: HadoopFsRelation => + scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) + case other: AnyRef => other + case null => null + } + val newArgs = mapProductIterator(scanExec, transform(_)) + val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] + val batchScanExec = CometScanExec( + wrapped.relation, + wrapped.output, + wrapped.requiredSchema, + wrapped.partitionFilters, + wrapped.optionalBucketSet, + wrapped.optionalNumCoalescedBuckets, + wrapped.dataFilters, + wrapped.tableIdentifier, + wrapped.disableBucketedScan, + wrapped) + scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) + batchScanExec + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala new file mode 100644 index 000000000..91e86eea3 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import scala.collection.JavaConverters + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop.ParquetInputFormat +import org.apache.parquet.hadoop.metadata.FileMetaData +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.RecordReaderIterator +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.{DateType, StructType, TimestampType} +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.CometConf +import org.apache.comet.MetricsSupport +import org.apache.comet.shims.ShimSQLConf +import org.apache.comet.vector.CometVector + +/** + * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's + * [[ParquetFileFormat]], but overrides: + * + * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap + * column vector in the whole-stage codegen path. + * - `supportBatch`, which simply returns true since data types should have already been checked + * in [[org.apache.comet.CometSparkSessionExtensions]] + * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values. + */ +class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf { + override def shortName(): String = "parquet" + override def toString: String = "CometParquet" + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat] + + override def vectorTypes( + requiredSchema: StructType, + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { + val length = requiredSchema.fields.length + partitionSchema.fields.length + Option(Seq.fill(length)(classOf[CometVector].getName)) + } + + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val sqlConf = sparkSession.sessionState.conf + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + val useFieldId = CometParquetUtils.readFieldId(sqlConf) + val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val optionsMap = CaseInsensitiveMap[String](options) + val parquetOptions = new ParquetOptions(optionsMap, sqlConf) + val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead + val parquetFilterPushDown = sqlConf.parquetFilterPushDown + + // Comet specific configurations + val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) + + (file: PartitionedFile) => { + val sharedConf = broadcastedHadoopConf.value.value + val footer = FooterReader.readFooter(sharedConf, file) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( + file, + requiredSchema, + sharedConf, + footerFileMetaData, + datetimeRebaseModeInRead) + + val pushed = if (parquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can + // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a + // `flatMap` is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) + + val batchReader = new BatchReader( + sharedConf, + file, + footer, + capacity, + requiredSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + val iter = new RecordReaderIterator(batchReader) + try { + batchReader.init() + iter.asInstanceOf[Iterator[InternalRow]] + } catch { + case e: Throwable => + iter.close() + throw e + } + } + } +} + +object CometParquetFileFormat extends Logging { + + /** + * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` + */ + def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = { + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sqlConf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlConf.isParquetINT96AsTimestamp) + + // Comet specific configs + hadoopConf.setBoolean( + CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key, + CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get()) + hadoopConf.setBoolean( + CometConf.COMET_USE_DECIMAL_128.key, + CometConf.COMET_USE_DECIMAL_128.get()) + hadoopConf.setBoolean( + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get()) + } + + def getDatetimeRebaseSpec( + file: PartitionedFile, + sparkSchema: StructType, + sharedConf: Configuration, + footerFileMetaData: FileMetaData, + datetimeRebaseModeInRead: String): RebaseSpec = { + val exceptionOnRebase = sharedConf.getBoolean( + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, + CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get) + var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val hasDateOrTimestamp = sparkSchema.exists(f => + f.dataType match { + case DateType | TimestampType => true + case _ => false + }) + + if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { + if (exceptionOnRebase) { + logWarning( + s"""Found Parquet file $file that could potentially contain dates/timestamps that were + written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase + and return these according to the new Proleptic Gregorian calendar, Comet will throw + exception when reading them. If you want to read them as it is according to the hybrid + Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to + false. Otherwise, if you want to read them according to the new Proleptic Gregorian + calendar, please disable Comet for this query.""") + } else { + // do not throw exception on rebase - read as it is + datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) + } + } + + datetimeRebaseSpec + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala new file mode 100644 index 000000000..787357531 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import scala.collection.JavaConverters +import scala.collection.mutable + +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.parquet.hadoop.ParquetInputFormat +import org.apache.parquet.hadoop.metadata.ParquetMetadata +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.{CometConf, CometRuntimeException} +import org.apache.comet.shims.ShimSQLConf + +case class CometParquetPartitionReaderFactory( + @transient sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + readDataSchema: StructType, + partitionSchema: StructType, + filters: Array[Filter], + options: ParquetOptions, + metrics: Map[String, SQLMetric]) + extends FilePartitionReaderFactory + with ShimSQLConf + with Logging { + + private val isCaseSensitive = sqlConf.caseSensitiveAnalysis + private val useFieldId = CometParquetUtils.readFieldId(sqlConf) + private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) + private val pushDownDate = sqlConf.parquetFilterPushDownDate + private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) + private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead + private val parquetFilterPushDown = sqlConf.parquetFilterPushDown + + // Comet specific configurations + private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf) + + // This is only called at executor on a Broadcast variable, so we don't want it to be + // materialized at driver. + @transient private lazy val preFetchEnabled = { + val conf = broadcastedConf.value.value + + conf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + } + + private var cometReaders: Iterator[BatchReader] = _ + private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]() + + // TODO: we may want to revisit this as we're going to only support flat types at the beginning + override def supportColumnarReads(partition: InputPartition): Boolean = true + + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + if (preFetchEnabled) { + val filePartition = partition.asInstanceOf[FilePartition] + val conf = broadcastedConf.value.value + + val threadNum = conf.getInt( + CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key, + CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get) + val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum) + + this.cometReaders = filePartition.files + .map { file => + // `init()` call is deferred to when the prefetch task begins. + // Otherwise we will hold too many resources for readers which are not ready + // to prefetch. + val cometReader = buildCometReader(file) + if (cometReader != null) { + cometReader.submitPrefetchTask(prefetchThreadPool) + } + + cometReader + } + .toSeq + .toIterator + } + + super.createColumnarReader(partition) + } + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = + throw new UnsupportedOperationException("Comet doesn't support 'buildReader'") + + private def buildCometReader(file: PartitionedFile): BatchReader = { + val conf = broadcastedConf.value.value + + try { + val (datetimeRebaseSpec, footer, filters) = getFilter(file) + filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) + val cometReader = new BatchReader( + conf, + file, + footer, + batchSize, + readDataSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + val taskContext = Option(TaskContext.get) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) + return cometReader + } catch { + case e: Throwable if preFetchEnabled => + // Keep original exception + cometReaderExceptionMap.put(file, e) + } + null + } + + override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { + val cometReader = if (!preFetchEnabled) { + // Prefetch is not enabled, create comet reader and initiate it. + val cometReader = buildCometReader(file) + cometReader.init() + + cometReader + } else { + // If prefetch is enabled, we already tried to access the file when in `buildCometReader`. + // It is possibly we got an exception like `FileNotFoundException` and we need to throw it + // now to let Spark handle it. + val reader = cometReaders.next() + val exception = cometReaderExceptionMap.get(file) + exception.foreach(e => throw e) + + if (reader == null) { + throw new CometRuntimeException(s"Cannot find comet file reader for $file") + } + reader + } + CometPartitionReader(cometReader) + } + + def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = { + val sharedConf = broadcastedConf.value.value + val footer = FooterReader.readFooter(sharedConf, file) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( + file, + readDataSchema, + sharedConf, + footerFileMetaData, + datetimeRebaseModeInRead) + + val pushed = if (parquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + (datetimeRebaseSpec, footer, pushed) + } + + override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = + throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.") + + /** + * A simple adapter on Comet's [[BatchReader]]. + */ + protected case class CometPartitionReader(reader: BatchReader) + extends PartitionReader[ColumnarBatch] { + + override def next(): Boolean = { + reader.nextBatch() + } + + override def get(): ColumnarBatch = { + reader.currentBatch() + } + + override def close(): Unit = { + reader.close() + } + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-3.5/org/apache/comet/parquet/ParquetFilters.scala new file mode 100644 index 000000000..3cd434418 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/parquet/ParquetFilters.scala @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} +import java.time.{Duration, Instant, LocalDate, Period} +import java.util.Locale + +import scala.collection.JavaConverters.asScalaBufferConverter + +import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} +import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.SparkFilterApi._ +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type} +import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} +import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.sources +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus + +/** + * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove + * this duplication + * + * Some utility function to convert Spark data source filters to Parquet filters. + */ +class ParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStringPredicate: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean, + datetimeRebaseSpec: RebaseSpec) { + // A map which contains parquet field name and data type, if predicate push down applies. + // + // Each key in `nameToParquetField` represents a column; `dots` are used as separators for + // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + private val nameToParquetField: Map[String, ParquetPrimitiveField] = { + // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. + // `parentFieldNames` is used to keep track of the current nested level when traversing. + def getPrimitiveFields( + fields: Seq[Type], + parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { + fields.flatMap { + // Parquet only supports predicate push-down for non-repeated primitive types. + // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for + // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34) + case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => + Some( + ParquetPrimitiveField( + fieldNames = parentFieldNames :+ p.getName, + fieldType = ParquetSchemaType( + p.getLogicalTypeAnnotation, + p.getPrimitiveTypeName, + p.getTypeLength))) + // Note that when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => + getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are removed. + case _ => None + } + } + + val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => + (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields.toMap) + } + } + + /** + * Holds a single primitive field information stored in the underlying parquet file. + * + * @param fieldNames + * a field name as an array of string multi-identifier in parquet file + * @param fieldType + * field type related info in parquet file + */ + private case class ParquetPrimitiveField( + fieldNames: Array[String], + fieldType: ParquetSchemaType) + + private case class ParquetSchemaType( + logicalTypeAnnotation: LogicalTypeAnnotation, + primitiveTypeName: PrimitiveTypeName, + length: Int) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0) + private val ParquetByteType = + ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0) + private val ParquetShortType = + ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0) + private val ParquetStringType = + ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0) + private val ParquetDateType = + ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0) + private val ParquetTimestampMicrosType = + ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0) + private val ParquetTimestampMillisType = + ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0) + + private def dateToDays(date: Any): Int = { + val gregorianDays = date match { + case d: Date => DateTimeUtils.fromJavaDate(d) + case ld: LocalDate => DateTimeUtils.localDateToDays(ld) + } + datetimeRebaseSpec.mode match { + case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) + case _ => gregorianDays + } + } + + private def timestampToMicros(v: Any): JLong = { + val gregorianMicros = v match { + case i: Instant => DateTimeUtils.instantToMicros(i) + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + } + datetimeRebaseSpec.mode match { + case LegacyBehaviorPolicy.LEGACY => + rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) + case _ => gregorianMicros + } + } + + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private def timestampToMillis(v: Any): JLong = { + val micros = timestampToMicros(v) + val millis = DateTimeUtils.microsToMillis(micros) + millis.asInstanceOf[JLong] + } + + private def toIntValue(v: Any): Integer = { + Option(v) + .map { + case p: Period => IntervalUtils.periodToMonths(p) + case n => n.asInstanceOf[Number].intValue + } + .map(_.asInstanceOf[Integer]) + .orNull + } + + private def toLongValue(v: Any): JLong = v match { + case d: Duration => IntervalUtils.durationToMicros(d) + case l => l.asInstanceOf[JLong] + } + + private val makeEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) + + // Binary.fromString and Binary.fromByteArray don't accept null values + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeNotEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeLt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeLtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeInPredicate: PartialFunction[ + ParquetSchemaType, + (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(toIntValue(_).toInt).foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetLongType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(toLongValue).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetFloatType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]), + FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat])) + + case ParquetDoubleType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]), + FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble])) + + case ParquetStringType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) + + case ParquetBinaryType => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]])) + .foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) + + case ParquetDateType if pushDownDate => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(timestampToMicros).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(timestampToMillis).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), + FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => + (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_)) + FilterApi.and( + FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), + FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) + + case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) + if pushDownDecimal => + (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => + v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)) + .foreach(statistics.updateStats) + FilterApi.and( + FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]), + FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary])) + } + + // Returns filters that can be pushed down when reading Parquet files. + def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { + filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) + } + + private def convertibleFiltersHelper( + predicate: sources.Filter, + canPartialPushDown: Boolean): Option[sources.Filter] = { + predicate match { + case sources.And(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + (leftResultOptional, rightResultOptional) match { + case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) + case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) + case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) + case _ => None + } + + case sources.Or(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { + None + } else { + Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) + } + case sources.Not(pred) => + val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) + resultOptional.map(sources.Not) + + case other => + if (createFilter(other).isDefined) { + Some(other) + } else { + None + } + } + } + + /** + * Converts data sources filters to Parquet filter predicates. + */ + def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { + createFilterHelper(predicate, canPartialPushDownConjuncts = true) + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToParquetField(name).fieldType match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => + if (isSpark34Plus) { + value match { + // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type + // Int. We don't create a filter if the value would overflow. + case _: JByte | _: JShort | _: Integer => true + case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue + case _ => false + } + } else { + // If not Spark 3.4+, we still following the old behavior as Spark does. + value.isInstanceOf[Number] + } + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => + value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] + case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) => + isDecimalMatched(value, decimalType) + case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) => + isDecimalMatched(value, decimalType) + case ParquetSchemaType( + decimalType: DecimalLogicalTypeAnnotation, + FIXED_LEN_BYTE_ARRAY, + _) => + isDecimalMatched(value, decimalType) + case _ => false + }) + } + + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + private def isDecimalMatched( + value: Any, + decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalLogicalType.getScale + case _ => false + } + + private def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) + } + + /** + * @param predicate + * the input filter predicates. Not all the predicates can be pushed down. + * @param canPartialPushDownConjuncts + * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one + * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. + * @return + * the Parquet-native filter predicates that are eligible for pushdown. + */ + private def createFilterHelper( + predicate: sources.Filter, + canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `PruneFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + + predicate match { + case sources.IsNull(name) if canMakeFilterOn(name, null) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => + makeLt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeLtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => + makeGt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeGtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.And(lhs, rhs) => + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val lhsFilterOption = + createFilterHelper(lhs, canPartialPushDownConjuncts) + val rhsFilterOption = + createFilterHelper(rhs, canPartialPushDownConjuncts) + + (lhsFilterOption, rhsFilterOption) match { + case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) + case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) + case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) + case _ => None + } + + case sources.Or(lhs, rhs) => + // The Or predicate is convertible when both of its children can be pushed down. + // That is to say, if one/both of the children can be partially pushed down, the Or + // predicate can be partially pushed down as well. + // + // Here is an example used to explain the reason. + // Let's say we have + // (a1 AND a2) OR (b1 AND b2), + // a1 and b1 is convertible, while a2 and b2 is not. + // The predicate can be converted as + // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) + // As per the logical in And predicate, we can push down (a1 OR b1). + for { + lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) + rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case sources.Not(pred) => + createFilterHelper(pred, canPartialPushDownConjuncts = false) + .map(FilterApi.not) + + case sources.In(name, values) + if pushDownInFilterThreshold > 0 && values.nonEmpty && + canMakeFilterOn(name, values.head) => + val fieldType = nameToParquetField(name).fieldType + val fieldNames = nameToParquetField(name).fieldNames + if (values.length <= pushDownInFilterThreshold) { + values.distinct + .flatMap { v => + makeEq.lift(fieldType).map(_(fieldNames, v)) + } + .reduceLeftOption(FilterApi.or) + } else if (canPartialPushDownConjuncts) { + val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType + val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType) + if (values.contains(null)) { + Seq( + makeEq.lift(fieldType).map(_(fieldNames, null)), + makeInPredicate + .lift(fieldType) + .map(_(fieldNames, values.filter(_ != null), statistics))).flatten + .reduceLeftOption(FilterApi.or) + } else { + makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics)) + } + } else { + None + } + + case sources.StringStartsWith(name, prefix) + if pushDownStringPredicate && canMakeFilterOn(name, prefix) => + Option(prefix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + value != null && UTF8String + .fromBytes(value.getBytes) + .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) + } + }) + } + + case sources.StringEndsWith(name, suffix) + if pushDownStringPredicate && canMakeFilterOn(name, suffix) => + Option(suffix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val suffixStr = UTF8String.fromString(v) + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + override def keep(value: Binary): Boolean = { + value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr) + } + }) + } + + case sources.StringContains(name, value) + if pushDownStringPredicate && canMakeFilterOn(name, value) => + Option(value).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val subStr = UTF8String.fromString(v) + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + override def keep(value: Binary): Boolean = { + value != null && UTF8String.fromBytes(value.getBytes).contains(subStr) + } + }) + } + + case _ => None + } + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala new file mode 100644 index 000000000..409e1c94b --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. + */ +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(unhex.failOnError)) + } +} diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala new file mode 100644 index 000000000..8a017fabe --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.mutable.HashMap +import scala.concurrent.duration.NANOSECONDS +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection._ + +import org.apache.comet.{CometConf, MetricsSupport} +import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} +import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} + +/** + * Comet physical scan node for DataSource V1. Most of the code here follow Spark's + * [[FileSourceScanExec]], + */ +case class CometScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false, + wrapped: FileSourceScanExec) + extends DataSourceScanExec + with ShimCometScanExec + with CometPlan { + + // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests + override lazy val supportsColumnar: Boolean = + relation.fileFormat.supportBatch(relation.sparkSession, schema) + + override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes + + private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has been + * initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) + val startTime = System.nanoTime() + val ret = + relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + setFilesNumAndSizeMetric(ret, true) + val timeTakenMs = + NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) + driverMetrics("metadataTime") = timeTakenMs + ret + }.toArray + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create( + predicate.transform { case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, + Nil) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } + + // exposed for testing + lazy val bucketedScan: Boolean = wrapped.bucketedScan + + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = + (wrapped.outputPartitioning, wrapped.outputOrdering) + + @transient + private lazy val pushedDownFilters = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } + + override lazy val metadata: Map[String, String] = + if (wrapped == null) Map.empty else wrapped.metadata + + override def verboseStringWithOperatorId(): String = { + getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) + wrapped.verboseStringWithOperatorId() + } + + lazy val inputRDD: RDD[InternalRow] = { + val options = relation.options + + (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) + val readFile: (PartitionedFile) => Iterator[InternalRow] = + relation.fileFormat.buildReaderWithPartitionValues( + sparkSession = relation.sparkSession, + dataSchema = relation.dataSchema, + partitionSchema = relation.partitionSchema, + requiredSchema = requiredSchema, + filters = pushedDownFilters, + options = options, + hadoopConf = + relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) + + val readRDD = if (bucketedScan) { + createBucketedReadRDD( + relation.bucketSpec.get, + readFile, + dynamicallySelectedPartitions, + relation) + } else { + createReadRDD(readFile, dynamicallySelectedPartitions, relation) + } + sendDriverMetrics() + readRDD + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + /** Helper for computing total number and size of files in selected partitions. */ + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchema.nonEmpty) { + driverMetrics("numPartitions") = partitions.length + } + } + + override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + relation.fileFormat match { + case f: MetricsSupport => f.initMetrics(sparkContext) + case _ => Map.empty + } + } + + protected override def doExecute(): RDD[InternalRow] = { + ColumnarToRowExec(this).doExecute() + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = batches.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + + override def next(): ColumnarBatch = { + val batch = batches.next() + numOutputRows += batch.numRows() + batch + } + } + } + } + + override def executeCollect(): Array[InternalRow] = { + ColumnarToRowExec(this).executeCollect() + } + + override val nodeName: String = + s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" + + /** + * Create an RDD for bucketed reads. The non-bucketed variant of this function is + * [[createReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the + * files with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec + * the bucketing spec. + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createBucketedReadRDD( + bucketSpec: BucketSpec, + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") + val filesGroupedToBuckets = + selectedPartitions + .flatMap { p => + p.files.map { f => + PartitionedFileUtil.getPartitionedFile(f, /*f.getPath,*/ p.values) + } + } + .groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath.toString()).getName) + .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) + } + + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { f => + bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets + .map { numCoalescedBuckets => + logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { bucketId => + val partitionedFiles = coalescedBuckets + .get(bucketId) + .map { + _.values.flatten.toArray + } + .getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + } + .getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + + prepareRDD(fsRelation, readFile, filePartitions) + } + + /** + * Create an RDD for non-bucketed reads. The bucketed variant of this function is + * [[createBucketedReadRDD]]. + * + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createReadRDD( + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) + logInfo( + s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + // Filter files with bucket pruning if possible + val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + + val splitFiles = selectedPartitions + .flatMap { partition => + partition.files.flatMap { file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + + if (shouldProcess(filePath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, + relation.options, + filePath) && + // SPARK-39634: Allow file splitting in combination with row index generation once + // the fix for PARQUET-2161 is available. + !isNeededForSchema(requiredSchema) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + //filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values) + } else { + Seq.empty + } + } + } + .sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + prepareRDD( + fsRelation, + readFile, + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) + } + + private def prepareRDD( + fsRelation: HadoopFsRelation, + readFile: (PartitionedFile) => Iterator[InternalRow], + partitions: Seq[FilePartition]): RDD[InternalRow] = { + val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) + val prefetchEnabled = hadoopConf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + + val sqlConf = fsRelation.sparkSession.sessionState.conf + if (prefetchEnabled) { + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedConf = + fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val partitionReaderFactory = CometParquetPartitionReaderFactory( + sqlConf, + broadcastedConf, + requiredSchema, + relation.partitionSchema, + pushedDownFilters.toArray, + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), + metrics) + + newDataSourceRDD( + fsRelation.sparkSession.sparkContext, + partitions.map(Seq(_)), + partitionReaderFactory, + true, + Map.empty) + } else { + newFileScanRDD( + fsRelation.sparkSession, + readFile, + partitions, + new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) + } + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + private def filterUnusedDynamicPruningExpressions( + predicates: Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } + + override def doCanonicalize(): CometScanExec = { + CometScanExec( + relation, + output.map(QueryPlan.normalizeExpressions(_, output)), + requiredSchema, + QueryPlan.normalizePredicates( + filterUnusedDynamicPruningExpressions(partitionFilters), + output), + optionalBucketSet, + optionalNumCoalescedBuckets, + QueryPlan.normalizePredicates(dataFilters, output), + None, + disableBucketedScan, + null) + } +} + +object CometScanExec { + def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { + // TreeNode.mapProductIterator is protected method. + def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { + val arr = Array.ofDim[B](product.productArity) + var i = 0 + while (i < arr.length) { + arr(i) = f(product.productElement(i)) + i += 1 + } + arr + } + + // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues + // on other Spark distributions if FileSourceScanExec constructor is changed. + // Using `makeCopy` to avoid the issue. + // https://github.com/apache/arrow-datafusion-comet/issues/190 + def transform(arg: Any): AnyRef = arg match { + case _: HadoopFsRelation => + scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) + case other: AnyRef => other + case null => null + } + val newArgs = mapProductIterator(scanExec, transform(_)) + val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] + val batchScanExec = CometScanExec( + wrapped.relation, + wrapped.output, + wrapped.requiredSchema, + wrapped.partitionFilters, + wrapped.optionalBucketSet, + wrapped.optionalNumCoalescedBuckets, + wrapped.dataFilters, + wrapped.tableIdentifier, + wrapped.disableBucketedScan, + wrapped) + scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) + batchScanExec + } +} From db99cd4e1076bd5d69c4e9c81a99755fa4db3a5d Mon Sep 17 00:00:00 2001 From: awol2005ex Date: Thu, 30 May 2024 11:48:12 +0800 Subject: [PATCH 02/12] delete old files backup --- .../parquet/CometParquetFileFormat.scala.old | 231 ----- ...metParquetPartitionReaderFactory.scala.old | 231 ----- .../comet/parquet/ParquetFilters.scala.old | 882 ------------------ .../spark/sql/comet/CometScanExec.scala.old | 483 ---------- 4 files changed, 1827 deletions(-) delete mode 100644 spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala.old delete mode 100644 spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala.old delete mode 100644 spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala.old delete mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala.old diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala.old b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala.old deleted file mode 100644 index ac871cf60..000000000 --- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala.old +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import scala.collection.JavaConverters - -import org.apache.hadoop.conf.Configuration -import org.apache.parquet.filter2.predicate.FilterApi -import org.apache.parquet.hadoop.ParquetInputFormat -import org.apache.parquet.hadoop.metadata.FileMetaData -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.execution.datasources.DataSourceUtils -import org.apache.spark.sql.execution.datasources.PartitionedFile -import org.apache.spark.sql.execution.datasources.RecordReaderIterator -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{DateType, StructType, TimestampType} -import org.apache.spark.util.SerializableConfiguration - -import org.apache.comet.CometConf -import org.apache.comet.MetricsSupport -import org.apache.comet.shims.ShimSQLConf -import org.apache.comet.vector.CometVector - -/** - * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's - * [[ParquetFileFormat]], but overrides: - * - * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap - * column vector in the whole-stage codegen path. - * - `supportBatch`, which simply returns true since data types should have already been checked - * in [[org.apache.comet.CometSparkSessionExtensions]] - * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values. - */ -class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf { - override def shortName(): String = "parquet" - override def toString: String = "CometParquet" - override def hashCode(): Int = getClass.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat] - - override def vectorTypes( - requiredSchema: StructType, - partitionSchema: StructType, - sqlConf: SQLConf): Option[Seq[String]] = { - val length = requiredSchema.fields.length + partitionSchema.fields.length - Option(Seq.fill(length)(classOf[CometVector].getName)) - } - - override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true - - override def buildReaderWithPartitionValues( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val sqlConf = sparkSession.sessionState.conf - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedHadoopConf = - sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - - val isCaseSensitive = sqlConf.caseSensitiveAnalysis - val useFieldId = CometParquetUtils.readFieldId(sqlConf) - val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) - val pushDownDate = sqlConf.parquetFilterPushDownDate - val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) - val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - val optionsMap = CaseInsensitiveMap[String](options) - val parquetOptions = new ParquetOptions(optionsMap, sqlConf) - val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead - val parquetFilterPushDown = sqlConf.parquetFilterPushDown - - // Comet specific configurations - val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) - - (file: PartitionedFile) => { - val sharedConf = broadcastedHadoopConf.value.value - val footer = FooterReader.readFooter(sharedConf, file) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( - file, - requiredSchema, - sharedConf, - footerFileMetaData, - datetimeRebaseModeInRead) - - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can - // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a - // `flatMap` is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) - - val batchReader = new BatchReader( - sharedConf, - file, - footer, - capacity, - requiredSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val iter = new RecordReaderIterator(batchReader) - try { - batchReader.init() - iter.asInstanceOf[Iterator[InternalRow]] - } catch { - case e: Throwable => - iter.close() - throw e - } - } - } -} - -object CometParquetFileFormat extends Logging { - - /** - * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` - */ - def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = { - hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) - hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) - hadoopConf.setBoolean( - SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, - sqlConf.nestedSchemaPruningEnabled) - hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) - - // Sets flags for `ParquetToSparkSchemaConverter` - hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) - hadoopConf.setBoolean( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlConf.isParquetINT96AsTimestamp) - - // Comet specific configs - hadoopConf.setBoolean( - CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key, - CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get()) - hadoopConf.setBoolean( - CometConf.COMET_USE_DECIMAL_128.key, - CometConf.COMET_USE_DECIMAL_128.get()) - hadoopConf.setBoolean( - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get()) - } - - def getDatetimeRebaseSpec( - file: PartitionedFile, - sparkSchema: StructType, - sharedConf: Configuration, - footerFileMetaData: FileMetaData, - datetimeRebaseModeInRead: String): RebaseSpec = { - val exceptionOnRebase = sharedConf.getBoolean( - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get) - var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( - footerFileMetaData.getKeyValueMetaData.get, - datetimeRebaseModeInRead) - val hasDateOrTimestamp = sparkSchema.exists(f => - f.dataType match { - case DateType | TimestampType => true - case _ => false - }) - - if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { - if (exceptionOnRebase) { - logWarning( - s"""Found Parquet file $file that could potentially contain dates/timestamps that were - written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase - and return these according to the new Proleptic Gregorian calendar, Comet will throw - exception when reading them. If you want to read them as it is according to the hybrid - Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to - false. Otherwise, if you want to read them according to the new Proleptic Gregorian - calendar, please disable Comet for this query.""") - } else { - // do not throw exception on rebase - read as it is - datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) - } - } - - datetimeRebaseSpec - } -} diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala.old b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala.old deleted file mode 100644 index 693af125b..000000000 --- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala.old +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import scala.collection.JavaConverters -import scala.collection.mutable - -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.hadoop.ParquetInputFormat -import org.apache.parquet.hadoop.metadata.ParquetMetadata -import org.apache.spark.TaskContext -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.connector.read.InputPartition -import org.apache.spark.sql.connector.read.PartitionReader -import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration - -import org.apache.comet.{CometConf, CometRuntimeException} -import org.apache.comet.shims.ShimSQLConf - -case class CometParquetPartitionReaderFactory( - @transient sqlConf: SQLConf, - broadcastedConf: Broadcast[SerializableConfiguration], - readDataSchema: StructType, - partitionSchema: StructType, - filters: Array[Filter], - options: ParquetOptions, - metrics: Map[String, SQLMetric]) - extends FilePartitionReaderFactory - with ShimSQLConf - with Logging { - - private val isCaseSensitive = sqlConf.caseSensitiveAnalysis - private val useFieldId = CometParquetUtils.readFieldId(sqlConf) - private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) - private val pushDownDate = sqlConf.parquetFilterPushDownDate - private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) - private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead - private val parquetFilterPushDown = sqlConf.parquetFilterPushDown - - // Comet specific configurations - private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf) - - // This is only called at executor on a Broadcast variable, so we don't want it to be - // materialized at driver. - @transient private lazy val preFetchEnabled = { - val conf = broadcastedConf.value.value - - conf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - } - - private var cometReaders: Iterator[BatchReader] = _ - private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]() - - // TODO: we may want to revisit this as we're going to only support flat types at the beginning - override def supportColumnarReads(partition: InputPartition): Boolean = true - - override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { - if (preFetchEnabled) { - val filePartition = partition.asInstanceOf[FilePartition] - val conf = broadcastedConf.value.value - - val threadNum = conf.getInt( - CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key, - CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get) - val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum) - - this.cometReaders = filePartition.files - .map { file => - // `init()` call is deferred to when the prefetch task begins. - // Otherwise we will hold too many resources for readers which are not ready - // to prefetch. - val cometReader = buildCometReader(file) - if (cometReader != null) { - cometReader.submitPrefetchTask(prefetchThreadPool) - } - - cometReader - } - .toSeq - .toIterator - } - - super.createColumnarReader(partition) - } - - override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = - throw new UnsupportedOperationException("Comet doesn't support 'buildReader'") - - private def buildCometReader(file: PartitionedFile): BatchReader = { - val conf = broadcastedConf.value.value - - try { - val (datetimeRebaseSpec, footer, filters) = getFilter(file) - filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) - val cometReader = new BatchReader( - conf, - file, - footer, - batchSize, - readDataSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val taskContext = Option(TaskContext.get) - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) - return cometReader - } catch { - case e: Throwable if preFetchEnabled => - // Keep original exception - cometReaderExceptionMap.put(file, e) - } - null - } - - override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val cometReader = if (!preFetchEnabled) { - // Prefetch is not enabled, create comet reader and initiate it. - val cometReader = buildCometReader(file) - cometReader.init() - - cometReader - } else { - // If prefetch is enabled, we already tried to access the file when in `buildCometReader`. - // It is possibly we got an exception like `FileNotFoundException` and we need to throw it - // now to let Spark handle it. - val reader = cometReaders.next() - val exception = cometReaderExceptionMap.get(file) - exception.foreach(e => throw e) - - if (reader == null) { - throw new CometRuntimeException(s"Cannot find comet file reader for $file") - } - reader - } - CometPartitionReader(cometReader) - } - - def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = { - val sharedConf = broadcastedConf.value.value - val footer = FooterReader.readFooter(sharedConf, file) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( - file, - readDataSchema, - sharedConf, - footerFileMetaData, - datetimeRebaseModeInRead) - - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - (datetimeRebaseSpec, footer, pushed) - } - - override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = - throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.") - - /** - * A simple adapter on Comet's [[BatchReader]]. - */ - protected case class CometPartitionReader(reader: BatchReader) - extends PartitionReader[ColumnarBatch] { - - override def next(): Boolean = { - reader.nextBatch() - } - - override def get(): ColumnarBatch = { - reader.currentBatch() - } - - override def close(): Unit = { - reader.close() - } - } -} diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala.old b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala.old deleted file mode 100644 index 5994dfb41..000000000 --- a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala.old +++ /dev/null @@ -1,882 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} -import java.math.{BigDecimal => JBigDecimal} -import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate, Period} -import java.util.Locale - -import scala.collection.JavaConverters.asScalaBufferConverter - -import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} -import org.apache.parquet.filter2.predicate._ -import org.apache.parquet.filter2.predicate.SparkFilterApi._ -import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type} -import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ -import org.apache.parquet.schema.Type.Repetition -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} -import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources -import org.apache.spark.unsafe.types.UTF8String - -import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus - -/** - * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove - * this duplication - * - * Some utility function to convert Spark data source filters to Parquet filters. - */ -class ParquetFilters( - schema: MessageType, - pushDownDate: Boolean, - pushDownTimestamp: Boolean, - pushDownDecimal: Boolean, - pushDownStringPredicate: Boolean, - pushDownInFilterThreshold: Int, - caseSensitive: Boolean, - datetimeRebaseSpec: RebaseSpec) { - // A map which contains parquet field name and data type, if predicate push down applies. - // - // Each key in `nameToParquetField` represents a column; `dots` are used as separators for - // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. - // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. - private val nameToParquetField: Map[String, ParquetPrimitiveField] = { - // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. - // `parentFieldNames` is used to keep track of the current nested level when traversing. - def getPrimitiveFields( - fields: Seq[Type], - parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { - fields.flatMap { - // Parquet only supports predicate push-down for non-repeated primitive types. - // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for - // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34) - case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => - Some( - ParquetPrimitiveField( - fieldNames = parentFieldNames :+ p.getName, - fieldType = ParquetSchemaType( - p.getLogicalTypeAnnotation, - p.getPrimitiveTypeName, - p.getTypeLength))) - // Note that when g is a `Struct`, `g.getOriginalType` is `null`. - // When g is a `Map`, `g.getOriginalType` is `MAP`. - // When g is a `List`, `g.getOriginalType` is `LIST`. - case g: GroupType if g.getOriginalType == null => - getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) - // Parquet only supports push-down for primitive types; as a result, Map and List types - // are removed. - case _ => None - } - } - - val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => - (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field) - } - if (caseSensitive) { - primitiveFields.toMap - } else { - // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive - // mode, just skip pushdown for these fields, they will trigger Exception when reading, - // See: SPARK-25132. - val dedupPrimitiveFields = - primitiveFields - .groupBy(_._1.toLowerCase(Locale.ROOT)) - .filter(_._2.size == 1) - .mapValues(_.head._2) - CaseInsensitiveMap(dedupPrimitiveFields.toMap) - } - } - - /** - * Holds a single primitive field information stored in the underlying parquet file. - * - * @param fieldNames - * a field name as an array of string multi-identifier in parquet file - * @param fieldType - * field type related info in parquet file - */ - private case class ParquetPrimitiveField( - fieldNames: Array[String], - fieldType: ParquetSchemaType) - - private case class ParquetSchemaType( - logicalTypeAnnotation: LogicalTypeAnnotation, - primitiveTypeName: PrimitiveTypeName, - length: Int) - - private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0) - private val ParquetByteType = - ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0) - private val ParquetShortType = - ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0) - private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0) - private val ParquetLongType = ParquetSchemaType(null, INT64, 0) - private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0) - private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0) - private val ParquetStringType = - ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0) - private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0) - private val ParquetDateType = - ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0) - private val ParquetTimestampMicrosType = - ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0) - private val ParquetTimestampMillisType = - ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0) - - private def dateToDays(date: Any): Int = { - val gregorianDays = date match { - case d: Date => DateTimeUtils.fromJavaDate(d) - case ld: LocalDate => DateTimeUtils.localDateToDays(ld) - } - datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) - case _ => gregorianDays - } - } - - private def timestampToMicros(v: Any): JLong = { - val gregorianMicros = v match { - case i: Instant => DateTimeUtils.instantToMicros(i) - case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) - } - datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => - rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) - case _ => gregorianMicros - } - } - - private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() - - private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() - - private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { - val decimalBuffer = new Array[Byte](numBytes) - val bytes = decimal.unscaledValue().toByteArray - - val fixedLengthBytes = if (bytes.length == numBytes) { - bytes - } else { - val signByte = if (bytes.head < 0) -1: Byte else 0: Byte - java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) - System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) - decimalBuffer - } - Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) - } - - private def timestampToMillis(v: Any): JLong = { - val micros = timestampToMicros(v) - val millis = DateTimeUtils.microsToMillis(micros) - millis.asInstanceOf[JLong] - } - - private def toIntValue(v: Any): Integer = { - Option(v) - .map { - case p: Period => IntervalUtils.periodToMonths(p) - case n => n.asInstanceOf[Number].intValue - } - .map(_.asInstanceOf[Integer]) - .orNull - } - - private def toLongValue(v: Any): JLong = v match { - case d: Duration => IntervalUtils.durationToMicros(d) - case l => l.asInstanceOf[JLong] - } - - private val makeEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetBooleanType => - (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) - - // Binary.fromString and Binary.fromByteArray don't accept null values - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) - } - - private val makeNotEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetBooleanType => - (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) - } - - private val makeLt - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeLtEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeGt - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeGtEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeInPredicate: PartialFunction[ - ParquetSchemaType, - (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(toIntValue(_).toInt).foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetLongType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(toLongValue).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetFloatType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]), - FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat])) - - case ParquetDoubleType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]), - FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble])) - - case ParquetStringType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) - - case ParquetBinaryType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]])) - .foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) - - case ParquetDateType if pushDownDate => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(timestampToMicros).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(timestampToMillis).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)) - .foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary])) - } - - // Returns filters that can be pushed down when reading Parquet files. - def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { - filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) - } - - private def convertibleFiltersHelper( - predicate: sources.Filter, - canPartialPushDown: Boolean): Option[sources.Filter] = { - predicate match { - case sources.And(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - (leftResultOptional, rightResultOptional) match { - case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) - case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) - case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) - case _ => None - } - - case sources.Or(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { - None - } else { - Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) - } - case sources.Not(pred) => - val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - resultOptional.map(sources.Not) - - case other => - if (createFilter(other).isDefined) { - Some(other) - } else { - None - } - } - } - - /** - * Converts data sources filters to Parquet filter predicates. - */ - def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { - createFilterHelper(predicate, canPartialPushDownConjuncts = true) - } - - // Parquet's type in the given file should be matched to the value's type - // in the pushed filter in order to push down the filter to Parquet. - private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { - value == null || (nameToParquetField(name).fieldType match { - case ParquetBooleanType => value.isInstanceOf[JBoolean] - case ParquetByteType | ParquetShortType | ParquetIntegerType => - if (isSpark34Plus) { - value match { - // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type - // Int. We don't create a filter if the value would overflow. - case _: JByte | _: JShort | _: Integer => true - case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue - case _ => false - } - } else { - // If not Spark 3.4+, we still following the old behavior as Spark does. - value.isInstanceOf[Number] - } - case ParquetLongType => value.isInstanceOf[JLong] - case ParquetFloatType => value.isInstanceOf[JFloat] - case ParquetDoubleType => value.isInstanceOf[JDouble] - case ParquetStringType => value.isInstanceOf[String] - case ParquetBinaryType => value.isInstanceOf[Array[Byte]] - case ParquetDateType => - value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] - case ParquetTimestampMicrosType | ParquetTimestampMillisType => - value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] - case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) => - isDecimalMatched(value, decimalType) - case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) => - isDecimalMatched(value, decimalType) - case ParquetSchemaType( - decimalType: DecimalLogicalTypeAnnotation, - FIXED_LEN_BYTE_ARRAY, - _) => - isDecimalMatched(value, decimalType) - case _ => false - }) - } - - // Decimal type must make sure that filter value's scale matched the file. - // If doesn't matched, which would cause data corruption. - private def isDecimalMatched( - value: Any, - decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match { - case decimal: JBigDecimal => - decimal.scale == decimalLogicalType.getScale - case _ => false - } - - private def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) - } - - /** - * @param predicate - * the input filter predicates. Not all the predicates can be pushed down. - * @param canPartialPushDownConjuncts - * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one - * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. - * @return - * the Parquet-native filter predicates that are eligible for pushdown. - */ - private def createFilterHelper( - predicate: sources.Filter, - canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { - // NOTE: - // - // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, - // which can be casted to `false` implicitly. Please refer to the `eval` method of these - // operators and the `PruneFilters` rule for details. - - // Hyukjin: - // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. - // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. - // The reason why I did this is, that the actual Parquet filter checks null-safe equality - // comparison. - // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because - // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. - // Probably I missed something and obviously this should be changed. - - predicate match { - case sources.IsNull(name) if canMakeFilterOn(name, null) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name, null) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, null)) - - case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.LessThan(name, value) if canMakeFilterOn(name, value) => - makeLt - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeLtEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => - makeGt - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeGtEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.And(lhs, rhs) => - // At here, it is not safe to just convert one side and remove the other side - // if we do not understand what the parent filters are. - // - // Here is an example used to explain the reason. - // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to - // convert b in ('1'). If we only convert a = 2, we will end up with a filter - // NOT(a = 2), which will generate wrong results. - // - // Pushing one side of AND down is only safe to do at the top level or in the child - // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate - // can be safely removed. - val lhsFilterOption = - createFilterHelper(lhs, canPartialPushDownConjuncts) - val rhsFilterOption = - createFilterHelper(rhs, canPartialPushDownConjuncts) - - (lhsFilterOption, rhsFilterOption) match { - case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) - case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) - case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) - case _ => None - } - - case sources.Or(lhs, rhs) => - // The Or predicate is convertible when both of its children can be pushed down. - // That is to say, if one/both of the children can be partially pushed down, the Or - // predicate can be partially pushed down as well. - // - // Here is an example used to explain the reason. - // Let's say we have - // (a1 AND a2) OR (b1 AND b2), - // a1 and b1 is convertible, while a2 and b2 is not. - // The predicate can be converted as - // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) - // As per the logical in And predicate, we can push down (a1 OR b1). - for { - lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) - rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) - } yield FilterApi.or(lhsFilter, rhsFilter) - - case sources.Not(pred) => - createFilterHelper(pred, canPartialPushDownConjuncts = false) - .map(FilterApi.not) - - case sources.In(name, values) - if pushDownInFilterThreshold > 0 && values.nonEmpty && - canMakeFilterOn(name, values.head) => - val fieldType = nameToParquetField(name).fieldType - val fieldNames = nameToParquetField(name).fieldNames - if (values.length <= pushDownInFilterThreshold) { - values.distinct - .flatMap { v => - makeEq.lift(fieldType).map(_(fieldNames, v)) - } - .reduceLeftOption(FilterApi.or) - } else if (canPartialPushDownConjuncts) { - val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType - val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType) - if (values.contains(null)) { - Seq( - makeEq.lift(fieldType).map(_(fieldNames, null)), - makeInPredicate - .lift(fieldType) - .map(_(fieldNames, values.filter(_ != null), statistics))).flatten - .reduceLeftOption(FilterApi.or) - } else { - makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics)) - } - } else { - None - } - - case sources.StringStartsWith(name, prefix) - if pushDownStringPredicate && canMakeFilterOn(name, prefix) => - Option(prefix).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val strToBinary = Binary.fromReusedByteArray(v.getBytes) - private val size = strToBinary.length - - override def canDrop(statistics: Statistics[Binary]): Boolean = { - val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR - val max = statistics.getMax - val min = statistics.getMin - comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || - comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 - } - - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { - val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR - val max = statistics.getMax - val min = statistics.getMin - comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && - comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 - } - - override def keep(value: Binary): Boolean = { - value != null && UTF8String - .fromBytes(value.getBytes) - .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) - } - }) - } - - case sources.StringEndsWith(name, suffix) - if pushDownStringPredicate && canMakeFilterOn(name, suffix) => - Option(suffix).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val suffixStr = UTF8String.fromString(v) - override def canDrop(statistics: Statistics[Binary]): Boolean = false - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false - override def keep(value: Binary): Boolean = { - value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr) - } - }) - } - - case sources.StringContains(name, value) - if pushDownStringPredicate && canMakeFilterOn(name, value) => - Option(value).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val subStr = UTF8String.fromString(v) - override def canDrop(statistics: Statistics[Binary]): Boolean = false - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false - override def keep(value: Binary): Boolean = { - value != null && UTF8String.fromBytes(value.getBytes).contains(subStr) - } - }) - } - - case _ => None - } - } -} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala.old b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala.old deleted file mode 100644 index 14a664108..000000000 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala.old +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet - -import scala.collection.mutable.HashMap -import scala.concurrent.duration.NANOSECONDS -import scala.reflect.ClassTag - -import org.apache.hadoop.fs.Path -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.metric._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection._ - -import org.apache.comet.{CometConf, MetricsSupport} -import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} -import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} - -/** - * Comet physical scan node for DataSource V1. Most of the code here follow Spark's - * [[FileSourceScanExec]], - */ -case class CometScanExec( - @transient relation: HadoopFsRelation, - output: Seq[Attribute], - requiredSchema: StructType, - partitionFilters: Seq[Expression], - optionalBucketSet: Option[BitSet], - optionalNumCoalescedBuckets: Option[Int], - dataFilters: Seq[Expression], - tableIdentifier: Option[TableIdentifier], - disableBucketedScan: Boolean = false, - wrapped: FileSourceScanExec) - extends DataSourceScanExec - with ShimCometScanExec - with CometPlan { - - // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests - override lazy val supportsColumnar: Boolean = - relation.fileFormat.supportBatch(relation.sparkSession, schema) - - override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes - - private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty - - /** - * Send the driver-side metrics. Before calling this function, selectedPartitions has been - * initialized. See SPARK-26327 for more details. - */ - private def sendDriverMetrics(): Unit = { - driverMetrics.foreach(e => metrics(e._1).add(e._2)) - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates( - sparkContext, - executionId, - metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) - } - - private def isDynamicPruningFilter(e: Expression): Boolean = - e.find(_.isInstanceOf[PlanExpression[_]]).isDefined - - @transient lazy val selectedPartitions: Array[PartitionDirectory] = { - val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) - val startTime = System.nanoTime() - val ret = - relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) - setFilesNumAndSizeMetric(ret, true) - val timeTakenMs = - NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) - driverMetrics("metadataTime") = timeTakenMs - ret - }.toArray - - // We can only determine the actual partitions at runtime when a dynamic partition filter is - // present. This is because such a filter relies on information that is only available at run - // time (for instance the keys used in the other side of a join). - @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { - val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) - - if (dynamicPartitionFilters.nonEmpty) { - val startTime = System.nanoTime() - // call the file index for the files matching all filters except dynamic partition filters - val predicate = dynamicPartitionFilters.reduce(And) - val partitionColumns = relation.partitionSchema - val boundPredicate = Predicate.create( - predicate.transform { case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }, - Nil) - val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) - setFilesNumAndSizeMetric(ret, false) - val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 - driverMetrics("pruningTime") = timeTakenMs - ret - } else { - selectedPartitions - } - } - - // exposed for testing - lazy val bucketedScan: Boolean = wrapped.bucketedScan - - override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = - (wrapped.outputPartitioning, wrapped.outputOrdering) - - @transient - private lazy val pushedDownFilters = { - val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) - } - - override lazy val metadata: Map[String, String] = - if (wrapped == null) Map.empty else wrapped.metadata - - override def verboseStringWithOperatorId(): String = { - getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) - wrapped.verboseStringWithOperatorId() - } - - lazy val inputRDD: RDD[InternalRow] = { - val options = relation.options + - (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) - val readFile: (PartitionedFile) => Iterator[InternalRow] = - relation.fileFormat.buildReaderWithPartitionValues( - sparkSession = relation.sparkSession, - dataSchema = relation.dataSchema, - partitionSchema = relation.partitionSchema, - requiredSchema = requiredSchema, - filters = pushedDownFilters, - options = options, - hadoopConf = - relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) - - val readRDD = if (bucketedScan) { - createBucketedReadRDD( - relation.bucketSpec.get, - readFile, - dynamicallySelectedPartitions, - relation) - } else { - createReadRDD(readFile, dynamicallySelectedPartitions, relation) - } - sendDriverMetrics() - readRDD - } - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - inputRDD :: Nil - } - - /** Helper for computing total number and size of files in selected partitions. */ - private def setFilesNumAndSizeMetric( - partitions: Seq[PartitionDirectory], - static: Boolean): Unit = { - val filesNum = partitions.map(_.files.size.toLong).sum - val filesSize = partitions.map(_.files.map(_.getLen).sum).sum - if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { - driverMetrics("numFiles") = filesNum - driverMetrics("filesSize") = filesSize - } else { - driverMetrics("staticFilesNum") = filesNum - driverMetrics("staticFilesSize") = filesSize - } - if (relation.partitionSchema.nonEmpty) { - driverMetrics("numPartitions") = partitions.length - } - } - - override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { - // Tracking scan time has overhead, we can't afford to do it for each row, and can only do - // it for each batch. - if (supportsColumnar) { - Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) - } else { - None - } - } ++ { - relation.fileFormat match { - case f: MetricsSupport => f.initMetrics(sparkContext) - case _ => Map.empty - } - } - - protected override def doExecute(): RDD[InternalRow] = { - ColumnarToRowExec(this).doExecute() - } - - protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val numOutputRows = longMetric("numOutputRows") - val scanTime = longMetric("scanTime") - inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => - new Iterator[ColumnarBatch] { - - override def hasNext: Boolean = { - // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. - val startNs = System.nanoTime() - val res = batches.hasNext - scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) - res - } - - override def next(): ColumnarBatch = { - val batch = batches.next() - numOutputRows += batch.numRows() - batch - } - } - } - } - - override def executeCollect(): Array[InternalRow] = { - ColumnarToRowExec(this).executeCollect() - } - - override val nodeName: String = - s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" - - /** - * Create an RDD for bucketed reads. The non-bucketed variant of this function is - * [[createReadRDD]]. - * - * The algorithm is pretty simple: each RDD partition being returned should include all the - * files with the same bucket id from all the given Hive partitions. - * - * @param bucketSpec - * the bucketing spec. - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createBucketedReadRDD( - bucketSpec: BucketSpec, - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val filesGroupedToBuckets = - selectedPartitions - .flatMap { p => - p.files.map { f => - PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) - } - } - .groupBy { f => - BucketingUtils - .getBucketId(new Path(f.filePath.toString()).getName) - .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) - } - - val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { - val bucketSet = optionalBucketSet.get - filesGroupedToBuckets.filter { f => - bucketSet.get(f._1) - } - } else { - filesGroupedToBuckets - } - - val filePartitions = optionalNumCoalescedBuckets - .map { numCoalescedBuckets => - logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") - val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) - Seq.tabulate(numCoalescedBuckets) { bucketId => - val partitionedFiles = coalescedBuckets - .get(bucketId) - .map { - _.values.flatten.toArray - } - .getOrElse(Array.empty) - FilePartition(bucketId, partitionedFiles) - } - } - .getOrElse { - Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) - } - } - - prepareRDD(fsRelation, readFile, filePartitions) - } - - /** - * Create an RDD for non-bucketed reads. The bucketed variant of this function is - * [[createBucketedReadRDD]]. - * - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createReadRDD( - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes - val maxSplitBytes = - FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) - logInfo( - s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + - s"open cost is considered as scanning $openCostInBytes bytes.") - - // Filter files with bucket pruning if possible - val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled - val shouldProcess: Path => Boolean = optionalBucketSet match { - case Some(bucketSet) if bucketingEnabled => - // Do not prune the file if bucket file name is invalid - filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) - case _ => - _ => true - } - - val splitFiles = selectedPartitions - .flatMap { partition => - partition.files.flatMap { file => - // getPath() is very expensive so we only want to call it once in this block: - val filePath = file.getPath - - if (shouldProcess(filePath)) { - val isSplitable = relation.fileFormat.isSplitable( - relation.sparkSession, - relation.options, - filePath) && - // SPARK-39634: Allow file splitting in combination with row index generation once - // the fix for PARQUET-2161 is available. - !isNeededForSchema(requiredSchema) - PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, - file = file, - filePath = filePath, - isSplitable = isSplitable, - maxSplitBytes = maxSplitBytes, - partitionValues = partition.values) - } else { - Seq.empty - } - } - } - .sortBy(_.length)(implicitly[Ordering[Long]].reverse) - - prepareRDD( - fsRelation, - readFile, - FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) - } - - private def prepareRDD( - fsRelation: HadoopFsRelation, - readFile: (PartitionedFile) => Iterator[InternalRow], - partitions: Seq[FilePartition]): RDD[InternalRow] = { - val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) - val prefetchEnabled = hadoopConf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - - val sqlConf = fsRelation.sparkSession.sessionState.conf - if (prefetchEnabled) { - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedConf = - fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val partitionReaderFactory = CometParquetPartitionReaderFactory( - sqlConf, - broadcastedConf, - requiredSchema, - relation.partitionSchema, - pushedDownFilters.toArray, - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), - metrics) - - newDataSourceRDD( - fsRelation.sparkSession.sparkContext, - partitions.map(Seq(_)), - partitionReaderFactory, - true, - Map.empty) - } else { - newFileScanRDD( - fsRelation.sparkSession, - readFile, - partitions, - new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) - } - } - - // Filters unused DynamicPruningExpression expressions - one which has been replaced - // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning - private def filterUnusedDynamicPruningExpressions( - predicates: Seq[Expression]): Seq[Expression] = { - predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) - } - - override def doCanonicalize(): CometScanExec = { - CometScanExec( - relation, - output.map(QueryPlan.normalizeExpressions(_, output)), - requiredSchema, - QueryPlan.normalizePredicates( - filterUnusedDynamicPruningExpressions(partitionFilters), - output), - optionalBucketSet, - optionalNumCoalescedBuckets, - QueryPlan.normalizePredicates(dataFilters, output), - None, - disableBucketedScan, - null) - } -} - -object CometScanExec { - def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { - // TreeNode.mapProductIterator is protected method. - def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { - val arr = Array.ofDim[B](product.productArity) - var i = 0 - while (i < arr.length) { - arr(i) = f(product.productElement(i)) - i += 1 - } - arr - } - - // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues - // on other Spark distributions if FileSourceScanExec constructor is changed. - // Using `makeCopy` to avoid the issue. - // https://github.com/apache/arrow-datafusion-comet/issues/190 - def transform(arg: Any): AnyRef = arg match { - case _: HadoopFsRelation => - scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) - case other: AnyRef => other - case null => null - } - val newArgs = mapProductIterator(scanExec, transform(_)) - val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] - val batchScanExec = CometScanExec( - wrapped.relation, - wrapped.output, - wrapped.requiredSchema, - wrapped.partitionFilters, - wrapped.optionalBucketSet, - wrapped.optionalNumCoalescedBuckets, - wrapped.dataFilters, - wrapped.tableIdentifier, - wrapped.disableBucketedScan, - wrapped) - scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) - batchScanExec - } -} From 6c8fc4f378723de1f8bb0b5dd7fde3049f318b0c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 11 Jun 2024 07:56:47 -0600 Subject: [PATCH 03/12] remove centos support - unrelated to supporting Spark 3.5 --- core/Cross.toml | 2 -- core/comet_build_env_centos7.dockerfile | 36 ------------------------- 2 files changed, 38 deletions(-) delete mode 100644 core/Cross.toml delete mode 100644 core/comet_build_env_centos7.dockerfile diff --git a/core/Cross.toml b/core/Cross.toml deleted file mode 100644 index cbf43363b..000000000 --- a/core/Cross.toml +++ /dev/null @@ -1,2 +0,0 @@ -[target.x86_64-unknown-linux-gnu] -image = "comet_build_env_centos7:1.0" \ No newline at end of file diff --git a/core/comet_build_env_centos7.dockerfile b/core/comet_build_env_centos7.dockerfile deleted file mode 100644 index 026facd49..000000000 --- a/core/comet_build_env_centos7.dockerfile +++ /dev/null @@ -1,36 +0,0 @@ -FROM centos:7 -RUN ulimit -n 65536 - -# install common tools -#RUN yum update -y -RUN yum install -y centos-release-scl epel-release -RUN rm -f /var/lib/rpm/__db* && rpm --rebuilddb -RUN yum clean all && rm -rf /var/cache/yum -RUN yum makecache -RUN yum install -y -v git -RUN yum install -y -v libzip unzip wget cmake3 openssl-devel -# install protoc -RUN wget -O /protobuf-21.7-linux-x86_64.zip https://github.com/protocolbuffers/protobuf/releases/download/v21.7/protoc-21.7-linux-x86_64.zip -RUN mkdir /protobuf-bin && (cd /protobuf-bin && unzip /protobuf-21.7-linux-x86_64.zip) -RUN echo 'export PATH="$PATH:/protobuf-bin/bin"' >> ~/.bashrc - - -# install gcc-11 -RUN yum install -y devtoolset-11-gcc devtoolset-11-gcc-c++ -RUN echo '. /opt/rh/devtoolset-11/enable' >> ~/.bashrc - -# install rust nightly toolchain -RUN curl https://sh.rustup.rs > /rustup-init -RUN chmod +x /rustup-init -RUN /rustup-init -y --default-toolchain nightly-2023-08-01-x86_64-unknown-linux-gnu - -RUN echo 'source $HOME/.cargo/env' >> ~/.bashrc - -# install java -RUN yum install -y java-1.8.0-openjdk java-1.8.0-openjdk-devel -RUN echo 'export JAVA_HOME="/usr/lib/jvm/java-1.8.0-openjdk"' >> ~/.bashrc - -# install maven -RUN yum install -y rh-maven35 -RUN echo 'source /opt/rh/rh-maven35/enable' >> ~/.bashrc -RUN yum -y install gcc automake autoconf libtool make From 4ad786937f38dc2a3e60ce77d9844abb99c80343 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 11 Jun 2024 08:03:56 -0600 Subject: [PATCH 04/12] remove centos buid script --- build_for_centos7.sh | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 build_for_centos7.sh diff --git a/build_for_centos7.sh b/build_for_centos7.sh deleted file mode 100644 index ccd257162..000000000 --- a/build_for_centos7.sh +++ /dev/null @@ -1,5 +0,0 @@ -docker build -t comet_build_env_centos7:1.0 -f core/comet_build_env_centos7.dockerfile -export PROFILES="-Dmaven.test.skip=true -Pjdk1.8 -Pspark-3.5 -Dscalastyle.skip -Dspotless.check.skip=true -Denforcer.skip -Prelease-centos7 -Drat.skip=true" -cd core && RUSTFLAGS="-Ctarget-cpu=native" cross build --target x86_64-unknown-linux-gnu --features nightly --release -cd .. -mvn install -Prelease -DskipTests ${PROFILES} \ No newline at end of file From df7acae6f38654850b70e07eab1833ea20e9e10d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 11 Jun 2024 08:04:40 -0600 Subject: [PATCH 05/12] remove centos profile and fix merge conflict --- pom.xml | 7 ------- .../comet/execution/shuffle/CometShuffleExchangeExec.scala | 5 ----- 2 files changed, 12 deletions(-) diff --git a/pom.xml b/pom.xml index 1cce9fde1..1c0961059 100644 --- a/pom.xml +++ b/pom.xml @@ -447,13 +447,6 @@ under the License. - - release-centos7 - - ${project.basedir}/../core/target/x86_64-unknown-linux-gnu/release - - - Win-x86 diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 93ab2811e..b94717cdf 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -391,12 +391,7 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { val pageSize = SparkEnv.get.memoryManager.pageSizeBytes val sorter = UnsafeExternalRowSorter.createWithRecordComparator( -<<<<<<< HEAD - //StructType.fromAttributes(outputAttributes), StructType(outputAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))), -======= - fromAttributes(outputAttributes), ->>>>>>> apache/main recordComparatorSupplier, prefixComparator, prefixComputer, From e7c5f257f0cfbf892c1f42da1e50fd302b74445f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 11 Jun 2024 08:10:35 -0600 Subject: [PATCH 06/12] reduce shim duplication --- .../parquet/CometParquetFileFormat.scala | 231 ----- .../CometParquetPartitionReaderFactory.scala | 231 ----- .../apache/comet/parquet/ParquetFilters.scala | 882 ------------------ .../parquet/CometParquetFileFormat.scala | 231 ----- .../CometParquetPartitionReaderFactory.scala | 231 ----- .../apache/comet/parquet/ParquetFilters.scala | 882 ------------------ .../spark/sql/comet/CometScanExec.scala | 483 ---------- .../spark/sql/comet/CometScanExec.scala | 483 ---------- .../parquet/CometParquetFileFormat.scala | 0 .../CometParquetPartitionReaderFactory.scala | 0 .../apache/comet/parquet/ParquetFilters.scala | 0 .../spark/sql/comet/CometScanExec.scala | 0 12 files changed, 3654 deletions(-) delete mode 100644 spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala delete mode 100644 spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala delete mode 100644 spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala delete mode 100644 spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala delete mode 100644 spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala delete mode 100644 spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala delete mode 100644 spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala delete mode 100644 spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala rename spark/src/main/{spark-3.2 => spark-pre-3.5}/org/apache/comet/parquet/CometParquetFileFormat.scala (100%) rename spark/src/main/{spark-3.2 => spark-pre-3.5}/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala (100%) rename spark/src/main/{spark-3.2 => spark-pre-3.5}/org/apache/comet/parquet/ParquetFilters.scala (100%) rename spark/src/main/{spark-3.2 => spark-pre-3.5}/org/apache/spark/sql/comet/CometScanExec.scala (100%) diff --git a/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala deleted file mode 100644 index ac871cf60..000000000 --- a/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetFileFormat.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import scala.collection.JavaConverters - -import org.apache.hadoop.conf.Configuration -import org.apache.parquet.filter2.predicate.FilterApi -import org.apache.parquet.hadoop.ParquetInputFormat -import org.apache.parquet.hadoop.metadata.FileMetaData -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.execution.datasources.DataSourceUtils -import org.apache.spark.sql.execution.datasources.PartitionedFile -import org.apache.spark.sql.execution.datasources.RecordReaderIterator -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{DateType, StructType, TimestampType} -import org.apache.spark.util.SerializableConfiguration - -import org.apache.comet.CometConf -import org.apache.comet.MetricsSupport -import org.apache.comet.shims.ShimSQLConf -import org.apache.comet.vector.CometVector - -/** - * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's - * [[ParquetFileFormat]], but overrides: - * - * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap - * column vector in the whole-stage codegen path. - * - `supportBatch`, which simply returns true since data types should have already been checked - * in [[org.apache.comet.CometSparkSessionExtensions]] - * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values. - */ -class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf { - override def shortName(): String = "parquet" - override def toString: String = "CometParquet" - override def hashCode(): Int = getClass.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat] - - override def vectorTypes( - requiredSchema: StructType, - partitionSchema: StructType, - sqlConf: SQLConf): Option[Seq[String]] = { - val length = requiredSchema.fields.length + partitionSchema.fields.length - Option(Seq.fill(length)(classOf[CometVector].getName)) - } - - override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true - - override def buildReaderWithPartitionValues( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val sqlConf = sparkSession.sessionState.conf - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedHadoopConf = - sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - - val isCaseSensitive = sqlConf.caseSensitiveAnalysis - val useFieldId = CometParquetUtils.readFieldId(sqlConf) - val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) - val pushDownDate = sqlConf.parquetFilterPushDownDate - val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) - val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - val optionsMap = CaseInsensitiveMap[String](options) - val parquetOptions = new ParquetOptions(optionsMap, sqlConf) - val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead - val parquetFilterPushDown = sqlConf.parquetFilterPushDown - - // Comet specific configurations - val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) - - (file: PartitionedFile) => { - val sharedConf = broadcastedHadoopConf.value.value - val footer = FooterReader.readFooter(sharedConf, file) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( - file, - requiredSchema, - sharedConf, - footerFileMetaData, - datetimeRebaseModeInRead) - - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can - // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a - // `flatMap` is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) - - val batchReader = new BatchReader( - sharedConf, - file, - footer, - capacity, - requiredSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val iter = new RecordReaderIterator(batchReader) - try { - batchReader.init() - iter.asInstanceOf[Iterator[InternalRow]] - } catch { - case e: Throwable => - iter.close() - throw e - } - } - } -} - -object CometParquetFileFormat extends Logging { - - /** - * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` - */ - def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = { - hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) - hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) - hadoopConf.setBoolean( - SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, - sqlConf.nestedSchemaPruningEnabled) - hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) - - // Sets flags for `ParquetToSparkSchemaConverter` - hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) - hadoopConf.setBoolean( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlConf.isParquetINT96AsTimestamp) - - // Comet specific configs - hadoopConf.setBoolean( - CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key, - CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get()) - hadoopConf.setBoolean( - CometConf.COMET_USE_DECIMAL_128.key, - CometConf.COMET_USE_DECIMAL_128.get()) - hadoopConf.setBoolean( - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get()) - } - - def getDatetimeRebaseSpec( - file: PartitionedFile, - sparkSchema: StructType, - sharedConf: Configuration, - footerFileMetaData: FileMetaData, - datetimeRebaseModeInRead: String): RebaseSpec = { - val exceptionOnRebase = sharedConf.getBoolean( - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get) - var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( - footerFileMetaData.getKeyValueMetaData.get, - datetimeRebaseModeInRead) - val hasDateOrTimestamp = sparkSchema.exists(f => - f.dataType match { - case DateType | TimestampType => true - case _ => false - }) - - if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { - if (exceptionOnRebase) { - logWarning( - s"""Found Parquet file $file that could potentially contain dates/timestamps that were - written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase - and return these according to the new Proleptic Gregorian calendar, Comet will throw - exception when reading them. If you want to read them as it is according to the hybrid - Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to - false. Otherwise, if you want to read them according to the new Proleptic Gregorian - calendar, please disable Comet for this query.""") - } else { - // do not throw exception on rebase - read as it is - datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) - } - } - - datetimeRebaseSpec - } -} diff --git a/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala deleted file mode 100644 index 693af125b..000000000 --- a/spark/src/main/spark-3.3/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import scala.collection.JavaConverters -import scala.collection.mutable - -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.hadoop.ParquetInputFormat -import org.apache.parquet.hadoop.metadata.ParquetMetadata -import org.apache.spark.TaskContext -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.connector.read.InputPartition -import org.apache.spark.sql.connector.read.PartitionReader -import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration - -import org.apache.comet.{CometConf, CometRuntimeException} -import org.apache.comet.shims.ShimSQLConf - -case class CometParquetPartitionReaderFactory( - @transient sqlConf: SQLConf, - broadcastedConf: Broadcast[SerializableConfiguration], - readDataSchema: StructType, - partitionSchema: StructType, - filters: Array[Filter], - options: ParquetOptions, - metrics: Map[String, SQLMetric]) - extends FilePartitionReaderFactory - with ShimSQLConf - with Logging { - - private val isCaseSensitive = sqlConf.caseSensitiveAnalysis - private val useFieldId = CometParquetUtils.readFieldId(sqlConf) - private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) - private val pushDownDate = sqlConf.parquetFilterPushDownDate - private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) - private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead - private val parquetFilterPushDown = sqlConf.parquetFilterPushDown - - // Comet specific configurations - private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf) - - // This is only called at executor on a Broadcast variable, so we don't want it to be - // materialized at driver. - @transient private lazy val preFetchEnabled = { - val conf = broadcastedConf.value.value - - conf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - } - - private var cometReaders: Iterator[BatchReader] = _ - private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]() - - // TODO: we may want to revisit this as we're going to only support flat types at the beginning - override def supportColumnarReads(partition: InputPartition): Boolean = true - - override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { - if (preFetchEnabled) { - val filePartition = partition.asInstanceOf[FilePartition] - val conf = broadcastedConf.value.value - - val threadNum = conf.getInt( - CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key, - CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get) - val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum) - - this.cometReaders = filePartition.files - .map { file => - // `init()` call is deferred to when the prefetch task begins. - // Otherwise we will hold too many resources for readers which are not ready - // to prefetch. - val cometReader = buildCometReader(file) - if (cometReader != null) { - cometReader.submitPrefetchTask(prefetchThreadPool) - } - - cometReader - } - .toSeq - .toIterator - } - - super.createColumnarReader(partition) - } - - override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = - throw new UnsupportedOperationException("Comet doesn't support 'buildReader'") - - private def buildCometReader(file: PartitionedFile): BatchReader = { - val conf = broadcastedConf.value.value - - try { - val (datetimeRebaseSpec, footer, filters) = getFilter(file) - filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) - val cometReader = new BatchReader( - conf, - file, - footer, - batchSize, - readDataSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val taskContext = Option(TaskContext.get) - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) - return cometReader - } catch { - case e: Throwable if preFetchEnabled => - // Keep original exception - cometReaderExceptionMap.put(file, e) - } - null - } - - override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val cometReader = if (!preFetchEnabled) { - // Prefetch is not enabled, create comet reader and initiate it. - val cometReader = buildCometReader(file) - cometReader.init() - - cometReader - } else { - // If prefetch is enabled, we already tried to access the file when in `buildCometReader`. - // It is possibly we got an exception like `FileNotFoundException` and we need to throw it - // now to let Spark handle it. - val reader = cometReaders.next() - val exception = cometReaderExceptionMap.get(file) - exception.foreach(e => throw e) - - if (reader == null) { - throw new CometRuntimeException(s"Cannot find comet file reader for $file") - } - reader - } - CometPartitionReader(cometReader) - } - - def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = { - val sharedConf = broadcastedConf.value.value - val footer = FooterReader.readFooter(sharedConf, file) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( - file, - readDataSchema, - sharedConf, - footerFileMetaData, - datetimeRebaseModeInRead) - - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - (datetimeRebaseSpec, footer, pushed) - } - - override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = - throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.") - - /** - * A simple adapter on Comet's [[BatchReader]]. - */ - protected case class CometPartitionReader(reader: BatchReader) - extends PartitionReader[ColumnarBatch] { - - override def next(): Boolean = { - reader.nextBatch() - } - - override def get(): ColumnarBatch = { - reader.currentBatch() - } - - override def close(): Unit = { - reader.close() - } - } -} diff --git a/spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala deleted file mode 100644 index 5994dfb41..000000000 --- a/spark/src/main/spark-3.3/org/apache/comet/parquet/ParquetFilters.scala +++ /dev/null @@ -1,882 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} -import java.math.{BigDecimal => JBigDecimal} -import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate, Period} -import java.util.Locale - -import scala.collection.JavaConverters.asScalaBufferConverter - -import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} -import org.apache.parquet.filter2.predicate._ -import org.apache.parquet.filter2.predicate.SparkFilterApi._ -import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type} -import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ -import org.apache.parquet.schema.Type.Repetition -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} -import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources -import org.apache.spark.unsafe.types.UTF8String - -import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus - -/** - * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove - * this duplication - * - * Some utility function to convert Spark data source filters to Parquet filters. - */ -class ParquetFilters( - schema: MessageType, - pushDownDate: Boolean, - pushDownTimestamp: Boolean, - pushDownDecimal: Boolean, - pushDownStringPredicate: Boolean, - pushDownInFilterThreshold: Int, - caseSensitive: Boolean, - datetimeRebaseSpec: RebaseSpec) { - // A map which contains parquet field name and data type, if predicate push down applies. - // - // Each key in `nameToParquetField` represents a column; `dots` are used as separators for - // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. - // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. - private val nameToParquetField: Map[String, ParquetPrimitiveField] = { - // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. - // `parentFieldNames` is used to keep track of the current nested level when traversing. - def getPrimitiveFields( - fields: Seq[Type], - parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { - fields.flatMap { - // Parquet only supports predicate push-down for non-repeated primitive types. - // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for - // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34) - case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => - Some( - ParquetPrimitiveField( - fieldNames = parentFieldNames :+ p.getName, - fieldType = ParquetSchemaType( - p.getLogicalTypeAnnotation, - p.getPrimitiveTypeName, - p.getTypeLength))) - // Note that when g is a `Struct`, `g.getOriginalType` is `null`. - // When g is a `Map`, `g.getOriginalType` is `MAP`. - // When g is a `List`, `g.getOriginalType` is `LIST`. - case g: GroupType if g.getOriginalType == null => - getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) - // Parquet only supports push-down for primitive types; as a result, Map and List types - // are removed. - case _ => None - } - } - - val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => - (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field) - } - if (caseSensitive) { - primitiveFields.toMap - } else { - // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive - // mode, just skip pushdown for these fields, they will trigger Exception when reading, - // See: SPARK-25132. - val dedupPrimitiveFields = - primitiveFields - .groupBy(_._1.toLowerCase(Locale.ROOT)) - .filter(_._2.size == 1) - .mapValues(_.head._2) - CaseInsensitiveMap(dedupPrimitiveFields.toMap) - } - } - - /** - * Holds a single primitive field information stored in the underlying parquet file. - * - * @param fieldNames - * a field name as an array of string multi-identifier in parquet file - * @param fieldType - * field type related info in parquet file - */ - private case class ParquetPrimitiveField( - fieldNames: Array[String], - fieldType: ParquetSchemaType) - - private case class ParquetSchemaType( - logicalTypeAnnotation: LogicalTypeAnnotation, - primitiveTypeName: PrimitiveTypeName, - length: Int) - - private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0) - private val ParquetByteType = - ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0) - private val ParquetShortType = - ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0) - private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0) - private val ParquetLongType = ParquetSchemaType(null, INT64, 0) - private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0) - private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0) - private val ParquetStringType = - ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0) - private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0) - private val ParquetDateType = - ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0) - private val ParquetTimestampMicrosType = - ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0) - private val ParquetTimestampMillisType = - ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0) - - private def dateToDays(date: Any): Int = { - val gregorianDays = date match { - case d: Date => DateTimeUtils.fromJavaDate(d) - case ld: LocalDate => DateTimeUtils.localDateToDays(ld) - } - datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) - case _ => gregorianDays - } - } - - private def timestampToMicros(v: Any): JLong = { - val gregorianMicros = v match { - case i: Instant => DateTimeUtils.instantToMicros(i) - case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) - } - datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => - rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) - case _ => gregorianMicros - } - } - - private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() - - private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() - - private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { - val decimalBuffer = new Array[Byte](numBytes) - val bytes = decimal.unscaledValue().toByteArray - - val fixedLengthBytes = if (bytes.length == numBytes) { - bytes - } else { - val signByte = if (bytes.head < 0) -1: Byte else 0: Byte - java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) - System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) - decimalBuffer - } - Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) - } - - private def timestampToMillis(v: Any): JLong = { - val micros = timestampToMicros(v) - val millis = DateTimeUtils.microsToMillis(micros) - millis.asInstanceOf[JLong] - } - - private def toIntValue(v: Any): Integer = { - Option(v) - .map { - case p: Period => IntervalUtils.periodToMonths(p) - case n => n.asInstanceOf[Number].intValue - } - .map(_.asInstanceOf[Integer]) - .orNull - } - - private def toLongValue(v: Any): JLong = v match { - case d: Duration => IntervalUtils.durationToMicros(d) - case l => l.asInstanceOf[JLong] - } - - private val makeEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetBooleanType => - (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) - - // Binary.fromString and Binary.fromByteArray don't accept null values - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) - } - - private val makeNotEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetBooleanType => - (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) - } - - private val makeLt - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeLtEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeGt - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeGtEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeInPredicate: PartialFunction[ - ParquetSchemaType, - (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(toIntValue(_).toInt).foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetLongType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(toLongValue).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetFloatType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]), - FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat])) - - case ParquetDoubleType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]), - FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble])) - - case ParquetStringType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) - - case ParquetBinaryType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]])) - .foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) - - case ParquetDateType if pushDownDate => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(timestampToMicros).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(timestampToMillis).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)) - .foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary])) - } - - // Returns filters that can be pushed down when reading Parquet files. - def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { - filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) - } - - private def convertibleFiltersHelper( - predicate: sources.Filter, - canPartialPushDown: Boolean): Option[sources.Filter] = { - predicate match { - case sources.And(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - (leftResultOptional, rightResultOptional) match { - case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) - case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) - case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) - case _ => None - } - - case sources.Or(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { - None - } else { - Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) - } - case sources.Not(pred) => - val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - resultOptional.map(sources.Not) - - case other => - if (createFilter(other).isDefined) { - Some(other) - } else { - None - } - } - } - - /** - * Converts data sources filters to Parquet filter predicates. - */ - def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { - createFilterHelper(predicate, canPartialPushDownConjuncts = true) - } - - // Parquet's type in the given file should be matched to the value's type - // in the pushed filter in order to push down the filter to Parquet. - private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { - value == null || (nameToParquetField(name).fieldType match { - case ParquetBooleanType => value.isInstanceOf[JBoolean] - case ParquetByteType | ParquetShortType | ParquetIntegerType => - if (isSpark34Plus) { - value match { - // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type - // Int. We don't create a filter if the value would overflow. - case _: JByte | _: JShort | _: Integer => true - case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue - case _ => false - } - } else { - // If not Spark 3.4+, we still following the old behavior as Spark does. - value.isInstanceOf[Number] - } - case ParquetLongType => value.isInstanceOf[JLong] - case ParquetFloatType => value.isInstanceOf[JFloat] - case ParquetDoubleType => value.isInstanceOf[JDouble] - case ParquetStringType => value.isInstanceOf[String] - case ParquetBinaryType => value.isInstanceOf[Array[Byte]] - case ParquetDateType => - value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] - case ParquetTimestampMicrosType | ParquetTimestampMillisType => - value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] - case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) => - isDecimalMatched(value, decimalType) - case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) => - isDecimalMatched(value, decimalType) - case ParquetSchemaType( - decimalType: DecimalLogicalTypeAnnotation, - FIXED_LEN_BYTE_ARRAY, - _) => - isDecimalMatched(value, decimalType) - case _ => false - }) - } - - // Decimal type must make sure that filter value's scale matched the file. - // If doesn't matched, which would cause data corruption. - private def isDecimalMatched( - value: Any, - decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match { - case decimal: JBigDecimal => - decimal.scale == decimalLogicalType.getScale - case _ => false - } - - private def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) - } - - /** - * @param predicate - * the input filter predicates. Not all the predicates can be pushed down. - * @param canPartialPushDownConjuncts - * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one - * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. - * @return - * the Parquet-native filter predicates that are eligible for pushdown. - */ - private def createFilterHelper( - predicate: sources.Filter, - canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { - // NOTE: - // - // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, - // which can be casted to `false` implicitly. Please refer to the `eval` method of these - // operators and the `PruneFilters` rule for details. - - // Hyukjin: - // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. - // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. - // The reason why I did this is, that the actual Parquet filter checks null-safe equality - // comparison. - // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because - // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. - // Probably I missed something and obviously this should be changed. - - predicate match { - case sources.IsNull(name) if canMakeFilterOn(name, null) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name, null) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, null)) - - case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.LessThan(name, value) if canMakeFilterOn(name, value) => - makeLt - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeLtEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => - makeGt - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeGtEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.And(lhs, rhs) => - // At here, it is not safe to just convert one side and remove the other side - // if we do not understand what the parent filters are. - // - // Here is an example used to explain the reason. - // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to - // convert b in ('1'). If we only convert a = 2, we will end up with a filter - // NOT(a = 2), which will generate wrong results. - // - // Pushing one side of AND down is only safe to do at the top level or in the child - // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate - // can be safely removed. - val lhsFilterOption = - createFilterHelper(lhs, canPartialPushDownConjuncts) - val rhsFilterOption = - createFilterHelper(rhs, canPartialPushDownConjuncts) - - (lhsFilterOption, rhsFilterOption) match { - case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) - case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) - case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) - case _ => None - } - - case sources.Or(lhs, rhs) => - // The Or predicate is convertible when both of its children can be pushed down. - // That is to say, if one/both of the children can be partially pushed down, the Or - // predicate can be partially pushed down as well. - // - // Here is an example used to explain the reason. - // Let's say we have - // (a1 AND a2) OR (b1 AND b2), - // a1 and b1 is convertible, while a2 and b2 is not. - // The predicate can be converted as - // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) - // As per the logical in And predicate, we can push down (a1 OR b1). - for { - lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) - rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) - } yield FilterApi.or(lhsFilter, rhsFilter) - - case sources.Not(pred) => - createFilterHelper(pred, canPartialPushDownConjuncts = false) - .map(FilterApi.not) - - case sources.In(name, values) - if pushDownInFilterThreshold > 0 && values.nonEmpty && - canMakeFilterOn(name, values.head) => - val fieldType = nameToParquetField(name).fieldType - val fieldNames = nameToParquetField(name).fieldNames - if (values.length <= pushDownInFilterThreshold) { - values.distinct - .flatMap { v => - makeEq.lift(fieldType).map(_(fieldNames, v)) - } - .reduceLeftOption(FilterApi.or) - } else if (canPartialPushDownConjuncts) { - val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType - val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType) - if (values.contains(null)) { - Seq( - makeEq.lift(fieldType).map(_(fieldNames, null)), - makeInPredicate - .lift(fieldType) - .map(_(fieldNames, values.filter(_ != null), statistics))).flatten - .reduceLeftOption(FilterApi.or) - } else { - makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics)) - } - } else { - None - } - - case sources.StringStartsWith(name, prefix) - if pushDownStringPredicate && canMakeFilterOn(name, prefix) => - Option(prefix).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val strToBinary = Binary.fromReusedByteArray(v.getBytes) - private val size = strToBinary.length - - override def canDrop(statistics: Statistics[Binary]): Boolean = { - val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR - val max = statistics.getMax - val min = statistics.getMin - comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || - comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 - } - - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { - val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR - val max = statistics.getMax - val min = statistics.getMin - comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && - comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 - } - - override def keep(value: Binary): Boolean = { - value != null && UTF8String - .fromBytes(value.getBytes) - .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) - } - }) - } - - case sources.StringEndsWith(name, suffix) - if pushDownStringPredicate && canMakeFilterOn(name, suffix) => - Option(suffix).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val suffixStr = UTF8String.fromString(v) - override def canDrop(statistics: Statistics[Binary]): Boolean = false - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false - override def keep(value: Binary): Boolean = { - value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr) - } - }) - } - - case sources.StringContains(name, value) - if pushDownStringPredicate && canMakeFilterOn(name, value) => - Option(value).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val subStr = UTF8String.fromString(v) - override def canDrop(statistics: Statistics[Binary]): Boolean = false - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false - override def keep(value: Binary): Boolean = { - value != null && UTF8String.fromBytes(value.getBytes).contains(subStr) - } - }) - } - - case _ => None - } - } -} diff --git a/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala deleted file mode 100644 index ac871cf60..000000000 --- a/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetFileFormat.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import scala.collection.JavaConverters - -import org.apache.hadoop.conf.Configuration -import org.apache.parquet.filter2.predicate.FilterApi -import org.apache.parquet.hadoop.ParquetInputFormat -import org.apache.parquet.hadoop.metadata.FileMetaData -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.execution.datasources.DataSourceUtils -import org.apache.spark.sql.execution.datasources.PartitionedFile -import org.apache.spark.sql.execution.datasources.RecordReaderIterator -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{DateType, StructType, TimestampType} -import org.apache.spark.util.SerializableConfiguration - -import org.apache.comet.CometConf -import org.apache.comet.MetricsSupport -import org.apache.comet.shims.ShimSQLConf -import org.apache.comet.vector.CometVector - -/** - * A Comet specific Parquet format. This mostly reuse the functionalities from Spark's - * [[ParquetFileFormat]], but overrides: - * - * - `vectorTypes`, so Spark allocates [[CometVector]] instead of it's own on-heap or off-heap - * column vector in the whole-stage codegen path. - * - `supportBatch`, which simply returns true since data types should have already been checked - * in [[org.apache.comet.CometSparkSessionExtensions]] - * - `buildReaderWithPartitionValues`, so Spark calls Comet's Parquet reader to read values. - */ -class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with ShimSQLConf { - override def shortName(): String = "parquet" - override def toString: String = "CometParquet" - override def hashCode(): Int = getClass.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[CometParquetFileFormat] - - override def vectorTypes( - requiredSchema: StructType, - partitionSchema: StructType, - sqlConf: SQLConf): Option[Seq[String]] = { - val length = requiredSchema.fields.length + partitionSchema.fields.length - Option(Seq.fill(length)(classOf[CometVector].getName)) - } - - override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = true - - override def buildReaderWithPartitionValues( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val sqlConf = sparkSession.sessionState.conf - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedHadoopConf = - sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - - val isCaseSensitive = sqlConf.caseSensitiveAnalysis - val useFieldId = CometParquetUtils.readFieldId(sqlConf) - val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) - val pushDownDate = sqlConf.parquetFilterPushDownDate - val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) - val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - val optionsMap = CaseInsensitiveMap[String](options) - val parquetOptions = new ParquetOptions(optionsMap, sqlConf) - val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead - val parquetFilterPushDown = sqlConf.parquetFilterPushDown - - // Comet specific configurations - val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) - - (file: PartitionedFile) => { - val sharedConf = broadcastedHadoopConf.value.value - val footer = FooterReader.readFooter(sharedConf, file) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( - file, - requiredSchema, - sharedConf, - footerFileMetaData, - datetimeRebaseModeInRead) - - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can - // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a - // `flatMap` is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) - - val batchReader = new BatchReader( - sharedConf, - file, - footer, - capacity, - requiredSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val iter = new RecordReaderIterator(batchReader) - try { - batchReader.init() - iter.asInstanceOf[Iterator[InternalRow]] - } catch { - case e: Throwable => - iter.close() - throw e - } - } - } -} - -object CometParquetFileFormat extends Logging { - - /** - * Populates Parquet related configurations from the input `sqlConf` to the `hadoopConf` - */ - def populateConf(sqlConf: SQLConf, hadoopConf: Configuration): Unit = { - hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) - hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) - hadoopConf.setBoolean( - SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, - sqlConf.nestedSchemaPruningEnabled) - hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) - - // Sets flags for `ParquetToSparkSchemaConverter` - hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) - hadoopConf.setBoolean( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlConf.isParquetINT96AsTimestamp) - - // Comet specific configs - hadoopConf.setBoolean( - CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.key, - CometConf.COMET_PARQUET_ENABLE_DIRECT_BUFFER.get()) - hadoopConf.setBoolean( - CometConf.COMET_USE_DECIMAL_128.key, - CometConf.COMET_USE_DECIMAL_128.get()) - hadoopConf.setBoolean( - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get()) - } - - def getDatetimeRebaseSpec( - file: PartitionedFile, - sparkSchema: StructType, - sharedConf: Configuration, - footerFileMetaData: FileMetaData, - datetimeRebaseModeInRead: String): RebaseSpec = { - val exceptionOnRebase = sharedConf.getBoolean( - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key, - CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.defaultValue.get) - var datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( - footerFileMetaData.getKeyValueMetaData.get, - datetimeRebaseModeInRead) - val hasDateOrTimestamp = sparkSchema.exists(f => - f.dataType match { - case DateType | TimestampType => true - case _ => false - }) - - if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LegacyBehaviorPolicy.LEGACY) { - if (exceptionOnRebase) { - logWarning( - s"""Found Parquet file $file that could potentially contain dates/timestamps that were - written in legacy hybrid Julian/Gregorian calendar. Unlike Spark 3+, which will rebase - and return these according to the new Proleptic Gregorian calendar, Comet will throw - exception when reading them. If you want to read them as it is according to the hybrid - Julian/Gregorian calendar, please set `spark.comet.exceptionOnDatetimeRebase` to - false. Otherwise, if you want to read them according to the new Proleptic Gregorian - calendar, please disable Comet for this query.""") - } else { - // do not throw exception on rebase - read as it is - datetimeRebaseSpec = datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED) - } - } - - datetimeRebaseSpec - } -} diff --git a/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala deleted file mode 100644 index 693af125b..000000000 --- a/spark/src/main/spark-3.4/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import scala.collection.JavaConverters -import scala.collection.mutable - -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.hadoop.ParquetInputFormat -import org.apache.parquet.hadoop.metadata.ParquetMetadata -import org.apache.spark.TaskContext -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.connector.read.InputPartition -import org.apache.spark.sql.connector.read.PartitionReader -import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration - -import org.apache.comet.{CometConf, CometRuntimeException} -import org.apache.comet.shims.ShimSQLConf - -case class CometParquetPartitionReaderFactory( - @transient sqlConf: SQLConf, - broadcastedConf: Broadcast[SerializableConfiguration], - readDataSchema: StructType, - partitionSchema: StructType, - filters: Array[Filter], - options: ParquetOptions, - metrics: Map[String, SQLMetric]) - extends FilePartitionReaderFactory - with ShimSQLConf - with Logging { - - private val isCaseSensitive = sqlConf.caseSensitiveAnalysis - private val useFieldId = CometParquetUtils.readFieldId(sqlConf) - private val ignoreMissingIds = CometParquetUtils.ignoreMissingIds(sqlConf) - private val pushDownDate = sqlConf.parquetFilterPushDownDate - private val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - private val pushDownStringPredicate = getPushDownStringPredicate(sqlConf) - private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead - private val parquetFilterPushDown = sqlConf.parquetFilterPushDown - - // Comet specific configurations - private val batchSize = CometConf.COMET_BATCH_SIZE.get(sqlConf) - - // This is only called at executor on a Broadcast variable, so we don't want it to be - // materialized at driver. - @transient private lazy val preFetchEnabled = { - val conf = broadcastedConf.value.value - - conf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - } - - private var cometReaders: Iterator[BatchReader] = _ - private val cometReaderExceptionMap = new mutable.HashMap[PartitionedFile, Throwable]() - - // TODO: we may want to revisit this as we're going to only support flat types at the beginning - override def supportColumnarReads(partition: InputPartition): Boolean = true - - override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { - if (preFetchEnabled) { - val filePartition = partition.asInstanceOf[FilePartition] - val conf = broadcastedConf.value.value - - val threadNum = conf.getInt( - CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.key, - CometConf.COMET_SCAN_PREFETCH_THREAD_NUM.defaultValue.get) - val prefetchThreadPool = CometPrefetchThreadPool.getOrCreateThreadPool(threadNum) - - this.cometReaders = filePartition.files - .map { file => - // `init()` call is deferred to when the prefetch task begins. - // Otherwise we will hold too many resources for readers which are not ready - // to prefetch. - val cometReader = buildCometReader(file) - if (cometReader != null) { - cometReader.submitPrefetchTask(prefetchThreadPool) - } - - cometReader - } - .toSeq - .toIterator - } - - super.createColumnarReader(partition) - } - - override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = - throw new UnsupportedOperationException("Comet doesn't support 'buildReader'") - - private def buildCometReader(file: PartitionedFile): BatchReader = { - val conf = broadcastedConf.value.value - - try { - val (datetimeRebaseSpec, footer, filters) = getFilter(file) - filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) - val cometReader = new BatchReader( - conf, - file, - footer, - batchSize, - readDataSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val taskContext = Option(TaskContext.get) - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) - return cometReader - } catch { - case e: Throwable if preFetchEnabled => - // Keep original exception - cometReaderExceptionMap.put(file, e) - } - null - } - - override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val cometReader = if (!preFetchEnabled) { - // Prefetch is not enabled, create comet reader and initiate it. - val cometReader = buildCometReader(file) - cometReader.init() - - cometReader - } else { - // If prefetch is enabled, we already tried to access the file when in `buildCometReader`. - // It is possibly we got an exception like `FileNotFoundException` and we need to throw it - // now to let Spark handle it. - val reader = cometReaders.next() - val exception = cometReaderExceptionMap.get(file) - exception.foreach(e => throw e) - - if (reader == null) { - throw new CometRuntimeException(s"Cannot find comet file reader for $file") - } - reader - } - CometPartitionReader(cometReader) - } - - def getFilter(file: PartitionedFile): (RebaseSpec, ParquetMetadata, Option[FilterPredicate]) = { - val sharedConf = broadcastedConf.value.value - val footer = FooterReader.readFooter(sharedConf, file) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseSpec = CometParquetFileFormat.getDatetimeRebaseSpec( - file, - readDataSchema, - sharedConf, - footerFileMetaData, - datetimeRebaseModeInRead) - - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - (datetimeRebaseSpec, footer, pushed) - } - - override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = - throw new UnsupportedOperationException("Only 'createColumnarReader' is supported.") - - /** - * A simple adapter on Comet's [[BatchReader]]. - */ - protected case class CometPartitionReader(reader: BatchReader) - extends PartitionReader[ColumnarBatch] { - - override def next(): Boolean = { - reader.nextBatch() - } - - override def get(): ColumnarBatch = { - reader.currentBatch() - } - - override def close(): Unit = { - reader.close() - } - } -} diff --git a/spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala deleted file mode 100644 index 5994dfb41..000000000 --- a/spark/src/main/spark-3.4/org/apache/comet/parquet/ParquetFilters.scala +++ /dev/null @@ -1,882 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.parquet - -import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} -import java.math.{BigDecimal => JBigDecimal} -import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate, Period} -import java.util.Locale - -import scala.collection.JavaConverters.asScalaBufferConverter - -import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} -import org.apache.parquet.filter2.predicate._ -import org.apache.parquet.filter2.predicate.SparkFilterApi._ -import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, MessageType, PrimitiveComparator, PrimitiveType, Type} -import org.apache.parquet.schema.LogicalTypeAnnotation.{DecimalLogicalTypeAnnotation, TimeUnit} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ -import org.apache.parquet.schema.Type.Repetition -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} -import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources -import org.apache.spark.unsafe.types.UTF8String - -import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus - -/** - * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: find a way to remove - * this duplication - * - * Some utility function to convert Spark data source filters to Parquet filters. - */ -class ParquetFilters( - schema: MessageType, - pushDownDate: Boolean, - pushDownTimestamp: Boolean, - pushDownDecimal: Boolean, - pushDownStringPredicate: Boolean, - pushDownInFilterThreshold: Int, - caseSensitive: Boolean, - datetimeRebaseSpec: RebaseSpec) { - // A map which contains parquet field name and data type, if predicate push down applies. - // - // Each key in `nameToParquetField` represents a column; `dots` are used as separators for - // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. - // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. - private val nameToParquetField: Map[String, ParquetPrimitiveField] = { - // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. - // `parentFieldNames` is used to keep track of the current nested level when traversing. - def getPrimitiveFields( - fields: Seq[Type], - parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { - fields.flatMap { - // Parquet only supports predicate push-down for non-repeated primitive types. - // TODO(SPARK-39393): Remove extra condition when parquet added filter predicate support for - // repeated columns (https://issues.apache.org/jira/browse/PARQUET-34) - case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => - Some( - ParquetPrimitiveField( - fieldNames = parentFieldNames :+ p.getName, - fieldType = ParquetSchemaType( - p.getLogicalTypeAnnotation, - p.getPrimitiveTypeName, - p.getTypeLength))) - // Note that when g is a `Struct`, `g.getOriginalType` is `null`. - // When g is a `Map`, `g.getOriginalType` is `MAP`. - // When g is a `List`, `g.getOriginalType` is `LIST`. - case g: GroupType if g.getOriginalType == null => - getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) - // Parquet only supports push-down for primitive types; as a result, Map and List types - // are removed. - case _ => None - } - } - - val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => - (field.fieldNames.toSeq.map(quoteIfNeeded).mkString("."), field) - } - if (caseSensitive) { - primitiveFields.toMap - } else { - // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive - // mode, just skip pushdown for these fields, they will trigger Exception when reading, - // See: SPARK-25132. - val dedupPrimitiveFields = - primitiveFields - .groupBy(_._1.toLowerCase(Locale.ROOT)) - .filter(_._2.size == 1) - .mapValues(_.head._2) - CaseInsensitiveMap(dedupPrimitiveFields.toMap) - } - } - - /** - * Holds a single primitive field information stored in the underlying parquet file. - * - * @param fieldNames - * a field name as an array of string multi-identifier in parquet file - * @param fieldType - * field type related info in parquet file - */ - private case class ParquetPrimitiveField( - fieldNames: Array[String], - fieldType: ParquetSchemaType) - - private case class ParquetSchemaType( - logicalTypeAnnotation: LogicalTypeAnnotation, - primitiveTypeName: PrimitiveTypeName, - length: Int) - - private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0) - private val ParquetByteType = - ParquetSchemaType(LogicalTypeAnnotation.intType(8, true), INT32, 0) - private val ParquetShortType = - ParquetSchemaType(LogicalTypeAnnotation.intType(16, true), INT32, 0) - private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0) - private val ParquetLongType = ParquetSchemaType(null, INT64, 0) - private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0) - private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0) - private val ParquetStringType = - ParquetSchemaType(LogicalTypeAnnotation.stringType(), BINARY, 0) - private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0) - private val ParquetDateType = - ParquetSchemaType(LogicalTypeAnnotation.dateType(), INT32, 0) - private val ParquetTimestampMicrosType = - ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0) - private val ParquetTimestampMillisType = - ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0) - - private def dateToDays(date: Any): Int = { - val gregorianDays = date match { - case d: Date => DateTimeUtils.fromJavaDate(d) - case ld: LocalDate => DateTimeUtils.localDateToDays(ld) - } - datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => rebaseGregorianToJulianDays(gregorianDays) - case _ => gregorianDays - } - } - - private def timestampToMicros(v: Any): JLong = { - val gregorianMicros = v match { - case i: Instant => DateTimeUtils.instantToMicros(i) - case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) - } - datetimeRebaseSpec.mode match { - case LegacyBehaviorPolicy.LEGACY => - rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, gregorianMicros) - case _ => gregorianMicros - } - } - - private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() - - private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() - - private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { - val decimalBuffer = new Array[Byte](numBytes) - val bytes = decimal.unscaledValue().toByteArray - - val fixedLengthBytes = if (bytes.length == numBytes) { - bytes - } else { - val signByte = if (bytes.head < 0) -1: Byte else 0: Byte - java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) - System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) - decimalBuffer - } - Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) - } - - private def timestampToMillis(v: Any): JLong = { - val micros = timestampToMicros(v) - val millis = DateTimeUtils.microsToMillis(micros) - millis.asInstanceOf[JLong] - } - - private def toIntValue(v: Any): Integer = { - Option(v) - .map { - case p: Period => IntervalUtils.periodToMonths(p) - case n => n.asInstanceOf[Number].intValue - } - .map(_.asInstanceOf[Integer]) - .orNull - } - - private def toLongValue(v: Any): JLong = v match { - case d: Duration => IntervalUtils.durationToMicros(d) - case l => l.asInstanceOf[JLong] - } - - private val makeEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetBooleanType => - (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) - - // Binary.fromString and Binary.fromByteArray don't accept null values - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.eq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) - } - - private val makeNotEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetBooleanType => - (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(_ => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => - FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.notEq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) - } - - private val makeLt - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeLtEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeGt - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeGtEq - : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) - case ParquetLongType => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) - case ParquetFloatType => - (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) - case ParquetDoubleType => - (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) - - case ParquetStringType => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case ParquetBinaryType => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case ParquetDateType if pushDownDate => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (n: Array[String], v: Any) => - FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) - } - - private val makeInPredicate: PartialFunction[ - ParquetSchemaType, - (Array[String], Array[Any], ParquetStatistics[_]) => FilterPredicate] = { - case ParquetByteType | ParquetShortType | ParquetIntegerType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(toIntValue(_).toInt).foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetLongType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(toLongValue).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetFloatType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JFloat]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(floatColumn(n), statistics.genericGetMin().asInstanceOf[JFloat]), - FilterApi.ltEq(floatColumn(n), statistics.genericGetMax().asInstanceOf[JFloat])) - - case ParquetDoubleType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JDouble]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(doubleColumn(n), statistics.genericGetMin().asInstanceOf[JDouble]), - FilterApi.ltEq(doubleColumn(n), statistics.genericGetMax().asInstanceOf[JDouble])) - - case ParquetStringType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(s => Binary.fromString(s.asInstanceOf[String])).foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) - - case ParquetBinaryType => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(b => Binary.fromReusedByteArray(b.asInstanceOf[Array[Byte]])) - .foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(n), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(n), statistics.genericGetMax().asInstanceOf[Binary])) - - case ParquetDateType if pushDownDate => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(dateToDays).map(_.asInstanceOf[Integer]).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(timestampToMicros).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(timestampToMillis).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt32).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(intColumn(n), statistics.genericGetMin().asInstanceOf[Integer]), - FilterApi.ltEq(intColumn(n), statistics.genericGetMax().asInstanceOf[Integer])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT64, _) if pushDownDecimal => - (n: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(_.asInstanceOf[JBigDecimal]).map(decimalToInt64).foreach(statistics.updateStats(_)) - FilterApi.and( - FilterApi.gtEq(longColumn(n), statistics.genericGetMin().asInstanceOf[JLong]), - FilterApi.ltEq(longColumn(n), statistics.genericGetMax().asInstanceOf[JLong])) - - case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, FIXED_LEN_BYTE_ARRAY, length) - if pushDownDecimal => - (path: Array[String], v: Array[Any], statistics: ParquetStatistics[_]) => - v.map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)) - .foreach(statistics.updateStats) - FilterApi.and( - FilterApi.gtEq(binaryColumn(path), statistics.genericGetMin().asInstanceOf[Binary]), - FilterApi.ltEq(binaryColumn(path), statistics.genericGetMax().asInstanceOf[Binary])) - } - - // Returns filters that can be pushed down when reading Parquet files. - def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { - filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) - } - - private def convertibleFiltersHelper( - predicate: sources.Filter, - canPartialPushDown: Boolean): Option[sources.Filter] = { - predicate match { - case sources.And(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - (leftResultOptional, rightResultOptional) match { - case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) - case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) - case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) - case _ => None - } - - case sources.Or(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { - None - } else { - Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) - } - case sources.Not(pred) => - val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - resultOptional.map(sources.Not) - - case other => - if (createFilter(other).isDefined) { - Some(other) - } else { - None - } - } - } - - /** - * Converts data sources filters to Parquet filter predicates. - */ - def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { - createFilterHelper(predicate, canPartialPushDownConjuncts = true) - } - - // Parquet's type in the given file should be matched to the value's type - // in the pushed filter in order to push down the filter to Parquet. - private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { - value == null || (nameToParquetField(name).fieldType match { - case ParquetBooleanType => value.isInstanceOf[JBoolean] - case ParquetByteType | ParquetShortType | ParquetIntegerType => - if (isSpark34Plus) { - value match { - // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type - // Int. We don't create a filter if the value would overflow. - case _: JByte | _: JShort | _: Integer => true - case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue - case _ => false - } - } else { - // If not Spark 3.4+, we still following the old behavior as Spark does. - value.isInstanceOf[Number] - } - case ParquetLongType => value.isInstanceOf[JLong] - case ParquetFloatType => value.isInstanceOf[JFloat] - case ParquetDoubleType => value.isInstanceOf[JDouble] - case ParquetStringType => value.isInstanceOf[String] - case ParquetBinaryType => value.isInstanceOf[Array[Byte]] - case ParquetDateType => - value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] - case ParquetTimestampMicrosType | ParquetTimestampMillisType => - value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] - case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) => - isDecimalMatched(value, decimalType) - case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) => - isDecimalMatched(value, decimalType) - case ParquetSchemaType( - decimalType: DecimalLogicalTypeAnnotation, - FIXED_LEN_BYTE_ARRAY, - _) => - isDecimalMatched(value, decimalType) - case _ => false - }) - } - - // Decimal type must make sure that filter value's scale matched the file. - // If doesn't matched, which would cause data corruption. - private def isDecimalMatched( - value: Any, - decimalLogicalType: DecimalLogicalTypeAnnotation): Boolean = value match { - case decimal: JBigDecimal => - decimal.scale == decimalLogicalType.getScale - case _ => false - } - - private def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) - } - - /** - * @param predicate - * the input filter predicates. Not all the predicates can be pushed down. - * @param canPartialPushDownConjuncts - * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one - * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. - * @return - * the Parquet-native filter predicates that are eligible for pushdown. - */ - private def createFilterHelper( - predicate: sources.Filter, - canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { - // NOTE: - // - // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, - // which can be casted to `false` implicitly. Please refer to the `eval` method of these - // operators and the `PruneFilters` rule for details. - - // Hyukjin: - // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. - // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. - // The reason why I did this is, that the actual Parquet filter checks null-safe equality - // comparison. - // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because - // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. - // Probably I missed something and obviously this should be changed. - - predicate match { - case sources.IsNull(name) if canMakeFilterOn(name, null) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name, null) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, null)) - - case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => - makeEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => - makeNotEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.LessThan(name, value) if canMakeFilterOn(name, value) => - makeLt - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeLtEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => - makeGt - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeGtEq - .lift(nameToParquetField(name).fieldType) - .map(_(nameToParquetField(name).fieldNames, value)) - - case sources.And(lhs, rhs) => - // At here, it is not safe to just convert one side and remove the other side - // if we do not understand what the parent filters are. - // - // Here is an example used to explain the reason. - // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to - // convert b in ('1'). If we only convert a = 2, we will end up with a filter - // NOT(a = 2), which will generate wrong results. - // - // Pushing one side of AND down is only safe to do at the top level or in the child - // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate - // can be safely removed. - val lhsFilterOption = - createFilterHelper(lhs, canPartialPushDownConjuncts) - val rhsFilterOption = - createFilterHelper(rhs, canPartialPushDownConjuncts) - - (lhsFilterOption, rhsFilterOption) match { - case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) - case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) - case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) - case _ => None - } - - case sources.Or(lhs, rhs) => - // The Or predicate is convertible when both of its children can be pushed down. - // That is to say, if one/both of the children can be partially pushed down, the Or - // predicate can be partially pushed down as well. - // - // Here is an example used to explain the reason. - // Let's say we have - // (a1 AND a2) OR (b1 AND b2), - // a1 and b1 is convertible, while a2 and b2 is not. - // The predicate can be converted as - // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) - // As per the logical in And predicate, we can push down (a1 OR b1). - for { - lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) - rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) - } yield FilterApi.or(lhsFilter, rhsFilter) - - case sources.Not(pred) => - createFilterHelper(pred, canPartialPushDownConjuncts = false) - .map(FilterApi.not) - - case sources.In(name, values) - if pushDownInFilterThreshold > 0 && values.nonEmpty && - canMakeFilterOn(name, values.head) => - val fieldType = nameToParquetField(name).fieldType - val fieldNames = nameToParquetField(name).fieldNames - if (values.length <= pushDownInFilterThreshold) { - values.distinct - .flatMap { v => - makeEq.lift(fieldType).map(_(fieldNames, v)) - } - .reduceLeftOption(FilterApi.or) - } else if (canPartialPushDownConjuncts) { - val primitiveType = schema.getColumnDescription(fieldNames).getPrimitiveType - val statistics: ParquetStatistics[_] = ParquetStatistics.createStats(primitiveType) - if (values.contains(null)) { - Seq( - makeEq.lift(fieldType).map(_(fieldNames, null)), - makeInPredicate - .lift(fieldType) - .map(_(fieldNames, values.filter(_ != null), statistics))).flatten - .reduceLeftOption(FilterApi.or) - } else { - makeInPredicate.lift(fieldType).map(_(fieldNames, values, statistics)) - } - } else { - None - } - - case sources.StringStartsWith(name, prefix) - if pushDownStringPredicate && canMakeFilterOn(name, prefix) => - Option(prefix).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val strToBinary = Binary.fromReusedByteArray(v.getBytes) - private val size = strToBinary.length - - override def canDrop(statistics: Statistics[Binary]): Boolean = { - val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR - val max = statistics.getMax - val min = statistics.getMin - comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || - comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 - } - - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { - val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR - val max = statistics.getMax - val min = statistics.getMin - comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && - comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 - } - - override def keep(value: Binary): Boolean = { - value != null && UTF8String - .fromBytes(value.getBytes) - .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) - } - }) - } - - case sources.StringEndsWith(name, suffix) - if pushDownStringPredicate && canMakeFilterOn(name, suffix) => - Option(suffix).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val suffixStr = UTF8String.fromString(v) - override def canDrop(statistics: Statistics[Binary]): Boolean = false - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false - override def keep(value: Binary): Boolean = { - value != null && UTF8String.fromBytes(value.getBytes).endsWith(suffixStr) - } - }) - } - - case sources.StringContains(name, value) - if pushDownStringPredicate && canMakeFilterOn(name, value) => - Option(value).map { v => - FilterApi.userDefined( - binaryColumn(nameToParquetField(name).fieldNames), - new UserDefinedPredicate[Binary] with Serializable { - private val subStr = UTF8String.fromString(v) - override def canDrop(statistics: Statistics[Binary]): Boolean = false - override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false - override def keep(value: Binary): Boolean = { - value != null && UTF8String.fromBytes(value.getBytes).contains(subStr) - } - }) - } - - case _ => None - } - } -} diff --git a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala deleted file mode 100644 index 14a664108..000000000 --- a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/CometScanExec.scala +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet - -import scala.collection.mutable.HashMap -import scala.concurrent.duration.NANOSECONDS -import scala.reflect.ClassTag - -import org.apache.hadoop.fs.Path -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.metric._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection._ - -import org.apache.comet.{CometConf, MetricsSupport} -import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} -import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} - -/** - * Comet physical scan node for DataSource V1. Most of the code here follow Spark's - * [[FileSourceScanExec]], - */ -case class CometScanExec( - @transient relation: HadoopFsRelation, - output: Seq[Attribute], - requiredSchema: StructType, - partitionFilters: Seq[Expression], - optionalBucketSet: Option[BitSet], - optionalNumCoalescedBuckets: Option[Int], - dataFilters: Seq[Expression], - tableIdentifier: Option[TableIdentifier], - disableBucketedScan: Boolean = false, - wrapped: FileSourceScanExec) - extends DataSourceScanExec - with ShimCometScanExec - with CometPlan { - - // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests - override lazy val supportsColumnar: Boolean = - relation.fileFormat.supportBatch(relation.sparkSession, schema) - - override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes - - private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty - - /** - * Send the driver-side metrics. Before calling this function, selectedPartitions has been - * initialized. See SPARK-26327 for more details. - */ - private def sendDriverMetrics(): Unit = { - driverMetrics.foreach(e => metrics(e._1).add(e._2)) - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates( - sparkContext, - executionId, - metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) - } - - private def isDynamicPruningFilter(e: Expression): Boolean = - e.find(_.isInstanceOf[PlanExpression[_]]).isDefined - - @transient lazy val selectedPartitions: Array[PartitionDirectory] = { - val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) - val startTime = System.nanoTime() - val ret = - relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) - setFilesNumAndSizeMetric(ret, true) - val timeTakenMs = - NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) - driverMetrics("metadataTime") = timeTakenMs - ret - }.toArray - - // We can only determine the actual partitions at runtime when a dynamic partition filter is - // present. This is because such a filter relies on information that is only available at run - // time (for instance the keys used in the other side of a join). - @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { - val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) - - if (dynamicPartitionFilters.nonEmpty) { - val startTime = System.nanoTime() - // call the file index for the files matching all filters except dynamic partition filters - val predicate = dynamicPartitionFilters.reduce(And) - val partitionColumns = relation.partitionSchema - val boundPredicate = Predicate.create( - predicate.transform { case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }, - Nil) - val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) - setFilesNumAndSizeMetric(ret, false) - val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 - driverMetrics("pruningTime") = timeTakenMs - ret - } else { - selectedPartitions - } - } - - // exposed for testing - lazy val bucketedScan: Boolean = wrapped.bucketedScan - - override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = - (wrapped.outputPartitioning, wrapped.outputOrdering) - - @transient - private lazy val pushedDownFilters = { - val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) - } - - override lazy val metadata: Map[String, String] = - if (wrapped == null) Map.empty else wrapped.metadata - - override def verboseStringWithOperatorId(): String = { - getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) - wrapped.verboseStringWithOperatorId() - } - - lazy val inputRDD: RDD[InternalRow] = { - val options = relation.options + - (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) - val readFile: (PartitionedFile) => Iterator[InternalRow] = - relation.fileFormat.buildReaderWithPartitionValues( - sparkSession = relation.sparkSession, - dataSchema = relation.dataSchema, - partitionSchema = relation.partitionSchema, - requiredSchema = requiredSchema, - filters = pushedDownFilters, - options = options, - hadoopConf = - relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) - - val readRDD = if (bucketedScan) { - createBucketedReadRDD( - relation.bucketSpec.get, - readFile, - dynamicallySelectedPartitions, - relation) - } else { - createReadRDD(readFile, dynamicallySelectedPartitions, relation) - } - sendDriverMetrics() - readRDD - } - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - inputRDD :: Nil - } - - /** Helper for computing total number and size of files in selected partitions. */ - private def setFilesNumAndSizeMetric( - partitions: Seq[PartitionDirectory], - static: Boolean): Unit = { - val filesNum = partitions.map(_.files.size.toLong).sum - val filesSize = partitions.map(_.files.map(_.getLen).sum).sum - if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { - driverMetrics("numFiles") = filesNum - driverMetrics("filesSize") = filesSize - } else { - driverMetrics("staticFilesNum") = filesNum - driverMetrics("staticFilesSize") = filesSize - } - if (relation.partitionSchema.nonEmpty) { - driverMetrics("numPartitions") = partitions.length - } - } - - override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { - // Tracking scan time has overhead, we can't afford to do it for each row, and can only do - // it for each batch. - if (supportsColumnar) { - Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) - } else { - None - } - } ++ { - relation.fileFormat match { - case f: MetricsSupport => f.initMetrics(sparkContext) - case _ => Map.empty - } - } - - protected override def doExecute(): RDD[InternalRow] = { - ColumnarToRowExec(this).doExecute() - } - - protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val numOutputRows = longMetric("numOutputRows") - val scanTime = longMetric("scanTime") - inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => - new Iterator[ColumnarBatch] { - - override def hasNext: Boolean = { - // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. - val startNs = System.nanoTime() - val res = batches.hasNext - scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) - res - } - - override def next(): ColumnarBatch = { - val batch = batches.next() - numOutputRows += batch.numRows() - batch - } - } - } - } - - override def executeCollect(): Array[InternalRow] = { - ColumnarToRowExec(this).executeCollect() - } - - override val nodeName: String = - s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" - - /** - * Create an RDD for bucketed reads. The non-bucketed variant of this function is - * [[createReadRDD]]. - * - * The algorithm is pretty simple: each RDD partition being returned should include all the - * files with the same bucket id from all the given Hive partitions. - * - * @param bucketSpec - * the bucketing spec. - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createBucketedReadRDD( - bucketSpec: BucketSpec, - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val filesGroupedToBuckets = - selectedPartitions - .flatMap { p => - p.files.map { f => - PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) - } - } - .groupBy { f => - BucketingUtils - .getBucketId(new Path(f.filePath.toString()).getName) - .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) - } - - val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { - val bucketSet = optionalBucketSet.get - filesGroupedToBuckets.filter { f => - bucketSet.get(f._1) - } - } else { - filesGroupedToBuckets - } - - val filePartitions = optionalNumCoalescedBuckets - .map { numCoalescedBuckets => - logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") - val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) - Seq.tabulate(numCoalescedBuckets) { bucketId => - val partitionedFiles = coalescedBuckets - .get(bucketId) - .map { - _.values.flatten.toArray - } - .getOrElse(Array.empty) - FilePartition(bucketId, partitionedFiles) - } - } - .getOrElse { - Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) - } - } - - prepareRDD(fsRelation, readFile, filePartitions) - } - - /** - * Create an RDD for non-bucketed reads. The bucketed variant of this function is - * [[createBucketedReadRDD]]. - * - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createReadRDD( - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes - val maxSplitBytes = - FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) - logInfo( - s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + - s"open cost is considered as scanning $openCostInBytes bytes.") - - // Filter files with bucket pruning if possible - val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled - val shouldProcess: Path => Boolean = optionalBucketSet match { - case Some(bucketSet) if bucketingEnabled => - // Do not prune the file if bucket file name is invalid - filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) - case _ => - _ => true - } - - val splitFiles = selectedPartitions - .flatMap { partition => - partition.files.flatMap { file => - // getPath() is very expensive so we only want to call it once in this block: - val filePath = file.getPath - - if (shouldProcess(filePath)) { - val isSplitable = relation.fileFormat.isSplitable( - relation.sparkSession, - relation.options, - filePath) && - // SPARK-39634: Allow file splitting in combination with row index generation once - // the fix for PARQUET-2161 is available. - !isNeededForSchema(requiredSchema) - PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, - file = file, - filePath = filePath, - isSplitable = isSplitable, - maxSplitBytes = maxSplitBytes, - partitionValues = partition.values) - } else { - Seq.empty - } - } - } - .sortBy(_.length)(implicitly[Ordering[Long]].reverse) - - prepareRDD( - fsRelation, - readFile, - FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) - } - - private def prepareRDD( - fsRelation: HadoopFsRelation, - readFile: (PartitionedFile) => Iterator[InternalRow], - partitions: Seq[FilePartition]): RDD[InternalRow] = { - val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) - val prefetchEnabled = hadoopConf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - - val sqlConf = fsRelation.sparkSession.sessionState.conf - if (prefetchEnabled) { - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedConf = - fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val partitionReaderFactory = CometParquetPartitionReaderFactory( - sqlConf, - broadcastedConf, - requiredSchema, - relation.partitionSchema, - pushedDownFilters.toArray, - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), - metrics) - - newDataSourceRDD( - fsRelation.sparkSession.sparkContext, - partitions.map(Seq(_)), - partitionReaderFactory, - true, - Map.empty) - } else { - newFileScanRDD( - fsRelation.sparkSession, - readFile, - partitions, - new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) - } - } - - // Filters unused DynamicPruningExpression expressions - one which has been replaced - // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning - private def filterUnusedDynamicPruningExpressions( - predicates: Seq[Expression]): Seq[Expression] = { - predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) - } - - override def doCanonicalize(): CometScanExec = { - CometScanExec( - relation, - output.map(QueryPlan.normalizeExpressions(_, output)), - requiredSchema, - QueryPlan.normalizePredicates( - filterUnusedDynamicPruningExpressions(partitionFilters), - output), - optionalBucketSet, - optionalNumCoalescedBuckets, - QueryPlan.normalizePredicates(dataFilters, output), - None, - disableBucketedScan, - null) - } -} - -object CometScanExec { - def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { - // TreeNode.mapProductIterator is protected method. - def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { - val arr = Array.ofDim[B](product.productArity) - var i = 0 - while (i < arr.length) { - arr(i) = f(product.productElement(i)) - i += 1 - } - arr - } - - // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues - // on other Spark distributions if FileSourceScanExec constructor is changed. - // Using `makeCopy` to avoid the issue. - // https://github.com/apache/arrow-datafusion-comet/issues/190 - def transform(arg: Any): AnyRef = arg match { - case _: HadoopFsRelation => - scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) - case other: AnyRef => other - case null => null - } - val newArgs = mapProductIterator(scanExec, transform(_)) - val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] - val batchScanExec = CometScanExec( - wrapped.relation, - wrapped.output, - wrapped.requiredSchema, - wrapped.partitionFilters, - wrapped.optionalBucketSet, - wrapped.optionalNumCoalescedBuckets, - wrapped.dataFilters, - wrapped.tableIdentifier, - wrapped.disableBucketedScan, - wrapped) - scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) - batchScanExec - } -} diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala deleted file mode 100644 index 8a017fabe..000000000 --- a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet - -import scala.collection.mutable.HashMap -import scala.concurrent.duration.NANOSECONDS -import scala.reflect.ClassTag - -import org.apache.hadoop.fs.Path -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.metric._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection._ - -import org.apache.comet.{CometConf, MetricsSupport} -import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} -import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} - -/** - * Comet physical scan node for DataSource V1. Most of the code here follow Spark's - * [[FileSourceScanExec]], - */ -case class CometScanExec( - @transient relation: HadoopFsRelation, - output: Seq[Attribute], - requiredSchema: StructType, - partitionFilters: Seq[Expression], - optionalBucketSet: Option[BitSet], - optionalNumCoalescedBuckets: Option[Int], - dataFilters: Seq[Expression], - tableIdentifier: Option[TableIdentifier], - disableBucketedScan: Boolean = false, - wrapped: FileSourceScanExec) - extends DataSourceScanExec - with ShimCometScanExec - with CometPlan { - - // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests - override lazy val supportsColumnar: Boolean = - relation.fileFormat.supportBatch(relation.sparkSession, schema) - - override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes - - private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty - - /** - * Send the driver-side metrics. Before calling this function, selectedPartitions has been - * initialized. See SPARK-26327 for more details. - */ - private def sendDriverMetrics(): Unit = { - driverMetrics.foreach(e => metrics(e._1).add(e._2)) - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates( - sparkContext, - executionId, - metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) - } - - private def isDynamicPruningFilter(e: Expression): Boolean = - e.find(_.isInstanceOf[PlanExpression[_]]).isDefined - - @transient lazy val selectedPartitions: Array[PartitionDirectory] = { - val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) - val startTime = System.nanoTime() - val ret = - relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) - setFilesNumAndSizeMetric(ret, true) - val timeTakenMs = - NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) - driverMetrics("metadataTime") = timeTakenMs - ret - }.toArray - - // We can only determine the actual partitions at runtime when a dynamic partition filter is - // present. This is because such a filter relies on information that is only available at run - // time (for instance the keys used in the other side of a join). - @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { - val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) - - if (dynamicPartitionFilters.nonEmpty) { - val startTime = System.nanoTime() - // call the file index for the files matching all filters except dynamic partition filters - val predicate = dynamicPartitionFilters.reduce(And) - val partitionColumns = relation.partitionSchema - val boundPredicate = Predicate.create( - predicate.transform { case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }, - Nil) - val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) - setFilesNumAndSizeMetric(ret, false) - val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 - driverMetrics("pruningTime") = timeTakenMs - ret - } else { - selectedPartitions - } - } - - // exposed for testing - lazy val bucketedScan: Boolean = wrapped.bucketedScan - - override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = - (wrapped.outputPartitioning, wrapped.outputOrdering) - - @transient - private lazy val pushedDownFilters = { - val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) - } - - override lazy val metadata: Map[String, String] = - if (wrapped == null) Map.empty else wrapped.metadata - - override def verboseStringWithOperatorId(): String = { - getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) - wrapped.verboseStringWithOperatorId() - } - - lazy val inputRDD: RDD[InternalRow] = { - val options = relation.options + - (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) - val readFile: (PartitionedFile) => Iterator[InternalRow] = - relation.fileFormat.buildReaderWithPartitionValues( - sparkSession = relation.sparkSession, - dataSchema = relation.dataSchema, - partitionSchema = relation.partitionSchema, - requiredSchema = requiredSchema, - filters = pushedDownFilters, - options = options, - hadoopConf = - relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) - - val readRDD = if (bucketedScan) { - createBucketedReadRDD( - relation.bucketSpec.get, - readFile, - dynamicallySelectedPartitions, - relation) - } else { - createReadRDD(readFile, dynamicallySelectedPartitions, relation) - } - sendDriverMetrics() - readRDD - } - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - inputRDD :: Nil - } - - /** Helper for computing total number and size of files in selected partitions. */ - private def setFilesNumAndSizeMetric( - partitions: Seq[PartitionDirectory], - static: Boolean): Unit = { - val filesNum = partitions.map(_.files.size.toLong).sum - val filesSize = partitions.map(_.files.map(_.getLen).sum).sum - if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { - driverMetrics("numFiles") = filesNum - driverMetrics("filesSize") = filesSize - } else { - driverMetrics("staticFilesNum") = filesNum - driverMetrics("staticFilesSize") = filesSize - } - if (relation.partitionSchema.nonEmpty) { - driverMetrics("numPartitions") = partitions.length - } - } - - override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { - // Tracking scan time has overhead, we can't afford to do it for each row, and can only do - // it for each batch. - if (supportsColumnar) { - Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) - } else { - None - } - } ++ { - relation.fileFormat match { - case f: MetricsSupport => f.initMetrics(sparkContext) - case _ => Map.empty - } - } - - protected override def doExecute(): RDD[InternalRow] = { - ColumnarToRowExec(this).doExecute() - } - - protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val numOutputRows = longMetric("numOutputRows") - val scanTime = longMetric("scanTime") - inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => - new Iterator[ColumnarBatch] { - - override def hasNext: Boolean = { - // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. - val startNs = System.nanoTime() - val res = batches.hasNext - scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) - res - } - - override def next(): ColumnarBatch = { - val batch = batches.next() - numOutputRows += batch.numRows() - batch - } - } - } - } - - override def executeCollect(): Array[InternalRow] = { - ColumnarToRowExec(this).executeCollect() - } - - override val nodeName: String = - s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" - - /** - * Create an RDD for bucketed reads. The non-bucketed variant of this function is - * [[createReadRDD]]. - * - * The algorithm is pretty simple: each RDD partition being returned should include all the - * files with the same bucket id from all the given Hive partitions. - * - * @param bucketSpec - * the bucketing spec. - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createBucketedReadRDD( - bucketSpec: BucketSpec, - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val filesGroupedToBuckets = - selectedPartitions - .flatMap { p => - p.files.map { f => - PartitionedFileUtil.getPartitionedFile(f, /*f.getPath,*/ p.values) - } - } - .groupBy { f => - BucketingUtils - .getBucketId(new Path(f.filePath.toString()).getName) - .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) - } - - val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { - val bucketSet = optionalBucketSet.get - filesGroupedToBuckets.filter { f => - bucketSet.get(f._1) - } - } else { - filesGroupedToBuckets - } - - val filePartitions = optionalNumCoalescedBuckets - .map { numCoalescedBuckets => - logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") - val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) - Seq.tabulate(numCoalescedBuckets) { bucketId => - val partitionedFiles = coalescedBuckets - .get(bucketId) - .map { - _.values.flatten.toArray - } - .getOrElse(Array.empty) - FilePartition(bucketId, partitionedFiles) - } - } - .getOrElse { - Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) - } - } - - prepareRDD(fsRelation, readFile, filePartitions) - } - - /** - * Create an RDD for non-bucketed reads. The bucketed variant of this function is - * [[createBucketedReadRDD]]. - * - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createReadRDD( - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes - val maxSplitBytes = - FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) - logInfo( - s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + - s"open cost is considered as scanning $openCostInBytes bytes.") - - // Filter files with bucket pruning if possible - val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled - val shouldProcess: Path => Boolean = optionalBucketSet match { - case Some(bucketSet) if bucketingEnabled => - // Do not prune the file if bucket file name is invalid - filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) - case _ => - _ => true - } - - val splitFiles = selectedPartitions - .flatMap { partition => - partition.files.flatMap { file => - // getPath() is very expensive so we only want to call it once in this block: - val filePath = file.getPath - - if (shouldProcess(filePath)) { - val isSplitable = relation.fileFormat.isSplitable( - relation.sparkSession, - relation.options, - filePath) && - // SPARK-39634: Allow file splitting in combination with row index generation once - // the fix for PARQUET-2161 is available. - !isNeededForSchema(requiredSchema) - PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, - file = file, - //filePath = filePath, - isSplitable = isSplitable, - maxSplitBytes = maxSplitBytes, - partitionValues = partition.values) - } else { - Seq.empty - } - } - } - .sortBy(_.length)(implicitly[Ordering[Long]].reverse) - - prepareRDD( - fsRelation, - readFile, - FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) - } - - private def prepareRDD( - fsRelation: HadoopFsRelation, - readFile: (PartitionedFile) => Iterator[InternalRow], - partitions: Seq[FilePartition]): RDD[InternalRow] = { - val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) - val prefetchEnabled = hadoopConf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - - val sqlConf = fsRelation.sparkSession.sessionState.conf - if (prefetchEnabled) { - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedConf = - fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val partitionReaderFactory = CometParquetPartitionReaderFactory( - sqlConf, - broadcastedConf, - requiredSchema, - relation.partitionSchema, - pushedDownFilters.toArray, - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), - metrics) - - newDataSourceRDD( - fsRelation.sparkSession.sparkContext, - partitions.map(Seq(_)), - partitionReaderFactory, - true, - Map.empty) - } else { - newFileScanRDD( - fsRelation.sparkSession, - readFile, - partitions, - new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) - } - } - - // Filters unused DynamicPruningExpression expressions - one which has been replaced - // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning - private def filterUnusedDynamicPruningExpressions( - predicates: Seq[Expression]): Seq[Expression] = { - predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) - } - - override def doCanonicalize(): CometScanExec = { - CometScanExec( - relation, - output.map(QueryPlan.normalizeExpressions(_, output)), - requiredSchema, - QueryPlan.normalizePredicates( - filterUnusedDynamicPruningExpressions(partitionFilters), - output), - optionalBucketSet, - optionalNumCoalescedBuckets, - QueryPlan.normalizePredicates(dataFilters, output), - None, - disableBucketedScan, - null) - } -} - -object CometScanExec { - def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { - // TreeNode.mapProductIterator is protected method. - def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { - val arr = Array.ofDim[B](product.productArity) - var i = 0 - while (i < arr.length) { - arr(i) = f(product.productElement(i)) - i += 1 - } - arr - } - - // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues - // on other Spark distributions if FileSourceScanExec constructor is changed. - // Using `makeCopy` to avoid the issue. - // https://github.com/apache/arrow-datafusion-comet/issues/190 - def transform(arg: Any): AnyRef = arg match { - case _: HadoopFsRelation => - scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) - case other: AnyRef => other - case null => null - } - val newArgs = mapProductIterator(scanExec, transform(_)) - val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] - val batchScanExec = CometScanExec( - wrapped.relation, - wrapped.output, - wrapped.requiredSchema, - wrapped.partitionFilters, - wrapped.optionalBucketSet, - wrapped.optionalNumCoalescedBuckets, - wrapped.dataFilters, - wrapped.tableIdentifier, - wrapped.disableBucketedScan, - wrapped) - scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) - batchScanExec - } -} diff --git a/spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala similarity index 100% rename from spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetFileFormat.scala rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetFileFormat.scala diff --git a/spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala similarity index 100% rename from spark/src/main/spark-3.2/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala diff --git a/spark/src/main/spark-3.2/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/spark-pre-3.5/org/apache/comet/parquet/ParquetFilters.scala similarity index 100% rename from spark/src/main/spark-3.2/org/apache/comet/parquet/ParquetFilters.scala rename to spark/src/main/spark-pre-3.5/org/apache/comet/parquet/ParquetFilters.scala diff --git a/spark/src/main/spark-3.2/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala similarity index 100% rename from spark/src/main/spark-3.2/org/apache/spark/sql/comet/CometScanExec.scala rename to spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala From ea252bb87b57730bd1c495c90bd464fc70551e03 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 11 Jun 2024 08:31:14 -0600 Subject: [PATCH 07/12] save progress --- pom.xml | 34 +- spark/pom.xml | 2 + .../spark/sql/comet/CometScanExec.scala | 483 ++++++++++++++++++ 3 files changed, 506 insertions(+), 13 deletions(-) create mode 100644 spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala diff --git a/pom.xml b/pom.xml index 1c0961059..519c0c8cc 100644 --- a/pom.xml +++ b/pom.xml @@ -91,6 +91,7 @@ under the License. spark-3.4-plus spark-3.x spark-3.4 + spark-pre-3.5 @@ -528,6 +529,7 @@ under the License. not-needed-yet not-needed-yet spark-3.2 + spark-pre-3.5 @@ -540,6 +542,7 @@ under the License. 1.12.0 not-needed-yet spark-3.3 + spark-pre-3.5 @@ -549,6 +552,24 @@ under the License. 2.12.17 3.4 1.13.1 + spark-3.4 + spark-pre-3.5 + + + + + + spark-3.5 + + 2.12.17 + 3.5.1 + 3.5 + 1.13.1 + + + spark-3.5 + not-needed + @@ -571,19 +592,6 @@ under the License. - - spark-3.5 - - 2.12.17 - 3.5.1 - 3.5 - 1.13.1 - spark-3.3-plus - spark-3.5 - spark-3.5 - - - scala-2.13 diff --git a/spark/pom.xml b/spark/pom.xml index 84e2e501f..8585b4b8c 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -254,6 +254,7 @@ under the License. src/test/${additional.3_4.test.source} src/test/${shims.majorVerSrc} src/test/${shims.minorVerSrc} + src/test/${shims.extraVerSrc} @@ -267,6 +268,7 @@ under the License. src/main/${shims.majorVerSrc} src/main/${shims.minorVerSrc} + src/main/${shims.extraVerSrc} diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala new file mode 100644 index 000000000..777248b41 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/CometScanExec.scala @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.mutable.HashMap +import scala.concurrent.duration.NANOSECONDS +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection._ + +import org.apache.comet.{CometConf, MetricsSupport} +import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} +import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} + +/** + * Comet physical scan node for DataSource V1. Most of the code here follow Spark's + * [[FileSourceScanExec]], + */ +case class CometScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false, + wrapped: FileSourceScanExec) + extends DataSourceScanExec + with ShimCometScanExec + with CometPlan { + + // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests + override lazy val supportsColumnar: Boolean = + relation.fileFormat.supportBatch(relation.sparkSession, schema) + + override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes + + private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has been + * initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) + val startTime = System.nanoTime() + val ret = + relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + setFilesNumAndSizeMetric(ret, true) + val timeTakenMs = + NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) + driverMetrics("metadataTime") = timeTakenMs + ret + }.toArray + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create( + predicate.transform { case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, + Nil) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } + + // exposed for testing + lazy val bucketedScan: Boolean = wrapped.bucketedScan + + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = + (wrapped.outputPartitioning, wrapped.outputOrdering) + + @transient + private lazy val pushedDownFilters = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } + + override lazy val metadata: Map[String, String] = + if (wrapped == null) Map.empty else wrapped.metadata + + override def verboseStringWithOperatorId(): String = { + getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) + wrapped.verboseStringWithOperatorId() + } + + lazy val inputRDD: RDD[InternalRow] = { + val options = relation.options + + (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) + val readFile: (PartitionedFile) => Iterator[InternalRow] = + relation.fileFormat.buildReaderWithPartitionValues( + sparkSession = relation.sparkSession, + dataSchema = relation.dataSchema, + partitionSchema = relation.partitionSchema, + requiredSchema = requiredSchema, + filters = pushedDownFilters, + options = options, + hadoopConf = + relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) + + val readRDD = if (bucketedScan) { + createBucketedReadRDD( + relation.bucketSpec.get, + readFile, + dynamicallySelectedPartitions, + relation) + } else { + createReadRDD(readFile, dynamicallySelectedPartitions, relation) + } + sendDriverMetrics() + readRDD + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + /** Helper for computing total number and size of files in selected partitions. */ + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchema.nonEmpty) { + driverMetrics("numPartitions") = partitions.length + } + } + + override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + relation.fileFormat match { + case f: MetricsSupport => f.initMetrics(sparkContext) + case _ => Map.empty + } + } + + protected override def doExecute(): RDD[InternalRow] = { + ColumnarToRowExec(this).doExecute() + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = batches.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + + override def next(): ColumnarBatch = { + val batch = batches.next() + numOutputRows += batch.numRows() + batch + } + } + } + } + + override def executeCollect(): Array[InternalRow] = { + ColumnarToRowExec(this).executeCollect() + } + + override val nodeName: String = + s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" + + /** + * Create an RDD for bucketed reads. The non-bucketed variant of this function is + * [[createReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the + * files with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec + * the bucketing spec. + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createBucketedReadRDD( + bucketSpec: BucketSpec, + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") + val filesGroupedToBuckets = + selectedPartitions + .flatMap { p => + p.files.map { f => + PartitionedFileUtil.getPartitionedFile(f, /*f.getPath,*/ p.values) + } + } + .groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath.toString()).getName) + .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) + } + + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { f => + bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets + .map { numCoalescedBuckets => + logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { bucketId => + val partitionedFiles = coalescedBuckets + .get(bucketId) + .map { + _.values.flatten.toArray + } + .getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + } + .getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + + prepareRDD(fsRelation, readFile, filePartitions) + } + + /** + * Create an RDD for non-bucketed reads. The bucketed variant of this function is + * [[createBucketedReadRDD]]. + * + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createReadRDD( + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) + logInfo( + s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + // Filter files with bucket pruning if possible + val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + + val splitFiles = selectedPartitions + .flatMap { partition => + partition.files.flatMap { file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + + if (shouldProcess(filePath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, + relation.options, + filePath) && + // SPARK-39634: Allow file splitting in combination with row index generation once + // the fix for PARQUET-2161 is available. + !isNeededForSchema(requiredSchema) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + //filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values) + } else { + Seq.empty + } + } + } + .sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + prepareRDD( + fsRelation, + readFile, + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) + } + + private def prepareRDD( + fsRelation: HadoopFsRelation, + readFile: (PartitionedFile) => Iterator[InternalRow], + partitions: Seq[FilePartition]): RDD[InternalRow] = { + val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) + val prefetchEnabled = hadoopConf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + + val sqlConf = fsRelation.sparkSession.sessionState.conf + if (prefetchEnabled) { + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedConf = + fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val partitionReaderFactory = CometParquetPartitionReaderFactory( + sqlConf, + broadcastedConf, + requiredSchema, + relation.partitionSchema, + pushedDownFilters.toArray, + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), + metrics) + + newDataSourceRDD( + fsRelation.sparkSession.sparkContext, + partitions.map(Seq(_)), + partitionReaderFactory, + true, + Map.empty) + } else { + newFileScanRDD( + fsRelation.sparkSession, + readFile, + partitions, + new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) + } + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + private def filterUnusedDynamicPruningExpressions( + predicates: Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } + + override def doCanonicalize(): CometScanExec = { + CometScanExec( + relation, + output.map(QueryPlan.normalizeExpressions(_, output)), + requiredSchema, + QueryPlan.normalizePredicates( + filterUnusedDynamicPruningExpressions(partitionFilters), + output), + optionalBucketSet, + optionalNumCoalescedBuckets, + QueryPlan.normalizePredicates(dataFilters, output), + None, + disableBucketedScan, + null) + } +} + +object CometScanExec { + def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { + // TreeNode.mapProductIterator is protected method. + def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { + val arr = Array.ofDim[B](product.productArity) + var i = 0 + while (i < arr.length) { + arr(i) = f(product.productElement(i)) + i += 1 + } + arr + } + + // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues + // on other Spark distributions if FileSourceScanExec constructor is changed. + // Using `makeCopy` to avoid the issue. + // https://github.com/apache/arrow-datafusion-comet/issues/190 + def transform(arg: Any): AnyRef = arg match { + case _: HadoopFsRelation => + scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) + case other: AnyRef => other + case null => null + } + val newArgs = mapProductIterator(scanExec, transform(_)) + val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] + val batchScanExec = CometScanExec( + wrapped.relation, + wrapped.output, + wrapped.requiredSchema, + wrapped.partitionFilters, + wrapped.optionalBucketSet, + wrapped.optionalNumCoalescedBuckets, + wrapped.dataFilters, + wrapped.tableIdentifier, + wrapped.disableBucketedScan, + wrapped) + scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) + batchScanExec + } +} \ No newline at end of file From e642e16e152e89031760d4a4acc7db00c726a435 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 11 Jun 2024 08:43:54 -0600 Subject: [PATCH 08/12] remove duplicate class --- .../spark/sql/comet/CometScanExec.scala | 484 ------------------ .../spark/sql/comet/CometScanExec.scala | 9 +- 2 files changed, 5 insertions(+), 488 deletions(-) delete mode 100644 spark/src/main/spark-3.3/org/apache/spark/sql/comet/CometScanExec.scala diff --git a/spark/src/main/spark-3.3/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-3.3/org/apache/spark/sql/comet/CometScanExec.scala deleted file mode 100644 index 9a5b55d65..000000000 --- a/spark/src/main/spark-3.3/org/apache/spark/sql/comet/CometScanExec.scala +++ /dev/null @@ -1,484 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet - -import scala.collection.mutable.HashMap -import scala.concurrent.duration.NANOSECONDS -import scala.reflect.ClassTag - -import org.apache.hadoop.fs.Path -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.comet.shims.ShimCometScanExec -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.metric._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection._ - -import org.apache.comet.{CometConf, MetricsSupport} -import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} -import org.apache.comet.shims.ShimFileFormat - -/** - * Comet physical scan node for DataSource V1. Most of the code here follow Spark's - * [[FileSourceScanExec]], - */ -case class CometScanExec( - @transient relation: HadoopFsRelation, - output: Seq[Attribute], - requiredSchema: StructType, - partitionFilters: Seq[Expression], - optionalBucketSet: Option[BitSet], - optionalNumCoalescedBuckets: Option[Int], - dataFilters: Seq[Expression], - tableIdentifier: Option[TableIdentifier], - disableBucketedScan: Boolean = false, - wrapped: FileSourceScanExec) - extends DataSourceScanExec - with ShimCometScanExec - with CometPlan { - - // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests - override lazy val supportsColumnar: Boolean = - relation.fileFormat.supportBatch(relation.sparkSession, schema) - - override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes - - private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty - - /** - * Send the driver-side metrics. Before calling this function, selectedPartitions has been - * initialized. See SPARK-26327 for more details. - */ - private def sendDriverMetrics(): Unit = { - driverMetrics.foreach(e => metrics(e._1).add(e._2)) - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates( - sparkContext, - executionId, - metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) - } - - private def isDynamicPruningFilter(e: Expression): Boolean = - e.find(_.isInstanceOf[PlanExpression[_]]).isDefined - - @transient lazy val selectedPartitions: Array[PartitionDirectory] = { - val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) - val startTime = System.nanoTime() - val ret = - relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) - setFilesNumAndSizeMetric(ret, true) - val timeTakenMs = - NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) - driverMetrics("metadataTime") = timeTakenMs - ret - }.toArray - - // We can only determine the actual partitions at runtime when a dynamic partition filter is - // present. This is because such a filter relies on information that is only available at run - // time (for instance the keys used in the other side of a join). - @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { - val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) - - if (dynamicPartitionFilters.nonEmpty) { - val startTime = System.nanoTime() - // call the file index for the files matching all filters except dynamic partition filters - val predicate = dynamicPartitionFilters.reduce(And) - val partitionColumns = relation.partitionSchema - val boundPredicate = Predicate.create( - predicate.transform { case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }, - Nil) - val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) - setFilesNumAndSizeMetric(ret, false) - val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 - driverMetrics("pruningTime") = timeTakenMs - ret - } else { - selectedPartitions - } - } - - // exposed for testing - lazy val bucketedScan: Boolean = wrapped.bucketedScan - - override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = - (wrapped.outputPartitioning, wrapped.outputOrdering) - - @transient - private lazy val pushedDownFilters = { - val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) - } - - override lazy val metadata: Map[String, String] = - if (wrapped == null) Map.empty else wrapped.metadata - - override def verboseStringWithOperatorId(): String = { - getTagValue(QueryPlan.OP_ID_TAG).foreach(id => wrapped.setTagValue(QueryPlan.OP_ID_TAG, id)) - wrapped.verboseStringWithOperatorId() - } - - lazy val inputRDD: RDD[InternalRow] = { - val options = relation.options + - (ShimFileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) - val readFile: (PartitionedFile) => Iterator[InternalRow] = - relation.fileFormat.buildReaderWithPartitionValues( - sparkSession = relation.sparkSession, - dataSchema = relation.dataSchema, - partitionSchema = relation.partitionSchema, - requiredSchema = requiredSchema, - filters = pushedDownFilters, - options = options, - hadoopConf = - relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) - - val readRDD = if (bucketedScan) { - createBucketedReadRDD( - relation.bucketSpec.get, - readFile, - dynamicallySelectedPartitions, - relation) - } else { - createReadRDD(readFile, dynamicallySelectedPartitions, relation) - } - sendDriverMetrics() - readRDD - } - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - inputRDD :: Nil - } - - /** Helper for computing total number and size of files in selected partitions. */ - private def setFilesNumAndSizeMetric( - partitions: Seq[PartitionDirectory], - static: Boolean): Unit = { - val filesNum = partitions.map(_.files.size.toLong).sum - val filesSize = partitions.map(_.files.map(_.getLen).sum).sum - if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { - driverMetrics("numFiles") = filesNum - driverMetrics("filesSize") = filesSize - } else { - driverMetrics("staticFilesNum") = filesNum - driverMetrics("staticFilesSize") = filesSize - } - if (relation.partitionSchema.nonEmpty) { - driverMetrics("numPartitions") = partitions.length - } - } - - override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics ++ { - // Tracking scan time has overhead, we can't afford to do it for each row, and can only do - // it for each batch. - if (supportsColumnar) { - Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) - } else { - None - } - } ++ { - relation.fileFormat match { - case f: MetricsSupport => f.initMetrics(sparkContext) - case _ => Map.empty - } - } - - protected override def doExecute(): RDD[InternalRow] = { - ColumnarToRowExec(this).doExecute() - } - - protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val numOutputRows = longMetric("numOutputRows") - val scanTime = longMetric("scanTime") - inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => - new Iterator[ColumnarBatch] { - - override def hasNext: Boolean = { - // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. - val startNs = System.nanoTime() - val res = batches.hasNext - scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) - res - } - - override def next(): ColumnarBatch = { - val batch = batches.next() - numOutputRows += batch.numRows() - batch - } - } - } - } - - override def executeCollect(): Array[InternalRow] = { - ColumnarToRowExec(this).executeCollect() - } - - override val nodeName: String = - s"CometScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" - - /** - * Create an RDD for bucketed reads. The non-bucketed variant of this function is - * [[createReadRDD]]. - * - * The algorithm is pretty simple: each RDD partition being returned should include all the - * files with the same bucket id from all the given Hive partitions. - * - * @param bucketSpec - * the bucketing spec. - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createBucketedReadRDD( - bucketSpec: BucketSpec, - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val filesGroupedToBuckets = - selectedPartitions - .flatMap { p => - p.files.map { f => - getPartitionedFile(f, p) - } - } - .groupBy { f => - BucketingUtils - .getBucketId(new Path(f.filePath.toString()).getName) - .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) - } - - val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { - val bucketSet = optionalBucketSet.get - filesGroupedToBuckets.filter { f => - bucketSet.get(f._1) - } - } else { - filesGroupedToBuckets - } - - val filePartitions = optionalNumCoalescedBuckets - .map { numCoalescedBuckets => - logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") - val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) - Seq.tabulate(numCoalescedBuckets) { bucketId => - val partitionedFiles = coalescedBuckets - .get(bucketId) - .map { - _.values.flatten.toArray - } - .getOrElse(Array.empty) - FilePartition(bucketId, partitionedFiles) - } - } - .getOrElse { - Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) - } - } - - prepareRDD(fsRelation, readFile, filePartitions) - } - - /** - * Create an RDD for non-bucketed reads. The bucketed variant of this function is - * [[createBucketedReadRDD]]. - * - * @param readFile - * a function to read each (part of a) file. - * @param selectedPartitions - * Hive-style partition that are part of the read. - * @param fsRelation - * [[HadoopFsRelation]] associated with the read. - */ - private def createReadRDD( - readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Array[PartitionDirectory], - fsRelation: HadoopFsRelation): RDD[InternalRow] = { - val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes - val maxSplitBytes = - FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) - logInfo( - s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + - s"open cost is considered as scanning $openCostInBytes bytes.") - - // Filter files with bucket pruning if possible - val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled - val shouldProcess: Path => Boolean = optionalBucketSet match { - case Some(bucketSet) if bucketingEnabled => - // Do not prune the file if bucket file name is invalid - filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) - case _ => - _ => true - } - - val splitFiles = selectedPartitions - .flatMap { partition => - partition.files.flatMap { file => - // getPath() is very expensive so we only want to call it once in this block: - val filePath = file.getPath - - if (shouldProcess(filePath)) { - val isSplitable = relation.fileFormat.isSplitable( - relation.sparkSession, - relation.options, - filePath) && - // SPARK-39634: Allow file splitting in combination with row index generation once - // the fix for PARQUET-2161 is available. - !isNeededForSchema(requiredSchema) - super.splitFiles( - sparkSession = relation.sparkSession, - file = file, - filePath = filePath, - isSplitable = isSplitable, - maxSplitBytes = maxSplitBytes, - partitionValues = partition.values) - } else { - Seq.empty - } - } - } - .sortBy(_.length)(implicitly[Ordering[Long]].reverse) - - prepareRDD( - fsRelation, - readFile, - FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) - } - - private def prepareRDD( - fsRelation: HadoopFsRelation, - readFile: (PartitionedFile) => Iterator[InternalRow], - partitions: Seq[FilePartition]): RDD[InternalRow] = { - val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) - val prefetchEnabled = hadoopConf.getBoolean( - CometConf.COMET_SCAN_PREFETCH_ENABLED.key, - CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) - - val sqlConf = fsRelation.sparkSession.sessionState.conf - if (prefetchEnabled) { - CometParquetFileFormat.populateConf(sqlConf, hadoopConf) - val broadcastedConf = - fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val partitionReaderFactory = CometParquetPartitionReaderFactory( - sqlConf, - broadcastedConf, - requiredSchema, - relation.partitionSchema, - pushedDownFilters.toArray, - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), - metrics) - - newDataSourceRDD( - fsRelation.sparkSession.sparkContext, - partitions.map(Seq(_)), - partitionReaderFactory, - true, - Map.empty) - } else { - newFileScanRDD( - fsRelation, - readFile, - partitions, - new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), - new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) - } - } - - // Filters unused DynamicPruningExpression expressions - one which has been replaced - // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning - private def filterUnusedDynamicPruningExpressions( - predicates: Seq[Expression]): Seq[Expression] = { - predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) - } - - override def doCanonicalize(): CometScanExec = { - CometScanExec( - relation, - output.map(QueryPlan.normalizeExpressions(_, output)), - requiredSchema, - QueryPlan.normalizePredicates( - filterUnusedDynamicPruningExpressions(partitionFilters), - output), - optionalBucketSet, - optionalNumCoalescedBuckets, - QueryPlan.normalizePredicates(dataFilters, output), - None, - disableBucketedScan, - null) - } -} - -object CometScanExec { - def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { - // TreeNode.mapProductIterator is protected method. - def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { - val arr = Array.ofDim[B](product.productArity) - var i = 0 - while (i < arr.length) { - arr(i) = f(product.productElement(i)) - i += 1 - } - arr - } - - // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues - // on other Spark distributions if FileSourceScanExec constructor is changed. - // Using `makeCopy` to avoid the issue. - // https://github.com/apache/arrow-datafusion-comet/issues/190 - def transform(arg: Any): AnyRef = arg match { - case _: HadoopFsRelation => - scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) - case other: AnyRef => other - case null => null - } - val newArgs = mapProductIterator(scanExec, transform(_)) - val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] - val batchScanExec = CometScanExec( - wrapped.relation, - wrapped.output, - wrapped.requiredSchema, - wrapped.partitionFilters, - wrapped.optionalBucketSet, - wrapped.optionalNumCoalescedBuckets, - wrapped.dataFilters, - wrapped.tableIdentifier, - wrapped.disableBucketedScan, - wrapped) - scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) - batchScanExec - } -} diff --git a/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala index 14a664108..9a5b55d65 100644 --- a/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala +++ b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/CometScanExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.comet.shims.ShimCometScanExec import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions @@ -43,7 +44,7 @@ import org.apache.spark.util.collection._ import org.apache.comet.{CometConf, MetricsSupport} import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} -import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat} +import org.apache.comet.shims.ShimFileFormat /** * Comet physical scan node for DataSource V1. Most of the code here follow Spark's @@ -271,7 +272,7 @@ case class CometScanExec( selectedPartitions .flatMap { p => p.files.map { f => - PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) + getPartitionedFile(f, p) } } .groupBy { f => @@ -358,7 +359,7 @@ case class CometScanExec( // SPARK-39634: Allow file splitting in combination with row index generation once // the fix for PARQUET-2161 is available. !isNeededForSchema(requiredSchema) - PartitionedFileUtil.splitFiles( + super.splitFiles( sparkSession = relation.sparkSession, file = file, filePath = filePath, @@ -409,7 +410,7 @@ case class CometScanExec( Map.empty) } else { newFileScanRDD( - fsRelation.sparkSession, + fsRelation, readFile, partitions, new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), From 594b08522f4d7d10af278359d25378c5a914a2aa Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 11 Jun 2024 08:51:59 -0600 Subject: [PATCH 09/12] spark 3.4 build works again --- .../execution/shuffle/CometShuffleExchangeExec.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index b94717cdf..a1b9a3677 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -41,11 +41,11 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan} import org.apache.spark.sql.comet.shims.ShimCometShuffleWriteProcessor import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} @@ -391,7 +391,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { val pageSize = SparkEnv.get.memoryManager.pageSizeBytes val sorter = UnsafeExternalRowSorter.createWithRecordComparator( - StructType(outputAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))), + StructType( + outputAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))), recordComparatorSupplier, prefixComparator, prefixComputer, @@ -435,7 +436,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { serializer, shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics), shuffleType = CometColumnarShuffle, - schema = Some(StructType(outputAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))) + schema = Some(StructType(outputAttributes.map(a => + StructField(a.name, a.dataType, a.nullable, a.metadata))))) dependency } } From 5eb01b0b5a0a641cadf0af4082b88da972f2c7c1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 11 Jun 2024 09:01:38 -0600 Subject: [PATCH 10/12] save progress --- pom.xml | 5 +- .../sql/comet/shims/ShimCometScanExec.scala | 141 ------------------ 2 files changed, 4 insertions(+), 142 deletions(-) delete mode 100644 spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala diff --git a/pom.xml b/pom.xml index 519c0c8cc..378e7144c 100644 --- a/pom.xml +++ b/pom.xml @@ -528,6 +528,7 @@ under the License. not-needed-yet not-needed-yet + spark-3.x spark-3.2 spark-pre-3.5 @@ -541,6 +542,7 @@ under the License. 3.3 1.12.0 not-needed-yet + spark-3.x spark-3.3 spark-pre-3.5 @@ -552,6 +554,7 @@ under the License. 2.12.17 3.4 1.13.1 + spark-3.x spark-3.4 spark-pre-3.5 @@ -567,9 +570,9 @@ under the License. 1.13.1 + spark-3.x spark-3.5 not-needed - diff --git a/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala deleted file mode 100644 index 02b97f9fb..000000000 --- a/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet.shims - -import org.apache.comet.shims.ShimFileFormat - -import scala.language.implicitConversions - -import org.apache.hadoop.fs.{FileStatus, Path} - -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} -import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} -import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile} -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.{LongType, StructField, StructType} - -trait ShimCometScanExec { - def wrapped: FileSourceScanExec - - // TODO: remove after dropping Spark 3.2 support and directly call wrapped.metadataColumns - lazy val metadataColumns: Seq[AttributeReference] = wrapped.getClass.getDeclaredMethods - .filter(_.getName == "metadataColumns") - .map { a => a.setAccessible(true); a } - .flatMap(_.invoke(wrapped).asInstanceOf[Seq[AttributeReference]]) - - // TODO: remove after dropping Spark 3.2 and 3.3 support and directly call - // wrapped.fileConstantMetadataColumns - lazy val fileConstantMetadataColumns: Seq[AttributeReference] = - wrapped.getClass.getDeclaredMethods - .filter(_.getName == "fileConstantMetadataColumns") - .map { a => a.setAccessible(true); a } - .flatMap(_.invoke(wrapped).asInstanceOf[Seq[AttributeReference]]) - - // TODO: remove after dropping Spark 3.2 support and directly call new DataSourceRDD - protected def newDataSourceRDD( - sc: SparkContext, - inputPartitions: Seq[Seq[InputPartition]], - partitionReaderFactory: PartitionReaderFactory, - columnarReads: Boolean, - customMetrics: Map[String, SQLMetric]): DataSourceRDD = { - implicit def flattenSeq(p: Seq[Seq[InputPartition]]): Seq[InputPartition] = p.flatten - new DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads, customMetrics) - } - - // TODO: remove after dropping Spark 3.2 support and directly call new FileScanRDD - protected def newFileScanRDD( - fsRelation: HadoopFsRelation, - readFunction: PartitionedFile => Iterator[InternalRow], - filePartitions: Seq[FilePartition], - readSchema: StructType, - options: ParquetOptions): FileScanRDD = - classOf[FileScanRDD].getDeclaredConstructors - // Prevent to pick up incorrect constructors from any custom Spark forks. - .filter(c => List(3, 5, 6).contains(c.getParameterCount()) ) - .map { c => - c.getParameterCount match { - case 3 => c.newInstance(fsRelation.sparkSession, readFunction, filePartitions) - case 5 => - c.newInstance(fsRelation.sparkSession, readFunction, filePartitions, readSchema, metadataColumns) - case 6 => - c.newInstance( - fsRelation.sparkSession, - readFunction, - filePartitions, - readSchema, - fileConstantMetadataColumns, - options) - } - } - .last - .asInstanceOf[FileScanRDD] - - // TODO: remove after dropping Spark 3.2 and 3.3 support and directly call - // QueryExecutionErrors.SparkException - protected def invalidBucketFile(path: String, sparkVersion: String): Throwable = { - if (sparkVersion >= "3.3") { - val messageParameters = if (sparkVersion >= "3.4") Map("path" -> path) else Array(path) - classOf[SparkException].getDeclaredConstructors - .filter(_.getParameterCount == 3) - .map(_.newInstance("INVALID_BUCKET_FILE", messageParameters, null)) - .last - .asInstanceOf[SparkException] - } else { // Spark 3.2 - new IllegalStateException(s"Invalid bucket file ${path}") - } - } - - // Copied from Spark 3.4 RowIndexUtil due to PARQUET-2161 (tracked in SPARK-39634) - // TODO: remove after PARQUET-2161 becomes available in Parquet - private def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int = { - sparkSchema.fields.zipWithIndex.find { case (field: StructField, _: Int) => - field.name == ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME - } match { - case Some((field: StructField, idx: Int)) => - if (field.dataType != LongType) { - throw new RuntimeException( - s"${ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME} must be of LongType") - } - idx - case _ => -1 - } - } - - protected def isNeededForSchema(sparkSchema: StructType): Boolean = { - findRowIndexColumnIndexInSchema(sparkSchema) >= 0 - } - - protected def getPartitionedFile(f: FileStatus, p: PartitionDirectory): PartitionedFile = - PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) - - protected def splitFiles(sparkSession: SparkSession, - file: FileStatus, - filePath: Path, - isSplitable: Boolean, - maxSplitBytes: Long, - partitionValues: InternalRow): Seq[PartitionedFile] = - PartitionedFileUtil.splitFiles(sparkSession, file, filePath, isSplitable, maxSplitBytes, partitionValues) -} From becd80b614cd5b41d70c904ca6dc849f47cc470a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 11 Jun 2024 09:01:42 -0600 Subject: [PATCH 11/12] save progress --- .../sql/comet/shims/ShimCometScanExec.scala | 83 +++++++++++ .../sql/comet/shims/ShimCometScanExec.scala | 141 ++++++++++++++++++ 2 files changed, 224 insertions(+) create mode 100644 spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala create mode 100644 spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala new file mode 100644 index 000000000..543116c10 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} +import org.apache.spark.sql.types.StructType +import org.apache.spark.SparkContext + +trait ShimCometScanExec { + def wrapped: FileSourceScanExec + + lazy val fileConstantMetadataColumns: Seq[AttributeReference] = + wrapped.fileConstantMetadataColumns + + protected def newDataSourceRDD( + sc: SparkContext, + inputPartitions: Seq[Seq[InputPartition]], + partitionReaderFactory: PartitionReaderFactory, + columnarReads: Boolean, + customMetrics: Map[String, SQLMetric]): DataSourceRDD = + new DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads, customMetrics) + + protected def newFileScanRDD( + fsRelation: HadoopFsRelation, + readFunction: PartitionedFile => Iterator[InternalRow], + filePartitions: Seq[FilePartition], + readSchema: StructType, + options: ParquetOptions): FileScanRDD = { + new FileScanRDD( + fsRelation.sparkSession, + readFunction, + filePartitions, + readSchema, + fileConstantMetadataColumns, + fsRelation.fileFormat.fileConstantMetadataExtractors, + options) + } + + protected def invalidBucketFile(path: String, sparkVersion: String): Throwable = + QueryExecutionErrors.invalidBucketFile(path) + + // see SPARK-39634 + protected def isNeededForSchema(sparkSchema: StructType): Boolean = false + + protected def getPartitionedFile(f: FileStatusWithMetadata, p: PartitionDirectory): PartitionedFile = + PartitionedFileUtil.getPartitionedFile(f, p.values, 0, f.getLen) + + protected def splitFiles(sparkSession: SparkSession, + file: FileStatusWithMetadata, + filePath: Path, + isSplitable: Boolean, + maxSplitBytes: Long, + partitionValues: InternalRow): Seq[PartitionedFile] = + PartitionedFileUtil.splitFiles(file, isSplitable, maxSplitBytes, partitionValues) +} diff --git a/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala new file mode 100644 index 000000000..02b97f9fb --- /dev/null +++ b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.comet.shims.ShimFileFormat + +import scala.language.implicitConversions + +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} +import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} +import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile} +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.{LongType, StructField, StructType} + +trait ShimCometScanExec { + def wrapped: FileSourceScanExec + + // TODO: remove after dropping Spark 3.2 support and directly call wrapped.metadataColumns + lazy val metadataColumns: Seq[AttributeReference] = wrapped.getClass.getDeclaredMethods + .filter(_.getName == "metadataColumns") + .map { a => a.setAccessible(true); a } + .flatMap(_.invoke(wrapped).asInstanceOf[Seq[AttributeReference]]) + + // TODO: remove after dropping Spark 3.2 and 3.3 support and directly call + // wrapped.fileConstantMetadataColumns + lazy val fileConstantMetadataColumns: Seq[AttributeReference] = + wrapped.getClass.getDeclaredMethods + .filter(_.getName == "fileConstantMetadataColumns") + .map { a => a.setAccessible(true); a } + .flatMap(_.invoke(wrapped).asInstanceOf[Seq[AttributeReference]]) + + // TODO: remove after dropping Spark 3.2 support and directly call new DataSourceRDD + protected def newDataSourceRDD( + sc: SparkContext, + inputPartitions: Seq[Seq[InputPartition]], + partitionReaderFactory: PartitionReaderFactory, + columnarReads: Boolean, + customMetrics: Map[String, SQLMetric]): DataSourceRDD = { + implicit def flattenSeq(p: Seq[Seq[InputPartition]]): Seq[InputPartition] = p.flatten + new DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads, customMetrics) + } + + // TODO: remove after dropping Spark 3.2 support and directly call new FileScanRDD + protected def newFileScanRDD( + fsRelation: HadoopFsRelation, + readFunction: PartitionedFile => Iterator[InternalRow], + filePartitions: Seq[FilePartition], + readSchema: StructType, + options: ParquetOptions): FileScanRDD = + classOf[FileScanRDD].getDeclaredConstructors + // Prevent to pick up incorrect constructors from any custom Spark forks. + .filter(c => List(3, 5, 6).contains(c.getParameterCount()) ) + .map { c => + c.getParameterCount match { + case 3 => c.newInstance(fsRelation.sparkSession, readFunction, filePartitions) + case 5 => + c.newInstance(fsRelation.sparkSession, readFunction, filePartitions, readSchema, metadataColumns) + case 6 => + c.newInstance( + fsRelation.sparkSession, + readFunction, + filePartitions, + readSchema, + fileConstantMetadataColumns, + options) + } + } + .last + .asInstanceOf[FileScanRDD] + + // TODO: remove after dropping Spark 3.2 and 3.3 support and directly call + // QueryExecutionErrors.SparkException + protected def invalidBucketFile(path: String, sparkVersion: String): Throwable = { + if (sparkVersion >= "3.3") { + val messageParameters = if (sparkVersion >= "3.4") Map("path" -> path) else Array(path) + classOf[SparkException].getDeclaredConstructors + .filter(_.getParameterCount == 3) + .map(_.newInstance("INVALID_BUCKET_FILE", messageParameters, null)) + .last + .asInstanceOf[SparkException] + } else { // Spark 3.2 + new IllegalStateException(s"Invalid bucket file ${path}") + } + } + + // Copied from Spark 3.4 RowIndexUtil due to PARQUET-2161 (tracked in SPARK-39634) + // TODO: remove after PARQUET-2161 becomes available in Parquet + private def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int = { + sparkSchema.fields.zipWithIndex.find { case (field: StructField, _: Int) => + field.name == ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME + } match { + case Some((field: StructField, idx: Int)) => + if (field.dataType != LongType) { + throw new RuntimeException( + s"${ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME} must be of LongType") + } + idx + case _ => -1 + } + } + + protected def isNeededForSchema(sparkSchema: StructType): Boolean = { + findRowIndexColumnIndexInSchema(sparkSchema) >= 0 + } + + protected def getPartitionedFile(f: FileStatus, p: PartitionDirectory): PartitionedFile = + PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) + + protected def splitFiles(sparkSession: SparkSession, + file: FileStatus, + filePath: Path, + isSplitable: Boolean, + maxSplitBytes: Long, + partitionValues: InternalRow): Seq[PartitionedFile] = + PartitionedFileUtil.splitFiles(sparkSession, file, filePath, isSplitable, maxSplitBytes, partitionValues) +} From 1379fccc2d82da9f88947ac68ba066846026e5a5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 11 Jun 2024 09:13:43 -0600 Subject: [PATCH 12/12] fix spark 4 build --- pom.xml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 378e7144c..b9001c9cd 100644 --- a/pom.xml +++ b/pom.xml @@ -572,7 +572,7 @@ under the License. spark-3.x spark-3.5 - not-needed + not-needed-yet @@ -588,6 +588,7 @@ under the License. 1.13.1 spark-4.0 not-needed-yet + not-needed-yet 17 ${java.version}