Skip to content

Commit

Permalink
refactor(udf-context): assemble sql udf related functions inside UdfC…
Browse files Browse the repository at this point in the history
…ontext
  • Loading branch information
xzhseh committed Jan 22, 2024
1 parent 657b8ab commit b9d7df7
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 115 deletions.
78 changes: 5 additions & 73 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ use risingwave_expr::window_function::{
Frame, FrameBound, FrameBounds, FrameExclusion, RowsFrameBounds, WindowFuncKind,
};
use risingwave_sqlparser::ast::{
self, Expr as AstExpr, Function, FunctionArg, FunctionArgExpr, Ident, SelectItem, SetExpr,
Statement, WindowFrameBound, WindowFrameExclusion, WindowFrameUnits, WindowSpec,
self, Function, FunctionArg, FunctionArgExpr, Ident, WindowFrameBound, WindowFrameExclusion,
WindowFrameUnits, WindowSpec,
};
use risingwave_sqlparser::parser::ParserError;
use thiserror_ext::AsReport;

use crate::binder::bind_context::Clause;
use crate::binder::{Binder, BoundQuery, BoundSetExpr};
use crate::catalog::function_catalog::FunctionCatalog;
use crate::binder::{Binder, BoundQuery, BoundSetExpr, UdfContext};
use crate::expr::{
AggCall, Expr, ExprImpl, ExprType, FunctionCall, FunctionCallWithLambda, Literal, Now, OrderBy,
Subquery, SubqueryKind, TableFunction, TableFunctionType, UserDefinedFunction, WindowFunction,
Expand Down Expand Up @@ -160,73 +159,6 @@ impl Binder {
return Ok(TableFunction::new(function_type, inputs)?.into());
}

/// TODO: add name related logic
/// NOTE: need to think of a way to prevent naming conflict
/// e.g., when existing column names conflict with parameter names in sql udf
fn create_udf_context(
args: &[FunctionArg],
_catalog: &Arc<FunctionCatalog>,
) -> Result<HashMap<String, AstExpr>> {
let mut ret: HashMap<String, AstExpr> = HashMap::new();
for (i, current_arg) in args.iter().enumerate() {
if let FunctionArg::Unnamed(arg) = current_arg {
let FunctionArgExpr::Expr(e) = arg else {
return Err(
ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()
);
};
// if catalog.arg_names.is_some() {
// todo!()
// }
ret.insert(format!("${}", i + 1), e.clone());
continue;
}
return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into());
}
Ok(ret)
}

fn extract_udf_expression(ast: Vec<Statement>) -> Result<AstExpr> {
if ast.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"the query for sql udf should contain only one statement".to_string(),
)
.into());
}

// Extract the expression out
let Statement::Query(query) = ast[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"invalid function definition, please recheck the syntax".to_string(),
)
.into());
};

let SetExpr::Select(select) = query.body else {
return Err(ErrorCode::InvalidInputSyntax(
"missing `select` body for sql udf expression, please recheck the syntax"
.to_string(),
)
.into());
};

if select.projection.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"`projection` should contain only one `SelectItem`".to_string(),
)
.into());
}

let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"expect `UnnamedExpr` for `projection`".to_string(),
)
.into());
};

Ok(expr)
}

