Skip to content

Commit

Permalink
Move from invoke to invoke batch
Browse files Browse the repository at this point in the history
  • Loading branch information
joseph-isaacs committed Nov 19, 2024
1 parent ae73371 commit 6b3db8c
Show file tree
Hide file tree
Showing 121 changed files with 921 additions and 462 deletions.
6 changes: 5 additions & 1 deletion datafusion-examples/examples/advanced_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ impl ScalarUDFImpl for PowUdf {
///
/// However, it also means the implementation is more complex than when
/// using `create_udf`.
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
// DataFusion has arranged for the correct inputs to be passed to this
// function, but we check again to make sure
assert_eq!(args.len(), 2);
Expand Down
6 changes: 5 additions & 1 deletion datafusion-examples/examples/optimizer_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,11 @@ impl ScalarUDFImpl for MyEq {
Ok(DataType::Boolean)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
// this example simply returns "true" which is not what a real
// implementation would do.
Ok(ColumnarValue::Scalar(ScalarValue::from(true)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,11 @@ mod tests {
Ok(DataType::Int32)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
unimplemented!("DummyUDF::invoke")
}
}
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/tests/fuzz_cases/equivalence/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,11 @@ impl ScalarUDFImpl for TestScalarUDF {
Ok(input[0].sort_properties)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;

let arr: ArrayRef = match args[0].data_type() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,6 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
Ok(self.return_type.clone())
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!("index_with_offset function does not accept arguments")
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand Down Expand Up @@ -720,7 +716,11 @@ impl ScalarUDFImpl for CastToI64UDF {
Ok(ExprSimplifyResult::Simplified(new_expr))
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
unimplemented!("Function should have been simplified prior to evaluation")
}
}
Expand Down Expand Up @@ -848,7 +848,11 @@ impl ScalarUDFImpl for TakeUDF {
}

// The actual implementation
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
let take_idx = match &args[2] {
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize,
_ => unreachable!(),
Expand Down Expand Up @@ -956,7 +960,11 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
Ok(self.return_type.clone())
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
internal_err!("This function should not get invoked!")
}

Expand Down Expand Up @@ -1240,7 +1248,11 @@ impl ScalarUDFImpl for MyRegexUdf {
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
match args {
[ColumnarValue::Scalar(ScalarValue::Utf8(value))] => {
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2355,7 +2355,7 @@ mod test {
use crate::expr_fn::col;
use crate::{
case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue,
ScalarUDF, ScalarUDFImpl, Volatility,
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility,
};
use sqlparser::ast;
use sqlparser::ast::{Ident, IdentWithAlias};
Expand Down Expand Up @@ -2484,7 +2484,7 @@ mod test {
Ok(DataType::Utf8)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
}
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::function::{
};
use crate::{
conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator,
AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs,
ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
};
use crate::{
Expand Down Expand Up @@ -462,8 +462,8 @@ impl ScalarUDFImpl for SimpleScalarUDF {
Ok(self.return_type.clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
(self.fun)(args)
fn invoke(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
(self.fun)(args.args.as_slice())
}
}

Expand Down
104 changes: 17 additions & 87 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
use crate::expr::schema_name_from_exprs_comma_seperated_without_space;
use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
use crate::sort_properties::{ExprProperties, SortProperties};
use crate::{
ColumnarValue, Documentation, Expr, ScalarFunctionImplementation, Signature,
};
use crate::{ColumnarValue, Documentation, Expr, Signature};
use arrow::datatypes::DataType;
use datafusion_common::{not_impl_err, ExprSchema, Result};
use datafusion_expr_common::interval_arithmetic::Interval;
Expand Down Expand Up @@ -203,12 +201,6 @@ impl ScalarUDF {
self.inner.simplify(args, info)
}

#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke(args)
}

pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
self.inner.is_nullable(args, schema)
}
Expand All @@ -225,27 +217,9 @@ impl ScalarUDF {

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_with_args`] for more details.
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
pub fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_no_args(number_rows)
}

/// Returns a `ScalarFunctionImplementation` that can invoke the function
/// during execution
#[deprecated(since = "42.0.0", note = "Use `invoke_batch` instead")]
pub fn fun(&self) -> ScalarFunctionImplementation {
let captured = Arc::clone(&self.inner);
#[allow(deprecated)]
Arc::new(move |args| captured.invoke(args))
/// See [`ScalarUDFImpl::invoke`] for more details.
pub fn invoke(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke(args)
}

/// Get the circuits of inner implementation
Expand Down Expand Up @@ -329,7 +303,7 @@ where

pub struct ScalarFunctionArgs<'a> {
// The evaluated arguments to the function
pub args: &'a [ColumnarValue],
pub args: Vec<ColumnarValue>,
// The number of rows in record batch being evaluated
pub number_rows: usize,
// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`)
Expand All @@ -353,7 +327,7 @@ pub struct ScalarFunctionArgs<'a> {
/// # use std::sync::OnceLock;
/// # use arrow::datatypes::DataType;
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility};
/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
///
Expand Down Expand Up @@ -396,7 +370,7 @@ pub struct ScalarFunctionArgs<'a> {
/// Ok(DataType::Int32)
/// }
/// // The actual implementation would add one to the argument
/// fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { unimplemented!() }
/// fn invoke(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { unimplemented!() }
/// fn documentation(&self) -> Option<&Documentation> {
/// Some(get_doc())
/// }
Expand Down Expand Up @@ -490,33 +464,6 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
true
}

/// Invoke the function on `args`, returning the appropriate result
///
/// The function will be invoked passed with the slice of [`ColumnarValue`]
/// (either scalar or array).
///
/// If the function does not take any arguments, please use [invoke_no_args]
/// instead and return [not_impl_err] for this function.
///
///
/// # Performance
///
/// For the best performance, the implementations of `invoke` should handle
/// the common case when one or more of their arguments are constant values
/// (aka [`ColumnarValue::Scalar`]).
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
///
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke but called",
self.name()
)
}

/// Invoke the function with `args` and the number of rows,
/// returning the appropriate result.
///
Expand All @@ -531,24 +478,15 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
#[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")]
#[deprecated(since = "43.0.0", note = "Use `invoke` instead")]
fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
match args.is_empty() {
true =>
{
#[allow(deprecated)]
self.invoke_no_args(number_rows)
}
false =>
{
#[allow(deprecated)]
self.invoke(args)
}
}
not_impl_err!(
"invoke_batch, this method is deprecated implement `invoke` instead"
)
}

/// Invoke the function with `args: ScalarFunctionArgs` returning the appropriate result.
Expand All @@ -563,19 +501,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
/// Note that this invoke method replaces the original invoke function deprecated in
/// version = 42.1.0.
fn invoke(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.invoke_batch(args.args, args.number_rows)
}

/// Invoke the function without `args`, instead the number of rows are provided,
/// returning the appropriate result.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke_no_args but called",
self.name()
)
self.invoke_batch(args.args.as_slice(), args.number_rows)
}

/// Returns any aliases (alternate names) for this function.
Expand Down
18 changes: 15 additions & 3 deletions datafusion/functions-nested/src/array_has.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ impl ScalarUDFImpl for ArrayHas {
Ok(DataType::Boolean)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
match &args[1] {
ColumnarValue::Array(array_needle) => {
// the needle is already an array, convert the haystack to an array of the same length
Expand Down Expand Up @@ -322,7 +326,11 @@ impl ScalarUDFImpl for ArrayHasAll {
Ok(DataType::Boolean)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
make_scalar_function(array_has_all_inner)(args)
}

Expand Down Expand Up @@ -403,7 +411,11 @@ impl ScalarUDFImpl for ArrayHasAny {
Ok(DataType::Boolean)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
make_scalar_function(array_has_any_inner)(args)
}

Expand Down
6 changes: 5 additions & 1 deletion datafusion/functions-nested/src/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ impl ScalarUDFImpl for Cardinality {
})
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
make_scalar_function(cardinality_inner)(args)
}

Expand Down
Loading

0 comments on commit 6b3db8c

Please sign in to comment.