diff --git a/core/Cargo.lock b/core/Cargo.lock index 52f105591..105bcaf7c 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -637,6 +637,7 @@ dependencies = [ "thrift 0.17.0", "tokio", "tokio-stream", + "twox-hash", "unicode-segmentation", "zstd", ] @@ -2823,6 +2824,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" dependencies = [ "cfg-if", + "rand", "static_assertions", ] diff --git a/core/Cargo.toml b/core/Cargo.toml index 5e3e0ee74..6a179a6ee 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -75,6 +75,7 @@ once_cell = "1.18.0" regex = "1.9.6" crc32fast = "1.3.2" simd-adler32 = "0.3.7" +twox-hash = "1.6.3" [build-dependencies] prost-build = "0.9.0" diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 8c5e1f391..0f254004b 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -23,7 +23,7 @@ use std::{ sync::Arc, }; -use crate::execution::datafusion::spark_hash::create_hashes; +use crate::execution::datafusion::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes}; use arrow::{ array::{ ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray, @@ -119,6 +119,10 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_murmur3_hash); make_comet_scalar_udf!("murmur3_hash", func, without data_type) } + "xxhash64" => { + let func = Arc::new(spark_xxhash64); + make_comet_scalar_udf!("xxhash64", func, without data_type) + } sha if sha2_functions.contains(&sha) => { // Spark requires hex string as the result of sha2 functions, we have to wrap the // result of digest functions as hex string @@ -653,7 +657,7 @@ fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result>(); - create_hashes(&arrays, &mut hashes)?; + create_murmur3_hashes(&arrays, &mut hashes)?; if num_rows == 1 { Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some( hashes[0] as i32, @@ -672,6 +676,49 @@ fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u64; num_rows]; + hashes.fill(*seed as u64); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_xxhash64_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( + hashes[0] as i64, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i64).collect(); + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.", + seed + ) + } + } +} + #[inline] fn hex_encode>(data: T) -> String { let mut s = String::with_capacity(data.as_ref().len() * 2); diff --git a/core/src/execution/datafusion/shuffle_writer.rs b/core/src/execution/datafusion/shuffle_writer.rs index 3b92abbde..967340979 100644 --- a/core/src/execution/datafusion/shuffle_writer.rs +++ b/core/src/execution/datafusion/shuffle_writer.rs @@ -62,7 +62,7 @@ use tokio::task; use crate::{ common::bit::ceil, errors::{CometError, CometResult}, - execution::datafusion::spark_hash::{create_hashes, pmod}, + execution::datafusion::spark_hash::{create_murmur3_hashes, pmod}, }; /// The shuffle writer operator maps each input partition to M output partitions based on a @@ -673,7 +673,7 @@ impl ShuffleRepartitioner { // Hash arrays and compute buckets based on number of partitions let partition_ids = &mut self.partition_ids[..arrays[0].len()]; - create_hashes(&arrays, hashes_buf)? + create_murmur3_hashes(&arrays, hashes_buf)? .iter() .enumerate() .for_each(|(idx, hash)| { diff --git a/core/src/execution/datafusion/spark_hash.rs b/core/src/execution/datafusion/spark_hash.rs index 6d25a72f6..d06aba4a4 100644 --- a/core/src/execution/datafusion/spark_hash.rs +++ b/core/src/execution/datafusion/spark_hash.rs @@ -21,7 +21,8 @@ use arrow::{ compute::take, datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, UInt8Type}, }; -use std::sync::Arc; +use std::{hash::Hasher, sync::Arc}; +use twox_hash::XxHash64; use datafusion::{ arrow::{ @@ -98,17 +99,25 @@ pub(crate) fn spark_compatible_murmur3_hash>(data: T, seed: u32) } } +#[inline] +pub(crate) fn spark_compatible_xxhash64>(data: T, seed: u64) -> u64 { + // TODO: Rewrite with a stateless hasher to reduce stack allocation? + let mut hasher = XxHash64::with_seed(seed); + hasher.write(data.as_ref()); + hasher.finish() +} + macro_rules! hash_array { - ($array_type: ident, $column: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $hashes: ident, $hash_method: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = spark_compatible_murmur3_hash(&array.value(i), *hash); + *hash = $hash_method(&array.value(i), *hash); } } else { for (i, hash) in $hashes.iter_mut().enumerate() { if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash(&array.value(i), *hash); + *hash = $hash_method(&array.value(i), *hash); } } } @@ -116,22 +125,17 @@ macro_rules! hash_array { } macro_rules! hash_array_boolean { - ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident, $hash_method: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = spark_compatible_murmur3_hash( - $hash_input_type::from(array.value(i)).to_le_bytes(), - *hash, - ); + *hash = $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash); } } else { for (i, hash) in $hashes.iter_mut().enumerate() { if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash( - $hash_input_type::from(array.value(i)).to_le_bytes(), - *hash, - ); + *hash = + $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash); } } } @@ -139,18 +143,18 @@ macro_rules! hash_array_boolean { } macro_rules! hash_array_primitive { - ($array_type: ident, $column: ident, $ty: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $ty: ident, $hashes: ident, $hash_method: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); if array.null_count() == 0 { for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = spark_compatible_murmur3_hash((*value as $ty).to_le_bytes(), *hash); + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); } } else { for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash((*value as $ty).to_le_bytes(), *hash); + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); } } } @@ -158,7 +162,7 @@ macro_rules! hash_array_primitive { } macro_rules! hash_array_primitive_float { - ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident, $hash_method: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); @@ -166,9 +170,9 @@ macro_rules! hash_array_primitive_float { for (hash, value) in $hashes.iter_mut().zip(values.iter()) { // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. if *value == 0.0 && value.is_sign_negative() { - *hash = spark_compatible_murmur3_hash((0 as $ty2).to_le_bytes(), *hash); + *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); } else { - *hash = spark_compatible_murmur3_hash((*value as $ty).to_le_bytes(), *hash); + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); } } } else { @@ -176,9 +180,9 @@ macro_rules! hash_array_primitive_float { if !array.is_null(i) { // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. if *value == 0.0 && value.is_sign_negative() { - *hash = spark_compatible_murmur3_hash((0 as $ty2).to_le_bytes(), *hash); + *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); } else { - *hash = spark_compatible_murmur3_hash((*value as $ty).to_le_bytes(), *hash); + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); } } } @@ -187,17 +191,17 @@ macro_rules! hash_array_primitive_float { } macro_rules! hash_array_decimal { - ($array_type: ident, $column: ident, $hashes: ident) => { + ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = spark_compatible_murmur3_hash(array.value(i).to_le_bytes(), *hash); + *hash = $hash_method(array.value(i).to_le_bytes(), *hash); } } else { for (i, hash) in $hashes.iter_mut().enumerate() { if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash(array.value(i).to_le_bytes(), *hash); + *hash = $hash_method(array.value(i).to_le_bytes(), *hash); } } } @@ -214,7 +218,7 @@ fn create_hashes_dictionary( if !first_col { // unpack the dictionary array as each row may have a different hash input let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; - create_hashes(&[unpacked], hashes_buffer)?; + create_murmur3_hashes(&[unpacked], hashes_buffer)?; } else { // For the first column, hash each dictionary value once, and then use // that computed hash for each key value to avoid a potentially @@ -222,7 +226,42 @@ fn create_hashes_dictionary( let dict_values = Arc::clone(dict_array.values()); // same initial seed as Spark let mut dict_hashes = vec![42; dict_values.len()]; - create_hashes(&[dict_values], &mut dict_hashes)?; + create_murmur3_hashes(&[dict_values], &mut dict_hashes)?; + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } + } + Ok(()) +} + +// Hash the values in a dictionary array using xxhash64 +fn create_xxhash64_hashes_dictionary( + array: &ArrayRef, + hashes_buffer: &mut [u64], + first_col: bool, +) -> Result<()> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + if !first_col { + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_xxhash64_hashes(&[unpacked], hashes_buffer)?; + } else { + // Hash each dictionary value once, and then use that computed + // hash for each key value to avoid a potentially expensive + // redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + // same initial seed as Spark + let mut dict_hashes = vec![42u64; dict_values.len()]; + create_xxhash64_hashes(&[dict_values], &mut dict_hashes)?; + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { if let Some(key) = key { let idx = key.to_usize().ok_or_else(|| { @@ -244,111 +283,214 @@ fn create_hashes_dictionary( /// /// The number of rows to hash is determined by `hashes_buffer.len()`. /// `hashes_buffer` should be pre-sized appropriately -pub fn create_hashes<'a>( - arrays: &[ArrayRef], - hashes_buffer: &'a mut [u32], -) -> Result<&'a mut [u32]> { - for (i, col) in arrays.iter().enumerate() { - let first_col = i == 0; - match col.data_type() { - DataType::Boolean => { - hash_array_boolean!(BooleanArray, col, i32, hashes_buffer); - } - DataType::Int8 => { - hash_array_primitive!(Int8Array, col, i32, hashes_buffer); - } - DataType::Int16 => { - hash_array_primitive!(Int16Array, col, i32, hashes_buffer); - } - DataType::Int32 => { - hash_array_primitive!(Int32Array, col, i32, hashes_buffer); - } - DataType::Int64 => { - hash_array_primitive!(Int64Array, col, i64, hashes_buffer); - } - DataType::Float32 => { - hash_array_primitive_float!(Float32Array, col, f32, i32, hashes_buffer); - } - DataType::Float64 => { - hash_array_primitive_float!(Float64Array, col, f64, i64, hashes_buffer); - } - DataType::Timestamp(TimeUnit::Second, _) => { - hash_array_primitive!(TimestampSecondArray, col, i64, hashes_buffer); - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - hash_array_primitive!(TimestampMillisecondArray, col, i64, hashes_buffer); - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - hash_array_primitive!(TimestampMicrosecondArray, col, i64, hashes_buffer); - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - hash_array_primitive!(TimestampNanosecondArray, col, i64, hashes_buffer); - } - DataType::Date32 => { - hash_array_primitive!(Date32Array, col, i32, hashes_buffer); - } - DataType::Date64 => { - hash_array_primitive!(Date64Array, col, i64, hashes_buffer); - } - DataType::Utf8 => { - hash_array!(StringArray, col, hashes_buffer); - } - DataType::LargeUtf8 => { - hash_array!(LargeStringArray, col, hashes_buffer); - } - DataType::Binary => { - hash_array!(BinaryArray, col, hashes_buffer); - } - DataType::LargeBinary => { - hash_array!(LargeBinaryArray, col, hashes_buffer); - } - DataType::FixedSizeBinary(_) => { - hash_array!(FixedSizeBinaryArray, col, hashes_buffer); - } - DataType::Decimal128(_, _) => { - hash_array_decimal!(Decimal128Array, col, hashes_buffer); - } - DataType::Dictionary(index_type, _) => match **index_type { +/// +/// `hash_method` is the hash function to use +/// `create_dictionary_hash_method` is the function to create hashes for dictionary arrays input +macro_rules! create_hashes_internal { + ($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident) => { + for (i, col) in $arrays.iter().enumerate() { + let first_col = i == 0; + match col.data_type() { + DataType::Boolean => { + hash_array_boolean!(BooleanArray, col, i32, $hashes_buffer, $hash_method); + } DataType::Int8 => { - create_hashes_dictionary::(col, hashes_buffer, first_col)?; + hash_array_primitive!(Int8Array, col, i32, $hashes_buffer, $hash_method); } DataType::Int16 => { - create_hashes_dictionary::(col, hashes_buffer, first_col)?; + hash_array_primitive!(Int16Array, col, i32, $hashes_buffer, $hash_method); } DataType::Int32 => { - create_hashes_dictionary::(col, hashes_buffer, first_col)?; + hash_array_primitive!(Int32Array, col, i32, $hashes_buffer, $hash_method); } DataType::Int64 => { - create_hashes_dictionary::(col, hashes_buffer, first_col)?; + hash_array_primitive!(Int64Array, col, i64, $hashes_buffer, $hash_method); } - DataType::UInt8 => { - create_hashes_dictionary::(col, hashes_buffer, first_col)?; + DataType::Float32 => { + hash_array_primitive_float!( + Float32Array, + col, + f32, + i32, + $hashes_buffer, + $hash_method + ); } - DataType::UInt16 => { - create_hashes_dictionary::(col, hashes_buffer, first_col)?; + DataType::Float64 => { + hash_array_primitive_float!( + Float64Array, + col, + f64, + i64, + $hashes_buffer, + $hash_method + ); } - DataType::UInt32 => { - create_hashes_dictionary::(col, hashes_buffer, first_col)?; + DataType::Timestamp(TimeUnit::Second, _) => { + hash_array_primitive!( + TimestampSecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + hash_array_primitive!( + TimestampMillisecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + hash_array_primitive!( + TimestampMicrosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + hash_array_primitive!( + TimestampNanosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); } - DataType::UInt64 => { - create_hashes_dictionary::(col, hashes_buffer, first_col)?; + DataType::Date32 => { + hash_array_primitive!(Date32Array, col, i32, $hashes_buffer, $hash_method); } + DataType::Date64 => { + hash_array_primitive!(Date64Array, col, i64, $hashes_buffer, $hash_method); + } + DataType::Utf8 => { + hash_array!(StringArray, col, $hashes_buffer, $hash_method); + } + DataType::LargeUtf8 => { + hash_array!(LargeStringArray, col, $hashes_buffer, $hash_method); + } + DataType::Binary => { + hash_array!(BinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::LargeBinary => { + hash_array!(LargeBinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::FixedSizeBinary(_) => { + hash_array!(FixedSizeBinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::Decimal128(_, _) => { + hash_array_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method); + } + DataType::Dictionary(index_type, _) => match **index_type { + DataType::Int8 => { + $create_dictionary_hash_method::(col, $hashes_buffer, first_col)?; + } + DataType::Int16 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::Int32 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::Int64 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt8 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt16 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt32 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt64 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported dictionary type in hasher hashing: {}", + col.data_type(), + ))) + } + }, _ => { + // This is internal because we should have caught this before. return Err(DataFusionError::Internal(format!( - "Unsupported dictionary type in hasher hashing: {}", - col.data_type(), - ))) + "Unsupported data type in hasher: {}", + col.data_type() + ))); } - }, - _ => { - // This is internal because we should have caught this before. - return Err(DataFusionError::Internal(format!( - "Unsupported data type in hasher: {}", - col.data_type() - ))); } } - } + }; +} + +/// Creates hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +pub(crate) fn create_murmur3_hashes<'a>( + arrays: &[ArrayRef], + hashes_buffer: &'a mut [u32], +) -> Result<&'a mut [u32]> { + create_hashes_internal!( + arrays, + hashes_buffer, + spark_compatible_murmur3_hash, + create_hashes_dictionary + ); + Ok(hashes_buffer) +} + +/// Creates xxhash64 hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +pub(crate) fn create_xxhash64_hashes<'a>( + arrays: &[ArrayRef], + hashes_buffer: &'a mut [u64], +) -> Result<&'a mut [u64]> { + create_hashes_internal!( + arrays, + hashes_buffer, + spark_compatible_xxhash64, + create_xxhash64_hashes_dictionary + ); Ok(hashes_buffer) } @@ -365,38 +507,61 @@ mod tests { use arrow::array::{Float32Array, Float64Array}; use std::sync::Arc; - use crate::execution::datafusion::spark_hash::{create_hashes, pmod}; + use crate::execution::datafusion::spark_hash::{ + create_murmur3_hashes, create_xxhash64_hashes, pmod, + }; use datafusion::arrow::array::{ArrayRef, Int32Array, Int64Array, Int8Array, StringArray}; macro_rules! test_hashes_internal { - ($input: expr, $len: expr, $expected: expr) => { - let i = $input as ArrayRef; - let mut hashes = vec![42; $len]; - create_hashes(&[i], &mut hashes).unwrap(); + ($hash_method: ident, $input: expr, $initial_seeds: expr, $expected: expr) => { + let i = $input; + let mut hashes = $initial_seeds.clone(); + $hash_method(&[i], &mut hashes).unwrap(); assert_eq!(hashes, $expected); }; } + macro_rules! test_hashes_with_nulls { + ($method: ident, $t: ty, $values: ident, $expected: ident, $seed_type: ty) => { + // copied before inserting nulls + let mut input_with_nulls = $values.clone(); + let mut expected_with_nulls = $expected.clone(); + // test before inserting nulls + let len = $values.len(); + let initial_seeds = vec![42 as $seed_type; len]; + let i = Arc::new(<$t>::from($values)) as ArrayRef; + test_hashes_internal!($method, i, initial_seeds, $expected); + + // test with nulls + let median = len / 2; + input_with_nulls.insert(0, None); + input_with_nulls.insert(median, None); + expected_with_nulls.insert(0, 42 as $seed_type); + expected_with_nulls.insert(median, 42 as $seed_type); + let len_with_nulls = len + 2; + let initial_seeds_with_nulls = vec![42 as $seed_type; len_with_nulls]; + let nullable_input = Arc::new(<$t>::from(input_with_nulls)) as ArrayRef; + test_hashes_internal!( + $method, + nullable_input, + initial_seeds_with_nulls, + expected_with_nulls + ); + }; + } + fn test_murmur3_hash>> + 'static>( values: Vec>, expected: Vec, ) { - // copied before inserting nulls - let mut input_with_nulls = values.clone(); - let mut expected_with_nulls = expected.clone(); - let len = values.len(); - let i = Arc::new(T::from(values)) as ArrayRef; - test_hashes_internal!(i, len, expected); - - // test with nulls - let median = len / 2; - input_with_nulls.insert(0, None); - input_with_nulls.insert(median, None); - expected_with_nulls.insert(0, 42); - expected_with_nulls.insert(median, 42); - let with_nulls_len = len + 2; - let nullable_input = Arc::new(T::from(input_with_nulls)) as ArrayRef; - test_hashes_internal!(nullable_input, with_nulls_len, expected_with_nulls); + test_hashes_with_nulls!(create_murmur3_hashes, T, values, expected, u32); + } + + fn test_xxhash64_hash>> + 'static>( + values: Vec>, + expected: Vec, + ) { + test_hashes_with_nulls!(create_xxhash64_hashes, T, values, expected, u64); } #[test] @@ -405,6 +570,16 @@ mod tests { vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365], ); + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 0x77cc15d9f9f2cdc2, + 0x39bc22b9e94d81d0, + ], + ); } #[test] @@ -413,6 +588,16 @@ mod tests { vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6], ); + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 0x14f0ac009c21721c, + 0x1cc7cb8d034769cd, + ], + ); } #[test] @@ -421,6 +606,16 @@ mod tests { vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb], ); + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], + vec![ + 0x9ed50fd59358d232, + 0xb71b47ebda15746c, + 0x358ae035bfb46fd2, + 0xd2f1c616ae7eb306, + 0x88608019c494c1f4, + ], + ); } #[test] @@ -438,6 +633,24 @@ mod tests { 0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 0xc0361c86, ], ); + test_xxhash64_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0x9b92689757fcdbd, + 0x3229fbc4681e48f3, + 0x3229fbc4681e48f3, + 0xa2becc0e61bb3823, + 0x8f20ab82d4f3687f, + 0xdce4982d97f7ac4, + ], + ) } #[test] @@ -455,6 +668,25 @@ mod tests { 0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 0xa0eef9f9, ], ); + + test_xxhash64_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0xe1fd6e07fee8ad53, + 0xb71b47ebda15746c, + 0xb71b47ebda15746c, + 0x8cdde022746f8f1f, + 0x793c5c88d313eac7, + 0xc5e60e7b75d9b232, + ], + ) } #[test] @@ -470,7 +702,22 @@ mod tests { 1322437556, 0xe860e5cc, 814637928, ]; - test_murmur3_hash::(input, expected); + test_murmur3_hash::(input.clone(), expected); + test_xxhash64_hash::( + input, + vec![ + 0xc3629e6318d53932, + 0xe7097b6a54378d8a, + 0x98b1582b0977e704, + 0xa80d9d5a6a523bd5, + 0xfcba5f61ac666c61, + 0x88e4fe59adf7b0cc, + 0x259dd873209a3fe3, + 0x13c1d910702770e6, + 0xa17b5eb5dc364dff, + 0xf241303e4a90f299, + ], + ) } #[test] diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 6333650dd..e015e87cd 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2086,6 +2086,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // the seed is put at the end of the arguments scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*) + case XxHash64(children, seed) => + val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType)) + if (firstUnSupportedInput.isDefined) { + withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}") + return None + } + val exprs = children.map(exprToProtoInternal(_, inputs)) + val seedBuilder = ExprOuterClass.Literal + .newBuilder() + .setDatatype(serializeDataType(LongType).get) + .setLongVal(seed) + val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) + // the seed is put at the end of the arguments + scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: _*) + case Sha2(left, numBits) => if (!numBits.foldable) { withInfo(expr, "non literal numBits is not supported") diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 1afdd78ec..b75777efb 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1468,6 +1468,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |select |md5(col), md5(cast(a as string)), md5(cast(b as string)), |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), + |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128) |from test |""".stripMargin) @@ -1490,14 +1491,13 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val table = "test" withTable(table) { sql(s"create table $table(col string, a int, b float) using parquet") - // TODO: Add a Row generator in the data gen class and replace th following code - val col = dataGen.generateStrings(randomNumRows, timestampPattern, 6) - val colA = dataGen.generateInts(randomNumRows) - val colB = dataGen.generateFloats(randomNumRows) - val data = col.zip(colA).zip(colB).map { case ((a, b), c) => (a, b, c) } - data - .toDF("col", "a", "b") - .write + val tableSchema = spark.table(table).schema + val rows = dataGen.generateRows( + randomNumRows, + tableSchema, + Some(() => dataGen.generateString(timestampPattern, 6))) + val data = spark.createDataFrame(spark.sparkContext.parallelize(rows), tableSchema) + data.write .mode("append") .insertInto(table) // with random generated data @@ -1506,6 +1506,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |select |md5(col), md5(cast(a as string)), --md5(cast(b as string)), |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), + |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128) |from test |""".stripMargin)