From 5dbd4aa1551b661f42959953afac3cbd775604d3 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Sun, 28 Apr 2024 11:37:54 -0700 Subject: [PATCH 01/20] feat: ic for native unhex --- .../datafusion/expressions/scalar_funcs.rs | 132 +++++++++++++++++- core/src/execution/datafusion/planner.rs | 10 +- .../apache/comet/serde/QueryPlanSerde.scala | 11 ++ .../apache/comet/CometExpressionSuite.scala | 13 ++ 4 files changed, 157 insertions(+), 9 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 2895937ca..f59c5f40c 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -43,7 +43,7 @@ use datafusion::{ }; use datafusion_common::{ cast::{as_binary_array, as_generic_string_array}, - exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, + exec_err, internal_err, not_impl_err, DataFusionError, Result as DataFusionResult, ScalarValue, }; use datafusion_physical_expr::{math_expressions, udf::ScalarUDF}; use num::{ @@ -105,6 +105,9 @@ pub fn create_comet_physical_fun( "make_decimal" => { make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type) } + "unhex" => { + make_comet_scalar_udf!("unhex", spark_unhex, data_type) + } "decimal_div" => { make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) } @@ -123,11 +126,10 @@ pub fn create_comet_physical_fun( make_comet_scalar_udf!(spark_func_name, wrapped_func, without data_type) } _ => { - let fun = BuiltinScalarFunction::from_str(fun_name); - if fun.is_err() { - Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?)) + if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) { + Ok(ScalarFunctionDefinition::BuiltIn(fun)) } else { - Ok(ScalarFunctionDefinition::BuiltIn(fun?)) + Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?)) } } } @@ -573,6 +575,111 @@ fn spark_rpad_internal( Ok(ColumnarValue::Array(Arc::new(result))) } +fn unhex(string: &str, result: &mut Vec) -> Result<(), DataFusionError> { + // https://docs.databricks.com/en/sql/language-manual/functions/unhex.html + // If the length of expr is odd, the first character is discarded and the result is padded with + // a null byte. If expr contains non hex characters the result is NULL. + let string = if string.len() % 2 == 1 { + &string[1..] + } else { + string + }; + + let mut iter = string.chars().peekable(); + while let Some(c) = iter.next() { + let high = if let Some(high) = c.to_digit(16) { + high + } else { + return Ok(()); + }; + + let low = iter + .next() + .ok_or_else(|| DataFusionError::Internal("Odd number of hex characters".to_string()))? + .to_digit(16); + + let low = if let Some(low) = low { + low + } else { + return Ok(()); + }; + + result.push((high << 4 | low) as u8); + } + + if string.len() % 2 == 1 { + result.push(0); + } + + Ok(()) +} + +fn spark_unhex_inner( + array: &ColumnarValue, + fail_on_error: bool, +) -> Result { + // let string_array = as_generic_string_array::(array)?; + + let string_array = match array { + ColumnarValue::Array(array) => as_generic_string_array::(array)?, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(_string))) => { + return not_impl_err!("unhex with scalar string is not implemented yet"); + } + _ => { + return internal_err!( + "The first argument must be a string scalar or array, but got: {:?}", + array.data_type() + ); + } + }; + + let mut builder = arrow::array::BinaryBuilder::new(); + let mut encoded = Vec::new(); + + for i in 0..string_array.len() { + let string = string_array.value(i); + + match unhex(string, &mut encoded) { + Ok(_) => { + builder.append_value(encoded.as_slice()); + encoded.clear(); + } + Err(_) if fail_on_error => { + return internal_err!("Invalid hex string: {:?}", string); + } + _ => { + builder.append_null(); + } + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) +} + +fn spark_unhex( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + if args.len() != 2 { + return internal_err!("unhex takes exactly two argument"); + } + + let val_to_unhex = &args[0]; + let fail_on_error = if let ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) = &args[1] { + *b + } else { + return internal_err!("The second argument must be a boolean scalar"); + }; + + match data_type { + DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + _ => internal_err!( + "The first argument must be an array, but got: {:?}", + data_type + ), + } +} + // Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). // Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to // get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since @@ -701,3 +808,18 @@ fn wrap_digest_result_as_hex_string( } } } + +#[cfg(test)] +mod test { + use super::unhex; + + #[test] + fn test_unhex() { + let mut result = Vec::new(); + + unhex("537061726B2053514C", &mut result).unwrap(); + let result_str = std::str::from_utf8(&result).unwrap(); + assert_eq!(result_str, "Spark SQL"); + result.clear(); + } +} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 72174790b..6939f81df 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -1301,6 +1301,7 @@ impl PhysicalPlanner { .iter() .map(|x| x.data_type(input_schema.as_ref())) .collect::, _>>()?; + let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) { Some(t) => t, None => { @@ -1308,17 +1309,18 @@ impl PhysicalPlanner { // scalar function // Note this assumes the `fun_name` is a defined function in DF. Otherwise, it'll // throw error. - let fun = BuiltinScalarFunction::from_str(fun_name); - if fun.is_err() { + + if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) { + fun.return_type(&input_expr_types)? + } else { self.session_ctx .udf(fun_name)? .inner() .return_type(&input_expr_types)? - } else { - fun?.return_type(&input_expr_types)? } } }; + let fun_expr = create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 57b15e2f5..99af03bc6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1396,6 +1396,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr) optExprWithInfo(optExpr, expr, left, right) + case e @ Unhex(child, failOnError) => + val childCast = Cast(child, StringType) + val failOnErrorCast = Cast(Literal(failOnError), BooleanType) + + val childExpr = exprToProtoInternal(childCast, inputs) + val failOnErrorExpr = exprToProtoInternal(failOnErrorCast, inputs) + + val optExpr = + scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr) + optExprWithInfo(optExpr, expr, child, failOnErrorCast) + case e @ Ceil(child) => val childExpr = exprToProtoInternal(child, inputs) child.dataType match { diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 3683c8d44..ae2e88da9 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1025,6 +1025,19 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("unhex") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col string) using parquet") + sql(s"insert into $table values('4A4D'), ('4A4D'), ('4A4D'), ('4A4D')") + checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table") + } + } + } + } + test("length, reverse, instr, replace, translate") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { From 04bb619b03478cd9b9a111dd5c2f2539ad894d1a Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 30 Apr 2024 19:11:24 -0700 Subject: [PATCH 02/20] feat: setup shim for unhex --- common/pom.xml | 3 +- .../datafusion/expressions/scalar_funcs.rs | 34 ++++++------------- pom.xml | 3 +- spark/pom.xml | 3 +- .../apache/comet/serde/QueryPlanSerde.scala | 14 +++++--- .../comet/shims/ShimCometUnhexExpr.scala | 29 ++++++++++++++++ .../comet/shims/ShimCometUnhexExpr.scala | 29 ++++++++++++++++ .../apache/comet/CometExpressionSuite.scala | 2 +- 8 files changed, 84 insertions(+), 33 deletions(-) create mode 100644 spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala create mode 100644 spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala diff --git a/common/pom.xml b/common/pom.xml index 540101d71..ac9d136b3 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -179,7 +179,8 @@ under the License. - src/main/${shims.source} + src/main/${shims.majorSource} + src/main/${shims.minorSource} diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index f59c5f40c..e02f17022 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -616,10 +616,7 @@ fn unhex(string: &str, result: &mut Vec) -> Result<(), DataFusionError> { fn spark_unhex_inner( array: &ColumnarValue, - fail_on_error: bool, ) -> Result { - // let string_array = as_generic_string_array::(array)?; - let string_array = match array { ColumnarValue::Array(array) => as_generic_string_array::(array)?, ColumnarValue::Scalar(ScalarValue::Utf8(Some(_string))) => { @@ -639,17 +636,11 @@ fn spark_unhex_inner( for i in 0..string_array.len() { let string = string_array.value(i); - match unhex(string, &mut encoded) { - Ok(_) => { - builder.append_value(encoded.as_slice()); - encoded.clear(); - } - Err(_) if fail_on_error => { - return internal_err!("Invalid hex string: {:?}", string); - } - _ => { - builder.append_null(); - } + if let Ok(_) = unhex(string, &mut encoded) { + builder.append_value(encoded.as_slice()); + encoded.clear(); + } else { + builder.append_null(); } } Ok(ColumnarValue::Array(Arc::new(builder.finish()))) @@ -659,22 +650,17 @@ fn spark_unhex( args: &[ColumnarValue], data_type: &DataType, ) -> Result { - if args.len() != 2 { - return internal_err!("unhex takes exactly two argument"); + if args.len() != 1 { + return internal_err!("unhex takes exactly one argument"); } let val_to_unhex = &args[0]; - let fail_on_error = if let ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) = &args[1] { - *b - } else { - return internal_err!("The second argument must be a boolean scalar"); - }; match data_type { - DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), - DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + DataType::Utf8 => spark_unhex_inner::(val_to_unhex), + DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex), _ => internal_err!( - "The first argument must be an array, but got: {:?}", + "The first argument must be string array, but got: {:?}", data_type ), } diff --git a/pom.xml b/pom.xml index 6d28c8168..a60a3c2f7 100644 --- a/pom.xml +++ b/pom.xml @@ -88,7 +88,8 @@ under the License. -ea -Xmx4g -Xss4m ${extraJavaTestArgs} spark-3.3-plus spark-3.4 - spark-3.x + spark-3.x + spark-3.4.x diff --git a/spark/pom.xml b/spark/pom.xml index 66ff82909..8c900451a 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -258,7 +258,8 @@ under the License. - src/main/${shims.source} + src/main/${shims.majorSource} + src/main/${shims.minorSource} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 99af03bc6..83482bc80 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -41,16 +41,18 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String + import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus, withInfo} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} import org.apache.comet.shims.ShimQueryPlanSerde +import org.apache.comet.shims.ShimCometUnhexExpr /** * An utility object for query plan and expression serialization. */ -object QueryPlanSerde extends Logging with ShimQueryPlanSerde { +object QueryPlanSerde extends Logging with ShimCometUnhexExpr with ShimQueryPlanSerde { def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") } @@ -1396,16 +1398,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr) optExprWithInfo(optExpr, expr, left, right) - case e @ Unhex(child, failOnError) => - val childCast = Cast(child, StringType) - val failOnErrorCast = Cast(Literal(failOnError), BooleanType) + case e: Unhex => + val unHex = unhexSerde(e) + + val childCast = Cast(unHex._1, StringType) + val failOnErrorCast = Cast(unHex._2, BooleanType) val childExpr = exprToProtoInternal(childCast, inputs) val failOnErrorExpr = exprToProtoInternal(failOnErrorCast, inputs) val optExpr = scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr) - optExprWithInfo(optExpr, expr, child, failOnErrorCast) + optExprWithInfo(optExpr, expr, unHex._1) case e @ Ceil(child) => val childExpr = exprToProtoInternal(child, inputs) diff --git a/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala new file mode 100644 index 000000000..f51b1b8c7 --- /dev/null +++ b/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StringType, BooleanType} +import org.apache.comet.serde.ExprOuterClass.Expr + +trait ShimCometUnhexExpr { + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(false)) + } +} diff --git a/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala new file mode 100644 index 000000000..450dd701c --- /dev/null +++ b/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StringType, BooleanType} +import org.apache.comet.serde.ExprOuterClass.Expr + +trait ShimCometUnhexExpr { + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(unhex.failOnError)) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index ae2e88da9..b59dd7324 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1031,7 +1031,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val table = "test" withTable(table) { sql(s"create table $table(col string) using parquet") - sql(s"insert into $table values('4A4D'), ('4A4D'), ('4A4D'), ('4A4D')") + sql(s"insert into $table values('537061726B2053514C')") checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table") } } From 6cb88c71c98fdd30aa75ee86330052570b76f51f Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 30 Apr 2024 19:40:47 -0700 Subject: [PATCH 03/20] style: cleanup --- .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 5 ++--- .../org/apache/comet/shims/ShimCometUnhexExpr.scala | 1 - .../org/apache/comet/shims/ShimCometUnhexExpr.scala | 1 - 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 83482bc80..e5b317813 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -41,18 +41,17 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus, withInfo} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} -import org.apache.comet.shims.ShimQueryPlanSerde import org.apache.comet.shims.ShimCometUnhexExpr +import org.apache.comet.shims.ShimQueryPlanSerde /** * An utility object for query plan and expression serialization. */ -object QueryPlanSerde extends Logging with ShimCometUnhexExpr with ShimQueryPlanSerde { +object QueryPlanSerde extends Logging with ShimQueryPlanSerde with ShimCometUnhexExpr { def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") } diff --git a/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala index f51b1b8c7..30a1cc04f 100644 --- a/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala +++ b/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala @@ -20,7 +20,6 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StringType, BooleanType} -import org.apache.comet.serde.ExprOuterClass.Expr trait ShimCometUnhexExpr { def unhexSerde(unhex: Unhex): (Expression, Expression) = { diff --git a/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala index 450dd701c..612169d32 100644 --- a/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala +++ b/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala @@ -20,7 +20,6 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StringType, BooleanType} -import org.apache.comet.serde.ExprOuterClass.Expr trait ShimCometUnhexExpr { def unhexSerde(unhex: Unhex): (Expression, Expression) = { From c649aef38e469b7b1120594c9567371274b07c2b Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 30 Apr 2024 19:53:30 -0700 Subject: [PATCH 04/20] refactor: set minor source --- pom.xml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pom.xml b/pom.xml index a60a3c2f7..e810632ed 100644 --- a/pom.xml +++ b/pom.xml @@ -501,6 +501,7 @@ under the License. not-needed-yet not-needed-yet + spark-3.3.x @@ -513,6 +514,7 @@ under the License. 1.12.0 spark-3.3-plus not-needed-yet + spark-3.3.x @@ -524,6 +526,7 @@ under the License. 1.13.1 spark-3.3-plus spark-3.4 + spark-3.4.x From bb4ad43bfbe079d792257808056265605034d5d7 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 30 Apr 2024 20:18:59 -0700 Subject: [PATCH 05/20] style: fix clippy in core --- core/src/execution/datafusion/expressions/scalar_funcs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index e02f17022..0c3d1d721 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -636,7 +636,7 @@ fn spark_unhex_inner( for i in 0..string_array.len() { let string = string_array.value(i); - if let Ok(_) = unhex(string, &mut encoded) { + if unhex(string, &mut encoded).is_ok() { builder.append_value(encoded.as_slice()); encoded.clear(); } else { From a0bdbbec9447e5e2c8d3233de314f3aad11fa334 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 30 Apr 2024 21:10:16 -0700 Subject: [PATCH 06/20] fix: fix tests --- .../datafusion/expressions/scalar_funcs.rs | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 0c3d1d721..a0887d710 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -42,8 +42,9 @@ use datafusion::{ physical_plan::ColumnarValue, }; use datafusion_common::{ - cast::{as_binary_array, as_generic_string_array}, - exec_err, internal_err, not_impl_err, DataFusionError, Result as DataFusionResult, ScalarValue, + cast::{as_binary_array, as_generic_string_array, as_string_array}, + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result as DataFusionResult, + ScalarValue, }; use datafusion_physical_expr::{math_expressions, udf::ScalarUDF}; use num::{ @@ -616,6 +617,7 @@ fn unhex(string: &str, result: &mut Vec) -> Result<(), DataFusionError> { fn spark_unhex_inner( array: &ColumnarValue, + fail_on_error: bool, ) -> Result { let string_array = match array { ColumnarValue::Array(array) => as_generic_string_array::(array)?, @@ -639,6 +641,8 @@ fn spark_unhex_inner( if unhex(string, &mut encoded).is_ok() { builder.append_value(encoded.as_slice()); encoded.clear(); + } else if fail_on_error { + return plan_err!("Input to unhex is not a valid hex string: {:?}", string); } else { builder.append_null(); } @@ -650,19 +654,34 @@ fn spark_unhex( args: &[ColumnarValue], data_type: &DataType, ) -> Result { - if args.len() != 1 { - return internal_err!("unhex takes exactly one argument"); + if args.len() > 2 { + return plan_err!("unhex takes at most 2 arguments, but got: {}", args.len()); } let val_to_unhex = &args[0]; + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, + _ => { + return plan_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; - match data_type { - DataType::Utf8 => spark_unhex_inner::(val_to_unhex), - DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex), - _ => internal_err!( - "The first argument must be string array, but got: {:?}", - data_type - ), + match val_to_unhex.data_type() { + DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + other => { + return internal_err!( + "The first argument must be a string scalar or array, but got: {:?}", + other + ); + } } } From 70c9dddcf07fac4ad803ae93a7b233e481375f0c Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 30 Apr 2024 21:11:00 -0700 Subject: [PATCH 07/20] style: fix clippy in core --- core/src/execution/datafusion/expressions/scalar_funcs.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index a0887d710..514afe31d 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -42,7 +42,7 @@ use datafusion::{ physical_plan::ColumnarValue, }; use datafusion_common::{ - cast::{as_binary_array, as_generic_string_array, as_string_array}, + cast::{as_binary_array, as_generic_string_array}, exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result as DataFusionResult, ScalarValue, }; @@ -652,7 +652,7 @@ fn spark_unhex_inner( fn spark_unhex( args: &[ColumnarValue], - data_type: &DataType, + _data_type: &DataType, ) -> Result { if args.len() > 2 { return plan_err!("unhex takes at most 2 arguments, but got: {}", args.len()); @@ -677,10 +677,10 @@ fn spark_unhex( DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), other => { - return internal_err!( + internal_err!( "The first argument must be a string scalar or array, but got: {:?}", other - ); + ) } } } From 663aef5bda593eccbead79c5b25fdd72f97fb845 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 30 Apr 2024 21:23:58 -0700 Subject: [PATCH 08/20] style: run scalacheck --- .../org/apache/comet/shims/ShimCometUnhexExpr.scala | 4 +++- .../org/apache/comet/shims/ShimCometUnhexExpr.scala | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala index 30a1cc04f..43d3343cf 100644 --- a/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala +++ b/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala @@ -19,8 +19,10 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, BooleanType} +/** + * `ShimCometUnhexExpr` parses the `Unhex` expression assuming that the catalyst version is 3.4.x. + */ trait ShimCometUnhexExpr { def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) diff --git a/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala index 612169d32..8fffb2c38 100644 --- a/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala +++ b/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala @@ -19,8 +19,10 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, BooleanType} +/** + * `ShimCometUnhexExpr` parses the `Unhex` expression assuming that the catalyst version is 3.4.x. + */ trait ShimCometUnhexExpr { def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) From bfe92c46d4174f722363aeb773f50227f1f786be Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Thu, 2 May 2024 15:25:10 -0700 Subject: [PATCH 09/20] refactor: update w/ feedback --- pom.xml | 8 ++--- .../apache/comet/serde/QueryPlanSerde.scala | 12 ++++---- .../comet/shims/ShimCometUnhexExpr.scala | 4 +-- .../comet/shims/ShimCometUnhexExpr.scala | 30 +++++++++++++++++++ .../comet/shims/ShimCometUnhexExpr.scala | 4 +-- 5 files changed, 44 insertions(+), 14 deletions(-) rename spark/src/main/{spark-3.3.x => spark-3.2}/org/apache/comet/shims/ShimCometUnhexExpr.scala (88%) create mode 100644 spark/src/main/spark-3.3/org/apache/comet/shims/ShimCometUnhexExpr.scala rename spark/src/main/{spark-3.4.x => spark-3.4}/org/apache/comet/shims/ShimCometUnhexExpr.scala (89%) diff --git a/pom.xml b/pom.xml index e810632ed..5e0454589 100644 --- a/pom.xml +++ b/pom.xml @@ -89,7 +89,7 @@ under the License. spark-3.3-plus spark-3.4 spark-3.x - spark-3.4.x + spark-3.4 @@ -501,7 +501,7 @@ under the License. not-needed-yet not-needed-yet - spark-3.3.x + spark-3.2 @@ -514,7 +514,7 @@ under the License. 1.12.0 spark-3.3-plus not-needed-yet - spark-3.3.x + spark-3.3 @@ -526,7 +526,7 @@ under the License. 1.13.1 spark-3.3-plus spark-3.4 - spark-3.4.x + spark-3.4 diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e5b317813..002161f82 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -45,13 +45,13 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isC import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} -import org.apache.comet.shims.ShimCometUnhexExpr +import org.apache.comet.shims.ShimCometExpr import org.apache.comet.shims.ShimQueryPlanSerde /** * An utility object for query plan and expression serialization. */ -object QueryPlanSerde extends Logging with ShimQueryPlanSerde with ShimCometUnhexExpr { +object QueryPlanSerde extends Logging with ShimQueryPlanSerde with ShimCometExpr { def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") } @@ -1400,11 +1400,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with ShimCometUnhe case e: Unhex => val unHex = unhexSerde(e) - val childCast = Cast(unHex._1, StringType) - val failOnErrorCast = Cast(unHex._2, BooleanType) + // val childCast = Cast(unHex._1, StringType) + // val failOnErrorCast = Cast(unHex._2, BooleanType) - val childExpr = exprToProtoInternal(childCast, inputs) - val failOnErrorExpr = exprToProtoInternal(failOnErrorCast, inputs) + val childExpr = exprToProtoInternal(unHex._1, inputs) + val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs) val optExpr = scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr) diff --git a/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/ShimCometUnhexExpr.scala similarity index 88% rename from spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala rename to spark/src/main/spark-3.2/org/apache/comet/shims/ShimCometUnhexExpr.scala index 43d3343cf..37f857f0f 100644 --- a/spark/src/main/spark-3.3.x/org/apache/comet/shims/ShimCometUnhexExpr.scala +++ b/spark/src/main/spark-3.2/org/apache/comet/shims/ShimCometUnhexExpr.scala @@ -21,9 +21,9 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ /** - * `ShimCometUnhexExpr` parses the `Unhex` expression assuming that the catalyst version is 3.4.x. + * `ShimCometExpr` parses the `Unhex` expression assuming that the catalyst version is 3.2.x. */ -trait ShimCometUnhexExpr { +trait ShimCometExpr { def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/ShimCometUnhexExpr.scala new file mode 100644 index 000000000..0b21ed8b4 --- /dev/null +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/ShimCometUnhexExpr.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * `ShimCometExpr` parses the `Unhex` expression assuming that the catalyst version is 3.3.x. + */ +trait ShimCometExpr { + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(false)) + } +} diff --git a/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometUnhexExpr.scala similarity index 89% rename from spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala rename to spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometUnhexExpr.scala index 8fffb2c38..b361b3a15 100644 --- a/spark/src/main/spark-3.4.x/org/apache/comet/shims/ShimCometUnhexExpr.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometUnhexExpr.scala @@ -21,9 +21,9 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ /** - * `ShimCometUnhexExpr` parses the `Unhex` expression assuming that the catalyst version is 3.4.x. + * `ShimCometExpr` parses the `Unhex` expression assuming that the catalyst version is 3.4.x. */ -trait ShimCometUnhexExpr { +trait ShimCometExpr { def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) } From 966d30785e5d4028813e31a457523a1b35100954 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Thu, 2 May 2024 15:47:06 -0700 Subject: [PATCH 10/20] refactor: delete unused code --- .../src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 002161f82..6a4bfb90d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1400,9 +1400,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with ShimCometExpr case e: Unhex => val unHex = unhexSerde(e) - // val childCast = Cast(unHex._1, StringType) - // val failOnErrorCast = Cast(unHex._2, BooleanType) - val childExpr = exprToProtoInternal(unHex._1, inputs) val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs) From 97eae4bff81cea543ff6aae32a535f8bee4aeff9 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Thu, 2 May 2024 16:50:53 -0700 Subject: [PATCH 11/20] refactor: improve rust --- .../datafusion/expressions/scalar_funcs.rs | 133 +--------------- .../expressions/scalar_funcs/unhex.rs | 149 ++++++++++++++++++ 2 files changed, 155 insertions(+), 127 deletions(-) create mode 100644 core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 514afe31d..8c5e1f391 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -43,8 +43,7 @@ use datafusion::{ }; use datafusion_common::{ cast::{as_binary_array, as_generic_string_array}, - exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result as DataFusionResult, - ScalarValue, + exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, }; use datafusion_physical_expr::{math_expressions, udf::ScalarUDF}; use num::{ @@ -53,6 +52,9 @@ use num::{ }; use unicode_segmentation::UnicodeSegmentation; +mod unhex; +use unhex::spark_unhex; + macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( @@ -107,7 +109,8 @@ pub fn create_comet_physical_fun( make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type) } "unhex" => { - make_comet_scalar_udf!("unhex", spark_unhex, data_type) + let func = Arc::new(spark_unhex); + make_comet_scalar_udf!("unhex", func, without data_type) } "decimal_div" => { make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) @@ -576,115 +579,6 @@ fn spark_rpad_internal( Ok(ColumnarValue::Array(Arc::new(result))) } -fn unhex(string: &str, result: &mut Vec) -> Result<(), DataFusionError> { - // https://docs.databricks.com/en/sql/language-manual/functions/unhex.html - // If the length of expr is odd, the first character is discarded and the result is padded with - // a null byte. If expr contains non hex characters the result is NULL. - let string = if string.len() % 2 == 1 { - &string[1..] - } else { - string - }; - - let mut iter = string.chars().peekable(); - while let Some(c) = iter.next() { - let high = if let Some(high) = c.to_digit(16) { - high - } else { - return Ok(()); - }; - - let low = iter - .next() - .ok_or_else(|| DataFusionError::Internal("Odd number of hex characters".to_string()))? - .to_digit(16); - - let low = if let Some(low) = low { - low - } else { - return Ok(()); - }; - - result.push((high << 4 | low) as u8); - } - - if string.len() % 2 == 1 { - result.push(0); - } - - Ok(()) -} - -fn spark_unhex_inner( - array: &ColumnarValue, - fail_on_error: bool, -) -> Result { - let string_array = match array { - ColumnarValue::Array(array) => as_generic_string_array::(array)?, - ColumnarValue::Scalar(ScalarValue::Utf8(Some(_string))) => { - return not_impl_err!("unhex with scalar string is not implemented yet"); - } - _ => { - return internal_err!( - "The first argument must be a string scalar or array, but got: {:?}", - array.data_type() - ); - } - }; - - let mut builder = arrow::array::BinaryBuilder::new(); - let mut encoded = Vec::new(); - - for i in 0..string_array.len() { - let string = string_array.value(i); - - if unhex(string, &mut encoded).is_ok() { - builder.append_value(encoded.as_slice()); - encoded.clear(); - } else if fail_on_error { - return plan_err!("Input to unhex is not a valid hex string: {:?}", string); - } else { - builder.append_null(); - } - } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) -} - -fn spark_unhex( - args: &[ColumnarValue], - _data_type: &DataType, -) -> Result { - if args.len() > 2 { - return plan_err!("unhex takes at most 2 arguments, but got: {}", args.len()); - } - - let val_to_unhex = &args[0]; - let fail_on_error = if args.len() == 2 { - match &args[1] { - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, - _ => { - return plan_err!( - "The second argument must be boolean scalar, but got: {:?}", - args[1] - ); - } - } - } else { - false - }; - - match val_to_unhex.data_type() { - DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), - DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), - other => { - internal_err!( - "The first argument must be a string scalar or array, but got: {:?}", - other - ) - } - } -} - // Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). // Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to // get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since @@ -813,18 +707,3 @@ fn wrap_digest_result_as_hex_string( } } } - -#[cfg(test)] -mod test { - use super::unhex; - - #[test] - fn test_unhex() { - let mut result = Vec::new(); - - unhex("537061726B2053514C", &mut result).unwrap(); - let result_str = std::str::from_utf8(&result).unwrap(); - assert_eq!(result_str, "Spark SQL"); - result.clear(); - } -} diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs new file mode 100644 index 000000000..5a7a16aba --- /dev/null +++ b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{Array, OffsetSizeTrait}; +use arrow_schema::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue}; + +fn unhex(string: &str, result: &mut Vec) -> Result<(), DataFusionError> { + if string.is_empty() { + return Ok(()); + } + + // Adjust the string if it has an odd length, and prepare to add a padding byte if needed. + let needs_padding = string.len() % 2 != 0; + let adjusted_string = if needs_padding { &string[1..] } else { string }; + + let mut iter = adjusted_string.chars().peekable(); + while let (Some(high_char), Some(low_char)) = (iter.next(), iter.next()) { + let high = high_char + .to_digit(16) + .ok_or_else(|| DataFusionError::Internal("Invalid hex character".to_string()))?; + let low = low_char + .to_digit(16) + .ok_or_else(|| DataFusionError::Internal("Invalid hex character".to_string()))?; + + result.push((high << 4 | low) as u8); + } + + if needs_padding { + result.push(0); + } + + Ok(()) +} + +fn spark_unhex_inner( + array: &ColumnarValue, + fail_on_error: bool, +) -> Result { + match array { + ColumnarValue::Array(array) => { + let string_array = as_generic_string_array::(array)?; + + let mut builder = arrow::array::BinaryBuilder::new(); + let mut encoded = Vec::new(); + + for i in 0..string_array.len() { + let string = string_array.value(i); + + if unhex(string, &mut encoded).is_ok() { + builder.append_value(encoded.as_slice()); + encoded.clear(); + } else if fail_on_error { + return exec_err!("Input to unhex is not a valid hex string: {string}"); + } else { + builder.append_null(); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))) => { + let mut encoded = Vec::new(); + + if unhex(string, &mut encoded).is_ok() { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encoded)))) + } else if fail_on_error { + exec_err!("Input to unhex is not a valid hex string: {string}") + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + } + _ => { + exec_err!( + "The first argument must be a string scalar or array, but got: {:?}", + array + ) + } + } +} + +pub(super) fn spark_unhex(args: &[ColumnarValue]) -> Result { + if args.len() > 2 { + return exec_err!("unhex takes at most 2 arguments, but got: {}", args.len()); + } + + let val_to_unhex = &args[0]; + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, + _ => { + return exec_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; + + match val_to_unhex.data_type() { + DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + other => exec_err!( + "The first argument must be a string scalar or array, but got: {:?}", + other + ), + } +} + +#[cfg(test)] +mod test { + use super::unhex; + + #[test] + fn test_unhex() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("537061726B2053514C", &mut result)?; + let result_str = std::str::from_utf8(&result)?; + assert_eq!(result_str, "Spark SQL"); + result.clear(); + + assert!(unhex("hello", &mut result).is_err()); + result.clear(); + + unhex("", &mut result)?; + assert!(result.is_empty()); + + Ok(()) + } +} From a378f7433836d81543c4e6b8ce6c108934a893c5 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Fri, 3 May 2024 12:23:40 -0700 Subject: [PATCH 12/20] refactor: rename to CometExprShim and update docs --- .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 4 ++-- .../{ShimCometUnhexExpr.scala => CometExprShim.scala} | 7 +++++-- .../{ShimCometUnhexExpr.scala => CometExprShim.scala} | 7 +++++-- .../{ShimCometUnhexExpr.scala => CometExprShim.scala} | 7 +++++-- 4 files changed, 17 insertions(+), 8 deletions(-) rename spark/src/main/spark-3.2/org/apache/comet/shims/{ShimCometUnhexExpr.scala => CometExprShim.scala} (83%) rename spark/src/main/spark-3.3/org/apache/comet/shims/{ShimCometUnhexExpr.scala => CometExprShim.scala} (83%) rename spark/src/main/spark-3.4/org/apache/comet/shims/{ShimCometUnhexExpr.scala => CometExprShim.scala} (83%) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 6a4bfb90d..16361334e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -45,13 +45,13 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isC import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} -import org.apache.comet.shims.ShimCometExpr +import org.apache.comet.shims.CometExprShim import org.apache.comet.shims.ShimQueryPlanSerde /** * An utility object for query plan and expression serialization. */ -object QueryPlanSerde extends Logging with ShimQueryPlanSerde with ShimCometExpr { +object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim { def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") } diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala similarity index 83% rename from spark/src/main/spark-3.2/org/apache/comet/shims/ShimCometUnhexExpr.scala rename to spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala index 37f857f0f..0c45a9c2c 100644 --- a/spark/src/main/spark-3.2/org/apache/comet/shims/ShimCometUnhexExpr.scala +++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala @@ -21,9 +21,12 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ /** - * `ShimCometExpr` parses the `Unhex` expression assuming that the catalyst version is 3.2.x. + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. */ -trait ShimCometExpr { +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala similarity index 83% rename from spark/src/main/spark-3.3/org/apache/comet/shims/ShimCometUnhexExpr.scala rename to spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala index 0b21ed8b4..0c45a9c2c 100644 --- a/spark/src/main/spark-3.3/org/apache/comet/shims/ShimCometUnhexExpr.scala +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala @@ -21,9 +21,12 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ /** - * `ShimCometExpr` parses the `Unhex` expression assuming that the catalyst version is 3.3.x. + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. */ -trait ShimCometExpr { +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometUnhexExpr.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala similarity index 83% rename from spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometUnhexExpr.scala rename to spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index b361b3a15..409e1c94b 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometUnhexExpr.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -21,9 +21,12 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ /** - * `ShimCometExpr` parses the `Unhex` expression assuming that the catalyst version is 3.4.x. + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. */ -trait ShimCometExpr { +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) } From 112c7c61db86f9a7383b0667af846ac02b1703c7 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Fri, 3 May 2024 12:37:13 -0700 Subject: [PATCH 13/20] refactor: rename majorSource to majorVerSrc, same with minor --- common/pom.xml | 4 ++-- pom.xml | 10 +++++----- spark/pom.xml | 4 ++-- .../org/apache/comet/CometExpressionSuite.scala | 14 +++++--------- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/common/pom.xml b/common/pom.xml index ac9d136b3..cc1f44481 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -179,8 +179,8 @@ under the License. - src/main/${shims.majorSource} - src/main/${shims.minorSource} + src/main/${shims.majorVerSrc} + src/main/${shims.minorVerSrc} diff --git a/pom.xml b/pom.xml index 5e0454589..3ffdda97a 100644 --- a/pom.xml +++ b/pom.xml @@ -88,8 +88,8 @@ under the License. -ea -Xmx4g -Xss4m ${extraJavaTestArgs} spark-3.3-plus spark-3.4 - spark-3.x - spark-3.4 + spark-3.x + spark-3.4 @@ -501,7 +501,7 @@ under the License. not-needed-yet not-needed-yet - spark-3.2 + spark-3.2 @@ -514,7 +514,7 @@ under the License. 1.12.0 spark-3.3-plus not-needed-yet - spark-3.3 + spark-3.3 @@ -526,7 +526,7 @@ under the License. 1.13.1 spark-3.3-plus spark-3.4 - spark-3.4 + spark-3.4 diff --git a/spark/pom.xml b/spark/pom.xml index 8c900451a..f8fe68221 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -258,8 +258,8 @@ under the License. - src/main/${shims.majorSource} - src/main/${shims.minorSource} + src/main/${shims.majorVerSrc} + src/main/${shims.minorVerSrc} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index b59dd7324..62e26247c 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1026,15 +1026,11 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("unhex") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql(s"create table $table(col string) using parquet") - sql(s"insert into $table values('537061726B2053514C')") - checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table") - } - } + val table = "test" + withTable(table) { + sql(s"create table $table(col string) using parquet") + sql(s"insert into $table values('537061726B2053514C')") + checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table") } } From 6146f3e1d2e7268d2e32d5d046b95530b00ae919 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Sat, 4 May 2024 11:24:20 -0700 Subject: [PATCH 14/20] refactor: import unhex impl and testing --- .../expressions/scalar_funcs/unhex.rs | 94 ++++++++++++++----- 1 file changed, 73 insertions(+), 21 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs index 5a7a16aba..c262ac379 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs @@ -22,29 +22,40 @@ use arrow_schema::DataType; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue}; -fn unhex(string: &str, result: &mut Vec) -> Result<(), DataFusionError> { - if string.is_empty() { - return Ok(()); +/// Helper function to convert a hex digit to a binary value. Returns None if the input is not a +/// valid hex digit. +fn unhex_digit(c: u8) -> Result { + match c { + b'0'..=b'9' => Ok(c - b'0'), + b'A'..=b'F' => Ok(10 + c - b'A'), + b'a'..=b'f' => Ok(10 + c - b'a'), + _ => Err(DataFusionError::Execution( + "Input to unhex_digit is not a valid hex digit".to_string(), + )), } +} + +/// Convert a hex string to binary and store the result in `result`. Returns an error if the input +/// is not a valid hex string. +fn unhex(hex_str: &str, result: &mut Vec) -> Result<(), DataFusionError> { + let bytes = hex_str.as_bytes(); - // Adjust the string if it has an odd length, and prepare to add a padding byte if needed. - let needs_padding = string.len() % 2 != 0; - let adjusted_string = if needs_padding { &string[1..] } else { string }; + let mut i = 0; - let mut iter = adjusted_string.chars().peekable(); - while let (Some(high_char), Some(low_char)) = (iter.next(), iter.next()) { - let high = high_char - .to_digit(16) - .ok_or_else(|| DataFusionError::Internal("Invalid hex character".to_string()))?; - let low = low_char - .to_digit(16) - .ok_or_else(|| DataFusionError::Internal("Invalid hex character".to_string()))?; + if (bytes.len() & 0x01) != 0 { + let v = unhex_digit(bytes[0])?; - result.push((high << 4 | low) as u8); + result.push(v); + i += 1; } - if needs_padding { - result.push(0); + while i < bytes.len() { + let first = unhex_digit(bytes[i])?; + let second = unhex_digit(bytes[i + 1])?; + // result.push(((first << 4) | second) & 0xFF); + result.push((first << 4) | second); + + i += 2; } Ok(()) @@ -130,7 +141,7 @@ mod test { use super::unhex; #[test] - fn test_unhex() -> Result<(), Box> { + fn test_unhex_valid() -> Result<(), Box> { let mut result = Vec::new(); unhex("537061726B2053514C", &mut result)?; @@ -138,12 +149,53 @@ mod test { assert_eq!(result_str, "Spark SQL"); result.clear(); - assert!(unhex("hello", &mut result).is_err()); + unhex("1C", &mut result)?; + assert_eq!(result, vec![28]); result.clear(); - unhex("", &mut result)?; - assert!(result.is_empty()); + unhex("737472696E67", &mut result)?; + assert_eq!(result, "string".as_bytes()); + result.clear(); + + unhex("1", &mut result)?; + assert_eq!(result, vec![1]); + result.clear(); + + Ok(()) + } + + #[test] + fn test_odd_length() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("A1B", &mut result)?; + assert_eq!(result, vec![10, 27]); + result.clear(); + + unhex("0A1B", &mut result)?; + assert_eq!(result, vec![10, 27]); + result.clear(); Ok(()) } + + #[test] + fn test_unhex_empty() { + let mut result = Vec::new(); + + // Empty hex string + unhex("", &mut result).unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn test_unhex_invalid() { + let mut result = Vec::new(); + + // Invalid hex strings + assert!(unhex("##", &mut result).is_err()); + assert!(unhex("G123", &mut result).is_err()); + assert!(unhex("hello", &mut result).is_err()); + assert!(unhex("\0", &mut result).is_err()); + } } From 1de0887a015079a7fd059af63ba2b3b13f193084 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Sat, 4 May 2024 12:53:38 -0700 Subject: [PATCH 15/20] tests: improve spark tests, better null handling --- .../expressions/scalar_funcs/unhex.rs | 51 +++++++++++++++---- .../apache/comet/CometExpressionSuite.scala | 17 +++++-- 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs index c262ac379..a7c8244b7 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use arrow_array::{Array, OffsetSizeTrait}; +use arrow_array::OffsetSizeTrait; use arrow_schema::DataType; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue}; @@ -52,7 +52,6 @@ fn unhex(hex_str: &str, result: &mut Vec) -> Result<(), DataFusionError> { while i < bytes.len() { let first = unhex_digit(bytes[i])?; let second = unhex_digit(bytes[i + 1])?; - // result.push(((first << 4) | second) & 0xFF); result.push((first << 4) | second); i += 2; @@ -69,17 +68,19 @@ fn spark_unhex_inner( ColumnarValue::Array(array) => { let string_array = as_generic_string_array::(array)?; - let mut builder = arrow::array::BinaryBuilder::new(); let mut encoded = Vec::new(); + let mut builder = arrow::array::BinaryBuilder::new(); - for i in 0..string_array.len() { - let string = string_array.value(i); - - if unhex(string, &mut encoded).is_ok() { - builder.append_value(encoded.as_slice()); - encoded.clear(); - } else if fail_on_error { - return exec_err!("Input to unhex is not a valid hex string: {string}"); + for item in string_array.iter() { + if let Some(s) = item { + if unhex(s, &mut encoded).is_ok() { + builder.append_value(encoded.as_slice()); + encoded.clear(); + } else if fail_on_error { + return exec_err!("Input to unhex is not a valid hex string: {s}"); + } else { + builder.append_null(); + } } else { builder.append_null(); } @@ -97,6 +98,9 @@ fn spark_unhex_inner( Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) } } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } _ => { exec_err!( "The first argument must be a string scalar or array, but got: {:?}", @@ -138,8 +142,33 @@ pub(super) fn spark_unhex(args: &[ColumnarValue]) -> Result Result<(), Box> { + let input = ArrayData::new_null(&arrow_schema::DataType::Utf8, 2); + let output = ArrayData::new_null(&arrow_schema::DataType::Binary, 2); + + let input = ColumnarValue::Array(Arc::new(make_array(input))); + let expected = ColumnarValue::Array(Arc::new(make_array(output))); + + let result = super::spark_unhex(&[input])?; + + match (result, expected) { + (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => { + assert_eq!(*result, *expected); + Ok(()) + } + _ => Err("Unexpected result type".into()), + } + } + #[test] fn test_unhex_valid() -> Result<(), Box> { let mut result = Vec::new(); diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 62e26247c..bd56ea2df 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1024,16 +1024,25 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } - test("unhex") { - val table = "test" + val table = "unhex_table" withTable(table) { sql(s"create table $table(col string) using parquet") - sql(s"insert into $table values('537061726B2053514C')") + + sql(s"""INSERT INTO $table VALUES + |('537061726B2053514C'), + |('737472696E67'), + |('\0'), + |(''), + |('###'), + |('G123'), + |('hello'), + |('A1B'), + |('0A1B')""".stripMargin) + checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table") } } - test("length, reverse, instr, replace, translate") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { From bd07fed3f307db60bb7944b63995c02d4e69e7d1 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Sat, 4 May 2024 13:31:41 -0700 Subject: [PATCH 16/20] docs: tweak docs --- .../execution/datafusion/expressions/scalar_funcs/unhex.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs index a7c8244b7..2726fe310 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs @@ -22,8 +22,7 @@ use arrow_schema::DataType; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue}; -/// Helper function to convert a hex digit to a binary value. Returns None if the input is not a -/// valid hex digit. +/// Helper function to convert a hex digit to a binary value. fn unhex_digit(c: u8) -> Result { match c { b'0'..=b'9' => Ok(c - b'0'), @@ -134,7 +133,7 @@ pub(super) fn spark_unhex(args: &[ColumnarValue]) -> Result spark_unhex_inner::(val_to_unhex, fail_on_error), DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), other => exec_err!( - "The first argument must be a string scalar or array, but got: {:?}", + "The first argument must be a Utf8 or LargeUtf8: {:?}", other ), } From 36baf8e87dd2689baf8c45961c7ed34a97f1f493 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 7 May 2024 18:46:39 -0700 Subject: [PATCH 17/20] build: dont use unhex on spark 3.2 --- .../src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 +- .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 16361334e..c320b581f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1397,7 +1397,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr) optExprWithInfo(optExpr, expr, left, right) - case e: Unhex => + case e: Unhex if !isSpark32 => val unHex = unhexSerde(e) val childExpr = exprToProtoInternal(unHex._1, inputs) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index bd56ea2df..eeccc8d24 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1025,6 +1025,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } test("unhex") { + assume(!isSpark32, "unhex function has incorrect behavior in 3.2") + val table = "unhex_table" withTable(table) { sql(s"create table $table(col string) using parquet") From d5a1c46a4cbe0832ecf0ed373c2c75c35769f961 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Wed, 8 May 2024 07:03:30 -0700 Subject: [PATCH 18/20] fix: escape null byte for scala 2.13 --- .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index eeccc8d24..4f1ae8f1a 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1034,7 +1034,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { sql(s"""INSERT INTO $table VALUES |('537061726B2053514C'), |('737472696E67'), - |('\0'), + |('\\0'), |(''), |('###'), |('G123'), From fb1c24a7b45c0be5dc70f03ab5fb37ea7147b2ce Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Wed, 8 May 2024 07:05:09 -0700 Subject: [PATCH 19/20] docs: better docs around why unhex test is skipped Co-authored-by: Andy Grove --- .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 4f1ae8f1a..82d889c22 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1025,6 +1025,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } test("unhex") { + // When running against Spark 3.2, we include a bug fix for https://issues.apache.org/jira/browse/SPARK-40924 that + // was added in Spark 3.3, so although Comet's behavior is more correct when running against Spark 3.2, it is not + // the same (and this only applies to edge cases with hex inputs with lengths that are not divisible by 2) assume(!isSpark32, "unhex function has incorrect behavior in 3.2") val table = "unhex_table" From c5c3fcd51edddba44d1dd3b5ad176066ece0a5eb Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Wed, 8 May 2024 15:28:08 -0700 Subject: [PATCH 20/20] fix: clear result vec of incomplete conversion on fail --- .../expressions/scalar_funcs/unhex.rs | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs index 2726fe310..38d5c0478 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs @@ -74,12 +74,12 @@ fn spark_unhex_inner( if let Some(s) = item { if unhex(s, &mut encoded).is_ok() { builder.append_value(encoded.as_slice()); - encoded.clear(); } else if fail_on_error { return exec_err!("Input to unhex is not a valid hex string: {s}"); } else { builder.append_null(); } + encoded.clear(); } else { builder.append_null(); } @@ -143,9 +143,11 @@ pub(super) fn spark_unhex(args: &[ColumnarValue]) -> Result Result<(), Box> { + let mut input = StringBuilder::new(); + + input.append_value("1CGG"); // 1C is ok, but GG is invalid + input.append_value("537061726B2053514C"); // followed by valid + + let input = ColumnarValue::Array(Arc::new(input.finish())); + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + + let result = super::spark_unhex(&[input, fail_on_error])?; + + let mut expected = BinaryBuilder::new(); + expected.append_null(); + expected.append_value("Spark SQL".as_bytes()); + + match (result, ColumnarValue::Array(Arc::new(expected.finish()))) { + (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => { + assert_eq!(*result, *expected); + + Ok(()) + } + _ => Err("Unexpected result type".into()), + } + } + #[test] fn test_unhex_valid() -> Result<(), Box> { let mut result = Vec::new();