// user defined function
// TODO: resolve schema name https://github.com/risingwavelabs/risingwave/issues/12422
if let Ok(schema) = self.first_valid_schema()
Expand Down Expand Up @@ -267,7 +199,7 @@ impl Binder {

// The actual inline logic for sql udf
// Note that we will always create new udf context for each sql udf
let Ok(context) = create_udf_context(&args, &Arc::clone(func)) else {
let Ok(context) = UdfContext::create_udf_context(&args, &Arc::clone(func)) else {
return Err(ErrorCode::InvalidInputSyntax(
"failed to create the `udf_context`, please recheck your function definition and syntax".to_string()
)
Expand Down Expand Up @@ -306,7 +238,7 @@ impl Binder {
self.udf_context.incr_global_count();
}

if let Ok(expr) = extract_udf_expression(ast) {
if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
let bind_result = self.bind_expr(expr);
// Restore context information for subsequent binding
self.udf_context.update_context(stashed_udf_context);
Expand Down
72 changes: 71 additions & 1 deletion src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use risingwave_common::error::Result;
use risingwave_common::session_config::{ConfigMap, SearchPath};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_sqlparser::ast::Statement;
use risingwave_sqlparser::ast::{
Expr as AstExpr, FunctionArg, FunctionArgExpr, SelectItem, SetExpr, Statement,
};

mod bind_context;
mod bind_param;
Expand Down Expand Up @@ -57,6 +59,7 @@ pub use update::BoundUpdate;
pub use values::BoundValues;

use crate::catalog::catalog_service::CatalogReadGuard;
use crate::catalog::function_catalog::FunctionCatalog;
use crate::catalog::schema_catalog::SchemaCatalog;
use crate::catalog::{CatalogResult, TableId, ViewId};
use crate::expr::ExprImpl;
Expand Down Expand Up @@ -168,6 +171,73 @@ impl UdfContext {
pub fn get_context(&self) -> HashMap<String, ExprImpl> {
self.udf_param_context.clone()
}

/// A common utility function to extract sql udf
/// expression out from the input `ast`
pub fn extract_udf_expression(ast: Vec<Statement>) -> Result<AstExpr> {
if ast.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"the query for sql udf should contain only one statement".to_string(),
)
.into());
}

// Extract the expression out
let Statement::Query(query) = ast[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"invalid function definition, please recheck the syntax".to_string(),
)
.into());
};

let SetExpr::Select(select) = query.body else {
return Err(ErrorCode::InvalidInputSyntax(
"missing `select` body for sql udf expression, please recheck the syntax"
.to_string(),
)
.into());
};

if select.projection.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"`projection` should contain only one `SelectItem`".to_string(),
)
.into());
}

let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"expect `UnnamedExpr` for `projection`".to_string(),
)
.into());
};

Ok(expr)
}

/// TODO: add name related logic
/// NOTE: need to think of a way to prevent naming conflict
/// e.g., when existing column names conflict with parameter names in sql udf
pub fn create_udf_context(
args: &[FunctionArg],
_catalog: &Arc<FunctionCatalog>,
) -> Result<HashMap<String, AstExpr>> {
let mut ret: HashMap<String, AstExpr> = HashMap::new();
for (i, current_arg) in args.iter().enumerate() {
if let FunctionArg::Unnamed(arg) = current_arg {
let FunctionArgExpr::Expr(e) = arg else {
return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into());
};
// if catalog.arg_names.is_some() {
// todo!()
// }
ret.insert(format!("${}", i + 1), e.clone());
continue;
}
return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into());
}
Ok(ret)
}
}

/// `ParameterTypes` is used to record the types of the parameters during binding. It works
Expand Down
43 changes: 2 additions & 41 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use risingwave_sqlparser::ast::{
use risingwave_sqlparser::parser::{Parser, ParserError};

use super::*;
use crate::binder::UdfContext;
use crate::catalog::CatalogError;
use crate::expr::{ExprImpl, Literal};
use crate::{bind_data_type, Binder};
Expand All @@ -41,46 +42,6 @@ fn create_mock_udf_context(arg_types: Vec<DataType>) -> HashMap<String, ExprImpl
.collect()
}

fn extract_udf_expression(ast: Vec<Statement>) -> Result<Expr> {
if ast.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"the query for sql udf should contain only one statement".to_string(),
)
.into());
}

// Extract the expression out
let Statement::Query(query) = ast[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"invalid function definition, please recheck the syntax".to_string(),
)
.into());
};

let SetExpr::Select(select) = query.body else {
return Err(ErrorCode::InvalidInputSyntax(
"missing `select` body for sql udf expression, please recheck the syntax".to_string(),
)
.into());
};

if select.projection.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"`projection` should contain only one `SelectItem`".to_string(),
)
.into());
}

let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"expect `UnnamedExpr` for `projection`".to_string(),
)
.into());
};

Ok(expr)
}

pub async fn handle_create_sql_function(
handler_args: HandlerArgs,
or_replace: bool,
Expand Down Expand Up @@ -213,7 +174,7 @@ pub async fn handle_create_sql_function(
.udf_context_mut()
.update_context(create_mock_udf_context(arg_types.clone()));

if let Ok(expr) = extract_udf_expression(ast) {
if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
if let Err(e) = binder.bind_expr(expr) {
return Err(ErrorCode::InvalidInputSyntax(
format!("failed to conduct semantic check, please see if you are calling non-existence functions.\nDetailed error: {e}")
Expand Down

0 comments on commit b9d7df7

Please sign in to comment.