Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(udf-context): assemble sql udf related functions inside UdfContext #14732

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 5 additions & 73 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ use risingwave_expr::window_function::{
Frame, FrameBound, FrameBounds, FrameExclusion, RowsFrameBounds, WindowFuncKind,
};
use risingwave_sqlparser::ast::{
self, Expr as AstExpr, Function, FunctionArg, FunctionArgExpr, Ident, SelectItem, SetExpr,
Statement, WindowFrameBound, WindowFrameExclusion, WindowFrameUnits, WindowSpec,
self, Function, FunctionArg, FunctionArgExpr, Ident, WindowFrameBound, WindowFrameExclusion,
WindowFrameUnits, WindowSpec,
};
use risingwave_sqlparser::parser::ParserError;
use thiserror_ext::AsReport;

use crate::binder::bind_context::Clause;
use crate::binder::{Binder, BoundQuery, BoundSetExpr};
use crate::catalog::function_catalog::FunctionCatalog;
use crate::binder::{Binder, BoundQuery, BoundSetExpr, UdfContext};
use crate::expr::{
AggCall, Expr, ExprImpl, ExprType, FunctionCall, FunctionCallWithLambda, Literal, Now, OrderBy,
Subquery, SubqueryKind, TableFunction, TableFunctionType, UserDefinedFunction, WindowFunction,
Expand Down Expand Up @@ -160,73 +159,6 @@ impl Binder {
return Ok(TableFunction::new(function_type, inputs)?.into());
}

/// TODO: add name related logic
/// NOTE: need to think of a way to prevent naming conflict
/// e.g., when existing column names conflict with parameter names in sql udf
fn create_udf_context(
args: &[FunctionArg],
_catalog: &Arc<FunctionCatalog>,
) -> Result<HashMap<String, AstExpr>> {
let mut ret: HashMap<String, AstExpr> = HashMap::new();
for (i, current_arg) in args.iter().enumerate() {
if let FunctionArg::Unnamed(arg) = current_arg {
let FunctionArgExpr::Expr(e) = arg else {
return Err(
ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()
);
};
// if catalog.arg_names.is_some() {
// todo!()
// }
ret.insert(format!("${}", i + 1), e.clone());
continue;
}
return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into());
}
Ok(ret)
}

fn extract_udf_expression(ast: Vec<Statement>) -> Result<AstExpr> {
if ast.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"the query for sql udf should contain only one statement".to_string(),
)
.into());
}

// Extract the expression out
let Statement::Query(query) = ast[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"invalid function definition, please recheck the syntax".to_string(),
)
.into());
};

let SetExpr::Select(select) = query.body else {
return Err(ErrorCode::InvalidInputSyntax(
"missing `select` body for sql udf expression, please recheck the syntax"
.to_string(),
)
.into());
};

if select.projection.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"`projection` should contain only one `SelectItem`".to_string(),
)
.into());
}

let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"expect `UnnamedExpr` for `projection`".to_string(),
)
.into());
};

Ok(expr)
}

