From b9d7df7db2c7131cb20ab8618977cfcefd0ff4c2 Mon Sep 17 00:00:00 2001 From: Michael Xu Date: Mon, 22 Jan 2024 15:04:06 -0500 Subject: [PATCH] refactor(udf-context): assemble sql udf related functions inside UdfContext --- src/frontend/src/binder/expr/function.rs | 78 ++----------------- src/frontend/src/binder/mod.rs | 72 ++++++++++++++++- .../src/handler/create_sql_function.rs | 43 +--------- 3 files changed, 78 insertions(+), 115 deletions(-) diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index c0ffc86cd5ebe..1e21c45ae033a 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -30,15 +30,14 @@ use risingwave_expr::window_function::{ Frame, FrameBound, FrameBounds, FrameExclusion, RowsFrameBounds, WindowFuncKind, }; use risingwave_sqlparser::ast::{ - self, Expr as AstExpr, Function, FunctionArg, FunctionArgExpr, Ident, SelectItem, SetExpr, - Statement, WindowFrameBound, WindowFrameExclusion, WindowFrameUnits, WindowSpec, + self, Function, FunctionArg, FunctionArgExpr, Ident, WindowFrameBound, WindowFrameExclusion, + WindowFrameUnits, WindowSpec, }; use risingwave_sqlparser::parser::ParserError; use thiserror_ext::AsReport; use crate::binder::bind_context::Clause; -use crate::binder::{Binder, BoundQuery, BoundSetExpr}; -use crate::catalog::function_catalog::FunctionCatalog; +use crate::binder::{Binder, BoundQuery, BoundSetExpr, UdfContext}; use crate::expr::{ AggCall, Expr, ExprImpl, ExprType, FunctionCall, FunctionCallWithLambda, Literal, Now, OrderBy, Subquery, SubqueryKind, TableFunction, TableFunctionType, UserDefinedFunction, WindowFunction, @@ -160,73 +159,6 @@ impl Binder { return Ok(TableFunction::new(function_type, inputs)?.into()); } - /// TODO: add name related logic - /// NOTE: need to think of a way to prevent naming conflict - /// e.g., when existing column names conflict with parameter names in sql udf - fn create_udf_context( - args: &[FunctionArg], - _catalog: &Arc, - ) -> Result> { - let mut ret: HashMap = HashMap::new(); - for (i, current_arg) in args.iter().enumerate() { - if let FunctionArg::Unnamed(arg) = current_arg { - let FunctionArgExpr::Expr(e) = arg else { - return Err( - ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into() - ); - }; - // if catalog.arg_names.is_some() { - // todo!() - // } - ret.insert(format!("${}", i + 1), e.clone()); - continue; - } - return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()); - } - Ok(ret) - } - - fn extract_udf_expression(ast: Vec) -> Result { - if ast.len() != 1 { - return Err(ErrorCode::InvalidInputSyntax( - "the query for sql udf should contain only one statement".to_string(), - ) - .into()); - } - - // Extract the expression out - let Statement::Query(query) = ast[0].clone() else { - return Err(ErrorCode::InvalidInputSyntax( - "invalid function definition, please recheck the syntax".to_string(), - ) - .into()); - }; - - let SetExpr::Select(select) = query.body else { - return Err(ErrorCode::InvalidInputSyntax( - "missing `select` body for sql udf expression, please recheck the syntax" - .to_string(), - ) - .into()); - }; - - if select.projection.len() != 1 { - return Err(ErrorCode::InvalidInputSyntax( - "`projection` should contain only one `SelectItem`".to_string(), - ) - .into()); - } - - let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else { - return Err(ErrorCode::InvalidInputSyntax( - "expect `UnnamedExpr` for `projection`".to_string(), - ) - .into()); - }; - - Ok(expr) - } - // user defined function // TODO: resolve schema name https://github.com/risingwavelabs/risingwave/issues/12422 if let Ok(schema) = self.first_valid_schema() @@ -267,7 +199,7 @@ impl Binder { // The actual inline logic for sql udf // Note that we will always create new udf context for each sql udf - let Ok(context) = create_udf_context(&args, &Arc::clone(func)) else { + let Ok(context) = UdfContext::create_udf_context(&args, &Arc::clone(func)) else { return Err(ErrorCode::InvalidInputSyntax( "failed to create the `udf_context`, please recheck your function definition and syntax".to_string() ) @@ -306,7 +238,7 @@ impl Binder { self.udf_context.incr_global_count(); } - if let Ok(expr) = extract_udf_expression(ast) { + if let Ok(expr) = UdfContext::extract_udf_expression(ast) { let bind_result = self.bind_expr(expr); // Restore context information for subsequent binding self.udf_context.update_context(stashed_udf_context); diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index bf090a87a7514..d26d827618182 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -21,7 +21,9 @@ use risingwave_common::error::Result; use risingwave_common::session_config::{ConfigMap, SearchPath}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqDebug; -use risingwave_sqlparser::ast::Statement; +use risingwave_sqlparser::ast::{ + Expr as AstExpr, FunctionArg, FunctionArgExpr, SelectItem, SetExpr, Statement, +}; mod bind_context; mod bind_param; @@ -57,6 +59,7 @@ pub use update::BoundUpdate; pub use values::BoundValues; use crate::catalog::catalog_service::CatalogReadGuard; +use crate::catalog::function_catalog::FunctionCatalog; use crate::catalog::schema_catalog::SchemaCatalog; use crate::catalog::{CatalogResult, TableId, ViewId}; use crate::expr::ExprImpl; @@ -168,6 +171,73 @@ impl UdfContext { pub fn get_context(&self) -> HashMap { self.udf_param_context.clone() } + + /// A common utility function to extract sql udf + /// expression out from the input `ast` + pub fn extract_udf_expression(ast: Vec) -> Result { + if ast.len() != 1 { + return Err(ErrorCode::InvalidInputSyntax( + "the query for sql udf should contain only one statement".to_string(), + ) + .into()); + } + + // Extract the expression out + let Statement::Query(query) = ast[0].clone() else { + return Err(ErrorCode::InvalidInputSyntax( + "invalid function definition, please recheck the syntax".to_string(), + ) + .into()); + }; + + let SetExpr::Select(select) = query.body else { + return Err(ErrorCode::InvalidInputSyntax( + "missing `select` body for sql udf expression, please recheck the syntax" + .to_string(), + ) + .into()); + }; + + if select.projection.len() != 1 { + return Err(ErrorCode::InvalidInputSyntax( + "`projection` should contain only one `SelectItem`".to_string(), + ) + .into()); + } + + let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else { + return Err(ErrorCode::InvalidInputSyntax( + "expect `UnnamedExpr` for `projection`".to_string(), + ) + .into()); + }; + + Ok(expr) + } + + /// TODO: add name related logic + /// NOTE: need to think of a way to prevent naming conflict + /// e.g., when existing column names conflict with parameter names in sql udf + pub fn create_udf_context( + args: &[FunctionArg], + _catalog: &Arc, + ) -> Result> { + let mut ret: HashMap = HashMap::new(); + for (i, current_arg) in args.iter().enumerate() { + if let FunctionArg::Unnamed(arg) = current_arg { + let FunctionArgExpr::Expr(e) = arg else { + return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()); + }; + // if catalog.arg_names.is_some() { + // todo!() + // } + ret.insert(format!("${}", i + 1), e.clone()); + continue; + } + return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()); + } + Ok(ret) + } } /// `ParameterTypes` is used to record the types of the parameters during binding. It works diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index 2af3f5d9291b6..66429c19fad12 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -26,6 +26,7 @@ use risingwave_sqlparser::ast::{ use risingwave_sqlparser::parser::{Parser, ParserError}; use super::*; +use crate::binder::UdfContext; use crate::catalog::CatalogError; use crate::expr::{ExprImpl, Literal}; use crate::{bind_data_type, Binder}; @@ -41,46 +42,6 @@ fn create_mock_udf_context(arg_types: Vec) -> HashMap) -> Result { - if ast.len() != 1 { - return Err(ErrorCode::InvalidInputSyntax( - "the query for sql udf should contain only one statement".to_string(), - ) - .into()); - } - - // Extract the expression out - let Statement::Query(query) = ast[0].clone() else { - return Err(ErrorCode::InvalidInputSyntax( - "invalid function definition, please recheck the syntax".to_string(), - ) - .into()); - }; - - let SetExpr::Select(select) = query.body else { - return Err(ErrorCode::InvalidInputSyntax( - "missing `select` body for sql udf expression, please recheck the syntax".to_string(), - ) - .into()); - }; - - if select.projection.len() != 1 { - return Err(ErrorCode::InvalidInputSyntax( - "`projection` should contain only one `SelectItem`".to_string(), - ) - .into()); - } - - let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else { - return Err(ErrorCode::InvalidInputSyntax( - "expect `UnnamedExpr` for `projection`".to_string(), - ) - .into()); - }; - - Ok(expr) -} - pub async fn handle_create_sql_function( handler_args: HandlerArgs, or_replace: bool, @@ -213,7 +174,7 @@ pub async fn handle_create_sql_function( .udf_context_mut() .update_context(create_mock_udf_context(arg_types.clone())); - if let Ok(expr) = extract_udf_expression(ast) { + if let Ok(expr) = UdfContext::extract_udf_expression(ast) { if let Err(e) = binder.bind_expr(expr) { return Err(ErrorCode::InvalidInputSyntax( format!("failed to conduct semantic check, please see if you are calling non-existence functions.\nDetailed error: {e}")