diff --git a/common/src/main/java/org/apache/arrow/c/CometSchemaImporter.java b/common/src/main/java/org/apache/arrow/c/CometSchemaImporter.java new file mode 100644 index 000000000..32955f1ac --- /dev/null +++ b/common/src/main/java/org/apache/arrow/c/CometSchemaImporter.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.arrow.c; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.types.pojo.Field; + +/** This is a simple wrapper around SchemaImporter to make it accessible from Java Arrow. */ +public class CometSchemaImporter { + private final BufferAllocator allocator; + private final SchemaImporter importer; + private final CDataDictionaryProvider provider = new CDataDictionaryProvider(); + + public CometSchemaImporter(BufferAllocator allocator) { + this.allocator = allocator; + this.importer = new SchemaImporter(allocator); + } + + public BufferAllocator getAllocator() { + return allocator; + } + + public CDataDictionaryProvider getProvider() { + return provider; + } + + public Field importField(ArrowSchema schema) { + try { + return importer.importField(schema, provider); + } finally { + schema.release(); + schema.close(); + } + } + + /** + * Imports data from ArrowArray/ArrowSchema into a FieldVector. This is basically the same as Java + * Arrow `Data.importVector`. `Data.importVector` initiates `SchemaImporter` internally which is + * used to fill dictionary ids for dictionary encoded vectors. Every call to `importVector` will + * begin with dictionary ids starting from 0. So, separate calls to `importVector` will overwrite + * dictionary ids. To avoid this, we need to use the same `SchemaImporter` instance for all calls + * to `importVector`. + */ + public FieldVector importVector(ArrowArray array, ArrowSchema schema) { + Field field = importField(schema); + FieldVector vector = field.createVector(allocator); + Data.importIntoVector(allocator, array, vector, provider); + + return vector; + } + + public void close() { + provider.close(); + } +} diff --git a/common/src/main/java/org/apache/comet/parquet/BatchReader.java b/common/src/main/java/org/apache/comet/parquet/BatchReader.java index 9940390dc..bf8e6e550 100644 --- a/common/src/main/java/org/apache/comet/parquet/BatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/BatchReader.java @@ -37,6 +37,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.arrow.c.CometSchemaImporter; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; @@ -88,6 +91,7 @@ */ public class BatchReader extends RecordReader implements Closeable { private static final Logger LOG = LoggerFactory.getLogger(FileReader.class); + protected static final BufferAllocator ALLOCATOR = new RootAllocator(); private Configuration conf; private int capacity; @@ -104,6 +108,7 @@ public class BatchReader extends RecordReader implements Cl private MessageType requestedSchema; private CometVector[] vectors; private AbstractColumnReader[] columnReaders; + private CometSchemaImporter importer; private ColumnarBatch currentBatch; private Future> prefetchTask; private LinkedBlockingQueue> prefetchQueue; @@ -515,6 +520,10 @@ public void close() throws IOException { fileReader.close(); fileReader = null; } + if (importer != null) { + importer.close(); + importer = null; + } } @SuppressWarnings("deprecation") @@ -552,6 +561,9 @@ private boolean loadNextRowGroupIfNecessary() throws Throwable { numRowGroupsMetric.add(1); } + if (importer != null) importer.close(); + importer = new CometSchemaImporter(ALLOCATOR); + List columns = requestedSchema.getColumns(); for (int i = 0; i < columns.size(); i++) { if (missingColumns[i]) continue; @@ -564,6 +576,7 @@ private boolean loadNextRowGroupIfNecessary() throws Throwable { Utils.getColumnReader( dataType, columns.get(i), + importer, capacity, useDecimal128, useLazyMaterialization, diff --git a/common/src/main/java/org/apache/comet/parquet/ColumnReader.java b/common/src/main/java/org/apache/comet/parquet/ColumnReader.java index 46fd87f6b..9e594804f 100644 --- a/common/src/main/java/org/apache/comet/parquet/ColumnReader.java +++ b/common/src/main/java/org/apache/comet/parquet/ColumnReader.java @@ -27,10 +27,7 @@ import org.apache.arrow.c.ArrowArray; import org.apache.arrow.c.ArrowSchema; -import org.apache.arrow.c.CDataDictionaryProvider; -import org.apache.arrow.c.Data; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.c.CometSchemaImporter; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; @@ -53,7 +50,6 @@ public class ColumnReader extends AbstractColumnReader { protected static final Logger LOG = LoggerFactory.getLogger(ColumnReader.class); - protected static final BufferAllocator ALLOCATOR = new RootAllocator(); /** * The current Comet vector holding all the values read by this column reader. Owned by this @@ -89,18 +85,19 @@ public class ColumnReader extends AbstractColumnReader { */ boolean hadNull; - /** Dictionary provider for this column. */ - private final CDataDictionaryProvider dictionaryProvider = new CDataDictionaryProvider(); + private final CometSchemaImporter importer; public ColumnReader( DataType type, ColumnDescriptor descriptor, + CometSchemaImporter importer, int batchSize, boolean useDecimal128, boolean useLegacyDateTimestamp) { super(type, descriptor, useDecimal128, useLegacyDateTimestamp); assert batchSize > 0 : "Batch size must be positive, found " + batchSize; this.batchSize = batchSize; + this.importer = importer; initNative(); } @@ -164,7 +161,6 @@ public void close() { currentVector.close(); currentVector = null; } - dictionaryProvider.close(); super.close(); } @@ -209,10 +205,11 @@ public CometDecodedVector loadVector() { try (ArrowArray array = ArrowArray.wrap(addresses[0]); ArrowSchema schema = ArrowSchema.wrap(addresses[1])) { - FieldVector vector = Data.importVector(ALLOCATOR, array, schema, dictionaryProvider); + FieldVector vector = importer.importVector(array, schema); + DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary(); - CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128, isUuid); + CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128); // Update whether the current vector contains any null values. This is used in the following // batch(s) to determine whether we can skip loading the native vector. @@ -230,19 +227,25 @@ public CometDecodedVector loadVector() { // return plain vector. currentVector = cometVector; return currentVector; - } else if (dictionary == null) { - // There is dictionary from native side but the Java side dictionary hasn't been - // initialized yet. - Dictionary arrowDictionary = dictionaryProvider.lookup(dictionaryEncoding.getId()); - CometPlainVector dictionaryVector = - new CometPlainVector(arrowDictionary.getVector(), useDecimal128, isUuid); + } + + // We should already re-initiate `CometDictionary` here because `Data.importVector` API will + // release the previous dictionary vector and create a new one. + Dictionary arrowDictionary = importer.getProvider().lookup(dictionaryEncoding.getId()); + CometPlainVector dictionaryVector = + new CometPlainVector(arrowDictionary.getVector(), useDecimal128, isUuid); + if (dictionary != null) { + dictionary.setDictionaryVector(dictionaryVector); + } else { dictionary = new CometDictionary(dictionaryVector); } currentVector = new CometDictionaryVector( - cometVector, dictionary, dictionaryProvider, useDecimal128, false, isUuid); + cometVector, dictionary, importer.getProvider(), useDecimal128, false, isUuid); + currentVector = + new CometDictionaryVector(cometVector, dictionary, importer.getProvider(), useDecimal128); return currentVector; } } diff --git a/common/src/main/java/org/apache/comet/parquet/LazyColumnReader.java b/common/src/main/java/org/apache/comet/parquet/LazyColumnReader.java index a15d84192..dd08a88ab 100644 --- a/common/src/main/java/org/apache/comet/parquet/LazyColumnReader.java +++ b/common/src/main/java/org/apache/comet/parquet/LazyColumnReader.java @@ -21,6 +21,7 @@ import java.io.IOException; +import org.apache.arrow.c.CometSchemaImporter; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.page.PageReader; import org.apache.spark.sql.types.DataType; @@ -45,10 +46,11 @@ public class LazyColumnReader extends ColumnReader { public LazyColumnReader( DataType sparkReadType, ColumnDescriptor descriptor, + CometSchemaImporter importer, int batchSize, boolean useDecimal128, boolean useLegacyDateTimestamp) { - super(sparkReadType, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp); + super(sparkReadType, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp); this.batchSize = 0; // the batch size is set later in `readBatch` this.vector = new CometLazyVector(sparkReadType, this, useDecimal128); } diff --git a/common/src/main/java/org/apache/comet/parquet/Utils.java b/common/src/main/java/org/apache/comet/parquet/Utils.java index 95ca06cda..99f3a4edd 100644 --- a/common/src/main/java/org/apache/comet/parquet/Utils.java +++ b/common/src/main/java/org/apache/comet/parquet/Utils.java @@ -19,6 +19,7 @@ package org.apache.comet.parquet; +import org.apache.arrow.c.CometSchemaImporter; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.PrimitiveType; @@ -28,26 +29,29 @@ public class Utils { public static ColumnReader getColumnReader( DataType type, ColumnDescriptor descriptor, + CometSchemaImporter importer, int batchSize, boolean useDecimal128, boolean useLazyMaterialization) { // TODO: support `useLegacyDateTimestamp` for Iceberg return getColumnReader( - type, descriptor, batchSize, useDecimal128, useLazyMaterialization, true); + type, descriptor, importer, batchSize, useDecimal128, useLazyMaterialization, true); } public static ColumnReader getColumnReader( DataType type, ColumnDescriptor descriptor, + CometSchemaImporter importer, int batchSize, boolean useDecimal128, boolean useLazyMaterialization, boolean useLegacyDateTimestamp) { if (useLazyMaterialization && supportLazyMaterialization(type)) { return new LazyColumnReader( - type, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp); + type, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp); } else { - return new ColumnReader(type, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp); + return new ColumnReader( + type, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp); } } diff --git a/common/src/main/java/org/apache/comet/vector/CometDictionary.java b/common/src/main/java/org/apache/comet/vector/CometDictionary.java index 9aa42bd68..b213b75d8 100644 --- a/common/src/main/java/org/apache/comet/vector/CometDictionary.java +++ b/common/src/main/java/org/apache/comet/vector/CometDictionary.java @@ -26,7 +26,7 @@ public class CometDictionary implements AutoCloseable { private static final int DECIMAL_BYTE_WIDTH = 16; - private final CometPlainVector values; + private CometPlainVector values; private final int numValues; /** Decoded dictionary values. Only one of the following is set. */ @@ -47,6 +47,13 @@ public CometDictionary(CometPlainVector values) { initialize(); } + public void setDictionaryVector(CometPlainVector values) { + this.values = values; + if (values.numValues() != numValues) { + throw new IllegalArgumentException("Mismatched dictionary size"); + } + } + public ValueVector getValueVector() { return values.getValueVector(); } diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index fd1f9166d..7e8a96f28 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -503,41 +503,37 @@ impl Cast { fn cast_array(&self, array: ArrayRef) -> DataFusionResult { let to_type = &self.data_type; let array = array_with_timezone(array, self.timezone.clone(), Some(to_type)); + let from_type = array.data_type().clone(); + + // unpack dictionary string arrays first + // TODO: we are unpacking a dictionary-encoded array and then performing + // the cast. We could potentially improve performance here by casting the + // dictionary values directly without unpacking the array first, although this + // would add more complexity to the code + let array = match &from_type { + DataType::Dictionary(key_type, value_type) + if key_type.as_ref() == &DataType::Int32 + && (value_type.as_ref() == &DataType::Utf8 + || value_type.as_ref() == &DataType::LargeUtf8) => + { + cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)? + } + _ => array, + }; let from_type = array.data_type(); + let cast_result = match (from_type, to_type) { (DataType::Utf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? + Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode) } (DataType::LargeUtf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? + Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode) } (DataType::Utf8, DataType::Timestamp(_, _)) => { - Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)? + Self::cast_string_to_timestamp(&array, to_type, self.eval_mode) } (DataType::Utf8, DataType::Date32) => { - Self::cast_string_to_date(&array, to_type, self.eval_mode)? - } - (DataType::Dictionary(key_type, value_type), DataType::Date32) - if key_type.as_ref() == &DataType::Int32 - && (value_type.as_ref() == &DataType::Utf8 - || value_type.as_ref() == &DataType::LargeUtf8) => - { - match value_type.as_ref() { - DataType::Utf8 => { - let unpacked_array = - cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; - Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)? - } - DataType::LargeUtf8 => { - let unpacked_array = - cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; - Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)? - } - dt => unreachable!( - "{}", - format!("invalid value type {dt} for dictionary-encoded string array") - ), - } + Self::cast_string_to_date(&array, to_type, self.eval_mode) } (DataType::Int64, DataType::Int32) | (DataType::Int64, DataType::Int16) @@ -547,61 +543,33 @@ impl Cast { | (DataType::Int16, DataType::Int8) if self.eval_mode != EvalMode::Try => { - Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)? + Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type) } ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode)?, + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode), ( DataType::LargeUtf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode)?, - ( - DataType::Dictionary(key_type, value_type), - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) if key_type.as_ref() == &DataType::Int32 - && (value_type.as_ref() == &DataType::Utf8 - || value_type.as_ref() == &DataType::LargeUtf8) => - { - // TODO: we are unpacking a dictionary-encoded array and then performing - // the cast. We could potentially improve performance here by casting the - // dictionary values directly without unpacking the array first, although this - // would add more complexity to the code - match value_type.as_ref() { - DataType::Utf8 => { - let unpacked_array = - cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; - Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? - } - DataType::LargeUtf8 => { - let unpacked_array = - cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; - Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? - } - dt => unreachable!( - "{}", - format!("invalid value type {dt} for dictionary-encoded string array") - ), - } - } + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode), (DataType::Float64, DataType::Utf8) => { - Self::spark_cast_float64_to_utf8::(&array, self.eval_mode)? + Self::spark_cast_float64_to_utf8::(&array, self.eval_mode) } (DataType::Float64, DataType::LargeUtf8) => { - Self::spark_cast_float64_to_utf8::(&array, self.eval_mode)? + Self::spark_cast_float64_to_utf8::(&array, self.eval_mode) } (DataType::Float32, DataType::Utf8) => { - Self::spark_cast_float32_to_utf8::(&array, self.eval_mode)? + Self::spark_cast_float32_to_utf8::(&array, self.eval_mode) } (DataType::Float32, DataType::LargeUtf8) => { - Self::spark_cast_float32_to_utf8::(&array, self.eval_mode)? + Self::spark_cast_float32_to_utf8::(&array, self.eval_mode) } (DataType::Float32, DataType::Decimal128(precision, scale)) => { - Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode)? + Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode) } (DataType::Float64, DataType::Decimal128(precision, scale)) => { - Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode)? + Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode) } (DataType::Float32, DataType::Int8) | (DataType::Float32, DataType::Int16) @@ -622,14 +590,94 @@ impl Cast { self.eval_mode, from_type, to_type, - )? + ) + } + _ if Self::is_datafusion_spark_compatible(from_type, to_type) => { + // use DataFusion cast only when we know that it is compatible with Spark + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) } _ => { - // when we have no Spark-specific casting we delegate to DataFusion - cast_with_options(&array, to_type, &CAST_OPTIONS)? + // we should never reach this code because the Scala code should be checking + // for supported cast operations and falling back to Spark for anything that + // is not yet supported + Err(CometError::Internal(format!( + "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" + ))) } }; - Ok(spark_cast(cast_result, from_type, to_type)) + Ok(spark_cast(cast_result?, from_type, to_type)) + } + + /// Determines if DataFusion supports the given cast in a way that is + /// compatible with Spark + fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { + if from_type == to_type { + return true; + } + match from_type { + DataType::Boolean => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + ), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + // note that the cast from Int32/Int64 -> Decimal128 here is actually + // not compatible with Spark (no overflow checks) but we have tests that + // rely on this cast working so we have to leave it here for now + matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Utf8 + ) + } + DataType::Float32 | DataType::Float64 => matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ), + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ), + DataType::Utf8 => matches!(to_type, DataType::Binary), + DataType::Date32 => matches!(to_type, DataType::Utf8), + DataType::Timestamp(_, _) => { + matches!( + to_type, + DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) + ) + } + DataType::Binary => { + // note that this is not completely Spark compatible because + // DataFusion only supports binary data containing valid UTF-8 strings + matches!(to_type, DataType::Utf8) + } + _ => false, + } } fn cast_string_to_int( diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 01d892381..7a37e3aae 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -395,7 +395,8 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; // Spark Substring's start is 1-based when start > 0 let start = expr.start - i32::from(expr.start > 0); - let len = expr.len; + // substring negative len is treated as 0 in Spark + let len = std::cmp::max(expr.len, 0); Ok(Arc::new(SubstringExec::new( child, diff --git a/core/src/execution/datafusion/spark_hash.rs b/core/src/execution/datafusion/spark_hash.rs index aa4269dd0..6d25a72f6 100644 --- a/core/src/execution/datafusion/spark_hash.rs +++ b/core/src/execution/datafusion/spark_hash.rs @@ -17,7 +17,10 @@ //! This includes utilities for hashing and murmur3 hashing. -use arrow::datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; +use arrow::{ + compute::take, + datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, UInt8Type}, +}; use std::sync::Arc; use datafusion::{ @@ -95,19 +98,8 @@ pub(crate) fn spark_compatible_murmur3_hash>(data: T, seed: u32) } } -#[test] -fn test_murmur3() { - let _hashes = ["", "a", "ab", "abc", "abcd", "abcde"] - .into_iter() - .map(|s| spark_compatible_murmur3_hash(s.as_bytes(), 42) as i32) - .collect::>(); - let _expected = vec![ - 142593372, 1485273170, -97053317, 1322437556, -396302900, 814637928, - ]; -} - macro_rules! hash_array { - ($array_type:ident, $column: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { for (i, hash) in $hashes.iter_mut().enumerate() { @@ -123,8 +115,31 @@ macro_rules! hash_array { }; } +macro_rules! hash_array_boolean { + ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = spark_compatible_murmur3_hash( + $hash_input_type::from(array.value(i)).to_le_bytes(), + *hash, + ); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = spark_compatible_murmur3_hash( + $hash_input_type::from(array.value(i)).to_le_bytes(), + *hash, + ); + } + } + } + }; +} + macro_rules! hash_array_primitive { - ($array_type:ident, $column: ident, $ty: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $ty: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); @@ -143,7 +158,7 @@ macro_rules! hash_array_primitive { } macro_rules! hash_array_primitive_float { - ($array_type:ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); @@ -172,7 +187,7 @@ macro_rules! hash_array_primitive_float { } macro_rules! hash_array_decimal { - ($array_type:ident, $column: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { @@ -193,27 +208,33 @@ macro_rules! hash_array_decimal { fn create_hashes_dictionary( array: &ArrayRef, hashes_buffer: &mut [u32], + first_col: bool, ) -> Result<()> { let dict_array = array.as_any().downcast_ref::>().unwrap(); - - // Hash each dictionary value once, and then use that computed - // hash for each key value to avoid a potentially expensive - // redundant hashing for large dictionary elements (e.g. strings) - let dict_values = Arc::clone(dict_array.values()); - let mut dict_hashes = vec![0; dict_values.len()]; - create_hashes(&[dict_values], &mut dict_hashes)?; - - for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { - if let Some(key) = key { - let idx = key.to_usize().ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert key value {:?} to usize in dictionary of type {:?}", - key, - dict_array.data_type() - )) - })?; - *hash = dict_hashes[idx] - } // no update for Null, consistent with other hashes + if !first_col { + // unpack the dictionary array as each row may have a different hash input + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_hashes(&[unpacked], hashes_buffer)?; + } else { + // For the first column, hash each dictionary value once, and then use + // that computed hash for each key value to avoid a potentially + // expensive redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + // same initial seed as Spark + let mut dict_hashes = vec![42; dict_values.len()]; + create_hashes(&[dict_values], &mut dict_hashes)?; + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } } Ok(()) } @@ -227,27 +248,11 @@ pub fn create_hashes<'a>( arrays: &[ArrayRef], hashes_buffer: &'a mut [u32], ) -> Result<&'a mut [u32]> { - for col in arrays { + for (i, col) in arrays.iter().enumerate() { + let first_col = i == 0; match col.data_type() { DataType::Boolean => { - let array = col.as_any().downcast_ref::().unwrap(); - if array.null_count() == 0 { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - *hash = spark_compatible_murmur3_hash( - i32::from(array.value(i)).to_le_bytes(), - *hash, - ); - } - } else { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash( - i32::from(array.value(i)).to_le_bytes(), - *hash, - ); - } - } - } + hash_array_boolean!(BooleanArray, col, i32, hashes_buffer); } DataType::Int8 => { hash_array_primitive!(Int8Array, col, i32, hashes_buffer); @@ -305,28 +310,28 @@ pub fn create_hashes<'a>( } DataType::Dictionary(index_type, _) => match **index_type { DataType::Int8 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::Int16 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::Int32 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::Int64 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::UInt8 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::UInt16 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::UInt32 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::UInt64 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } _ => { return Err(DataFusionError::Internal(format!( @@ -363,78 +368,64 @@ mod tests { use crate::execution::datafusion::spark_hash::{create_hashes, pmod}; use datafusion::arrow::array::{ArrayRef, Int32Array, Int64Array, Int8Array, StringArray}; - macro_rules! test_hashes { - ($ty:ty, $values:expr, $expected:expr) => { - let i = Arc::new(<$ty>::from($values)) as ArrayRef; - let mut hashes = vec![42; $values.len()]; + macro_rules! test_hashes_internal { + ($input: expr, $len: expr, $expected: expr) => { + let i = $input as ArrayRef; + let mut hashes = vec![42; $len]; create_hashes(&[i], &mut hashes).unwrap(); assert_eq!(hashes, $expected); }; } + fn test_murmur3_hash>> + 'static>( + values: Vec>, + expected: Vec, + ) { + // copied before inserting nulls + let mut input_with_nulls = values.clone(); + let mut expected_with_nulls = expected.clone(); + let len = values.len(); + let i = Arc::new(T::from(values)) as ArrayRef; + test_hashes_internal!(i, len, expected); + + // test with nulls + let median = len / 2; + input_with_nulls.insert(0, None); + input_with_nulls.insert(median, None); + expected_with_nulls.insert(0, 42); + expected_with_nulls.insert(median, 42); + let with_nulls_len = len + 2; + let nullable_input = Arc::new(T::from(input_with_nulls)) as ArrayRef; + test_hashes_internal!(nullable_input, with_nulls_len, expected_with_nulls); + } + #[test] fn test_i8() { - test_hashes!( - Int8Array, + test_murmur3_hash::( vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], - vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365] - ); - // with null input - test_hashes!( - Int8Array, - vec![Some(1), None, Some(-1), Some(i8::MAX), Some(i8::MIN)], - vec![0xdea578e3, 42, 0xa0590e3d, 0x43b4d8ed, 0x422a1365] + vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365], ); } #[test] fn test_i32() { - test_hashes!( - Int32Array, + test_murmur3_hash::( vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], - vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6] - ); - // with null input - test_hashes!( - Int32Array, - vec![ - Some(1), - Some(0), - Some(-1), - None, - Some(i32::MAX), - Some(i32::MIN) - ], - vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 42, 0x07fb67e7, 0x2b1f0fc6] + vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6], ); } #[test] fn test_i64() { - test_hashes!( - Int64Array, + test_murmur3_hash::( vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], - vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb] - ); - // with null input - test_hashes!( - Int64Array, - vec![ - Some(1), - Some(0), - Some(-1), - None, - Some(i64::MAX), - Some(i64::MIN) - ], - vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 42, 0xa05b5d7b, 0xcd1e64fb] + vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb], ); } #[test] fn test_f32() { - test_hashes!( - Float32Array, + test_murmur3_hash::( vec![ Some(1.0), Some(0.0), @@ -443,28 +434,15 @@ mod tests { Some(99999999999.99999999999), Some(-99999999999.99999999999), ], - vec![0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 0xc0361c86] - ); - // with null input - test_hashes!( - Float32Array, vec![ - Some(1.0), - Some(0.0), - Some(-0.0), - Some(-1.0), - None, - Some(99999999999.99999999999), - Some(-99999999999.99999999999) + 0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 0xc0361c86, ], - vec![0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 42, 0xcbdc340f, 0xc0361c86] ); } #[test] fn test_f64() { - test_hashes!( - Float64Array, + test_murmur3_hash::( vec![ Some(1.0), Some(0.0), @@ -473,44 +451,26 @@ mod tests { Some(99999999999.99999999999), Some(-99999999999.99999999999), ], - vec![0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 0xa0eef9f9] - ); - // with null input - test_hashes!( - Float64Array, vec![ - Some(1.0), - Some(0.0), - Some(-0.0), - Some(-1.0), - None, - Some(99999999999.99999999999), - Some(-99999999999.99999999999) + 0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 0xa0eef9f9, ], - vec![0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 42, 0xb87e1595, 0xa0eef9f9] ); } #[test] fn test_str() { - test_hashes!( - StringArray, - vec!["hello", "bar", "", "😁", "天地"], - vec![3286402344, 2486176763, 142593372, 885025535, 2395000894] - ); - // test with null input - test_hashes!( - StringArray, - vec![ - Some("hello"), - Some("bar"), - None, - Some(""), - Some("😁"), - Some("天地") - ], - vec![3286402344, 2486176763, 42, 142593372, 885025535, 2395000894] - ); + let input = vec![ + "hello", "bar", "", "😁", "天地", "a", "ab", "abc", "abcd", "abcde", + ] + .iter() + .map(|s| Some(s.to_string())) + .collect::>>(); + let expected: Vec = vec![ + 3286402344, 2486176763, 142593372, 885025535, 2395000894, 1485273170, 0xfa37157b, + 1322437556, 0xe860e5cc, 814637928, + ]; + + test_murmur3_hash::(input, expected); } #[test] diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 6ca4baf60..1afdd78ec 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -156,6 +156,12 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("string type and substring") { withParquetTable((0 until 5).map(i => (i.toString, (i + 100).toString)), "tbl") { checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, 2) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, -2) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, 2) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, -2) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, 10) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, 0, 0) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, 1, 0) FROM tbl") } } @@ -1454,17 +1460,55 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { withTable(table) { sql(s"create table $table(col string, a int, b float) using parquet") sql(s""" - |insert into $table values - |('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) - |, ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) - |""".stripMargin) + |insert into $table values + |('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) + |, ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) + |""".stripMargin) checkSparkAnswerAndOperator(""" - |select - |md5(col), md5(cast(a as string)), md5(cast(b as string)), - |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128) - |from test - |""".stripMargin) + |select + |md5(col), md5(cast(a as string)), md5(cast(b as string)), + |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128) + |from test + |""".stripMargin) + } + } + } + } + + test("hash functions with random input") { + val dataGen = DataGenerator.DEFAULT + // sufficient number of rows to create dictionary encoded ArrowArray. + val randomNumRows = 1000 + + val whitespaceChars = " \t\r\n" + val timestampPattern = "0123456789/:T" + whitespaceChars + Seq(true, false).foreach { dictionary => + withSQLConf( + "parquet.enable.dictionary" -> dictionary.toString, + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + val table = "test" + withTable(table) { + sql(s"create table $table(col string, a int, b float) using parquet") + // TODO: Add a Row generator in the data gen class and replace th following code + val col = dataGen.generateStrings(randomNumRows, timestampPattern, 6) + val colA = dataGen.generateInts(randomNumRows) + val colB = dataGen.generateFloats(randomNumRows) + val data = col.zip(colA).zip(colB).map { case ((a, b), c) => (a, b, c) } + data + .toDF("col", "a", "b") + .write + .mode("append") + .insertInto(table) + // with random generated data + // disable cast(b as string) for now, as the cast from float to string may produce incompatible result + checkSparkAnswerAndOperator(""" + |select + |md5(col), md5(cast(a as string)), --md5(cast(b as string)), + |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128) + |from test + |""".stripMargin) } } }