// user defined function
// TODO: resolve schema name https://github.com/risingwavelabs/risingwave/issues/12422
if let Ok(schema) = self.first_valid_schema()
Expand Down Expand Up @@ -267,7 +199,7 @@ impl Binder {

// The actual inline logic for sql udf
// Note that we will always create new udf context for each sql udf
let Ok(context) = create_udf_context(&args, &Arc::clone(func)) else {
let Ok(context) = UdfContext::create_udf_context(&args, &Arc::clone(func)) else {
return Err(ErrorCode::InvalidInputSyntax(
"failed to create the `udf_context`, please recheck your function definition and syntax".to_string()
)
Expand Down Expand Up @@ -306,7 +238,7 @@ impl Binder {
self.udf_context.incr_global_count();
}

if let Ok(expr) = extract_udf_expression(ast) {
if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
let bind_result = self.bind_expr(expr);
// Restore context information for subsequent binding
self.udf_context.update_context(stashed_udf_context);
Expand Down
72 changes: 71 additions & 1 deletion src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use risingwave_common::error::Result;
use risingwave_common::session_config::{ConfigMap, SearchPath};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_sqlparser::ast::Statement;
use risingwave_sqlparser::ast::{
Expr as AstExpr, FunctionArg, FunctionArgExpr, SelectItem, SetExpr, Statement,
};

mod bind_context;
mod bind_param;
Expand Down Expand Up @@ -57,6 +59,7 @@ pub use update::BoundUpdate;
pub use values::BoundValues;

use crate::catalog::catalog_service::CatalogReadGuard;
use crate::catalog::function_catalog::FunctionCatalog;
use crate::catalog::schema_catalog::SchemaCatalog;
use crate::catalog::{CatalogResult, TableId, ViewId};
use crate::expr::ExprImpl;
Expand Down Expand Up @@ -168,6 +171,73 @@ impl UdfContext {
pub fn get_context(&self) -> HashMap<String, ExprImpl> {
self.udf_param_context.clone()
}

/// A common utility function to extract sql udf
/// expression out from the input `ast`
pub fn extract_udf_expression(ast: Vec<Statement>) -> Result<AstExpr> {
if ast.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"the query for sql udf should contain only one statement".to_string(),
)
.into());
}

// Extract the expression out
let Statement::Query(query) = ast[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"invalid function definition, please recheck the syntax".to_string(),
)
.into());
};

let SetExpr::Select(select) = query.body else {
return Err(ErrorCode::InvalidInputSyntax(
"missing `select` body for sql udf expression, please recheck the syntax"
.to_string(),
)
.into());
};

if select.projection.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"`projection` should contain only one `SelectItem`".to_string(),
)
.into());
}

let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"expect `UnnamedExpr` for `projection`".to_string(),
)
.into());
};

Ok(expr)
}

/// TODO: add name related logic
/// NOTE: need to think of a way to prevent naming conflict
/// e.g., when existing column names conflict with parameter names in sql udf
pub fn create_udf_context(
args: &[FunctionArg],
_catalog: &Arc<FunctionCatalog>,
) -> Result<HashMap<String, AstExpr>> {
let mut ret: HashMap<String, AstExpr> = HashMap::new();
for (i, current_arg) in args.iter().enumerate() {
if let FunctionArg::Unnamed(arg) = current_arg {
let FunctionArgExpr::Expr(e) = arg else {
return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into());
};
// if catalog.arg_names.is_some() {
// todo!()
// }
ret.insert(format!("${}", i + 1), e.clone());
continue;
}
return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into());
}
Ok(ret)
}
}

/// `ParameterTypes` is used to record the types of the parameters during binding. It works
Expand Down
43 changes: 2 additions & 41 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use risingwave_sqlparser::ast::{
use risingwave_sqlparser::parser::{Parser, ParserError};

use super::*;
use crate::binder::UdfContext;
use crate::catalog::CatalogError;
use crate::expr::{ExprImpl, Literal};
use crate::{bind_data_type, Binder};
Expand All @@ -41,46 +42,6 @@ fn create_mock_udf_context(arg_types: Vec<DataType>) -> HashMap<String, ExprImpl
.collect()
}

fn extract_udf_expression(ast: Vec<Statement>) -> Result<Expr> {
if ast.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"the query for sql udf should contain only one statement".to_string(),
)
.into());
}

// Extract the expression out
let Statement::Query(query) = ast[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"invalid function definition, please recheck the syntax".to_string(),
)
.into());
};

let SetExpr::Select(select) = query.body else {
return Err(ErrorCode::InvalidInputSyntax(
"missing `select` body for sql udf expression, please recheck the syntax".to_string(),
)
.into());
};

if select.projection.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"`projection` should contain only one `SelectItem`".to_string(),
)
.into());
}

let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"expect `UnnamedExpr` for `projection`".to_string(),
)
.into());
};

Ok(expr)
}

pub async fn handle_create_sql_function(
handler_args: HandlerArgs,
or_replace: bool,
Expand Down Expand Up @@ -213,7 +174,7 @@ pub async fn handle_create_sql_function(
.udf_context_mut()
.update_context(create_mock_udf_context(arg_types.clone()));

if let Ok(expr) = extract_udf_expression(ast) {
if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
if let Err(e) = binder.bind_expr(expr) {
return Err(ErrorCode::InvalidInputSyntax(
format!("failed to conduct semantic check, please see if you are calling non-existence functions.\nDetailed error: {e}")
Expand Down
Loading