diff --git a/.github/workflows/codegen.yml b/.github/workflows/codegen.yml index c892c70..122f6e1 100644 --- a/.github/workflows/codegen.yml +++ b/.github/workflows/codegen.yml @@ -2,15 +2,15 @@ name: CODEGEN on: push: - branches: [ main ] + branches: [main] paths: - - "codegen/**" + - "codegen/**" tags: - - 'weld-codegen-v*' + - "weld-codegen-v*" pull_request: - branches: [ main ] + branches: [main] paths: - - "codegen/**" + - "codegen/**" env: CARGO_TERM_COLOR: always @@ -20,34 +20,33 @@ jobs: rust_test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: update_rust - run: | - rustup toolchain install nightly - rustup update stable nightly - rustup component add --toolchain nightly rustfmt clippy - rustup default nightly - - name: run_all_tests_clippy_fmt - run: | - cd ${{ env.working-directory }} - cargo clippy --all-targets --all-features - rustfmt --check src/*.rs + - uses: actions/checkout@v2 + - name: Update rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + components: rustfmt, clippy + - name: run_all_tests_clippy_fmt + run: | + cd ${{ env.working-directory }} + make test + make rust-check github_release: if: startswith(github.ref, 'refs/tags/') # Only run on tag push needs: rust_test runs-on: ubuntu-latest steps: - - name: Create Release - id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ github.ref }} - release_name: Release ${{ github.ref }} - draft: false - prerelease: true + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ github.ref }} + release_name: Release ${{ github.ref }} + draft: false + prerelease: true crates_release: if: startswith(github.ref, 'refs/tags/') # Only run on tag push diff --git a/.github/workflows/macros.yml b/.github/workflows/macros.yml index 6b36de7..3ff0881 100644 --- a/.github/workflows/macros.yml +++ b/.github/workflows/macros.yml @@ -2,45 +2,51 @@ name: MACROS on: push: - branches: [ main ] + branches: [main] paths: - - "macros/**" + - "macros/**" tags: - - 'wasmbus-macros-v*' + - "wasmbus-macros-v*" pull_request: - branches: [ main ] + branches: [main] paths: - - "macros/**" + - "macros/**" env: CARGO_TERM_COLOR: always working-directory: ./macros jobs: - rust_check: + rust_test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - id: rust-check-action - uses: wasmcloud/common-actions/rust-check@main - with: + - uses: actions/checkout@v2 + - name: Update rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + components: rustfmt, clippy + - name: run_all_tests_clippy_fmt working-directory: ${{ env.working-directory }} + run: | + make test + make rust-check github_release: if: startswith(github.ref, 'refs/tags/') # Only run on tag push - needs: rust_check + needs: rust_test runs-on: ubuntu-latest steps: - - name: Create Release - id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ github.ref }} - release_name: Release ${{ github.ref }} - draft: false - prerelease: true + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ github.ref }} + release_name: Release ${{ github.ref }} + draft: false + prerelease: true crates_release: if: startswith(github.ref, 'refs/tags/') # Only run on tag push diff --git a/.github/workflows/wasmbus-rpc.yml b/.github/workflows/wasmbus-rpc.yml index ef52599..4617023 100644 --- a/.github/workflows/wasmbus-rpc.yml +++ b/.github/workflows/wasmbus-rpc.yml @@ -21,13 +21,14 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: update_rust - run: | - rustup update + - name: Update rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + components: rustfmt, clippy - name: run_all_tests_clippy_fmt - run: | - cd ${{ env.working-directory }} - make test + working-directory: ${{ env.working-directory }} + run: make test github_release: if: startswith(github.ref, 'refs/tags/') # Only run on tag push diff --git a/Cargo.toml b/Cargo.toml deleted file mode 100644 index 7741dc7..0000000 --- a/Cargo.toml +++ /dev/null @@ -1,2 +0,0 @@ -[workspace] -members = [ "codegen", "macros", "rpc-rs" ] diff --git a/Makefile b/Makefile index 34e9a96..9028cde 100644 --- a/Makefile +++ b/Makefile @@ -1,34 +1,35 @@ -# weld top-level Makefile +# wasmcloud/weld top-level Makefile # # Makefiles in this repository assume you have GNU Make (version 4.x) # If you're on mac, `brew install make` # and ensure `/usr/local/opt/make/libexec/gnubin` is in your PATH before /usr/bin +subdirs = codegen macros rpc-rs + MODEL_OUTPUT := codegen/src/wasmbus_model.rs rpc-rs/src/wasmbus_model.rs #MODEL_SRC := examples/interface/wasmbus-core/wasmcloud-model.smithy \ # examples/interface/wasmbus-core/codegen.toml -#WELD := target/debug/weld - -all: build -build clean: - cargo $@ +all build release clean test update lint validate rust-check:: + for dir in $(subdirs); do \ + $(MAKE) -C $$dir $@ ; \ + done -test: - # run clippy on all features and tests, and fail on warnings - cargo clippy --all-targets --all-features -- -D warnings - cargo test - -release: - cargo build --release +test:: + $(MAKE) check-model check-model: $(MODEL_OUTPUT) @diff $(MODEL_OUTPUT) || (echo ERROR: Model files differ && exit 1) -WELD_SRC := bin/Cargo.toml bin/src/*.rs codegen/Cargo.toml codegen/templates/*.toml \ - codegen/templates/*.hbs codegen/templates/rust/*.hbs -target/debug/weld: $(WELD_SRC) - cargo build --package weld-bin +gen: + $(MAKE) -C codegen release + (cd codegen && target/release/codegen) + (cd rpc-rs && ../codegen/target/release/codegen) + +#WELD_SRC := bin/Cargo.toml bin/src/*.rs codegen/Cargo.toml codegen/templates/*.toml \ +# codegen/templates/*.hbs codegen/templates/rust/*.hbs +#target/debug/weld: $(WELD_SRC) +# cargo build --package weld-bin .PHONY: all build release clean test check-model .NOTPARALLEL: diff --git a/codegen/src/codegen_rust.rs b/codegen/src/codegen_rust.rs index b573bbd..370ef2a 100644 --- a/codegen/src/codegen_rust.rs +++ b/codegen/src/codegen_rust.rs @@ -309,7 +309,7 @@ impl<'model> CodeGen for RustCodeGen<'model> { r#" #[allow(unused_imports)] use {}::{{ - //cbor::*, + cbor::*, common::{{ Context, deserialize, Message, MessageFormat, message_format, MessageDispatch, SendOpts, serialize, Transport, diff --git a/macros/Makefile b/macros/Makefile new file mode 100644 index 0000000..dab957e --- /dev/null +++ b/macros/Makefile @@ -0,0 +1,18 @@ +# weld/macros/Makefile + +all: build + +build clean update: + cargo $@ + +release: + cargo build --release + +test:: + cargo test --all-features --all-targets -- --nocapture + +rust-check:: + cargo clippy --all-features --all-targets + rustfmt --edition 2021 --check src/*.rs + +.PHONY: all build release clean lint validate test update rust-check diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 1c596b1..1f8fbbe 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -39,9 +39,7 @@ impl syn::parse::Parse for ReceiverDef { fn parse(input: syn::parse::ParseStream) -> ParseResult { let derive_input: syn::DeriveInput = input.parse()?; let attrs_span = derive_input.span(); - let syn::DeriveInput { - attrs, ident, data, .. - } = derive_input; + let syn::DeriveInput { attrs, ident, data, .. } = derive_input; let ident_span = ident.span(); let fields = match data { syn::Data::Struct(data) => data.fields, diff --git a/rpc-rs/CHANGELOG.md b/rpc-rs/CHANGELOG.md index 5335b4a..69180f7 100644 --- a/rpc-rs/CHANGELOG.md +++ b/rpc-rs/CHANGELOG.md @@ -1,10 +1,45 @@ # wasmbus-rpc Changelog +## BREAKING CHANGES from 0.8.x to 0.9.0 + +- provider_main has a new parameter: friendly_name, which is displayed on OTEL tracing dashboards. + Instead of `provider_main(MyAwesomeProvider::default())`, use: + `provider_main(MyAwesomeProvider::default(), Some("My Awesome Provider".to_string()))` + +- nats-aflowt is replaced with async-nats! + - removed 'wasmbus_rpc::anats' + - anats::ServerAddress renamed to async_nats::ServerAddr + - anats::Subscription is not public, replaced with async_nats::Subscriber + - anats::Subscription.close() replaced with async_nats::Subscriber.unsubscribe() + - anats::Options renamed to async_nats::ConnectOptions + - anats::Connection is no longer public. Use async_nats::Client instead. + - anats::Message.data renamed to async_nats::Message.payload +- HostBridge::new() changes + - first parameter is async_nats::Client instead of anats::Connection +- RpcClient::new() changes + - new() parameter takes async_nats Client instead of anats::Client + - lattice prefix removed from constructor, added in to some of the method parameters +- got rid of enum NatsClientType, replaced with async_nats::Client +- removed feature "chunkify" (it is always enabled for non-wasm32 targets) + +- RpcError does not implement Serialize, Deserialize, or PartialEq + (unless/until we can find a good reason to support these) + + +## non-breaking changes +- new feature flag "otel" enables OpenTelemetry tracing spans + - set environment variable `OTEL_TRACES_EXPORTER` to "otlp" + - set environment variable `OTEL_EXPORTER_OTLP_ENDPOINT` to the desired collector. Defaults to "http://127.0.0.1:55681/v1/traces" + - ("/v1/traces" will always be appended to this setting if it doesn't already end with "/v1/traces") +- dependencies (minicbor, uuid, and others) +- replaced ring with sha2 for sha256 + + ## 0.7.0 ### Breaking changes (since 0.6.x) -- Some of the crate exported symbols have moved to sub-modules. The intent is to resolve some linking problems +- Some of the crate's exported symbols have moved to submodules. The intent is to resolve some linking problems resulting from multiple inconsistent references to these symbols. Most of these changes will require only a recompile, for Actors and Providers that import `wasmbus_rpc::actor::prelude::*` or `wasmbus_rpc::provider::prelude::*`, respectively. diff --git a/rpc-rs/Cargo.toml b/rpc-rs/Cargo.toml index d2160d4..bd9ed54 100644 --- a/rpc-rs/Cargo.toml +++ b/rpc-rs/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wasmbus-rpc" -version = "0.8.5" +version = "0.9.0-alpha.1.1" authors = [ "wasmcloud Team" ] license = "Apache-2.0" description = "Runtime library for actors and capability providers" @@ -14,20 +14,20 @@ edition = "2021" exclude = [ "build.rs" ] [features] -default = [ "chunkify" ] +default = [ ] BigInteger = [ "num-bigint" ] BigDecimal = [ "bigdecimal" ] -# non-working feature - used for internal testing of async-rewrite branch of nats-io/nats.rs -#async_rewrite = [ "nats-experimental" ] -chunkify = [ "nats" ] -otel = ["opentelemetry", "tracing-opentelemetry", "nats", "opentelemetry-otlp"] +prometheus = [ "dep:prometheus" ] +otel = ["opentelemetry", "tracing-opentelemetry", "opentelemetry-otlp"] [dependencies] +anyhow = "1.0.57" async-trait = "0.1" base64 = "0.13" +bytes = "1.1.0" cfg-if = "1.0" -minicbor = { version = "0.13", features = ["std", "partial-skip-support"] } -rmp-serde = { version = "0.15.4" } +minicbor = { version = "0.17.1", features = ["std"] } +rmp-serde = "1.1.0" serde_bytes = "0.11" serde_json = "1.0" serde = { version = "1.0", features = ["derive"] } @@ -40,21 +40,21 @@ tracing-futures = "0.2" wasmbus-macros = { path = "../macros", version = "0.1.8" } minicbor-ser = "0.1.2" -#BigInteger support +# BigInteger support num-bigint = { version = "0.4", optional = true } -#BigDecimal support +# BigDecimal support bigdecimal = { version = "0.3", optional = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio = { version = "1", features = ["full"] } futures = "0.3" -nats-aflowt = "0.16.104" -nats = { version = "0.20", optional = true } +async-nats = "0.15.0" +nats = "0.20.1" nkeys = "0.2" once_cell = "1.8" -uuid = { version = "0.8", features = ["v4", "serde"] } +uuid = { version = "1.0", features = ["v4", "serde"] } wascap = "0.8.0" -ring = "0.16" +sha2 = "0.10.2" data-encoding = "2.3" tracing-subscriber = { version = "0.3.7", features = ["env-filter", "json"] } atty = "0.2" @@ -63,9 +63,11 @@ tracing-opentelemetry = { version = "0.17", optional = true } lazy_static = "1.4" opentelemetry-otlp = { version = "0.10", features = ["http-proto", "reqwest-client"], optional = true } +prometheus = { version = "0.13", optional = true } + [dev-dependencies] regex = "1" -env_logger = "0.9.0" +clap = { version = "3.2.5", features = ["derive"] } [build-dependencies] weld-codegen = { version = "0.4.3", path = "../codegen" } diff --git a/rpc-rs/Makefile b/rpc-rs/Makefile index 1a189ef..e18bf38 100644 --- a/rpc-rs/Makefile +++ b/rpc-rs/Makefile @@ -1,23 +1,16 @@ -# rpc-rs/Makefile +# weld/rpc-rs/Makefile all: build -build: - cargo build +build clean update: + cargo $@ release: cargo build --release -clean: - cargo clean - -# Run lint check on all smithy models in the models/smithy folder -lint: - $(WELD) lint - -# Run validation checks on all smithy models in the models/smithy folder -validate: - $(WELD) validate +# Run lint check on all smithy models +lint validate: + wash $@ ifeq ($(shell nc -czt -w1 127.0.0.1 4222 || echo fail),fail) test:: @@ -33,4 +26,4 @@ test:: rustfmt --edition 2021 --check src/*.rs endif -.PHONY: all build release clean lint validate +.PHONY: all build release clean lint validate test update rust-check diff --git a/rpc-rs/build.rs b/rpc-rs/build.rs deleted file mode 100644 index 6552653..0000000 --- a/rpc-rs/build.rs +++ /dev/null @@ -1,9 +0,0 @@ -// build.rs - build smithy models into rust sources at compile tile - -// path to codegen.toml relative to location of Cargo.toml -const CONFIG: &str = "./codegen.toml"; - -fn main() -> Result<(), Box> { - weld_codegen::rust_build(CONFIG)?; - Ok(()) -} diff --git a/rpc-rs/codegen.toml b/rpc-rs/codegen.toml index 6d00ce0..80e67da 100644 --- a/rpc-rs/codegen.toml +++ b/rpc-rs/codegen.toml @@ -2,11 +2,11 @@ [[models]] # wasmbus-core -url = "https://cdn.jsdelivr.net/gh/wasmcloud/interfaces@e69a8b22ce46457dbc48f8bbffd4734b68b739ce/core/wasmcloud-core.smithy" +url = "https://cdn.jsdelivr.net/gh/wasmcloud/interfaces@2c89bb706c6f2785a926dcde78ebc6a511a33206/core/wasmcloud-core.smithy" [[models]] # wasmbus-model -url = "https://cdn.jsdelivr.net/gh/wasmcloud/interfaces@e0f205da8a0e1549497571c3e994a1851480621c/core/wasmcloud-model.smithy" +url = "https://cdn.jsdelivr.net/gh/wasmcloud/interfaces/core/wasmcloud-model.smithy" [rust] output_dir = "." diff --git a/rpc-rs/examples/README.md b/rpc-rs/examples/README.md new file mode 100644 index 0000000..3f37c05 --- /dev/null +++ b/rpc-rs/examples/README.md @@ -0,0 +1,13 @@ + + +## Request + +```shell +cargo run --example request -- -d "data" -t 2000 "subject" +``` + +## Subscribe +```shell +cargo run --example sub -- "subject" +``` + diff --git a/rpc-rs/examples/request.rs b/rpc-rs/examples/request.rs new file mode 100644 index 0000000..05e18fb --- /dev/null +++ b/rpc-rs/examples/request.rs @@ -0,0 +1,63 @@ +use anyhow::{anyhow, Result}; +use clap::Parser; +use nkeys::KeyPairType; +use std::path::PathBuf; +use std::sync::Arc; +use wascap::prelude::KeyPair; +use wasmbus_rpc::rpc_client::RpcClient; + +/// RpcClient test CLI for making nats request +#[derive(Parser)] +#[clap(version, about, long_about = None)] +struct Args { + /// Nats uri. Defaults to 'nats://127.0.0.1:4222' + #[clap(short, long)] + nats: Option, + + /// File source for payload + #[clap(short, long)] + file: Option, + + /// Raw data for payload (as string) + #[clap(short, long)] + data: Option, + + /// Optional timeout in milliseconds + #[clap(short, long)] + timeout_ms: Option, + + /// Subject (topic) + #[clap(value_parser)] + subject: String, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + let data = match (args.data, args.file) { + (Some(d), None) => d.as_bytes().to_vec(), + (None, Some(f)) => { + if !f.is_file() { + return Err(anyhow!("missing data file {}", f.display())); + } + std::fs::read(&f) + .map_err(|e| anyhow!("error reading data source (path={}): {}", f.display(), e))? + } + _ => { + return Err(anyhow!("please specify --file or --data for data source")); + } + }; + if args.subject.is_empty() { + return Err(anyhow!("subject may not be empty")); + } + + let timeout = args.timeout_ms.map(|n| std::time::Duration::from_millis(n as u64)); + let kp = Arc::new(KeyPair::new(KeyPairType::User)); + let nats_uri = args.nats.unwrap_or_else(|| "nats://127.0.0.1:4222".to_string()); + let nc = async_nats::connect(&nats_uri).await?; + let client = RpcClient::new(nc, "HOST".into(), timeout, kp); + + let resp = client.request(args.subject, data).await?; + println!("{}", String::from_utf8_lossy(&resp)); + Ok(()) +} diff --git a/rpc-rs/examples/sub.rs b/rpc-rs/examples/sub.rs new file mode 100644 index 0000000..e411cc1 --- /dev/null +++ b/rpc-rs/examples/sub.rs @@ -0,0 +1,40 @@ +use anyhow::{anyhow, Result}; +use clap::Parser; +use futures::StreamExt; +use nkeys::KeyPairType; +use std::sync::Arc; +use wascap::prelude::KeyPair; +use wasmbus_rpc::rpc_client::RpcClient; + +/// RpcClient test CLI for connection and subscription +#[derive(Parser)] +#[clap(version, about, long_about = None)] +struct Args { + /// Nats uri. Defaults to 'nats://127.0.0.1:4222' + #[clap(short, long)] + nats: Option, + + /// Subject (topic) + #[clap(value_parser)] + subject: String, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + if args.subject.is_empty() { + return Err(anyhow!("subject may not be empty")); + } + let kp = Arc::new(KeyPair::new(KeyPairType::User)); + let nats_uri = args.nats.unwrap_or_else(|| "nats://127.0.0.1:4222".to_string()); + let nc = async_nats::connect(&nats_uri).await?; + let client = RpcClient::new(nc, "HOST".into(), None, kp); + + println!("Subscribing to {}", &args.subject); + + let mut sub = client.client().subscribe(args.subject).await?; + while let Some(msg) = sub.next().await { + println!("{}", String::from_utf8_lossy(&msg.payload)); + } + Ok(()) +} diff --git a/rpc-rs/src/actor_wasm.rs b/rpc-rs/src/actor_wasm.rs index 101b1fd..829a56e 100644 --- a/rpc-rs/src/actor_wasm.rs +++ b/rpc-rs/src/actor_wasm.rs @@ -30,12 +30,7 @@ extern "C" { } /// The function through which all host calls (from actors) take place. -pub fn host_call( - binding: &str, - ns: &str, - op: &str, - msg: &[u8], -) -> crate::error::RpcResult> { +pub fn host_call(binding: &str, ns: &str, op: &str, msg: &[u8]) -> RpcResult> { let callresult = unsafe { __host_call( binding.as_ptr() as _, @@ -104,7 +99,7 @@ impl crate::common::Transport for WasmHost { _ctx: &crate::common::Context, req: Message<'_>, _opts: Option, - ) -> std::result::Result, RpcError> { + ) -> Result, RpcError> { let res = if !self.target.public_key.is_empty() { // actor-to-actor calls use namespace for the actor target identifier host_call("", &self.target.public_key, req.method, req.arg.as_ref())? diff --git a/rpc-rs/src/cbor.rs b/rpc-rs/src/cbor.rs index 040c306..5669891 100644 --- a/rpc-rs/src/cbor.rs +++ b/rpc-rs/src/cbor.rs @@ -176,8 +176,8 @@ pub trait Decode<'b>: Sized { } } -pub trait MDecodeOwned: for<'de> crate::minicbor::Decode<'de> {} -impl MDecodeOwned for T where T: for<'de> crate::minicbor::Decode<'de> {} +pub trait MDecodeOwned: for<'de> crate::minicbor::Decode<'de, C> {} +impl MDecodeOwned for T where T: for<'de> crate::minicbor::Decode<'de, C> {} pub trait DecodeOwned: for<'de> crate::cbor::Decode<'de> {} impl DecodeOwned for T where T: for<'de> crate::cbor::Decode<'de> {} @@ -203,9 +203,10 @@ pub fn vec_encoder(header: bool) -> Encoder> { } /// A non-allocating CBOR encoder -impl Encoder +impl Encoder where - RpcError: From::Error>>, + W::Error: std::fmt::Display, + W: Write, { /// Constructs an Encoder around the writer pub fn new(writer: W) -> Self { @@ -348,7 +349,7 @@ where /// Returns the inner writer pub fn into_inner(self) -> W { - self.inner.into_inner() + self.inner.into_writer() } /// Write a tag @@ -383,6 +384,7 @@ pub enum Type { I16, I32, I64, + Int, F16, F32, F64, @@ -415,6 +417,7 @@ impl From for Type { MT::I16 => Type::I16, MT::I32 => Type::I32, MT::I64 => Type::I64, + MT::Int => Type::Int, MT::F16 => Type::F16, MT::F32 => Type::F32, MT::F64 => Type::F64, @@ -448,6 +451,7 @@ impl std::fmt::Display for Type { Type::I16 => f.write_str("i16"), Type::I32 => f.write_str("i32"), Type::I64 => f.write_str("i64"), + Type::Int => f.write_str("int"), Type::F16 => f.write_str("f16"), Type::F32 => f.write_str("f32"), Type::F64 => f.write_str("f64"), diff --git a/rpc-rs/src/chunkify.rs b/rpc-rs/src/chunkify.rs index 61f150d..ffb399b 100644 --- a/rpc-rs/src/chunkify.rs +++ b/rpc-rs/src/chunkify.rs @@ -1,4 +1,4 @@ -#![cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] +#![cfg(not(target_arch = "wasm32"))] // You strike me as a message that has never been chunkified // I'm sure I don't know what you mean, You forget yourself @@ -119,11 +119,7 @@ impl ChunkEndpoint { } /// chunkify a portion of a response - pub fn chunkify_response( - &self, - inv_id: &str, - bytes: &mut impl std::io::Read, - ) -> Result<(), RpcError> { + pub fn chunkify_response(&self, inv_id: &str, bytes: &mut impl Read) -> Result<(), RpcError> { self.chunkify(&format!("{}-r", inv_id), bytes) } diff --git a/rpc-rs/src/common.rs b/rpc-rs/src/common.rs index 0def68f..5cc8a01 100644 --- a/rpc-rs/src/common.rs +++ b/rpc-rs/src/common.rs @@ -61,7 +61,7 @@ pub trait Transport: Send { ctx: &Context, req: Message<'_>, opts: Option, - ) -> std::result::Result, RpcError>; + ) -> Result, RpcError>; /// Sets rpc timeout fn set_timeout(&self, interval: std::time::Duration); @@ -69,7 +69,7 @@ pub trait Transport: Send { // select serialization/deserialization mode pub fn deserialize<'de, T: Deserialize<'de>>(buf: &'de [u8]) -> Result { - rmp_serde::from_read_ref(buf).map_err(|e| RpcError::Deser(e.to_string())) + rmp_serde::from_slice(buf).map_err(|e| RpcError::Deser(e.to_string())) } pub fn serialize(data: &T) -> Result, RpcError> { @@ -88,7 +88,7 @@ pub trait MessageDispatch { } /// Message encoding format -#[derive(Clone, PartialEq)] +#[derive(Clone, Eq, PartialEq)] pub enum MessageFormat { Msgpack, Cbor, @@ -253,8 +253,8 @@ impl AnySender { } /// Send rpc with serializable payload using cbor encode/decode - pub async fn send_cbor<'de, In: crate::minicbor::Encode, Out: crate::cbor::MDecodeOwned>( - &self, + pub async fn send_cbor<'de, In: minicbor::Encode<()>, Out: crate::cbor::MDecodeOwned<()>>( + &mut self, ctx: &Context, method: &str, arg: &In, diff --git a/rpc-rs/src/document.rs b/rpc-rs/src/document.rs index f21c854..d904b9c 100644 --- a/rpc-rs/src/document.rs +++ b/rpc-rs/src/document.rs @@ -531,7 +531,7 @@ macro_rules! from_num_fn { } macro_rules! impl_try_from { - ($t:ty, $p: ident) => { + ($t:ty, $p:ident) => { impl TryFrom for $t { type Error = Document; @@ -561,7 +561,7 @@ macro_rules! impl_try_from_num { } macro_rules! impl_try_from_ref { - ($t:ty, $p: ident) => { + ($t:ty, $p:ident) => { impl<'v> TryFrom> for $t { type Error = DocumentRef<'v>; @@ -682,7 +682,10 @@ impl FromIterator for Document { pub fn encode_document( e: &mut crate::cbor::Encoder, val: &Document, -) -> RpcResult<()> { +) -> RpcResult<()> +where + ::Error: std::fmt::Display, +{ e.array(2)?; match val { Document::Object(map) => { @@ -728,7 +731,10 @@ pub fn encode_document( pub fn encode_document_ref<'v, W: crate::cbor::Write>( e: &mut crate::cbor::Encoder, val: &DocumentRef<'v>, -) -> RpcResult<()> { +) -> RpcResult<()> +where + ::Error: std::fmt::Display, +{ e.array(2)?; match val { DocumentRef::Object(map) => { @@ -775,7 +781,10 @@ pub fn encode_document_ref<'v, W: crate::cbor::Write>( pub fn encode_number( e: &mut crate::cbor::Encoder, val: &Number, -) -> RpcResult<()> { +) -> RpcResult<()> +where + ::Error: std::fmt::Display, +{ e.array(2)?; match val { Number::PosInt(val) => { @@ -804,7 +813,7 @@ pub fn decode_document(d: &mut crate::cbor::Decoder<'_>) -> RpcResult 0 => { // Object let map_len = d.fixed_map()? as usize; - let mut map = std::collections::HashMap::with_capacity(map_len); + let mut map = HashMap::with_capacity(map_len); for _ in 0..map_len { let k = d.str()?.to_string(); let v = decode_document(d)?; diff --git a/rpc-rs/src/error.rs b/rpc-rs/src/error.rs index cb635d5..207d41c 100644 --- a/rpc-rs/src/error.rs +++ b/rpc-rs/src/error.rs @@ -1,8 +1,6 @@ -use serde::{Deserialize, Serialize}; - /// An error that can occur in the processing of an RPC. This is not request-specific errors but /// rather cross-cutting errors that can always occur. -#[derive(thiserror::Error, Debug, PartialEq, Serialize, Deserialize)] +#[derive(thiserror::Error, Debug)] #[non_exhaustive] pub enum RpcError { /// The request exceeded its deadline. @@ -51,8 +49,6 @@ pub enum RpcError { #[error("timeout: {0}")] Timeout(String), - //#[error("IO error")] - //IO([from] std::io::Error) /// Anything else #[error("{0}")] Other(String), @@ -78,46 +74,14 @@ impl From for RpcError { } } -impl core::convert::From> for RpcError { - fn from(e: minicbor::encode::Error) -> Self { - let msg = match e { - minicbor::encode::Error::Write(_) => "writing to buffer", - minicbor::encode::Error::Message(s) => s, - _ => "unspecified encoding error", - } - .to_string(); - RpcError::Ser(format!("encode: {}", msg)) - } -} - -impl core::convert::From for RpcError { - fn from(e: minicbor::decode::Error) -> Self { - RpcError::Ser(format!("decode: {}", e)) +impl From> for RpcError { + fn from(e: minicbor::encode::Error) -> RpcError { + RpcError::Other(format!("encode: {}", e)) } } -/* -impl>> From> for RpcError { - fn from(_ee: Error<::Error>) -> RpcError { - RpcError::Other("help".to_string()) - } -} - */ - -/* -impl From> for RpcError { - fn from(e: minicbor::encode::Error) -> RpcError - where - W: minicbor::encode::Write, - { - RpcError::Ser( - match e { - minicbor::encode::Error::Write(_) => "writing to buffer", - minicbor::encode::Error::Message(s) => s, - _ => "unspecified encoding error", - } - .to_string(), - ) +impl From for RpcError { + fn from(e: minicbor::decode::Error) -> RpcError { + RpcError::Other(format!("decode: {}", e)) } } - */ diff --git a/rpc-rs/src/lib.rs b/rpc-rs/src/lib.rs index 5baa3e5..a951702 100644 --- a/rpc-rs/src/lib.rs +++ b/rpc-rs/src/lib.rs @@ -1,4 +1,4 @@ -//! Wasmcloud Weld runtime library +//! wasmcloud-rpc runtime library //! //! This crate provides code generation and runtime support for wasmcloud rpc messages //! used by [wasmcloud](https://wasmcloud.dev) actors and capability providers. @@ -7,31 +7,23 @@ mod timestamp; // re-export Timestamp pub use timestamp::Timestamp; +// re-export wascap crate +#[cfg(not(target_arch = "wasm32"))] +pub use wascap; -mod actor_wasm; -pub mod common; -#[cfg(feature = "otel")] +#[cfg(all(not(target_arch = "wasm32"), feature = "otel"))] +#[macro_use] pub mod otel; -pub mod provider; -pub(crate) mod provider_main; -mod wasmbus_model; -pub mod model { - // re-export model lib as "model" - pub use crate::wasmbus_model::*; - // declare unit type - pub type Unit = (); -} +mod actor_wasm; pub mod cbor; +pub mod common; pub(crate) mod document; pub mod error; +pub mod provider; +pub(crate) mod provider_main; +mod wasmbus_model; -// re-export nats-aflowt -#[cfg(not(target_arch = "wasm32"))] -pub use nats_aflowt as anats; - -/// This will be removed in a later version - use cbor instead to avoid dependence on minicbor crate -/// @deprecated pub use minicbor; #[cfg(not(target_arch = "wasm32"))] @@ -48,14 +40,21 @@ pub const WASMBUS_RPC_VERSION: u32 = 0; /// This crate's published version pub const WELD_CRATE_VERSION: &str = env!("CARGO_PKG_VERSION"); -pub type CallResult = std::result::Result, Box>; -pub type HandlerResult = std::result::Result>; +pub type CallResult = Result, Box>; +pub type HandlerResult = Result>; pub type TomlMap = toml::value::Map; -#[cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] +#[cfg(not(target_arch = "wasm32"))] pub(crate) mod chunkify; mod wasmbus_core; +#[macro_use] + +pub mod model { + // re-export model lib as "model" + pub use crate::wasmbus_model::*; +} + pub mod core { // re-export core lib as "core" use crate::error::{RpcError, RpcResult}; @@ -75,21 +74,20 @@ pub mod core { } /// Connect to nats using options provided by host - pub async fn nats_connect(&self) -> RpcResult { + pub async fn nats_connect(&self) -> RpcResult { use std::str::FromStr as _; let nats_addr = if !self.lattice_rpc_url.is_empty() { self.lattice_rpc_url.as_str() } else { crate::provider::DEFAULT_NATS_ADDR }; - let nats_server = nats_aflowt::ServerAddress::from_str(nats_addr).map_err(|e| { + let nats_server = async_nats::ServerAddr::from_str(nats_addr).map_err(|e| { RpcError::InvalidParameter(format!("Invalid nats server url '{}': {}", nats_addr, e)) })?; // Connect to nats - let nc = nats_aflowt::Options::default() - .max_reconnects(None) - .connect(vec![nats_server]) + let nc = async_nats::ConnectOptions::default() + .connect(nats_server) .await .map_err(|e| { RpcError::ProviderInit(format!("nats connection to {} failed: {}", nats_addr, e)) @@ -143,17 +141,6 @@ pub mod core { }) } - /* - /// create provider entity from link definition - pub fn from_link(link: &LinkDefinition) -> Self { - WasmCloudEntity { - public_key: link.provider_id.clone(), - contract_id: link.contract_id.clone(), - link_name: link.link_name.clone(), - } - } - */ - /// constructor for capability provider entity /// all parameters are required pub fn new_provider( @@ -182,7 +169,7 @@ pub mod core { /// Returns URL of the entity pub fn url(&self) -> String { if self.public_key.to_uppercase().starts_with('M') { - format!("{}://{}", crate::core::URL_SCHEME, self.public_key) + format!("{}://{}", URL_SCHEME, self.public_key) } else { format!( "{}://{}/{}/{}", @@ -275,3 +262,88 @@ pub mod actor { } } } + +#[cfg(test)] +mod test { + use anyhow::anyhow; + + fn ret_rpc_err(val: u8) -> Result { + let x = match val { + 0 => Ok(0), + 10 | 11 => Err(crate::error::RpcError::Other(format!("rpc:{}", val))), + _ => Ok(255), + }?; + Ok(x) + } + + fn ret_any(val: u8) -> anyhow::Result { + let x = match val { + 0 => Ok(0), + 20 | 21 => Err(anyhow!("any:{}", val)), + _ => Ok(255), + }?; + Ok(x) + } + + fn either(val: u8) -> anyhow::Result { + let x = match val { + 0 => 0, + 10 | 11 => ret_rpc_err(val)?, + 20 | 21 => ret_any(val)?, + _ => 255, + }; + Ok(x) + } + + #[test] + fn values() { + use crate::error::RpcError; + + let v0 = ret_rpc_err(0); + assert_eq!(v0.ok().unwrap(), 0); + + let v10 = either(10); + assert!(v10.is_err()); + assert_eq!(v10.as_ref().err().unwrap().to_string().as_str(), "rpc:10"); + if let Err(e) = &v10 { + if let Some(rpc_err) = e.downcast_ref::() { + eprintln!("10 is rpc error (ok)"); + match rpc_err { + RpcError::Other(s) => { + eprintln!("RpcError::Other({})", s); + } + RpcError::Nats(s) => { + eprintln!("RpcError::Nats({})", s); + } + _ => { + eprintln!("RpcError::unknown {}", rpc_err); + } + } + } else { + eprintln!("10 is not rpc error. value={}", e); + } + } + + let v20 = either(20); + assert!(v20.is_err()); + assert_eq!(v20.as_ref().err().unwrap().to_string().as_str(), "any:20"); + if let Err(e) = &v20 { + if let Some(rpc_err) = e.downcast_ref::() { + eprintln!("20 is rpc error (ok)"); + match rpc_err { + RpcError::Other(s) => { + eprintln!("RpcError::Other({})", s); + } + RpcError::Nats(s) => { + eprintln!("RpcError::Nats({})", s); + } + _ => { + eprintln!("RpcError::unknown {}", rpc_err); + } + } + } else { + eprintln!("20 is not rpc error. value={}", e); + } + } + } +} diff --git a/rpc-rs/src/otel.rs b/rpc-rs/src/otel.rs index 8b88b16..c81fe57 100644 --- a/rpc-rs/src/otel.rs +++ b/rpc-rs/src/otel.rs @@ -2,14 +2,11 @@ //! wasmbus-rpc calls. Please note that right now this is only supported for providers. This module //! is only available with the `otel` feature enabled -use std::collections::HashSet; - -/// NOTE: The commented out code in here will work once we upgrade to async_nats. I have left it so that we don't have to rewrite it again -// use async_nats::header::HeaderName; -// use async_nats::HeaderMap; -use nats_aflowt::header::HeaderMap; -use opentelemetry::propagation::{Extractor, Injector, TextMapPropagator}; -use opentelemetry::sdk::propagation::TraceContextPropagator; +use async_nats::header::HeaderMap; +use opentelemetry::{ + propagation::{Extractor, Injector, TextMapPropagator}, + sdk::propagation::TraceContextPropagator, +}; use tracing::span::Span; use tracing_opentelemetry::OpenTelemetrySpanExt; @@ -17,123 +14,6 @@ lazy_static::lazy_static! { static ref EMPTY_HEADERS: HeaderMap = HeaderMap::default(); } -// /// A convenience type that wraps a NATS [`HeaderMap`] and implements the [`Extractor`] trait -// #[derive(Debug)] -// pub struct OtelHeaderExtractor<'a> { -// inner: &'a HeaderMap, -// } - -// impl<'a> OtelHeaderExtractor<'a> { -// /// Creates a new extractor using the given [`HeaderMap`] -// pub fn new(headers: &'a HeaderMap) -> Self { -// OtelHeaderExtractor { inner: headers } -// } - -// /// Creates a new extractor using the given message -// pub fn new_from_message(msg: &'a async_nats::Message) -> Self { -// OtelHeaderExtractor { -// inner: msg.headers.as_ref().unwrap_or(&EMPTY_HEADERS), -// } -// } -// } - -// impl<'a> Extractor for OtelHeaderExtractor<'a> { -// fn get(&self, key: &str) -> Option<&str> { -// self.inner.get(key).and_then(|s| s.to_str().ok()) -// } - -// fn keys(&self) -> Vec<&str> { -// self.inner.keys().map(|s| s.as_str()).collect() -// } -// } - -// impl<'a> AsRef for OtelHeaderExtractor<'a> { -// fn as_ref(&self) -> &'a HeaderMap { -// self.inner -// } -// } - -// /// A convenience type that wraps a NATS [`HeaderMap`] and implements the [`Injector`] trait -// #[derive(Debug, Default)] -// pub struct OtelHeaderInjector { -// inner: HeaderMap, -// } - -// impl OtelHeaderInjector { -// /// Creates a new injector using the given [`HeaderMap`] -// pub fn new(headers: HeaderMap) -> Self { -// OtelHeaderInjector { inner: headers } -// } - -// /// Convenience constructor that returns a new injector with the current span context already -// /// injected into the given header map -// pub fn new_with_span(headers: HeaderMap) -> Self { -// let mut header_map = Self::new(headers); -// header_map.inject_context(); -// header_map -// } - -// /// Convenience constructor that returns a new injector with the current span context already -// /// injected into a default [`HeaderMap`] -// pub fn default_with_span() -> Self { -// let mut header_map = Self::default(); -// header_map.inject_context(); -// header_map -// } - -// /// Injects the current context from the span into the headers -// pub fn inject_context(&mut self) { -// let ctx_propagator = TraceContextPropagator::new(); -// ctx_propagator.inject_context(&Span::current().context(), self); -// } -// } - -// impl Injector for OtelHeaderInjector { -// fn set(&mut self, key: &str, value: String) { -// // NOTE: Because the underlying headers are an http header, we are going to escape any -// // unicode values and non-printable ASCII chars, which sounds better than just silently -// // ignoring or using an empty string. Unfortunately this adds an extra allocation that is -// // probably ok for now as it is freed at the end, but I prefer telemetry stuff to be as -// // little overhead as possible. If anyone has a better idea of how to handle this, please PR -// // it in -// let header_name = key.escape_default().to_string().into_bytes(); -// let escaped = value.escape_default().to_string().into_bytes(); -// // SAFETY: All chars escaped above -// self.inner.insert( -// HeaderName::from_bytes(&header_name).unwrap(), -// async_nats::HeaderValue::from_bytes(&escaped).unwrap(), -// ); -// } -// } - -// impl AsRef for OtelHeaderInjector { -// fn as_ref(&self) -> &HeaderMap { -// &self.inner -// } -// } - -// impl From for OtelHeaderInjector { -// fn from(headers: HeaderMap) -> Self { -// OtelHeaderInjector::new(headers) -// } -// } - -// impl From for HeaderMap { -// fn from(inj: OtelHeaderInjector) -> Self { -// inj.inner -// } -// } - -// /// A convenience function that will extract the current context from NATS message headers and set -// /// the parent span for the current tracing Span. If you want to do something more advanced, use the -// /// [`OtelHeaderExtractor`] type directly -// pub fn attach_span_context(msg: &async_nats::Message) { -// let header_map = OtelHeaderExtractor::new_from_message(msg); -// let ctx_propagator = TraceContextPropagator::new(); -// let parent_ctx = ctx_propagator.extract(&header_map); -// Span::current().set_parent(parent_ctx); -// } - /// A convenience type that wraps a NATS [`HeaderMap`] and implements the [`Extractor`] trait #[derive(Debug)] pub struct OtelHeaderExtractor<'a> { @@ -147,7 +27,7 @@ impl<'a> OtelHeaderExtractor<'a> { } /// Creates a new extractor using the given message - pub fn new_from_message(msg: &'a nats_aflowt::Message) -> Self { + pub fn new_from_message(msg: &'a async_nats::Message) -> Self { OtelHeaderExtractor { inner: msg.headers.as_ref().unwrap_or(&EMPTY_HEADERS), } @@ -156,8 +36,7 @@ impl<'a> OtelHeaderExtractor<'a> { impl<'a> Extractor for OtelHeaderExtractor<'a> { fn get(&self, key: &str) -> Option<&str> { - // This will just take the first element of the header if it exists - self.inner.get(key).and_then(|s| s.iter().next().map(|s| s.as_str())) + self.inner.get(key).and_then(|s| s.to_str().ok()) } fn keys(&self) -> Vec<&str> { @@ -208,9 +87,19 @@ impl OtelHeaderInjector { impl Injector for OtelHeaderInjector { fn set(&mut self, key: &str, value: String) { - let mut settified_value = HashSet::new(); - settified_value.insert(value); - self.inner.inner.insert(key.to_owned(), settified_value); + // NOTE: Because the underlying headers are an http header, we are going to escape any + // unicode values and non-printable ASCII chars, which sounds better than just silently + // ignoring or using an empty string. Unfortunately this adds an extra allocation that is + // probably ok for now as it is freed at the end, but I prefer telemetry stuff to be as + // little overhead as possible. If anyone has a better idea of how to handle this, please PR + // it in + let header_name = key.escape_default().to_string().into_bytes(); + let escaped = value.escape_default().to_string().into_bytes(); + // SAFETY: All chars escaped above + self.inner.insert( + async_nats::header::HeaderName::from_bytes(&header_name).unwrap(), + async_nats::HeaderValue::from_bytes(&escaped).unwrap(), + ); } } @@ -235,7 +124,7 @@ impl From for HeaderMap { /// A convenience function that will extract the current context from NATS message headers and set /// the parent span for the current tracing Span. If you want to do something more advanced, use the /// [`OtelHeaderExtractor`] type directly -pub fn attach_span_context(msg: &nats_aflowt::Message) { +pub fn attach_span_context(msg: &async_nats::Message) { let header_map = OtelHeaderExtractor::new_from_message(msg); let ctx_propagator = TraceContextPropagator::new(); let parent_ctx = ctx_propagator.extract(&header_map); diff --git a/rpc-rs/src/provider.rs b/rpc-rs/src/provider.rs index c71f600..d1776ca 100644 --- a/rpc-rs/src/provider.rs +++ b/rpc-rs/src/provider.rs @@ -6,21 +6,23 @@ use std::{ borrow::Cow, collections::HashMap, convert::Infallible, + fmt::Formatter, ops::Deref, sync::{Arc, Mutex as StdMutex}, time::Duration, }; +use crate::wascap::{ + jwt, + prelude::{Claims, KeyPair}, +}; use async_trait::async_trait; -use cfg_if::cfg_if; -use futures::future::JoinAll; +use futures::{future::JoinAll, StreamExt}; use serde::de::DeserializeOwned; -use tokio::sync::{oneshot, RwLock}; -use tracing::{debug, error, info, trace, warn}; +use tokio::sync::RwLock; +use tracing::{debug, error, info, warn}; use tracing_futures::Instrument; -#[cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] -use crate::chunkify::chunkify_endpoint; pub use crate::rpc_client::make_uuid; use crate::{ common::{deserialize, serialize, Context, Message, MessageDispatch, SendOpts, Transport}, @@ -29,7 +31,7 @@ use crate::{ LinkDefinition, }, error::{RpcError, RpcResult}, - rpc_client::{NatsClientType, RpcClient, DEFAULT_RPC_TIMEOUT_MILLIS}, + rpc_client::{RpcClient, DEFAULT_RPC_TIMEOUT_MILLIS}, }; // name of nats queue group for rpc subscription @@ -41,7 +43,7 @@ pub(crate) const DEFAULT_NATS_ADDR: &str = "nats://127.0.0.1:4222"; pub type HostShutdownEvent = String; pub trait ProviderDispatch: MessageDispatch + ProviderHandler {} -trait ProviderImpl: ProviderDispatch + Send + Sync + Clone + 'static {} +trait ProviderImpl: ProviderDispatch + Send + Sync + 'static {} pub mod prelude { pub use crate::{ @@ -98,6 +100,38 @@ pub trait ProviderHandler: Sync { /// format of log message sent to main thread for output to logger pub type LogEntry = (tracing::Level, String); +pub type QuitSignal = tokio::sync::broadcast::Receiver; + +#[doc(hidden)] +/// Process subscription, until closed or exhausted, or value is received on the channel. +/// `sub` is a mutable Subscriber (regular or queue subscription) +/// `channel` may be either tokio mpsc::Receiver or broadcast::Receiver, and is considered signaled +/// when a value is sent or the chanel is closed. +/// `msg` is the variable name to be used in the handler +/// `on_item` is an async handler +macro_rules! process_until_quit { + ($sub:ident, $channel:ident, $msg:ident, $on_item:tt) => { + tokio::spawn(async move { + loop { + tokio::select! { + _ = $channel.recv() => { + let _ = $sub.unsubscribe().await; + break; + }, + __msg = $sub.next() => { + match __msg { + None => break, + Some($msg) => { + $on_item + } + } + } + } + } + }) + }; +} + /// HostBridge manages the NATS connection to the host, /// and processes subscriptions for links, health-checks, and rpc messages. /// Callbacks from HostBridge are implemented by the provider in the [[ProviderHandler]] implementation. @@ -105,22 +139,14 @@ pub type LogEntry = (tracing::Level, String); #[derive(Clone)] pub struct HostBridge { inner: Arc, + #[allow(dead_code)] + key: Arc, host_data: HostData, } impl HostBridge { - pub fn new(nats: crate::anats::Connection, host_data: &HostData) -> RpcResult { - Self::new_client(NatsClientType::Async(nats), host_data) - } - - #[cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] + #[cfg(not(target_arch = "wasm32"))] pub(crate) fn new_sync_client(&self) -> RpcResult { - //let key = if self.host_data.is_test() { - // wascap::prelude::KeyPair::new_user() - //} else { - // wascap::prelude::KeyPair::from_seed(&self.host_data.invocation_seed) - // .map_err(|e| RpcError::NotInitialized(format!("key failure: {}", e)))? - //}; let nats_addr = if !self.host_data.lattice_rpc_url.is_empty() { self.host_data.lattice_rpc_url.as_str() } else { @@ -146,28 +172,31 @@ impl HostBridge { }) } - pub(crate) fn new_client(nats: NatsClientType, host_data: &HostData) -> RpcResult { - let key = if host_data.is_test() { - wascap::prelude::KeyPair::new_user() + pub(crate) fn new_client( + nats: async_nats::Client, + host_data: &HostData, + ) -> RpcResult { + let key = Arc::new(if host_data.is_test() { + KeyPair::new_user() } else { - wascap::prelude::KeyPair::from_seed(&host_data.invocation_seed) + KeyPair::from_seed(&host_data.invocation_seed) .map_err(|e| RpcError::NotInitialized(format!("key failure: {}", e)))? - }; + }); + let rpc_client = RpcClient::new_client( nats, - &host_data.lattice_rpc_prefix, - key, host_data.host_id.clone(), host_data.default_rpc_timeout_ms.map(|ms| Duration::from_millis(ms as u64)), + key.clone(), ); Ok(HostBridge { inner: Arc::new(HostBridgeInner { - subs: RwLock::new(Vec::new()), links: RwLock::new(HashMap::new()), rpc_client, lattice_prefix: host_data.lattice_rpc_prefix.clone(), }), + key, host_data: host_data.clone(), }) } @@ -200,11 +229,8 @@ impl Deref for HostBridge { /// Initialize host bridge for use by wasmbus-test-util. /// The purpose is so that test code can get the nats configuration /// This is never called inside a provider process (and will fail if a provider calls it) -pub fn init_host_bridge_for_test( - nc: crate::anats::Connection, - host_data: &HostData, -) -> crate::error::RpcResult<()> { - let hb = HostBridge::new(nc, host_data)?; +pub fn init_host_bridge_for_test(nc: async_nats::Client, host_data: &HostData) -> RpcResult<()> { + let hb = HostBridge::new_client(nc, host_data)?; crate::provider_main::set_host_bridge(hb) .map_err(|_| RpcError::Other("HostBridge already initialized".to_string()))?; Ok(()) @@ -212,7 +238,6 @@ pub fn init_host_bridge_for_test( #[doc(hidden)] pub struct HostBridgeInner { - subs: RwLock>, /// Table of actors that are bound to this provider /// Key is actor_id / actor public key links: RwLock>, @@ -220,46 +245,31 @@ pub struct HostBridgeInner { lattice_prefix: String, } +impl std::fmt::Debug for HostBridge { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HostBridge") + .field("provider_id", &self.host_data.provider_key) + .field("host_id", &self.host_data.host_id) + .field("link", &self.host_data.link_name) + .field("lattice_prefix", &self.lattice_prefix) + .finish() + } +} + impl HostBridge { /// Returns a reference to the rpc client fn rpc_client(&self) -> &RpcClient { &self.rpc_client } - /// Clear out all subscriptions - async fn unsubscribe_all(&self) { - let mut copy = Vec::new(); - { - let mut sub_lock = self.subs.write().await; - copy.append(&mut sub_lock); - }; - // `drop`ping the Subscription doesn't close it - we need to unsubscribe - for sub in copy.into_iter() { - if let Err(e) = sub.close().await { - debug!(error = %e, "failure to unsubscribe during shutdown"); - } - } - debug!("unsubscribed from all subscriptions"); - } - - // add subscription so we can unsubscribe_all later - async fn add_subscription(&self, sub: crate::anats::Subscription) { - let mut sub_lock = self.subs.write().await; - sub_lock.push(sub); - } - // parse incoming subscription message // if it fails deserialization, we can't really respond; // so log the error - fn parse_msg( - &self, - msg: &crate::anats::Message, - topic: &str, - ) -> Option { + fn parse_msg(&self, msg: &async_nats::Message, topic: &str) -> Option { match if self.host_data.is_test() { - serde_json::from_slice(&msg.data).map_err(|e| RpcError::Deser(e.to_string())) + serde_json::from_slice(&msg.payload).map_err(|e| RpcError::Deser(e.to_string())) } else { - deserialize(&msg.data) + deserialize(&msg.payload) } { Ok(item) => Some(item), Err(e) => { @@ -294,309 +304,246 @@ impl HostBridge { } /// Implement subscriber listener threads and provider callbacks - pub async fn connect

