diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index ca440d329..e6f8de16b 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -27,6 +27,7 @@ use arrow::{ use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array}; use arrow_schema::DataType; use datafusion::{ + execution::FunctionRegistry, logical_expr::{ BuiltinScalarFunction, ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDFImpl, Signature, Volatility, @@ -48,6 +49,7 @@ use unicode_segmentation::UnicodeSegmentation; pub fn create_comet_physical_fun( fun_name: &str, data_type: DataType, + registry: &dyn FunctionRegistry, ) -> Result { match fun_name { "ceil" => { @@ -128,8 +130,12 @@ pub fn create_comet_physical_fun( ))) } _ => { - let fun = &BuiltinScalarFunction::from_str(fun_name)?; - Ok(ScalarFunctionDefinition::BuiltIn(*fun)) + let fun = BuiltinScalarFunction::from_str(fun_name); + if fun.is_err() { + Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?)) + } else { + Ok(ScalarFunctionDefinition::BuiltIn(fun?)) + } } } } diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 1b0a2d0d8..adb143aca 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -23,6 +23,7 @@ use arrow_schema::{DataType, Field, Schema, TimeUnit}; use datafusion::{ arrow::{compute::SortOptions, datatypes::SchemaRef}, common::DataFusionError, + execution::FunctionRegistry, functions::math, logical_expr::{ BuiltinScalarFunction, Operator as DataFusionOperator, ScalarFunctionDefinition, @@ -45,6 +46,7 @@ use datafusion::{ sorts::sort::SortExec, ExecutionPlan, Partitioning, }, + prelude::SessionContext, }; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, @@ -109,20 +111,16 @@ pub struct PhysicalPlanner { // The execution context id of this planner. exec_context_id: i64, execution_props: ExecutionProps, -} - -impl Default for PhysicalPlanner { - fn default() -> Self { - Self::new() - } + session_ctx: Arc, } impl PhysicalPlanner { - pub fn new() -> Self { + pub fn new(session_ctx: Arc) -> Self { let execution_props = ExecutionProps::new(); Self { exec_context_id: TEST_EXEC_CONTEXT_ID, execution_props, + session_ctx, } } @@ -130,6 +128,7 @@ impl PhysicalPlanner { Self { exec_context_id, execution_props: self.execution_props, + session_ctx: self.session_ctx.clone(), } } @@ -636,7 +635,11 @@ impl PhysicalPlanner { Ok(DataType::Decimal128(_p2, _s2)), ) => { let data_type = return_type.map(to_arrow_datatype).unwrap(); - let fun_expr = create_comet_physical_fun("decimal_div", data_type.clone())?; + let fun_expr = create_comet_physical_fun( + "decimal_div", + data_type.clone(), + &self.session_ctx.state(), + )?; Ok(Arc::new(ScalarFunctionExpr::new( "decimal_div", fun_expr, @@ -1213,11 +1216,19 @@ impl PhysicalPlanner { // scalar function // Note this assumes the `fun_name` is a defined function in DF. Otherwise, it'll // throw error. - let fun = &BuiltinScalarFunction::from_str(fun_name)?; - fun.return_type(&input_expr_types)? + let fun = BuiltinScalarFunction::from_str(fun_name); + if fun.is_err() { + self.session_ctx + .udf(fun_name)? + .inner() + .return_type(&input_expr_types)? + } else { + fun?.return_type(&input_expr_types)? + } } }; - let fun_expr = create_comet_physical_fun(fun_name, data_type.clone())?; + let fun_expr = + create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?; let scalar_expr: Arc = Arc::new(ScalarFunctionExpr::new( fun_name, diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index 20f98a3a4..8249097a1 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -321,7 +321,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // Because we don't know if input arrays are dictionary-encoded when we create // query plan, we need to defer stream initialization to first time execution. if exec_context.root_op.is_none() { - let planner = PhysicalPlanner::new().with_exec_id(exec_context_id); + let planner = PhysicalPlanner::new(exec_context.session_ctx.clone()) + .with_exec_id(exec_context_id); let (scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(),