diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index dc333e8be..2d2e64b9d 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -58,6 +58,9 @@ use unhex::spark_unhex; mod hex; use hex::spark_hex; +mod chr; +use chr::spark_chr; + macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( @@ -130,6 +133,10 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_xxhash64); make_comet_scalar_udf!("xxhash64", func, without data_type) } + "chr" => { + let func = Arc::new(spark_chr); + make_comet_scalar_udf!("chr", func, without data_type) + } sha if sha2_functions.contains(&sha) => { // Spark requires hex string as the result of sha2 functions, we have to wrap the // result of digest functions as hex string diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/chr.rs b/core/src/execution/datafusion/expressions/scalar_funcs/chr.rs new file mode 100644 index 000000000..53d498443 --- /dev/null +++ b/core/src/execution/datafusion/expressions/scalar_funcs/chr.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{ArrayRef, StringArray}, + datatypes::{ + DataType, + DataType::{Int64, Utf8}, + }, +}; + +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{cast::as_int64_array, exec_err, DataFusionError, Result}; + +/// chr(65) = 'A' + +// Compatible with Apache Spark's Chr function +pub fn spark_chr(args: &[ColumnarValue]) -> Result { + let chr_func = ChrFunc::default(); + chr_func.invoke(args) +} + +pub fn chr(args: &[ArrayRef]) -> Result { + let integer_array = as_int64_array(&args[0])?; + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|integer: Option| { + integer + .map(|integer| match core::char::from_u32(integer as u32) { + Some(integer) => Ok(integer.to_string()), + None => { + exec_err!("requested character too large for encoding.") + } + }) + .transpose() + }) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub struct ChrFunc { + signature: Signature, +} + +impl Default for ChrFunc { + fn default() -> Self { + Self::new() + } +} + +impl ChrFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ChrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "chr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let array = args[0].clone(); + match array { + ColumnarValue::Array(array) => { + let array = chr(&[array])?; + Ok(ColumnarValue::Array(array)) + } + _ => { + exec_err!("The first argument must be an array, but got: {:?}", array) + } + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index cb1dea847..c04b53d77 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -981,6 +981,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("Chr with null character") { + // test compatibility with Spark, spark supports chr(0) Seq(false, true).foreach { dictionary => withSQLConf( "parquet.enable.dictionary" -> dictionary.toString, @@ -993,10 +994,14 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { withTable(table) { sql(s"create table $table(c9 int, c4 int) using parquet") sql(s"insert into $table values(0, 0), (66, null), (null, 70), (null, null)") - checkSparkMaybeThrows(sql(s"SELECT chr(c9), chr(c4) FROM $table")) match { + val query = s"SELECT chr(c9), chr(c4) FROM $table" + checkSparkMaybeThrows(sql(query)) match { case (None, None) => {} - case (_, _) => fail("Expected no exception") + case (Some(e), None) => throw e + case (None, Some(e)) => throw e + case (Some(e1), Some(e2)) => throw new Exception(s"$e1, $e2") } + checkSparkAnswerAndOperator(query) } } }