From 76ab5888f7783bb3c44e5af0b637bb455f0a641b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Mar 2024 13:33:27 -0700 Subject: [PATCH] fix: Use `makeCopy` to change relation in `FileSourceScanExec` --- .../spark/sql/comet/CometScanExec.scala | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala index 4bf01f0f4..42cc96bf5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.comet import scala.collection.mutable.HashMap import scala.concurrent.duration.NANOSECONDS +import scala.reflect.ClassTag import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD @@ -439,8 +440,29 @@ case class CometScanExec( object CometScanExec { def apply(scanExec: FileSourceScanExec, session: SparkSession): CometScanExec = { - val wrapped = scanExec.copy(relation = - scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session)) + // TreeNode.mapProductIterator is protected method. + def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { + val arr = Array.ofDim[B](product.productArity) + var i = 0 + while (i < arr.length) { + arr(i) = f(product.productElement(i)) + i += 1 + } + arr + } + + // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues + // on other Spark distributions if FileSourceScanExec constructor is changed. + // Using `makeCopy` to avoid the issue. + // https://github.com/apache/arrow-datafusion-comet/issues/190 + def transform(arg: Any): AnyRef = arg match { + case _: HadoopFsRelation => + scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) + case other: AnyRef => other + case null => null + } + val newArgs = mapProductIterator(scanExec, transform(_)) + val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] val batchScanExec = CometScanExec( wrapped.relation, wrapped.output,