Skip to content

Commit

Permalink
feat(udf): support client side load balancer for UDF
Browse files Browse the repository at this point in the history
  • Loading branch information
KeXiangWang committed Feb 23, 2024
1 parent d2c547a commit d8cd658
Show file tree
Hide file tree
Showing 19 changed files with 269 additions and 23 deletions.
114 changes: 113 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions proto/catalog.proto
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ message Function {
optional string link = 8;
optional string identifier = 10;
optional string body = 14;
bool enable_dns_resolver = 16;

oneof kind {
ScalarFunction scalar = 11;
Expand Down
2 changes: 2 additions & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ message UserDefinedFunction {
optional string identifier = 6;
// For JavaScript UDF, it's the body of the function.
optional string body = 7;
bool enable_dns_resolver = 9;
}

// Additional information for user defined table functions.
Expand All @@ -533,4 +534,5 @@ message UserDefinedTableFunction {
optional string link = 5;
optional string identifier = 6;
optional string body = 7;
bool enable_dns_resolver = 9;
}
12 changes: 9 additions & 3 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ impl Build for UserDefinedFunction {
#[cfg(not(madsim))]
_ => {
let link = udf.get_link()?;
UdfImpl::External(get_or_create_flight_client(link)?)
UdfImpl::External(get_or_create_flight_client(link, udf.enable_dns_resolver)?)
}
#[cfg(madsim)]
l => panic!("UDF language {l:?} is not supported on madsim"),
Expand Down Expand Up @@ -257,7 +257,10 @@ impl Build for UserDefinedFunction {
/// Get or create a client for the given UDF service.
///
/// There is a global cache for clients, so that we can reuse the same client for the same service.
pub(crate) fn get_or_create_flight_client(link: &str) -> Result<Arc<ArrowFlightUdfClient>> {
pub(crate) fn get_or_create_flight_client(
link: &str,
enable_dns_resolver: bool,
) -> Result<Arc<ArrowFlightUdfClient>> {
static CLIENTS: LazyLock<std::sync::Mutex<HashMap<String, Weak<ArrowFlightUdfClient>>>> =
LazyLock::new(Default::default);
let mut clients = CLIENTS.lock().unwrap();
Expand All @@ -266,7 +269,10 @@ pub(crate) fn get_or_create_flight_client(link: &str) -> Result<Arc<ArrowFlightU
Ok(client)
} else {
// create new client
let client = Arc::new(ArrowFlightUdfClient::connect_lazy(link)?);
let client = Arc::new(ArrowFlightUdfClient::connect_lazy(
link,
enable_dns_resolver,
)?);
clients.insert(link.into(), Arc::downgrade(&client));
Ok(client)
}
Expand Down
5 changes: 4 additions & 1 deletion src/expr/core/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
// connect to UDF service
_ => {
let link = udtf.get_link()?;
UdfImpl::External(crate::expr::expr_udf::get_or_create_flight_client(link)?)
UdfImpl::External(crate::expr::expr_udf::get_or_create_flight_client(
link,
udtf.enable_dns_resolver,
)?)
}
};

Expand Down
2 changes: 2 additions & 0 deletions src/expr/udf/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ arrow-flight = { workspace = true }
arrow-schema = { workspace = true }
arrow-select = { workspace = true }
cfg-or-panic = "0.2"
futures = "0.3"
futures-util = "0.3.28"
ginepro = "0.7.0"
prometheus = "0.13"
risingwave_common = { workspace = true }
static_assertions = "1"
Expand Down
2 changes: 1 addition & 1 deletion src/expr/udf/examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use risingwave_udf::ArrowFlightUdfClient;
#[tokio::main]
async fn main() {
let addr = "http://localhost:8815";
let client = ArrowFlightUdfClient::connect(addr).await.unwrap();
let client = ArrowFlightUdfClient::connect(addr, false).await.unwrap();

// build `RecordBatch` to send (equivalent to our `DataChunk`)
let array1 = Arc::new(Int32Array::from_iter(vec![1, 6, 10]));
Expand Down
67 changes: 56 additions & 11 deletions src/expr/udf/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::str::FromStr;
use std::time::Duration;

use arrow_array::RecordBatch;
Expand All @@ -22,8 +23,12 @@ use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::{FlightData, FlightDescriptor};
use arrow_schema::Schema;
use cfg_or_panic::cfg_or_panic;
use futures::executor::block_on;
use futures_util::{stream, Stream, StreamExt, TryStreamExt};
use ginepro::{LoadBalancedChannel, ResolutionStrategy};
use risingwave_common::util::addr::HostAddr;
use thiserror_ext::AsReport;
use tokio::time::Duration as TokioDuration;
use tonic::transport::Channel;

use crate::metrics::GLOBAL_METRICS;
Expand All @@ -40,12 +45,22 @@ pub struct ArrowFlightUdfClient {
#[cfg_or_panic(not(madsim))]
impl ArrowFlightUdfClient {
/// Connect to a UDF service.
pub async fn connect(addr: &str) -> Result<Self> {
let conn = tonic::transport::Endpoint::new(addr.to_string())?
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(5))
.connect()
.await?;
pub async fn connect(addr: &str, enable_dns_resolver: bool) -> Result<Self> {
let conn = if enable_dns_resolver {
Self::connect_with_dns(
addr,
ResolutionStrategy::Eager {
timeout: TokioDuration::from_secs(5),
},
)
.await?
} else {
tonic::transport::Endpoint::new(addr.to_string())?
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(5))
.connect()
.await?
};
let client = FlightServiceClient::new(conn);
Ok(Self {
client,
Expand All @@ -54,18 +69,48 @@ impl ArrowFlightUdfClient {
}

/// Connect to a UDF service lazily (i.e. only when the first request is sent).
pub fn connect_lazy(addr: &str) -> Result<Self> {
let conn = tonic::transport::Endpoint::new(addr.to_string())?
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(5))
.connect_lazy();
pub fn connect_lazy(addr: &str, enable_dns_resolver: bool) -> Result<Self> {
let conn = if enable_dns_resolver {
block_on(async { Self::connect_with_dns(addr, ResolutionStrategy::Lazy).await })?
} else {
tonic::transport::Endpoint::new(addr.to_string())?
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(5))
.connect_lazy()
};
let client = FlightServiceClient::new(conn);
Ok(Self {
client,
addr: addr.into(),
})
}

async fn connect_with_dns(
addr: &str,
resolution_strategy: ResolutionStrategy,
) -> Result<Channel> {
let addr = addr.strip_prefix("http://").ok_or_else(|| {
Error::service_error(format!("udf address must starts with http://: {}", addr))
})?;
let host_addr = HostAddr::from_str(addr)
.map_err(|e| Error::service_error(format!("invalid address: {}, err: {}", addr, e)))?;
let channel = LoadBalancedChannel::builder((host_addr.host.clone(), host_addr.port))
.dns_probe_interval(std::time::Duration::from_secs(5))
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(5))
.resolution_strategy(resolution_strategy)
.channel()
.await
.map_err(|e| {
Error::service_error(format!(
"failed to create LoadBalancedChannel, address: {}, err: {}",
host_addr, e
))
})?;

Ok(channel.into())
}

/// Check if the function is available and the schema is match.
pub async fn check(&self, id: &str, args: &Schema, returns: &Schema) -> Result<()> {
let descriptor = FlightDescriptor::new_path(vec![id.into()]);
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/catalog/function_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct FunctionCatalog {
pub identifier: Option<String>,
pub body: Option<String>,
pub link: Option<String>,
pub enable_dns_resolver: bool,
}

#[derive(Clone, Display, PartialEq, Eq, Hash, Debug)]
Expand Down Expand Up @@ -68,6 +69,7 @@ impl From<&PbFunction> for FunctionCatalog {
identifier: prost.identifier.clone(),
body: prost.body.clone(),
link: prost.link.clone(),
enable_dns_resolver: prost.enable_dns_resolver,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl TableFunction {
link: c.link.clone(),
identifier: c.identifier.clone(),
body: c.body.clone(),
enable_dns_resolver: c.enable_dns_resolver,
}),
}
}
Expand Down
Loading

0 comments on commit d8cd658

Please sign in to comment.