Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(udf): store WASM UDF in meta store #15269

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion e2e_test/batch/catalog/pg_settings.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ internal data_directory
internal parallel_compact_size_mb
internal sstable_size_mb
internal state_store
internal wasm_storage_url
postmaster backup_storage_directory
postmaster backup_storage_url
postmaster barrier_interval_ms
Expand Down
3 changes: 3 additions & 0 deletions proto/catalog.proto
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,10 @@ message Function {
string language = 7;
optional string link = 8;
optional string identifier = 10;
// The source code of the function.
optional string body = 14;
// The zstd-compressed binary of the function.
optional bytes compressed_binary = 17;
bool always_retry_on_network_error = 16;

oneof kind {
Expand Down
16 changes: 9 additions & 7 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -517,16 +517,17 @@ message UserDefinedFunction {
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
string language = 4;
// For external UDF: the link to the external function service.
// For WASM UDF: the link to the wasm binary file.
// The link to the external function service.
optional string link = 5;
// An unique identifier for the function.
// For external UDF, it's the name of the function in the external function service.
// For WASM UDF, it's the name of the function in the wasm binary file.
// For JavaScript UDF, it's the name of the function.
// An unique identifier to the function.
// - If `link` is not empty, the name of the function in the external function service.
// - If `language` is `rust` or `wasm`, the name of the function in the wasm binary file.
// - If `language` is `javascript`, the name of the function.
optional string identifier = 6;
// For JavaScript UDF, it's the body of the function.
// - If `language` is `javascript`, the source code of the function.
optional string body = 7;
// - If `language` is `rust` or `wasm`, the zstd-compressed wasm binary.
optional bytes compressed_binary = 10;
bool always_retry_on_network_error = 9;
}

Expand All @@ -538,4 +539,5 @@ message UserDefinedTableFunction {
optional string link = 5;
optional string identifier = 6;
optional string body = 7;
optional bytes compressed_binary = 10;
}
2 changes: 1 addition & 1 deletion proto/meta.proto
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ message SystemParams {
optional uint32 parallel_compact_size_mb = 11;
optional uint32 max_concurrent_creating_streaming_jobs = 12;
optional bool pause_on_next_bootstrap = 13;
optional string wasm_storage_url = 14;
optional string wasm_storage_url = 14 [deprecated = true];
optional bool enable_tracing = 15;
}

Expand Down
2 changes: 0 additions & 2 deletions src/common/src/system_param/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ macro_rules! for_all_params {
{ backup_storage_directory, String, None, true, "Remote directory for storing snapshots.", },
{ max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", },
{ pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", },
{ wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false, "", },
{ enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", },
}
};
Expand Down Expand Up @@ -440,7 +439,6 @@ mod tests {
(BACKUP_STORAGE_DIRECTORY_KEY, "a"),
(MAX_CONCURRENT_CREATING_STREAMING_JOBS_KEY, "1"),
(PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"),
(WASM_STORAGE_URL_KEY, "a"),
(ENABLE_TRACING_KEY, "true"),
("a_deprecated_param", "foo"),
];
Expand Down
7 changes: 0 additions & 7 deletions src/common/src/system_param/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,4 @@ where
.enable_tracing
.unwrap_or_else(default::enable_tracing)
}

fn wasm_storage_url(&self) -> &str {
self.inner()
.wasm_storage_url
.as_ref()
.unwrap_or(&default::WASM_STORAGE_URL)
}
}
1 change: 0 additions & 1 deletion src/config/docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,3 @@ This page is automatically generated by `./risedev generate-example-config`
| pause_on_next_bootstrap | Whether to pause all data sources on next bootstrap. | false |
| sstable_size_mb | Target size of the Sstable. | 256 |
| state_store | | |
| wasm_storage_url | | "fs://.risingwave/data" |
1 change: 0 additions & 1 deletion src/config/example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -195,5 +195,4 @@ block_size_kb = 64
bloom_false_positive = 0.001
max_concurrent_creating_streaming_jobs = 1
pause_on_next_bootstrap = false
wasm_storage_url = "fs://.risingwave/data"
enable_tracing = false
5 changes: 3 additions & 2 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ futures-async-stream = { workspace = true }
futures-util = "0.3"
itertools = "0.12"
linkme = { version = "0.3", features = ["used_linker"] }
moka = { version = "0.12", features = ["future"] }
md5 = "0.7"
moka = { version = "0.12", features = ["sync"] }
num-traits = "0.2"
openssl = { version = "0.10", features = ["vendored"] }
parse-display = "0.9"
paste = "1"
risingwave_common = { workspace = true }
risingwave_expr_macro = { path = "../macro" }
risingwave_object_store = { workspace = true }
risingwave_pb = { workspace = true }
risingwave_udf = { workspace = true }
smallvec = "1"
Expand All @@ -57,6 +57,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [
"macros",
] }
tracing = "0.1"
zstd = { version = "0.13", default-features = false }

[target.'cfg(not(madsim))'.dependencies]
workspace-hack = { path = "../../workspace-hack" }
Expand Down
47 changes: 13 additions & 34 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,10 @@ use arrow_udf_js::{CallMode, Runtime as JsRuntime};
use arrow_udf_wasm::Runtime as WasmRuntime;
use await_tree::InstrumentAwait;
use cfg_or_panic::cfg_or_panic;
use moka::future::Cache;
use moka::sync::Cache;
use risingwave_common::array::{ArrayError, ArrayRef, DataChunk};
use risingwave_common::config::ObjectStoreConfig;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
use risingwave_object_store::object::build_remote_object_store;
use risingwave_object_store::object::object_metrics::ObjectStoreMetrics;
use risingwave_pb::expr::ExprNode;
use risingwave_udf::ArrowFlightUdfClient;
use thiserror_ext::AsReport;
Expand Down Expand Up @@ -188,12 +185,11 @@ impl Build for UserDefinedFunction {
let imp = match udf.language.as_str() {
#[cfg(not(madsim))]
"wasm" | "rust" => {
let link = udf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(get_or_create_wasm_runtime(link))
})?)
let compressed_wasm_binary = udf.get_compressed_binary()?;
let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice())
.context("failed to decompress wasm binary")?;
let runtime = get_or_create_wasm_runtime(&wasm_binary)?;
UdfImpl::Wasm(runtime)
}
"javascript" => {
let mut rt = JsRuntime::new()?;
Expand Down Expand Up @@ -271,38 +267,21 @@ pub(crate) fn get_or_create_flight_client(link: &str) -> Result<Arc<ArrowFlightU
/// Get or create a wasm runtime.
///
/// Runtimes returned by this function are cached inside for at least 60 seconds.
/// Later calls with the same link will reuse the same runtime.
/// Later calls with the same binary will reuse the same runtime.
#[cfg_or_panic(not(madsim))]
pub async fn get_or_create_wasm_runtime(link: &str) -> Result<Arc<WasmRuntime>> {
static RUNTIMES: LazyLock<Cache<String, Arc<WasmRuntime>>> = LazyLock::new(|| {
pub fn get_or_create_wasm_runtime(binary: &[u8]) -> Result<Arc<WasmRuntime>> {
static RUNTIMES: LazyLock<Cache<md5::Digest, Arc<WasmRuntime>>> = LazyLock::new(|| {
Cache::builder()
.time_to_idle(Duration::from_secs(60))
.build()
});

if let Some(runtime) = RUNTIMES.get(link).await {
let md5 = md5::compute(binary);
if let Some(runtime) = RUNTIMES.get(&md5) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using get_with here? BTW, is it possible that there's an md5 collision? 🤣

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know get_with before you mentioned it. Looks good to use.
I think the possibility of md5 collision is negligible as this is not a security critical case. 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have the udf's identifier? Why bother to use md5? 🤔

Copy link
Contributor Author

@wangrunji0408 wangrunji0408 Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are different. The identifier is used for finding functions within a WASM binary, while multiple functions may share the same binary. Here we cache the WASM runtime for each binary, but we don't want to store full binaries in memory, so their md5 is used as the cache key.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. But my idea is that the case where multiple functions share the same binary is also rare. Anyway, both OK to me.

return Ok(runtime.clone());
}

// create new runtime
let (wasm_storage_url, object_name) = link
.rsplit_once('/')
.context("invalid link for wasm function")?;

// load wasm binary from object store
let object_store = build_remote_object_store(
wasm_storage_url,
Arc::new(ObjectStoreMetrics::unused()),
"Wasm Engine",
ObjectStoreConfig::default(),
)
.await;
let binary = object_store
.read(object_name, ..)
.await
.context("failed to load wasm binary from object storage")?;

let runtime = Arc::new(arrow_udf_wasm::Runtime::new(&binary)?);
RUNTIMES.insert(link.into(), runtime.clone()).await;
let runtime = Arc::new(arrow_udf_wasm::Runtime::new(binary)?);
RUNTIMES.insert(md5, runtime.clone());
Ok(runtime)
}
15 changes: 7 additions & 8 deletions src/expr/core/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::sync::Arc;

use anyhow::Context;
use arrow_array::RecordBatch;
use arrow_schema::{Field, Fields, Schema, SchemaRef};
use arrow_udf_js::{CallMode, Runtime as JsRuntime};
Expand Down Expand Up @@ -188,14 +189,12 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
let return_type = DataType::from(prost.get_return_type()?);

let client = match udtf.language.as_str() {
"wasm" => {
let link = udtf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(crate::expr::expr_udf::get_or_create_wasm_runtime(link))
})?)
"wasm" | "rust" => {
let compressed_wasm_binary = udtf.get_compressed_binary()?;
let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice())
.context("failed to decompress wasm binary")?;
let runtime = crate::expr::expr_udf::get_or_create_wasm_runtime(&wasm_binary)?;
UdfImpl::Wasm(runtime)
}
"javascript" => {
let mut rt = JsRuntime::new()?;
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ tokio-stream = "0.1"
tonic = { workspace = true }
tracing = "0.1"
uuid = "1"
zstd = { version = "0.13", default-features = false }

[target.'cfg(not(madsim))'.dependencies]
workspace-hack = { path = "../workspace-hack" }
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/catalog/function_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct FunctionCatalog {
pub identifier: Option<String>,
pub body: Option<String>,
pub link: Option<String>,
pub compressed_binary: Option<Vec<u8>>,
pub always_retry_on_network_error: bool,
}

Expand Down Expand Up @@ -69,6 +70,7 @@ impl From<&PbFunction> for FunctionCatalog {
identifier: prost.identifier.clone(),
body: prost.body.clone(),
link: prost.link.clone(),
compressed_binary: prost.compressed_binary.clone(),
always_retry_on_network_error: prost.always_retry_on_network_error,
}
}
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl TableFunction {
link: c.link.clone(),
identifier: c.identifier.clone(),
body: c.body.clone(),
compressed_binary: c.compressed_binary.clone(),
}),
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/expr/user_defined_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl UserDefinedFunction {
identifier: udf.identifier.clone(),
body: udf.body.clone(),
link: udf.link.clone(),
compressed_binary: udf.compressed_binary.clone(),
always_retry_on_network_error: udf.always_retry_on_network_error,
};

Expand Down Expand Up @@ -93,6 +94,7 @@ impl Expr for UserDefinedFunction {
identifier: self.catalog.identifier.clone(),
link: self.catalog.link.clone(),
body: self.catalog.body.clone(),
compressed_binary: self.catalog.compressed_binary.clone(),
always_retry_on_network_error: self.catalog.always_retry_on_network_error,
})),
}
Expand Down
Loading
Loading