Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sql-udf): better hint display for invalid sql udf definition #15091

Merged
merged 18 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 40 additions & 5 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -280,23 +280,58 @@ $1 + 114514 + $1

# Named sql udf with invalid parameter in body definition
# Will be rejected at creation time
statement error failed to find named parameter aa
statement error
create function unknown_parameter(a INT) returns int language sql as 'select a + aa + a';
----
db error: ERROR: Failed to run the query

Caused by:
Invalid input syntax: Failed to conduct semantic check
Bind error: [sql udf] failed to find named parameter aa
In SQL UDF definition: `select a + aa + a`
^^


# With unnamed parameter
statement error
create function unnamed_param_hint(INT, INT) returns int language sql as 'select $1 + $3 + $2';
----
db error: ERROR: Failed to run the query

Caused by:
Invalid input syntax: Failed to conduct semantic check
Bind error: [sql udf] failed to find unnamed parameter $3
In SQL UDF definition: `select $1 + $3 + $2`
^^


# A mixture of both
statement error
create function mix_hint(INT, aa INT, INT) returns int language sql as 'select $1 + aa + a + $2';
----
db error: ERROR: Failed to run the query

Caused by:
Invalid input syntax: Failed to conduct semantic check
Bind error: [sql udf] failed to find named parameter a
In SQL UDF definition: `select $1 + aa + a + $2`
^


statement error Expected end of statement, found: 💩
create function call_regexp_replace() returns varchar language sql as 'select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')';

# Recursive definition can NOT be accepted at present due to semantic check
statement error failed to conduct semantic check, please see if you are calling non-existent functions
statement error Failed to conduct semantic check
create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)';

# Complex but error-prone definition, recursive & normal sql udfs interleaving
statement error failed to conduct semantic check, please see if you are calling non-existent functions
statement error Failed to conduct semantic check
create function recursive_non_recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + sub($1, $2)';

# Create a valid recursive function
# Please note we do NOT support actual running the recursive sql udf at present
statement error failed to conduct semantic check, please see if you are calling non-existent functions
statement error Failed to conduct semantic check
create function fib(INT) returns int
language sql as 'select case
when $1 = 0 then 0
Expand All @@ -307,7 +342,7 @@ create function fib(INT) returns int
end;';

# Calling a non-existent function
statement error failed to conduct semantic check, please see if you are calling non-existent functions
statement error Failed to conduct semantic check
create function non_exist(INT) returns int language sql as 'select yo(114514)';

# Try to create an sql udf with unnamed parameters whose return type mismatches with the sql body definition
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ easy-ext = "1"
educe = "0.5"
either = "1"
enum-as-inner = "0.6"
fancy-regex = "0.11.0"
fixedbitset = "0.4.2"
futures = { version = "0.3", default-features = false, features = ["alloc"] }
futures-async-stream = { workspace = true }
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ impl Binder {
// there will not exist any column identifiers
// And invalid cases should already be caught
// during semantic check phase
// Note: the error message here also help with hint display
// when invalid definition occurs at sql udf creation time
return Err(ErrorCode::BindError(format!(
"failed to find named parameter {column_name}"
"[sql udf] failed to find named parameter {column_name}"
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
))
.into());
}
Expand Down
7 changes: 7 additions & 0 deletions src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ impl Binder {
if self.udf_context.global_count() != 0 {
if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
return Ok(expr.clone());
} else {
// Same as `bind_column`, the error message here
// help with hint display when invalid definition occurs
return Err(ErrorCode::BindError(format!(
"[sql udf] failed to find unnamed parameter ${index}"
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
))
.into());
}
}

Expand Down
70 changes: 64 additions & 6 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::collections::HashMap;

