From f8e0691b4f7a30b00fbda9316aafa6ef84b39ae9 Mon Sep 17 00:00:00 2001 From: advancedxy Date: Fri, 24 May 2024 23:13:24 +0800 Subject: [PATCH] fix: Compute murmur3 hash with dictionary input correctly (#433) * fix: Handle compute murmur3 hash with dictionary input correctly * add unit tests * spotless apply * apply scala fix * address comment * another style issue * Update spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala Co-authored-by: Liang-Chi Hsieh --------- Co-authored-by: Liang-Chi Hsieh (cherry picked from commit 93af70438b92049226dfd130e04dd83a9863f1a9) --- core/src/execution/datafusion/spark_hash.rs | 270 ++++++++---------- .../apache/comet/CometExpressionSuite.scala | 58 +++- 2 files changed, 163 insertions(+), 165 deletions(-) diff --git a/core/src/execution/datafusion/spark_hash.rs b/core/src/execution/datafusion/spark_hash.rs index aa4269dd01..6d25a72f6c 100644 --- a/core/src/execution/datafusion/spark_hash.rs +++ b/core/src/execution/datafusion/spark_hash.rs @@ -17,7 +17,10 @@ //! This includes utilities for hashing and murmur3 hashing. -use arrow::datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; +use arrow::{ + compute::take, + datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, UInt8Type}, +}; use std::sync::Arc; use datafusion::{ @@ -95,19 +98,8 @@ pub(crate) fn spark_compatible_murmur3_hash>(data: T, seed: u32) } } -#[test] -fn test_murmur3() { - let _hashes = ["", "a", "ab", "abc", "abcd", "abcde"] - .into_iter() - .map(|s| spark_compatible_murmur3_hash(s.as_bytes(), 42) as i32) - .collect::>(); - let _expected = vec![ - 142593372, 1485273170, -97053317, 1322437556, -396302900, 814637928, - ]; -} - macro_rules! hash_array { - ($array_type:ident, $column: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { for (i, hash) in $hashes.iter_mut().enumerate() { @@ -123,8 +115,31 @@ macro_rules! hash_array { }; } +macro_rules! hash_array_boolean { + ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = spark_compatible_murmur3_hash( + $hash_input_type::from(array.value(i)).to_le_bytes(), + *hash, + ); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = spark_compatible_murmur3_hash( + $hash_input_type::from(array.value(i)).to_le_bytes(), + *hash, + ); + } + } + } + }; +} + macro_rules! hash_array_primitive { - ($array_type:ident, $column: ident, $ty: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $ty: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); @@ -143,7 +158,7 @@ macro_rules! hash_array_primitive { } macro_rules! hash_array_primitive_float { - ($array_type:ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); @@ -172,7 +187,7 @@ macro_rules! hash_array_primitive_float { } macro_rules! hash_array_decimal { - ($array_type:ident, $column: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $hashes: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { @@ -193,27 +208,33 @@ macro_rules! hash_array_decimal { fn create_hashes_dictionary( array: &ArrayRef, hashes_buffer: &mut [u32], + first_col: bool, ) -> Result<()> { let dict_array = array.as_any().downcast_ref::>().unwrap(); - - // Hash each dictionary value once, and then use that computed - // hash for each key value to avoid a potentially expensive - // redundant hashing for large dictionary elements (e.g. strings) - let dict_values = Arc::clone(dict_array.values()); - let mut dict_hashes = vec![0; dict_values.len()]; - create_hashes(&[dict_values], &mut dict_hashes)?; - - for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { - if let Some(key) = key { - let idx = key.to_usize().ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert key value {:?} to usize in dictionary of type {:?}", - key, - dict_array.data_type() - )) - })?; - *hash = dict_hashes[idx] - } // no update for Null, consistent with other hashes + if !first_col { + // unpack the dictionary array as each row may have a different hash input + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_hashes(&[unpacked], hashes_buffer)?; + } else { + // For the first column, hash each dictionary value once, and then use + // that computed hash for each key value to avoid a potentially + // expensive redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + // same initial seed as Spark + let mut dict_hashes = vec![42; dict_values.len()]; + create_hashes(&[dict_values], &mut dict_hashes)?; + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } } Ok(()) } @@ -227,27 +248,11 @@ pub fn create_hashes<'a>( arrays: &[ArrayRef], hashes_buffer: &'a mut [u32], ) -> Result<&'a mut [u32]> { - for col in arrays { + for (i, col) in arrays.iter().enumerate() { + let first_col = i == 0; match col.data_type() { DataType::Boolean => { - let array = col.as_any().downcast_ref::().unwrap(); - if array.null_count() == 0 { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - *hash = spark_compatible_murmur3_hash( - i32::from(array.value(i)).to_le_bytes(), - *hash, - ); - } - } else { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash( - i32::from(array.value(i)).to_le_bytes(), - *hash, - ); - } - } - } + hash_array_boolean!(BooleanArray, col, i32, hashes_buffer); } DataType::Int8 => { hash_array_primitive!(Int8Array, col, i32, hashes_buffer); @@ -305,28 +310,28 @@ pub fn create_hashes<'a>( } DataType::Dictionary(index_type, _) => match **index_type { DataType::Int8 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::Int16 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::Int32 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::Int64 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::UInt8 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::UInt16 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::UInt32 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } DataType::UInt64 => { - create_hashes_dictionary::(col, hashes_buffer)?; + create_hashes_dictionary::(col, hashes_buffer, first_col)?; } _ => { return Err(DataFusionError::Internal(format!( @@ -363,78 +368,64 @@ mod tests { use crate::execution::datafusion::spark_hash::{create_hashes, pmod}; use datafusion::arrow::array::{ArrayRef, Int32Array, Int64Array, Int8Array, StringArray}; - macro_rules! test_hashes { - ($ty:ty, $values:expr, $expected:expr) => { - let i = Arc::new(<$ty>::from($values)) as ArrayRef; - let mut hashes = vec![42; $values.len()]; + macro_rules! test_hashes_internal { + ($input: expr, $len: expr, $expected: expr) => { + let i = $input as ArrayRef; + let mut hashes = vec![42; $len]; create_hashes(&[i], &mut hashes).unwrap(); assert_eq!(hashes, $expected); }; } + fn test_murmur3_hash>> + 'static>( + values: Vec>, + expected: Vec, + ) { + // copied before inserting nulls + let mut input_with_nulls = values.clone(); + let mut expected_with_nulls = expected.clone(); + let len = values.len(); + let i = Arc::new(T::from(values)) as ArrayRef; + test_hashes_internal!(i, len, expected); + + // test with nulls + let median = len / 2; + input_with_nulls.insert(0, None); + input_with_nulls.insert(median, None); + expected_with_nulls.insert(0, 42); + expected_with_nulls.insert(median, 42); + let with_nulls_len = len + 2; + let nullable_input = Arc::new(T::from(input_with_nulls)) as ArrayRef; + test_hashes_internal!(nullable_input, with_nulls_len, expected_with_nulls); + } + #[test] fn test_i8() { - test_hashes!( - Int8Array, + test_murmur3_hash::( vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], - vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365] - ); - // with null input - test_hashes!( - Int8Array, - vec![Some(1), None, Some(-1), Some(i8::MAX), Some(i8::MIN)], - vec![0xdea578e3, 42, 0xa0590e3d, 0x43b4d8ed, 0x422a1365] + vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365], ); } #[test] fn test_i32() { - test_hashes!( - Int32Array, + test_murmur3_hash::( vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], - vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6] - ); - // with null input - test_hashes!( - Int32Array, - vec![ - Some(1), - Some(0), - Some(-1), - None, - Some(i32::MAX), - Some(i32::MIN) - ], - vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 42, 0x07fb67e7, 0x2b1f0fc6] + vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6], ); } #[test] fn test_i64() { - test_hashes!( - Int64Array, + test_murmur3_hash::( vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], - vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb] - ); - // with null input - test_hashes!( - Int64Array, - vec![ - Some(1), - Some(0), - Some(-1), - None, - Some(i64::MAX), - Some(i64::MIN) - ], - vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 42, 0xa05b5d7b, 0xcd1e64fb] + vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb], ); } #[test] fn test_f32() { - test_hashes!( - Float32Array, + test_murmur3_hash::( vec![ Some(1.0), Some(0.0), @@ -443,28 +434,15 @@ mod tests { Some(99999999999.99999999999), Some(-99999999999.99999999999), ], - vec![0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 0xc0361c86] - ); - // with null input - test_hashes!( - Float32Array, vec![ - Some(1.0), - Some(0.0), - Some(-0.0), - Some(-1.0), - None, - Some(99999999999.99999999999), - Some(-99999999999.99999999999) + 0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 0xc0361c86, ], - vec![0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 42, 0xcbdc340f, 0xc0361c86] ); } #[test] fn test_f64() { - test_hashes!( - Float64Array, + test_murmur3_hash::( vec![ Some(1.0), Some(0.0), @@ -473,44 +451,26 @@ mod tests { Some(99999999999.99999999999), Some(-99999999999.99999999999), ], - vec![0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 0xa0eef9f9] - ); - // with null input - test_hashes!( - Float64Array, vec![ - Some(1.0), - Some(0.0), - Some(-0.0), - Some(-1.0), - None, - Some(99999999999.99999999999), - Some(-99999999999.99999999999) + 0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 0xa0eef9f9, ], - vec![0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 42, 0xb87e1595, 0xa0eef9f9] ); } #[test] fn test_str() { - test_hashes!( - StringArray, - vec!["hello", "bar", "", "😁", "天地"], - vec![3286402344, 2486176763, 142593372, 885025535, 2395000894] - ); - // test with null input - test_hashes!( - StringArray, - vec![ - Some("hello"), - Some("bar"), - None, - Some(""), - Some("😁"), - Some("天地") - ], - vec![3286402344, 2486176763, 42, 142593372, 885025535, 2395000894] - ); + let input = vec![ + "hello", "bar", "", "😁", "天地", "a", "ab", "abc", "abcd", "abcde", + ] + .iter() + .map(|s| Some(s.to_string())) + .collect::>>(); + let expected: Vec = vec![ + 3286402344, 2486176763, 142593372, 885025535, 2395000894, 1485273170, 0xfa37157b, + 1322437556, 0xe860e5cc, 814637928, + ]; + + test_murmur3_hash::(input, expected); } #[test] diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 6ca4baf606..9926150840 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1454,17 +1454,55 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { withTable(table) { sql(s"create table $table(col string, a int, b float) using parquet") sql(s""" - |insert into $table values - |('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) - |, ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) - |""".stripMargin) + |insert into $table values + |('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) + |, ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) + |""".stripMargin) checkSparkAnswerAndOperator(""" - |select - |md5(col), md5(cast(a as string)), md5(cast(b as string)), - |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128) - |from test - |""".stripMargin) + |select + |md5(col), md5(cast(a as string)), md5(cast(b as string)), + |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128) + |from test + |""".stripMargin) + } + } + } + } + + test("hash functions with random input") { + val dataGen = DataGenerator.DEFAULT + // sufficient number of rows to create dictionary encoded ArrowArray. + val randomNumRows = 1000 + + val whitespaceChars = " \t\r\n" + val timestampPattern = "0123456789/:T" + whitespaceChars + Seq(true, false).foreach { dictionary => + withSQLConf( + "parquet.enable.dictionary" -> dictionary.toString, + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + val table = "test" + withTable(table) { + sql(s"create table $table(col string, a int, b float) using parquet") + // TODO: Add a Row generator in the data gen class and replace th following code + val col = dataGen.generateStrings(randomNumRows, timestampPattern, 6) + val colA = dataGen.generateInts(randomNumRows) + val colB = dataGen.generateFloats(randomNumRows) + val data = col.zip(colA).zip(colB).map { case ((a, b), c) => (a, b, c) } + data + .toDF("col", "a", "b") + .write + .mode("append") + .insertInto(table) + // with random generated data + // disable cast(b as string) for now, as the cast from float to string may produce incompatible result + checkSparkAnswerAndOperator(""" + |select + |md5(col), md5(cast(a as string)), --md5(cast(b as string)), + |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128) + |from test + |""".stripMargin) } } }