( + pub(crate) async fn connect

( &'static self, provider: P, - shutdown_tx: oneshot::Sender, - ) -> RpcResult>>> + shutdown_tx: &tokio::sync::broadcast::Sender, + lattice: &str, + ) -> JoinAll>> where P: ProviderDispatch + Send + Sync + Clone + 'static, { - let join = futures::future::join_all(vec![ - tokio::task::spawn(self.subscribe_rpc(provider.clone())), - tokio::task::spawn(self.subscribe_link_put(provider.clone())), - tokio::task::spawn(self.subscribe_link_del(provider.clone())), - tokio::task::spawn(self.subscribe_shutdown(provider.clone(), shutdown_tx)), - tokio::task::spawn(self.subscribe_health(provider)), - ]); - Ok(join) + let lattice = lattice.to_string(); + futures::future::join_all(vec![ + tokio::task::spawn(self.subscribe_rpc( + provider.clone(), + shutdown_tx.subscribe(), + lattice, + )), + tokio::task::spawn(self.subscribe_link_put(provider.clone(), shutdown_tx.subscribe())), + tokio::task::spawn(self.subscribe_link_del(provider.clone(), shutdown_tx.subscribe())), + tokio::task::spawn(self.subscribe_shutdown(provider.clone(), shutdown_tx.clone())), + // subscribe to health last, after receivers are set up + tokio::task::spawn(self.subscribe_health(provider, shutdown_tx.subscribe())), + ]) } - async fn subscribe_rpc

(&self, provider: P) -> RpcResult<()> - where - P: ProviderDispatch + Send + Sync + Clone + 'static, - { - let rpc_topic = format!( + /// flush nats - called before main process exits + pub(crate) async fn flush(&self) { + if let Err(error) = self.inner.rpc_client.client().flush().await { + error!(%error, "flushing nats connection"); + } + } + + /// Returns the nats rpc topic for capability providers + pub fn provider_rpc_topic(&self) -> String { + format!( "wasmbus.rpc.{}.{}.{}", &self.lattice_prefix, &self.host_data.provider_key, self.host_data.link_name - ); + ) + } - debug!(%rpc_topic, "subscribing for rpc"); - let sub = self - .rpc_client() - .get_async() - .unwrap() // we are only async - .queue_subscribe(&rpc_topic, RPC_SUBSCRIPTION_QUEUE_GROUP) + /// Subscribe to a nats topic for rpc messages. + /// This method starts a separate async task and returns immediately. + /// It will exit if the nats client disconnects, or if a signal is received on the quit channel. + pub async fn subscribe_rpc

( + &self, + provider: P, + mut quit: QuitSignal, + lattice: String, + ) -> RpcResult<()> + where + P: ProviderDispatch + Send + Sync + Clone + 'static, + { + let mut sub = self + .rpc_client + .client() + .queue_subscribe( + self.provider_rpc_topic(), + RPC_SUBSCRIPTION_QUEUE_GROUP.to_string(), + ) .await .map_err(|e| RpcError::Nats(e.to_string()))?; - self.add_subscription(sub.clone()).await; let this = self.clone(); tokio::spawn(async move { - while let Some(msg) = sub.next().await { - let span = tracing::debug_span!("subscribe_rpc", %rpc_topic); - let _enter = span.enter(); - #[cfg(feature = "otel")] - crate::otel::attach_span_context(&msg); - match deserialize::(&msg.data) { - Ok(mut inv) => { - match this.dechunk_validate(&mut inv).in_current_span().await { - Ok(()) => { - let provider = provider.clone(); - let rpc_client = this.rpc_client().clone(); - // NOTE: This should just spawn with its own span as we aren't - // awaiting anything here. If for some reason there are some funky - // traces, we can try manually exiting the span here - tokio::task::spawn( - async move { - tracing::span::Span::current().record( - "operation", - &tracing::field::display(&inv.operation), - ); - tracing::span::Span::current().record( - "public_key", - &tracing::field::display(&inv.origin.public_key), - ); - trace!("Dispatching RPC invocation"); - let response = match provider - .dispatch( - &Context { - actor: Some(inv.origin.public_key.clone()), - ..Default::default() - }, - Message { - method: &inv.operation, - arg: Cow::from(inv.msg), - }, - ) - .await - { - Ok(msg) => InvocationResponse { - invocation_id: inv.id, - msg: msg.arg.to_vec(), + loop { + tokio::select! { + _ = quit.recv() => { + let _ = sub.unsubscribe().await; + break; + }, + nats_msg = sub.next() => { + let msg = match nats_msg { + None => break, + Some(msg) => msg + }; + let this = this.clone(); + let provider = provider.clone(); + let lattice = lattice.clone(); + tokio::spawn( async move { + let span = tracing::debug_span!("rpc"); + let _enter = span.enter(); + #[cfg(feature = "otel")] + crate::otel::attach_span_context(&msg); + match crate::common::deserialize::(&msg.payload) { + Ok(inv) => { + let inv_id = inv.id.clone(); + span.record("operation", &tracing::field::display(&inv.operation)); + span.record("lattice_id", &tracing::field::display(&lattice)); + span.record("actor_id", &tracing::field::display(&inv.origin)); + span.record("inv_id", &tracing::field::display(&inv.id)); + span.record("host_id", &tracing::field::display(&inv.host_id)); + span.record("provider_id", &tracing::field::display(&inv.target.public_key)); + span.record("contract_id", &tracing::field::display(&inv.target.contract_id)); + span.record("link_name", &tracing::field::display(&inv.target.link_name)); + span.record("payload_size", &tracing::field::display(&inv.content_length.unwrap_or_default())); + let provider = provider.clone(); + let resp = match this.handle_rpc(provider, inv).in_current_span().await { + Err(error) => { + error!( + %error, + "Invocation failed" + ); + InvocationResponse{ + invocation_id: inv_id, + error: Some(error.to_string()), ..Default::default() - }, - Err(e) => { - error!( - error = %e, - "RPC invocation failed", - ); - InvocationResponse { - invocation_id: inv.id, - error: Some(e.to_string()), - ..Default::default() - } } - }; - if let Some(reply_to) = msg.reply { - // Errors are published from inside the function, safe to ignore Result - let _ = publish_invocation_response( - &rpc_client, - reply_to, - response, - ) - .await; + }, + Ok(bytes) => { + InvocationResponse{ + invocation_id: inv_id, + content_length: Some(bytes.len() as u64), + msg: bytes, + ..Default::default() + } + } + }; + if let Some(reply) = msg.reply { + // send reply + if let Err(error) = this.rpc_client() + .publish_invocation_response(reply, resp, &lattice).in_current_span().await { + error!(%error, "rpc sending response"); } } - .instrument( - tracing::error_span!( - "invocation_dispatch", - operation = tracing::field::Empty, - public_key = tracing::field::Empty - ), - ), - ); - } - Err(s) => { - error!( - operation = %inv.operation, - public_key = %inv.origin.public_key, - invocation_id = %inv.id, - host_id = %inv.host_id, - error = %s, - "Invocation validation failure" - ); - - if let Some(reply_to) = msg.reply { - // Errors are published from inside the function, safe to ignore Result - let _ = publish_invocation_response( - this.rpc_client(), - reply_to, - InvocationResponse { - invocation_id: inv.id, - error: Some(s.to_string()), - ..Default::default() - }, - ) - .in_current_span() - .await; - } - } - } - } - Err(e) => { - error!(error = %e, "Invocation deserialization failure"); - if let Some(reply_to) = msg.reply { - if let Err(e) = publish_invocation_response( - this.rpc_client(), - reply_to, - InvocationResponse { - invocation_id: "invalid".to_string(), - error: Some(format!("Corrupt invocation: {}", e)), - ..Default::default() }, - ) - .in_current_span() - .await - { - error!(error = %e, "error when replying to rpc"); - } - } - } + Err(error) => { + error!(%error, "invalid rpc message received (not deserializable)"); + if let Some(reply) = msg.reply { + if let Err(e) = this.rpc_client().publish_invocation_response(reply, + InvocationResponse{ + error: Some(format!("deser error: {}", error)), + ..Default::default() + }, + &lattice + ).in_current_span().await { + error!(error = %e, "unable to publish error message to invocation response"); + } + } + } + }; + }); /* spawn */ + } /* next */ } - } + } /* loop */ }); Ok(()) } - async fn dechunk_validate(&self, inv: &mut Invocation) -> RpcResult<()> { - #[cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] - if inv.content_length.is_some() && inv.content_length.unwrap() > inv.msg.len() as u64 { - let inv_id = inv.id.clone(); - let lattice = self.rpc_client.lattice_prefix().to_string(); - inv.msg = tokio::task::spawn_blocking(move || { - let ce = chunkify_endpoint(None, lattice) - .map_err(|e| format!("connecting for de-chunkifying: {}", &e.to_string()))?; - ce.get_unchunkified(&inv_id).map_err(|e| e.to_string()) - }) - .await - .map_err(|je| format!("join/dechunk-validate: {}", je))??; + async fn handle_rpc

(&self, provider: P, inv: Invocation) -> Result, RpcError> + where + P: ProviderDispatch + Send + Sync + Clone + 'static, + { + let lattice = &self.host_data.lattice_rpc_prefix; + #[cfg(feature = "prometheus")] + { + if let Some(len) = inv.content_length { + self.rpc_client.stats.rpc_recv_bytes.inc_by(len); + } + self.rpc_client.stats.rpc_recv.inc(); } - self.validate_invocation(inv).await.map_err(RpcError::Rpc)?; - Ok(()) - } + let inv = self.rpc_client().dechunk(inv, lattice).await?; + let (inv, claims) = self.rpc_client.validate_invocation(inv).await?; + self.validate_provider_invocation(&inv, &claims).await?; + + let rc = provider + .dispatch( + &Context { + actor: Some(inv.origin.public_key.clone()), + ..Default::default() + }, + Message { + method: &inv.operation, + arg: Cow::from(inv.msg), + }, + ) + .instrument(tracing::debug_span!("dispatch", public_key = %inv.origin.public_key, operation = %inv.operation)) + .await + .map(|m| m.arg.to_vec()); - async fn validate_invocation(&self, inv: &Invocation) -> Result<(), String> { - let vr = wascap::jwt::validate_token::(&inv.encoded_claims) - .map_err(|e| format!("{}", e))?; - if vr.expired { - return Err("Invocation claims token expired".into()); - } - if !vr.signature_valid { - return Err("Invocation claims signature invalid".into()); - } - if vr.cannot_use_yet { - return Err("Attempt to use invocation before claims token allows".into()); - } - let target_url = format!("{}/{}", inv.target.url(), &inv.operation); - let hash = crate::rpc_client::invocation_hash( - &target_url, - &inv.origin.url(), - &inv.operation, - &inv.msg, - ); - let claims = - wascap::prelude::Claims::::decode(&inv.encoded_claims) - .map_err(|e| format!("{}", e))?; - let inv_claims = claims - .metadata - .ok_or_else(|| "No wascap metadata found on claims".to_string())?; - if inv_claims.invocation_hash != hash { - return Err(format!( - "Invocation hash does not match signed claims hash ({} / {})", - inv_claims.invocation_hash, hash - )); - } - if !inv.host_id.starts_with('N') && inv.host_id.len() != 56 { - return Err(format!("Invalid host ID on invocation: '{}'", inv.host_id)); - } - if !self.host_data.cluster_issuers.contains(&claims.issuer) { - return Err("Issuer of this invocation is not in list of cluster issuers".into()); - } - if inv_claims.target_url != target_url { - return Err(format!( - "Invocation claims and invocation target URL do not match: {} != {}", - &inv_claims.target_url, &target_url - )); - } - if inv_claims.origin_url != inv.origin.url() { - return Err("Invocation claims and invocation origin URL do not match".into()); - } - // verify target public key is my key - if inv.target.public_key != self.host_data.provider_key { - return Err(format!( - "target key mismatch: {} != {}", - &inv.target.public_key, &self.host_data.host_id - )); - } - // verify that the sending actor is linked with this provider - if !self.is_linked(&inv.origin.public_key).await { - return Err(format!("unlinked actor: '{}'", &inv.origin.public_key)); + #[cfg(feature = "prometheus")] + match &rc { + Err(_) => { + self.rpc_client.stats.rpc_recv_err.inc(); + } + Ok(vec) => { + self.rpc_client.stats.rpc_recv_resp_bytes.inc_by(vec.len() as u64); + } } - Ok(()) + rc } async fn subscribe_shutdown

( &self, provider: P, - shutdown_tx: oneshot::Sender, + shutdown_tx: tokio::sync::broadcast::Sender, ) -> RpcResult<()> where - P: ProviderDispatch + Send + Sync + Clone + 'static, + P: ProviderDispatch + Send + Sync + 'static, { let shutdown_topic = format!( "wasmbus.rpc.{}.{}.{}.shutdown", &self.lattice_prefix, &self.host_data.provider_key, self.host_data.link_name ); debug!("subscribing for shutdown : {}", &shutdown_topic); - let sub = self + let mut sub = self .rpc_client() - .get_async() - .unwrap() // we are only async - .subscribe(&shutdown_topic) + .client() + .subscribe(shutdown_topic) .await .map_err(|e| RpcError::Nats(e.to_string()))?; - // TODO: there should be validation on this message, but it's not signed by host yet let msg = sub.next().await; - + // TODO: there should be validation on this message, but it's not signed by host yet // Shutdown messages are unsigned (see https://github.com/wasmCloud/wasmcloud-otp/issues/256) // so we can't verify that this came from a trusted source. // When the above issue is fixed, verify the source and keep looping if it's invalid. - eprintln!("Received termination signal. Shutting down capability provider."); - debug!("Received termination signal. Shutting down capability provider."); - let (this, provider) = (self.clone(), provider.clone()); - if let Err(e) = tokio::spawn(async move { - // Tell provider to shutdown - before we shut down nats subscriptions, - // in case it needs to do any message passing during shutdown - if let Err(e) = provider.shutdown().await { - error!(error = %e, "got error during provider shutdown processing"); - } - // drain all subscriptions except this one - this.unsubscribe_all().await; - }) - .await - { - error!(error = %e, "joining thread shutdown/unsubscribe task"); + info!("Received termination signal. Shutting down capability provider."); + // Tell provider to shutdown - before we shut down nats subscriptions, + // in case it needs to do any message passing during shutdown + if let Err(e) = provider.shutdown().await { + error!(error = %e, "got error during provider shutdown processing"); } // send ack to host - if let Some(crate::anats::Message { reply: Some(reply_to), .. }) = msg.as_ref() { + if let Some(async_nats::Message { reply: Some(reply_to), .. }) = msg { let data = b"shutting down".to_vec(); - if let Err(e) = self.rpc_client().publish(reply_to, &data).await { + if let Err(e) = self.rpc_client().publish(reply_to, data).await { error!(error = %e, "failed to send shutdown ack"); } } - // unsubscribe from shutdown messages - let _ = sub.close().await; // ignore errors + // unsubscribe from shutdown topic + let _ = sub.unsubscribe().await; - // signal main thread to quit - if let Err(e) = shutdown_tx.send("bye".to_string()) { + // send shutdown signal to all listeners: quit all subscribers and signal main thread to quit + if let Err(e) = shutdown_tx.send(true) { error!(error = %e, "Problem shutting down: failure to send signal"); } + Ok(()) } - async fn subscribe_link_put

(&self, provider: P) -> RpcResult<()> + async fn subscribe_link_put

(&self, provider: P, mut quit: QuitSignal) -> RpcResult<()> where P: ProviderDispatch + Send + Sync + Clone + 'static, { @@ -605,55 +552,48 @@ impl HostBridge { &self.lattice_prefix, &self.host_data.provider_key, &self.host_data.link_name ); - debug!("subscribing for link put : {}", &ldput_topic); - let sub = self + let mut sub = self .rpc_client() - .get_async() - .unwrap() // we are only async - .subscribe(&ldput_topic) + .client() + .subscribe(ldput_topic) .await .map_err(|e| RpcError::Nats(e.to_string()))?; - self.add_subscription(sub.clone()).await; - //let provider = provider.clone(); let (this, provider) = (self.clone(), provider.clone()); - tokio::spawn(async move { - // TODO(ss): do we need to pin it with stream() before iterating? - while let Some(msg) = sub.next().await { - let span = tracing::error_span!( - "subscribe_link_put", - actor_id = tracing::field::Empty, - provider_id = tracing::field::Empty - ); - let _enter = span.enter(); - #[cfg(feature = "otel")] - crate::otel::attach_span_context(&msg); - if let Some(ld) = this.parse_msg::(&msg, "link.put") { - span.record("actor_id", &tracing::field::display(&ld.actor_id)); - span.record("provider_id", &tracing::field::display(&ld.provider_id)); - if this.is_linked(&ld.actor_id).in_current_span().await { - warn!("Ignoring duplicate link put"); - } else { - info!("Linking actor with provider"); - match provider.put_link(&ld).in_current_span().await { - Ok(true) => { - this.put_link(ld).in_current_span().await; - } - Ok(false) => { - // authorization failed or parameters were invalid - warn!("put_link denied"); - } - Err(e) => { - error!(error = %e, "put_link failed"); - } + process_until_quit!(sub, quit, msg, { + let span = tracing::error_span!( + "subscribe_link_put", + actor_id = tracing::field::Empty, + provider_id = tracing::field::Empty + ); + let _enter = span.enter(); + if let Some(ld) = this.parse_msg::(&msg, "link.put") { + span.record("actor_id", &tracing::field::display(&ld.actor_id)); + span.record("provider_id", &tracing::field::display(&ld.provider_id)); + span.record("contract_id", &tracing::field::display(&ld.contract_id)); + span.record("link_name", &tracing::field::display(&ld.link_name)); + if this.is_linked(&ld.actor_id).await { + warn!("Ignoring duplicate link put"); + } else { + info!("Linking actor with provider"); + match provider.put_link(&ld).await { + Ok(true) => { + this.put_link(ld).await; + } + Ok(false) => { + // authorization failed or parameters were invalid + warn!("put_link denied"); + } + Err(error) => { + error!(%error, "put_link failed"); } } } - } - }); + } // msg is "link.put" + }); // process until quit Ok(()) } - async fn subscribe_link_del

(&self, provider: P) -> RpcResult<()> + async fn subscribe_link_del

(&self, provider: P, mut quit: QuitSignal) -> RpcResult<()> where P: ProviderDispatch + Send + Sync + Clone + 'static, { @@ -663,150 +603,106 @@ impl HostBridge { &self.lattice_prefix, &self.host_data.provider_key, &self.host_data.link_name ); debug!(topic = %link_del_topic, "subscribing for link del"); - let sub = self + let mut sub = self .rpc_client() - .get_async() - .unwrap() // we are only async - .subscribe(&link_del_topic) + .client() + .subscribe(link_del_topic.clone()) .await .map_err(|e| RpcError::Nats(e.to_string()))?; - self.add_subscription(sub.clone()).await; let (this, provider) = (self.clone(), provider.clone()); - tokio::spawn(async move { - while let Some(msg) = sub.next().await { - let span = tracing::trace_span!("subscribe_link_del", topic = %link_del_topic); - let _enter = span.enter(); - if let Some(ld) = &this.parse_msg::(&msg, "link.del") { - this.delete_link(&ld.actor_id).in_current_span().await; - // notify provider that link is deleted - provider.delete_link(&ld.actor_id).in_current_span().await; - } + process_until_quit!(sub, quit, msg, { + let span = tracing::trace_span!("subscribe_link_del", topic = %link_del_topic); + let _enter = span.enter(); + if let Some(ld) = &this.parse_msg::(&msg, "link.del") { + this.delete_link(&ld.actor_id).await; + // notify provider that link is deleted + provider.delete_link(&ld.actor_id).await; } }); Ok(()) } - async fn subscribe_health

(&self, provider: P) -> RpcResult<()> + async fn subscribe_health

(&self, provider: P, mut quit: QuitSignal) -> RpcResult<()> where - P: ProviderDispatch + Send + Sync + Clone + 'static, + P: ProviderDispatch + Send + Sync + 'static, { let topic = format!( "wasmbus.rpc.{}.{}.{}.health", &self.lattice_prefix, &self.host_data.provider_key, &self.host_data.link_name ); - let sub = self + let mut sub = self .rpc_client() - .get_async() - .unwrap() // we are only async - .subscribe(&topic) + .client() + .subscribe(topic) .await .map_err(|e| RpcError::Nats(e.to_string()))?; - self.add_subscription(sub.clone()).await; let this = self.clone(); - tokio::spawn(async move { - while let Some(msg) = sub.next().await { - // placeholder arg - let arg = HealthCheckRequest {}; - let resp = match provider.health_request(&arg).await { - Ok(resp) => resp, - Err(e) => { - error!(error = %e, "error generating health check response"); - HealthCheckResponse { - healthy: false, - message: Some(e.to_string()), - } + process_until_quit!(sub, quit, msg, { + let arg = HealthCheckRequest {}; + let resp = match provider.health_request(&arg).await { + Ok(resp) => resp, + Err(e) => { + error!(error = %e, "error generating health check response"); + HealthCheckResponse { + healthy: false, + message: Some(e.to_string()), } - }; - let buf = if this.host_data.is_test() { - Ok(serde_json::to_vec(&resp).unwrap()) - } else { - serialize(&resp) - }; - match buf { - Ok(t) => { - if let Some(reply_to) = msg.reply.as_ref() { - if let Err(e) = this.rpc_client().publish(reply_to, &t).await { - error!(error = %e, "failed sending health check response"); - } + } + }; + let buf = if this.host_data.is_test() { + Ok(serde_json::to_vec(&resp).unwrap()) + } else { + serialize(&resp) + }; + match buf { + Ok(t) => { + if let Some(reply_to) = msg.reply { + if let Err(e) = this.rpc_client().publish(reply_to, t).await { + error!(error = %e, "failed sending health check response"); } } - Err(e) => { - // extremely unlikely that InvocationResponse would fail to serialize - error!(error = %e, "failed serializing HealthCheckResponse"); - } + } + Err(e) => { + // extremely unlikely that InvocationResponse would fail to serialize + error!(error = %e, "failed serializing HealthCheckResponse"); } } }); Ok(()) } -} -async fn publish_invocation_response( - rpc_client: &RpcClient, - reply_to: String, - response: InvocationResponse, -) -> Result<(), String> { - let content_length = Some(response.msg.len() as u64); - - let response = { - cfg_if! { - if #[cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] { - let inv_id = response.invocation_id.clone(); - if crate::chunkify::needs_chunking(response.msg.len()) { - let msg = response.msg; - let lattice = rpc_client.lattice_prefix().to_string(); - tokio::task::spawn_blocking(move || { - let ce = chunkify_endpoint(None, lattice) - .map_err(|e| format!("connecting for chunkifying: {}", &e.to_string()))?; - ce.chunkify_response(&inv_id, &mut msg.as_slice()) - .map_err(|e| e.to_string()) - }) - .await - .map_err(|je| format!("join/response-chunk: {}", je))??; - InvocationResponse { - msg: Vec::new(), - content_length, - ..response - } - } else { - InvocationResponse { - content_length, - ..response - } + /// extra validation performed by providers + async fn validate_provider_invocation( + &self, + inv: &Invocation, + claims: &Claims, + ) -> Result<(), String> { + if !self.host_data.cluster_issuers.contains(&claims.issuer) { + return Err("Issuer of this invocation is not in list of cluster issuers".into()); } - } else { - InvocationResponse { - content_length, - ..response - } - } - } - }; - - match serialize(&response) { - Ok(t) => { - if let Err(e) = rpc_client.publish(&reply_to, &t).await { - error!( - %reply_to, - error = %e, - "failed sending rpc response", - ); - } + // verify target public key is my key + if inv.target.public_key != self.host_data.provider_key { + return Err(format!( + "target key mismatch: {} != {}", + &inv.target.public_key, &self.host_data.host_id + )); } - Err(e) => { - // extremely unlikely that InvocationResponse would fail to serialize - error!(error = %e, "failed serializing InvocationResponse"); + + // verify that the sending actor is linked with this provider + if !self.is_linked(&inv.origin.public_key).await { + return Err(format!("unlinked actor: '{}'", &inv.origin.public_key)); } + + Ok(()) } - Ok(()) } pub struct ProviderTransport<'send> { pub bridge: &'send HostBridge, pub ld: &'send LinkDefinition, - timeout: StdMutex, + timeout: StdMutex, } impl<'send> ProviderTransport<'send> { @@ -822,7 +718,7 @@ impl<'send> ProviderTransport<'send> { pub fn new_with_timeout( ld: &'send LinkDefinition, bridge: Option<&'send HostBridge>, - timeout: Option, + timeout: Option, ) -> Self { #[allow(clippy::redundant_closure)] let bridge = bridge.unwrap_or_else(|| crate::provider_main::get_host_bridge()); @@ -860,7 +756,11 @@ impl<'send> Transport for ProviderTransport<'send> { .unwrap_or(DEFAULT_RPC_TIMEOUT_MILLIS) } }; - self.bridge.rpc_client().send_timeout(origin, target, req, timeout).await + let lattice = &self.bridge.lattice_prefix; + self.bridge + .rpc_client() + .send_timeout(origin, target, lattice, req, timeout) + .await } fn set_timeout(&self, interval: Duration) { diff --git a/rpc-rs/src/provider_main.rs b/rpc-rs/src/provider_main.rs index 04d4237..ca02e05 100644 --- a/rpc-rs/src/provider_main.rs +++ b/rpc-rs/src/provider_main.rs @@ -27,7 +27,6 @@ use crate::{ core::HostData, error::RpcError, provider::{HostBridge, ProviderDispatch}, - rpc_client::NatsClientType, }; lazy_static::lazy_static! { @@ -131,10 +130,11 @@ where { configure_tracing( friendly_name.unwrap_or_else(|| host_data.provider_key.clone()), - host_data.structured_logging_enabled, + host_data.structured_logging, ); - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let (shutdown_tx, mut shutdown_rx) = tokio::sync::broadcast::channel::(1); + eprintln!( "Starting capability provider {} instance {} with nats url {}", &host_data.provider_key, &host_data.instance_id, &host_data.lattice_rpc_url, @@ -145,45 +145,28 @@ where } else { crate::provider::DEFAULT_NATS_ADDR }; - let nats_server = nats_aflowt::ServerAddress::from_str(nats_addr).map_err(|e| { + let nats_server = async_nats::ServerAddr::from_str(nats_addr).map_err(|e| { RpcError::InvalidParameter(format!("Invalid nats server url '{}': {}", nats_addr, e)) })?; - let nc = { - cfg_if::cfg_if! { - if #[cfg(feature="async_rewrite")] { - - NatsClientType::AsyncRewrite(nats_experimental::connect(nats_addr).await - .map_err(|e| { - RpcError::ProviderInit(format!("nats connection to {} failed: {}", nats_addr, e)) - })?) - - } else { - let nats_opts = match ( - host_data.lattice_rpc_user_jwt.trim(), - host_data.lattice_rpc_user_seed.trim(), - ) { - ("", "") => nats_aflowt::Options::default(), - (rpc_jwt, rpc_seed) => { - let kp = nkeys::KeyPair::from_seed(rpc_seed).unwrap(); - let jwt = rpc_jwt.to_owned(); - nats_aflowt::Options::with_jwt( - move || Ok(jwt.to_owned()), - move |nonce| kp.sign(nonce).unwrap(), - ) - } - }; - // Connect to nats - NatsClientType::Async(nats_opts - .max_reconnects(None) - .connect(vec![nats_server]) - .await - .map_err(|e| { - RpcError::ProviderInit(format!("nats connection to {} failed: {}", nats_addr, e)) - })?) + let nc = crate::rpc_client::with_connection_event_logging( + match ( + host_data.lattice_rpc_user_jwt.trim(), + host_data.lattice_rpc_user_seed.trim(), + ) { + ("", "") => async_nats::ConnectOptions::default(), + (rpc_jwt, rpc_seed) => { + let key_pair = std::sync::Arc::new(nkeys::KeyPair::from_seed(rpc_seed).unwrap()); + let jwt = rpc_jwt.to_owned(); + async_nats::ConnectOptions::with_jwt(jwt, move |nonce| { + let key_pair = key_pair.clone(); + async move { key_pair.sign(&nonce).map_err(async_nats::AuthError::new) } + }) } - } - }; + }, + ) + .connect(nats_server) + .await?; // initialize HostBridge let bridge = HostBridge::new_client(nc, &host_data)?; @@ -205,16 +188,22 @@ where } // subscribe to nats topics - let _join = bridge.connect(provider_dispatch, shutdown_tx).await.map_err(|e| { - RpcError::ProviderInit(format!("fatal error setting up subscriptions: {}", e)) - })?; + let _join = bridge + .connect( + provider_dispatch, + &shutdown_tx, + &host_data.lattice_rpc_prefix, + ) + .await; - // process subscription events and log messages, waiting for shutdown signal - let _ = shutdown_rx.await; + // run until we receive a shutdown request from host + let _ = shutdown_rx.recv().await; // close chunkifiers - #[cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] - crate::chunkify::shutdown(); + let _ = tokio::task::spawn_blocking(crate::chunkify::shutdown).await; + + // flush async_nats client + bridge.flush().await; Ok(()) } diff --git a/rpc-rs/src/rpc_client.rs b/rpc-rs/src/rpc_client.rs index 7fc6dc6..3b10b7b 100644 --- a/rpc-rs/src/rpc_client.rs +++ b/rpc-rs/src/rpc_client.rs @@ -2,21 +2,25 @@ use std::{ convert::{TryFrom, TryInto}, + fmt, sync::Arc, time::Duration, }; -use nats_aflowt::{header::HeaderMap, Connection}; -use ring::digest::{Context, SHA256}; +use async_nats::HeaderMap; +use futures::Future; +#[cfg(feature = "prometheus")] +use prometheus::{IntCounter, Opts}; use serde::{de::DeserializeOwned, Serialize}; use serde_json::Value as JsonValue; -use tracing::{debug, error, instrument, trace}; +use tracing::{debug, error, info, instrument, trace, warn}; -#[cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] -use crate::chunkify::chunkify_endpoint; #[cfg(feature = "otel")] use crate::otel::OtelHeaderInjector; + +use crate::wascap::{jwt, prelude::Claims}; use crate::{ + chunkify, common::Message, core::{Invocation, InvocationResponse, WasmCloudEntity}, error::{RpcError, RpcResult}, @@ -42,26 +46,54 @@ pub(crate) const CHUNK_RPC_EXTRA_TIME: Duration = Duration::from_secs(13); /// #[derive(Clone)] pub struct RpcClient { - /// sync or async nats client - client: NatsClientType, - /// lattice rpc prefix - lattice_prefix: String, - /// secrets for signing invocations + client: async_nats::Client, key: Arc, - /// host id for invocations + /// host id (public key) for invocations host_id: String, /// timeout for rpc messages timeout: Option, + + #[cfg(feature = "prometheus")] + pub(crate) stats: Arc, } -#[derive(Clone)] -#[non_exhaustive] -pub(crate) enum NatsClientType { - Async(crate::anats::Connection), - #[cfg(feature = "async_rewrite")] - AsyncRewrite(nats_experimental::Client), - //#[cfg(feature = "chunkify")] - //Sync(nats::Connection), +// just so RpcClient can be included in other Debug structs +impl fmt::Debug for RpcClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("RpcClient()") + } +} + +#[cfg(feature = "prometheus")] +#[derive(Debug)] +pub struct RpcStats { + // number of rpc nats messages sent + pub(crate) rpc_sent: IntCounter, + // number of errors sending - of all types: including errors while receiving responses, and timeouts + pub(crate) rpc_sent_err: IntCounter, + // number of messages sent that required chunking + pub(crate) rpc_sent_chunky: IntCounter, + // number of responses received that were chunked + pub(crate) rpc_sent_resp_chunky: IntCounter, + // total bytes sent (chunked & not chunked). bytes are for sent requests. + pub(crate) rpc_sent_bytes: IntCounter, + // total bytes received in response to sent messages + pub(crate) rpc_sent_resp_bytes: IntCounter, + // number of timeout errors sending. Note that timeout errors are also included in rpc_sent_err + pub(crate) rpc_sent_timeouts: IntCounter, + + // number of rpc messages received from rpc subscription + pub(crate) rpc_recv: IntCounter, + // recv errors include errors receiving subscription messages and replying to them + pub(crate) rpc_recv_err: IntCounter, + // number of rpc messages received that were chunkified + pub(crate) rpc_recv_chunky: IntCounter, + // number of rpc message responses that were chunkified + pub(crate) rpc_recv_resp_chunky: IntCounter, + // bytes received in rpc (subscription) messages + pub(crate) rpc_recv_bytes: IntCounter, + // bytes sent in response to received rpc messages + pub(crate) rpc_recv_resp_bytes: IntCounter, } /// Returns the rpc topic (subject) name for sending to an actor or provider. @@ -83,69 +115,39 @@ pub fn rpc_topic(entity: &WasmCloudEntity, lattice_prefix: &str) -> String { impl RpcClient { /// Constructs a new RpcClient with an async nats connection. - /// parameters: async nats client, lattice rpc prefix (usually "default"), + /// parameters: async nats client, rpc timeout /// secret key for signing messages, host_id, and optional timeout. pub fn new( - nats: crate::anats::Connection, - lattice_prefix: &str, - key: wascap::prelude::KeyPair, + nats: async_nats::Client, host_id: String, timeout: Option, + key_pair: Arc, ) -> Self { - Self::new_client( - NatsClientType::Async(nats), - lattice_prefix, - key, - host_id, - timeout, - ) + Self::new_client(nats, host_id, timeout, key_pair) } /// Constructs a new RpcClient with a nats connection. /// parameters: nats client, lattice rpc prefix (usually "default"), /// secret key for signing messages, host_id, and optional timeout. pub(crate) fn new_client( - nats: NatsClientType, - lattice_prefix: &str, - key: wascap::prelude::KeyPair, + nats: async_nats::Client, host_id: String, timeout: Option, + key_pair: Arc, ) -> Self { RpcClient { client: nats, - lattice_prefix: lattice_prefix.to_string(), - key: Arc::new(key), host_id, timeout, + #[cfg(feature = "prometheus")] + stats: Arc::new(RpcStats::init(key_pair.public_key())), + key: key_pair, } } /// convenience method for returning async client - /// If the client is not the correct type, returns None - #[cfg(feature = "async_rewrite")] - pub fn get_async(&self) -> Option { - use std::borrow::Borrow; - match self.client.borrow() { - NatsClientType::AsyncRewrite(nc) => Some(nc.clone()), - _ => None, - } - } - - /// convenience method for returning async client - /// If the client is not the correct type, returns None - #[cfg(not(feature = "async_rewrite"))] - pub fn get_async(&self) -> Option { - use std::borrow::Borrow; - #[allow(unreachable_patterns)] - match self.client.borrow() { - NatsClientType::Async(nats) => Some(nats.clone()), - _ => None, - } - } - - /// returns the lattice prefix - pub fn lattice_prefix(&self) -> &str { - self.lattice_prefix.as_str() + pub fn client(&self) -> async_nats::Client { + self.client.clone() } /// Replace the default timeout with the specified value. @@ -159,6 +161,7 @@ impl RpcClient { &self, origin: WasmCloudEntity, target: Target, + lattice: &str, method: &str, data: JsonValue, ) -> RpcResult @@ -168,14 +171,13 @@ impl RpcClient { Target: Into, { let msg = JsonMessage(method, data).try_into()?; - let bytes = self.send(origin, target, msg).await?; + let bytes = self.send(origin, target, lattice, msg).await?; let resp = response_to_json::(&bytes)?; Ok(resp) } /// Send a wasmbus rpc message by wrapping with an Invocation before sending over nats. /// 'target' may be &str or String for sending to an actor, or a WasmCloudEntity (for actor or provider) - /// If nats client is sync, this can block the current thread. /// If a response is not received within the default timeout, the Error RpcError::Timeout is returned. /// If the client timeout has been set, this call is equivalent to send_timeout passing in the /// default timeout. @@ -183,31 +185,45 @@ impl RpcClient { &self, origin: WasmCloudEntity, target: Target, + lattice: &str, message: Message<'_>, ) -> RpcResult> where Target: Into, { - self.inner_rpc(origin, target, message, true, self.timeout).await + let rc = self.inner_rpc(origin, target, lattice, message, true, self.timeout).await; + #[cfg(feature = "prometheus")] + { + if rc.is_err() { + self.stats.rpc_sent_err.inc() + } + } + rc } /// Send a wasmbus rpc message, with a timeout. /// The rpc message is wrapped with an Invocation before sending over nats. /// 'target' may be &str or String for sending to an actor, or a WasmCloudEntity (for actor or provider) - /// If nats client is sync, this can block the current thread until either the response is received, - /// or the timeout expires. If the timeout expires before the response is received, - /// this returns Error RpcError::Timeout. + /// If the timeout expires before the response is received, this returns Error RpcError::Timeout. pub async fn send_timeout( &self, origin: WasmCloudEntity, target: Target, + lattice: &str, message: Message<'_>, timeout: Duration, ) -> RpcResult> where Target: Into, { - self.inner_rpc(origin, target, message, true, Some(timeout)).await + let rc = self.inner_rpc(origin, target, lattice, message, true, Some(timeout)).await; + #[cfg(feature = "prometheus")] + { + if rc.is_err() { + self.stats.rpc_sent_err.inc(); + } + } + rc } /// Send a wasmbus rpc message without waiting for response. @@ -221,21 +237,30 @@ impl RpcClient { &self, origin: WasmCloudEntity, target: Target, + lattice: &str, message: Message<'_>, ) -> RpcResult<()> where Target: Into, { - let _ = self.inner_rpc(origin, target, message, false, None).await?; - Ok(()) + let rc = self.inner_rpc(origin, target, lattice, message, false, None).await; + match rc { + Err(e) => { + #[cfg(feature = "prometheus")] + self.stats.rpc_sent_err.inc(); + Err(e) + } + Ok(_) => Ok(()), + } } /// request or publish an rpc invocation - #[instrument(level = "debug", skip(self, origin, target, message), fields(issuer = tracing::field::Empty, origin_url = tracing::field::Empty, subject = tracing::field::Empty, target_url = tracing::field::Empty, method = tracing::field::Empty))] + #[instrument(level = "debug", skip(self, origin, target, message), fields(issuer = tracing::field::Empty, origin_url = tracing::field::Empty, inv_id = tracing::field::Empty, target_url = tracing::field::Empty, method = tracing::field::Empty, provider_id = tracing::field::Empty))] async fn inner_rpc( &self, origin: WasmCloudEntity, target: Target, + lattice: &str, message: Message<'_>, expect_response: bool, timeout: Option, @@ -246,20 +271,22 @@ impl RpcClient { let target = target.into(); let origin_url = origin.url(); let subject = make_uuid(); - let issuer = &self.key.public_key(); + let issuer = self.key.public_key(); let raw_target_url = target.url(); let target_url = format!("{}/{}", raw_target_url, &message.method); // Record all of the fields on the span. To avoid extra allocations, we are only going to // record here after we generate/derive the values - let current_span = tracing::span::Span::current(); - current_span.record("issuer", &tracing::field::display(issuer)); - current_span.record("origin_url", &tracing::field::display(&origin_url)); - current_span.record("subject", &tracing::field::display(&subject)); - current_span.record("target_url", &tracing::field::display(&raw_target_url)); - current_span.record("method", &tracing::field::display(message.method)); - - let claims = wascap::prelude::Claims::::new( + let span = tracing::span::Span::current(); + span.record("provider_id", &tracing::field::display(&issuer)); + span.record("method", &tracing::field::display(&message.method)); + span.record("lattice_id", &tracing::field::display(&lattice)); + span.record("target_id", &tracing::field::display(&target.public_key)); + span.record("subject", &tracing::field::display(&subject)); + span.record("issuer", &tracing::field::display(&issuer)); + + //debug!("rpc_client sending"); + let claims = Claims::::new( issuer.clone(), subject.clone(), &target_url, @@ -267,27 +294,18 @@ impl RpcClient { &invocation_hash(&target_url, &origin_url, message.method, &message.arg), ); - let topic = rpc_topic(&target, &self.lattice_prefix); + let topic = rpc_topic(&target, lattice); let method = message.method.to_string(); let len = message.arg.len(); - let chunkify = { - cfg_if::cfg_if! { - if #[cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] { - crate::chunkify::needs_chunking(len) - } else { - false - } - } - }; + let chunkify = chunkify::needs_chunking(len); - #[allow(unused_variables)] let (invocation, body) = { let mut inv = Invocation { origin, target, operation: method.clone(), id: subject, - encoded_claims: claims.encode(&self.key).unwrap(), + encoded_claims: claims.encode(&self.key).unwrap_or_default(), host_id: self.host_id.clone(), content_length: Some(len as u64), ..Default::default() @@ -300,24 +318,26 @@ impl RpcClient { } }; let nats_body = crate::common::serialize(&invocation)?; - - #[cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] if let Some(body) = body { let inv_id = invocation.id.clone(); debug!(invocation_id = %inv_id, %len, "chunkifying invocation"); // start chunking thread - let lattice = self.lattice_prefix().to_string(); - tokio::task::spawn_blocking(move || { - let ce = chunkify_endpoint(None, lattice)?; + let lattice = lattice.to_string(); + if let Err(error) = tokio::task::spawn_blocking(move || { + let ce = chunkify::chunkify_endpoint(None, lattice)?; ce.chunkify(&inv_id, &mut body.as_slice()) }) + .await + .map_err(|join_e| RpcError::Other(join_e.to_string()))? + { + error!(%error, "chunking error"); + return Err(RpcError::Other(error.to_string())); + } // I tried starting the send to ObjectStore in background thread, // and then send the rpc, but if the objectstore hasn't completed, // the recipient gets a missing object error, // so we need to flush this first - .await // any errors sending chunks will cause send to fail with RpcError::Nats - .map_err(|e| RpcError::Other(e.to_string()))??; } let timeout = if chunkify { @@ -325,31 +345,39 @@ impl RpcClient { } else { timeout }; - trace!("rpc send"); + #[cfg(feature = "prometheus")] + { + self.stats.rpc_sent.inc(); + if let Some(len) = invocation.content_length { + self.stats.rpc_sent_bytes.inc_by(len); + } + if chunkify { + self.stats.rpc_sent_chunky.inc(); + } + } if expect_response { + let this = self.clone(); + let topic_ = topic.clone(); let payload = if let Some(timeout) = timeout { - // if we expect a response before timeout, finish sending all chunks, then wait for response - match tokio::time::timeout(timeout, self.request(&topic, &nats_body)).await { - Ok(Ok(result)) => Ok(result), - Ok(Err(rpc_err)) => Err(RpcError::Nats(format!( - "rpc send error: {}: {}", - target_url, rpc_err - ))), - Err(timeout_err) => { - error!(error = %timeout_err, "rpc timeout: sending to target"); - Err(RpcError::Timeout(format!( - "sending to {}: {}", - &target_url, timeout_err - ))) + match tokio::time::timeout(timeout, this.request(topic, nats_body)).await { + Err(elapsed) => { + #[cfg(feature = "prometheus")] + self.stats.rpc_sent_timeouts.inc(); + Err(RpcError::Timeout(elapsed.to_string())) } + Ok(Ok(data)) => Ok(data), + Ok(Err(err)) => Err(RpcError::Nats(err.to_string())), } } else { - // no timeout, wait indefinitely or until host times out - self.request(&topic, &nats_body) + this.request(topic, nats_body) .await - .map_err(|e| RpcError::Nats(format!("rpc send error: {}: {}", target_url, e))) - }?; + .map_err(|e| RpcError::Nats(e.to_string())) + } + .map_err(|error| { + error!(%error, topic=%topic_, "sending request"); + error + })?; let inv_response = crate::common::deserialize::(&payload).map_err(|e| { @@ -357,14 +385,23 @@ impl RpcClient { })?; match inv_response.error { None => { + #[cfg(feature = "prometheus")] + { + if let Some(len) = inv_response.content_length { + self.stats.rpc_sent_resp_bytes.inc_by(len); + } + } // was response chunked? - #[cfg(all(feature = "chunkify", not(target_arch = "wasm32")))] let msg = if inv_response.content_length.is_some() && inv_response.content_length.unwrap() > inv_response.msg.len() as u64 { - let lattice = self.lattice_prefix().to_string(); + let lattice = lattice.to_string(); + #[cfg(feature = "prometheus")] + { + self.stats.rpc_sent_resp_chunky.inc(); + } tokio::task::spawn_blocking(move || { - let ce = chunkify_endpoint(None, lattice)?; + let ce = chunkify::chunkify_endpoint(None, lattice)?; ce.get_unchunkified_response(&inv_response.invocation_id) }) .await @@ -372,8 +409,6 @@ impl RpcClient { } else { inv_response.msg }; - #[cfg(not(feature = "chunkify"))] - let msg = inv_response.msg; trace!("rpc ok response"); Ok(msg) } @@ -384,9 +419,9 @@ impl RpcClient { } } } else { - self.publish(&topic, &nats_body) + self.publish(topic, nats_body) .await - .map_err(|e| RpcError::Nats(format!("rpc send error: {}: {}", target_url, e)))?; + .map_err(|e| RpcError::Nats(format!("publish error: {}: {}", target_url, e)))?; Ok(Vec::new()) } } @@ -395,93 +430,221 @@ impl RpcClient { /// This can be used for general nats messages, not just wasmbus actor/provider messages. /// If this client has a default timeout, and a response is not received within /// the appropriate time, an error will be returned. - pub async fn request(&self, subject: &str, data: &[u8]) -> RpcResult> { - use std::borrow::Borrow as _; - - let bytes = match self.client.borrow() { - NatsClientType::Async(ref nats) => { - #[cfg(feature = "otel")] - let headers: Option = - Some(OtelHeaderInjector::default_with_span().into()); - #[cfg(not(feature = "otel"))] - let headers: Option = None; - let resp = request_with_headers_or_timeout( - nats, - subject, - headers.as_ref(), - self.timeout, - data, - ) - .await - .map_err(|e| RpcError::Nats(e.to_string()))?; - resp.data + #[instrument(level = "debug", skip_all, fields(subject = %subject))] + pub async fn request(&self, subject: String, payload: Vec) -> RpcResult> { + #[cfg(feature = "otel")] + let headers: Option = Some(OtelHeaderInjector::default_with_span().into()); + #[cfg(not(feature = "otel"))] + let headers: Option = None; + + let nc = self.client(); + match self + .maybe_timeout(self.timeout, async move { + if let Some(headers) = headers { + nc.request_with_headers(subject, headers, payload.into()).await + } else { + nc.request(subject, payload.into()).await + } + }) + .await + { + Err(error) => { + error!(%error, "sending request"); + Err(error) } - // These two never get invoked - #[cfg(feature = "async_rewrite")] - NatsClientType::AsyncRewrite(_) => unimplemented!(), - //NatsClientType::Sync(_) => unimplemented!(), - }; - Ok(bytes) + Ok(message) => Ok(message.payload.to_vec()), + } } /// Send a nats message with no reply-to. Do not wait for a response. /// This can be used for general nats messages, not just wasmbus actor/provider messages. - pub async fn publish(&self, subject: &str, data: &[u8]) -> RpcResult<()> { - use std::borrow::Borrow as _; - - match self.client.borrow() { - NatsClientType::Async(nats) => { - #[cfg(feature = "otel")] - let headers: Option = - Some(OtelHeaderInjector::default_with_span().into()); - #[cfg(not(feature = "otel"))] - let headers: Option = None; - nats.publish_with_reply_or_headers(subject, None, headers.as_ref(), data) + #[instrument(level = "debug", skip_all, fields(subject = %subject))] + pub async fn publish(&self, subject: String, payload: Vec) -> RpcResult<()> { + #[cfg(feature = "otel")] + let headers: Option = Some(OtelHeaderInjector::default_with_span().into()); + #[cfg(not(feature = "otel"))] + let headers: Option = None; + + let nc = self.client(); + self.maybe_timeout(self.timeout, async move { + if let Some(headers) = headers { + nc.publish_with_headers(subject, headers, payload.into()) + .await + .map_err(|e| RpcError::Nats(e.to_string())) + } else { + nc.publish(subject, payload.into()) .await - .map_err(|e| RpcError::Nats(e.to_string()))? + .map_err(|e| RpcError::Nats(e.to_string())) + } + }) + .await?; + Ok(()) + } + + pub async fn publish_invocation_response( + &self, + reply_to: String, + response: InvocationResponse, + lattice: &str, + ) -> Result<(), String> { + let content_length = Some(response.msg.len() as u64); + let response = { + let inv_id = response.invocation_id.clone(); + if chunkify::needs_chunking(response.msg.len()) { + #[cfg(feature = "prometheus")] + { + self.stats.rpc_recv_resp_chunky.inc(); + } + let msg = response.msg; + let lattice = lattice.to_string(); + tokio::task::spawn_blocking(move || { + let ce = chunkify::chunkify_endpoint(None, lattice) + .map_err(|e| format!("connecting for chunkifying: {}", &e.to_string()))?; + ce.chunkify_response(&inv_id, &mut msg.as_slice()) + .map_err(|e| e.to_string()) + }) + .await + .map_err(|je| format!("join/response-chunk: {}", je))??; + InvocationResponse { + msg: Vec::new(), + content_length, + ..response + } + } else { + InvocationResponse { content_length, ..response } + } + }; + + match crate::common::serialize(&response) { + Ok(t) => { + if let Err(e) = self.client().publish(reply_to.clone(), t.into()).await { + error!( + %reply_to, + error = %e, + "failed sending rpc response", + ); + } + } + Err(e) => { + // extremely unlikely that InvocationResponse would fail to serialize + error!(error = %e, "failed serializing InvocationResponse"); } - // These two never get invoked - #[cfg(feature = "async_rewrite")] - NatsClientType::AsyncRewrite(_) => unimplemented!(), } Ok(()) } -} -/// Copied straight from aflowt because the function isn't public -async fn request_with_headers_or_timeout( - nats: &Connection, - subject: &str, - maybe_headers: Option<&HeaderMap>, - maybe_timeout: Option, - msg: impl AsRef<[u8]>, -) -> std::io::Result { - // Publish a request. - let reply = nats.new_inbox(); - let sub = nats.subscribe(&reply).await?; - nats.publish_with_reply_or_headers(subject, Some(reply.as_str()), maybe_headers, msg) - .await?; + pub async fn dechunk(&self, mut inv: Invocation, lattice: &str) -> RpcResult { + if inv.content_length.is_some() && inv.content_length.unwrap() > inv.msg.len() as u64 { + #[cfg(feature = "prometheus")] + { + self.stats.rpc_recv_chunky.inc(); + } + let inv_id = inv.id.clone(); + let lattice = lattice.to_string(); + inv.msg = tokio::task::spawn_blocking(move || { + let ce = chunkify::chunkify_endpoint(None, lattice) + .map_err(|e| format!("connecting for de-chunkifying: {}", &e.to_string()))?; + ce.get_unchunkified(&inv_id).map_err(|e| e.to_string()) + }) + .await + .map_err(|je| format!("join/dechunk-validate: {}", je))??; + } + Ok(inv) + } - // Wait for the response - let result = if let Some(timeout) = maybe_timeout { - sub.next_timeout(timeout).await - } else if let Some(msg) = sub.next().await { - Ok(msg) - } else { - Err(std::io::ErrorKind::ConnectionReset.into()) - }; - - // Check for no responder status. - if let Ok(msg) = result.as_ref() { - if msg.is_no_responders() { - return Err(std::io::Error::new( - std::io::ErrorKind::NotFound, - "no responders", + /// Initial validation of received message. See provider::validate_provider_invocation for second part. + pub async fn validate_invocation( + &self, + inv: Invocation, + ) -> Result<(Invocation, Claims), String> { + let vr = jwt::validate_token::(&inv.encoded_claims) + .map_err(|e| format!("{}", e))?; + if vr.expired { + return Err("Invocation claims token expired".into()); + } + if !vr.signature_valid { + return Err("Invocation claims signature invalid".into()); + } + if vr.cannot_use_yet { + return Err("Attempt to use invocation before claims token allows".into()); + } + let target_url = format!("{}/{}", inv.target.url(), &inv.operation); + let hash = invocation_hash(&target_url, &inv.origin.url(), &inv.operation, &inv.msg); + let claims = + Claims::::decode(&inv.encoded_claims).map_err(|e| format!("{}", e))?; + let inv_claims = claims + .metadata + .as_ref() + .ok_or_else(|| "No wascap metadata found on claims".to_string())?; + if inv_claims.invocation_hash != hash { + return Err(format!( + "Invocation hash does not match signed claims hash ({} / {})", + inv_claims.invocation_hash, hash + )); + } + if !inv.host_id.starts_with('N') && inv.host_id.len() != 56 { + return Err(format!("Invalid host ID on invocation: '{}'", inv.host_id)); + } + + if inv_claims.target_url != target_url { + return Err(format!( + "Invocation claims and invocation target URL do not match: {} != {}", + &inv_claims.target_url, &target_url )); } + if inv_claims.origin_url != inv.origin.url() { + return Err("Invocation claims and invocation origin URL do not match".into()); + } + Ok((inv, claims)) } - result + /// Invoke future with optional timeout. This is to work around async_nats + /// not implementing request_with_timeout or publish_with_timeout anymore. + async fn maybe_timeout(&self, t: Option, f: F) -> RpcResult + where + F: Future> + Send + Sync + 'static, + T: 'static, + E: ToString, + { + if let Some(timeout) = t { + match tokio::time::timeout(timeout, f).await { + Err(elapsed) => { + #[cfg(feature = "prometheus")] + self.stats.rpc_sent_timeouts.inc(); + Err(RpcError::Timeout(elapsed.to_string())) + } + Ok(Ok(data)) => Ok(data), + Ok(Err(err)) => Err(RpcError::Nats(err.to_string())), + } + } else { + f.await.map_err(|e| RpcError::Nats(e.to_string())) + } + } +} + +/// helper method to add logging to a nats connection. Logs disconnection (warn level), reconnection (info level), error (error), and lame duck(warn) events. +pub fn with_connection_event_logging( + opts: async_nats::ConnectOptions, +) -> async_nats::ConnectOptions { + opts.disconnect_callback(|| async { + warn!("nats connection has disconnected. Attempting reconnection ..."); + }) + .reconnect_callback(|| async { info!("nats connection has been reestablished.") }) + .error_callback(|error| async move { error!(%error, "nats connection encountered error ") }) + .lame_duck_callback(|| async { warn!("nats connection has entered lame duck mode") }) + .ping_interval(Duration::from_secs(17)) +} + +#[derive(Clone)] +pub struct InvocationArg { + /// Sender of the message + pub origin: String, + + /// Method name, usually of the form Service.Method + pub operation: String, + + /// Message payload (could be empty array). May need to be serialized + pub arg: Vec, } pub(crate) fn invocation_hash( @@ -490,13 +653,15 @@ pub(crate) fn invocation_hash( method: &str, args: &[u8], ) -> String { - let mut context = Context::new(&SHA256); - context.update(origin_url.as_bytes()); - context.update(target_url.as_bytes()); - context.update(method.as_bytes()); - context.update(args); - let digest = context.finish(); - data_encoding::HEXUPPER.encode(digest.as_ref()) + use sha2::Digest as _; + + let mut hasher = sha2::Sha256::new(); + hasher.update(origin_url.as_bytes()); + hasher.update(target_url.as_bytes()); + hasher.update(method.as_bytes()); + hasher.update(args); + let digest = hasher.finalize(); + data_encoding::HEXUPPER.encode(digest.as_slice()) } /// Create a new random uuid for invocations. @@ -509,7 +674,7 @@ pub fn make_uuid() -> String { // uuid uses getrandom, which uses the operating system's RNG // as the source of random numbers. Uuid::new_v4() - .to_simple() + .as_simple() .encode_lower(&mut Uuid::encode_buffer()) .to_string() } @@ -551,3 +716,100 @@ where serde_json::to_value(crate::common::deserialize::(msg)?) .map_err(|e| RpcError::Ser(format!("response serialization : {}.", e))) } + +#[cfg(feature = "prometheus")] +impl RpcStats { + fn init(public_key: String) -> RpcStats { + let mut map = std::collections::HashMap::new(); + map.insert("public_key".to_string(), public_key); + + RpcStats { + rpc_sent: IntCounter::with_opts( + Opts::new("rpc_sent", "number of rpc nats messages sent").const_labels(map.clone()), + ) + .unwrap(), + rpc_sent_err: IntCounter::with_opts( + Opts::new("rpc_sent_err", "number of errors sending rpc").const_labels(map.clone()), + ) + .unwrap(), + rpc_sent_chunky: IntCounter::with_opts( + Opts::new( + "rpc_sent_chunky", + "number of rpc messages that were chunkified", + ) + .const_labels(map.clone()), + ) + .unwrap(), + rpc_sent_resp_chunky: IntCounter::with_opts( + Opts::new( + "rpc_sent_resp_chunky", + "number of responses to sent rpc that were chunkified", + ) + .const_labels(map.clone()), + ) + .unwrap(), + rpc_sent_bytes: IntCounter::with_opts( + Opts::new("rpc_sent_bytes", "total bytes sent in rpc requests") + .const_labels(map.clone()), + ) + .unwrap(), + rpc_sent_resp_bytes: IntCounter::with_opts( + Opts::new( + "rpc_sent_resp_bytes", + "total bytes sent in responses to incoming rpc", + ) + .const_labels(map.clone()), + ) + .unwrap(), + rpc_sent_timeouts: IntCounter::with_opts( + Opts::new( + "rpc_sent_timeouts", + "number of rpc messages that incurred timeout error", + ) + .const_labels(map.clone()), + ) + .unwrap(), + rpc_recv: IntCounter::with_opts( + Opts::new("rpc_recv", "number of rpc messages received").const_labels(map.clone()), + ) + .unwrap(), + rpc_recv_err: IntCounter::with_opts( + Opts::new( + "rpc_recv_err", + "number of errors encountered responding to incoming rpc", + ) + .const_labels(map.clone()), + ) + .unwrap(), + rpc_recv_chunky: IntCounter::with_opts( + Opts::new( + "rpc_recv_chunky", + "number of received rpc that were chunkified", + ) + .const_labels(map.clone()), + ) + .unwrap(), + rpc_recv_resp_chunky: IntCounter::with_opts( + Opts::new( + "rpc_recv_resp_chunky", + "number of chunkified responses to received rpc", + ) + .const_labels(map.clone()), + ) + .unwrap(), + rpc_recv_bytes: IntCounter::with_opts( + Opts::new("rpc_recv_bytes", "total bytes in received rpc") + .const_labels(map.clone()), + ) + .unwrap(), + rpc_recv_resp_bytes: IntCounter::with_opts( + Opts::new( + "rpc_recv_resp_bytes", + "total bytes in responses to incoming rpc", + ) + .const_labels(map.clone()), + ) + .unwrap(), + } + } +} diff --git a/rpc-rs/src/timestamp.rs b/rpc-rs/src/timestamp.rs index 26149e8..fb88eec 100644 --- a/rpc-rs/src/timestamp.rs +++ b/rpc-rs/src/timestamp.rs @@ -218,5 +218,5 @@ fn timestamp_ordering() { assert!(t1 > t4); // not equals - assert!(t1 != t4); + assert_ne!(t1, t4); } diff --git a/rpc-rs/src/wasmbus_core.rs b/rpc-rs/src/wasmbus_core.rs index 3f09c96..161e156 100644 --- a/rpc-rs/src/wasmbus_core.rs +++ b/rpc-rs/src/wasmbus_core.rs @@ -4,7 +4,7 @@ #[allow(unused_imports)] use crate::{ - //cbor::*, + cbor::*, common::{ deserialize, message_format, serialize, Context, Message, MessageDispatch, MessageFormat, SendOpts, Transport, @@ -325,8 +325,9 @@ pub struct HostData { /// Host-wide default RPC timeout for rpc messages, in milliseconds. Defaults to 2000. #[serde(default, skip_serializing_if = "Option::is_none")] pub default_rpc_timeout_ms: Option, + /// True if structured logging is enabled for the host. Providers should use the same setting as the host. #[serde(default)] - pub structured_logging_enabled: bool, + pub structured_logging: bool, } // Encode HostData as CBOR and append to output stream @@ -362,7 +363,7 @@ where } else { e.null()?; } - e.bool(val.structured_logging_enabled)?; + e.bool(val.structured_logging)?; Ok(()) } @@ -384,7 +385,7 @@ pub fn decode_host_data(d: &mut crate::cbor::Decoder<'_>) -> Result = None; let mut config_json: Option> = Some(None); let mut default_rpc_timeout_ms: Option> = Some(None); - let mut structured_logging_enabled: Option = None; + let mut structured_logging: Option = None; let is_array = match d.datatype()? { crate::cbor::Type::Array => true, @@ -439,7 +440,7 @@ pub fn decode_host_data(d: &mut crate::cbor::Decoder<'_>) -> Result structured_logging_enabled = Some(d.bool()?), + 14 => structured_logging = Some(d.bool()?), _ => d.skip()?, } } @@ -487,7 +488,7 @@ pub fn decode_host_data(d: &mut crate::cbor::Decoder<'_>) -> Result structured_logging_enabled = Some(d.bool()?), + "structuredLogging" => structured_logging = Some(d.bool()?), _ => d.skip()?, } } @@ -591,11 +592,11 @@ pub fn decode_host_data(d: &mut crate::cbor::Decoder<'_>) -> Result) -> Result { Ok(__result) } /// list of identifiers +/// This declaration supports code generations and is not part of an actor or provider sdk pub type IdentifierList = Vec; // Encode IdentifierList as CBOR and append to output stream diff --git a/rpc-rs/tests/nats_sub.rs b/rpc-rs/tests/nats_sub.rs index b11e7e8..af72094 100644 --- a/rpc-rs/tests/nats_sub.rs +++ b/rpc-rs/tests/nats_sub.rs @@ -1,95 +1,102 @@ //! test nats subscriptions (queue and non-queue) with rpc_client #![cfg(test)] -const THREE_SEC: Duration = Duration::from_secs(3); - -use std::{str::FromStr as _, time::Duration}; +use std::{str::FromStr, sync::Arc, time::Duration}; -use tracing::debug; +use tracing::{debug, error}; +use wascap::prelude::KeyPair; use wasmbus_rpc::{ error::{RpcError, RpcResult}, rpc_client::RpcClient, }; -//const DEFAULT_NATS_ADDR: &str = "nats://127.0.0.1:4222"; +const ONE_SEC: Duration = Duration::from_secs(1); +const THREE_SEC: Duration = Duration::from_secs(3); const TEST_NATS_ADDR: &str = "nats://127.0.0.1:4222"; -const LATTICE_PREFIX: &str = "test_nats_sub"; const HOST_ID: &str = "HOST_test_nats_sub"; +fn nats_url() -> String { + if let Ok(addr) = std::env::var("NATS_URL") { + addr + } else { + TEST_NATS_ADDR.to_string() + } +} + +fn is_demo() -> bool { + nats_url().contains("demo.nats.io") +} + /// create async nats client for test (sender or receiver) -async fn make_client() -> RpcResult { - let server_addr = wasmbus_rpc::anats::ServerAddress::from_str(TEST_NATS_ADDR).unwrap(); - let nc = wasmbus_rpc::anats::Options::default() - .max_reconnects(None) - .connect(vec![server_addr]) +/// Parameter is optional RPC timeout +async fn make_client(timeout: Option) -> RpcResult { + let nats_url = nats_url(); + let server_addr = async_nats::ServerAddr::from_str(&nats_url).unwrap(); + let nc = async_nats::ConnectOptions::default() + .connect(server_addr) .await .map_err(|e| { - RpcError::ProviderInit(format!( - "nats connection to {} failed: {}", - TEST_NATS_ADDR, e - )) + RpcError::ProviderInit(format!("nats connection to {} failed: {}", nats_url, e)) })?; - let kp = wascap::prelude::KeyPair::new_user(); - let client = RpcClient::new( - nc, - LATTICE_PREFIX, - kp, - HOST_ID.to_string(), - Some(Duration::from_secs(5)), - ); + + let key_pair = KeyPair::new_user(); + let client = RpcClient::new(nc, HOST_ID.to_string(), timeout, Arc::new(key_pair)); Ok(client) } async fn listen(client: RpcClient, subject: &str, pattern: &str) -> tokio::task::JoinHandle { + use futures::StreamExt; + let subject = subject.to_string(); let pattern = pattern.to_string(); - let nc = client.get_async().unwrap(); + let nc = client.client(); let pattern = regex::Regex::new(&pattern).unwrap(); - let sub = nc.subscribe(&subject).await.expect("subscriber"); + let mut sub = nc.subscribe(subject.clone()).await.expect("subscriber"); tokio::task::spawn(async move { let mut count: u64 = 0; while let Some(msg) = sub.next().await { - let payload = String::from_utf8_lossy(&msg.data); + let payload = String::from_utf8_lossy(&msg.payload); if !pattern.is_match(payload.as_ref()) && &payload != "exit" { - println!("ERROR: payload on {}: {}", &subject, &payload); + println!("ERROR: payload on {}: {}", subject, &payload); } if let Some(reply_to) = msg.reply { - client.publish(&reply_to, b"ok").await.expect("reply"); + client.publish(reply_to, b"ok".to_vec()).await.expect("reply"); } if payload == "exit" { break; } count += 1; } - println!("exiting: {}", count); - let _ = sub.close().await; + println!("received {} message(s)", count); count }) } async fn listen_bin(client: RpcClient, subject: &str) -> tokio::task::JoinHandle { + use futures::StreamExt; let subject = subject.to_string(); - let nc = client.get_async().unwrap(); + let nc = client.client(); - let sub = nc.subscribe(&subject).await.expect("subscriber"); + let mut sub = nc.subscribe(subject.clone()).await.expect("subscriber"); tokio::task::spawn(async move { let mut count: u64 = 0; - println!("listening subj: {}", &subject); while let Some(msg) = sub.next().await { - let size = msg.data.len(); + let size = msg.payload.len(); let response = format!("{}", size); if let Some(reply_to) = msg.reply { - client.publish(&reply_to, response.as_bytes()).await.expect("reply"); + if let Err(e) = nc.publish(reply_to, response.as_bytes().to_vec().into()).await { + error!("error publishing subscriber response: {}", e); + } } count += 1; if size == 1 { break; } } - let _ = sub.close().await; - println!("exiting: {}", count); + let _ = sub.unsubscribe().await; + debug!("listen_bin exiting with count {}", count); count }) } @@ -100,34 +107,36 @@ async fn listen_queue( queue: &str, pattern: &str, ) -> tokio::task::JoinHandle { + use futures::StreamExt; let subject = subject.to_string(); let queue = queue.to_string(); let pattern = pattern.to_string(); - let nc = client.get_async().unwrap(); + let nc = client.client(); tokio::task::spawn(async move { let mut count: u64 = 0; let pattern = regex::Regex::new(&pattern).unwrap(); - let sub = nc.queue_subscribe(&subject, &queue).await.expect("group subscriber"); - debug!("listening subj: {} queue: {}", &subject, &queue); + let mut sub = nc + .queue_subscribe(subject.clone(), queue.clone()) + .await + .expect("group subscriber"); while let Some(msg) = sub.next().await { - let payload = String::from_utf8_lossy(&msg.data); + let payload = String::from_utf8_lossy(&msg.payload); if !pattern.is_match(payload.as_ref()) && &payload != "exit" { debug!("ERROR: payload on {}: {}", &subject, &payload); break; } if let Some(reply_to) = msg.reply { debug!("listener {} replying ok", &subject); - client.publish(&reply_to, b"ok").await.expect("reply"); + client.publish(reply_to, b"ok".to_vec()).await.expect("reply"); } if &payload == "exit" { - debug!("listener {} received 'exit'", &subject); - //let _ = sub.close().await; + let _ = sub.unsubscribe().await; break; } count += 1; } - println!("listener {} exiting with count {}", &subject, count); + println!("subscriber '{}' exiting count={}", &subject, count); count }) } @@ -138,60 +147,57 @@ async fn simple_sub() -> Result<(), Box> { let sub_name = uuid::Uuid::new_v4().to_string(); let topic = format!("one_{}", &sub_name); - let l1 = listen(make_client().await?, &topic, "^abc").await; + let l1 = listen(make_client(None).await?, &topic, "^abc").await; - let sender = make_client().await.expect("creating sender"); - sender.publish(&topic, b"abc").await.expect("send"); - sender.publish(&topic, b"exit").await.expect("send"); + let sender = make_client(None).await.expect("creating sender"); + sender.publish(topic.clone(), b"abc".to_vec()).await.expect("send"); + sender.publish(topic, b"exit".to_vec()).await.expect("send"); let val = l1.await.expect("join"); assert_eq!(val, 1); Ok(()) } -/// send large messages to find size limits +/// send large messages - this uses request() and does not test chunking #[tokio::test] async fn test_message_size() -> Result<(), Box> { - if env_logger::try_init().is_err() {}; // create unique subscription name for this test let sub_name = uuid::Uuid::new_v4().to_string(); let topic = format!("bin_{}", &sub_name); - let l1 = listen_bin(make_client().await?, &topic).await; + let l1 = listen_bin(make_client(Some(THREE_SEC)).await?, &topic).await; let mut pass_count = 0; - let sender = make_client().await.expect("creating bin sender"); - const TEST_SIZES: &[u32] = &[ - 100, 200, - // NOTE: if using 'demo.nats.io' as the test server, - // don't abuse it by running this test - only use larger sizes - // if testing against a local nats server. + let sender = make_client(Some(THREE_SEC)).await.expect("creating bin sender"); + // messages sizes to test + let test_sizes = if is_demo() { + // if using 'demo.nats.io' as the test server, + // don't abuse it by running this test with very large sizes // - // 100_000, 200_000, 300_000, 400_000, 500_000, 600_000, 700_000, 800_000, 900_000, - //1_000_000, (1024 * 1024), - // The last size must be 1: signal to listen_bin to exit - 1, - ]; - for size in TEST_SIZES.iter() { + // The last size must be 1 to signal to listen_bin to exit + &[10u32, 25, 100, 200, 500, 1000, 1] + } else { + // The last size must be 1 to signal to listen_bin to exit + &[10u32, 25, 500, 10_000, 800_000, 1_000_000, 1] + }; + for size in test_sizes.iter() { let mut data = Vec::with_capacity(*size as usize); data.resize(*size as usize, 255u8); - let resp = - match tokio::time::timeout(Duration::from_millis(2000), sender.request(&topic, &data)) - .await - { - Ok(Ok(result)) => result, - Ok(Err(rpc_err)) => { - eprintln!("rpc send error on msg size {}: {}", *size, rpc_err); - continue; - } - Err(timeout_err) => { - eprintln!( - "rpc timeout: sending msg of size {}: {}", - *size, timeout_err - ); - continue; - } - }; + let resp = match tokio::time::timeout(THREE_SEC, sender.request(topic.clone(), data)).await + { + Ok(Ok(result)) => result, + Ok(Err(rpc_err)) => { + eprintln!("send error on msg size {}: {}", *size, rpc_err); + continue; + } + Err(timeout_err) => { + eprintln!( + "rpc timeout: sending msg of size {}: {}", + *size, timeout_err + ); + continue; + } + }; let sbody = String::from_utf8_lossy(&resp); let received_size = sbody.parse::().expect("response contains int size"); if *size == received_size { @@ -201,18 +207,18 @@ async fn test_message_size() -> Result<(), Box> { eprintln!("FAIL: message_size: {}, got: {}", size, received_size); } } - assert_eq!(pass_count, TEST_SIZES.len(), "some size tests did not pass"); + assert_eq!(pass_count, test_sizes.len(), "some size tests did not pass"); let val = l1.await.expect("join"); assert_eq!( val as usize, - TEST_SIZES.len(), + test_sizes.len(), "some messages were not received" ); Ok(()) } async fn sleep(millis: u64) { - tokio::time::sleep(tokio::time::Duration::from_millis(millis)).await; + tokio::time::sleep(Duration::from_millis(millis)).await; } fn check_ok(data: Vec) -> Result<(), RpcError> { @@ -236,34 +242,33 @@ async fn queue_sub() -> Result<(), Box> { // This confirms that publishing to queue subscription divides the load, // and also confirms that a queue group name ('X') is only applicable // within a topic. - let _ = env_logger::try_init(); let sub_name = uuid::Uuid::new_v4().to_string(); let topic_one = format!("one_{}", &sub_name); let topic_two = format!("two_{}", &sub_name); let queue_name = uuid::Uuid::new_v4().to_string(); - let thread1 = listen_queue(make_client().await?, &topic_one, &queue_name, "^one").await; - let thread2 = listen_queue(make_client().await?, &topic_one, &queue_name, "^one").await; - let thread3 = listen_queue(make_client().await?, &topic_two, &queue_name, "^two").await; - sleep(2000).await; + let thread1 = listen_queue(make_client(None).await?, &topic_one, &queue_name, "^one").await; + let thread2 = listen_queue(make_client(None).await?, &topic_one, &queue_name, "^one").await; + let thread3 = listen_queue(make_client(None).await?, &topic_two, &queue_name, "^two").await; + sleep(200).await; - let sender = make_client().await?; + let sender = make_client(None).await?; const SPLIT_TOTAL: usize = 6; const SINGLE_TOTAL: usize = 6; for _ in 0..SPLIT_TOTAL { - check_ok(sender.request(&topic_one, b"one").await?)?; + check_ok(sender.request(topic_one.clone(), b"one".to_vec()).await?)?; } for _ in 0..SINGLE_TOTAL { - check_ok(sender.request(&topic_two, b"two").await?)?; + check_ok(sender.request(topic_two.clone(), b"two".to_vec()).await?)?; } - check_ok(sender.request(&topic_one, b"exit").await?)?; - check_ok(sender.request(&topic_one, b"exit").await?)?; - check_ok(sender.request(&topic_two, b"exit").await?)?; + check_ok(sender.request(topic_one.clone(), b"exit".to_vec()).await?)?; + check_ok(sender.request(topic_one.clone(), b"exit".to_vec()).await?)?; + check_ok(sender.request(topic_two.clone(), b"exit".to_vec()).await?)?; - let v3 = wait_for(thread3, THREE_SEC).await??; - let v2 = wait_for(thread2, THREE_SEC).await??; - let v1 = wait_for(thread1, THREE_SEC).await??; + let v3 = wait_for(thread3, ONE_SEC).await??; + let v2 = wait_for(thread2, ONE_SEC).await??; + let v1 = wait_for(thread1, ONE_SEC).await??; assert_eq!(v1 + v2, SPLIT_TOTAL as u64, "no loss in queue"); assert_eq!(v3, SINGLE_TOTAL as u64, "no overlap between queues");