From faab96d89df9e3c20a06094f8b9a009d808df442 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Wed, 15 May 2024 09:42:17 +0800 Subject: [PATCH] fix: Handle compute murmur3 hash with dictionary input correctly --- core/src/execution/datafusion/spark_hash.rs | 151 +++++++++++--------- 1 file changed, 85 insertions(+), 66 deletions(-) diff --git a/core/src/execution/datafusion/spark_hash.rs b/core/src/execution/datafusion/spark_hash.rs index aa4269dd01..600c0924fd 100644 --- a/core/src/execution/datafusion/spark_hash.rs +++ b/core/src/execution/datafusion/spark_hash.rs @@ -17,7 +17,10 @@ //! This includes utilities for hashing and murmur3 hashing. -use arrow::datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; +use arrow::{ + compute::take, + datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, UInt8Type}, +}; use std::sync::Arc; use datafusion::{ @@ -95,19 +98,8 @@ pub(crate) fn spark_compatible_murmur3_hash>(data: T, seed: u32) } } -#[test] -fn test_murmur3() { - let _hashes = ["", "a", "ab", "abc", "abcd", "abcde"] - .into_iter() - .map(|s| spark_compatible_murmur3_hash(s.as_bytes(), 42) as i32) - .collect::>(); - let _expected = vec![ - 142593372, 1485273170, -97053317, 1322437556, -396302900, 814637928, - ]; -} - macro_rules! hash_array { - ($array_type:ident, $column: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { for (i, hash) in $hashes.iter_mut().enumerate() { @@ -123,8 +115,31 @@ macro_rules! hash_array { }; } +macro_rules! hash_array_boolean { + ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = spark_compatible_murmur3_hash( + $hash_input_type::from(array.value(i)).to_le_bytes(), + *hash, + ); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = spark_compatible_murmur3_hash( + $hash_input_type::from(array.value(i)).to_le_bytes(), + *hash, + ); + } + } + } + }; +} + macro_rules! hash_array_primitive { - ($array_type:ident, $column: ident, $ty: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $ty: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); @@ -143,7 +158,7 @@ macro_rules! hash_array_primitive { } macro_rules! hash_array_primitive_float { - ($array_type:ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); @@ -172,7 +187,7 @@ macro_rules! hash_array_primitive_float { } macro_rules! hash_array_decimal { - ($array_type:ident, $column: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { @@ -193,27 +208,33 @@ macro_rules! hash_array_decimal { fn create_hashes_dictionary( array: &ArrayRef, hashes_buffer: &mut [u32], + multi_col: bool, ) -> Result<()> { let dict_array = array.as_any().downcast_ref::>().unwrap(); - - // Hash each dictionary value once, and then use that computed - // hash for each key value to avoid a potentially expensive - // redundant hashing for large dictionary elements (e.g. strings) - let dict_values = Arc::clone(dict_array.values()); - let mut dict_hashes = vec![0; dict_values.len()]; - create_hashes(&[dict_values], &mut dict_hashes)?; - - for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { - if let Some(key) = key { - let idx = key.to_usize().ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert key value {:?} to usize in dictionary of type {:?}", - key, - dict_array.data_type() - )) - })?; - *hash = dict_hashes[idx] - } // no update for Null, consistent with other hashes + if multi_col { + // unpack the dictionary array as each row may have a different hash input + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_hashes(&[unpacked], hashes_buffer)?; + } else { + // For the first column, hash each dictionary value once, and then use + // that computed hash for each key value to avoid a potentially + // expensive redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + // same initial seed as Spark + let mut dict_hashes = vec![42; dict_values.len()]; + create_hashes(&[dict_values], &mut dict_hashes)?; + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } } Ok(()) } @@ -227,27 +248,11 @@ pub fn create_hashes<'a>( arrays: &[ArrayRef], hashes_buffer: &'a mut [u32], ) -> Result<&'a mut [u32]> { - for col in arrays { + for (i, col) in arrays.iter().enumerate() { + let multi_col = i >= 1; match col.data_type() { DataType::Boolean => { - let array = col.as_any().downcast_ref::().unwrap(); - if array.null_count() == 0 { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - *hash = spark_compatible_murmur3_hash( - i32::from(array.value(i)).to_le_bytes(), - *hash, - ); - } - } else { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash( - i32::from(array.value(i)).to_le_bytes(), - *hash, - ); - } - } - } + hash_array_boolean!(BooleanArray, col, i32, hashes_buffer); } DataType::Int8 => { hash_array_primitive!(Int8Array, col, i32, hashes_buffer); @@ -305,28 +310,28 @@ pub fn create_hashes<'a>( } DataType::Dictionary(index_type, _) => match **index_type { DataType::Int8 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, multi_col)?; } DataType::Int16 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, multi_col)?; } DataType::Int32 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, multi_col)?; } DataType::Int64 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, multi_col)?; } DataType::UInt8 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, multi_col)?; } DataType::UInt16 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, multi_col)?; } DataType::UInt32 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, multi_col)?; } DataType::UInt64 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, multi_col)?; } _ => { return Err(DataFusionError::Internal(format!( @@ -493,12 +498,21 @@ mod tests { #[test] fn test_str() { + let expected: Vec = vec![ + 3286402344, 2486176763, 142593372, 885025535, 2395000894, 1485273170, 0xfa37157b, + 1322437556, 0xe860e5cc, 814637928, + ]; + test_hashes!( StringArray, - vec!["hello", "bar", "", "😁", "天地"], - vec![3286402344, 2486176763, 142593372, 885025535, 2395000894] + vec!["hello", "bar", "", "😁", "天地", "a", "ab", "abc", "abcd", "abcde"], + expected ); // test with null input + let with_null_expected: Vec = vec![ + 3286402344, 2486176763, 42, 142593372, 885025535, 2395000894, 1485273170, 0xfa37157b, + 1322437556, 0xe860e5cc, 814637928, + ]; test_hashes!( StringArray, vec![ @@ -507,9 +521,14 @@ mod tests { None, Some(""), Some("😁"), - Some("天地") + Some("天地"), + Some("a"), + Some("ab"), + Some("abc"), + Some("abcd"), + Some("abcde"), ], - vec![3286402344, 2486176763, 42, 142593372, 885025535, 2395000894] + with_null_expected ); }