From de3696ff44e64558b436c4d88648c77a967f14a3 Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Tue, 27 Feb 2024 16:56:34 +0800 Subject: [PATCH] feat(udf): support `always_retry_on_network_error` config for udf functions (#15163) Co-authored-by: August Co-authored-by: Runji Wang --- ci/scripts/run-e2e-test.sh | 1 + e2e_test/udf/always_retry_python.slt | 75 +++++++++++++ proto/catalog.proto | 1 + proto/expr.proto | 1 + src/expr/core/src/expr/expr_udf.rs | 53 +++++---- src/expr/udf/src/external.rs | 19 ++++ src/frontend/src/catalog/function_catalog.rs | 2 + .../src/expr/user_defined_function.rs | 2 + src/frontend/src/handler/create_function.rs | 4 + .../src/handler/create_sql_function.rs | 1 + src/frontend/src/handler/mod.rs | 2 + .../migration/src/m20230908_072257_init.rs | 6 ++ src/meta/model_v2/src/function.rs | 2 + src/meta/src/controller/mod.rs | 1 + src/sqlparser/src/ast/mod.rs | 101 ++++++++++++++++++ src/sqlparser/src/parser.rs | 4 +- src/sqlparser/tests/sqlparser_postgres.rs | 9 +- 17 files changed, 260 insertions(+), 24 deletions(-) create mode 100644 e2e_test/udf/always_retry_python.slt diff --git a/ci/scripts/run-e2e-test.sh b/ci/scripts/run-e2e-test.sh index 95bf4fc679add..eeed357487198 100755 --- a/ci/scripts/run-e2e-test.sh +++ b/ci/scripts/run-e2e-test.sh @@ -90,6 +90,7 @@ pkill python3 sqllogictest -p 4566 -d dev './e2e_test/udf/alter_function.slt' sqllogictest -p 4566 -d dev './e2e_test/udf/graceful_shutdown_python.slt' +sqllogictest -p 4566 -d dev './e2e_test/udf/always_retry_python.slt' # FIXME: flaky test # sqllogictest -p 4566 -d dev './e2e_test/udf/retry_python.slt' diff --git a/e2e_test/udf/always_retry_python.slt b/e2e_test/udf/always_retry_python.slt new file mode 100644 index 0000000000000..64477a7df6be9 --- /dev/null +++ b/e2e_test/udf/always_retry_python.slt @@ -0,0 +1,75 @@ +system ok +python3 e2e_test/udf/test.py & + +# wait for server to start +sleep 10s + +statement ok +CREATE FUNCTION sleep_always_retry(INT) RETURNS INT AS 'sleep' USING LINK 'http://localhost:8815' WITH ( always_retry_on_network_error = true ); + +statement ok +CREATE FUNCTION sleep_no_retry(INT) RETURNS INT AS 'sleep' USING LINK 'http://localhost:8815'; + +# Create a table with 30 records +statement ok +CREATE TABLE t (v1 int); + +statement ok +INSERT INTO t select 0 from generate_series(1, 30); + +statement ok +flush; + +statement ok +SET STREAMING_RATE_LIMIT=1; + +statement ok +SET BACKGROUND_DDL=true; + +statement ok +CREATE MATERIALIZED VIEW mv_no_retry AS SELECT sleep_no_retry(v1) as s1 from t; + +# Create a Materialized View +statement ok +CREATE MATERIALIZED VIEW mv_always_retry AS SELECT sleep_always_retry(v1) as s1 from t; + +# Immediately kill the server, sleep for 1minute. +system ok +pkill -9 -i python && sleep 60 + +# Restart the server +system ok +python3 e2e_test/udf/test.py & + +# Wait for materialized view to be complete +statement ok +wait; + +query I +SELECT count(*) FROM mv_always_retry where s1 is NULL; +---- +0 + +query B +SELECT count(*) > 0 FROM mv_no_retry where s1 is NULL; +---- +t + +statement ok +SET STREAMING_RATE_LIMIT=0; + +statement ok +SET BACKGROUND_DDL=false; + +# close the server +system ok +pkill -i python + +statement ok +DROP FUNCTION sleep_always_retry; + +statement ok +DROP FUNCTION sleep_no_retry; + +statement ok +DROP TABLE t CASCADE; \ No newline at end of file diff --git a/proto/catalog.proto b/proto/catalog.proto index 99fd1b0a69514..513022c13c17c 100644 --- a/proto/catalog.proto +++ b/proto/catalog.proto @@ -220,6 +220,7 @@ message Function { optional string link = 8; optional string identifier = 10; optional string body = 14; + bool always_retry_on_network_error = 16; oneof kind { ScalarFunction scalar = 11; diff --git a/proto/expr.proto b/proto/expr.proto index 3a24c306ae5da..1597fd7983c60 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -527,6 +527,7 @@ message UserDefinedFunction { optional string identifier = 6; // For JavaScript UDF, it's the body of the function. optional string body = 7; + bool always_retry_on_network_error = 9; } // Additional information for user defined table functions. diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 246944ae5d9d3..5744e3e0ee8c9 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -55,6 +55,8 @@ pub struct UserDefinedFunction { /// On each successful call, the count will be decreased by 1. /// See . disable_retry_count: AtomicU8, + /// Always retry. Overrides `disable_retry_count`. + always_retry_on_network_error: bool, } const INITIAL_RETRY_COUNT: u8 = 16; @@ -111,32 +113,40 @@ impl UserDefinedFunction { UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &arrow_input)?, UdfImpl::External(client) => { let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); - let result = if disable_retry_count != 0 { + let result = if self.always_retry_on_network_error { client - .call(&self.identifier, arrow_input) + .call_with_always_retry_on_network_error(&self.identifier, arrow_input) .instrument_await(self.span.clone()) .await } else { - client - .call_with_retry(&self.identifier, arrow_input) - .instrument_await(self.span.clone()) - .await + let result = if disable_retry_count != 0 { + client + .call(&self.identifier, arrow_input) + .instrument_await(self.span.clone()) + .await + } else { + client + .call_with_retry(&self.identifier, arrow_input) + .instrument_await(self.span.clone()) + .await + }; + let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); + let connection_error = matches!(&result, Err(e) if e.is_connection_error()); + if connection_error && disable_retry_count != INITIAL_RETRY_COUNT { + // reset count on connection error + self.disable_retry_count + .store(INITIAL_RETRY_COUNT, Ordering::Relaxed); + } else if !connection_error && disable_retry_count != 0 { + // decrease count on success, ignore if exchange failed + _ = self.disable_retry_count.compare_exchange( + disable_retry_count, + disable_retry_count - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ); + } + result }; - let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); - let connection_error = matches!(&result, Err(e) if e.is_connection_error()); - if connection_error && disable_retry_count != INITIAL_RETRY_COUNT { - // reset count on connection error - self.disable_retry_count - .store(INITIAL_RETRY_COUNT, Ordering::Relaxed); - } else if !connection_error && disable_retry_count != 0 { - // decrease count on success, ignore if exchange failed - _ = self.disable_retry_count.compare_exchange( - disable_retry_count, - disable_retry_count - 1, - Ordering::Relaxed, - Ordering::Relaxed, - ); - } result? } }; @@ -234,6 +244,7 @@ impl Build for UserDefinedFunction { identifier: identifier.clone(), span: format!("udf_call({})", identifier).into(), disable_retry_count: AtomicU8::new(0), + always_retry_on_network_error: udf.always_retry_on_network_error, }) } } diff --git a/src/expr/udf/src/external.rs b/src/expr/udf/src/external.rs index f37647a591df4..fd5a05b84ec2d 100644 --- a/src/expr/udf/src/external.rs +++ b/src/expr/udf/src/external.rs @@ -204,6 +204,25 @@ impl ArrowFlightUdfClient { unreachable!() } + /// Always retry on connection error + pub async fn call_with_always_retry_on_network_error( + &self, + id: &str, + input: RecordBatch, + ) -> Result { + let mut backoff = Duration::from_millis(100); + loop { + match self.call(id, input.clone()).await { + Err(err) if err.is_connection_error() => { + tracing::error!(error = %err.as_report(), "UDF connection error. retry..."); + } + ret => return ret, + } + tokio::time::sleep(backoff).await; + backoff *= 2; + } + } + /// Call a function with streaming input and output. #[panic_return = "Result>"] pub async fn call_stream( diff --git a/src/frontend/src/catalog/function_catalog.rs b/src/frontend/src/catalog/function_catalog.rs index 96dbbe77c2a12..e60a3a758b7b5 100644 --- a/src/frontend/src/catalog/function_catalog.rs +++ b/src/frontend/src/catalog/function_catalog.rs @@ -33,6 +33,7 @@ pub struct FunctionCatalog { pub identifier: Option, pub body: Option, pub link: Option, + pub always_retry_on_network_error: bool, } #[derive(Clone, Display, PartialEq, Eq, Hash, Debug)] @@ -68,6 +69,7 @@ impl From<&PbFunction> for FunctionCatalog { identifier: prost.identifier.clone(), body: prost.body.clone(), link: prost.link.clone(), + always_retry_on_network_error: prost.always_retry_on_network_error, } } } diff --git a/src/frontend/src/expr/user_defined_function.rs b/src/frontend/src/expr/user_defined_function.rs index 323d74b04be08..6d35bd37145ce 100644 --- a/src/frontend/src/expr/user_defined_function.rs +++ b/src/frontend/src/expr/user_defined_function.rs @@ -58,6 +58,7 @@ impl UserDefinedFunction { identifier: udf.identifier.clone(), body: udf.body.clone(), link: udf.link.clone(), + always_retry_on_network_error: udf.always_retry_on_network_error, }; Ok(Self { @@ -92,6 +93,7 @@ impl Expr for UserDefinedFunction { identifier: self.catalog.identifier.clone(), link: self.catalog.link.clone(), body: self.catalog.body.clone(), + always_retry_on_network_error: self.catalog.always_retry_on_network_error, })), } } diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 0a8329e54be08..108fdb16e788b 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -40,6 +40,7 @@ pub async fn handle_create_function( args: Option>, returns: Option, params: CreateFunctionBody, + with_options: CreateFunctionWithOptions, ) -> Result { if or_replace { bail_not_implemented!("CREATE OR REPLACE FUNCTION"); @@ -285,6 +286,9 @@ pub async fn handle_create_function( link, body, owner: session.user_id(), + always_retry_on_network_error: with_options + .always_retry_on_network_error + .unwrap_or_default(), }; let catalog_writer = session.catalog_writer()?; diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index 71fb480ae9ab3..0f9b42cf515a1 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -295,6 +295,7 @@ pub async fn handle_create_sql_function( body: Some(body), link: None, owner: session.user_id(), + always_retry_on_network_error: false, }; let catalog_writer = session.catalog_writer()?; diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index 827f28f87319e..a8cb4c03e1bd2 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -246,6 +246,7 @@ pub async fn handle( args, returns, params, + with_options, } => { // For general udf, `language` clause could be ignored // refer: https://github.com/risingwavelabs/risingwave/pull/10608 @@ -266,6 +267,7 @@ pub async fn handle( args, returns, params, + with_options, ) .await } else { diff --git a/src/meta/model_v2/migration/src/m20230908_072257_init.rs b/src/meta/model_v2/migration/src/m20230908_072257_init.rs index 444433d84c31e..ace5f52a85002 100644 --- a/src/meta/model_v2/migration/src/m20230908_072257_init.rs +++ b/src/meta/model_v2/migration/src/m20230908_072257_init.rs @@ -724,6 +724,11 @@ impl MigrationTrait for Migration { .col(ColumnDef::new(Function::Identifier).string()) .col(ColumnDef::new(Function::Body).string()) .col(ColumnDef::new(Function::Kind).string().not_null()) + .col( + ColumnDef::new(Function::AlwaysRetryOnNetworkError) + .boolean() + .not_null(), + ) .foreign_key( &mut ForeignKey::create() .name("FK_function_object_id") @@ -1124,6 +1129,7 @@ enum Function { Identifier, Body, Kind, + AlwaysRetryOnNetworkError, } #[derive(DeriveIden)] diff --git a/src/meta/model_v2/src/function.rs b/src/meta/model_v2/src/function.rs index ae68782a50fd1..1976cee4f867a 100644 --- a/src/meta/model_v2/src/function.rs +++ b/src/meta/model_v2/src/function.rs @@ -45,6 +45,7 @@ pub struct Model { pub identifier: Option, pub body: Option, pub kind: FunctionKind, + pub always_retry_on_network_error: bool, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] @@ -100,6 +101,7 @@ impl From for ActiveModel { identifier: Set(function.identifier), body: Set(function.body), kind: Set(function.kind.unwrap().into()), + always_retry_on_network_error: Set(function.always_retry_on_network_error), } } } diff --git a/src/meta/src/controller/mod.rs b/src/meta/src/controller/mod.rs index 7fd70318a1ff3..1f61f55c754ba 100644 --- a/src/meta/src/controller/mod.rs +++ b/src/meta/src/controller/mod.rs @@ -285,6 +285,7 @@ impl From> for PbFunction { identifier: value.0.identifier, body: value.0.body, kind: Some(value.0.kind.into()), + always_retry_on_network_error: value.0.always_retry_on_network_error, } } } diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index 952ce05eb933f..7051c10862d44 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -26,6 +26,7 @@ use alloc::{ vec::Vec, }; use core::fmt; +use core::fmt::Display; use itertools::Itertools; #[cfg(feature = "serde")] @@ -1168,6 +1169,7 @@ pub enum Statement { returns: Option, /// Optional parameters. params: CreateFunctionBody, + with_options: CreateFunctionWithOptions, }, /// CREATE AGGREGATE /// @@ -1536,6 +1538,7 @@ impl fmt::Display for Statement { args, returns, params, + with_options, } => { write!( f, @@ -1550,6 +1553,7 @@ impl fmt::Display for Statement { write!(f, " {}", return_type)?; } write!(f, "{params}")?; + write!(f, "{with_options}")?; Ok(()) } Statement::CreateAggregate { @@ -2736,6 +2740,57 @@ impl fmt::Display for CreateFunctionBody { Ok(()) } } +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct CreateFunctionWithOptions { + /// Always retry on network errors. + pub always_retry_on_network_error: Option, +} + +/// TODO(kwannoel): Generate from the struct definition instead. +impl CreateFunctionWithOptions { + fn is_empty(&self) -> bool { + self.always_retry_on_network_error.is_none() + } +} + +/// TODO(kwannoel): Generate from the struct definition instead. +impl TryFrom> for CreateFunctionWithOptions { + type Error = ParserError; + + fn try_from(with_options: Vec) -> Result { + let mut always_retry_on_network_error = None; + for option in with_options { + if option.name.to_string().to_lowercase() == "always_retry_on_network_error" { + always_retry_on_network_error = Some(option.value == Value::Boolean(true)); + } else { + return Err(ParserError::ParserError(format!( + "Unsupported option: {}", + option.name + ))); + } + } + Ok(Self { + always_retry_on_network_error, + }) + } +} + +impl Display for CreateFunctionWithOptions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_empty() { + return Ok(()); + } + let mut options = vec![]; + if let Some(always_retry_on_network_error) = self.always_retry_on_network_error { + options.push(format!( + "ALWAYS_RETRY_NETWORK_ERRORS = {}", + always_retry_on_network_error + )); + } + write!(f, " WITH ( {} )", display_comma_separated(&options)) + } +} #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -2936,4 +2991,50 @@ mod tests { }; assert_eq!("NOT true IS NOT FALSE", format!("{}", unary_op)); } + + #[test] + fn test_create_function_display() { + let create_function = Statement::CreateFunction { + temporary: false, + or_replace: false, + name: ObjectName(vec![Ident::new_unchecked("foo")]), + args: Some(vec![OperateFunctionArg::unnamed(DataType::Int)]), + returns: Some(CreateFunctionReturns::Value(DataType::Int)), + params: CreateFunctionBody { + language: Some(Ident::new_unchecked("python")), + behavior: Some(FunctionBehavior::Immutable), + as_: Some(FunctionDefinition::SingleQuotedDef("SELECT 1".to_string())), + return_: None, + using: None, + }, + with_options: CreateFunctionWithOptions { + always_retry_on_network_error: None, + }, + }; + assert_eq!( + "CREATE FUNCTION foo(INT) RETURNS INT LANGUAGE python IMMUTABLE AS 'SELECT 1'", + format!("{}", create_function) + ); + let create_function = Statement::CreateFunction { + temporary: false, + or_replace: false, + name: ObjectName(vec![Ident::new_unchecked("foo")]), + args: Some(vec![OperateFunctionArg::unnamed(DataType::Int)]), + returns: Some(CreateFunctionReturns::Value(DataType::Int)), + params: CreateFunctionBody { + language: Some(Ident::new_unchecked("python")), + behavior: Some(FunctionBehavior::Immutable), + as_: Some(FunctionDefinition::SingleQuotedDef("SELECT 1".to_string())), + return_: None, + using: None, + }, + with_options: CreateFunctionWithOptions { + always_retry_on_network_error: Some(true), + }, + }; + assert_eq!( + "CREATE FUNCTION foo(INT) RETURNS INT LANGUAGE python IMMUTABLE AS 'SELECT 1' WITH ( ALWAYS_RETRY_NETWORK_ERRORS = true )", + format!("{}", create_function) + ); + } } diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 87263dff16ee2..c152371993306 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -2238,7 +2238,8 @@ impl Parser { }; let params = self.parse_create_function_body()?; - + let with_options = self.parse_options_with_preceding_keyword(Keyword::WITH)?; + let with_options = with_options.try_into()?; Ok(Statement::CreateFunction { or_replace, temporary, @@ -2246,6 +2247,7 @@ impl Parser { args, returns: return_type, params, + with_options, }) } diff --git a/src/sqlparser/tests/sqlparser_postgres.rs b/src/sqlparser/tests/sqlparser_postgres.rs index 99e2c185fdcff..6a5dec5d809c1 100644 --- a/src/sqlparser/tests/sqlparser_postgres.rs +++ b/src/sqlparser/tests/sqlparser_postgres.rs @@ -765,6 +765,7 @@ fn parse_create_function() { )), ..Default::default() }, + with_options: Default::default(), } ); @@ -786,7 +787,8 @@ fn parse_create_function() { "select $1 - $2;".into() )), ..Default::default() - } + }, + with_options: Default::default(), }, ); @@ -811,7 +813,8 @@ fn parse_create_function() { right: Box::new(Expr::Parameter { index: 2 }), }), ..Default::default() - } + }, + with_options: Default::default(), }, ); @@ -842,6 +845,7 @@ fn parse_create_function() { }), ..Default::default() }, + with_options: Default::default(), } ); @@ -865,6 +869,7 @@ fn parse_create_function() { return_: Some(Expr::Identifier("a".into())), ..Default::default() }, + with_options: Default::default(), } ); }