From d8cd658d8865cb201e2f462d82b967678805d0aa Mon Sep 17 00:00:00 2001 From: Kexiang Wang Date: Tue, 20 Feb 2024 00:41:31 -0500 Subject: [PATCH] feat(udf): support client side load balancer for UDF --- Cargo.lock | 114 +++++++++++++++++- proto/catalog.proto | 1 + proto/expr.proto | 2 + src/expr/core/src/expr/expr_udf.rs | 12 +- .../core/src/table_function/user_defined.rs | 5 +- src/expr/udf/Cargo.toml | 2 + src/expr/udf/examples/client.rs | 2 +- src/expr/udf/src/external.rs | 67 ++++++++-- src/frontend/src/catalog/function_catalog.rs | 2 + src/frontend/src/expr/table_function.rs | 1 + .../src/expr/user_defined_function.rs | 2 + src/frontend/src/handler/create_function.rs | 11 +- .../src/handler/create_sql_function.rs | 2 + src/frontend/src/handler/mod.rs | 3 + src/meta/model_v2/src/function.rs | 2 + src/meta/src/controller/mod.rs | 1 + src/sqlparser/src/ast/mod.rs | 50 ++++++++ src/sqlparser/src/parser.rs | 4 +- src/sqlparser/tests/sqlparser_postgres.rs | 9 +- 19 files changed, 269 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ed7677d652954..02c889ab96620 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4593,6 +4593,23 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "ginepro" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eedbff62a689be48f58f32571dbf3d60c4a73b39740141dfe7ac942536ea27f7" +dependencies = [ + "anyhow", + "async-trait", + "http 0.2.9", + "thiserror", + "tokio", + "tonic 0.10.2", + "tower", + "tracing", + "trust-dns-resolver", +] + [[package]] name = "glob" version = "0.3.1" @@ -5121,6 +5138,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" +[[package]] +name = "idna" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "idna" version = "0.5.0" @@ -5265,6 +5292,18 @@ version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a611371471e98973dbcab4e0ec66c31a10bc356eeb4d54a0e05eac8158fe38c" +[[package]] +name = "ipconfig" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" +dependencies = [ + "socket2 0.5.3", + "widestring", + "windows-sys 0.48.0", + "winreg", +] + [[package]] name = "ipnet" version = "2.8.0" @@ -5779,6 +5818,15 @@ dependencies = [ "hashbrown 0.14.0", ] +[[package]] +name = "lru-cache" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c" +dependencies = [ + "linked-hash-map", +] + [[package]] name = "lz4" version = "1.24.0" @@ -8422,6 +8470,16 @@ dependencies = [ "winreg", ] +[[package]] +name = "resolv-conf" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00" +dependencies = [ + "hostname", + "quick-error", +] + [[package]] name = "retain_mut" version = "0.1.7" @@ -10036,7 +10094,9 @@ dependencies = [ "arrow-schema 50.0.0", "arrow-select 50.0.0", "cfg-or-panic", + "futures", "futures-util", + "ginepro", "madsim-tokio", "madsim-tonic", "prometheus", @@ -12467,6 +12527,52 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "622b09ce2fe2df4618636fb92176d205662f59803f39e70d1c333393082de96c" +[[package]] +name = "trust-dns-proto" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3119112651c157f4488931a01e586aa459736e9d6046d3bd9105ffb69352d374" +dependencies = [ + "async-trait", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "idna 0.4.0", + "ipnet", + "once_cell", + "rand", + "smallvec", + "thiserror", + "tinyvec", + "tokio", + "tracing", + "url", +] + +[[package]] +name = "trust-dns-resolver" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a3e6c3aff1718b3c73e395d1f35202ba2ffa847c6a62eea0db8fb4cfe30be6" +dependencies = [ + "cfg-if", + "futures-util", + "ipconfig", + "lru-cache", + "once_cell", + "parking_lot 0.12.1", + "rand", + "resolv-conf", + "smallvec", + "thiserror", + "tokio", + "tracing", + "trust-dns-proto", +] + [[package]] name = "try-lock" version = "0.2.4" @@ -12618,7 +12724,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", - "idna", + "idna 0.5.0", "percent-encoding", "serde", ] @@ -13307,6 +13413,12 @@ dependencies = [ "web-sys", ] +[[package]] +name = "widestring" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" + [[package]] name = "wiggle" version = "17.0.0" diff --git a/proto/catalog.proto b/proto/catalog.proto index 99fd1b0a69514..6123025ba865f 100644 --- a/proto/catalog.proto +++ b/proto/catalog.proto @@ -220,6 +220,7 @@ message Function { optional string link = 8; optional string identifier = 10; optional string body = 14; + bool enable_dns_resolver = 16; oneof kind { ScalarFunction scalar = 11; diff --git a/proto/expr.proto b/proto/expr.proto index 14f9eb8c102cd..0a1eb2e0812cb 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -523,6 +523,7 @@ message UserDefinedFunction { optional string identifier = 6; // For JavaScript UDF, it's the body of the function. optional string body = 7; + bool enable_dns_resolver = 9; } // Additional information for user defined table functions. @@ -533,4 +534,5 @@ message UserDefinedTableFunction { optional string link = 5; optional string identifier = 6; optional string body = 7; + bool enable_dns_resolver = 9; } diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index b162691896a43..05bc4a536497a 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -219,7 +219,7 @@ impl Build for UserDefinedFunction { #[cfg(not(madsim))] _ => { let link = udf.get_link()?; - UdfImpl::External(get_or_create_flight_client(link)?) + UdfImpl::External(get_or_create_flight_client(link, udf.enable_dns_resolver)?) } #[cfg(madsim)] l => panic!("UDF language {l:?} is not supported on madsim"), @@ -257,7 +257,10 @@ impl Build for UserDefinedFunction { /// Get or create a client for the given UDF service. /// /// There is a global cache for clients, so that we can reuse the same client for the same service. -pub(crate) fn get_or_create_flight_client(link: &str) -> Result> { +pub(crate) fn get_or_create_flight_client( + link: &str, + enable_dns_resolver: bool, +) -> Result> { static CLIENTS: LazyLock>>> = LazyLock::new(Default::default); let mut clients = CLIENTS.lock().unwrap(); @@ -266,7 +269,10 @@ pub(crate) fn get_or_create_flight_client(link: &str) -> Result Result { let link = udtf.get_link()?; - UdfImpl::External(crate::expr::expr_udf::get_or_create_flight_client(link)?) + UdfImpl::External(crate::expr::expr_udf::get_or_create_flight_client( + link, + udtf.enable_dns_resolver, + )?) } }; diff --git a/src/expr/udf/Cargo.toml b/src/expr/udf/Cargo.toml index 838f2e62958c3..b17ad7acadfc1 100644 --- a/src/expr/udf/Cargo.toml +++ b/src/expr/udf/Cargo.toml @@ -16,7 +16,9 @@ arrow-flight = { workspace = true } arrow-schema = { workspace = true } arrow-select = { workspace = true } cfg-or-panic = "0.2" +futures = "0.3" futures-util = "0.3.28" +ginepro = "0.7.0" prometheus = "0.13" risingwave_common = { workspace = true } static_assertions = "1" diff --git a/src/expr/udf/examples/client.rs b/src/expr/udf/examples/client.rs index 92f93ae13614e..94d6f1515321e 100644 --- a/src/expr/udf/examples/client.rs +++ b/src/expr/udf/examples/client.rs @@ -21,7 +21,7 @@ use risingwave_udf::ArrowFlightUdfClient; #[tokio::main] async fn main() { let addr = "http://localhost:8815"; - let client = ArrowFlightUdfClient::connect(addr).await.unwrap(); + let client = ArrowFlightUdfClient::connect(addr, false).await.unwrap(); // build `RecordBatch` to send (equivalent to our `DataChunk`) let array1 = Arc::new(Int32Array::from_iter(vec![1, 6, 10])); diff --git a/src/expr/udf/src/external.rs b/src/expr/udf/src/external.rs index f8d4cf6cc379e..4c87f39226da0 100644 --- a/src/expr/udf/src/external.rs +++ b/src/expr/udf/src/external.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::str::FromStr; use std::time::Duration; use arrow_array::RecordBatch; @@ -22,8 +23,12 @@ use arrow_flight::flight_service_client::FlightServiceClient; use arrow_flight::{FlightData, FlightDescriptor}; use arrow_schema::Schema; use cfg_or_panic::cfg_or_panic; +use futures::executor::block_on; use futures_util::{stream, Stream, StreamExt, TryStreamExt}; +use ginepro::{LoadBalancedChannel, ResolutionStrategy}; +use risingwave_common::util::addr::HostAddr; use thiserror_ext::AsReport; +use tokio::time::Duration as TokioDuration; use tonic::transport::Channel; use crate::metrics::GLOBAL_METRICS; @@ -40,12 +45,22 @@ pub struct ArrowFlightUdfClient { #[cfg_or_panic(not(madsim))] impl ArrowFlightUdfClient { /// Connect to a UDF service. - pub async fn connect(addr: &str) -> Result { - let conn = tonic::transport::Endpoint::new(addr.to_string())? - .timeout(Duration::from_secs(5)) - .connect_timeout(Duration::from_secs(5)) - .connect() - .await?; + pub async fn connect(addr: &str, enable_dns_resolver: bool) -> Result { + let conn = if enable_dns_resolver { + Self::connect_with_dns( + addr, + ResolutionStrategy::Eager { + timeout: TokioDuration::from_secs(5), + }, + ) + .await? + } else { + tonic::transport::Endpoint::new(addr.to_string())? + .timeout(Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(5)) + .connect() + .await? + }; let client = FlightServiceClient::new(conn); Ok(Self { client, @@ -54,11 +69,15 @@ impl ArrowFlightUdfClient { } /// Connect to a UDF service lazily (i.e. only when the first request is sent). - pub fn connect_lazy(addr: &str) -> Result { - let conn = tonic::transport::Endpoint::new(addr.to_string())? - .timeout(Duration::from_secs(5)) - .connect_timeout(Duration::from_secs(5)) - .connect_lazy(); + pub fn connect_lazy(addr: &str, enable_dns_resolver: bool) -> Result { + let conn = if enable_dns_resolver { + block_on(async { Self::connect_with_dns(addr, ResolutionStrategy::Lazy).await })? + } else { + tonic::transport::Endpoint::new(addr.to_string())? + .timeout(Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(5)) + .connect_lazy() + }; let client = FlightServiceClient::new(conn); Ok(Self { client, @@ -66,6 +85,32 @@ impl ArrowFlightUdfClient { }) } + async fn connect_with_dns( + addr: &str, + resolution_strategy: ResolutionStrategy, + ) -> Result { + let addr = addr.strip_prefix("http://").ok_or_else(|| { + Error::service_error(format!("udf address must starts with http://: {}", addr)) + })?; + let host_addr = HostAddr::from_str(addr) + .map_err(|e| Error::service_error(format!("invalid address: {}, err: {}", addr, e)))?; + let channel = LoadBalancedChannel::builder((host_addr.host.clone(), host_addr.port)) + .dns_probe_interval(std::time::Duration::from_secs(5)) + .timeout(Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(5)) + .resolution_strategy(resolution_strategy) + .channel() + .await + .map_err(|e| { + Error::service_error(format!( + "failed to create LoadBalancedChannel, address: {}, err: {}", + host_addr, e + )) + })?; + + Ok(channel.into()) + } + /// Check if the function is available and the schema is match. pub async fn check(&self, id: &str, args: &Schema, returns: &Schema) -> Result<()> { let descriptor = FlightDescriptor::new_path(vec![id.into()]); diff --git a/src/frontend/src/catalog/function_catalog.rs b/src/frontend/src/catalog/function_catalog.rs index 96dbbe77c2a12..a4bbe176fea98 100644 --- a/src/frontend/src/catalog/function_catalog.rs +++ b/src/frontend/src/catalog/function_catalog.rs @@ -33,6 +33,7 @@ pub struct FunctionCatalog { pub identifier: Option, pub body: Option, pub link: Option, + pub enable_dns_resolver: bool, } #[derive(Clone, Display, PartialEq, Eq, Hash, Debug)] @@ -68,6 +69,7 @@ impl From<&PbFunction> for FunctionCatalog { identifier: prost.identifier.clone(), body: prost.body.clone(), link: prost.link.clone(), + enable_dns_resolver: prost.enable_dns_resolver, } } } diff --git a/src/frontend/src/expr/table_function.rs b/src/frontend/src/expr/table_function.rs index e3000d0c245ab..c59041e3aa819 100644 --- a/src/frontend/src/expr/table_function.rs +++ b/src/frontend/src/expr/table_function.rs @@ -79,6 +79,7 @@ impl TableFunction { link: c.link.clone(), identifier: c.identifier.clone(), body: c.body.clone(), + enable_dns_resolver: c.enable_dns_resolver, }), } } diff --git a/src/frontend/src/expr/user_defined_function.rs b/src/frontend/src/expr/user_defined_function.rs index 323d74b04be08..14ee39acb92c1 100644 --- a/src/frontend/src/expr/user_defined_function.rs +++ b/src/frontend/src/expr/user_defined_function.rs @@ -58,6 +58,7 @@ impl UserDefinedFunction { identifier: udf.identifier.clone(), body: udf.body.clone(), link: udf.link.clone(), + enable_dns_resolver: udf.enable_dns_resolver, }; Ok(Self { @@ -92,6 +93,7 @@ impl Expr for UserDefinedFunction { identifier: self.catalog.identifier.clone(), link: self.catalog.link.clone(), body: self.catalog.body.clone(), + enable_dns_resolver: self.catalog.enable_dns_resolver, })), } } diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 0a8329e54be08..3cc515b99742e 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -40,6 +40,7 @@ pub async fn handle_create_function( args: Option>, returns: Option, params: CreateFunctionBody, + with_options: CreateFunctionWithOptions, ) -> Result { if or_replace { bail_not_implemented!("CREATE OR REPLACE FUNCTION"); @@ -141,9 +142,12 @@ pub async fn handle_create_function( // check UDF server { - let client = ArrowFlightUdfClient::connect(&l) - .await - .map_err(|e| anyhow!(e))?; + let client = ArrowFlightUdfClient::connect( + &l, + with_options.enable_dns_resolver.unwrap_or_default(), + ) + .await + .map_err(|e| anyhow!(e))?; /// A helper function to create a unnamed field from data type. fn to_field(data_type: arrow_schema::DataType) -> arrow_schema::Field { arrow_schema::Field::new("", data_type, true) @@ -285,6 +289,7 @@ pub async fn handle_create_function( link, body, owner: session.user_id(), + enable_dns_resolver: with_options.enable_dns_resolver.unwrap_or_default(), }; let catalog_writer = session.catalog_writer()?; diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index 311664735603f..b0431efeda259 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -61,6 +61,7 @@ pub async fn handle_create_sql_function( args: Option>, returns: Option, params: CreateFunctionBody, + with_options: CreateFunctionWithOptions, ) -> Result { if or_replace { bail_not_implemented!("CREATE OR REPLACE FUNCTION"); @@ -236,6 +237,7 @@ pub async fn handle_create_sql_function( body: Some(body), link: None, owner: session.user_id(), + enable_dns_resolver: with_options.enable_dns_resolver.unwrap_or_default(), }; let catalog_writer = session.catalog_writer()?; diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index 3cdc4b191da92..5f488fe30e66f 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -206,6 +206,7 @@ pub async fn handle( args, returns, params, + with_options, } => { // For general udf, `language` clause could be ignored // refer: https://github.com/risingwavelabs/risingwave/pull/10608 @@ -226,6 +227,7 @@ pub async fn handle( args, returns, params, + with_options, ) .await } else { @@ -237,6 +239,7 @@ pub async fn handle( args, returns, params, + with_options, ) .await } diff --git a/src/meta/model_v2/src/function.rs b/src/meta/model_v2/src/function.rs index ae68782a50fd1..b17d2abf1f0fc 100644 --- a/src/meta/model_v2/src/function.rs +++ b/src/meta/model_v2/src/function.rs @@ -45,6 +45,7 @@ pub struct Model { pub identifier: Option, pub body: Option, pub kind: FunctionKind, + pub enable_dns_resolver: bool, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] @@ -100,6 +101,7 @@ impl From for ActiveModel { identifier: Set(function.identifier), body: Set(function.body), kind: Set(function.kind.unwrap().into()), + enable_dns_resolver: Set(function.enable_dns_resolver), } } } diff --git a/src/meta/src/controller/mod.rs b/src/meta/src/controller/mod.rs index 4873a42809b0b..a59562dcd2a53 100644 --- a/src/meta/src/controller/mod.rs +++ b/src/meta/src/controller/mod.rs @@ -287,6 +287,7 @@ impl From> for PbFunction { identifier: value.0.identifier, body: value.0.body, kind: Some(value.0.kind.into()), + enable_dns_resolver: value.0.enable_dns_resolver, } } } diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index 952ce05eb933f..3b4eb7057468d 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -1168,6 +1168,7 @@ pub enum Statement { returns: Option, /// Optional parameters. params: CreateFunctionBody, + with_options: CreateFunctionWithOptions, }, /// CREATE AGGREGATE /// @@ -1536,6 +1537,7 @@ impl fmt::Display for Statement { args, returns, params, + with_options } => { write!( f, @@ -1550,6 +1552,7 @@ impl fmt::Display for Statement { write!(f, " {}", return_type)?; } write!(f, "{params}")?; + write!(f, "{with_options}")?; Ok(()) } Statement::CreateAggregate { @@ -2737,6 +2740,53 @@ impl fmt::Display for CreateFunctionBody { } } +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct CreateFunctionWithOptions { + /// Whether to enable DNS resolver for the UDF client. Should be enabled if the UDF servers are behind a L4 load balancer. + pub enable_dns_resolver: Option, +} + +impl CreateFunctionWithOptions { + fn is_empty(&self) -> bool { + self.enable_dns_resolver.is_none() + } +} + +impl TryFrom> for CreateFunctionWithOptions { + type Error = ParserError; + + fn try_from(with_options: Vec) -> Result { + let mut enable_dns_resolver = None; + for option in with_options { + if option.name.to_string().to_lowercase() == "enable_dns_resolver" { + enable_dns_resolver = Some(option.value == Value::Boolean(true)); + } else { + return Err(ParserError::ParserError(format!( + "Unsupported option: {}", + option.name + ))); + } + } + Ok(Self { + enable_dns_resolver, + }) + } +} + +impl fmt::Display for CreateFunctionWithOptions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_empty() { + return Ok(()); + } + let mut options = vec![]; + if let Some(enable_dns_resolver) = self.enable_dns_resolver { + options.push(format!("ENABLE_DNS_RESOLVER = {}", enable_dns_resolver)); + } + write!(f, " WITH ( {} )", display_comma_separated(&options)) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum CreateFunctionUsing { diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 87263dff16ee2..c152371993306 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -2238,7 +2238,8 @@ impl Parser { }; let params = self.parse_create_function_body()?; - + let with_options = self.parse_options_with_preceding_keyword(Keyword::WITH)?; + let with_options = with_options.try_into()?; Ok(Statement::CreateFunction { or_replace, temporary, @@ -2246,6 +2247,7 @@ impl Parser { args, returns: return_type, params, + with_options, }) } diff --git a/src/sqlparser/tests/sqlparser_postgres.rs b/src/sqlparser/tests/sqlparser_postgres.rs index 99e2c185fdcff..6a5dec5d809c1 100644 --- a/src/sqlparser/tests/sqlparser_postgres.rs +++ b/src/sqlparser/tests/sqlparser_postgres.rs @@ -765,6 +765,7 @@ fn parse_create_function() { )), ..Default::default() }, + with_options: Default::default(), } ); @@ -786,7 +787,8 @@ fn parse_create_function() { "select $1 - $2;".into() )), ..Default::default() - } + }, + with_options: Default::default(), }, ); @@ -811,7 +813,8 @@ fn parse_create_function() { right: Box::new(Expr::Parameter { index: 2 }), }), ..Default::default() - } + }, + with_options: Default::default(), }, ); @@ -842,6 +845,7 @@ fn parse_create_function() { }), ..Default::default() }, + with_options: Default::default(), } ); @@ -865,6 +869,7 @@ fn parse_create_function() { return_: Some(Expr::Identifier("a".into())), ..Default::default() }, + with_options: Default::default(), } ); }