Skip to content

Commit

Permalink
fix: CometReader.loadVector should not overwrite dictionary ids
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed May 26, 2024
1 parent 5bfe46d commit 0eb2ca6
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.arrow.c.CometSchemaImporter;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
Expand Down Expand Up @@ -88,6 +91,7 @@
*/
public class BatchReader extends RecordReader<Void, ColumnarBatch> implements Closeable {
private static final Logger LOG = LoggerFactory.getLogger(FileReader.class);
protected static final BufferAllocator ALLOCATOR = new RootAllocator();

private Configuration conf;
private int capacity;
Expand Down Expand Up @@ -552,6 +556,8 @@ private boolean loadNextRowGroupIfNecessary() throws Throwable {
numRowGroupsMetric.add(1);
}

CometSchemaImporter importer = new CometSchemaImporter(ALLOCATOR);

List<ColumnDescriptor> columns = requestedSchema.getColumns();
for (int i = 0; i < columns.size(); i++) {
if (missingColumns[i]) continue;
Expand All @@ -564,6 +570,7 @@ private boolean loadNextRowGroupIfNecessary() throws Throwable {
Utils.getColumnReader(
dataType,
columns.get(i),
importer,
capacity,
useDecimal128,
useLazyMaterialization,
Expand Down
25 changes: 13 additions & 12 deletions common/src/main/java/org/apache/comet/parquet/ColumnReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@

import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.CDataDictionaryProvider;
import org.apache.arrow.c.Data;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.c.CometSchemaImporter;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
Expand All @@ -42,6 +39,7 @@
import org.apache.parquet.column.page.DictionaryPage;
import org.apache.parquet.column.page.PageReader;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.spark.sql.comet.util.Utils$;
import org.apache.spark.sql.types.DataType;

import org.apache.comet.CometConf;
Expand All @@ -53,7 +51,6 @@

public class ColumnReader extends AbstractColumnReader {
protected static final Logger LOG = LoggerFactory.getLogger(ColumnReader.class);
protected static final BufferAllocator ALLOCATOR = new RootAllocator();

/**
* The current Comet vector holding all the values read by this column reader. Owned by this
Expand Down Expand Up @@ -89,18 +86,19 @@ public class ColumnReader extends AbstractColumnReader {
*/
boolean hadNull;

/** Dictionary provider for this column. */
private final CDataDictionaryProvider dictionaryProvider = new CDataDictionaryProvider();
private final CometSchemaImporter importer;

public ColumnReader(
DataType type,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLegacyDateTimestamp) {
super(type, descriptor, useDecimal128, useLegacyDateTimestamp);
assert batchSize > 0 : "Batch size must be positive, found " + batchSize;
this.batchSize = batchSize;
this.importer = importer;
initNative();
}

Expand Down Expand Up @@ -164,7 +162,6 @@ public void close() {
currentVector.close();
currentVector = null;
}
dictionaryProvider.close();
super.close();
}

Expand Down Expand Up @@ -209,10 +206,12 @@ public CometDecodedVector loadVector() {

try (ArrowArray array = ArrowArray.wrap(addresses[0]);
ArrowSchema schema = ArrowSchema.wrap(addresses[1])) {
FieldVector vector = Data.importVector(ALLOCATOR, array, schema, dictionaryProvider);
FieldVector vector =
Utils$.MODULE$.importVector(importer.getAllocator(), importer, array, schema);

DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary();

CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128, isUuid);
CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128);

// Update whether the current vector contains any null values. This is used in the following
// batch(s) to determine whether we can skip loading the native vector.
Expand All @@ -234,15 +233,17 @@ public CometDecodedVector loadVector() {

// We should already re-initiate `CometDictionary` here because `Data.importVector` API will
// release the previous dictionary vector and create a new one.
Dictionary arrowDictionary = dictionaryProvider.lookup(dictionaryEncoding.getId());
Dictionary arrowDictionary = importer.getProvider().lookup(dictionaryEncoding.getId());
CometPlainVector dictionaryVector =
new CometPlainVector(arrowDictionary.getVector(), useDecimal128, isUuid);
dictionary = new CometDictionary(dictionaryVector);

currentVector =
new CometDictionaryVector(
cometVector, dictionary, dictionaryProvider, useDecimal128, false, isUuid);
cometVector, dictionary, importer.getProvider(), useDecimal128, false, isUuid);

currentVector =
new CometDictionaryVector(cometVector, dictionary, importer.getProvider(), useDecimal128);
return currentVector;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.IOException;

import org.apache.arrow.c.CometSchemaImporter;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.page.PageReader;
import org.apache.spark.sql.types.DataType;
Expand All @@ -45,10 +46,11 @@ public class LazyColumnReader extends ColumnReader {
public LazyColumnReader(
DataType sparkReadType,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLegacyDateTimestamp) {
super(sparkReadType, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp);
super(sparkReadType, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp);
this.batchSize = 0; // the batch size is set later in `readBatch`
this.vector = new CometLazyVector(sparkReadType, this, useDecimal128);
}
Expand Down
10 changes: 7 additions & 3 deletions common/src/main/java/org/apache/comet/parquet/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.comet.parquet;

import org.apache.arrow.c.CometSchemaImporter;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
Expand All @@ -28,26 +29,29 @@ public class Utils {
public static ColumnReader getColumnReader(
DataType type,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLazyMaterialization) {
// TODO: support `useLegacyDateTimestamp` for Iceberg
return getColumnReader(
type, descriptor, batchSize, useDecimal128, useLazyMaterialization, true);
type, descriptor, importer, batchSize, useDecimal128, useLazyMaterialization, true);
}

public static ColumnReader getColumnReader(
DataType type,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLazyMaterialization,
boolean useLegacyDateTimestamp) {
if (useLazyMaterialization && supportLazyMaterialization(type)) {
return new LazyColumnReader(
type, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp);
type, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp);
} else {
return new ColumnReader(type, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp);
return new ColumnReader(
type, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp);
}
}

Expand Down
44 changes: 44 additions & 0 deletions common/src/main/scala/org/apache/arrow/c/CometSchemaImporter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.arrow.c

import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.types.pojo.Field

/**
* This is a simple wrapper around SchemaImporter to make it accessible from Java Arrow.
*/
class CometSchemaImporter(allocator: BufferAllocator) {
val importer = new SchemaImporter(allocator)
val provider = new CDataDictionaryProvider()

def getAllocator(): BufferAllocator = allocator

def getProvider(): CDataDictionaryProvider = provider

def importField(schema: ArrowSchema): Field = {
try {
importer.importField(schema, provider)
} finally {
schema.release();
schema.close();
}
}
}
25 changes: 23 additions & 2 deletions common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import java.nio.channels.Channels

import scala.collection.JavaConverters._

import org.apache.arrow.c.CDataDictionaryProvider
import org.apache.arrow.c.{ArrowArray, ArrowSchema, CDataDictionaryProvider, CometSchemaImporter, Data}
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.dictionary.DictionaryProvider
Expand All @@ -38,7 +39,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

import org.apache.comet.vector.CometVector
import org.apache.comet.vector._

object Utils {
def getConfPath(confFileName: String): String = {
Expand Down Expand Up @@ -264,4 +265,24 @@ object Utils {
throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}")
}
}

/**
* Imports data from ArrowArray/ArrowSchema into a FieldVector. This is basically the same as
* Java Arrow `Data.importVector`. `Data.importVector` initiates `SchemaImporter` internally
* which is used to fill dictionary ids for dictionary encoded vectors. Every call to
* `importVector` will begin with dictionary ids starting from 0. So, separate calls to
* `importVector` will overwrite dictionary ids. To avoid this, we need to use the same
* `SchemaImporter` instance for all calls to `importVector`.
*/
def importVector(
allocator: BufferAllocator,
importer: CometSchemaImporter,
array: ArrowArray,
schema: ArrowSchema): FieldVector = {
val field = importer.importField(schema)
val vector = field.createVector(allocator)
Data.importIntoVector(allocator, array, vector, importer.getProvider())

vector
}
}

0 comments on commit 0eb2ca6

Please sign in to comment.