diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 514afe31d..8c5e1f391 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -43,8 +43,7 @@ use datafusion::{ }; use datafusion_common::{ cast::{as_binary_array, as_generic_string_array}, - exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result as DataFusionResult, - ScalarValue, + exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, }; use datafusion_physical_expr::{math_expressions, udf::ScalarUDF}; use num::{ @@ -53,6 +52,9 @@ use num::{ }; use unicode_segmentation::UnicodeSegmentation; +mod unhex; +use unhex::spark_unhex; + macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( @@ -107,7 +109,8 @@ pub fn create_comet_physical_fun( make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type) } "unhex" => { - make_comet_scalar_udf!("unhex", spark_unhex, data_type) + let func = Arc::new(spark_unhex); + make_comet_scalar_udf!("unhex", func, without data_type) } "decimal_div" => { make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) @@ -576,115 +579,6 @@ fn spark_rpad_internal( Ok(ColumnarValue::Array(Arc::new(result))) } -fn unhex(string: &str, result: &mut Vec) -> Result<(), DataFusionError> { - // https://docs.databricks.com/en/sql/language-manual/functions/unhex.html - // If the length of expr is odd, the first character is discarded and the result is padded with - // a null byte. If expr contains non hex characters the result is NULL. - let string = if string.len() % 2 == 1 { - &string[1..] - } else { - string - }; - - let mut iter = string.chars().peekable(); - while let Some(c) = iter.next() { - let high = if let Some(high) = c.to_digit(16) { - high - } else { - return Ok(()); - }; - - let low = iter - .next() - .ok_or_else(|| DataFusionError::Internal("Odd number of hex characters".to_string()))? - .to_digit(16); - - let low = if let Some(low) = low { - low - } else { - return Ok(()); - }; - - result.push((high << 4 | low) as u8); - } - - if string.len() % 2 == 1 { - result.push(0); - } - - Ok(()) -} - -fn spark_unhex_inner( - array: &ColumnarValue, - fail_on_error: bool, -) -> Result { - let string_array = match array { - ColumnarValue::Array(array) => as_generic_string_array::(array)?, - ColumnarValue::Scalar(ScalarValue::Utf8(Some(_string))) => { - return not_impl_err!("unhex with scalar string is not implemented yet"); - } - _ => { - return internal_err!( - "The first argument must be a string scalar or array, but got: {:?}", - array.data_type() - ); - } - }; - - let mut builder = arrow::array::BinaryBuilder::new(); - let mut encoded = Vec::new(); - - for i in 0..string_array.len() { - let string = string_array.value(i); - - if unhex(string, &mut encoded).is_ok() { - builder.append_value(encoded.as_slice()); - encoded.clear(); - } else if fail_on_error { - return plan_err!("Input to unhex is not a valid hex string: {:?}", string); - } else { - builder.append_null(); - } - } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) -} - -fn spark_unhex( - args: &[ColumnarValue], - _data_type: &DataType, -) -> Result { - if args.len() > 2 { - return plan_err!("unhex takes at most 2 arguments, but got: {}", args.len()); - } - - let val_to_unhex = &args[0]; - let fail_on_error = if args.len() == 2 { - match &args[1] { - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, - _ => { - return plan_err!( - "The second argument must be boolean scalar, but got: {:?}", - args[1] - ); - } - } - } else { - false - }; - - match val_to_unhex.data_type() { - DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), - DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), - other => { - internal_err!( - "The first argument must be a string scalar or array, but got: {:?}", - other - ) - } - } -} - // Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). // Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to // get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since @@ -813,18 +707,3 @@ fn wrap_digest_result_as_hex_string( } } } - -#[cfg(test)] -mod test { - use super::unhex; - - #[test] - fn test_unhex() { - let mut result = Vec::new(); - - unhex("537061726B2053514C", &mut result).unwrap(); - let result_str = std::str::from_utf8(&result).unwrap(); - assert_eq!(result_str, "Spark SQL"); - result.clear(); - } -} diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs new file mode 100644 index 000000000..5a7a16aba --- /dev/null +++ b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{Array, OffsetSizeTrait}; +use arrow_schema::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue}; + +fn unhex(string: &str, result: &mut Vec) -> Result<(), DataFusionError> { + if string.is_empty() { + return Ok(()); + } + + // Adjust the string if it has an odd length, and prepare to add a padding byte if needed. + let needs_padding = string.len() % 2 != 0; + let adjusted_string = if needs_padding { &string[1..] } else { string }; + + let mut iter = adjusted_string.chars().peekable(); + while let (Some(high_char), Some(low_char)) = (iter.next(), iter.next()) { + let high = high_char + .to_digit(16) + .ok_or_else(|| DataFusionError::Internal("Invalid hex character".to_string()))?; + let low = low_char + .to_digit(16) + .ok_or_else(|| DataFusionError::Internal("Invalid hex character".to_string()))?; + + result.push((high << 4 | low) as u8); + } + + if needs_padding { + result.push(0); + } + + Ok(()) +} + +fn spark_unhex_inner( + array: &ColumnarValue, + fail_on_error: bool, +) -> Result { + match array { + ColumnarValue::Array(array) => { + let string_array = as_generic_string_array::(array)?; + + let mut builder = arrow::array::BinaryBuilder::new(); + let mut encoded = Vec::new(); + + for i in 0..string_array.len() { + let string = string_array.value(i); + + if unhex(string, &mut encoded).is_ok() { + builder.append_value(encoded.as_slice()); + encoded.clear(); + } else if fail_on_error { + return exec_err!("Input to unhex is not a valid hex string: {string}"); + } else { + builder.append_null(); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))) => { + let mut encoded = Vec::new(); + + if unhex(string, &mut encoded).is_ok() { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encoded)))) + } else if fail_on_error { + exec_err!("Input to unhex is not a valid hex string: {string}") + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + } + _ => { + exec_err!( + "The first argument must be a string scalar or array, but got: {:?}", + array + ) + } + } +} + +pub(super) fn spark_unhex(args: &[ColumnarValue]) -> Result { + if args.len() > 2 { + return exec_err!("unhex takes at most 2 arguments, but got: {}", args.len()); + } + + let val_to_unhex = &args[0]; + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, + _ => { + return exec_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; + + match val_to_unhex.data_type() { + DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + other => exec_err!( + "The first argument must be a string scalar or array, but got: {:?}", + other + ), + } +} + +#[cfg(test)] +mod test { + use super::unhex; + + #[test] + fn test_unhex() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("537061726B2053514C", &mut result)?; + let result_str = std::str::from_utf8(&result)?; + assert_eq!(result_str, "Spark SQL"); + result.clear(); + + assert!(unhex("hello", &mut result).is_err()); + result.clear(); + + unhex("", &mut result)?; + assert!(result.is_empty()); + + Ok(()) + } +}