Skip to content

Commit

Permalink
Make test_channel server stoppable
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesmunns committed Nov 1, 2024
1 parent 291c591 commit 066db48
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 31 deletions.
2 changes: 1 addition & 1 deletion example/firmware/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion example/workbook-host/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion source/postcard-rpc-test/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 46 additions & 2 deletions source/postcard-rpc-test/tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use std::{sync::Arc, time::Instant};

use postcard_schema::{schema::owned::OwnedNamedType, Schema};
use serde::{Deserialize, Serialize};
use tokio::{sync::mpsc, task::yield_now};
use tokio::{sync::mpsc, task::yield_now, time::timeout};

use postcard_rpc::{
define_dispatch, endpoints,
header::{VarHeader, VarKey, VarKeyKind, VarSeq, VarSeqKind},
host_client::{test_channels as client, HostClient},
server::{
impls::test_channels::{
dispatch_impl::{new_server, spawn_fn, Settings, WireSpawnImpl, WireTxImpl},
dispatch_impl::{new_server, new_server_stoppable, spawn_fn, Settings, WireSpawnImpl, WireTxImpl},
ChannelWireRx, ChannelWireSpawn, ChannelWireTx,
},
Dispatch, Sender, SpawnContext,
Expand Down Expand Up @@ -561,3 +561,47 @@ fn device_map() {
println!("{:?}", app.device_map.min_key_len);
println!();
}

#[tokio::test]
async fn end_to_end_stoppable() {
let (client_tx, server_rx) = mpsc::channel(16);
let (server_tx, client_rx) = mpsc::channel(16);
let topic_ctr = Arc::new(AtomicUsize::new(0));

let app = SingleDispatcher::new(
TestContext {
ctr: Arc::new(AtomicUsize::new(0)),
topic_ctr: topic_ctr.clone(),
msg: String::from("hello"),
},
ChannelWireSpawn {},
);

let cwrx = ChannelWireRx::new(server_rx);
let cwtx = ChannelWireTx::new(server_tx);

let kkind = app.min_key_len();
let (mut server, stopper) = new_server_stoppable(
app,
Settings {
tx: cwtx,
rx: cwrx,
buf: 1024,
kkind,
},
);
let hdl = tokio::task::spawn(async move {
server.run().await;
});

let cli = client::new_from_channels(client_tx, client_rx, VarSeqKind::Seq1);

let resp = cli.send_resp::<AlphaEndpoint>(&AReq(42)).await.unwrap();
assert_eq!(resp.0, 42);
stopper.stop();
match timeout(Duration::from_millis(100), hdl).await {
Ok(Ok(())) => {},
Ok(Err(e)) => panic!("Server task panicked? {e:?}"),
Err(_) => panic!("Server task did not stop!"),
}
}
2 changes: 1 addition & 1 deletion source/postcard-rpc/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "postcard-rpc"
version = "0.10.1"
version = "0.10.2"
authors = ["James Munns <[email protected]>"]
edition = "2021"
repository = "https://github.com/jamesmunns/postcard-rpc"
Expand Down
116 changes: 91 additions & 25 deletions source/postcard-rpc/src/server/impls/test_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
use core::{
convert::Infallible,
future::Future,
future::{pending, Future},
sync::atomic::{AtomicU32, Ordering},
};
use std::sync::Arc;

use crate::{
header::{VarHeader, VarKey, VarKeyKind, VarSeq},
host_client::util::Stopper,
server::{
AsWireRxErrorKind, AsWireTxErrorKind, WireRx, WireRxErrorKind, WireSpawn, WireTx,
WireTxErrorKind,
Expand All @@ -17,14 +18,15 @@ use crate::{
Topic,
};
use core::fmt::Arguments;
use tokio::sync::mpsc;
use tokio::{select, sync::mpsc};

//////////////////////////////////////////////////////////////////////////////
// DISPATCH IMPL
//////////////////////////////////////////////////////////////////////////////

/// A collection of types and aliases useful for importing the correct types
pub mod dispatch_impl {
pub use crate::host_client::util::Stopper;
use crate::{
header::VarKeyKind,
server::{Dispatch, Server},
Expand Down Expand Up @@ -70,6 +72,33 @@ pub mod dispatch_impl {
settings.kkind,
)
}

/// Create a new server using the [`Settings`] and [`Dispatch`] implementation
///
/// Also returns a [`Stopper`] that can be used to halt the server's operation
pub fn new_server_stoppable<D>(
dispatch: D,
mut settings: Settings,
) -> (
crate::server::Server<WireTxImpl, WireRxImpl, WireRxBuf, D>,
Stopper,
)
where
D: Dispatch<Tx = WireTxImpl>,
{
let stopper = Stopper::new();
settings.tx.set_stopper(stopper.clone());
settings.rx.set_stopper(stopper.clone());
let buf = vec![0; settings.buf];
let me = Server::new(
&settings.tx,
settings.rx,
buf.into_boxed_slice(),
dispatch,
settings.kkind,
);
(me, stopper)
}
}

//////////////////////////////////////////////////////////////////////////////
Expand All @@ -81,6 +110,7 @@ pub mod dispatch_impl {
pub struct ChannelWireTx {
tx: mpsc::Sender<Vec<u8>>,
log_ctr: Arc<AtomicU32>,
stopper: Option<Stopper>,
}

impl ChannelWireTx {
Expand All @@ -89,6 +119,33 @@ impl ChannelWireTx {
Self {
tx,
log_ctr: Arc::new(AtomicU32::new(0)),
stopper: None,
}
}

/// Add a stopper to listen for "close" methods
pub fn set_stopper(&mut self, stopper: Stopper) {
self.stopper = Some(stopper);
}

async fn inner_send(&self, msg: Vec<u8>) -> Result<(), ChannelWireTxError> {
let stop_fut = async {
if let Some(s) = self.stopper.as_ref() {
s.wait_stopped().await;
} else {
pending::<()>().await;
}
};
select! {
_ = stop_fut => {
Err(ChannelWireTxError::ChannelClosed)
}
res = self.tx.send(msg) => {
match res {
Ok(()) => Ok(()),
Err(_) => Err(ChannelWireTxError::ChannelClosed)
}
}
}
}
}
Expand All @@ -104,20 +161,12 @@ impl WireTx for ChannelWireTx {
let mut hdr_ser = hdr.write_to_vec();
let bdy_ser = postcard::to_stdvec(msg).unwrap();
hdr_ser.extend_from_slice(&bdy_ser);
self.tx
.send(hdr_ser)
.await
.map_err(|_| ChannelWireTxError::ChannelClosed)?;
Ok(())
self.inner_send(hdr_ser).await
}

async fn send_raw(&self, buf: &[u8]) -> Result<(), Self::Error> {
let buf = buf.to_vec();
self.tx
.send(buf)
.await
.map_err(|_| ChannelWireTxError::ChannelClosed)?;
Ok(())
self.inner_send(buf).await
}

async fn send_log_str(&self, kkind: VarKeyKind, s: &str) -> Result<(), Self::Error> {
Expand Down Expand Up @@ -158,11 +207,7 @@ impl WireTx for ChannelWireTx {
let msg = format!("{a}");
let msg = postcard::to_stdvec(&msg).unwrap();
buf.extend_from_slice(&msg);
self.tx
.send(buf)
.await
.map_err(|_| ChannelWireTxError::ChannelClosed)?;
Ok(())
self.inner_send(buf).await
}
}

Expand All @@ -188,12 +233,18 @@ impl AsWireTxErrorKind for ChannelWireTxError {
/// A [`WireRx`] impl using tokio mpsc channels
pub struct ChannelWireRx {
rx: mpsc::Receiver<Vec<u8>>,
stopper: Option<Stopper>,
}

impl ChannelWireRx {
/// Create a new [`ChannelWireRx`]
pub fn new(rx: mpsc::Receiver<Vec<u8>>) -> Self {
Self { rx }
Self { rx, stopper: None }
}

/// Add a stopper to listen for "close" methods
pub fn set_stopper(&mut self, stopper: Stopper) {
self.stopper = Some(stopper);
}
}

Expand All @@ -202,13 +253,28 @@ impl WireRx for ChannelWireRx {

async fn receive<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a mut [u8], Self::Error> {
// todo: some kind of receive_owned?
let msg = self.rx.recv().await;
let msg = msg.ok_or(ChannelWireRxError::ChannelClosed)?;
let out = buf
.get_mut(..msg.len())
.ok_or(ChannelWireRxError::MessageTooLarge)?;
out.copy_from_slice(&msg);
Ok(out)
let ChannelWireRx { rx, stopper } = self;
let stop_fut = async {
if let Some(s) = stopper.as_ref() {
s.wait_stopped().await;
} else {
pending::<()>().await;
}
};

select! {
_ = stop_fut => {
Err(ChannelWireRxError::ChannelClosed)
}
msg = rx.recv() => {
let msg = msg.ok_or(ChannelWireRxError::ChannelClosed)?;
let out = buf
.get_mut(..msg.len())
.ok_or(ChannelWireRxError::MessageTooLarge)?;
out.copy_from_slice(&msg);
Ok(out)
}
}
}
}

Expand Down

0 comments on commit 066db48

Please sign in to comment.