From 848555818713ca9dff6fedfdf6e407969f1eeca2 Mon Sep 17 00:00:00 2001 From: advancedxy Date: Fri, 26 Apr 2024 13:02:47 +0800 Subject: [PATCH] feat: Support murmur3_hash and sha2 family hash functions (#226) * feat: Support murmur3_hash and sha2 family hash functions * address comments * apply scalafix * ensure crypto_expressions feature is enabled --- core/Cargo.toml | 4 +- .../datafusion/expressions/scalar_funcs.rs | 206 ++++++++++++------ .../apache/comet/serde/QueryPlanSerde.scala | 43 +++- .../apache/comet/CometExpressionSuite.scala | 26 ++- 4 files changed, 205 insertions(+), 74 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index 5d1604952..b09b0ea7f 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -67,8 +67,8 @@ chrono = { version = "0.4", default-features = false, features = ["clock"] } chrono-tz = { version = "0.8" } paste = "1.0.14" datafusion-common = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4" } -datafusion = { default-features = false, git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features = ["unicode_expressions"] } -datafusion-functions = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4" } +datafusion = { default-features = false, git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features = ["unicode_expressions", "crypto_expressions"] } +datafusion-functions = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features = ["crypto_expressions"]} datafusion-physical-expr = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", default-features = false, features = ["unicode_expressions"] } unicode-segmentation = "^1.10.1" once_cell = "1.18.0" diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index e6f8de16b..2895937ca 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -15,8 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, cmp::min, fmt::Debug, str::FromStr, sync::Arc}; +use std::{ + any::Any, + cmp::min, + fmt::{Debug, Write}, + str::FromStr, + sync::Arc, +}; +use crate::execution::datafusion::spark_hash::create_hashes; use arrow::{ array::{ ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray, @@ -24,7 +31,7 @@ use arrow::{ }, datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, }; -use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array}; +use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array, StringArray}; use arrow_schema::DataType; use datafusion::{ execution::FunctionRegistry, @@ -35,8 +42,8 @@ use datafusion::{ physical_plan::ColumnarValue, }; use datafusion_common::{ - cast::as_generic_string_array, exec_err, internal_err, DataFusionError, - Result as DataFusionResult, ScalarValue, + cast::{as_binary_array, as_generic_string_array}, + exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, }; use datafusion_physical_expr::{math_expressions, udf::ScalarUDF}; use num::{ @@ -45,89 +52,75 @@ use num::{ }; use unicode_segmentation::UnicodeSegmentation; +macro_rules! make_comet_scalar_udf { + ($name:expr, $func:ident, $data_type:ident) => {{ + let scalar_func = CometScalarFunction::new( + $name.to_string(), + Signature::variadic_any(Volatility::Immutable), + $data_type.clone(), + Arc::new(move |args| $func(args, &$data_type)), + ); + Ok(ScalarFunctionDefinition::UDF(Arc::new( + ScalarUDF::new_from_impl(scalar_func), + ))) + }}; + ($name:expr, $func:expr, without $data_type:ident) => {{ + let scalar_func = CometScalarFunction::new( + $name.to_string(), + Signature::variadic_any(Volatility::Immutable), + $data_type, + $func, + ); + Ok(ScalarFunctionDefinition::UDF(Arc::new( + ScalarUDF::new_from_impl(scalar_func), + ))) + }}; +} + /// Create a physical scalar function. pub fn create_comet_physical_fun( fun_name: &str, data_type: DataType, registry: &dyn FunctionRegistry, ) -> Result { + let sha2_functions = ["sha224", "sha256", "sha384", "sha512"]; match fun_name { "ceil" => { - let scalar_func = CometScalarFunction::new( - "ceil".to_string(), - Signature::variadic_any(Volatility::Immutable), - data_type.clone(), - Arc::new(move |args| spark_ceil(args, &data_type)), - ); - Ok(ScalarFunctionDefinition::UDF(Arc::new( - ScalarUDF::new_from_impl(scalar_func), - ))) + make_comet_scalar_udf!("ceil", spark_ceil, data_type) } "floor" => { - let scalar_func = CometScalarFunction::new( - "floor".to_string(), - Signature::variadic_any(Volatility::Immutable), - data_type.clone(), - Arc::new(move |args| spark_floor(args, &data_type)), - ); - Ok(ScalarFunctionDefinition::UDF(Arc::new( - ScalarUDF::new_from_impl(scalar_func), - ))) + make_comet_scalar_udf!("floor", spark_floor, data_type) } "rpad" => { - let scalar_func = CometScalarFunction::new( - "rpad".to_string(), - Signature::variadic_any(Volatility::Immutable), - data_type.clone(), - Arc::new(spark_rpad), - ); - Ok(ScalarFunctionDefinition::UDF(Arc::new( - ScalarUDF::new_from_impl(scalar_func), - ))) + let func = Arc::new(spark_rpad); + make_comet_scalar_udf!("rpad", func, without data_type) } "round" => { - let scalar_func = CometScalarFunction::new( - "round".to_string(), - Signature::variadic_any(Volatility::Immutable), - data_type.clone(), - Arc::new(move |args| spark_round(args, &data_type)), - ); - Ok(ScalarFunctionDefinition::UDF(Arc::new( - ScalarUDF::new_from_impl(scalar_func), - ))) + make_comet_scalar_udf!("round", spark_round, data_type) } "unscaled_value" => { - let scalar_func = CometScalarFunction::new( - "unscaled_value".to_string(), - Signature::variadic_any(Volatility::Immutable), - data_type.clone(), - Arc::new(spark_unscaled_value), - ); - Ok(ScalarFunctionDefinition::UDF(Arc::new( - ScalarUDF::new_from_impl(scalar_func), - ))) + let func = Arc::new(spark_unscaled_value); + make_comet_scalar_udf!("unscaled_value", func, without data_type) } "make_decimal" => { - let scalar_func = CometScalarFunction::new( - "make_decimal".to_string(), - Signature::variadic_any(Volatility::Immutable), - data_type.clone(), - Arc::new(move |args| spark_make_decimal(args, &data_type)), - ); - Ok(ScalarFunctionDefinition::UDF(Arc::new( - ScalarUDF::new_from_impl(scalar_func), - ))) + make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type) } "decimal_div" => { - let scalar_func = CometScalarFunction::new( - "decimal_div".to_string(), - Signature::variadic_any(Volatility::Immutable), - data_type.clone(), - Arc::new(move |args| spark_decimal_div(args, &data_type)), - ); - Ok(ScalarFunctionDefinition::UDF(Arc::new( - ScalarUDF::new_from_impl(scalar_func), - ))) + make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) + } + "murmur3_hash" => { + let func = Arc::new(spark_murmur3_hash); + make_comet_scalar_udf!("murmur3_hash", func, without data_type) + } + sha if sha2_functions.contains(&sha) => { + // Spark requires hex string as the result of sha2 functions, we have to wrap the + // result of digest functions as hex string + let func = registry.udf(sha)?; + let wrapped_func = Arc::new(move |args: &[ColumnarValue]| { + wrap_digest_result_as_hex_string(args, func.fun()) + }); + let spark_func_name = "spark".to_owned() + sha; + make_comet_scalar_udf!(spark_func_name, wrapped_func, without data_type) } _ => { let fun = BuiltinScalarFunction::from_str(fun_name); @@ -629,3 +622,82 @@ fn spark_decimal_div( let result = result.with_data_type(DataType::Decimal128(p3, s3)); Ok(ColumnarValue::Array(Arc::new(result))) } + +fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u32; num_rows]; + hashes.fill(*seed as u32); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some( + hashes[0] as i32, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i32).collect(); + Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.", + seed + ) + } + } +} + +#[inline] +fn hex_encode>(data: T) -> String { + let mut s = String::with_capacity(data.as_ref().len() * 2); + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02x}").unwrap(); + } + s +} + +fn wrap_digest_result_as_hex_string( + args: &[ColumnarValue], + digest: ScalarFunctionImplementation, +) -> Result { + let value = digest(args)?; + match value { + ColumnarValue::Array(array) => { + let binary_array = as_binary_array(&array)?; + let string_array: StringArray = binary_array + .iter() + .map(|opt| opt.map(hex_encode::<_>)) + .collect(); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(opt.map(hex_encode::<_>)), + )), + _ => { + exec_err!( + "digest function should return binary value, but got: {:?}", + value.data_type() + ) + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d08fb6b90..57b15e2f5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1613,10 +1613,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { optExprWithInfo(optExpr, expr, castExpr) case Md5(child) => - val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs) + val childExpr = exprToProtoInternal(child, inputs) val optExpr = scalarExprToProto("md5", childExpr) - optExprWithInfo(optExpr, expr, castExpr) + optExprWithInfo(optExpr, expr, child) case OctetLength(child) => val castExpr = Cast(child, StringType) @@ -1954,6 +1953,44 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None } + case Murmur3Hash(children, seed) => + val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType)) + if (firstUnSupportedInput.isDefined) { + withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}") + return None + } + val exprs = children.map(exprToProtoInternal(_, inputs)) + val seedBuilder = ExprOuterClass.Literal + .newBuilder() + .setDatatype(serializeDataType(IntegerType).get) + .setIntVal(seed) + val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) + // the seed is put at the end of the arguments + scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*) + + case Sha2(left, numBits) => + if (!numBits.foldable) { + withInfo(expr, "non literal numBits is not supported") + return None + } + // it's possible for spark to dynamically compute the number of bits from input + // expression, however DataFusion does not support that yet. + val childExpr = exprToProtoInternal(left, inputs) + val bits = numBits.eval().asInstanceOf[Int] + val algorithm = bits match { + case 224 => "sha224" + case 256 | 0 => "sha256" + case 384 => "sha384" + case 512 => "sha512" + case _ => + null + } + if (algorithm == null) { + exprToProtoInternal(Literal(null, StringType), inputs) + } else { + scalarExprToProtoWithReturnType(algorithm, StringType, childExpr) + } + case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 376baa378..3683c8d44 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -981,8 +981,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - // TODO: enable this when we add md5 function to Comet - ignore("md5") { + test("md5") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { val table = "test" @@ -1405,4 +1404,27 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("hash functions") { + Seq(true, false).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col string, a int, b float) using parquet") + sql(s""" + |insert into $table values + |('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) + |, ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) + |""".stripMargin) + checkSparkAnswerAndOperator(""" + |select + |md5(col), md5(cast(a as string)), md5(cast(b as string)), + |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128) + |from test + |""".stripMargin) + } + } + } + } + }