use fancy_regex::Regex;
use itertools::Itertools;
use pgwire::pg_response::StatementType;
use risingwave_common::catalog::FunctionId;
Expand All @@ -24,14 +25,19 @@ use risingwave_sqlparser::ast::{
CreateFunctionBody, FunctionDefinition, ObjectName, OperateFunctionArg,
};
use risingwave_sqlparser::parser::{Parser, ParserError};
use thiserror_ext::AsReport;

use super::*;
use crate::binder::UdfContext;
use crate::catalog::CatalogError;
use crate::expr::{Expr, ExprImpl, Literal};
use crate::{bind_data_type, Binder};

const DEFAULT_ERR_MSG: &str = "Failed to conduct semantic check";

const PROMPT: &str = "In SQL UDF definition: ";

const SQL_UDF_PATTERN: &str = "[sql udf]";

/// Create a mock `udf_context`, which is used for semantic check
fn create_mock_udf_context(
arg_types: Vec<DataType>,
Expand All @@ -53,6 +59,23 @@ fn create_mock_udf_context(
ret
}

/// Find the pattern for better hint display
/// return the exact index where the pattern first appears
fn find_target(input: &str, target: &str) -> Option<usize> {
// Regex pattern to find `target` not preceded/followed by an ASCII letter.
// \b asserts a word boundary, and \B asserts NOT a word boundary
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
// The pattern uses negative lookbehind (?<!...) and lookahead (?!...) to ensure
// the target is not surrounded by ASCII alphabetic characters
let pattern = format!(r"(?<![A-Za-z]){0}(?![A-Za-z])", fancy_regex::escape(target));
let re = Regex::new(&pattern).unwrap();

let Ok(Some(ma)) = re.find(input) else {
return None;
};

Some(ma.start())
}

pub async fn handle_create_sql_function(
handler_args: HandlerArgs,
or_replace: bool,
Expand Down Expand Up @@ -205,11 +228,46 @@ pub async fn handle_create_sql_function(
.into());
}
}
Err(e) => return Err(ErrorCode::InvalidInputSyntax(format!(
"failed to conduct semantic check, please see if you are calling non-existence functions or parameters\ndetailed error message: {}",
e.as_report()
))
.into()),
Err(e) => {
if let ErrorCode::BindErrorRoot { expr: _, error } = e.inner() {
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved
let invalid_msg = error.to_string();

// Valid error message for hint display
let Some(_) = invalid_msg.as_str().find(SQL_UDF_PATTERN) else {
return Err(
ErrorCode::InvalidInputSyntax(DEFAULT_ERR_MSG.into()).into()
);
};

// Get the name of the invalid item
// We will just display the first one found
let invalid_item_name =
invalid_msg.split_whitespace().last().unwrap_or("null");

// Find the invalid parameter / column
let Some(idx) = find_target(body.as_str(), invalid_item_name) else {
return Err(
ErrorCode::InvalidInputSyntax(DEFAULT_ERR_MSG.into()).into()
);
};

// The exact error position for `^` to point to
let position = format!(
"{}{}",
" ".repeat(idx + PROMPT.len() + 1),
"^".repeat(invalid_item_name.len())
);

return Err(ErrorCode::InvalidInputSyntax(format!(
"{}\n{}\n{}`{}`\n{}",
DEFAULT_ERR_MSG, invalid_msg, PROMPT, body, position
))
.into());
} else {
// Otherwise return the default error message
return Err(ErrorCode::InvalidInputSyntax(DEFAULT_ERR_MSG.into()).into());
}
}
}
} else {
return Err(ErrorCode::InvalidInputSyntax(
Expand Down
1 change: 1 addition & 0 deletions src/workspace-hack/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ aws-smithy-types = { version = "1", default-features = false, features = ["byte-
axum = { version = "0.6" }
base64 = { version = "0.21", features = ["alloc"] }
bigdecimal = { version = "0.4" }
bit-set = { version = "0.5" }
bit-vec = { version = "0.6" }
bitflags = { version = "2", default-features = false, features = ["serde", "std"] }
byteorder = { version = "1" }
Expand Down
Loading