From ab09337699876c839f841c3ed4545279130d1522 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Dec 2024 11:51:45 -0700 Subject: [PATCH] [comet-parquet-exec] Comet parquet exec 2 (copy of Parth's PR) (#1138) * WIP: (POC2) A Parquet reader that uses the arrow-rs Parquet reader directly * Change default config --------- Co-authored-by: Parth Chandra --- .../java/org/apache/comet/parquet/Native.java | 52 ++ .../comet/parquet/NativeBatchReader.java | 507 ++++++++++++++++++ .../comet/parquet/NativeColumnReader.java | 190 +++++++ .../scala/org/apache/comet/CometConf.scala | 12 +- native/core/src/parquet/mod.rs | 220 +++++++- .../parquet/CometParquetFileFormat.scala | 51 +- .../org/apache/spark/sql/CometTestBase.scala | 3 + 7 files changed, 1018 insertions(+), 17 deletions(-) create mode 100644 common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java create mode 100644 common/src/main/java/org/apache/comet/parquet/NativeColumnReader.java diff --git a/common/src/main/java/org/apache/comet/parquet/Native.java b/common/src/main/java/org/apache/comet/parquet/Native.java index 1e666652e..1ed01d326 100644 --- a/common/src/main/java/org/apache/comet/parquet/Native.java +++ b/common/src/main/java/org/apache/comet/parquet/Native.java @@ -234,4 +234,56 @@ public static native void setPageV2( * @param handle the handle to the native Parquet column reader */ public static native void closeColumnReader(long handle); + + ///////////// Arrow Native Parquet Reader APIs + // TODO: Add partitionValues(?), improve requiredColumns to use a projection mask that corresponds + // to arrow. + // Add batch size, datetimeRebaseModeSpec, metrics(how?)... + + /** + * Initialize a record batch reader for a PartitionedFile + * + * @param filePath + * @param start + * @param length + * @param required_columns array of names of fields to read + * @return a handle to the record batch reader, used in subsequent calls. + */ + public static native long initRecordBatchReader( + String filePath, long start, long length, Object[] required_columns); + + public static native int numRowGroups(long handle); + + public static native long numTotalRows(long handle); + + // arrow native version of read batch + /** + * Read the next batch of data into memory on native side + * + * @param handle + * @return the number of rows read + */ + public static native int readNextRecordBatch(long handle); + + // arrow native equivalent of currentBatch. 'columnNum' is number of the column in the record + // batch + /** + * Load the column corresponding to columnNum in the currently loaded record batch into JVM + * + * @param handle + * @param columnNum + * @param arrayAddr + * @param schemaAddr + */ + public static native void currentColumnBatch( + long handle, int columnNum, long arrayAddr, long schemaAddr); + + // arrow native version to close record batch reader + + /** + * Close the record batch reader. Free the resources + * + * @param handle + */ + public static native void closeRecordBatchReader(long handle); } diff --git a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java new file mode 100644 index 000000000..17fab47e5 --- /dev/null +++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet; + +import java.io.Closeable; +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.URISyntaxException; +import java.util.*; + +import scala.Option; +import scala.collection.Seq; +import scala.collection.mutable.Buffer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.arrow.c.CometSchemaImporter; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.HadoopReadOptions; +import org.apache.parquet.ParquetReadOptions; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.comet.parquet.CometParquetReadSupport; +import org.apache.spark.sql.execution.datasources.PartitionedFile; +import org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter; +import org.apache.spark.sql.execution.metric.SQLMetric; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.util.AccumulatorV2; + +import org.apache.comet.CometConf; +import org.apache.comet.shims.ShimBatchReader; +import org.apache.comet.shims.ShimFileFormat; +import org.apache.comet.vector.CometVector; + +/** + * A vectorized Parquet reader that reads a Parquet file in a batched fashion. + * + *

Example of how to use this: + * + *

+ *   BatchReader reader = new BatchReader(parquetFile, batchSize);
+ *   try {
+ *     reader.init();
+ *     while (reader.readBatch()) {
+ *       ColumnarBatch batch = reader.currentBatch();
+ *       // consume the batch
+ *     }
+ *   } finally { // resources associated with the reader should be released
+ *     reader.close();
+ *   }
+ * 
+ */ +public class NativeBatchReader extends RecordReader implements Closeable { + private static final Logger LOG = LoggerFactory.getLogger(NativeBatchReader.class); + protected static final BufferAllocator ALLOCATOR = new RootAllocator(); + + private Configuration conf; + private int capacity; + private boolean isCaseSensitive; + private boolean useFieldId; + private boolean ignoreMissingIds; + private StructType partitionSchema; + private InternalRow partitionValues; + private PartitionedFile file; + private final Map metrics; + + private long rowsRead; + private StructType sparkSchema; + private MessageType requestedSchema; + private CometVector[] vectors; + private AbstractColumnReader[] columnReaders; + private CometSchemaImporter importer; + private ColumnarBatch currentBatch; + // private FileReader fileReader; + private boolean[] missingColumns; + private boolean isInitialized; + private ParquetMetadata footer; + + /** The total number of rows across all row groups of the input split. */ + private long totalRowCount; + + /** + * Whether the native scan should always return decimal represented by 128 bits, regardless of its + * precision. Normally, this should be true if native execution is enabled, since Arrow compute + * kernels doesn't support 32 and 64 bit decimals yet. + */ + // TODO: (ARROW NATIVE) + private boolean useDecimal128; + + /** + * Whether to return dates/timestamps that were written with legacy hybrid (Julian + Gregorian) + * calendar as it is. If this is true, Comet will return them as it is, instead of rebasing them + * to the new Proleptic Gregorian calendar. If this is false, Comet will throw exceptions when + * seeing these dates/timestamps. + */ + // TODO: (ARROW NATIVE) + private boolean useLegacyDateTimestamp; + + /** The TaskContext object for executing this task. */ + private final TaskContext taskContext; + + private long handle; + + // Only for testing + public NativeBatchReader(String file, int capacity) { + this(file, capacity, null, null); + } + + // Only for testing + public NativeBatchReader( + String file, int capacity, StructType partitionSchema, InternalRow partitionValues) { + this(new Configuration(), file, capacity, partitionSchema, partitionValues); + } + + // Only for testing + public NativeBatchReader( + Configuration conf, + String file, + int capacity, + StructType partitionSchema, + InternalRow partitionValues) { + + this.conf = conf; + this.capacity = capacity; + this.isCaseSensitive = false; + this.useFieldId = false; + this.ignoreMissingIds = false; + this.partitionSchema = partitionSchema; + this.partitionValues = partitionValues; + + this.file = ShimBatchReader.newPartitionedFile(partitionValues, file); + this.metrics = new HashMap<>(); + + this.taskContext = TaskContext$.MODULE$.get(); + } + + public NativeBatchReader(AbstractColumnReader[] columnReaders) { + // Todo: set useDecimal128 and useLazyMaterialization + int numColumns = columnReaders.length; + this.columnReaders = new AbstractColumnReader[numColumns]; + vectors = new CometVector[numColumns]; + currentBatch = new ColumnarBatch(vectors); + // This constructor is used by Iceberg only. The columnReaders are + // initialized in Iceberg, so no need to call the init() + isInitialized = true; + this.taskContext = TaskContext$.MODULE$.get(); + this.metrics = new HashMap<>(); + } + + NativeBatchReader( + Configuration conf, + PartitionedFile inputSplit, + ParquetMetadata footer, + int capacity, + StructType sparkSchema, + boolean isCaseSensitive, + boolean useFieldId, + boolean ignoreMissingIds, + boolean useLegacyDateTimestamp, + StructType partitionSchema, + InternalRow partitionValues, + Map metrics) { + this.conf = conf; + this.capacity = capacity; + this.sparkSchema = sparkSchema; + this.isCaseSensitive = isCaseSensitive; + this.useFieldId = useFieldId; + this.ignoreMissingIds = ignoreMissingIds; + this.useLegacyDateTimestamp = useLegacyDateTimestamp; + this.partitionSchema = partitionSchema; + this.partitionValues = partitionValues; + this.file = inputSplit; + this.footer = footer; + this.metrics = metrics; + this.taskContext = TaskContext$.MODULE$.get(); + } + + /** + * Initialize this reader. The reason we don't do it in the constructor is that we want to close + * any resource hold by this reader when error happens during the initialization. + */ + public void init() throws URISyntaxException, IOException { + + useDecimal128 = + conf.getBoolean( + CometConf.COMET_USE_DECIMAL_128().key(), + (Boolean) CometConf.COMET_USE_DECIMAL_128().defaultValue().get()); + + long start = file.start(); + long length = file.length(); + String filePath = file.filePath().toString(); + + requestedSchema = footer.getFileMetaData().getSchema(); + MessageType fileSchema = requestedSchema; + // TODO: (ARROW NATIVE) Get requested schema - Convert the Spark schema (from catalyst) into a + // list of fields to project (?). Fields must be matched by field id first and then by name + { //////// Get requested Schema - replace this block of code native (avoid reading the footer + ParquetReadOptions.Builder builder = HadoopReadOptions.builder(conf, new Path(filePath)); + + if (start >= 0 && length >= 0) { + builder = builder.withRange(start, start + length); + } + ParquetReadOptions readOptions = builder.build(); + + ReadOptions cometReadOptions = ReadOptions.builder(conf).build(); + + if (sparkSchema == null) { + sparkSchema = new ParquetToSparkSchemaConverter(conf).convert(requestedSchema); + } else { + requestedSchema = + CometParquetReadSupport.clipParquetSchema( + requestedSchema, sparkSchema, isCaseSensitive, useFieldId, ignoreMissingIds); + if (requestedSchema.getColumns().size() != sparkSchema.size()) { + throw new IllegalArgumentException( + String.format( + "Spark schema has %d columns while " + "Parquet schema has %d columns", + sparkSchema.size(), requestedSchema.getColumns().size())); + } + } + } ////// End get requested schema + + //// Create Column readers + List columns = requestedSchema.getColumns(); + int numColumns = columns.size(); + if (partitionSchema != null) numColumns += partitionSchema.size(); + columnReaders = new AbstractColumnReader[numColumns]; + + // Initialize missing columns and use null vectors for them + missingColumns = new boolean[columns.size()]; + List paths = requestedSchema.getPaths(); + StructField[] nonPartitionFields = sparkSchema.fields(); + // ShimFileFormat.findRowIndexColumnIndexInSchema(sparkSchema); + for (int i = 0; i < requestedSchema.getFieldCount(); i++) { + Type t = requestedSchema.getFields().get(i); + Preconditions.checkState( + t.isPrimitive() && !t.isRepetition(Type.Repetition.REPEATED), + "Complex type is not supported"); + String[] colPath = paths.get(i); + if (nonPartitionFields[i].name().equals(ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME())) { + // Values of ROW_INDEX_TEMPORARY_COLUMN_NAME column are always populated with + // generated row indexes, rather than read from the file. + // TODO(SPARK-40059): Allow users to include columns named + // FileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME in their schemas. + // TODO: (ARROW NATIVE) Support row indices ... + // long[] rowIndices = fileReader.getRowIndices(); + // columnReaders[i] = new RowIndexColumnReader(nonPartitionFields[i], capacity, + // rowIndices); + missingColumns[i] = true; + } else if (fileSchema.containsPath(colPath)) { + ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); + if (!fd.equals(columns.get(i))) { + throw new UnsupportedOperationException("Schema evolution is not supported"); + } + missingColumns[i] = false; + } else { + if (columns.get(i).getMaxDefinitionLevel() == 0) { + throw new IOException( + "Required column '" + + Arrays.toString(colPath) + + "' is missing" + + " in data file " + + filePath); + } + ConstantColumnReader reader = + new ConstantColumnReader(nonPartitionFields[i], capacity, useDecimal128); + columnReaders[i] = reader; + missingColumns[i] = true; + } + } + + // Initialize constant readers for partition columns + if (partitionSchema != null) { + StructField[] partitionFields = partitionSchema.fields(); + for (int i = columns.size(); i < columnReaders.length; i++) { + int fieldIndex = i - columns.size(); + StructField field = partitionFields[fieldIndex]; + ConstantColumnReader reader = + new ConstantColumnReader(field, capacity, partitionValues, fieldIndex, useDecimal128); + columnReaders[i] = reader; + } + } + + vectors = new CometVector[numColumns]; + currentBatch = new ColumnarBatch(vectors); + + // For test purpose only + // If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read + // will be updated to the accumulator. So we can check if the row groups are filtered or not + // in test case. + // Note that this tries to get thread local TaskContext object, if this is called at other + // thread, it won't update the accumulator. + if (taskContext != null) { + Option> accu = getTaskAccumulator(taskContext.taskMetrics()); + if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { + @SuppressWarnings("unchecked") + AccumulatorV2 intAccum = (AccumulatorV2) accu.get(); + // TODO: Get num_row_groups from native + // intAccum.add(fileReader.getRowGroups().size()); + } + } + + // TODO: (ARROW NATIVE) Use a ProjectionMask here ? + ArrayList requiredColumns = new ArrayList<>(); + for (Type col : requestedSchema.asGroupType().getFields()) { + requiredColumns.add(col.getName()); + } + this.handle = Native.initRecordBatchReader(filePath, start, length, requiredColumns.toArray()); + totalRowCount = Native.numRowGroups(handle); + isInitialized = true; + } + + public void setSparkSchema(StructType schema) { + this.sparkSchema = schema; + } + + public AbstractColumnReader[] getColumnReaders() { + return columnReaders; + } + + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + // Do nothing. The initialization work is done in 'init' already. + } + + @Override + public boolean nextKeyValue() throws IOException { + return nextBatch(); + } + + @Override + public Void getCurrentKey() { + return null; + } + + @Override + public ColumnarBatch getCurrentValue() { + return currentBatch(); + } + + @Override + public float getProgress() { + return (float) rowsRead / totalRowCount; + } + + /** + * Returns the current columnar batch being read. + * + *

Note that this must be called AFTER {@link NativeBatchReader#nextBatch()}. + */ + public ColumnarBatch currentBatch() { + return currentBatch; + } + + /** + * Loads the next batch of rows. This is called by Spark _and_ Iceberg + * + * @return true if there are no more rows to read, false otherwise. + */ + public boolean nextBatch() throws IOException { + Preconditions.checkState(isInitialized, "init() should be called first!"); + + if (rowsRead >= totalRowCount) return false; + int batchSize; + + try { + batchSize = loadNextBatch(); + } catch (RuntimeException e) { + // Spark will check certain exception e.g. `SchemaColumnConvertNotSupportedException`. + throw e; + } catch (Throwable e) { + throw new IOException(e); + } + + if (batchSize == 0) return false; + + long totalDecodeTime = 0, totalLoadTime = 0; + for (int i = 0; i < columnReaders.length; i++) { + AbstractColumnReader reader = columnReaders[i]; + long startNs = System.nanoTime(); + // TODO: read from native reader + reader.readBatch(batchSize); + // totalDecodeTime += System.nanoTime() - startNs; + // startNs = System.nanoTime(); + vectors[i] = reader.currentBatch(); + totalLoadTime += System.nanoTime() - startNs; + } + + // TODO: (ARROW NATIVE) Add Metrics + // SQLMetric decodeMetric = metrics.get("ParquetNativeDecodeTime"); + // if (decodeMetric != null) { + // decodeMetric.add(totalDecodeTime); + // } + SQLMetric loadMetric = metrics.get("ParquetNativeLoadTime"); + if (loadMetric != null) { + loadMetric.add(totalLoadTime); + } + + currentBatch.setNumRows(batchSize); + rowsRead += batchSize; + return true; + } + + @Override + public void close() throws IOException { + if (columnReaders != null) { + for (AbstractColumnReader reader : columnReaders) { + if (reader != null) { + reader.close(); + } + } + } + if (importer != null) { + importer.close(); + importer = null; + } + Native.closeRecordBatchReader(this.handle); + } + + @SuppressWarnings("deprecation") + private int loadNextBatch() throws Throwable { + long startNs = System.nanoTime(); + + int batchSize = Native.readNextRecordBatch(this.handle); + if (importer != null) importer.close(); + importer = new CometSchemaImporter(ALLOCATOR); + + List columns = requestedSchema.getColumns(); + for (int i = 0; i < columns.size(); i++) { + // TODO: (ARROW NATIVE) check this. Currently not handling missing columns correctly? + if (missingColumns[i]) continue; + if (columnReaders[i] != null) columnReaders[i].close(); + // TODO: (ARROW NATIVE) handle tz, datetime & int96 rebase + DataType dataType = sparkSchema.fields()[i].dataType(); + NativeColumnReader reader = + new NativeColumnReader( + this.handle, + i, + dataType, + columns.get(i), + importer, + capacity, + useDecimal128, + useLegacyDateTimestamp); + columnReaders[i] = reader; + } + return batchSize; + } + + // Signature of externalAccums changed from returning a Buffer to returning a Seq. If comet is + // expecting a Buffer but the Spark version returns a Seq or vice versa, we get a + // method not found exception. + @SuppressWarnings("unchecked") + private Option> getTaskAccumulator(TaskMetrics taskMetrics) { + Method externalAccumsMethod; + try { + externalAccumsMethod = TaskMetrics.class.getDeclaredMethod("externalAccums"); + externalAccumsMethod.setAccessible(true); + String returnType = externalAccumsMethod.getReturnType().getName(); + if (returnType.equals("scala.collection.mutable.Buffer")) { + return ((Buffer>) externalAccumsMethod.invoke(taskMetrics)) + .lastOption(); + } else if (returnType.equals("scala.collection.Seq")) { + return ((Seq>) externalAccumsMethod.invoke(taskMetrics)).lastOption(); + } else { + return Option.apply(null); // None + } + } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { + return Option.apply(null); // None + } + } +} diff --git a/common/src/main/java/org/apache/comet/parquet/NativeColumnReader.java b/common/src/main/java/org/apache/comet/parquet/NativeColumnReader.java new file mode 100644 index 000000000..448ba0fec --- /dev/null +++ b/common/src/main/java/org/apache/comet/parquet/NativeColumnReader.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.CometSchemaImporter; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.page.*; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.spark.sql.types.DataType; + +import org.apache.comet.vector.*; + +// TODO: extend ColumnReader instead of AbstractColumnReader to reduce code duplication +public class NativeColumnReader extends AbstractColumnReader { + protected static final Logger LOG = LoggerFactory.getLogger(NativeColumnReader.class); + protected final BufferAllocator ALLOCATOR = new RootAllocator(); + + /** + * The current Comet vector holding all the values read by this column reader. Owned by this + * reader and MUST be closed after use. + */ + private CometDecodedVector currentVector; + + /** Dictionary values for this column. Only set if the column is using dictionary encoding. */ + protected CometDictionary dictionary; + + /** + * The number of values in the current batch, used when we are skipping importing of Arrow + * vectors, in which case we'll simply update the null count of the existing vectors. + */ + int currentNumValues; + + /** + * Whether the last loaded vector contains any null value. This is used to determine if we can + * skip vector reloading. If the flag is false, Arrow C API will skip to import the validity + * buffer, and therefore we cannot skip vector reloading. + */ + boolean hadNull; + + private final CometSchemaImporter importer; + + private ArrowArray array = null; + private ArrowSchema schema = null; + + private long nativeBatchHandle = 0xDEADBEEFL; + private final int columnNum; + + public NativeColumnReader( + long nativeBatchHandle, + int columnNum, + DataType type, + ColumnDescriptor descriptor, + CometSchemaImporter importer, + int batchSize, + boolean useDecimal128, + boolean useLegacyDateTimestamp) { + super(type, descriptor, useDecimal128, useLegacyDateTimestamp); + assert batchSize > 0 : "Batch size must be positive, found " + batchSize; + this.batchSize = batchSize; + this.importer = importer; + this.nativeBatchHandle = nativeBatchHandle; + this.columnNum = columnNum; + initNative(); + } + + @Override + // Override in order to avoid creation of JVM side column readers + protected void initNative() { + LOG.debug( + "Native column reader " + String.join(".", this.descriptor.getPath()) + " is initialized"); + nativeHandle = 0; + } + + @Override + public void readBatch(int total) { + LOG.debug("Reading column batch of size = " + total); + + this.currentNumValues = total; + } + + /** Returns the {@link CometVector} read by this reader. */ + @Override + public CometVector currentBatch() { + return loadVector(); + } + + @Override + public void close() { + if (currentVector != null) { + currentVector.close(); + currentVector = null; + } + super.close(); + } + + /** Returns a decoded {@link CometDecodedVector Comet vector}. */ + public CometDecodedVector loadVector() { + + LOG.debug("Loading vector for next batch"); + + // Close the previous vector first to release struct memory allocated to import Arrow array & + // schema from native side, through the C data interface + if (currentVector != null) { + currentVector.close(); + } + + LogicalTypeAnnotation logicalTypeAnnotation = + descriptor.getPrimitiveType().getLogicalTypeAnnotation(); + boolean isUuid = + logicalTypeAnnotation instanceof LogicalTypeAnnotation.UUIDLogicalTypeAnnotation; + + array = ArrowArray.allocateNew(ALLOCATOR); + schema = ArrowSchema.allocateNew(ALLOCATOR); + + long arrayAddr = array.memoryAddress(); + long schemaAddr = schema.memoryAddress(); + + Native.currentColumnBatch(nativeBatchHandle, columnNum, arrayAddr, schemaAddr); + + FieldVector vector = importer.importVector(array, schema); + + DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary(); + + CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128); + + // Update whether the current vector contains any null values. This is used in the following + // batch(s) to determine whether we can skip loading the native vector. + hadNull = cometVector.hasNull(); + + if (dictionaryEncoding == null) { + if (dictionary != null) { + // This means the column was using dictionary encoding but now has fall-back to plain + // encoding, on the native side. Setting 'dictionary' to null here, so we can use it as + // a condition to check if we can re-use vector later. + dictionary = null; + } + // Either the column is not dictionary encoded, or it was using dictionary encoding but + // a new data page has switched back to use plain encoding. For both cases we should + // return plain vector. + currentVector = cometVector; + return currentVector; + } + + // We should already re-initiate `CometDictionary` here because `Data.importVector` API will + // release the previous dictionary vector and create a new one. + Dictionary arrowDictionary = importer.getProvider().lookup(dictionaryEncoding.getId()); + CometPlainVector dictionaryVector = + new CometPlainVector(arrowDictionary.getVector(), useDecimal128, isUuid); + if (dictionary != null) { + dictionary.setDictionaryVector(dictionaryVector); + } else { + dictionary = new CometDictionary(dictionaryVector); + } + + currentVector = + new CometDictionaryVector( + cometVector, dictionary, importer.getProvider(), useDecimal128, false, isUuid); + + currentVector = + new CometDictionaryVector(cometVector, dictionary, importer.getProvider(), useDecimal128); + return currentVector; + } +} diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 09355446c..275114a11 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -85,7 +85,17 @@ object CometConf extends ShimCometConf { "read supported data sources (currently only Parquet is supported natively)." + " By default, this config is true.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) + + val COMET_NATIVE_ARROW_SCAN_ENABLED: ConfigEntry[Boolean] = conf( + "spark.comet.native.arrow.scan.enabled") + .internal() + .doc( + "Whether to enable the fully native arrow based scan. When this is turned on, Spark will " + + "use Comet to read Parquet files natively via the Arrow based Parquet reader." + + " By default, this config is false.") + .booleanConf + .createWithDefault(false) val COMET_PARQUET_PARALLEL_IO_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.parquet.read.parallel.io.enabled") diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index 455f19929..ffd342167 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -23,6 +23,7 @@ pub use mutable_vector::*; pub mod util; pub mod read; +use std::fs::File; use std::{boxed::Box, ptr::NonNull, sync::Arc}; use crate::errors::{try_unwrap_or_throw, CometError}; @@ -39,10 +40,18 @@ use jni::{ }, }; +use crate::execution::operators::ExecutionError; use crate::execution::utils::SparkArrowConvert; use arrow::buffer::{Buffer, MutableBuffer}; -use jni::objects::{JBooleanArray, JLongArray, JPrimitiveArray, ReleaseMode}; +use arrow_array::{Array, RecordBatch}; +use jni::objects::{ + JBooleanArray, JLongArray, JObjectArray, JPrimitiveArray, JString, ReleaseMode, +}; +use jni::sys::jstring; +use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder}; +use parquet::arrow::ProjectionMask; use read::ColumnReader; +use url::Url; use util::jni::{convert_column_descriptor, convert_encoding}; use self::util::jni::TypePromotionInfo; @@ -582,3 +591,212 @@ fn from_u8_slice(src: &mut [u8]) -> &mut [i8] { let raw_ptr = src.as_mut_ptr() as *mut i8; unsafe { std::slice::from_raw_parts_mut(raw_ptr, src.len()) } } + +// TODO: (ARROW NATIVE) remove this if not needed. +enum ParquetReaderState { + Init, + Reading, + Complete, +} +/// Parquet read context maintained across multiple JNI calls. +struct BatchContext { + batch_reader: ParquetRecordBatchReader, + current_batch: Option, + reader_state: ParquetReaderState, + num_row_groups: i32, + total_rows: i64, +} + +#[inline] +fn get_batch_context<'a>(handle: jlong) -> Result<&'a mut BatchContext, CometError> { + unsafe { + (handle as *mut BatchContext) + .as_mut() + .ok_or_else(|| CometError::NullPointer("null batch context handle".to_string())) + } +} + +#[inline] +fn get_batch_reader<'a>(handle: jlong) -> Result<&'a mut ParquetRecordBatchReader, CometError> { + Ok(&mut get_batch_context(handle)?.batch_reader) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBatchReader( + e: JNIEnv, + _jclass: JClass, + file_path: jstring, + start: jlong, + length: jlong, + required_columns: jobjectArray, +) -> jlong { + try_unwrap_or_throw(&e, |mut env| unsafe { + let path: String = env + .get_string(&JString::from_raw(file_path)) + .unwrap() + .into(); + //TODO: (ARROW NATIVE) - this works only for 'file://' urls + let path = Url::parse(path.as_ref()).unwrap().to_file_path().unwrap(); + let file = File::open(path).unwrap(); + + // Create a async parquet reader builder with batch_size. + // batch_size is the number of rows to read up to buffer once from pages, defaults to 1024 + // TODO: (ARROW NATIVE) Use async reader ParquetRecordBatchStreamBuilder + let mut builder = ParquetRecordBatchReaderBuilder::try_new(file) + .unwrap() + .with_batch_size(8192); // TODO: (ARROW NATIVE) Use batch size configured in JVM + + //TODO: (ARROW NATIVE) if we can get the ParquetMetadata serialized, we need not do this. + let metadata = builder.metadata().clone(); + + let mut columns_to_read: Vec = Vec::new(); + let columns_to_read_array = JObjectArray::from_raw(required_columns); + let array_len = env.get_array_length(&columns_to_read_array)?; + let mut required_columns: Vec = Vec::new(); + for i in 0..array_len { + let p: JString = env + .get_object_array_element(&columns_to_read_array, i)? + .into(); + required_columns.push(env.get_string(&p)?.into()); + } + for (i, col) in metadata + .file_metadata() + .schema_descr() + .columns() + .iter() + .enumerate() + { + for (_, required) in required_columns.iter().enumerate() { + if col.name().to_uppercase().eq(&required.to_uppercase()) { + columns_to_read.push(i); + break; + } + } + } + //TODO: (ARROW NATIVE) make this work for complex types (especially deeply nested structs) + let mask = ProjectionMask::leaves(metadata.file_metadata().schema_descr(), columns_to_read); + // Set projection mask to read only root columns 1 and 2. + builder = builder.with_projection(mask); + + let mut row_groups_to_read: Vec = Vec::new(); + let mut total_rows: i64 = 0; + // get row groups - + for (i, rg) in metadata.row_groups().into_iter().enumerate() { + let rg_start = rg.file_offset().unwrap(); + let rg_end = rg_start + rg.compressed_size(); + if rg_start >= start && rg_end <= start + length { + row_groups_to_read.push(i); + total_rows += rg.num_rows(); + } + } + + // Build a sync parquet reader. + let batch_reader = builder + .with_row_groups(row_groups_to_read.clone()) + .build() + .unwrap(); + + let ctx = BatchContext { + batch_reader, + current_batch: None, + reader_state: ParquetReaderState::Init, + num_row_groups: row_groups_to_read.len() as i32, + total_rows: total_rows, + }; + let res = Box::new(ctx); + Ok(Box::into_raw(res) as i64) + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_comet_parquet_Native_numRowGroups( + e: JNIEnv, + _jclass: JClass, + handle: jlong, +) -> jint { + try_unwrap_or_throw(&e, |_env| { + let context = get_batch_context(handle)?; + // Read data + Ok(context.num_row_groups) + }) as jint +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_comet_parquet_Native_numTotalRows( + e: JNIEnv, + _jclass: JClass, + handle: jlong, +) -> jlong { + try_unwrap_or_throw(&e, |_env| { + let context = get_batch_context(handle)?; + // Read data + Ok(context.total_rows) + }) as jlong +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_comet_parquet_Native_readNextRecordBatch( + e: JNIEnv, + _jclass: JClass, + handle: jlong, +) -> jint { + try_unwrap_or_throw(&e, |_env| { + let context = get_batch_context(handle)?; + let batch_reader = &mut context.batch_reader; + // Read data + let mut rows_read: i32 = 0; + let batch = batch_reader.next(); + + match batch { + Some(record_batch) => { + let batch = record_batch?; + rows_read = batch.num_rows() as i32; + context.current_batch = Some(batch); + context.reader_state = ParquetReaderState::Reading; + } + None => { + context.current_batch = None; + context.reader_state = ParquetReaderState::Complete; + } + } + Ok(rows_read) + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_comet_parquet_Native_currentColumnBatch( + e: JNIEnv, + _jclass: JClass, + handle: jlong, + column_idx: jint, + array_addr: jlong, + schema_addr: jlong, +) { + try_unwrap_or_throw(&e, |_env| { + let context = get_batch_context(handle)?; + let batch_reader = context + .current_batch + .as_mut() + .ok_or_else(|| CometError::Execution { + source: ExecutionError::GeneralError("There is no more data to read".to_string()), + }); + let data = batch_reader?.column(column_idx as usize).into_data(); + data.move_to_spark(array_addr, schema_addr) + .map_err(|e| e.into()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_comet_parquet_Native_closeRecordBatchReader( + env: JNIEnv, + _jclass: JClass, + handle: jlong, +) { + try_unwrap_or_throw(&env, |_| { + unsafe { + let ctx = handle as *mut BatchContext; + let _ = Box::from_raw(ctx); + }; + Ok(()) + }) +} diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala index 52d8d09a0..4c96bef4e 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala @@ -100,6 +100,7 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with // Comet specific configurations val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) + val nativeArrowReaderEnabled = CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.get(sqlConf) (file: PartitionedFile) => { val sharedConf = broadcastedHadoopConf.value.value @@ -134,22 +135,42 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with } pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) - val batchReader = new BatchReader( - sharedConf, - file, - footer, - capacity, - requiredSchema, - isCaseSensitive, - useFieldId, - ignoreMissingIds, - datetimeRebaseSpec.mode == CORRECTED, - partitionSchema, - file.partitionValues, - JavaConverters.mapAsJavaMap(metrics)) - val iter = new RecordReaderIterator(batchReader) + val recordBatchReader = + if (nativeArrowReaderEnabled) { + val batchReader = new NativeBatchReader( + sharedConf, + file, + footer, + capacity, + requiredSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + batchReader.init() + batchReader + } else { + val batchReader = new BatchReader( + sharedConf, + file, + footer, + capacity, + requiredSchema, + isCaseSensitive, + useFieldId, + ignoreMissingIds, + datetimeRebaseSpec.mode == CORRECTED, + partitionSchema, + file.partitionValues, + JavaConverters.mapAsJavaMap(metrics)) + batchReader.init() + batchReader + } + val iter = new RecordReaderIterator(recordBatchReader) try { - batchReader.init() iter.asInstanceOf[Iterator[InternalRow]] } catch { case e: Throwable => diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 35ba06902..99ed5d3cb 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -79,6 +79,9 @@ abstract class CometTestBase conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") conf.set(CometConf.COMET_SPARK_TO_ARROW_ENABLED.key, "true") + conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true") + conf.set(CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key, "true") + conf.set(CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.key, "false") conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") conf.set(CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key, "true") conf