Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: CometReader.loadVector should not overwrite dictionary ids #476

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.arrow.c.CometSchemaImporter;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
Expand Down Expand Up @@ -88,6 +91,7 @@
*/
public class BatchReader extends RecordReader<Void, ColumnarBatch> implements Closeable {
private static final Logger LOG = LoggerFactory.getLogger(FileReader.class);
protected static final BufferAllocator ALLOCATOR = new RootAllocator();

private Configuration conf;
private int capacity;
Expand Down Expand Up @@ -552,6 +556,8 @@ private boolean loadNextRowGroupIfNecessary() throws Throwable {
numRowGroupsMetric.add(1);
}

CometSchemaImporter importer = new CometSchemaImporter(ALLOCATOR);

List<ColumnDescriptor> columns = requestedSchema.getColumns();
for (int i = 0; i < columns.size(); i++) {
if (missingColumns[i]) continue;
Expand All @@ -564,6 +570,7 @@ private boolean loadNextRowGroupIfNecessary() throws Throwable {
Utils.getColumnReader(
dataType,
columns.get(i),
importer,
capacity,
useDecimal128,
useLazyMaterialization,
Expand Down
25 changes: 13 additions & 12 deletions common/src/main/java/org/apache/comet/parquet/ColumnReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@

import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.CDataDictionaryProvider;
import org.apache.arrow.c.Data;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.c.CometSchemaImporter;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
Expand All @@ -42,6 +39,7 @@
import org.apache.parquet.column.page.DictionaryPage;
import org.apache.parquet.column.page.PageReader;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.spark.sql.comet.util.Utils$;
import org.apache.spark.sql.types.DataType;

import org.apache.comet.CometConf;
Expand All @@ -53,7 +51,6 @@

public class ColumnReader extends AbstractColumnReader {
protected static final Logger LOG = LoggerFactory.getLogger(ColumnReader.class);
protected static final BufferAllocator ALLOCATOR = new RootAllocator();

/**
* The current Comet vector holding all the values read by this column reader. Owned by this
Expand Down Expand Up @@ -89,18 +86,19 @@ public class ColumnReader extends AbstractColumnReader {
*/
boolean hadNull;

/** Dictionary provider for this column. */
private final CDataDictionaryProvider dictionaryProvider = new CDataDictionaryProvider();
private final CometSchemaImporter importer;

public ColumnReader(
DataType type,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLegacyDateTimestamp) {
super(type, descriptor, useDecimal128, useLegacyDateTimestamp);
assert batchSize > 0 : "Batch size must be positive, found " + batchSize;
this.batchSize = batchSize;
this.importer = importer;
initNative();
}

Expand Down Expand Up @@ -164,7 +162,6 @@ public void close() {
currentVector.close();
currentVector = null;
}
dictionaryProvider.close();
super.close();
}

Expand Down Expand Up @@ -209,10 +206,12 @@ public CometDecodedVector loadVector() {

try (ArrowArray array = ArrowArray.wrap(addresses[0]);
ArrowSchema schema = ArrowSchema.wrap(addresses[1])) {
FieldVector vector = Data.importVector(ALLOCATOR, array, schema, dictionaryProvider);
FieldVector vector =
Utils$.MODULE$.importVector(importer.getAllocator(), importer, array, schema);

DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary();

CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128, isUuid);
CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128);

// Update whether the current vector contains any null values. This is used in the following
// batch(s) to determine whether we can skip loading the native vector.
Expand All @@ -234,15 +233,17 @@ public CometDecodedVector loadVector() {

// We should already re-initiate `CometDictionary` here because `Data.importVector` API will
// release the previous dictionary vector and create a new one.
Dictionary arrowDictionary = dictionaryProvider.lookup(dictionaryEncoding.getId());
Dictionary arrowDictionary = importer.getProvider().lookup(dictionaryEncoding.getId());
CometPlainVector dictionaryVector =
new CometPlainVector(arrowDictionary.getVector(), useDecimal128, isUuid);
dictionary = new CometDictionary(dictionaryVector);

currentVector =
new CometDictionaryVector(
cometVector, dictionary, dictionaryProvider, useDecimal128, false, isUuid);
cometVector, dictionary, importer.getProvider(), useDecimal128, false, isUuid);

currentVector =
new CometDictionaryVector(cometVector, dictionary, importer.getProvider(), useDecimal128);
return currentVector;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.IOException;

import org.apache.arrow.c.CometSchemaImporter;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.page.PageReader;
import org.apache.spark.sql.types.DataType;
Expand All @@ -45,10 +46,11 @@ public class LazyColumnReader extends ColumnReader {
public LazyColumnReader(
DataType sparkReadType,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLegacyDateTimestamp) {
super(sparkReadType, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp);
super(sparkReadType, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp);
this.batchSize = 0; // the batch size is set later in `readBatch`
this.vector = new CometLazyVector(sparkReadType, this, useDecimal128);
}
Expand Down
10 changes: 7 additions & 3 deletions common/src/main/java/org/apache/comet/parquet/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.comet.parquet;

import org.apache.arrow.c.CometSchemaImporter;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
Expand All @@ -28,26 +29,29 @@ public class Utils {
public static ColumnReader getColumnReader(
DataType type,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLazyMaterialization) {
// TODO: support `useLegacyDateTimestamp` for Iceberg
return getColumnReader(
type, descriptor, batchSize, useDecimal128, useLazyMaterialization, true);
type, descriptor, importer, batchSize, useDecimal128, useLazyMaterialization, true);
}

public static ColumnReader getColumnReader(
DataType type,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLazyMaterialization,
boolean useLegacyDateTimestamp) {
if (useLazyMaterialization && supportLazyMaterialization(type)) {
return new LazyColumnReader(
type, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp);
type, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp);
} else {
return new ColumnReader(type, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp);
return new ColumnReader(
type, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp);
}
}

Expand Down
44 changes: 44 additions & 0 deletions common/src/main/scala/org/apache/arrow/c/CometSchemaImporter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.arrow.c

import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.types.pojo.Field

/**
* This is a simple wrapper around SchemaImporter to make it accessible from Java Arrow.
*/
class CometSchemaImporter(allocator: BufferAllocator) {
viirya marked this conversation as resolved.
Show resolved Hide resolved
val importer = new SchemaImporter(allocator)
val provider = new CDataDictionaryProvider()

def getAllocator(): BufferAllocator = allocator

def getProvider(): CDataDictionaryProvider = provider

def importField(schema: ArrowSchema): Field = {
try {
importer.importField(schema, provider)
} finally {
schema.release();
schema.close();
}
}
}
25 changes: 23 additions & 2 deletions common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import java.nio.channels.Channels

import scala.collection.JavaConverters._

import org.apache.arrow.c.CDataDictionaryProvider
import org.apache.arrow.c.{ArrowArray, ArrowSchema, CDataDictionaryProvider, CometSchemaImporter, Data}
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.dictionary.DictionaryProvider
Expand All @@ -38,7 +39,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

import org.apache.comet.vector.CometVector
import org.apache.comet.vector._

object Utils {
def getConfPath(confFileName: String): String = {
Expand Down Expand Up @@ -264,4 +265,24 @@ object Utils {
throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}")
}
}

/**
* Imports data from ArrowArray/ArrowSchema into a FieldVector. This is basically the same as
* Java Arrow `Data.importVector`. `Data.importVector` initiates `SchemaImporter` internally
* which is used to fill dictionary ids for dictionary encoded vectors. Every call to
* `importVector` will begin with dictionary ids starting from 0. So, separate calls to
* `importVector` will overwrite dictionary ids. To avoid this, we need to use the same
* `SchemaImporter` instance for all calls to `importVector`.
*/
def importVector(
viirya marked this conversation as resolved.
Show resolved Hide resolved
allocator: BufferAllocator,
importer: CometSchemaImporter,
array: ArrowArray,
schema: ArrowSchema): FieldVector = {
val field = importer.importField(schema)
val vector = field.createVector(allocator)
Data.importIntoVector(allocator, array, vector, importer.getProvider())

vector
}
}
Loading