From 5bee8ae17def46b1099b3f553f8379d3337e3cd9 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 18 May 2024 17:32:48 -0700 Subject: [PATCH] add a test --- .../org/apache/comet/parquet/BatchReader.java | 3 +- .../CometParquetToSparkSchemaConverter.scala | 403 ++++++++++++++++++ .../comet/parquet/ParquetReadSuite.scala | 86 ++++ 3 files changed, 490 insertions(+), 2 deletions(-) create mode 100644 common/src/main/java/org/apache/comet/parquet/CometParquetToSparkSchemaConverter.scala diff --git a/common/src/main/java/org/apache/comet/parquet/BatchReader.java b/common/src/main/java/org/apache/comet/parquet/BatchReader.java index 9940390dc5..600fb9ea22 100644 --- a/common/src/main/java/org/apache/comet/parquet/BatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/BatchReader.java @@ -57,7 +57,6 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.comet.parquet.CometParquetReadSupport; import org.apache.spark.sql.execution.datasources.PartitionedFile; -import org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter; import org.apache.spark.sql.execution.metric.SQLMetric; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -257,7 +256,7 @@ public void init() throws URISyntaxException, IOException { MessageType fileSchema = requestedSchema; if (sparkSchema == null) { - sparkSchema = new ParquetToSparkSchemaConverter(conf).convert(requestedSchema); + sparkSchema = new CometParquetToSparkSchemaConverter(conf).convert(requestedSchema); } else { requestedSchema = CometParquetReadSupport.clipParquetSchema( diff --git a/common/src/main/java/org/apache/comet/parquet/CometParquetToSparkSchemaConverter.scala b/common/src/main/java/org/apache/comet/parquet/CometParquetToSparkSchemaConverter.scala new file mode 100644 index 0000000000..2e7311e0d7 --- /dev/null +++ b/common/src/main/java/org/apache/comet/parquet/CometParquetToSparkSchemaConverter.scala @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.io.{ColumnIO, GroupColumnIO, PrimitiveColumnIO} +import org.apache.parquet.schema._ +import org.apache.parquet.schema.LogicalTypeAnnotation._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition._ +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.normalizeFieldName +import org.apache.spark.sql.execution.datasources.parquet.{ParquetColumn, ParquetToSparkSchemaConverter} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class CometParquetToSparkSchemaConverter( + assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + caseSensitive: Boolean = SQLConf.CASE_SENSITIVE.defaultValue.get, + inferTimestampNTZ: Boolean = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get, + nanosAsLong: Boolean = SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get) extends ParquetToSparkSchemaConverter { + + def this(conf: Configuration) = this( + assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, + caseSensitive = conf.get(SQLConf.CASE_SENSITIVE.key).toBoolean, + inferTimestampNTZ = conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean, + nanosAsLong = conf.get(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key).toBoolean) + + override def convertField( + field: ColumnIO, + sparkReadType: Option[DataType] = None): ParquetColumn = { + val targetType = sparkReadType.map { + case udt: UserDefinedType[_] => udt.sqlType + case otherType => otherType + } + field match { + case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, targetType) + case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType) + } + } + + private def convertPrimitiveField( + primitiveColumn: PrimitiveColumnIO, + sparkReadType: Option[DataType] = None): ParquetColumn = { + val parquetType = primitiveColumn.getType.asPrimitiveType() + val typeAnnotation = primitiveColumn.getType.getLogicalTypeAnnotation + val typeName = primitiveColumn.getPrimitive + + def typeString = + if (typeAnnotation == null) s"$typeName" else s"$typeName ($typeAnnotation)" + + def typeNotImplemented() = + throw new UnsupportedOperationException("unsupported Parquet type: " + typeString) + + def illegalType() = + throw new UnsupportedOperationException("Illegal Parquet type: " + typeString) + + // When maxPrecision = -1, we skip precision range check, and always respect the precision + // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored + // as binaries with variable lengths. + def makeDecimalType(maxPrecision: Int = -1): DecimalType = { + val decimalLogicalTypeAnnotation = typeAnnotation + .asInstanceOf[DecimalLogicalTypeAnnotation] + val precision = decimalLogicalTypeAnnotation.getPrecision + val scale = decimalLogicalTypeAnnotation.getScale + + CometParquetSchemaConverter.checkConversionRequirement( + maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, + s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") + + DecimalType(precision, scale) + } + + val sparkType = sparkReadType.getOrElse(typeName match { + case BOOLEAN => BooleanType + + case FLOAT => FloatType + + case DOUBLE => DoubleType + + case INT32 => + typeAnnotation match { + case intTypeAnnotation: IntLogicalTypeAnnotation if intTypeAnnotation.isSigned => + intTypeAnnotation.getBitWidth match { + case 8 => ByteType + case 16 => ShortType + case 32 => IntegerType + case _ => illegalType() + } + case null => IntegerType + case _: DateLogicalTypeAnnotation => DateType + case _: DecimalLogicalTypeAnnotation => makeDecimalType(Decimal.MAX_INT_DIGITS) + case intTypeAnnotation: IntLogicalTypeAnnotation if !intTypeAnnotation.isSigned => + intTypeAnnotation.getBitWidth match { + case 8 => ShortType + case 16 => IntegerType + case 32 => LongType + case _ => illegalType() + } + case t: TimestampLogicalTypeAnnotation if t.getUnit == TimeUnit.MILLIS => + typeNotImplemented() + case _ => illegalType() + } + + case INT64 => + typeAnnotation match { + case intTypeAnnotation: IntLogicalTypeAnnotation if intTypeAnnotation.isSigned => + intTypeAnnotation.getBitWidth match { + case 64 => LongType + case _ => illegalType() + } + case null => LongType + case _: DecimalLogicalTypeAnnotation => makeDecimalType(Decimal.MAX_LONG_DIGITS) + case intTypeAnnotation: IntLogicalTypeAnnotation if !intTypeAnnotation.isSigned => + intTypeAnnotation.getBitWidth match { + // The precision to hold the largest unsigned long is: + // `java.lang.Long.toUnsignedString(-1).length` = 20 + case 64 => DecimalType(20, 0) + case _ => illegalType() + } + case timestamp: TimestampLogicalTypeAnnotation + if timestamp.getUnit == TimeUnit.MICROS || timestamp.getUnit == TimeUnit.MILLIS => + val inferTimestampNTZ = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get + if (timestamp.isAdjustedToUTC || !inferTimestampNTZ) { + TimestampType + } else { + TimestampNTZType + } + // SPARK-40819: NANOS are not supported as a Timestamp, convert to LongType without + // timezone awareness to address behaviour regression introduced by SPARK-34661 + case timestamp: TimestampLogicalTypeAnnotation + if timestamp.getUnit == TimeUnit.NANOS && SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get => + LongType + case _ => illegalType() + } + + case INT96 => + CometParquetSchemaConverter.checkConversionRequirement( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + "INT96 is not supported unless it's interpreted as timestamp. " + + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") + TimestampType + + case BINARY => + typeAnnotation match { + case _: StringLogicalTypeAnnotation | _: EnumLogicalTypeAnnotation | + _: JsonLogicalTypeAnnotation => StringType + case null if SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get => StringType + case null => BinaryType + case _: BsonLogicalTypeAnnotation => BinaryType + case _: DecimalLogicalTypeAnnotation => makeDecimalType() + case _ => illegalType() + } + + case FIXED_LEN_BYTE_ARRAY => + typeAnnotation match { + case _: DecimalLogicalTypeAnnotation => + makeDecimalType(Decimal.maxPrecisionForBytes(parquetType.getTypeLength)) + case _: UUIDLogicalTypeAnnotation => StringType + case _: IntervalLogicalTypeAnnotation => typeNotImplemented() + case null => BinaryType + case _ => illegalType() + } + + case _ => illegalType() + }) + + ParquetColumn(sparkType, primitiveColumn) + } + + private def convertGroupField( + groupColumn: GroupColumnIO, + sparkReadType: Option[DataType] = None): ParquetColumn = { + val field = groupColumn.getType.asGroupType() + Option(field.getLogicalTypeAnnotation).fold( + convertInternal(groupColumn, sparkReadType.map(_.asInstanceOf[StructType]))) { + // A Parquet list is represented as a 3-level structure: + // + // group (LIST) { + // repeated group list { + // element; + // } + // } + // + // However, according to the most recent Parquet format spec (not released yet up until + // writing), some 2-level structures are also recognized for backwards-compatibility. Thus, + // we need to check whether the 2nd level or the 3rd level refers to list element type. + // + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + case _: ListLogicalTypeAnnotation => + CometParquetSchemaConverter.checkConversionRequirement( + field.getFieldCount == 1, s"Invalid list type $field") + CometParquetSchemaConverter.checkConversionRequirement( + sparkReadType.forall(_.isInstanceOf[ArrayType]), + s"Invalid Spark read type: expected $field to be list type but found $sparkReadType") + + val repeated = groupColumn.getChild(0) + val repeatedType = repeated.getType + CometParquetSchemaConverter.checkConversionRequirement( + repeatedType.isRepetition(REPEATED), s"Invalid list type $field") + val sparkReadElementType = sparkReadType.map(_.asInstanceOf[ArrayType].elementType) + + if (isElementType2(repeatedType, field.getName)) { + var converted = convertField(repeated, sparkReadElementType) + val convertedType = sparkReadElementType.getOrElse(converted.sparkType) + + // legacy format such as: + // optional group my_list (LIST) { + // repeated int32 element; + // } + // we should mark the primitive field as required + if (repeatedType.isPrimitive) converted = converted.copy(required = true) + + ParquetColumn(ArrayType(convertedType, containsNull = false), + groupColumn, Seq(converted)) + } else { + val element = repeated.asInstanceOf[GroupColumnIO].getChild(0) + val converted = convertField(element, sparkReadElementType) + val convertedType = sparkReadElementType.getOrElse(converted.sparkType) + val optional = element.getType.isRepetition(OPTIONAL) + ParquetColumn(ArrayType(convertedType, containsNull = optional), + groupColumn, Seq(converted)) + } + + // scalastyle:off + // `MAP_KEY_VALUE` is for backwards-compatibility + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 + // scalastyle:on + case _: MapLogicalTypeAnnotation | _: MapKeyValueTypeAnnotation => + CometParquetSchemaConverter.checkConversionRequirement( + field.getFieldCount == 1 && !field.getType(0).isPrimitive, + s"Invalid map type: $field") + CometParquetSchemaConverter.checkConversionRequirement( + sparkReadType.forall(_.isInstanceOf[MapType]), + s"Invalid Spark read type: expected $field to be map type but found $sparkReadType") + + val keyValue = groupColumn.getChild(0).asInstanceOf[GroupColumnIO] + val keyValueType = keyValue.getType.asGroupType() + CometParquetSchemaConverter.checkConversionRequirement( + keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, + s"Invalid map type: $field") + + val key = keyValue.getChild(0) + val value = keyValue.getChild(1) + val sparkReadKeyType = sparkReadType.map(_.asInstanceOf[MapType].keyType) + val sparkReadValueType = sparkReadType.map(_.asInstanceOf[MapType].valueType) + val convertedKey = convertField(key, sparkReadKeyType) + val convertedValue = convertField(value, sparkReadValueType) + val convertedKeyType = sparkReadKeyType.getOrElse(convertedKey.sparkType) + val convertedValueType = sparkReadValueType.getOrElse(convertedValue.sparkType) + val valueOptional = value.getType.isRepetition(OPTIONAL) + ParquetColumn( + MapType(convertedKeyType, convertedValueType, + valueContainsNull = valueOptional), + groupColumn, Seq(convertedKey, convertedValue)) + case _ => + throw new UnsupportedOperationException("unrecognized Parquet type: " + field.toString) + } + } + + // scalastyle:off + // Here we implement Parquet LIST backwards-compatibility rules. + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + // scalastyle:on + private[parquet] def isElementType2(repeatedType: Type, parentName: String): Boolean = { + { + // For legacy 2-level list types with primitive element type, e.g.: + // + // // ARRAY (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + // + repeatedType.isPrimitive + } || { + // For legacy 2-level list types whose element type is a group type with 2 or more fields, + // e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + // + repeatedType.asGroupType().getFieldCount > 1 + } || { + // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == "array" + } || { + // For Parquet data generated by parquet-thrift, e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == s"${parentName}_tuple" + } + } + + + private def convertInternal( + groupColumn: GroupColumnIO, + sparkReadSchema: Option[StructType] = None): ParquetColumn = { + // First convert the read schema into a map from field name to the field itself, to avoid O(n) + // lookup cost below. + val schemaMapOpt = sparkReadSchema.map { schema => + schema.map(f => normalizeFieldName(f.name) -> f).toMap + } + + val converted = (0 until groupColumn.getChildrenCount).map { i => + val field = groupColumn.getChild(i) + val fieldFromReadSchema = schemaMapOpt.flatMap { schemaMap => + schemaMap.get(normalizeFieldName(field.getName)) + } + var fieldReadType = fieldFromReadSchema.map(_.dataType) + + // If a field is repeated here then it is neither contained by a `LIST` nor `MAP` + // annotated group (these should've been handled in `convertGroupField`), e.g.: + // + // message schema { + // repeated int32 int_array; + // } + // or + // message schema { + // repeated group struct_array { + // optional int32 field; + // } + // } + // + // the corresponding Spark read type should be an array and we should pass the element type + // to the group or primitive type conversion method. + if (field.getType.getRepetition == REPEATED) { + fieldReadType = fieldReadType.flatMap { + case at: ArrayType => Some(at.elementType) + case _ => + throw new UnsupportedOperationException("Illegal Parquet type " + groupColumn.toString) + } + } + + val convertedField = convertField(field, fieldReadType) + val fieldName = fieldFromReadSchema.map(_.name).getOrElse(field.getType.getName) + + field.getType.getRepetition match { + case OPTIONAL | REQUIRED => + val nullable = field.getType.getRepetition == OPTIONAL + (StructField(fieldName, convertedField.sparkType, nullable = nullable), + convertedField) + + case REPEATED => + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + val arrayType = ArrayType(convertedField.sparkType, containsNull = false) + (StructField(fieldName, arrayType, nullable = false), + ParquetColumn(arrayType, None, convertedField.repetitionLevel - 1, + convertedField.definitionLevel - 1, required = true, convertedField.path, + Seq(convertedField.copy(required = true)))) + } + } + + ParquetColumn(StructType(converted.map(_._1)), groupColumn, converted.map(_._2)) + } +} + +private object CometParquetSchemaConverter { + + def checkConversionRequirement(f: => Boolean, message: String): Unit = { + if (!f) { + throw new UnsupportedOperationException("conversion is not supported " + message) + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala index f447522973..f94eafd1ea 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala @@ -21,7 +21,9 @@ package org.apache.comet.parquet import java.io.{File, FileFilter} import java.math.BigDecimal +import java.nio.ByteBuffer import java.time.{ZoneId, ZoneOffset} +import java.util.UUID import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -32,6 +34,7 @@ import org.scalatest.Tag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.parquet.example.data.simple.SimpleGroup +import org.apache.parquet.io.api.Binary import org.apache.parquet.schema.MessageTypeParser import org.apache.spark.SparkException import org.apache.spark.sql.CometTestBase @@ -1251,6 +1254,89 @@ abstract class ParquetReadSuite extends CometTestBase { } } } + + test("read uuid column") { + Seq(false, true).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val uuid: Array[UUID] = Array.fill(100)(UUID.randomUUID()) + + createParquetFileWithUUID(path, dictionaryEnabled = dictionaryEnabled, 100, uuid, 128) + + val file = dir + .listFiles(new FileFilter { + override def accept(pathname: File): Boolean = + pathname.isFile && pathname.toString.endsWith("parquet") + }) + .head + + val reader = new BatchReader(new Configuration, file.toString, 100, null, null) + reader.init() + + val uuidRead: Array[String] = new Array[String](100) + + try { + while (reader.nextBatch()) { + val batch = reader.currentBatch() + val column = batch.column(1) + val numRows = batch.numRows() + for (i <- 0 until numRows) { + if (!column.isNullAt(i)) { + uuidRead(i) = column.getUTF8String(i).toString + } + } + } + uuid.indices.foreach { i => + assert( + uuid(i).toString == uuidRead(i), + s"UUID mismatch at index $i: ${uuid(i)} != ${uuidRead(i)}") + } + } finally { + reader.close() + } + } + } + } + + def createParquetFileWithUUID( + path: Path, + dictionaryEnabled: Boolean, + num: Int, + uuid: Array[UUID], + pageSize: Int = 128): Unit = { + val schemaStr = + """ + |message root { + | required int32 id; + | required FIXED_LEN_BYTE_ARRAY(16) uuid (UUID); + |} + """.stripMargin + + val schema = MessageTypeParser.parseMessageType(schemaStr.toString) + val writer = createParquetWriter( + schema, + path, + dictionaryEnabled = dictionaryEnabled, + pageSize = pageSize, + dictionaryPageSize = pageSize) + + (0 until num).foreach { n => + val record = new SimpleGroup(schema) + + record.add(0, n) + + // Convert UUID to 16-byte array and add to record + val bb = ByteBuffer.allocate(16) + bb.putLong(uuid(n).getMostSignificantBits) + bb.putLong(uuid(n).getLeastSignificantBits) + record.add(1, Binary.fromConstantByteArray(bb.array())) + + writer.write(record) + } + + writer.close() + } + def testScanner(cometEnabled: String, scanner: String, v1: Option[String] = None): Unit = { withSQLConf( CometConf.COMET_ENABLED.key -> cometEnabled,