From 66304c08781088fa599a1afe5fc2f6d58cc6ba73 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Dec 2023 13:21:53 -0800 Subject: [PATCH] Return error for unresolved scalar function --- datafusion/expr/src/expr.rs | 62 ++++++++++++++++--- .../optimizer/src/common_subexpr_eliminate.rs | 10 +-- 2 files changed, 60 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 32296a15ded2..76b258a8a9e9 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -376,15 +376,17 @@ impl ScalarFunctionDefinition { /// Whether this function is volatile, i.e. whether it can return different results /// when evaluated multiple times with the same input. - pub fn is_volatile(&self) -> bool { + pub fn is_volatile(&self) -> Result { match self { ScalarFunctionDefinition::BuiltIn(fun) => { - fun.volatility() == crate::Volatility::Volatile + Ok(fun.volatility() == crate::Volatility::Volatile) } ScalarFunctionDefinition::UDF(udf) => { - udf.signature().volatility == crate::Volatility::Volatile + Ok(udf.signature().volatility == crate::Volatility::Volatile) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Cannot determine volatility of unresolved function") } - ScalarFunctionDefinition::Name(_) => false, } } } @@ -1708,10 +1710,10 @@ fn create_names(exprs: &[Expr]) -> Result { /// Whether the given expression is volatile, i.e. whether it can return different results /// when evaluated multiple times with the same input. -pub fn is_volatile(expr: &Expr) -> bool { +pub fn is_volatile(expr: &Expr) -> Result { match expr { Expr::ScalarFunction(func) => func.func_def.is_volatile(), - _ => false, + _ => Ok(false), } } @@ -1719,10 +1721,15 @@ pub fn is_volatile(expr: &Expr) -> bool { mod test { use crate::expr::Cast; use crate::expr_fn::col; - use crate::{case, lit, Expr}; + use crate::{ + case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction, + ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature, + Volatility, + }; use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_common::{Result, ScalarValue}; + use std::sync::Arc; #[test] fn format_case_when() -> Result<()> { @@ -1823,4 +1830,45 @@ mod test { "UInt32(1) OR UInt32(2)" ); } + + #[test] + fn test_is_volatile_scalar_func_definition() { + // BuiltIn + assert!( + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random) + .is_volatile() + .unwrap() + ); + assert!( + !ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs) + .is_volatile() + .unwrap() + ); + + // UDF + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); + let fun: ScalarFunctionImplementation = + Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); + let udf = Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + &return_type, + &fun, + )); + assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + let udf = Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile), + &return_type, + &fun, + )); + assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + // Unresolved function + ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc")) + .is_volatile() + .expect_err("Unresolved function should not be resolved"); + } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index b42b095a8c88..1e089257c61a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -518,7 +518,7 @@ enum ExprMask { } impl ExprMask { - fn ignores(&self, expr: &Expr) -> bool { + fn ignores(&self, expr: &Expr) -> Result { let is_normal_minus_aggregates = matches!( expr, Expr::Literal(..) @@ -529,14 +529,14 @@ impl ExprMask { | Expr::Wildcard { .. } ); - let is_volatile = is_volatile(expr); + let is_volatile = is_volatile(expr)?; let is_aggr = matches!(expr, Expr::AggregateFunction(..)); - match self { + Ok(match self { Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr, Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates, - } + }) } } @@ -628,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { let (idx, sub_expr_desc) = self.pop_enter_mark(); // skip exprs should not be recognize. - if self.expr_mask.ignores(expr) { + if self.expr_mask.ignores(expr)? { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc));