Skip to content

Commit

Permalink
Return error for unresolved scalar function
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Dec 13, 2023
1 parent ffe1756 commit 66304c0
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 12 deletions.
62 changes: 55 additions & 7 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,17 @@ impl ScalarFunctionDefinition {

/// Whether this function is volatile, i.e. whether it can return different results
/// when evaluated multiple times with the same input.
pub fn is_volatile(&self) -> bool {
pub fn is_volatile(&self) -> Result<bool> {
match self {
ScalarFunctionDefinition::BuiltIn(fun) => {
fun.volatility() == crate::Volatility::Volatile
Ok(fun.volatility() == crate::Volatility::Volatile)
}
ScalarFunctionDefinition::UDF(udf) => {
udf.signature().volatility == crate::Volatility::Volatile
Ok(udf.signature().volatility == crate::Volatility::Volatile)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Cannot determine volatility of unresolved function")
}
ScalarFunctionDefinition::Name(_) => false,
}
}
}
Expand Down Expand Up @@ -1708,21 +1710,26 @@ fn create_names(exprs: &[Expr]) -> Result<String> {

/// Whether the given expression is volatile, i.e. whether it can return different results
/// when evaluated multiple times with the same input.
pub fn is_volatile(expr: &Expr) -> bool {
pub fn is_volatile(expr: &Expr) -> Result<bool> {
match expr {
Expr::ScalarFunction(func) => func.func_def.is_volatile(),
_ => false,
_ => Ok(false),
}
}

#[cfg(test)]
mod test {
use crate::expr::Cast;
use crate::expr_fn::col;
use crate::{case, lit, Expr};
use crate::{
case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction,
ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature,
Volatility,
};
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_common::{Result, ScalarValue};
use std::sync::Arc;

#[test]
fn format_case_when() -> Result<()> {
Expand Down Expand Up @@ -1823,4 +1830,45 @@ mod test {
"UInt32(1) OR UInt32(2)"
);
}

#[test]
fn test_is_volatile_scalar_func_definition() {
// BuiltIn
assert!(
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random)
.is_volatile()
.unwrap()
);
assert!(
!ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs)
.is_volatile()
.unwrap()
);

// UDF
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation =
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
let udf = Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
&return_type,
&fun,
));
assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());

let udf = Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile),
&return_type,
&fun,
));
assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());

// Unresolved function
ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc"))
.is_volatile()
.expect_err("Unresolved function should not be resolved");
}
}
10 changes: 5 additions & 5 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ enum ExprMask {
}

impl ExprMask {
fn ignores(&self, expr: &Expr) -> bool {
fn ignores(&self, expr: &Expr) -> Result<bool> {
let is_normal_minus_aggregates = matches!(
expr,
Expr::Literal(..)
Expand All @@ -529,14 +529,14 @@ impl ExprMask {
| Expr::Wildcard { .. }
);

let is_volatile = is_volatile(expr);
let is_volatile = is_volatile(expr)?;

let is_aggr = matches!(expr, Expr::AggregateFunction(..));

match self {
Ok(match self {
Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr,
Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates,
}
})
}
}

Expand Down Expand Up @@ -628,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {

let (idx, sub_expr_desc) = self.pop_enter_mark();
// skip exprs should not be recognize.
if self.expr_mask.ignores(expr) {
if self.expr_mask.ignores(expr)? {
self.id_array[idx].0 = self.series_number;
let desc = Self::desc_expr(expr);
self.visit_stack.push(VisitRecord::ExprItem(desc));
Expand Down

0 comments on commit 66304c0

Please sign in to comment.