Skip to content

Commit

Permalink
fix(sql-udf): correctly handle udf_binding_flag & udf_global_count (
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhseh authored Feb 4, 2024
1 parent 50b91c0 commit 823382b
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 29 deletions.
12 changes: 10 additions & 2 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,13 @@ select add_named(a, b) from t3 order by a asc;
################################

# Mixed parameter with calling inner sql udfs
# statement ok
# create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)';
statement ok
create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)';

query I
select add_sub_mix_wrapper(1, 2, 3);
----
4

# Named sql udf with corner case
statement ok
Expand Down Expand Up @@ -404,6 +409,9 @@ drop function add_named_wrapper;
statement ok
drop function type_match;

statement ok
drop function add_sub_mix_wrapper;

statement ok
drop table t1;

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl Binder {
// to the name of the defined sql udf parameters stored in `udf_context`.
// If so, we will treat this bind as an special bind, the actual expression
// stored in `udf_context` will then be bound instead of binding the non-existing column.
if self.udf_binding_flag {
if self.udf_context.global_count() != 0 {
if let Some(expr) = self.udf_context.get_expr(&column_name) {
return Ok(expr.clone());
} else {
Expand Down
8 changes: 6 additions & 2 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,16 @@ impl Binder {
}

if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
self.set_udf_binding_flag();
let bind_result = self.bind_expr(expr);
self.unset_udf_binding_flag();

// We should properly decrement global count after a successful binding
// Since the subsequent probe operation in `bind_column` or
// `bind_parameter` relies on global counting
self.udf_context.decr_global_count();

// Restore context information for subsequent binding
self.udf_context.update_context(stashed_udf_context);

return bind_result;
} else {
return Err(ErrorCode::InvalidInputSyntax(
Expand Down
8 changes: 5 additions & 3 deletions src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,13 @@ impl Binder {

fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
// Special check for sql udf
// Note: This is specific to anonymous sql udf, since the
// Note: This is specific to sql udf with unnamed parameters, since the
// parameters will be parsed and treated as `Parameter`.
// For detailed explanation, consider checking `bind_column`.
if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
return Ok(expr.clone());
if self.udf_context.global_count() != 0 {
if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
return Ok(expr.clone());
}
}

Ok(Parameter::new(index, self.param_types.clone()).into())
Expand Down
33 changes: 16 additions & 17 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ pub struct Binder {

/// The sql udf context that will be used during binding phase
udf_context: UdfContext,

/// Udf binding flag, used to distinguish between
/// columns and named parameters during sql udf binding
udf_binding_flag: bool,
}

#[derive(Clone, Debug, Default)]
Expand Down Expand Up @@ -155,6 +151,10 @@ impl UdfContext {
self.udf_global_counter += 1;
}

pub fn decr_global_count(&mut self) {
self.udf_global_counter -= 1;
}

pub fn _is_empty(&self) -> bool {
self.udf_param_context.is_empty()
}
Expand Down Expand Up @@ -219,6 +219,8 @@ impl UdfContext {
Ok(expr)
}

/// Create the sql udf context
/// used per `bind_function` for sql udf & semantic check at definition time
pub fn create_udf_context(
args: &[FunctionArg],
catalog: &Arc<FunctionCatalog>,
Expand All @@ -228,9 +230,10 @@ impl UdfContext {
match current_arg {
FunctionArg::Unnamed(arg) => {
let FunctionArgExpr::Expr(e) = arg else {
return Err(
ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()
);
return Err(ErrorCode::InvalidInputSyntax(
"expect `FunctionArgExpr` for unnamed argument".to_string(),
)
.into());
};
if catalog.arg_names[i].is_empty() {
ret.insert(format!("${}", i + 1), e.clone());
Expand All @@ -240,7 +243,12 @@ impl UdfContext {
ret.insert(catalog.arg_names[i].clone(), e.clone());
}
}
_ => return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()),
_ => {
return Err(ErrorCode::InvalidInputSyntax(
"expect unnamed argument when creating sql udf context".to_string(),
)
.into())
}
}
}
Ok(ret)
Expand Down Expand Up @@ -347,7 +355,6 @@ impl Binder {
included_relations: HashSet::new(),
param_types: ParameterTypes::new(param_types),
udf_context: UdfContext::new(),
udf_binding_flag: false,
}
}

Expand Down Expand Up @@ -497,14 +504,6 @@ impl Binder {
pub fn udf_context_mut(&mut self) -> &mut UdfContext {
&mut self.udf_context
}

pub fn set_udf_binding_flag(&mut self) {
self.udf_binding_flag = true;
}

pub fn unset_udf_binding_flag(&mut self) {
self.udf_binding_flag = false;
}
}

/// The column name stored in [`BindContext`] for a column without an alias.
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ pub async fn handle_create_sql_function(
arg_names.clone(),
));

binder.set_udf_binding_flag();
// Need to set the initial global count to 1
// otherwise the context will not be probed during the semantic check
binder.udf_context_mut().incr_global_count();

if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
match binder.bind_expr(expr) {
Expand All @@ -204,7 +206,7 @@ pub async fn handle_create_sql_function(
}
}
Err(e) => return Err(ErrorCode::InvalidInputSyntax(format!(
"failed to conduct semantic check, please see if you are calling non-existent functions: {}",
"failed to conduct semantic check, please see if you are calling non-existence functions or parameters\ndetailed error message: {}",
e.as_report()
))
.into()),
Expand All @@ -217,8 +219,6 @@ pub async fn handle_create_sql_function(
)
.into());
}

binder.unset_udf_binding_flag();
}

// Create the actual function, will be stored in function catalog
Expand Down

0 comments on commit 823382b

Please sign in to comment.