diff --git a/example/firmware/Cargo.lock b/example/firmware/Cargo.lock index 47821f1..a7baf7c 100644 --- a/example/firmware/Cargo.lock +++ b/example/firmware/Cargo.lock @@ -1150,7 +1150,7 @@ dependencies = [ [[package]] name = "postcard-rpc" -version = "0.5.1" +version = "0.5.3" dependencies = [ "embassy-executor", "embassy-sync", diff --git a/example/workbook-host/Cargo.lock b/example/workbook-host/Cargo.lock index a89ad6c..56a3f18 100644 --- a/example/workbook-host/Cargo.lock +++ b/example/workbook-host/Cargo.lock @@ -328,9 +328,9 @@ dependencies = [ [[package]] name = "maitake-sync" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27ff6bc892d1b738a544d20599bce0a1446454edaa0338020a7d1b046d78a80f" +checksum = "6816ab14147f80234c675b80ed6dc4f440d8a1cefc158e766067aedb84c0bcd5" dependencies = [ "cordyceps", "loom 0.7.2", @@ -486,7 +486,7 @@ dependencies = [ [[package]] name = "postcard-rpc" -version = "0.5.1" +version = "0.5.3" dependencies = [ "heapless 0.8.0", "maitake-sync", diff --git a/example/workbook-host/src/bin/comms-01.rs b/example/workbook-host/src/bin/comms-01.rs index c820cd5..24306d8 100644 --- a/example/workbook-host/src/bin/comms-01.rs +++ b/example/workbook-host/src/bin/comms-01.rs @@ -6,6 +6,18 @@ use tokio::time::interval; #[tokio::main] pub async fn main() { let client = WorkbookClient::new(); + + tokio::select! { + _ = client.wait_closed() => { + println!("Client is closed, exiting..."); + } + _ = run(&client) => { + println!("App is done") + } + } +} + +async fn run(client: &WorkbookClient) { let mut ticker = interval(Duration::from_millis(250)); for i in 0..10 { diff --git a/example/workbook-host/src/client.rs b/example/workbook-host/src/client.rs index a44e5e3..3e8ded2 100644 --- a/example/workbook-host/src/client.rs +++ b/example/workbook-host/src/client.rs @@ -44,6 +44,10 @@ impl WorkbookClient { Self { client } } + pub async fn wait_closed(&self) { + self.client.wait_closed().await; + } + pub async fn ping(&self, id: u32) -> Result> { let val = self.client.send_resp::(&id).await?; Ok(val) diff --git a/source/postcard-rpc-test/Cargo.lock b/source/postcard-rpc-test/Cargo.lock index 1f8a3c7..b99c0e0 100644 --- a/source/postcard-rpc-test/Cargo.lock +++ b/source/postcard-rpc-test/Cargo.lock @@ -248,9 +248,9 @@ dependencies = [ [[package]] name = "maitake-sync" -version = "0.1.0" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68d76dcfa3b14b75b60ff187f5df11c10fa76f227741a42070f3d36215756f24" +checksum = "6816ab14147f80234c675b80ed6dc4f440d8a1cefc158e766067aedb84c0bcd5" dependencies = [ "cordyceps", "loom 0.7.1", @@ -379,7 +379,7 @@ dependencies = [ [[package]] name = "postcard-rpc" -version = "0.5.1" +version = "0.5.3" dependencies = [ "heapless 0.8.0", "maitake-sync", diff --git a/source/postcard-rpc-test/src/lib.rs b/source/postcard-rpc-test/src/lib.rs index 5e832dd..255535f 100644 --- a/source/postcard-rpc-test/src/lib.rs +++ b/source/postcard-rpc-test/src/lib.rs @@ -1,2 +1 @@ // I'm just here so we can write integration tests - diff --git a/source/postcard-rpc-test/tests/basic.rs b/source/postcard-rpc-test/tests/basic.rs index 61ba894..97283da 100644 --- a/source/postcard-rpc-test/tests/basic.rs +++ b/source/postcard-rpc-test/tests/basic.rs @@ -1,11 +1,12 @@ use std::{collections::HashMap, time::Duration}; use postcard::experimental::schema::Schema; +use postcard_rpc::test_utils::local_setup; use postcard_rpc::{ endpoint, headered::to_stdvec_keyed, topic, Dispatch, Endpoint, Key, Topic, WireHeader, }; -use postcard_rpc::test_utils::local_setup; use serde::{Deserialize, Serialize}; +use tokio::task::yield_now; use tokio::time::timeout; endpoint!(EndpointOne, Req1, Resp1, "endpoint/one"); @@ -78,7 +79,7 @@ async fn smoke_reqresp() { }); // As the wire, get the outgoing request - let out1 = srv.from_client.recv().await.unwrap(); + let out1 = srv.recv_from_client().await.unwrap(); // Does the outgoing value match what we expect? let exp_out = to_stdvec_keyed(0, EndpointOne::REQ_KEY, &Req1 { a: 10, b: 100 }).unwrap(); @@ -127,7 +128,7 @@ async fn smoke_publish() { .unwrap(); // As the wire, get the outgoing request - let out1 = srv.from_client.recv().await.unwrap(); + let out1 = srv.recv_from_client().await.unwrap(); // Does the outgoing value match what we expect? let exp_out = to_stdvec_keyed(123, TopicOne::TOPIC_KEY, &Req1 { a: 10, b: 100 }).unwrap(); @@ -158,3 +159,147 @@ async fn smoke_subscribe() { assert_eq!(publ, VAL); } + +#[tokio::test] +async fn smoke_io_error() { + let (mut srv, client) = local_setup::(8, "error"); + + // Do one round trip to make sure the connection works + { + let rr_rt = tokio::task::spawn({ + let client = client.clone(); + async move { + client + .send_resp::(&Req1 { a: 10, b: 100 }) + .await + } + }); + + // As the wire, get the outgoing request + let out1 = srv.recv_from_client().await.unwrap(); + + // Does the outgoing value match what we expect? + let exp_out = to_stdvec_keyed(0, EndpointOne::REQ_KEY, &Req1 { a: 10, b: 100 }).unwrap(); + let act_out = out1.to_bytes(); + assert_eq!(act_out, exp_out); + + // The request is still awaiting a response + assert!(!rr_rt.is_finished()); + + // Feed a simulated response "from the wire" back to the + // awaiting request + const RESP_001: Resp1 = Resp1 { + c: [1, 2, 3, 4, 5, 6, 7, 8], + d: -10, + }; + srv.reply::(out1.header.seq_no, &RESP_001) + .await + .unwrap(); + + // Now wait for the request to complete + let end = rr_rt.await.unwrap().unwrap(); + + // We got the simulated value back + assert_eq!(end, RESP_001); + } + + // Now, simulate an I/O error + srv.cause_fatal_error(); + + // Give the clients some time to halt + yield_now().await; + + // Our server channels should now be closed - the tasks hung up + assert!(srv.from_client.recv().await.is_none()); + assert!(srv.to_client.send(Vec::new()).await.is_err()); + + // Try again, but nothing should work because all the worker tasks just halted + { + let rr_rt = tokio::task::spawn({ + let client = client.clone(); + async move { + client + .send_resp::(&Req1 { a: 10, b: 100 }) + .await + } + }); + + // As the wire, get the outgoing request - didn't happen + assert!(srv.recv_from_client().await.is_err()); + + // Now wait for the request to complete - it failed + rr_rt.await.unwrap().unwrap_err(); + } +} + +#[tokio::test] +async fn smoke_closed() { + let (mut srv, client) = local_setup::(8, "error"); + + // Do one round trip to make sure the connection works + { + let rr_rt = tokio::task::spawn({ + let client = client.clone(); + async move { + client + .send_resp::(&Req1 { a: 10, b: 100 }) + .await + } + }); + + // As the wire, get the outgoing request + let out1 = srv.recv_from_client().await.unwrap(); + + // Does the outgoing value match what we expect? + let exp_out = to_stdvec_keyed(0, EndpointOne::REQ_KEY, &Req1 { a: 10, b: 100 }).unwrap(); + let act_out = out1.to_bytes(); + assert_eq!(act_out, exp_out); + + // The request is still awaiting a response + assert!(!rr_rt.is_finished()); + + // Feed a simulated response "from the wire" back to the + // awaiting request + const RESP_001: Resp1 = Resp1 { + c: [1, 2, 3, 4, 5, 6, 7, 8], + d: -10, + }; + srv.reply::(out1.header.seq_no, &RESP_001) + .await + .unwrap(); + + // Now wait for the request to complete + let end = rr_rt.await.unwrap().unwrap(); + + // We got the simulated value back + assert_eq!(end, RESP_001); + } + + // Now, use the *client* to close the connection + client.close(); + + // Give the clients some time to halt + yield_now().await; + + // Our server channels should now be closed - the tasks hung up + assert!(srv.from_client.recv().await.is_none()); + assert!(srv.to_client.send(Vec::new()).await.is_err()); + + // Try again, but nothing should work because all the worker tasks just halted + { + let rr_rt = tokio::task::spawn({ + let client = client.clone(); + async move { + client + .send_resp::(&Req1 { a: 10, b: 100 }) + .await + } + }); + + // As the wire, get the outgoing request - didn't happen + assert!(srv.recv_from_client().await.is_err()); + + // Now wait for the request to complete - it failed + rr_rt.await.unwrap().unwrap_err(); + } +} diff --git a/source/postcard-rpc/Cargo.toml b/source/postcard-rpc/Cargo.toml index e289bfa..72bf622 100644 --- a/source/postcard-rpc/Cargo.toml +++ b/source/postcard-rpc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postcard-rpc" -version = "0.5.1" +version = "0.5.3" authors = ["James Munns "] edition = "2021" repository = "https://github.com/jamesmunns/postcard-rpc" @@ -43,7 +43,7 @@ version = "5.4.4" optional = true [dependencies.maitake-sync] -version = "0.1.0" +version = "0.1.2" optional = true [dependencies.tokio] diff --git a/source/postcard-rpc/src/host_client/mod.rs b/source/postcard-rpc/src/host_client/mod.rs index 1bee600..6b51144 100644 --- a/source/postcard-rpc/src/host_client/mod.rs +++ b/source/postcard-rpc/src/host_client/mod.rs @@ -25,6 +25,8 @@ use tokio::{ use crate::{Endpoint, Key, Topic, WireHeader}; +use self::util::Stopper; + #[cfg(all(feature = "raw-nusb", not(target_family = "wasm")))] mod raw_nusb; @@ -34,7 +36,7 @@ mod serial; #[cfg(all(feature = "webusb", target_family = "wasm"))] pub mod webusb; -mod util; +pub(crate) mod util; /// Host Error Kind #[derive(Debug, PartialEq)] @@ -153,6 +155,7 @@ pub struct HostClient { out: Sender, subber: Sender, err_key: Key, + stopper: Stopper, _pd: PhantomData WireErr>, } @@ -190,6 +193,7 @@ where err_key, _pd: PhantomData, subber: tx_si.clone(), + stopper: Stopper::new(), }; let wire = WireContext { @@ -215,6 +219,23 @@ where &self, t: &E::Request, ) -> Result> + where + E::Request: Serialize + Schema, + E::Response: DeserializeOwned + Schema, + { + let cancel_fut = self.stopper.wait_stopped(); + let operate_fut = self.send_resp_inner::(t); + select! { + _ = cancel_fut => Err(HostErr::Closed), + res = operate_fut => res, + } + } + + /// Inner function version of [Self::send_resp] + async fn send_resp_inner( + &self, + t: &E::Request, + ) -> Result> where E::Request: Serialize + Schema, E::Response: DeserializeOwned + Schema, @@ -257,6 +278,19 @@ where /// There is no feedback if the server received our message. If the I/O worker is /// closed, an error is returned. pub async fn publish(&self, seq_no: u32, msg: &T::Message) -> Result<(), IoClosed> + where + T::Message: Serialize, + { + let cancel_fut = self.stopper.wait_stopped(); + let operate_fut = self.publish_inner::(seq_no, msg); + select! { + _ = cancel_fut => Err(IoClosed), + res = operate_fut => res, + } + } + + /// Inner function version of [Self::publish] + async fn publish_inner(&self, seq_no: u32, msg: &T::Message) -> Result<(), IoClosed> where T::Message: Serialize, { @@ -284,6 +318,22 @@ where &self, depth: usize, ) -> Result, IoClosed> + where + T::Message: DeserializeOwned, + { + let cancel_fut = self.stopper.wait_stopped(); + let operate_fut = self.subscribe_inner::(depth); + select! { + _ = cancel_fut => Err(IoClosed), + res = operate_fut => res, + } + } + + /// Inner function version of [Self::subscribe] + async fn subscribe_inner( + &self, + depth: usize, + ) -> Result, IoClosed> where T::Message: DeserializeOwned, { @@ -300,6 +350,27 @@ where _pd: PhantomData, }) } + + /// Permanently close the connection to the client + /// + /// All other HostClients sharing the connection (e.g. created by cloning + /// a single HostClient) will also stop, and no further communication will + /// succeed. The in-flight messages will not be flushed. + /// + /// This will also signal any I/O worker tasks to halt immediately as well. + pub fn close(&self) { + self.stopper.stop() + } + + /// Has this host client been closed? + pub fn is_closed(&self) -> bool { + self.stopper.is_stopped() + } + + /// Wait for the host client to be closed + pub async fn wait_closed(&self) { + self.stopper.wait_stopped().await; + } } /// A structure that represents a subscription to the given topic @@ -334,6 +405,7 @@ impl Clone for HostClient { err_key: self.err_key, _pd: PhantomData, subber: self.subber.clone(), + stopper: self.stopper.clone(), } } } diff --git a/source/postcard-rpc/src/host_client/util.rs b/source/postcard-rpc/src/host_client/util.rs index 3d49c38..dd422ab 100644 --- a/source/postcard-rpc/src/host_client/util.rs +++ b/source/postcard-rpc/src/host_client/util.rs @@ -1,11 +1,15 @@ // the contents of this file can probably be moved up to `mod.rs` use std::{collections::HashMap, fmt::Debug, sync::Arc}; +use maitake_sync::WaitQueue; use postcard::experimental::schema::Schema; use serde::de::DeserializeOwned; -use tokio::sync::{ - mpsc::{error::TrySendError, Receiver, Sender}, - Mutex, +use tokio::{ + select, + sync::{ + mpsc::{error::TrySendError, Receiver, Sender}, + Mutex, + }, }; use tracing::{debug, trace, warn}; @@ -20,6 +24,44 @@ use crate::{ pub(crate) type Subscriptions = HashMap>; +/// A basic cancellation-token +/// +/// Used to terminate (and signal termination of) worker tasks +#[derive(Clone)] +pub struct Stopper { + inner: Arc, +} + +impl Stopper { + /// Create a new Stopper + pub fn new() -> Self { + Self { + inner: Arc::new(WaitQueue::new()), + } + } + + /// Wait until the stopper has been stopped. + /// + /// Once this completes, the stopper has been permanently stopped + pub async fn wait_stopped(&self) { + // This completes if we are awoken OR if the queue is closed: either + // means we're cancelled + let _ = self.inner.wait().await; + } + + /// Have we been stopped? + pub fn is_stopped(&self) -> bool { + self.inner.is_closed() + } + + /// Stop the stopper + /// + /// All current and future calls to [Self::wait_stopped] will complete immediately + pub fn stop(&self) { + self.inner.close(); + } +} + impl HostClient where WireErr: DeserializeOwned + Schema, @@ -50,16 +92,37 @@ where let subscriptions: Arc> = Arc::new(Mutex::new(Subscriptions::new())); - sp.spawn(out_worker(tx, outgoing)); - sp.spawn(in_worker(rx, incoming, subscriptions.clone())); - sp.spawn(sub_worker(new_subs, subscriptions)); + sp.spawn(out_worker(tx, outgoing, me.stopper.clone())); + sp.spawn(in_worker( + rx, + incoming, + subscriptions.clone(), + me.stopper.clone(), + )); + sp.spawn(sub_worker(new_subs, subscriptions, me.stopper.clone())); me } } /// Output worker, feeding frames to the `Client`. -async fn out_worker(mut wire: W, mut rec: Receiver) +async fn out_worker(wire: W, rec: Receiver, stop: Stopper) +where + W: WireTx, + W::Error: Debug, +{ + let cancel_fut = stop.wait_stopped(); + let operate_fut = out_worker_inner(wire, rec); + select! { + _ = cancel_fut => {}, + _ = operate_fut => { + // if WE exited, notify everyone else it's stoppin time + stop.stop(); + }, + } +} + +async fn out_worker_inner(mut wire: W, mut rec: Receiver) where W: WireTx, W::Error: Debug, @@ -78,6 +141,26 @@ where /// Input worker, getting frames from the `Client` async fn in_worker( + wire: W, + host_ctx: Arc, + subscriptions: Arc>, + stop: Stopper, +) where + W: WireRx, + W::Error: Debug, +{ + let cancel_fut = stop.wait_stopped(); + let operate_fut = in_worker_inner(wire, host_ctx, subscriptions); + select! { + _ = cancel_fut => {}, + _ = operate_fut => { + // if WE exited, notify everyone else it's stoppin time + stop.stop(); + }, + } +} + +async fn in_worker_inner( mut wire: W, host_ctx: Arc, subscriptions: Arc>, @@ -154,7 +237,26 @@ async fn in_worker( } } -async fn sub_worker(mut new_subs: Receiver, subscriptions: Arc>) { +async fn sub_worker( + new_subs: Receiver, + subscriptions: Arc>, + stop: Stopper, +) { + let cancel_fut = stop.wait_stopped(); + let operate_fut = sub_worker_inner(new_subs, subscriptions); + select! { + _ = cancel_fut => {}, + _ = operate_fut => { + // if WE exited, notify everyone else it's stoppin time + stop.stop(); + }, + } +} + +async fn sub_worker_inner( + mut new_subs: Receiver, + subscriptions: Arc>, +) { while let Some(sub) = new_subs.recv().await { let mut sub_guard = subscriptions.lock().await; if let Some(_old) = sub_guard.insert(sub.key, sub.tx) { diff --git a/source/postcard-rpc/src/test_utils.rs b/source/postcard-rpc/src/test_utils.rs index 5af7a60..80bbe76 100644 --- a/source/postcard-rpc/src/test_utils.rs +++ b/source/postcard-rpc/src/test_utils.rs @@ -1,10 +1,12 @@ //! Test utilities for doctests and integration tests -use std::collections::HashMap; +use core::{fmt::Display, future::Future}; +use crate::host_client::util::Stopper; use crate::{ - host_client::{HostClient, ProcessError, RpcFrame, WireContext}, - Endpoint, Key, Topic, WireHeader, + headered::extract_header_from_bytes, + host_client::{HostClient, RpcFrame, WireRx, WireSpawn, WireTx}, + Endpoint, Topic, WireHeader, }; use postcard::experimental::schema::Schema; use serde::{de::DeserializeOwned, Serialize}; @@ -13,121 +15,187 @@ use tokio::{ sync::mpsc::{channel, Receiver, Sender}, }; -/// This function creates a directly-linked Server and Client. -/// -/// This is useful for testing and demonstrating server/client behavior, -/// without actually requiring an external device. -pub fn local_setup(bound: usize, err_uri_path: &str) -> (LocalServer, HostClient) -where - E: Schema + DeserializeOwned, -{ - let (srv_tx, srv_rx) = channel(bound); - let (cli_tx, cli_rx) = channel(bound); - let srv = LocalServer { - from_client: cli_rx, - to_client: srv_tx, - }; - let cli = LocalClient { - to_server: cli_tx, - from_server: srv_rx, - }; - let cli = make_client::(cli, bound, err_uri_path); - (srv, cli) +pub struct LocalRx { + fake_error: Stopper, + from_server: Receiver>, } - -pub struct LocalServer { - pub from_client: Receiver, - pub to_client: Sender, +pub struct LocalTx { + fake_error: Stopper, + to_server: Sender>, +} +pub struct LocalSpawn; +pub struct LocalFakeServer { + fake_error: Stopper, + pub from_client: Receiver>, + pub to_client: Sender>, } -impl LocalServer { - pub async fn reply(&mut self, seq_no: u32, msg: &E::Response) -> Result<(), ()> +impl LocalFakeServer { + pub async fn recv_from_client(&mut self) -> Result { + let msg = self.from_client.recv().await.ok_or(LocalError::TxClosed)?; + let Ok((hdr, body)) = extract_header_from_bytes(&msg) else { + return Err(LocalError::BadFrame); + }; + Ok(RpcFrame { + header: hdr, + body: body.to_vec(), + }) + } + + pub async fn reply( + &mut self, + seq_no: u32, + data: &E::Response, + ) -> Result<(), LocalError> where E::Response: Serialize, { + let frame = RpcFrame { + header: WireHeader { + key: E::RESP_KEY, + seq_no, + }, + body: postcard::to_stdvec(data).unwrap(), + }; self.to_client - .send(RpcFrame { - header: WireHeader { - key: E::RESP_KEY, - seq_no, - }, - body: postcard::to_stdvec(msg).unwrap(), - }) + .send(frame.to_bytes()) .await - .map_err(drop) + .map_err(|_| LocalError::RxClosed) } - pub async fn publish(&mut self, seq_no: u32, msg: &T::Message) -> Result<(), ()> + pub async fn publish( + &mut self, + seq_no: u32, + data: &T::Message, + ) -> Result<(), LocalError> where T::Message: Serialize, { + let frame = RpcFrame { + header: WireHeader { + key: T::TOPIC_KEY, + seq_no, + }, + body: postcard::to_stdvec(data).unwrap(), + }; self.to_client - .send(RpcFrame { - header: WireHeader { - key: T::TOPIC_KEY, - seq_no, - }, - body: postcard::to_stdvec(msg).unwrap(), - }) + .send(frame.to_bytes()) .await - .map_err(drop) + .map_err(|_| LocalError::RxClosed) + } + + pub fn cause_fatal_error(&self) { + self.fake_error.stop(); } } -pub struct LocalClient { - pub to_server: Sender, - pub from_server: Receiver, +#[derive(Debug, PartialEq)] +pub enum LocalError { + RxClosed, + TxClosed, + BadFrame, + FatalError, } -pub fn make_client(cli: LocalClient, depth: usize, err_uri_path: &str) -> HostClient -where - E: Schema + DeserializeOwned, -{ - let (hcli, hcli_ctx) = HostClient::::new_manual_priv(err_uri_path, depth); - tokio::task::spawn(wire_worker(cli, hcli_ctx)); - hcli +impl Display for LocalError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + ::fmt(self, f) + } } -async fn wire_worker(mut cli: LocalClient, mut ctx: WireContext) { - let mut subs: HashMap> = HashMap::new(); - loop { - // Wait for EITHER a serialized request, OR some data from the embedded device - select! { - sub = ctx.new_subs.recv() => { - let Some(si) = sub else { - return; - }; - - subs.insert(si.key, si.tx); +impl std::error::Error for LocalError {} + +impl WireRx for LocalRx { + type Error = LocalError; + + #[allow(clippy::manual_async_fn)] + fn receive(&mut self) -> impl Future, Self::Error>> + Send { + async { + // This is not usually necessary - HostClient machinery takes care of listening + // to the stopper, but we have an EXTRA one to simulate I/O failure + let recv_fut = self.from_server.recv(); + let error_fut = self.fake_error.wait_stopped(); + + // Before we await, do a quick check to see if an error occured, this way + // recv can't accidentally win the select + if self.fake_error.is_stopped() { + return Err(LocalError::FatalError); + } + + select! { + recv = recv_fut => recv.ok_or(LocalError::RxClosed), + _err = error_fut => Err(LocalError::FatalError), } - out = ctx.outgoing.recv() => { - let Some(msg) = out else { - return; - }; - if cli.to_server.send(msg).await.is_err() { - return; - } + } + } +} + +impl WireTx for LocalTx { + type Error = LocalError; + + #[allow(clippy::manual_async_fn)] + fn send(&mut self, data: Vec) -> impl Future> + Send { + async { + // This is not usually necessary - HostClient machinery takes care of listening + // to the stopper, but we have an EXTRA one to simulate I/O failure + let send_fut = self.to_server.send(data); + let error_fut = self.fake_error.wait_stopped(); + + // Before we await, do a quick check to see if an error occured, this way + // send can't accidentally win the select + if self.fake_error.is_stopped() { + return Err(LocalError::FatalError); } - inc = cli.from_server.recv() => { - let Some(msg) = inc else { - return; - }; - // Give priority to subscriptions. TBH I only do this because I know a hashmap - // lookup is cheaper than a waitmap search. - let key = msg.header.key; - if let Some(tx) = subs.get_mut(&key) { - // Yup, we have a subscription - if tx.send(msg).await.is_err() { - // But if sending failed, the listener is gone, so drop it - subs.remove(&key); - } - } else { - // Wake the given sequence number. If the WaitMap is closed, we're done here - if let Err(ProcessError::Closed) = ctx.incoming.process(msg) { - return; - } - } + + select! { + send = send_fut => send.map_err(|_| LocalError::TxClosed), + _err = error_fut => Err(LocalError::FatalError), } } } } + +impl WireSpawn for LocalSpawn { + fn spawn(&mut self, fut: impl Future + Send + 'static) { + tokio::task::spawn(fut); + } +} + +/// This function creates a directly-linked Server and Client. +/// +/// This is useful for testing and demonstrating server/client behavior, +/// without actually requiring an external device. +pub fn local_setup(bound: usize, err_uri_path: &str) -> (LocalFakeServer, HostClient) +where + E: Schema + DeserializeOwned, +{ + let (c2s_tx, c2s_rx) = channel(bound); + let (s2c_tx, s2c_rx) = channel(bound); + + // NOTE: the normal HostClient machinery has it's own Stopper used for signalling + // errors, this is an EXTRA stopper we use to simulate the error occurring, like + // if our USB device disconnected or the serial port was closed + let fake_error = Stopper::new(); + + let client = HostClient::::new_with_wire( + LocalTx { + to_server: c2s_tx, + fake_error: fake_error.clone(), + }, + LocalRx { + from_server: s2c_rx, + fake_error: fake_error.clone(), + }, + LocalSpawn, + err_uri_path, + bound, + ); + + let lfs = LocalFakeServer { + from_client: c2s_rx, + to_client: s2c_tx, + fake_error: fake_error.clone(), + }; + + (lfs, client) +}