Skip to content

Commit

Permalink
build: Enable spark-4.0 Spark tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kazuyukitanimura committed Jun 8, 2024
1 parent c1d90aa commit b980c0e
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ public void init() throws URISyntaxException, IOException {
missingColumns = new boolean[columns.size()];
List<String[]> paths = requestedSchema.getPaths();
StructField[] nonPartitionFields = sparkSchema.fields();
ShimFileFormat.findRowIndexColumnIndexInSchema(sparkSchema);
for (int i = 0; i < requestedSchema.getFieldCount(); i++) {
Type t = requestedSchema.getFields().get(i);
Preconditions.checkState(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.comet.shims

import org.apache.spark.sql.types.{LongType, StructField, StructType}

object ShimFileFormat {

// TODO: remove after dropping Spark 3.2 & 3.3 support and directly use FileFormat.ROW_INDEX
Expand All @@ -32,4 +34,20 @@ object ShimFileFormat {

// TODO: remove after dropping Spark 3.2 support and use FileFormat.OPTION_RETURNING_BATCH
val OPTION_RETURNING_BATCH = "returning_batch"

// TODO: remove after dropping Spark 3.2 & 3.3 support and directly use
// RowIndexUtil.findRowIndexColumnIndexInSchema
def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int = {
sparkSchema.fields.zipWithIndex.find { case (field: StructField, _: Int) =>
field.name == ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME
} match {
case Some((field: StructField, idx: Int)) =>
if (field.dataType != LongType) {
throw new RuntimeException(
s"${ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME} must be of LongType")
}
idx
case _ => -1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ package org.apache.comet.shims

import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetRowIndexUtil
import org.apache.spark.sql.types.StructType

object ShimFileFormat {
// A name for a temporary column that holds row indexes computed by the file format reader
// until they can be placed in the _metadata struct.
val ROW_INDEX_TEMPORARY_COLUMN_NAME = ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME

val OPTION_RETURNING_BATCH = FileFormat.OPTION_RETURNING_BATCH

def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int =
ParquetRowIndexUtil.findRowIndexColumnIndexInSchema(sparkSchema)
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ trait CometExprShim {
}

protected def isTimestampNTZType(dt: DataType): Boolean = dt match {
// `TimestampNTZType` is private in Spark 3.2.
// `TimestampNTZType` is private
case dt if dt.typeName == "timestamp_ntz" => true
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.comet.shims

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, TimestampNTZType}
import org.apache.spark.sql.types.DataType

/**
* `CometExprShim` acts as a shim for for parsing expressions from different Spark versions.
Expand All @@ -33,7 +33,8 @@ trait CometExprShim {
}

protected def isTimestampNTZType(dt: DataType): Boolean = dt match {
case _: TimestampNTZType => true
// `TimestampNTZType` is private
case dt if dt.typeName == "timestamp_ntz" => true
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, H
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.sql.types.StructType

trait ShimCometScanExec {
def wrapped: FileSourceScanExec
Expand Down Expand Up @@ -108,24 +108,9 @@ trait ShimCometScanExec {
}
}

// Copied from Spark 3.4 RowIndexUtil due to PARQUET-2161 (tracked in SPARK-39634)
// TODO: remove after PARQUET-2161 becomes available in Parquet
private def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int = {
sparkSchema.fields.zipWithIndex.find { case (field: StructField, _: Int) =>
field.name == ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME
} match {
case Some((field: StructField, idx: Int)) =>
if (field.dataType != LongType) {
throw new RuntimeException(
s"${ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME} must be of LongType")
}
idx
case _ => -1
}
}

protected def isNeededForSchema(sparkSchema: StructType): Boolean = {
findRowIndexColumnIndexInSchema(sparkSchema) >= 0
// TODO: remove after PARQUET-2161 becomes available in Parquet (tracked in SPARK-39634)
ShimFileFormat.findRowIndexColumnIndexInSchema(sparkSchema) >= 0
}

protected def getPartitionedFile(f: FileStatus, p: PartitionDirectory): PartitionedFile =
Expand Down

0 comments on commit b980c0e

Please sign in to comment.