From 0c7a228b212a9e5cad0f9c551e757b2f0bf49f5c Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Wed, 13 Nov 2024 12:29:57 -0800 Subject: [PATCH] fix: Remove export --- .../comet/parquet/CometParquetPartitionReaderFactory.scala | 4 +++- .../scala/org/apache/comet/parquet/CometParquetScan.scala | 7 +++++++ .../org/apache/spark/sql/comet/CometBatchScanExec.scala | 6 +++++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala index 7afd6cfa3..edcb6641d 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration import org.apache.comet.{CometConf, CometRuntimeException} +import org.apache.comet.parquet.CometParquetFileFormat.HAS_NATIVE_OPERATIONS import org.apache.comet.shims.ShimSQLConf case class CometParquetPartitionReaderFactory( @@ -125,6 +126,7 @@ case class CometParquetPartitionReaderFactory( try { val (datetimeRebaseSpec, footer, filters) = getFilter(file) filters.foreach(pushed => ParquetInputFormat.setFilterPredicate(conf, pushed)) + val hasNativeOperations = conf.get(HAS_NATIVE_OPERATIONS, "false").toBoolean val cometReader = new BatchReader( conf, file, @@ -138,7 +140,7 @@ case class CometParquetPartitionReaderFactory( partitionSchema, file.partitionValues, JavaConverters.mapAsJavaMap(metrics), - false) + hasNativeOperations) val taskContext = Option(TaskContext.get) taskContext.foreach(_.addTaskCompletionListener[Unit](_ => cometReader.close())) return cometReader diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetScan.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetScan.scala index e3cd33b41..8edd4e527 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetScan.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetScan.scala @@ -33,8 +33,10 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration import org.apache.comet.MetricsSupport +import org.apache.comet.parquet.CometParquetFileFormat.HAS_NATIVE_OPERATIONS trait CometParquetScan extends FileScan with MetricsSupport { + private var hasNativeOperations = false def sparkSession: SparkSession def hadoopConf: Configuration def readDataSchema: StructType @@ -55,6 +57,7 @@ trait CometParquetScan extends FileScan with MetricsSupport { override def createReaderFactory(): PartitionReaderFactory = { val sqlConf = sparkSession.sessionState.conf CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + hadoopConf.set(HAS_NATIVE_OPERATIONS, hasNativeOperations.toString) val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) CometParquetPartitionReaderFactory( @@ -66,6 +69,10 @@ trait CometParquetScan extends FileScan with MetricsSupport { new ParquetOptions(options.asScala.toMap, sqlConf), metrics) } + + def prepareForNativeExec(): Unit = { + hasNativeOperations = true + } } object CometParquetScan { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBatchScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBatchScanExec.scala index 57aff9a85..ff6220bfb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBatchScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBatchScanExec.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.vectorized._ import com.google.common.base.Objects import org.apache.comet.{DataTypeSupport, MetricsSupport} +import org.apache.comet.parquet.CometParquetScan import org.apache.comet.shims.ShimCometBatchScanExec case class CometBatchScanExec(wrapped: BatchScanExec, runtimeFilters: Seq[Expression]) @@ -77,7 +78,10 @@ case class CometBatchScanExec(wrapped: BatchScanExec, runtimeFilters: Seq[Expres } def prepareForNativeExec(): Unit = { - // TODO: utilize this to avoid import and export Arrow arrays + scan match { + case s: CometParquetScan => s.prepareForNativeExec() + case _ => // TODO support Iceberg + } } override def executeCollect(): Array[InternalRow] = {