Skip to content

Commit

Permalink
simplifying the scalar/array handling fn
Browse files Browse the repository at this point in the history
  • Loading branch information
vaibhawvipul committed Jun 9, 2024
1 parent a3ab124 commit c80ac59
Showing 1 changed file with 17 additions and 35 deletions.
52 changes: 17 additions & 35 deletions core/src/execution/datafusion/expressions/scalar_funcs/chr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,44 +94,26 @@ impl ScalarUDFImpl for ChrFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(chr)(args)
handle_chr_fn(args)
}
}

/// The make_scalar_function function is a higher-order function that:
/// - Takes a function inner designed to operate on arrays.
/// - Wraps this function in a closure that can accept a mix of scalar and array inputs.
/// - Converts scalar inputs to arrays, calls the inner function, and then converts the result back
/// to a scalar if the original inputs were all scalars.
///
/// taken from datafusion utils
fn make_scalar_function<F>(inner: F) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef> + Sync + Send + 'static,
{
move |args: &[ColumnarValue]| {
// first, identify if any of the arguments is an Array. If yes, store its `len`,
// as any scalar will need to be converted to an array of len `len`.
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let is_scalar = len.is_none();

let args = ColumnarValue::values_to_arrays(args)?;

let result = (inner)(&args);

if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
fn handle_chr_fn(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let array = args[0].clone();
match array {
ColumnarValue::Array(array) => {
let array = chr(&[array])?;
Ok(ColumnarValue::Array(array))
}
ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => {
match core::char::from_u32(value as u32) {
Some(ch) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(ch.to_string())))),
None => exec_err!("requested character too large for encoding."),
}
}
ColumnarValue::Scalar(ScalarValue::Int64(None)) => {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
}
_ => exec_err!("The argument must be an Int64 array or scalar."),
}
}

0 comments on commit c80ac59

Please sign in to comment.