From 5d0ab79bbfb6312df3bba68fd5ae5a78432169b1 Mon Sep 17 00:00:00 2001 From: James Munns Date: Sun, 27 Oct 2024 18:00:53 +0100 Subject: [PATCH] Rework basically the entire protocol (#53) This is a major change to postcard-rpc. This is a **very breaking wire change!** In particular, this PR changes the following: ## Server Rework Signifcantly reworks how "servers" are implemented, removing the previous embassy-usb-specific one in favor of more reusable parts, that allow for implementing a server generically over different transports. This new version still has an implementation for embassy-usb 0.3, but ALSO provides a channel-based implementation for testing, and I am likely to port a TCP-based one I have in `poststation` as well. This unlocks the ability to reuse the bulk of the existing code for supporting other transports, like UART, SPI, I2C, Ethernet, or even over radio. ## `define_dispatch` macro rework As part of the Server Rework, I also mostly rewrote the `define_dispatch!` macro, which now can be used with ANY transport, not just embassy-usb. This change also now allows servers to define topic handlers, so incoming published messages can be dispatched similar to endpoint dispatching. CC #15 ## Automatic Key Shrinking Previously, we would always send the full 8-byte "hash of the path and schema" ID in every message, as well as checking at compile time whether there is a collision or not. This PR takes that up a notch, now calculating the *smallest* hash key we can use (1, 2, 4, or 8 bytes) automatically at compile time, and uses that as our "native" mapping. This is similar to the concept of "Perfect Hash Functions". ## Variable Sequence Number size Additionally, the client can now send sequence numbers of 1, 2, or 4 bytes. Previously, sequence numbers were a `varint(u32)`. Servers will always respond back with the same sequence number they received when replying to requests. ## Completely redo message headers As we now have variable sized keys and sequence numbers, headers can now scale dynamically to the necessary size. CC #51 Before, we had 8 byte keys (fixed) and 1-5 bytes (`varint(u32)`), meaning headers were between **9-13 bytes**. Now, we use one byte as a discriminant, containing the key and seqno len, as well as a 4-bit version field, as well as the variable key and variable sequence number. This now means that headers are **3-13 bytes** (1B discriminant + 1/2/4/8B key + 1/2/4B sequence number), and will be **3 bytes in many common cases** where it is not necessary to disambiguate more than 256 in-flight messages (via sequence numbers) or 256 endpoints (though the liklihood of having a collision at 8 bits is higher than that due to the birthday problem). When a Client first connects to a Server, it will always start by sending an 8B key. If the Server replies with a shorter key, the Client will then switch to using keys of that size. It is not necessary to ever hardcode what size keys are necessary, as this is calculated when the Server is compiled, and is automatically detected by the Client. In general: * The CLIENT (usually the PC) is in charge of picking the size of the sequence number * The SERVER (usually the MCU) is in charge of picking the size of the message keys --- example/firmware/Cargo.lock | 30 +- example/firmware/src/bin/comms-01.rs | 124 ++- example/firmware/src/bin/comms-02.rs | 173 ++-- example/workbook-host/Cargo.lock | 24 + example/workbook-host/src/bin/comms-01.rs | 2 +- example/workbook-host/src/client.rs | 17 +- example/workbook-icd/Cargo.lock | 9 +- example/workbook-icd/src/lib.rs | 35 +- source/postcard-rpc-test/Cargo.lock | 24 + source/postcard-rpc-test/tests/basic.rs | 658 +++++++----- source/postcard-rpc/Cargo.toml | 4 +- source/postcard-rpc/src/accumulator.rs | 101 +- source/postcard-rpc/src/hash.rs | 14 +- source/postcard-rpc/src/header.rs | 630 ++++++++++++ source/postcard-rpc/src/headered.rs | 126 --- source/postcard-rpc/src/host_client/mod.rs | 162 +-- .../postcard-rpc/src/host_client/raw_nusb.rs | 18 +- source/postcard-rpc/src/host_client/serial.rs | 10 +- .../src/host_client/test_channels.rs | 85 ++ source/postcard-rpc/src/host_client/util.rs | 74 +- source/postcard-rpc/src/host_client/webusb.rs | 19 +- source/postcard-rpc/src/lib.rs | 569 ++++++----- source/postcard-rpc/src/macros.rs | 224 ++++- .../postcard-rpc/src/server/dispatch_macro.rs | 454 +++++++++ .../src/server/impls/embassy_usb_v0_3.rs | 669 +++++++++++++ source/postcard-rpc/src/server/impls/mod.rs | 9 + .../src/server/impls/test_channels.rs | 199 ++++ source/postcard-rpc/src/server/mod.rs | 522 ++++++++++ .../postcard-rpc/src/target_server/buffers.rs | 36 - .../src/target_server/dispatch_macro.rs | 472 --------- source/postcard-rpc/src/target_server/mod.rs | 195 ---- .../postcard-rpc/src/target_server/sender.rs | 319 ------ source/postcard-rpc/src/test_utils.rs | 34 +- source/postcard-rpc/src/uniques.rs | 944 ++++++++++++++++++ 34 files changed, 4964 insertions(+), 2021 deletions(-) create mode 100644 source/postcard-rpc/src/header.rs delete mode 100644 source/postcard-rpc/src/headered.rs create mode 100644 source/postcard-rpc/src/host_client/test_channels.rs create mode 100644 source/postcard-rpc/src/server/dispatch_macro.rs create mode 100644 source/postcard-rpc/src/server/impls/embassy_usb_v0_3.rs create mode 100644 source/postcard-rpc/src/server/impls/mod.rs create mode 100644 source/postcard-rpc/src/server/impls/test_channels.rs create mode 100644 source/postcard-rpc/src/server/mod.rs delete mode 100644 source/postcard-rpc/src/target_server/buffers.rs delete mode 100644 source/postcard-rpc/src/target_server/dispatch_macro.rs delete mode 100644 source/postcard-rpc/src/target_server/mod.rs delete mode 100644 source/postcard-rpc/src/target_server/sender.rs create mode 100644 source/postcard-rpc/src/uniques.rs diff --git a/example/firmware/Cargo.lock b/example/firmware/Cargo.lock index 3a594ba..da38593 100644 --- a/example/firmware/Cargo.lock +++ b/example/firmware/Cargo.lock @@ -344,17 +344,6 @@ dependencies = [ "nb 1.1.0", ] -[[package]] -name = "embassy-executor" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec648daedd2143466eff4b3e8002024f9f6c1de4ab7666bb679688752624c925" -dependencies = [ - "critical-section", - "document-features", - "embassy-executor-macros 0.4.1", -] - [[package]] name = "embassy-executor" version = "0.6.0" @@ -364,23 +353,11 @@ dependencies = [ "critical-section", "defmt", "document-features", - "embassy-executor-macros 0.5.0", + "embassy-executor-macros", "embassy-time-driver", "embassy-time-queue-driver", ] -[[package]] -name = "embassy-executor-macros" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad454accf80050e9cf7a51e994132ba0e56286b31f9317b68703897c328c59b5" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 2.0.72", -] - [[package]] name = "embassy-executor-macros" version = "0.5.0" @@ -1154,12 +1131,13 @@ dependencies = [ name = "postcard-rpc" version = "0.8.0" dependencies = [ - "embassy-executor 0.5.0", + "embassy-executor", "embassy-sync", "embassy-usb", "embassy-usb-driver", "futures-util", "heapless 0.8.0", + "paste", "postcard", "postcard-schema", "serde", @@ -1759,7 +1737,7 @@ dependencies = [ "cortex-m-rt", "defmt", "defmt-rtt", - "embassy-executor 0.6.0", + "embassy-executor", "embassy-rp", "embassy-sync", "embassy-time", diff --git a/example/firmware/src/bin/comms-01.rs b/example/firmware/src/bin/comms-01.rs index f188857..2aa2de8 100644 --- a/example/firmware/src/bin/comms-01.rs +++ b/example/firmware/src/bin/comms-01.rs @@ -3,77 +3,119 @@ use defmt::info; use embassy_executor::Spawner; -use embassy_rp::{ - peripherals::USB, - usb::{self, Driver, Endpoint, Out}, -}; - +use embassy_rp::{peripherals::USB, usb}; use embassy_sync::blocking_mutex::raw::ThreadModeRawMutex; - -use embassy_usb::UsbDevice; +use embassy_usb::{Config, UsbDevice}; use postcard_rpc::{ define_dispatch, - target_server::{buffers::AllBuffers, configure_usb, example_config, rpc_dispatch}, - WireHeader, + header::VarHeader, + server::{ + impls::embassy_usb_v0_3::{ + dispatch_impl::{WireRxBuf, WireRxImpl, WireSpawnImpl, WireStorage, WireTxImpl}, + PacketBuffers, + }, + Dispatch, Server, + }, }; - use static_cell::ConstStaticCell; use workbook_fw::{get_unique_id, Irqs}; -use workbook_icd::PingEndpoint; +use workbook_icd::{PingEndpoint, ENDPOINT_LIST, TOPICS_IN_LIST, TOPICS_OUT_LIST}; +use {defmt_rtt as _, panic_probe as _}; + +pub struct Context; -static ALL_BUFFERS: ConstStaticCell> = - ConstStaticCell::new(AllBuffers::new()); +type AppDriver = usb::Driver<'static, USB>; +type AppStorage = WireStorage; +type BufStorage = PacketBuffers<1024, 1024>; +type AppTx = WireTxImpl; +type AppRx = WireRxImpl; +type AppServer = Server; -pub struct Context {} +static PBUFS: ConstStaticCell = ConstStaticCell::new(BufStorage::new()); +static STORAGE: AppStorage = AppStorage::new(); +fn usb_config() -> Config<'static> { + let mut config = Config::new(0x16c0, 0x27DD); + config.manufacturer = Some("OneVariable"); + config.product = Some("ov-twin"); + config.serial_number = Some("12345678"); + + // Required for windows compatibility. + // https://developer.nordicsemi.com/nRF_Connect_SDK/doc/1.9.1/kconfig/CONFIG_CDC_ACM_IAD.html#help + config.device_class = 0xEF; + config.device_sub_class = 0x02; + config.device_protocol = 0x01; + config.composite_with_iads = true; + + config +} define_dispatch! { - dispatcher: Dispatcher< - Mutex = ThreadModeRawMutex, - Driver = usb::Driver<'static, USB>, - Context = Context - >; - PingEndpoint => blocking ping_handler, + app: MyApp; + spawn_fn: spawn_fn; + tx_impl: AppTx; + spawn_impl: WireSpawnImpl; + context: Context; + + endpoints: { + list: ENDPOINT_LIST; + + | EndpointTy | kind | handler | + | ---------- | ---- | ------- | + | PingEndpoint | blocking | ping_handler | + }; + topics_in: { + list: TOPICS_IN_LIST; + + | TopicTy | kind | handler | + | ---------- | ---- | ------- | + }; + topics_out: { + list: TOPICS_OUT_LIST; + }; } #[embassy_executor::main] async fn main(spawner: Spawner) { // SYSTEM INIT info!("Start"); - let mut p = embassy_rp::init(Default::default()); - let unique_id = get_unique_id(&mut p.FLASH).unwrap(); + let unique_id = defmt::unwrap!(get_unique_id(&mut p.FLASH)); info!("id: {=u64:016X}", unique_id); // USB/RPC INIT let driver = usb::Driver::new(p.USB, Irqs); - let mut config = example_config(); - config.manufacturer = Some("OneVariable"); - config.product = Some("ov-twin"); - let buffers = ALL_BUFFERS.take(); - let (device, ep_in, ep_out) = configure_usb(driver, &mut buffers.usb_device, config); - let dispatch = Dispatcher::new(&mut buffers.tx_buf, ep_in, Context {}); + let pbufs = PBUFS.take(); + let config = usb_config(); - spawner.must_spawn(dispatch_task(ep_out, dispatch, &mut buffers.rx_buf)); + let context = Context; + let (device, tx_impl, rx_impl) = STORAGE.init(driver, config, pbufs.tx_buf.as_mut_slice()); + let dispatcher = MyApp::new(context, spawner.into()); + let vkk = dispatcher.min_key_len(); + let mut server: AppServer = Server::new( + &tx_impl, + rx_impl, + pbufs.rx_buf.as_mut_slice(), + dispatcher, + vkk, + ); spawner.must_spawn(usb_task(device)); -} -/// This actually runs the dispatcher -#[embassy_executor::task] -async fn dispatch_task( - ep_out: Endpoint<'static, USB, Out>, - dispatch: Dispatcher, - rx_buf: &'static mut [u8], -) { - rpc_dispatch(ep_out, dispatch, rx_buf).await; + loop { + // If the host disconnects, we'll return an error here. + // If this happens, just wait until the host reconnects + let _ = server.run().await; + } } /// This handles the low level USB management #[embassy_executor::task] -pub async fn usb_task(mut usb: UsbDevice<'static, Driver<'static, USB>>) { +pub async fn usb_task(mut usb: UsbDevice<'static, AppDriver>) { usb.run().await; } -fn ping_handler(_context: &mut Context, header: WireHeader, rqst: u32) -> u32 { - info!("ping: seq - {=u32}", header.seq_no); +// --- + +fn ping_handler(_context: &mut Context, _header: VarHeader, rqst: u32) -> u32 { + info!("ping"); rqst } diff --git a/example/firmware/src/bin/comms-02.rs b/example/firmware/src/bin/comms-02.rs index ad1efff..4461a34 100644 --- a/example/firmware/src/bin/comms-02.rs +++ b/example/firmware/src/bin/comms-02.rs @@ -8,31 +8,38 @@ use embassy_rp::{ peripherals::{PIO0, SPI0, USB}, pio::Pio, spi::{self, Spi}, - usb::{self, Driver, Endpoint, Out}, + usb, }; use embassy_sync::{blocking_mutex::raw::ThreadModeRawMutex, mutex::Mutex}; use embassy_time::{Delay, Duration, Ticker}; -use embassy_usb::UsbDevice; +use embassy_usb::{Config, UsbDevice}; use embedded_hal_bus::spi::ExclusiveDevice; use lis3dh_async::{Lis3dh, Lis3dhSPI}; use portable_atomic::{AtomicBool, Ordering}; use postcard_rpc::{ define_dispatch, - target_server::{ - buffers::AllBuffers, configure_usb, example_config, rpc_dispatch, sender::Sender, - SpawnContext, + header::VarHeader, + server::{ + impls::embassy_usb_v0_3::{ + dispatch_impl::{ + spawn_fn, WireRxBuf, WireRxImpl, WireSpawnImpl, WireStorage, WireTxImpl, + }, + PacketBuffers, + }, + Dispatch, Sender, Server, SpawnContext, }, - WireHeader, }; use smart_leds::{colors::BLACK, RGB8}; use static_cell::{ConstStaticCell, StaticCell}; use workbook_fw::{ - get_unique_id, ws2812::{self, Ws2812}, Irqs + get_unique_id, + ws2812::{self, Ws2812}, + Irqs, }; use workbook_icd::{ AccelTopic, Acceleration, BadPositionError, GetUniqueIdEndpoint, PingEndpoint, Rgb8, SetAllLedEndpoint, SetSingleLedEndpoint, SingleLed, StartAccel, StartAccelerationEndpoint, - StopAccelerationEndpoint, + StopAccelerationEndpoint, ENDPOINT_LIST, TOPICS_IN_LIST, TOPICS_OUT_LIST, }; use {defmt_rtt as _, panic_probe as _}; @@ -40,9 +47,6 @@ pub type Accel = Lis3dh, Output<'static>, Delay>>>; static ACCEL: StaticCell> = StaticCell::new(); -static ALL_BUFFERS: ConstStaticCell> = - ConstStaticCell::new(AllBuffers::new()); - pub struct Context { pub unique_id: u64, pub ws2812: Ws2812<'static, PIO0, 0, 24>, @@ -61,18 +65,60 @@ impl SpawnContext for Context { } } +type AppDriver = usb::Driver<'static, USB>; +type AppStorage = WireStorage; +type BufStorage = PacketBuffers<1024, 1024>; +type AppTx = WireTxImpl; +type AppRx = WireRxImpl; +type AppServer = Server; + +static PBUFS: ConstStaticCell = ConstStaticCell::new(BufStorage::new()); +static STORAGE: AppStorage = AppStorage::new(); + +fn usb_config() -> Config<'static> { + let mut config = Config::new(0x16c0, 0x27DD); + config.manufacturer = Some("OneVariable"); + config.product = Some("ov-twin"); + config.serial_number = Some("12345678"); + + // Required for windows compatibility. + // https://developer.nordicsemi.com/nRF_Connect_SDK/doc/1.9.1/kconfig/CONFIG_CDC_ACM_IAD.html#help + config.device_class = 0xEF; + config.device_sub_class = 0x02; + config.device_protocol = 0x01; + config.composite_with_iads = true; + + config +} + define_dispatch! { - dispatcher: Dispatcher< - Mutex = ThreadModeRawMutex, - Driver = usb::Driver<'static, USB>, - Context = Context, - >; - PingEndpoint => blocking ping_handler, - GetUniqueIdEndpoint => blocking unique_id_handler, - SetSingleLedEndpoint => async set_led_handler, - SetAllLedEndpoint => async set_all_led_handler, - StartAccelerationEndpoint => spawn accelerometer_handler, - StopAccelerationEndpoint => blocking accelerometer_stop_handler, + app: MyApp; + spawn_fn: spawn_fn; + tx_impl: AppTx; + spawn_impl: WireSpawnImpl; + context: Context; + + endpoints: { + list: ENDPOINT_LIST; + + | EndpointTy | kind | handler | + | ---------- | ---- | ------- | + | PingEndpoint | blocking | ping_handler | + | GetUniqueIdEndpoint | blocking | unique_id_handler | + | SetSingleLedEndpoint | async | set_led_handler | + | SetAllLedEndpoint | async | set_all_led_handler | + | StartAccelerationEndpoint | spawn | accelerometer_handler | + | StopAccelerationEndpoint | blocking | accelerometer_stop_handler | + }; + topics_in: { + list: TOPICS_IN_LIST; + + | TopicTy | kind | handler | + | ---------- | ---- | ------- | + }; + topics_out: { + list: TOPICS_OUT_LIST; + }; } #[embassy_executor::main] @@ -106,60 +152,59 @@ async fn main(spawner: Spawner) { // USB/RPC INIT let driver = usb::Driver::new(p.USB, Irqs); - let mut config = example_config(); - config.manufacturer = Some("OneVariable"); - config.product = Some("ov-twin"); - let buffers = ALL_BUFFERS.take(); - let (device, ep_in, ep_out) = configure_usb(driver, &mut buffers.usb_device, config); - let dispatch = Dispatcher::new( - &mut buffers.tx_buf, - ep_in, - Context { - unique_id, - ws2812, - ws2812_state: [BLACK; 24], - accel: accel_ref, - }, - ); + let pbufs = PBUFS.take(); + let config = usb_config(); - spawner.must_spawn(dispatch_task(ep_out, dispatch, &mut buffers.rx_buf)); + let context = Context { + unique_id, + ws2812, + ws2812_state: [BLACK; 24], + accel: accel_ref, + }; + + let (device, tx_impl, rx_impl) = STORAGE.init(driver, config, pbufs.tx_buf.as_mut_slice()); + let dispatcher = MyApp::new(context, spawner.into()); + let vkk = dispatcher.min_key_len(); + let mut server: AppServer = Server::new( + &tx_impl, + rx_impl, + pbufs.rx_buf.as_mut_slice(), + dispatcher, + vkk, + ); spawner.must_spawn(usb_task(device)); -} -/// This actually runs the dispatcher -#[embassy_executor::task] -async fn dispatch_task( - ep_out: Endpoint<'static, USB, Out>, - dispatch: Dispatcher, - rx_buf: &'static mut [u8], -) { - rpc_dispatch(ep_out, dispatch, rx_buf).await; + loop { + // If the host disconnects, we'll return an error here. + // If this happens, just wait until the host reconnects + let _ = server.run().await; + } } /// This handles the low level USB management #[embassy_executor::task] -pub async fn usb_task(mut usb: UsbDevice<'static, Driver<'static, USB>>) { +pub async fn usb_task(mut usb: UsbDevice<'static, AppDriver>) { usb.run().await; } // --- -fn ping_handler(_context: &mut Context, header: WireHeader, rqst: u32) -> u32 { - info!("ping: seq - {=u32}", header.seq_no); +fn ping_handler(_context: &mut Context, _header: VarHeader, rqst: u32) -> u32 { + info!("ping"); rqst } -fn unique_id_handler(context: &mut Context, header: WireHeader, _rqst: ()) -> u64 { - info!("unique_id: seq - {=u32}", header.seq_no); +fn unique_id_handler(context: &mut Context, _header: VarHeader, _rqst: ()) -> u64 { + info!("unique_id"); context.unique_id } async fn set_led_handler( context: &mut Context, - header: WireHeader, + _header: VarHeader, rqst: SingleLed, ) -> Result<(), BadPositionError> { - info!("set_led: seq - {=u32}", header.seq_no); + info!("set_led"); if rqst.position >= 24 { return Err(BadPositionError); } @@ -173,8 +218,8 @@ async fn set_led_handler( Ok(()) } -async fn set_all_led_handler(context: &mut Context, header: WireHeader, rqst: [Rgb8; 24]) { - info!("set_all_led: seq - {=u32}", header.seq_no); +async fn set_all_led_handler(context: &mut Context, _header: VarHeader, rqst: [Rgb8; 24]) { + info!("set_all_led"); context .ws2812_state .iter_mut() @@ -192,9 +237,9 @@ static STOP: AtomicBool = AtomicBool::new(false); #[embassy_executor::task] async fn accelerometer_handler( context: SpawnCtx, - header: WireHeader, + header: VarHeader, rqst: StartAccel, - sender: Sender>, + sender: Sender, ) { let mut accel = context.accel.lock().await; if sender @@ -209,7 +254,7 @@ async fn accelerometer_handler( defmt::unwrap!(accel.set_range(lis3dh_async::Range::G8).await.map_err(drop)); let mut ticker = Ticker::every(Duration::from_millis(rqst.interval_ms.into())); - let mut seq = 0; + let mut seq = 0u8; while !STOP.load(Ordering::Acquire) { ticker.next().await; let acc = defmt::unwrap!(accel.accel_raw().await.map_err(drop)); @@ -219,7 +264,11 @@ async fn accelerometer_handler( y: acc.y, z: acc.z, }; - if sender.publish::(seq, &msg).await.is_err() { + if sender + .publish::(seq.into(), &msg) + .await + .is_err() + { defmt::error!("Send error!"); break; } @@ -229,8 +278,8 @@ async fn accelerometer_handler( STOP.store(false, Ordering::Release); } -fn accelerometer_stop_handler(context: &mut Context, header: WireHeader, _rqst: ()) -> bool { - info!("accel_stop: seq - {=u32}", header.seq_no); +fn accelerometer_stop_handler(context: &mut Context, _header: VarHeader, _rqst: ()) -> bool { + info!("accel_stop"); let was_busy = context.accel.try_lock().is_err(); if was_busy { STOP.store(true, Ordering::Release); diff --git a/example/workbook-host/Cargo.lock b/example/workbook-host/Cargo.lock index d50a8c7..86a1103 100644 --- a/example/workbook-host/Cargo.lock +++ b/example/workbook-host/Cargo.lock @@ -136,6 +136,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "errno" version = "0.3.8" @@ -407,6 +413,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pin-project" version = "1.1.5" @@ -469,9 +481,11 @@ dependencies = [ "heapless 0.8.0", "maitake-sync", "nusb", + "paste", "postcard", "postcard-schema", "serde", + "ssmarshal", "thiserror", "tokio", "tracing", @@ -655,6 +669,16 @@ dependencies = [ "lock_api", ] +[[package]] +name = "ssmarshal" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3e6ad23b128192ed337dfa4f1b8099ced0c2bf30d61e551b65fda5916dbb850" +dependencies = [ + "encode_unicode", + "serde", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/example/workbook-host/src/bin/comms-01.rs b/example/workbook-host/src/bin/comms-01.rs index 24306d8..1313f6e 100644 --- a/example/workbook-host/src/bin/comms-01.rs +++ b/example/workbook-host/src/bin/comms-01.rs @@ -1,7 +1,7 @@ use std::time::Duration; -use workbook_host_client::client::WorkbookClient; use tokio::time::interval; +use workbook_host_client::client::WorkbookClient; #[tokio::main] pub async fn main() { diff --git a/example/workbook-host/src/client.rs b/example/workbook-host/src/client.rs index 3e8ded2..0688a7a 100644 --- a/example/workbook-host/src/client.rs +++ b/example/workbook-host/src/client.rs @@ -1,9 +1,14 @@ -use std::convert::Infallible; use postcard_rpc::{ + header::VarSeqKind, host_client::{HostClient, HostErr}, standard_icd::{WireError, ERROR_PATH}, }; -use workbook_icd::{AccelRange, BadPositionError, GetUniqueIdEndpoint, PingEndpoint, Rgb8, SetAllLedEndpoint, SetSingleLedEndpoint, SingleLed, StartAccel, StartAccelerationEndpoint, StopAccelerationEndpoint}; +use std::convert::Infallible; +use workbook_icd::{ + AccelRange, BadPositionError, GetUniqueIdEndpoint, PingEndpoint, Rgb8, SetAllLedEndpoint, + SetSingleLedEndpoint, SingleLed, StartAccel, StartAccelerationEndpoint, + StopAccelerationEndpoint, +}; pub struct WorkbookClient { pub client: HostClient, @@ -39,8 +44,12 @@ impl FlattenErr for Result { impl WorkbookClient { pub fn new() -> Self { - let client = - HostClient::new_raw_nusb(|d| d.product_string() == Some("ov-twin"), ERROR_PATH, 8); + let client = HostClient::new_raw_nusb( + |d| d.product_string() == Some("ov-twin"), + ERROR_PATH, + 8, + VarSeqKind::Seq2, + ); Self { client } } diff --git a/example/workbook-icd/Cargo.lock b/example/workbook-icd/Cargo.lock index a2cc8d1..a39eb63 100644 --- a/example/workbook-icd/Cargo.lock +++ b/example/workbook-icd/Cargo.lock @@ -87,6 +87,12 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "postcard" version = "1.0.8" @@ -111,9 +117,10 @@ dependencies = [ [[package]] name = "postcard-rpc" -version = "0.7.0" +version = "0.8.0" dependencies = [ "heapless 0.8.0", + "paste", "postcard", "postcard-schema", "serde", diff --git a/example/workbook-icd/src/lib.rs b/example/workbook-icd/src/lib.rs index 2783dc1..ecf109c 100644 --- a/example/workbook-icd/src/lib.rs +++ b/example/workbook-icd/src/lib.rs @@ -1,22 +1,35 @@ #![no_std] +use postcard_rpc::{endpoints, topics}; use postcard_schema::Schema; -use postcard_rpc::{endpoint, topic}; use serde::{Deserialize, Serialize}; -endpoint!(PingEndpoint, u32, u32, "ping"); - // --- -endpoint!(GetUniqueIdEndpoint, (), u64, "unique_id/get"); - -endpoint!(SetSingleLedEndpoint, SingleLed, Result<(), BadPositionError>, "led/set_one"); -endpoint!(SetAllLedEndpoint, [Rgb8; 24], (), "led/set_all"); +endpoints! { + list = ENDPOINT_LIST; + | EndpointTy | RequestTy | ResponseTy | Path | + | ---------- | --------- | ---------- | ---- | + | PingEndpoint | u32 | u32 | "ping" | + | GetUniqueIdEndpoint | () | u64 | "unique_id/get" | + | SetSingleLedEndpoint | SingleLed | Result<(), BadPositionError> | "led/set_one" | + | SetAllLedEndpoint | [Rgb8; 24] | () | "led/set_all" | + | StartAccelerationEndpoint | StartAccel | () | "accel/start" | + | StopAccelerationEndpoint | () | bool | "accel/stop" | +} -endpoint!(StartAccelerationEndpoint, StartAccel, (), "accel/start"); -endpoint!(StopAccelerationEndpoint, (), bool, "accel/stop"); +topics! { + list = TOPICS_IN_LIST; + | TopicTy | MessageTy | Path | + | ------- | --------- | ---- | +} -topic!(AccelTopic, Acceleration, "accel/data"); +topics! { + list = TOPICS_OUT_LIST; + | TopicTy | MessageTy | Path | + | ------- | --------- | ---- | + | AccelTopic | Acceleration | "accel/data" | +} #[derive(Serialize, Deserialize, Schema, Debug, PartialEq)] pub struct SingleLed { @@ -28,7 +41,7 @@ pub struct SingleLed { pub struct Rgb8 { pub r: u8, pub g: u8, - pub b: u8 + pub b: u8, } #[derive(Serialize, Deserialize, Schema, Debug, PartialEq)] diff --git a/source/postcard-rpc-test/Cargo.lock b/source/postcard-rpc-test/Cargo.lock index 0306838..5c791ae 100644 --- a/source/postcard-rpc-test/Cargo.lock +++ b/source/postcard-rpc-test/Cargo.lock @@ -131,6 +131,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "generator" version = "0.7.5" @@ -320,6 +326,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pin-project" version = "1.1.3" @@ -394,9 +406,11 @@ version = "0.8.0" dependencies = [ "heapless 0.8.0", "maitake-sync", + "paste", "postcard", "postcard-schema", "serde", + "ssmarshal", "thiserror", "tokio", "tracing", @@ -569,6 +583,16 @@ dependencies = [ "lock_api", ] +[[package]] +name = "ssmarshal" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3e6ad23b128192ed337dfa4f1b8099ced0c2bf30d61e551b65fda5916dbb850" +dependencies = [ + "encode_unicode", + "serde", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/source/postcard-rpc-test/tests/basic.rs b/source/postcard-rpc-test/tests/basic.rs index e830fb1..33ffff7 100644 --- a/source/postcard-rpc-test/tests/basic.rs +++ b/source/postcard-rpc-test/tests/basic.rs @@ -1,305 +1,449 @@ -use std::{collections::HashMap, time::Duration}; - -use postcard_schema::Schema; -use postcard_rpc::test_utils::local_setup; -use postcard_rpc::{ - endpoint, headered::to_stdvec_keyed, topic, Dispatch, Endpoint, Key, Topic, WireHeader, +use core::{ + sync::atomic::{AtomicUsize, Ordering}, + time::Duration, }; +use std::{sync::Arc, time::Instant}; + +use postcard_schema::{schema::owned::OwnedNamedType, Schema}; use serde::{Deserialize, Serialize}; -use tokio::task::yield_now; -use tokio::time::timeout; +use tokio::{sync::mpsc, task::yield_now}; -endpoint!(EndpointOne, Req1, Resp1, "endpoint/one"); -topic!(TopicOne, Req1, "unsolicited/topic1"); +use postcard_rpc::{ + define_dispatch, endpoints, + header::{VarHeader, VarKey, VarKeyKind, VarSeq, VarSeqKind}, + host_client::test_channels as client, + server::{ + impls::test_channels::{ + dispatch_impl::{new_server, spawn_fn, Settings, WireSpawnImpl, WireTxImpl}, + ChannelWireRx, ChannelWireSpawn, ChannelWireTx, + }, + Dispatch, Sender, SpawnContext, + }, + topics, Endpoint, Topic, +}; -#[derive(Debug, PartialEq, Serialize, Deserialize, Schema)] -pub struct Req1 { - a: u8, - b: u64, +#[derive(Serialize, Deserialize, Schema)] +pub struct AReq(pub u8); +#[derive(Serialize, Deserialize, Schema)] +pub struct AResp(pub u8); +#[derive(Serialize, Deserialize, Schema)] +pub struct BReq(pub u16); +#[derive(Serialize, Deserialize, Schema)] +pub struct BResp(pub u32); +#[derive(Serialize, Deserialize, Schema)] +pub struct GReq; +#[derive(Serialize, Deserialize, Schema)] +pub struct GResp; +#[derive(Serialize, Deserialize, Schema)] +pub struct DReq; +#[derive(Serialize, Deserialize, Schema)] +pub struct DResp; +#[derive(Serialize, Deserialize, Schema)] +pub struct EReq; +#[derive(Serialize, Deserialize, Schema)] +pub struct EResp; +#[derive(Serialize, Deserialize, Schema)] +pub struct ZMsg(pub i16); + +endpoints! { + list = ENDPOINT_LIST; + | EndpointTy | RequestTy | ResponseTy | Path | + | ---------- | --------- | ---------- | ---- | + | AlphaEndpoint | AReq | AResp | "alpha" | + | BetaEndpoint | BReq | BResp | "beta" | + | GammaEndpoint | GReq | GResp | "gamma" | + | DeltaEndpoint | DReq | DResp | "delta" | + | EpsilonEndpoint | EReq | EResp | "epsilon" | } -#[derive(Debug, PartialEq, Serialize, Deserialize, Schema)] -pub struct Resp1 { - c: [u8; 8], - d: i32, +topics! { + list = TOPICS_IN_LIST; + | TopicTy | MessageTy | Path | + | ---------- | --------- | ---- | + | ZetaTopic1 | ZMsg | "zeta1" | + | ZetaTopic2 | ZMsg | "zeta2" | + | ZetaTopic3 | ZMsg | "zeta3" | } -#[derive(Debug, PartialEq, Serialize, Deserialize, Schema)] -pub enum WireError { - LeastBad, - MediumBad, - MostBad, +topics! { + list = TOPICS_OUT_LIST; + | TopicTy | MessageTy | Path | + | ---------- | --------- | ---- | + | ZetaTopic10 | ZMsg | "zeta10" | } -struct SmokeContext { - got: HashMap)>, - next_err: bool, +pub struct TestContext { + pub ctr: Arc, + pub topic_ctr: Arc, } -fn store_disp(hdr: &WireHeader, ctx: &mut SmokeContext, body: &[u8]) -> Result<(), WireError> { - if ctx.next_err { - ctx.next_err = false; - return Err(WireError::MediumBad); - } - ctx.got.insert(hdr.key, (hdr.clone(), body.to_vec())); - Ok(()) +pub struct TestSpawnContext { + pub ctr: Arc, + pub topic_ctr: Arc, } -impl SmokeDispatch { - pub fn new() -> Self { - let ctx = SmokeContext { - got: HashMap::new(), - next_err: false, - }; - let disp = Dispatch::new(ctx); - Self { disp } - } -} +impl SpawnContext for TestContext { + type SpawnCtxt = TestSpawnContext; -struct SmokeDispatch { - disp: Dispatch, -} - -#[tokio::test] -async fn smoke_reqresp() { - let (mut srv, client) = local_setup::(8, "error"); - - // Create the Dispatch Server - let mut disp = SmokeDispatch::new(); - disp.disp.add_handler::(store_disp).unwrap(); - - // Start the request - let send1 = tokio::spawn({ - let client = client.clone(); - async move { - client - .send_resp::(&Req1 { a: 10, b: 100 }) - .await + fn spawn_ctxt(&mut self) -> Self::SpawnCtxt { + TestSpawnContext { + ctr: self.ctr.clone(), + topic_ctr: self.topic_ctr.clone(), } - }); - - // As the wire, get the outgoing request - let out1 = srv.recv_from_client().await.unwrap(); - - // Does the outgoing value match what we expect? - let exp_out = to_stdvec_keyed(0, EndpointOne::REQ_KEY, &Req1 { a: 10, b: 100 }).unwrap(); - let act_out = out1.to_bytes(); - assert_eq!(act_out, exp_out); - - // The request is still awaiting a response - assert!(!send1.is_finished()); - - // Feed the request through the dispatcher - disp.disp.dispatch(&act_out).unwrap(); + } +} - // Make sure we "dispatched" it right - let disp_got = disp.disp.context().got.remove(&out1.header.key).unwrap(); - assert_eq!(disp_got.0, out1.header); - assert!(act_out.ends_with(&disp_got.1)); +define_dispatch! { + app: SingleDispatcher; + spawn_fn: spawn_fn; + tx_impl: WireTxImpl; + spawn_impl: WireSpawnImpl; + context: TestContext; - // The request is still awaiting a response - assert!(!send1.is_finished()); + endpoints: { + list: ENDPOINT_LIST; - // Feed a simulated response "from the wire" back to the - // awaiting request - const RESP_001: Resp1 = Resp1 { - c: [1, 2, 3, 4, 5, 6, 7, 8], - d: -10, + | EndpointTy | kind | handler | + | ---------- | ---- | ------- | + | AlphaEndpoint | async | test_alpha_handler | + | BetaEndpoint | spawn | test_beta_handler | + }; + topics_in: { + list: TOPICS_IN_LIST; + + | TopicTy | kind | handler | + | ---------- | ---- | ------- | + | ZetaTopic1 | blocking | test_zeta_blocking | + | ZetaTopic2 | async | test_zeta_async | + | ZetaTopic3 | spawn | test_zeta_spawn | + }; + topics_out: { + list: TOPICS_OUT_LIST; }; - srv.reply::(out1.header.seq_no, &RESP_001) - .await - .unwrap(); - - // Now wait for the request to complete - let end = send1.await.unwrap().unwrap(); - - // We got the simulated value back - assert_eq!(end, RESP_001); } -#[tokio::test] -async fn smoke_publish() { - let (mut srv, client) = local_setup::(8, "error"); - - // Start the request - client - .publish::(123, &Req1 { a: 10, b: 100 }) - .await - .unwrap(); - - // As the wire, get the outgoing request - let out1 = srv.recv_from_client().await.unwrap(); - - // Does the outgoing value match what we expect? - let exp_out = to_stdvec_keyed(123, TopicOne::TOPIC_KEY, &Req1 { a: 10, b: 100 }).unwrap(); - let act_out = out1.to_bytes(); - assert_eq!(act_out, exp_out); +fn test_zeta_blocking( + context: &mut TestContext, + _header: VarHeader, + _body: ZMsg, + _out: &Sender, +) { + context.topic_ctr.fetch_add(1, Ordering::Relaxed); } -#[tokio::test] -async fn smoke_subscribe() { - let (mut srv, client) = local_setup::(8, "error"); - - // Do a subscription - let mut sub = client.subscribe::(8).await.unwrap(); - - // Start the listen - let recv1 = timeout(Duration::from_millis(100), sub.recv()); - let _ = recv1.await.unwrap_err(); +async fn test_zeta_async( + context: &mut TestContext, + _header: VarHeader, + _body: ZMsg, + _out: &Sender, +) { + context.topic_ctr.fetch_add(1, Ordering::Relaxed); +} - // Send a message on the topic - const VAL: Req1 = Req1 { a: 10, b: 100 }; - srv.publish::(123, &VAL).await.unwrap(); +async fn test_zeta_spawn( + context: TestSpawnContext, + _header: VarHeader, + _body: ZMsg, + _out: Sender, +) { + context.topic_ctr.fetch_add(1, Ordering::Relaxed); +} - // Now the request resolves - let publ = timeout(Duration::from_millis(100), sub.recv()) - .await - .unwrap() - .unwrap(); +async fn test_alpha_handler(context: &mut TestContext, _header: VarHeader, body: AReq) -> AResp { + context.ctr.fetch_add(1, Ordering::Relaxed); + AResp(body.0) +} - assert_eq!(publ, VAL); +async fn test_beta_handler( + context: TestSpawnContext, + header: VarHeader, + body: BReq, + out: Sender, +) { + context.ctr.fetch_add(1, Ordering::Relaxed); + let _ = out + .reply::(header.seq_no, &BResp(body.0.into())) + .await; } #[tokio::test] -async fn smoke_io_error() { - let (mut srv, client) = local_setup::(8, "error"); - - // Do one round trip to make sure the connection works - { - let rr_rt = tokio::task::spawn({ - let client = client.clone(); - async move { - client - .send_resp::(&Req1 { a: 10, b: 100 }) - .await +async fn smoke() { + let (client_tx, server_rx) = mpsc::channel(16); + let (server_tx, mut client_rx) = mpsc::channel(16); + let topic_ctr = Arc::new(AtomicUsize::new(0)); + + let app = SingleDispatcher::new( + TestContext { + ctr: Arc::new(AtomicUsize::new(0)), + topic_ctr: topic_ctr.clone(), + }, + ChannelWireSpawn {}, + ); + + let cwrx = ChannelWireRx::new(server_rx); + let cwtx = ChannelWireTx::new(server_tx); + let kkind = app.min_key_len(); + let mut server = new_server( + app, + Settings { + tx: cwtx, + rx: cwrx, + buf: 1024, + kkind, + }, + ); + tokio::task::spawn(async move { + server.run().await; + }); + + // manually build request - Alpha + let mut msg = VarHeader { + key: VarKey::Key8(AlphaEndpoint::REQ_KEY), + seq_no: VarSeq::Seq4(123), + } + .write_to_vec(); + let body = postcard::to_stdvec(&AReq(42)).unwrap(); + msg.extend_from_slice(&body); + client_tx.send(msg).await.unwrap(); + let resp = client_rx.recv().await.unwrap(); + + // manually extract response + let (hdr, body) = VarHeader::take_from_slice(&resp).unwrap(); + let resp = postcard::from_bytes::<::Response>(body).unwrap(); + assert_eq!(resp.0, 42); + assert_eq!(hdr.key, VarKey::Key8(AlphaEndpoint::RESP_KEY)); + assert_eq!(hdr.seq_no, VarSeq::Seq4(123)); + + // manually build request - Beta + let mut msg = VarHeader { + key: VarKey::Key8(BetaEndpoint::REQ_KEY), + seq_no: VarSeq::Seq4(234), + } + .write_to_vec(); + let body = postcard::to_stdvec(&BReq(1000)).unwrap(); + msg.extend_from_slice(&body); + client_tx.send(msg).await.unwrap(); + let resp = client_rx.recv().await.unwrap(); + + // manually extract response + let (hdr, body) = VarHeader::take_from_slice(&resp).unwrap(); + let resp = postcard::from_bytes::<::Response>(body).unwrap(); + assert_eq!(resp.0, 1000); + assert_eq!(hdr.key, VarKey::Key8(BetaEndpoint::RESP_KEY)); + assert_eq!(hdr.seq_no, VarSeq::Seq4(234)); + + // blocking topic handler + for i in 0..3 { + let mut msg = VarHeader { + key: VarKey::Key8(ZetaTopic1::TOPIC_KEY), + seq_no: VarSeq::Seq4(i), + } + .write_to_vec(); + + let body = postcard::to_stdvec(&ZMsg(456)).unwrap(); + msg.extend_from_slice(&body); + client_tx.send(msg).await.unwrap(); + + let start = Instant::now(); + let mut good = false; + while start.elapsed() < Duration::from_millis(100) { + let ct = topic_ctr.load(Ordering::Relaxed); + if ct == (i + 1) as usize { + good = true; + break; + } else { + yield_now().await } - }); - - // As the wire, get the outgoing request - let out1 = srv.recv_from_client().await.unwrap(); - - // Does the outgoing value match what we expect? - let exp_out = to_stdvec_keyed(0, EndpointOne::REQ_KEY, &Req1 { a: 10, b: 100 }).unwrap(); - let act_out = out1.to_bytes(); - assert_eq!(act_out, exp_out); - - // The request is still awaiting a response - assert!(!rr_rt.is_finished()); - - // Feed a simulated response "from the wire" back to the - // awaiting request - const RESP_001: Resp1 = Resp1 { - c: [1, 2, 3, 4, 5, 6, 7, 8], - d: -10, - }; - srv.reply::(out1.header.seq_no, &RESP_001) - .await - .unwrap(); - - // Now wait for the request to complete - let end = rr_rt.await.unwrap().unwrap(); - - // We got the simulated value back - assert_eq!(end, RESP_001); + } + assert!(good); } - // Now, simulate an I/O error - srv.cause_fatal_error(); - - // Give the clients some time to halt - yield_now().await; - - // Our server channels should now be closed - the tasks hung up - assert!(srv.from_client.recv().await.is_none()); - assert!(srv.to_client.send(Vec::new()).await.is_err()); + // async topic handler + for i in 0..3 { + let mut msg = VarHeader { + key: VarKey::Key8(ZetaTopic2::TOPIC_KEY), + seq_no: VarSeq::Seq4(i), + } + .write_to_vec(); + let body = postcard::to_stdvec(&ZMsg(456)).unwrap(); + msg.extend_from_slice(&body); + client_tx.send(msg).await.unwrap(); + + let start = Instant::now(); + let mut good = false; + while start.elapsed() < Duration::from_millis(100) { + let ct = topic_ctr.load(Ordering::Relaxed); + if ct == (i + 4) as usize { + good = true; + break; + } else { + yield_now().await + } + } + assert!(good); + } - // Try again, but nothing should work because all the worker tasks just halted - { - let rr_rt = tokio::task::spawn({ - let client = client.clone(); - async move { - client - .send_resp::(&Req1 { a: 10, b: 100 }) - .await + // spawn topic handler + for i in 0..3 { + let mut msg = VarHeader { + key: VarKey::Key8(ZetaTopic3::TOPIC_KEY), + seq_no: VarSeq::Seq4(i), + } + .write_to_vec(); + let body = postcard::to_stdvec(&ZMsg(456)).unwrap(); + msg.extend_from_slice(&body); + client_tx.send(msg).await.unwrap(); + + let start = Instant::now(); + let mut good = false; + while start.elapsed() < Duration::from_millis(100) { + let ct = topic_ctr.load(Ordering::Relaxed); + if ct == (i + 7) as usize { + good = true; + break; + } else { + yield_now().await } - }); + } + assert!(good); + } +} - // As the wire, get the outgoing request - didn't happen - assert!(srv.recv_from_client().await.is_err()); +#[tokio::test] +async fn end_to_end() { + let (client_tx, server_rx) = mpsc::channel(16); + let (server_tx, client_rx) = mpsc::channel(16); + let topic_ctr = Arc::new(AtomicUsize::new(0)); + + let app = SingleDispatcher::new( + TestContext { + ctr: Arc::new(AtomicUsize::new(0)), + topic_ctr: topic_ctr.clone(), + }, + ChannelWireSpawn {}, + ); + + let cwrx = ChannelWireRx::new(server_rx); + let cwtx = ChannelWireTx::new(server_tx); + + let kkind = app.min_key_len(); + let mut server = new_server( + app, + Settings { + tx: cwtx, + rx: cwrx, + buf: 1024, + kkind, + }, + ); + tokio::task::spawn(async move { + server.run().await; + }); - // Now wait for the request to complete - it failed - rr_rt.await.unwrap().unwrap_err(); - } + let cli = client::new_from_channels(client_tx, client_rx, VarSeqKind::Seq1); + + let resp = cli.send_resp::(&AReq(42)).await.unwrap(); + assert_eq!(resp.0, 42); + let resp = cli.send_resp::(&BReq(1234)).await.unwrap(); + assert_eq!(resp.0, 1234); } #[tokio::test] -async fn smoke_closed() { - let (mut srv, client) = local_setup::(8, "error"); - - // Do one round trip to make sure the connection works - { - let rr_rt = tokio::task::spawn({ - let client = client.clone(); - async move { - client - .send_resp::(&Req1 { a: 10, b: 100 }) - .await - } - }); - - // As the wire, get the outgoing request - let out1 = srv.recv_from_client().await.unwrap(); - - // Does the outgoing value match what we expect? - let exp_out = to_stdvec_keyed(0, EndpointOne::REQ_KEY, &Req1 { a: 10, b: 100 }).unwrap(); - let act_out = out1.to_bytes(); - assert_eq!(act_out, exp_out); - - // The request is still awaiting a response - assert!(!rr_rt.is_finished()); - - // Feed a simulated response "from the wire" back to the - // awaiting request - const RESP_001: Resp1 = Resp1 { - c: [1, 2, 3, 4, 5, 6, 7, 8], - d: -10, - }; - srv.reply::(out1.header.seq_no, &RESP_001) - .await - .unwrap(); - - // Now wait for the request to complete - let end = rr_rt.await.unwrap().unwrap(); - - // We got the simulated value back - assert_eq!(end, RESP_001); - } +async fn end_to_end_force8() { + let (client_tx, server_rx) = mpsc::channel(16); + let (server_tx, client_rx) = mpsc::channel(16); + let topic_ctr = Arc::new(AtomicUsize::new(0)); + + let app = SingleDispatcher::new( + TestContext { + ctr: Arc::new(AtomicUsize::new(0)), + topic_ctr: topic_ctr.clone(), + }, + ChannelWireSpawn {}, + ); + + let cwrx = ChannelWireRx::new(server_rx); + let cwtx = ChannelWireTx::new(server_tx); + + let kkind = VarKeyKind::Key8; + let mut server = new_server( + app, + Settings { + tx: cwtx, + rx: cwrx, + buf: 1024, + kkind, + }, + ); + tokio::task::spawn(async move { + server.run().await; + }); - // Now, use the *client* to close the connection - client.close(); + let cli = client::new_from_channels(client_tx, client_rx, VarSeqKind::Seq4); - // Give the clients some time to halt - yield_now().await; + let resp = cli.send_resp::(&AReq(42)).await.unwrap(); + assert_eq!(resp.0, 42); + let resp = cli.send_resp::(&BReq(1234)).await.unwrap(); + assert_eq!(resp.0, 1234); +} - // Our server channels should now be closed - the tasks hung up - assert!(srv.from_client.recv().await.is_none()); - assert!(srv.to_client.send(Vec::new()).await.is_err()); +#[test] +fn device_map() { + let topic_ctr = Arc::new(AtomicUsize::new(0)); + let app = SingleDispatcher::new( + TestContext { + ctr: Arc::new(AtomicUsize::new(0)), + topic_ctr: topic_ctr.clone(), + }, + ChannelWireSpawn {}, + ); + + println!("# SingleDispatcher"); + println!(); + + println!("## Types"); + println!(); + for ty in app.device_map.types { + let ty = OwnedNamedType::from(*ty); + println!("* {ty}"); + } - // Try again, but nothing should work because all the worker tasks just halted - { - let rr_rt = tokio::task::spawn({ - let client = client.clone(); - async move { - client - .send_resp::(&Req1 { a: 10, b: 100 }) - .await - } - }); + println!(); + println!("## Endpoints"); + println!(); + for ep in app.device_map.endpoints { + println!( + "* {} ({:016X} -> {:016X})", + ep.0, + u64::from_le_bytes(ep.1.to_bytes()), + u64::from_le_bytes(ep.2.to_bytes()), + ); + } - // As the wire, get the outgoing request - didn't happen - assert!(srv.recv_from_client().await.is_err()); + println!(); + println!("## Topics (In)"); + println!(); + for tp in app.device_map.topics_in { + println!( + "* {} <- ({:016X})", + tp.0, + u64::from_le_bytes(tp.1.to_bytes()), + ); + } - // Now wait for the request to complete - it failed - rr_rt.await.unwrap().unwrap_err(); + println!(); + println!("## Topics (Out)"); + println!(); + for tp in app.device_map.topics_out { + println!( + "* {} -> ({:016X})", + tp.0, + u64::from_le_bytes(tp.1.to_bytes()), + ); } + println!(); + println!("## Min Key Length"); + println!(); + println!("{:?}", app.device_map.min_key_len); + println!(); } diff --git a/source/postcard-rpc/Cargo.toml b/source/postcard-rpc/Cargo.toml index be20af9..704f41d 100644 --- a/source/postcard-rpc/Cargo.toml +++ b/source/postcard-rpc/Cargo.toml @@ -30,6 +30,7 @@ heapless = "0.8.0" postcard = { version = "1.0.8" } serde = { version = "1.0.192", default-features = false, features = ["derive"] } postcard-schema = { version = "0.1.0", features = ["derive"] } +paste = "1.0.15" # # std-only features @@ -124,7 +125,7 @@ version = "2.1" optional = true [dependencies.embassy-executor] -version = "0.5" +version = "0.6" optional = true [dependencies.futures-util] @@ -155,6 +156,7 @@ use-std = [ "dep:thiserror", "dep:tracing", "dep:trait-variant", + "dep:ssmarshal", ] # Cobs Serial support. diff --git a/source/postcard-rpc/src/accumulator.rs b/source/postcard-rpc/src/accumulator.rs index 421c885..493ba80 100644 --- a/source/postcard-rpc/src/accumulator.rs +++ b/source/postcard-rpc/src/accumulator.rs @@ -9,6 +9,7 @@ pub mod raw { use cobs::decode_in_place; + /// A header-aware COBS accumulator pub struct CobsAccumulator { buf: [u8; N], idx: usize, @@ -116,103 +117,3 @@ pub mod raw { } } } - -/// Accumulate and Dispatch -pub mod dispatch { - use super::raw::{CobsAccumulator, FeedResult}; - use crate::Dispatch; - - /// An error containing the handler-specific error, as well as the unprocessed - /// feed bytes - #[derive(Debug, PartialEq)] - pub struct FeedError<'a, E> { - pub err: E, - pub remainder: &'a [u8], - } - - /// A COBS-flavored version of [Dispatch] - /// - /// [Dispatch]: crate::Dispatch - /// - /// CobsDispatch is generic over four types: - /// - /// 1. The `Context`, which will be passed as a mutable reference - /// to each of the handlers. It typically should contain - /// whatever resource is necessary to send replies back to - /// the host. - /// 2. The `Error` type, which can be returned by handlers - /// 3. `N`, for the maximum number of handlers - /// 4. `BUF`, for the maximum number of bytes to buffer for a single - /// COBS-encoded message - pub struct CobsDispatch { - dispatch: Dispatch, - acc: CobsAccumulator, - } - - impl CobsDispatch { - /// Create a new `CobsDispatch` - pub fn new(c: Context) -> Self { - Self { - dispatch: Dispatch::new(c), - acc: CobsAccumulator::new(), - } - } - - /// Access the contained [Dispatch]` - pub fn dispatcher(&mut self) -> &mut Dispatch { - &mut self.dispatch - } - - /// Feed the given bytes into the dispatcher, attempting to dispatch each framed - /// message found. - /// - /// Line format errors, such as an overfull buffer or bad COBS frames will be - /// silently ignored. - /// - /// If an error in dispatching occurs, this function will return immediately, - /// yielding the error and the remaining unprocessed bytes for further processing. - pub fn feed<'a>( - &mut self, - buf: &'a [u8], - ) -> Result<(), FeedError<'a, crate::Error>> { - let mut window = buf; - let CobsDispatch { dispatch, acc } = self; - 'cobs: while !window.is_empty() { - window = match acc.feed(window) { - FeedResult::Consumed => break 'cobs, - FeedResult::OverFull(new_wind) => new_wind, - FeedResult::DeserError(new_wind) => new_wind, - FeedResult::Success { data, remaining } => { - dispatch.dispatch(data).map_err(|e| FeedError { - err: e, - remainder: remaining, - })?; - remaining - } - }; - } - - // We have dispatched all (if any) messages, and consumed the buffer - // without dispatch errors. - Ok(()) - } - - /// Similar to [CobsDispatch::feed], but the provided closure is called on each - /// error, allowing for handling. - /// - /// Useful if you need to do something blocking on each error case. - /// - /// If you need to handle the error in an async context, you may want to use - /// [CobsDispatch::feed] instead. - pub fn feed_with_err(&mut self, buf: &[u8], mut f: F) - where - F: FnMut(&mut Context, crate::Error), - { - let mut window = buf; - while let Err(FeedError { err, remainder }) = self.feed(window) { - f(&mut self.dispatch.context, err); - window = remainder; - } - } - } -} diff --git a/source/postcard-rpc/src/hash.rs b/source/postcard-rpc/src/hash.rs index 8d629df..3929de3 100644 --- a/source/postcard-rpc/src/hash.rs +++ b/source/postcard-rpc/src/hash.rs @@ -13,6 +13,7 @@ use postcard_schema::{ Schema, }; +/// A const compatible Fnv1a64 hasher pub struct Fnv1a64Hasher { state: u64, } @@ -22,10 +23,12 @@ impl Fnv1a64Hasher { const BASIS: u64 = 0xcbf2_9ce4_8422_2325; const PRIME: u64 = 0x0000_0100_0000_01b3; + /// Create a new hasher with the default basis as state contents pub fn new() -> Self { Self { state: Self::BASIS } } + /// Calculate the hash for each of the given data bytes pub fn update(&mut self, data: &[u8]) { for b in data { let ext = u64::from(*b); @@ -34,10 +37,12 @@ impl Fnv1a64Hasher { } } + /// Extract the current state for finalizing the hash pub fn digest(self) -> u64 { self.state } + /// Same as digest but as bytes pub fn digest_bytes(self) -> [u8; 8] { self.digest().to_le_bytes() } @@ -50,10 +55,12 @@ impl Default for Fnv1a64Hasher { } pub mod fnv1a64 { + //! Const and no-std helper methods and types for perfoming hash calculation use postcard_schema::schema::DataModelVariant; use super::*; + /// Calculate the Key hash for the given path and type T pub const fn hash_ty_path(path: &str) -> [u8; 8] { let schema = T::SCHEMA; let state = hash_update_str(Fnv1a64Hasher::BASIS, path); @@ -205,7 +212,7 @@ pub mod fnv1a64 { let mut state = hash_update(state, &[0xC7]); let mut idx = 0; while idx < nts.len() { - state = hash_named_type(state, &nts[idx]); + state = hash_named_type(state, nts[idx]); idx += 1; } state @@ -214,7 +221,7 @@ pub mod fnv1a64 { let mut state = hash_update(state, &[0x67]); let mut idx = 0; while idx < nvs.len() { - state = hash_named_value(state, &nvs[idx]); + state = hash_named_value(state, nvs[idx]); idx += 1; } state @@ -230,6 +237,8 @@ pub mod fnv1a64 { #[cfg(feature = "use-std")] pub mod fnv1a64_owned { + //! Heapful helpers and versions of hashing for use on `std` targets + use postcard_schema::schema::owned::{ OwnedDataModelType, OwnedDataModelVariant, OwnedNamedType, OwnedNamedValue, OwnedNamedVariant, @@ -238,6 +247,7 @@ pub mod fnv1a64_owned { use super::fnv1a64::*; use super::*; + /// Calculate the Key hash for the given path and OwnedNamedType pub fn hash_ty_path_owned(path: &str, nt: &OwnedNamedType) -> [u8; 8] { let state = hash_update_str(Fnv1a64Hasher::BASIS, path); hash_named_type_owned(state, nt).to_le_bytes() diff --git a/source/postcard-rpc/src/header.rs b/source/postcard-rpc/src/header.rs new file mode 100644 index 0000000..7e34a09 --- /dev/null +++ b/source/postcard-rpc/src/header.rs @@ -0,0 +1,630 @@ +//! # Postcard-RPC Header Format +//! +//! Postcard-RPC's header is made up of three main parts: +//! +//! 1. A one-byte discriminant +//! 2. A 1-8 byte "Key" +//! 3. A 1-4 byte "Sequence Number" +//! +//! The Postcard-RPC Header is NOT encoded using `postcard`'s wire format. +//! +//! ## Discriminant +//! +//! The discriminant field is always one byte, and consists of three subfields +//! in the form `0bNNMM_VVVV`. +//! +//! * The two msbits are "key length", where the two N length bits represent +//! a key length of 2^N. All values are valid. +//! * The next two msbits are "sequence number length", where the two M length +//! bits represent a sequence number length of 2^M. Values 00, 01, and 10 +//! are valid. +//! * The four lsbits are "protocol version", where the four V version bits +//! represent an unsigned 4-bit number. Currently only 0000 is a valid value. +//! +//! ## Key +//! +//! The Key consists of an fnv1a hash of the path string and schema of the +//! contained message. These are calculated using the [`hash` module](crate::hash), +//! and are natively calculated as an 8-byte hash. +//! +//! Keys may be encoded with variable fidelity on the wire, as follows: +//! +//! * For 8-byte keys, all key bytes appear in the form `[A, B, C, D, E, F, G, H]`. +//! * For 4-byte keys, the 8-byte form is compressed as `[A^B, C^D, E^F, G^H]`. +//! * For 2-byte keys, the 8-byte form is compressed as `[A^B^C^D, E^F^G^H]`. +//! * For 1-byte keys, the 8-byte form is compressed as `A^B^C^D^E^F^G^H`. +//! +//! The length of the Key is determined by the two `NN` bits in the discriminant. +//! +//! The length of the key is usually chosen by the **Server**, as the server is +//! able to calculate the minimum number of bits necessary to avoid collisions. +//! +//! When Clients receive a server response, they shall note the Key length used, +//! and match that for all subsequent messages. When Clients make first connection, +//! they shall use the 8-byte form by default. +//! +//! ## Sequence Number +//! +//! The Sequence Number is an unsigned integer used to match request-response pairs, +//! and disambiguate between multiple in-flight messages. +//! +//! Sequence Numbers may be encoded with variable fidelity on the wire, always in +//! little-endian order, of 1, 2, or 4 bytes. +//! +//! The length of the Sequence Number is determined by the two `MM` bits in the +//! discriminant. +//! +//! The length of the key is chosen by the "originator" of the message. For Endpoints +//! this is the client making the request. For Topics, this is the device sending the +//! topic message. + +use crate::{Key, Key1, Key2, Key4}; + +////////////////////////////////////////////////////////////////////////////// +// VARKEY +////////////////////////////////////////////////////////////////////////////// + +/// A variably sized header Key +/// +/// NOTE: We DO NOT impl Serialize/Deserialize for this type because +/// we use non-postcard-compatible format (externally tagged) on the wire. +/// +/// NOTE: VarKey implements `PartialEq` by reducing two VarKeys down to the +/// smaller of the two forms, and checking whether they match. This allows +/// a key in 8-byte form to be compared to a key in 1, 2, or 4-byte form +/// for equality. +#[derive(Debug, Copy, Clone)] +pub enum VarKey { + /// A one byte key + Key1(Key1), + /// A two byte key + Key2(Key2), + /// A four byte key + Key4(Key4), + /// An eight byte key + Key8(Key), +} + +/// We implement PartialEq MANUALLY for VarKey, because keys of different lengths SHOULD compare +/// as equal. +impl PartialEq for VarKey { + fn eq(&self, other: &Self) -> bool { + // figure out the minimum length + match (self, other) { + // Matching kinds + (VarKey::Key1(self_key), VarKey::Key1(other_key)) => self_key.0.eq(&other_key.0), + (VarKey::Key2(self_key), VarKey::Key2(other_key)) => self_key.0.eq(&other_key.0), + (VarKey::Key4(self_key), VarKey::Key4(other_key)) => self_key.0.eq(&other_key.0), + (VarKey::Key8(self_key), VarKey::Key8(other_key)) => self_key.0.eq(&other_key.0), + + // For the rest of the options, degrade the LARGER key to the SMALLER key, and then + // check for equivalence after that. + (VarKey::Key1(this), VarKey::Key2(other)) => { + let other = Key1::from_key2(*other); + this.0.eq(&other.0) + } + (VarKey::Key1(this), VarKey::Key4(other)) => { + let other = Key1::from_key4(*other); + this.0.eq(&other.0) + } + (VarKey::Key1(this), VarKey::Key8(other)) => { + let other = Key1::from_key8(*other); + this.0.eq(&other.0) + } + (VarKey::Key2(this), VarKey::Key1(other)) => { + let this = Key1::from_key2(*this); + this.0.eq(&other.0) + } + (VarKey::Key2(this), VarKey::Key4(other)) => { + let other = Key2::from_key4(*other); + this.0.eq(&other.0) + } + (VarKey::Key2(this), VarKey::Key8(other)) => { + let other = Key2::from_key8(*other); + this.0.eq(&other.0) + } + (VarKey::Key4(this), VarKey::Key1(other)) => { + let this = Key1::from_key4(*this); + this.0.eq(&other.0) + } + (VarKey::Key4(this), VarKey::Key2(other)) => { + let this = Key2::from_key4(*this); + this.0.eq(&other.0) + } + (VarKey::Key4(this), VarKey::Key8(other)) => { + let other = Key4::from_key8(*other); + this.0.eq(&other.0) + } + (VarKey::Key8(this), VarKey::Key1(other)) => { + let this = Key1::from_key8(*this); + this.0.eq(&other.0) + } + (VarKey::Key8(this), VarKey::Key2(other)) => { + let this = Key2::from_key8(*this); + this.0.eq(&other.0) + } + (VarKey::Key8(this), VarKey::Key4(other)) => { + let this = Key4::from_key8(*this); + this.0.eq(&other.0) + } + } + } +} + +impl VarKey { + /// Keys can not be reaised, but instead only shrunk. + /// + /// This method will shrink to the requested length if that length is + /// smaller than the current representation, or if the requested length + /// is the same or larger than the current representation, it will be + /// kept unchanged + pub fn shrink_to(&mut self, kind: VarKeyKind) { + match (&self, kind) { + (VarKey::Key1(_), _) => { + // Nothing to shrink + } + (VarKey::Key2(key2), VarKeyKind::Key1) => { + *self = VarKey::Key1(Key1::from_key2(*key2)); + } + (VarKey::Key2(_), _) => { + // We are already as small or smaller than the request + } + (VarKey::Key4(key4), VarKeyKind::Key1) => { + *self = VarKey::Key1(Key1::from_key4(*key4)); + } + (VarKey::Key4(key4), VarKeyKind::Key2) => { + *self = VarKey::Key2(Key2::from_key4(*key4)); + } + (VarKey::Key4(_), _) => { + // We are already as small or smaller than the request + } + (VarKey::Key8(key), VarKeyKind::Key1) => { + *self = VarKey::Key1(Key1::from_key8(*key)); + } + (VarKey::Key8(key), VarKeyKind::Key2) => { + *self = VarKey::Key2(Key2::from_key8(*key)); + } + (VarKey::Key8(key), VarKeyKind::Key4) => { + *self = VarKey::Key4(Key4::from_key8(*key)); + } + (VarKey::Key8(_), VarKeyKind::Key8) => { + // Nothing to do + } + } + } + + /// The current kind/length of the key + pub fn kind(&self) -> VarKeyKind { + match self { + VarKey::Key1(_) => VarKeyKind::Key1, + VarKey::Key2(_) => VarKeyKind::Key2, + VarKey::Key4(_) => VarKeyKind::Key4, + VarKey::Key8(_) => VarKeyKind::Key8, + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// VARKEYKIND +////////////////////////////////////////////////////////////////////////////// + +/// The kind or length of the variably sized header Key +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum VarKeyKind { + /// A one byte key + Key1, + /// A two byte key + Key2, + /// A four byte key + Key4, + /// An eight byte key + Key8, +} + +////////////////////////////////////////////////////////////////////////////// +// VARSEQ +////////////////////////////////////////////////////////////////////////////// + +/// A variably sized sequence number +/// +/// NOTE: We use the standard PartialEq here, as we DO NOT treat sequence +/// numbers of different lengths as equivalent. +/// +/// We DO NOT impl Serialize/Deserialize for this type because we use +/// non-postcard-compatible format (externally tagged) +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum VarSeq { + /// A one byte sequence number + Seq1(u8), + /// A two byte sequence number + Seq2(u16), + /// A four byte sequence number + Seq4(u32), +} + +impl From for VarSeq { + fn from(value: u8) -> Self { + Self::Seq1(value) + } +} + +impl From for VarSeq { + fn from(value: u16) -> Self { + Self::Seq2(value) + } +} + +impl From for VarSeq { + fn from(value: u32) -> Self { + Self::Seq4(value) + } +} + +impl VarSeq { + /// Resize (up or down) to the requested kind. + /// + /// When increasing size, the number is left-extended, e.g. `0x42u8` becomes + /// `0x0000_0042u32` when resizing 1 -> 4. + /// + /// When decreasing size, the number is truncated, e.g. `0xABCD_EF12u32` + /// becomes `0x12u8` when resizing 4 -> 1. + pub fn resize(&mut self, kind: VarSeqKind) { + match (&self, kind) { + (VarSeq::Seq1(_), VarSeqKind::Seq1) => {} + (VarSeq::Seq2(_), VarSeqKind::Seq2) => {} + (VarSeq::Seq4(_), VarSeqKind::Seq4) => {} + (VarSeq::Seq1(s), VarSeqKind::Seq2) => { + *self = VarSeq::Seq2((*s).into()); + } + (VarSeq::Seq1(s), VarSeqKind::Seq4) => { + *self = VarSeq::Seq4((*s).into()); + } + (VarSeq::Seq2(s), VarSeqKind::Seq1) => { + *self = VarSeq::Seq1((*s) as u8); + } + (VarSeq::Seq2(s), VarSeqKind::Seq4) => { + *self = VarSeq::Seq4((*s).into()); + } + (VarSeq::Seq4(s), VarSeqKind::Seq1) => { + *self = VarSeq::Seq1((*s) as u8); + } + (VarSeq::Seq4(s), VarSeqKind::Seq2) => { + *self = VarSeq::Seq2((*s) as u16); + } + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// VARSEQKIND +////////////////////////////////////////////////////////////////////////////// + +/// The Kind or Length of a VarSeq +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum VarSeqKind { + /// A one byte sequence number + Seq1, + /// A two byte sequence number + Seq2, + /// A four byte sequence number + Seq4, +} + +////////////////////////////////////////////////////////////////////////////// +// VARHEADER +////////////////////////////////////////////////////////////////////////////// + +/// A variably sized message header +/// +/// NOTE: We use the standard PartialEq here as it will do the correct things. +/// +/// Sequence numbers must be EXACTLY the same, and keys must be equivalent when +/// degraded to the smaller of the two. +/// +/// We DO NOT impl Serialize/Deserialize for this type because we use +/// non-postcard-compatible format (externally tagged) +#[derive(Debug, PartialEq, Clone, Copy)] +pub struct VarHeader { + /// The variably sized Key + pub key: VarKey, + /// The variably sized Sequence Number + pub seq_no: VarSeq, +} + +#[allow(clippy::unusual_byte_groupings)] +impl VarHeader { + /// Bits for a key of ONE byte + pub const KEY_ONE_BITS: u8 = 0b00_00_0000; + /// Bits for a key of TWO bytes + pub const KEY_TWO_BITS: u8 = 0b01_00_0000; + /// Bits for a key of FOUR bytes + pub const KEY_FOUR_BITS: u8 = 0b10_00_0000; + /// Bits for a key of EIGHT bytes + pub const KEY_EIGHT_BITS: u8 = 0b11_00_0000; + /// Mask bits + pub const KEY_MASK_BITS: u8 = 0b11_00_0000; + + /// Bits for a sequence number of ONE bytes + pub const SEQ_ONE_BITS: u8 = 0b00_00_0000; + /// Bits for a sequence number of TWO bytes + pub const SEQ_TWO_BITS: u8 = 0b00_01_0000; + /// Bits for a sequence number of FOUR bytes + pub const SEQ_FOUR_BITS: u8 = 0b00_10_0000; + /// Mask bits + pub const SEQ_MASK_BITS: u8 = 0b00_11_0000; + + /// Bits for a version number of ZERO + pub const VER_ZERO_BITS: u8 = 0b00_00_0000; + /// Mask bits + pub const VER_MASK_BITS: u8 = 0b00_00_1111; + + /// Encode the header to a Vec of bytes + #[cfg(feature = "use-std")] + pub fn write_to_vec(&self) -> Vec { + // start with placeholder byte + let mut out = vec![0u8; 1]; + let mut disc_out: u8; + match &self.key { + VarKey::Key1(k) => { + disc_out = Self::KEY_ONE_BITS; + out.push(k.0); + } + VarKey::Key2(k) => { + disc_out = Self::KEY_TWO_BITS; + out.extend_from_slice(&k.0); + } + VarKey::Key4(k) => { + disc_out = Self::KEY_FOUR_BITS; + out.extend_from_slice(&k.0); + } + VarKey::Key8(k) => { + disc_out = Self::KEY_EIGHT_BITS; + out.extend_from_slice(&k.0); + } + } + match &self.seq_no { + VarSeq::Seq1(s) => { + disc_out |= Self::SEQ_ONE_BITS; + out.push(*s); + } + VarSeq::Seq2(s) => { + disc_out |= Self::SEQ_TWO_BITS; + out.extend_from_slice(&s.to_le_bytes()); + } + VarSeq::Seq4(s) => { + disc_out |= Self::SEQ_FOUR_BITS; + out.extend_from_slice(&s.to_le_bytes()); + } + } + // push discriminant to the end... + out.push(disc_out); + // ...and swap-remove the placeholder byte, moving the discriminant to the front + out.swap_remove(0); + out + } + + /// Attempt to write the header to the given slice + /// + /// If the slice is large enough, a `Some` will be returned with the bytes used + /// to encode the header, as well as the remaining unused bytes. + /// + /// If the slice is not large enough, a `None` will be returned, and some bytes + /// of the buffer may have been modified. + pub fn write_to_slice<'a>(&self, buf: &'a mut [u8]) -> Option<(&'a mut [u8], &'a mut [u8])> { + let (disc_out, mut remain) = buf.split_first_mut()?; + let mut used = 1; + + match &self.key { + VarKey::Key1(k) => { + *disc_out = Self::KEY_ONE_BITS; + let (keybs, remain2) = remain.split_first_mut()?; + *keybs = k.0; + remain = remain2; + used += 1; + } + VarKey::Key2(k) => { + *disc_out = Self::KEY_TWO_BITS; + let (keybs, remain2) = remain.split_at_mut_checked(2)?; + keybs.copy_from_slice(&k.0); + remain = remain2; + used += 2; + } + VarKey::Key4(k) => { + *disc_out = Self::KEY_FOUR_BITS; + let (keybs, remain2) = remain.split_at_mut_checked(4)?; + keybs.copy_from_slice(&k.0); + remain = remain2; + used += 4; + } + VarKey::Key8(k) => { + *disc_out = Self::KEY_EIGHT_BITS; + let (keybs, remain2) = remain.split_at_mut_checked(8)?; + keybs.copy_from_slice(&k.0); + remain = remain2; + used += 8; + } + } + match &self.seq_no { + VarSeq::Seq1(s) => { + *disc_out |= Self::SEQ_ONE_BITS; + let (seqbs, _) = remain.split_first_mut()?; + *seqbs = *s; + used += 1; + } + VarSeq::Seq2(s) => { + *disc_out |= Self::SEQ_TWO_BITS; + let (seqbs, _) = remain.split_at_mut_checked(2)?; + seqbs.copy_from_slice(&s.to_le_bytes()); + used += 2; + } + VarSeq::Seq4(s) => { + *disc_out |= Self::SEQ_FOUR_BITS; + let (seqbs, _) = remain.split_at_mut_checked(4)?; + seqbs.copy_from_slice(&s.to_le_bytes()); + used += 4; + } + } + Some(buf.split_at_mut(used)) + } + + /// Attempt to decode a header from the given bytes. + /// + /// If a well-formed header was found, a `Some` will be returned with the + /// decoded header and unused remaining bytes. + /// + /// If no well-formed header was found, a `None` will be returned. + pub fn take_from_slice(buf: &[u8]) -> Option<(Self, &[u8])> { + let (disc, mut remain) = buf.split_first()?; + + // For now, we only trust version zero + if (*disc & Self::VER_MASK_BITS) != Self::VER_ZERO_BITS { + return None; + } + + let key = match (*disc) & Self::KEY_MASK_BITS { + Self::KEY_ONE_BITS => { + let (keybs, remain2) = remain.split_first()?; + remain = remain2; + VarKey::Key1(Key1(*keybs)) + } + Self::KEY_TWO_BITS => { + let (keybs, remain2) = remain.split_at_checked(2)?; + remain = remain2; + let mut buf = [0u8; 2]; + buf.copy_from_slice(keybs); + VarKey::Key2(Key2(buf)) + } + Self::KEY_FOUR_BITS => { + let (keybs, remain2) = remain.split_at_checked(4)?; + remain = remain2; + let mut buf = [0u8; 4]; + buf.copy_from_slice(keybs); + VarKey::Key4(Key4(buf)) + } + Self::KEY_EIGHT_BITS => { + let (keybs, remain2) = remain.split_at_checked(8)?; + remain = remain2; + let mut buf = [0u8; 8]; + buf.copy_from_slice(keybs); + VarKey::Key8(Key(buf)) + } + // Impossible: all bits covered + _ => unreachable!(), + }; + let seq_no = match (*disc) & Self::SEQ_MASK_BITS { + Self::SEQ_ONE_BITS => { + let (seqbs, remain3) = remain.split_first()?; + remain = remain3; + VarSeq::Seq1(*seqbs) + } + Self::SEQ_TWO_BITS => { + let (seqbs, remain3) = remain.split_at_checked(2)?; + remain = remain3; + let mut buf = [0u8; 2]; + buf.copy_from_slice(seqbs); + VarSeq::Seq2(u16::from_le_bytes(buf)) + } + Self::SEQ_FOUR_BITS => { + let (seqbs, remain3) = remain.split_at_checked(4)?; + remain = remain3; + let mut buf = [0u8; 4]; + buf.copy_from_slice(seqbs); + VarSeq::Seq4(u32::from_le_bytes(buf)) + } + // Possible (could be 0b11), is invalid + _ => return None, + }; + Some((Self { key, seq_no }, remain)) + } +} + +#[cfg(test)] +mod test { + use super::{VarHeader, VarKey, VarSeq}; + use crate::{Key, Key1, Key2}; + + #[test] + fn wire_format() { + let checks: &[(_, &[u8])] = &[ + ( + VarHeader { + key: VarKey::Key1(Key1(0)), + seq_no: VarSeq::Seq1(0x00), + }, + &[ + VarHeader::KEY_ONE_BITS | VarHeader::SEQ_ONE_BITS, + 0x00, + 0x00, + ], + ), + ( + VarHeader { + key: VarKey::Key1(Key1(1)), + seq_no: VarSeq::Seq1(0x02), + }, + &[ + VarHeader::KEY_ONE_BITS | VarHeader::SEQ_ONE_BITS, + 0x01, + 0x02, + ], + ), + ( + VarHeader { + key: VarKey::Key2(Key2([0x42, 0xAF])), + seq_no: VarSeq::Seq1(0x02), + }, + &[ + VarHeader::KEY_TWO_BITS | VarHeader::SEQ_ONE_BITS, + 0x42, + 0xAF, + 0x02, + ], + ), + ( + VarHeader { + key: VarKey::Key1(Key1(1)), + seq_no: VarSeq::Seq2(0x42_AF), + }, + &[ + VarHeader::KEY_ONE_BITS | VarHeader::SEQ_TWO_BITS, + 0x01, + 0xAF, + 0x42, + ], + ), + ( + VarHeader { + key: VarKey::Key8(Key([0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89])), + seq_no: VarSeq::Seq4(0x42_AF_AA_BB), + }, + &[ + VarHeader::KEY_EIGHT_BITS | VarHeader::SEQ_FOUR_BITS, + 0x12, + 0x23, + 0x34, + 0x45, + 0x56, + 0x67, + 0x78, + 0x89, + 0xBB, + 0xAA, + 0xAF, + 0x42, + ], + ), + ]; + + let mut buf = [0u8; 1 + 8 + 4]; + + for (val, exp) in checks.iter() { + let (used, _) = val.write_to_slice(&mut buf).unwrap(); + assert_eq!(used, *exp); + let v = val.write_to_vec(); + assert_eq!(&v, *exp); + let (deser, remain) = VarHeader::take_from_slice(used).unwrap(); + assert!(remain.is_empty()); + assert_eq!(val, &deser); + } + } +} diff --git a/source/postcard-rpc/src/headered.rs b/source/postcard-rpc/src/headered.rs deleted file mode 100644 index 0521e86..0000000 --- a/source/postcard-rpc/src/headered.rs +++ /dev/null @@ -1,126 +0,0 @@ -//! Helper functions for encoding/decoding messages with postcard-rpc headers - -use crate::{Key, WireHeader}; -use postcard::{ - ser_flavors::{Cobs, Flavor as SerFlavor, Slice}, - Serializer, -}; -use postcard_schema::Schema; - -use serde::Serialize; - -struct Headered { - flavor: B, -} - -impl Headered { - fn try_new_keyed(b: B, seq_no: u32, key: Key) -> Result { - let mut serializer = Serializer { output: b }; - let hdr = WireHeader { key, seq_no }; - hdr.serialize(&mut serializer)?; - Ok(Self { - flavor: serializer.output, - }) - } - - fn try_new(b: B, seq_no: u32, path: &str) -> Result { - let key = Key::for_path::(path); - Self::try_new_keyed(b, seq_no, key) - } -} - -impl SerFlavor for Headered { - type Output = B::Output; - - #[inline] - fn try_push(&mut self, data: u8) -> postcard::Result<()> { - self.flavor.try_push(data) - } - - #[inline] - fn finalize(self) -> postcard::Result { - self.flavor.finalize() - } - - #[inline] - fn try_extend(&mut self, data: &[u8]) -> postcard::Result<()> { - self.flavor.try_extend(data) - } -} - -/// Serialize to a slice with a prepended header -/// -/// WARNING: This rehashes the schema! Prefer [to_slice_keyed]! -pub fn to_slice<'a, T: Serialize + ?Sized + Schema>( - seq_no: u32, - path: &str, - value: &T, - buf: &'a mut [u8], -) -> Result<&'a mut [u8], postcard::Error> { - let flavor = Headered::try_new::(Slice::new(buf), seq_no, path)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Serialize to a slice with a prepended header -pub fn to_slice_keyed<'a, T: Serialize + ?Sized + Schema>( - seq_no: u32, - key: Key, - value: &T, - buf: &'a mut [u8], -) -> Result<&'a mut [u8], postcard::Error> { - let flavor = Headered::try_new_keyed(Slice::new(buf), seq_no, key)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Serialize to a COBS-encoded slice with a prepended header -/// -/// WARNING: This rehashes the schema! Prefer [to_slice_cobs_keyed]! -pub fn to_slice_cobs<'a, T: Serialize + ?Sized + Schema>( - seq_no: u32, - path: &str, - value: &T, - buf: &'a mut [u8], -) -> Result<&'a mut [u8], postcard::Error> { - let flavor = Headered::try_new::(Cobs::try_new(Slice::new(buf))?, seq_no, path)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Serialize to a COBS-encoded slice with a prepended header -pub fn to_slice_cobs_keyed<'a, T: Serialize + ?Sized + Schema>( - seq_no: u32, - key: Key, - value: &T, - buf: &'a mut [u8], -) -> Result<&'a mut [u8], postcard::Error> { - let flavor = Headered::try_new_keyed(Cobs::try_new(Slice::new(buf))?, seq_no, key)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Serialize to a Vec with a prepended header -/// -/// WARNING: This rehashes the schema! Prefer [to_slice_keyed]! -#[cfg(feature = "use-std")] -pub fn to_stdvec( - seq_no: u32, - path: &str, - value: &T, -) -> Result, postcard::Error> { - let flavor = Headered::try_new::(postcard::ser_flavors::StdVec::new(), seq_no, path)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Serialize to a Vec with a prepended header -#[cfg(feature = "use-std")] -pub fn to_stdvec_keyed( - seq_no: u32, - key: Key, - value: &T, -) -> Result, postcard::Error> { - let flavor = Headered::try_new_keyed(postcard::ser_flavors::StdVec::new(), seq_no, key)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Extract the header from a slice of bytes -pub fn extract_header_from_bytes(slice: &[u8]) -> Result<(WireHeader, &[u8]), postcard::Error> { - postcard::take_from_bytes::(slice) -} diff --git a/source/postcard-rpc/src/host_client/mod.rs b/source/postcard-rpc/src/host_client/mod.rs index aafec4d..f3d435e 100644 --- a/source/postcard-rpc/src/host_client/mod.rs +++ b/source/postcard-rpc/src/host_client/mod.rs @@ -8,7 +8,7 @@ use std::{ marker::PhantomData, sync::{ atomic::{AtomicU32, Ordering}, - Arc, + Arc, RwLock, }, }; @@ -20,10 +20,17 @@ use postcard_schema::Schema; use serde::{de::DeserializeOwned, Serialize}; use tokio::{ select, - sync::mpsc::{Receiver, Sender}, + sync::{ + mpsc::{Receiver, Sender}, + Mutex, + }, }; +use util::Subscriptions; -use crate::{Endpoint, Key, Topic, WireHeader}; +use crate::{ + header::{VarHeader, VarKey, VarKeyKind, VarSeq, VarSeqKind}, + Endpoint, Key, Topic, +}; use self::util::Stopper; @@ -38,6 +45,9 @@ pub mod webusb; pub(crate) mod util; +#[cfg(feature = "test-utils")] +pub mod test_channels; + /// Host Error Kind #[derive(Debug, PartialEq)] pub enum HostErr { @@ -76,7 +86,9 @@ impl From for HostErr { /// be returned to the caller. #[cfg(target_family = "wasm")] pub trait WireTx: 'static { - type Error: std::error::Error; // or std? + /// Transmit error type + type Error: std::error::Error; + /// Send a single frame fn send(&mut self, data: Vec) -> impl Future>; } @@ -89,7 +101,9 @@ pub trait WireTx: 'static { /// be returned to the caller. #[cfg(target_family = "wasm")] pub trait WireRx: 'static { + /// Receive error type type Error: std::error::Error; // or std? + /// Receive a single frame fn receive(&mut self) -> impl Future, Self::Error>>; } @@ -98,6 +112,7 @@ pub trait WireRx: 'static { /// Should be suitable for spawning a task in the host executor. #[cfg(target_family = "wasm")] pub trait WireSpawn: 'static { + /// Spawn a task fn spawn(&mut self, fut: impl Future + 'static); } @@ -113,7 +128,9 @@ pub trait WireSpawn: 'static { /// be returned to the caller. #[cfg(not(target_family = "wasm"))] pub trait WireTx: Send + 'static { - type Error: std::error::Error; // or std? + /// Transmit error type + type Error: std::error::Error; + /// Send a single frame fn send(&mut self, data: Vec) -> impl Future> + Send; } @@ -126,7 +143,9 @@ pub trait WireTx: Send + 'static { /// be returned to the caller. #[cfg(not(target_family = "wasm"))] pub trait WireRx: Send + 'static { - type Error: std::error::Error; // or std? + /// Receive error type + type Error: std::error::Error; + /// Receive a single frame fn receive(&mut self) -> impl Future, Self::Error>> + Send; } @@ -135,6 +154,7 @@ pub trait WireRx: Send + 'static { /// Should be suitable for spawning a task in the host executor. #[cfg(not(target_family = "wasm"))] pub trait WireSpawn: 'static { + /// Spawn a task fn spawn(&mut self, fut: impl Future + Send + 'static); } @@ -153,9 +173,10 @@ pub trait WireSpawn: 'static { pub struct HostClient { ctx: Arc, out: Sender, - subber: Sender, + subscriptions: Arc>, err_key: Key, stopper: Stopper, + seq_kind: VarSeqKind, _pd: PhantomData WireErr>, } @@ -164,23 +185,16 @@ impl HostClient where WireErr: DeserializeOwned + Schema, { - /// Create a new manually implemented [HostClient]. - /// - /// This is now deprecated, it is recommended to use [`HostClient::new_with_wire`] instead. - #[deprecated = "HostClient::new_manual will become private in the future, use HostClient::new_with_wire instead"] - pub fn new_manual(err_uri_path: &str, outgoing_depth: usize) -> (Self, WireContext) { - Self::new_manual_priv(err_uri_path, outgoing_depth) - } - /// Private method for creating internal context pub(crate) fn new_manual_priv( err_uri_path: &str, outgoing_depth: usize, + seq_kind: VarSeqKind, ) -> (Self, WireContext) { let (tx_pc, rx_pc) = tokio::sync::mpsc::channel(outgoing_depth); - let (tx_si, rx_si) = tokio::sync::mpsc::channel(outgoing_depth); let ctx = Arc::new(HostContext { + kkind: RwLock::new(VarKeyKind::Key8), map: WaitMap::new(), seq: AtomicU32::new(0), }); @@ -192,14 +206,14 @@ where out: tx_pc, err_key, _pd: PhantomData, - subber: tx_si.clone(), + subscriptions: Arc::new(Mutex::new(Subscriptions::default())), stopper: Stopper::new(), + seq_kind, }; let wire = WireContext { outgoing: rx_pc, incoming: ctx, - new_subs: rx_si, }; (me, wire) @@ -224,11 +238,14 @@ where E::Response: DeserializeOwned + Schema, { let seq_no = self.ctx.seq.fetch_add(1, Ordering::Relaxed); + let msg = postcard::to_stdvec(&t).expect("Allocations should not ever fail"); let frame = RpcFrame { - header: WireHeader { - key: E::REQ_KEY, - seq_no, + // NOTE: send_resp_raw automatically shrinks down key and sequence + // kinds to the appropriate amount + header: VarHeader { + key: VarKey::Key8(E::REQ_KEY), + seq_no: VarSeq::Seq4(seq_no), }, body: msg, }; @@ -237,33 +254,47 @@ where Ok(r) } + /// Perform an endpoint request/response,but without handling the + /// Ser/De automatically pub async fn send_resp_raw( &self, - rqst: RpcFrame, + mut rqst: RpcFrame, resp_key: Key, ) -> Result> { let cancel_fut = self.stopper.wait_stopped(); + let kkind: VarKeyKind = *self.ctx.kkind.read().unwrap(); + rqst.header.key.shrink_to(kkind); + rqst.header.seq_no.resize(self.seq_kind); + let mut resp_key = VarKey::Key8(resp_key); + let mut err_key = VarKey::Key8(self.err_key); + resp_key.shrink_to(kkind); + err_key.shrink_to(kkind); // TODO: Do I need something like a .subscribe method to ensure this is enqueued? - let ok_resp = self.ctx.map.wait(WireHeader { + let ok_resp = self.ctx.map.wait(VarHeader { seq_no: rqst.header.seq_no, key: resp_key, }); - let err_resp = self.ctx.map.wait(WireHeader { + let err_resp = self.ctx.map.wait(VarHeader { seq_no: rqst.header.seq_no, - key: self.err_key, + key: err_key, }); - let seq_no = rqst.header.seq_no; self.out.send(rqst).await.map_err(|_| HostErr::Closed)?; select! { _c = cancel_fut => Err(HostErr::Closed), o = ok_resp => { - let resp = o?; - Ok(RpcFrame { header: WireHeader { key: resp_key, seq_no }, body: resp }) + let (hdr, resp) = o?; + if hdr.key.kind() != kkind { + *self.ctx.kkind.write().unwrap() = hdr.key.kind(); + } + Ok(RpcFrame { header: hdr, body: resp }) }, e = err_resp => { - let resp = e?; + let (hdr, resp) = e?; + if hdr.key.kind() != kkind { + *self.ctx.kkind.write().unwrap() = hdr.key.kind(); + } let r = postcard::from_bytes::(&resp)?; Err(HostErr::Wire(r)) }, @@ -274,14 +305,14 @@ where /// /// There is no feedback if the server received our message. If the I/O worker is /// closed, an error is returned. - pub async fn publish(&self, seq_no: u32, msg: &T::Message) -> Result<(), IoClosed> + pub async fn publish(&self, seq_no: VarSeq, msg: &T::Message) -> Result<(), IoClosed> where T::Message: Serialize, { let smsg = postcard::to_stdvec(msg).expect("alloc should never fail"); let frame = RpcFrame { - header: WireHeader { - key: T::TOPIC_KEY, + header: VarHeader { + key: VarKey::Key8(T::TOPIC_KEY), seq_no, }, body: smsg, @@ -289,7 +320,12 @@ where self.publish_raw(frame).await } - pub async fn publish_raw(&self, frame: RpcFrame) -> Result<(), IoClosed> { + /// Publish the given raw frame + pub async fn publish_raw(&self, mut frame: RpcFrame) -> Result<(), IoClosed> { + let kkind: VarKeyKind = *self.ctx.kkind.read().unwrap(); + frame.header.key.shrink_to(kkind); + frame.header.seq_no.resize(self.seq_kind); + let cancel_fut = self.stopper.wait_stopped(); let operate_fut = self.out.send(frame); @@ -330,19 +366,25 @@ where T::Message: DeserializeOwned, { let (tx, rx) = tokio::sync::mpsc::channel(depth); - self.subber - .send(SubInfo { - key: T::TOPIC_KEY, - tx, - }) - .await - .map_err(|_| IoClosed)?; + { + let mut guard = self.subscriptions.lock().await; + if guard.stopped { + return Err(IoClosed); + } + if let Some(entry) = guard.list.iter_mut().find(|(k, _)| *k == T::TOPIC_KEY) { + tracing::warn!("replacing subscription for topic path '{}'", T::PATH); + entry.1 = tx; + } else { + guard.list.push((T::TOPIC_KEY, tx)); + } + } Ok(Subscription { rx, _pd: PhantomData, }) } + /// Subscribe to the given [`Key`], without automatically handling deserialization pub async fn subscribe_raw(&self, key: Key, depth: usize) -> Result { let cancel_fut = self.stopper.wait_stopped(); let operate_fut = self.subscribe_inner_raw(key, depth); @@ -359,10 +401,18 @@ where depth: usize, ) -> Result { let (tx, rx) = tokio::sync::mpsc::channel(depth); - self.subber - .send(SubInfo { key, tx }) - .await - .map_err(|_| IoClosed)?; + { + let mut guard = self.subscriptions.lock().await; + if guard.stopped { + return Err(IoClosed); + } + if let Some(entry) = guard.list.iter_mut().find(|(k, _)| *k == key) { + tracing::warn!("replacing subscription for raw topic key '{:?}'", key); + entry.1 = tx; + } else { + guard.list.push((key, tx)); + } + } Ok(RawSubscription { rx }) } @@ -388,6 +438,8 @@ where } } +/// Like Subscription, but receives Raw frames that are not +/// automatically deserialized pub struct RawSubscription { rx: Receiver, } @@ -432,18 +484,13 @@ impl Clone for HostClient { out: self.out.clone(), err_key: self.err_key, _pd: PhantomData, - subber: self.subber.clone(), + subscriptions: self.subscriptions.clone(), stopper: self.stopper.clone(), + seq_kind: self.seq_kind, } } } -/// A new subscription that should be accounted for -pub struct SubInfo { - pub key: Key, - pub tx: Sender, -} - /// Items necessary for implementing a custom I/O Task pub struct WireContext { /// This is a stream of frames that should be placed on the @@ -452,14 +499,12 @@ pub struct WireContext { /// This shared information contains the WaitMap used for replying to /// open requests. pub incoming: Arc, - /// This is a stream of new subscriptions that should be tracked - pub new_subs: Receiver, } /// A single postcard-rpc frame pub struct RpcFrame { /// The wire header - pub header: WireHeader, + pub header: VarHeader, /// The serialized message payload pub body: Vec, } @@ -467,7 +512,7 @@ pub struct RpcFrame { impl RpcFrame { /// Serialize the `RpcFrame` into a Vec of bytes pub fn to_bytes(&self) -> Vec { - let mut out = postcard::to_stdvec(&self.header).expect("Alloc should never fail"); + let mut out = self.header.write_to_vec(); out.extend_from_slice(&self.body); out } @@ -475,7 +520,8 @@ impl RpcFrame { /// Shared context between [HostClient] and the I/O worker task pub struct HostContext { - map: WaitMap>, + kkind: RwLock, + map: WaitMap)>, seq: AtomicU32, } @@ -495,7 +541,7 @@ impl HostContext { /// Like `HostContext::process` but tells you if we processed the message or /// nobody wanted it pub fn process_did_wake(&self, frame: RpcFrame) -> Result { - match self.map.wake(&frame.header, frame.body) { + match self.map.wake(&frame.header, (frame.header, frame.body)) { WakeOutcome::Woke => Ok(true), WakeOutcome::NoMatch(_) => Ok(false), WakeOutcome::Closed(_) => Err(ProcessError::Closed), @@ -506,7 +552,7 @@ impl HostContext { /// /// Returns an Err if the map was closed. pub fn process(&self, frame: RpcFrame) -> Result<(), ProcessError> { - if let WakeOutcome::Closed(_) = self.map.wake(&frame.header, frame.body) { + if let WakeOutcome::Closed(_) = self.map.wake(&frame.header, (frame.header, frame.body)) { Err(ProcessError::Closed) } else { Ok(()) diff --git a/source/postcard-rpc/src/host_client/raw_nusb.rs b/source/postcard-rpc/src/host_client/raw_nusb.rs index e1f2adb..55752f6 100644 --- a/source/postcard-rpc/src/host_client/raw_nusb.rs +++ b/source/postcard-rpc/src/host_client/raw_nusb.rs @@ -1,3 +1,5 @@ +//! Implementation of transport using nusb + use std::future::Future; use nusb::{ @@ -7,7 +9,10 @@ use nusb::{ use postcard_schema::Schema; use serde::de::DeserializeOwned; -use crate::host_client::{HostClient, WireRx, WireSpawn, WireTx}; +use crate::{ + header::VarSeqKind, + host_client::{HostClient, WireRx, WireSpawn, WireTx}, +}; // TODO: These should all be configurable, PRs welcome @@ -48,6 +53,7 @@ where /// /// ```rust,no_run /// use postcard_rpc::host_client::HostClient; + /// use postcard_rpc::header::VarSeqKind; /// use serde::{Serialize, Deserialize}; /// use postcard_schema::Schema; /// @@ -65,12 +71,15 @@ where /// "error", /// // Outgoing queue depth in messages /// 8, + /// // Use one-byte sequence numbers + /// VarSeqKind::Seq1, /// ).unwrap(); /// ``` pub fn try_new_raw_nusb bool>( func: F, err_uri_path: &str, outgoing_depth: usize, + seq_no_kind: VarSeqKind, ) -> Result { let x = nusb::list_devices() .map_err(|e| format!("Error listing devices: {e:?}"))? @@ -93,6 +102,7 @@ where consecutive_errs: 0, }, NusbSpawn, + seq_no_kind, err_uri_path, outgoing_depth, )) @@ -108,6 +118,7 @@ where /// /// ```rust,no_run /// use postcard_rpc::host_client::HostClient; + /// use postcard_rpc::header::VarSeqKind; /// use serde::{Serialize, Deserialize}; /// use postcard_schema::Schema; /// @@ -125,14 +136,17 @@ where /// "error", /// // Outgoing queue depth in messages /// 8, + /// // Use one-byte sequence numbers + /// VarSeqKind::Seq1, /// ); /// ``` pub fn new_raw_nusb bool>( func: F, err_uri_path: &str, outgoing_depth: usize, + seq_no_kind: VarSeqKind, ) -> Self { - Self::try_new_raw_nusb(func, err_uri_path, outgoing_depth) + Self::try_new_raw_nusb(func, err_uri_path, outgoing_depth, seq_no_kind) .expect("should have found nusb device") } } diff --git a/source/postcard-rpc/src/host_client/serial.rs b/source/postcard-rpc/src/host_client/serial.rs index 9f748b9..98e2a34 100644 --- a/source/postcard-rpc/src/host_client/serial.rs +++ b/source/postcard-rpc/src/host_client/serial.rs @@ -8,6 +8,7 @@ use tokio_serial::{SerialPortBuilderExt, SerialStream}; use crate::{ accumulator::raw::{CobsAccumulator, FeedResult}, + header::VarSeqKind, host_client::{HostClient, WireRx, WireSpawn, WireTx}, }; @@ -31,6 +32,7 @@ where /// /// ```rust,no_run /// use postcard_rpc::host_client::HostClient; + /// use postcard_rpc::header::VarSeqKind; /// use serde::{Serialize, Deserialize}; /// use postcard_schema::Schema; /// @@ -51,6 +53,8 @@ where /// // Baud rate of serial (does not generally matter for /// // USB UART/CDC-ACM serial connections) /// 115_200, + /// // Use one-byte sequence numbers + /// VarSeqKind::Seq1, /// ); /// ``` pub fn try_new_serial_cobs( @@ -58,6 +62,7 @@ where err_uri_path: &str, outgoing_depth: usize, baud: u32, + seq_no_kind: VarSeqKind, ) -> Result { let port = tokio_serial::new(serial_path, baud) .open_native_async() @@ -74,6 +79,7 @@ where pending: VecDeque::new(), }, SerialSpawn, + seq_no_kind, err_uri_path, outgoing_depth, )) @@ -89,8 +95,10 @@ where err_uri_path: &str, outgoing_depth: usize, baud: u32, + seq_no_kind: VarSeqKind, ) -> Self { - Self::try_new_serial_cobs(serial_path, err_uri_path, outgoing_depth, baud).unwrap() + Self::try_new_serial_cobs(serial_path, err_uri_path, outgoing_depth, baud, seq_no_kind) + .unwrap() } } diff --git a/source/postcard-rpc/src/host_client/test_channels.rs b/source/postcard-rpc/src/host_client/test_channels.rs new file mode 100644 index 0000000..a6df2d3 --- /dev/null +++ b/source/postcard-rpc/src/host_client/test_channels.rs @@ -0,0 +1,85 @@ +//! A Client implementation using channels for testing + +use crate::{ + header::VarSeqKind, + host_client::{HostClient, WireRx, WireSpawn, WireTx}, + standard_icd::WireError, +}; +use core::fmt::Display; +use tokio::sync::mpsc; + +/// Create a new HostClient from the given server channels +pub fn new_from_channels( + tx: mpsc::Sender>, + rx: mpsc::Receiver>, + seq_kind: VarSeqKind, +) -> HostClient { + HostClient::new_with_wire( + ChannelTx { tx }, + ChannelRx { rx }, + TokSpawn, + seq_kind, + crate::standard_icd::ERROR_PATH, + 64, + ) +} + +/// Server error kinds +#[derive(Debug)] +pub enum ChannelError { + /// Rx was closed + RxClosed, + /// Tx was closed + TxClosed, +} + +impl Display for ChannelError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) + } +} + +impl std::error::Error for ChannelError {} + +/// Trait impl for channels +pub struct ChannelRx { + rx: mpsc::Receiver>, +} +/// Trait impl for channels +pub struct ChannelTx { + tx: mpsc::Sender>, +} +/// Trait impl for channels +pub struct ChannelSpawn; + +impl WireRx for ChannelRx { + type Error = ChannelError; + + async fn receive(&mut self) -> Result, Self::Error> { + match self.rx.recv().await { + Some(v) => { + #[cfg(test)] + println!("c<-s: {v:?}"); + Ok(v) + } + None => Err(ChannelError::RxClosed), + } + } +} + +impl WireTx for ChannelTx { + type Error = ChannelError; + + async fn send(&mut self, data: Vec) -> Result<(), Self::Error> { + #[cfg(test)] + println!("c->s: {data:?}"); + self.tx.send(data).await.map_err(|_| ChannelError::TxClosed) + } +} + +struct TokSpawn; +impl WireSpawn for TokSpawn { + fn spawn(&mut self, fut: impl std::future::Future + Send + 'static) { + _ = tokio::task::spawn(fut); + } +} diff --git a/source/postcard-rpc/src/host_client/util.rs b/source/postcard-rpc/src/host_client/util.rs index 95c60fe..8222da1 100644 --- a/source/postcard-rpc/src/host_client/util.rs +++ b/source/postcard-rpc/src/host_client/util.rs @@ -1,5 +1,5 @@ // the contents of this file can probably be moved up to `mod.rs` -use std::{collections::HashMap, fmt::Debug, sync::Arc}; +use std::{fmt::Debug, sync::Arc}; use maitake_sync::WaitQueue; use postcard_schema::Schema; @@ -14,15 +14,18 @@ use tokio::{ use tracing::{debug, trace, warn}; use crate::{ - headered::extract_header_from_bytes, + header::{VarHeader, VarKey, VarSeqKind}, host_client::{ - HostClient, HostContext, ProcessError, RpcFrame, SubInfo, WireContext, WireRx, WireSpawn, - WireTx, + HostClient, HostContext, ProcessError, RpcFrame, WireContext, WireRx, WireSpawn, WireTx, }, Key, }; -pub(crate) type Subscriptions = HashMap>; +#[derive(Default, Debug)] +pub(crate) struct Subscriptions { + pub(crate) list: Vec<(Key, Sender)>, + pub(crate) stopped: bool, +} /// A basic cancellation-token /// @@ -74,6 +77,7 @@ where tx: WTX, rx: WRX, mut sp: WSP, + seq_kind: VarSeqKind, err_uri_path: &str, outgoing_depth: usize, ) -> Self @@ -82,24 +86,17 @@ where WRX: WireRx, WSP: WireSpawn, { - let (me, wire_ctx) = Self::new_manual_priv(err_uri_path, outgoing_depth); - - let WireContext { - outgoing, - incoming, - new_subs, - } = wire_ctx; + let (me, wire_ctx) = Self::new_manual_priv(err_uri_path, outgoing_depth, seq_kind); - let subscriptions: Arc> = Arc::new(Mutex::new(Subscriptions::new())); + let WireContext { outgoing, incoming } = wire_ctx; sp.spawn(out_worker(tx, outgoing, me.stopper.clone())); sp.spawn(in_worker( rx, incoming, - subscriptions.clone(), + me.subscriptions.clone(), me.stopper.clone(), )); - sp.spawn(sub_worker(new_subs, subscriptions, me.stopper.clone())); me } @@ -150,7 +147,7 @@ async fn in_worker( W::Error: Debug, { let cancel_fut = stop.wait_stopped(); - let operate_fut = in_worker_inner(wire, host_ctx, subscriptions); + let operate_fut = in_worker_inner(wire, host_ctx, subscriptions.clone()); select! { _ = cancel_fut => {}, _ = operate_fut => { @@ -158,6 +155,11 @@ async fn in_worker( stop.stop(); }, } + // If we stop, purge the subscription list so that it is clear that no more messages are coming + // TODO: Have a "stopped" flag to prevent later additions (e.g. sub after store?) + let mut guard = subscriptions.lock().await; + guard.stopped = true; + guard.list.clear(); } async fn in_worker_inner( @@ -174,7 +176,7 @@ async fn in_worker_inner( return; }; - let Ok((hdr, body)) = extract_header_from_bytes(&res) else { + let Some((hdr, body)) = VarHeader::take_from_slice(&res) else { warn!("Header decode error!"); continue; }; @@ -188,10 +190,14 @@ async fn in_worker_inner( let key = hdr.key; // Remove if sending fails - let remove_sub = if let Some(m) = subs_guard.get(&key) { + let remove_sub = if let Some((_h, m)) = subs_guard + .list + .iter() + .find(|(k, _)| VarKey::Key8(*k) == key) + { handled = true; let frame = RpcFrame { - header: hdr.clone(), + header: hdr, body: body.to_vec(), }; let res = m.try_send(frame); @@ -213,7 +219,7 @@ async fn in_worker_inner( if remove_sub { debug!("Dropping subscription"); - subs_guard.remove(&key); + subs_guard.list.retain(|(k, _)| VarKey::Key8(*k) != key); } } @@ -236,31 +242,3 @@ async fn in_worker_inner( } } } - -async fn sub_worker( - new_subs: Receiver, - subscriptions: Arc>, - stop: Stopper, -) { - let cancel_fut = stop.wait_stopped(); - let operate_fut = sub_worker_inner(new_subs, subscriptions); - select! { - _ = cancel_fut => {}, - _ = operate_fut => { - // if WE exited, notify everyone else it's stoppin time - stop.stop(); - }, - } -} - -async fn sub_worker_inner( - mut new_subs: Receiver, - subscriptions: Arc>, -) { - while let Some(sub) = new_subs.recv().await { - let mut sub_guard = subscriptions.lock().await; - if let Some(_old) = sub_guard.insert(sub.key, sub.tx) { - warn!("Replacing old subscription for {:?}", sub.key); - } - } -} diff --git a/source/postcard-rpc/src/host_client/webusb.rs b/source/postcard-rpc/src/host_client/webusb.rs index d2f8d48..8609f82 100644 --- a/source/postcard-rpc/src/host_client/webusb.rs +++ b/source/postcard-rpc/src/host_client/webusb.rs @@ -1,3 +1,5 @@ +//! Implementation of transport using webusb + use gloo::utils::format::JsValueSerdeExt; use postcard_schema::Schema; use serde::de::DeserializeOwned; @@ -7,8 +9,10 @@ use wasm_bindgen::{prelude::*, JsCast}; use wasm_bindgen_futures::{spawn_local, JsFuture}; use web_sys::{UsbDevice, UsbInTransferResult, UsbTransferStatus}; +use crate::header::VarSeqKind; use crate::host_client::{HostClient, WireRx, WireSpawn, WireTx}; +/// Implementation of the wire interface for WebUsb #[derive(Clone)] pub struct WebUsbWire { device: UsbDevice, @@ -17,10 +21,13 @@ pub struct WebUsbWire { ep_out: u8, } +/// WebUsb Error type #[derive(thiserror::Error, Debug, Clone)] pub enum Error { + /// Error originating from the browser #[error("Browser error: {0}")] Browser(String), + /// Error originating from the USB stack #[error("USB transfer error: {0}")] UsbTransfer(&'static str), } @@ -49,6 +56,7 @@ impl HostClient where WireErr: DeserializeOwned + Schema, { + /// Create a new webusb connection instance pub async fn try_new_webusb( vendor_id: u16, interface: u8, @@ -57,6 +65,7 @@ where ep_out: u8, err_uri_path: &str, outgoing_depth: usize, + seq_no_len: VarSeqKind, ) -> Result { let wire = WebUsbWire::new(vendor_id, interface, transfer_max_length, ep_in, ep_out).await?; @@ -64,6 +73,7 @@ where wire.clone(), wire.clone(), wire, + seq_no_len, err_uri_path, outgoing_depth, )) @@ -71,6 +81,7 @@ where } /// # Example usage () +/// /// ```no_run /// let wire = WebUsbWire::new(0x16c0, 0, 1000, 1, 1) /// .await @@ -82,10 +93,12 @@ where /// wire, /// crate::standard_icd::ERROR_PATH, /// 8, +/// VarSeqKind::Seq1, /// ) /// .expect("could not create HostClient"); /// ``` impl WebUsbWire { + /// Create a new instance of [`WebUsbWire`] pub async fn new( vendor_id: u16, interface: u8, @@ -146,11 +159,7 @@ impl WireTx for WebUsbWire { let data: js_sys::Uint8Array = data.as_slice().into(); // TODO for reasons unknown, web-sys wants mutable access to the send buffer. // tracking issue: https://github.com/rustwasm/wasm-bindgen/issues/3963 - JsFuture::from( - self.device - .transfer_out_with_u8_array(self.ep_out, &data)?, - ) - .await?; + JsFuture::from(self.device.transfer_out_with_u8_array(self.ep_out, &data)?).await?; Ok(()) } } diff --git a/source/postcard-rpc/src/lib.rs b/source/postcard-rpc/src/lib.rs index 5d34d15..34bde2b 100644 --- a/source/postcard-rpc/src/lib.rs +++ b/source/postcard-rpc/src/lib.rs @@ -1,12 +1,67 @@ //! The goal of `postcard-rpc` is to make it easier for a //! host PC to talk to a constrained device, like a microcontroller. //! -//! See [the repo] for examples, and [the overview] for more details on how -//! to use this crate. +//! See [the repo] for examples //! //! [the repo]: https://github.com/jamesmunns/postcard-rpc //! [the overview]: https://github.com/jamesmunns/postcard-rpc/blob/main/docs/overview.md //! +//! ## Architecture overview +//! +//! ```text +//! ┌──────────┐ ┌─────────┐ ┌───────────┐ +//! │ Endpoint │ │ Publish │ │ Subscribe │ +//! └──────────┘ └─────────┘ └───────────┘ +//! │ ▲ message│ │ ▲ +//! ┌────────┐ rqst│ │resp │ subscribe│ │messages +//! ┌─┤ CLIENT ├─────┼─────┼──────────────┼────────────────┼────────┼──┐ +//! │ └────────┘ ▼ │ ▼ ▼ │ │ +//! │ ┌─────────────────────────────────────────────────────┐ │ │ +//! │ │ HostClient │ │ │ +//! │ └─────────────────────────────────────────────────────┘ │ │ +//! │ │ │ ▲ │ | │ +//! │ │ │ │ │ │ │ +//! │ │ │ │ ▼ │ │ +//! │ │ │ ┌──────────────┬──────────────┐│ +//! │ │ └─────▶│ Pending Resp │ Subscription ││ +//! │ │ └──────────────┴──────────────┘│ +//! │ │ ▲ ▲ │ +//! │ │ └───────┬──────┘ │ +//! │ ▼ │ │ +//! │ ┌────────────────────┐ ┌────────────────────┐ │ +//! │ ││ Task: out_worker │ │ Task: in_worker ▲│ │ +//! │ ├┼───────────────────┤ ├───────────────────┼┤ │ +//! │ │▼ Trait: WireTx │ │ Trait: WireRx ││ │ +//! └──────┴────────────────────┴────────────┴────────────────────┴────┘ +//! │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ ▲ +//! │ The Server + Client WireRx │ +//! │ │ and WireTx traits can be │ │ +//! │ impl'd for any wire │ +//! │ │ transport like USB, TCP, │ │ +//! │ I2C, UART, etc. │ +//! ▼ └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ │ +//! ┌─────┬────────────────────┬────────────┬────────────────────┬─────┐ +//! │ ││ Trait: WireRx │ │ Trait: WireTx ▲│ │ +//! │ ├┼───────────────────┤ ├───────────────────┼┤ │ +//! │ ││ Server │ ┌───▶│ Sender ││ │ +//! │ ├┼───────────────────┤ │ └────────────────────┘ │ +//! │ │▼ Macro: Dispatch │ │ ▲ │ +//! │ └────────────────────┘ │ │ │ +//! │ ┌─────────┐ │ ┌──────────┐ │ ┌───────────┐ │ ┌───────────┐ │ +//! │ │ Topic │ │ │ Endpoint │ │ │ Publisher │ │ │ Publisher │ │ +//! │ │ fn │◀┼▶│ async fn │────┤ │ Task │─┼─│ Task │ │ +//! │ │ Handler │ │ │ Handler │ │ └───────────┘ │ └───────────┘ │ +//! │ └─────────┘ │ └──────────┘ │ │ │ +//! │ ┌─────────┐ │ ┌──────────┐ │ ┌───────────┐ │ ┌───────────┐ │ +//! │ │ Topic │ │ │ Endpoint │ │ │ Publisher │ │ │ Publisher │ │ +//! │ │async fn │◀┴▶│ task │────┘ │ Task │─┴─│ Task │ │ +//! │ │ Handler │ │ Handler │ └───────────┘ └───────────┘ │ +//! │ └─────────┘ └──────────┘ │ +//! │ ┌────────┐ │ +//! └─┤ SERVER ├───────────────────────────────────────────────────────┘ +//! └────────┘ +//! ``` +//! //! ## Defining a schema //! //! Typically, you will define your "wire types" in a shared schema crate. This @@ -24,47 +79,9 @@ //! convert a type into bytes on the wire //! * [`serde`]'s [`Deserialize`] trait - which defines how we //! can convert bytes on the wire into a type -//! * [`postcard`]'s [`Schema`] trait - which generates a reflection-style +//! * [`postcard_schema`]'s [`Schema`] trait - which generates a reflection-style //! schema value for a given type. //! -//! Here's an example of three types we'll use in future examples: -//! -//! ```rust -//! // Consider making your shared "wire types" crate conditionally no-std, -//! // if you want to use it with no-std embedded targets! This makes it no_std -//! // except for testing and when the "use-std" feature is active. -//! // -//! // You may need to also ensure that `std`/`use-std` features are not active -//! // in any dependencies as well. -//! #![cfg_attr(not(any(test, feature = "use-std")), no_std)] -//! # fn main() {} -//! -//! use serde::{Serialize, Deserialize}; -//! use postcard_schema::Schema; -//! -//! #[derive(Serialize, Deserialize, Schema)] -//! pub struct Alpha { -//! pub one: u8, -//! pub two: i64, -//! } -//! -//! #[derive(Serialize, Deserialize, Schema)] -//! pub enum Beta { -//! Bib, -//! Bim(i16), -//! Bap, -//! } -//! -//! #[derive(Serialize, Deserialize, Schema)] -//! pub struct Delta(pub [u8; 32]); -//! -//! #[derive(Serialize, Deserialize, Schema)] -//! pub enum WireError { -//! ALittleBad, -//! VeryBad, -//! } -//! ``` -//! //! ### Endpoints //! //! Now that we have some basic types that will be used on the wire, we need @@ -80,42 +97,6 @@ //! * The type of the Response //! * A string "path", like an HTTP URI that uniquely identifies the endpoint. //! -//! The easiest way to define an Endpoint is to use the [`endpoint!`][endpoint] -//! macro. -//! -//! ```rust -//! # use serde::{Serialize, Deserialize}; -//! # use postcard_schema::Schema; -//! # -//! # #[derive(Serialize, Deserialize, Schema)] -//! # pub struct Alpha { -//! # pub one: u8, -//! # pub two: i64, -//! # } -//! # -//! # #[derive(Serialize, Deserialize, Schema)] -//! # pub enum Beta { -//! # Bib, -//! # Bim(i16), -//! # Bap, -//! # } -//! # -//! use postcard_rpc::endpoint; -//! -//! // Define an endpoint -//! endpoint!( -//! // This is the name of a marker type that represents our Endpoint, -//! // and implements the `Endpoint` trait. -//! FirstEndpoint, -//! // This is the request type for this endpoint -//! Alpha, -//! // This is the response type for this endpoint -//! Beta, -//! // This is the path/URI of the endpoint -//! "endpoints/first", -//! ); -//! ``` -//! //! ### Topics //! //! Sometimes, you would just like to send data in a single direction, with no @@ -129,60 +110,20 @@ //! //! * The type of the Message //! * A string "path", like an HTTP URI that uniquely identifies the topic. -//! -//! The easiest way to define a Topic is to use the [`topic!`][topic] -//! macro. -//! -//! ```rust -//! # use serde::{Serialize, Deserialize}; -//! # use postcard_schema::Schema; -//! # -//! # #[derive(Serialize, Deserialize, Schema)] -//! # pub struct Delta(pub [u8; 32]); -//! # -//! use postcard_rpc::topic; -//! -//! // Define a topic -//! topic!( -//! // This is the name of a marker type that represents our Topic, -//! // and implements `Topic` trait. -//! FirstTopic, -//! // This is the message type for the endpoint (note there is no -//! // response type!) -//! Delta, -//! // This is the path/URI of the topic -//! "topics/first", -//! ); -//! ``` -//! -//! ## Using a schema -//! -//! At the moment, this library is primarily oriented around: -//! -//! * A single Client, usually a PC, with access to `std` -//! * A single Server, usually an MCU, without access to `std` -//! -//! For Client facilities, check out the [`host_client`] module, -//! particularly the [`HostClient`][host_client::HostClient] struct. -//! This is only available with the `use-std` feature active. -//! -//! A serial-port transport using cobs encoding is available with the `cobs-serial` feature. -//! This feature will add the [`new_serial_cobs`][host_client::HostClient::new_serial_cobs] constructor to [`HostClient`][host_client::HostClient]. -//! -//! For Server facilities, check out the [`Dispatch`] struct. This is -//! available with or without the standard library. #![cfg_attr(not(any(test, feature = "use-std")), no_std)] +#![deny(missing_docs)] +#![deny(rustdoc::broken_intra_doc_links)] -use headered::extract_header_from_bytes; -use postcard_schema::Schema; +use header::{VarKey, VarKeyKind}; +use postcard_schema::{schema::NamedType, Schema}; use serde::{Deserialize, Serialize}; #[cfg(feature = "cobs")] pub mod accumulator; pub mod hash; -pub mod headered; +pub mod header; #[cfg(feature = "use-std")] pub mod host_client; @@ -190,121 +131,11 @@ pub mod host_client; #[cfg(any(test, feature = "test-utils"))] pub mod test_utils; -#[cfg(feature = "embassy-usb-0_3-server")] -pub mod target_server; - mod macros; -/// Error type for [Dispatch] -#[derive(Debug, PartialEq)] -pub enum Error { - /// No handler was found for the given message. - /// The decoded key and sequence number are returned - NoMatchingHandler { key: Key, seq_no: u32 }, - /// The handler returned an error - DispatchFailure(E), - /// An error when decoding messages - Postcard(postcard::Error), -} +pub mod server; -impl From for Error { - fn from(value: postcard::Error) -> Self { - Self::Postcard(value) - } -} - -/// Dispatch is the primary interface for MCU "server" devices. -/// -/// Dispatch is generic over three types: -/// -/// 1. The `Context`, which will be passed as a mutable reference -/// to each of the handlers. It typically should contain -/// whatever resource is necessary to send replies back to -/// the host. -/// 2. The `Error` type, which can be returned by handlers -/// 3. `N`, for the maximum number of handlers -/// -/// If you plan to use COBS encoding, you can also use [CobsDispatch]. -/// which will automatically handle accumulating bytes from the wire. -/// -/// [CobsDispatch]: crate::accumulator::dispatch::CobsDispatch -/// Note: This will be available when the `cobs` or `cobs-serial` feature is enabled. -pub struct Dispatch { - items: heapless::Vec<(Key, Handler), N>, - context: Context, -} - -impl Dispatch { - /// Create a new [Dispatch] - pub fn new(c: Context) -> Self { - Self { - items: heapless::Vec::new(), - context: c, - } - } - - /// Add a handler to the [Dispatch] for the given path and type - /// - /// Returns an error if the given type+path have already been added, - /// or if Dispatch is full. - pub fn add_handler( - &mut self, - handler: Handler, - ) -> Result<(), &'static str> { - if self.items.is_full() { - return Err("full"); - } - let id = E::REQ_KEY; - if self.items.iter().any(|(k, _)| k == &id) { - return Err("dupe"); - } - let _ = self.items.push((id, handler)); - - // TODO: Why does this throw lifetime errors? - // self.items.sort_unstable_by_key(|(k, _)| k); - Ok(()) - } - - /// Accessor function for the Context field - pub fn context(&mut self) -> &mut Context { - &mut self.context - } - - /// Attempt to dispatch the given message - /// - /// The bytes should consist of exactly one message (including the header). - /// - /// Returns an error in any of the following cases: - /// - /// * We failed to decode a header - /// * No handler was found for the decoded key - /// * The handler ran, but returned an error - pub fn dispatch(&mut self, bytes: &[u8]) -> Result<(), Error> { - let (hdr, remain) = extract_header_from_bytes(bytes)?; - - // TODO: switch to binary search once we sort? - let Some(disp) = self - .items - .iter() - .find_map(|(k, d)| if k == &hdr.key { Some(d) } else { None }) - else { - return Err(Error::::NoMatchingHandler { - key: hdr.key, - seq_no: hdr.seq_no, - }); - }; - (disp)(&hdr, &mut self.context, remain).map_err(Error::DispatchFailure) - } -} - -type Handler = fn(&WireHeader, &mut C, &[u8]) -> Result<(), E>; - -/// The WireHeader is appended to all messages -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct WireHeader { - pub key: Key, - pub seq_no: u32, -} +pub mod uniques; /// The `Key` uniquely identifies what "kind" of message this is. /// @@ -381,12 +212,213 @@ mod key_owned { use super::*; use postcard_schema::schema::owned::OwnedNamedType; impl Key { + /// Calculate the Key for the given path and [`OwnedNamedType`] pub fn for_owned_schema_path(path: &str, nt: &OwnedNamedType) -> Key { Key(crate::hash::fnv1a64_owned::hash_ty_path_owned(path, nt)) } } } +/// A compacted 2-byte key +/// +/// This is defined specifically as the following conversion: +/// +/// * Key8 bytes (`[u8; 8]`): `[a, b, c, d, e, f, g, h]` +/// * Key4 bytes (`u8`): `a ^ b ^ c ^ d ^ e ^ f ^ g ^ h` +#[derive(Debug, Copy, Clone)] +pub struct Key1(u8); + +/// A compacted 2-byte key +/// +/// This is defined specifically as the following conversion: +/// +/// * Key8 bytes (`[u8; 8]`): `[a, b, c, d, e, f, g, h]` +/// * Key4 bytes (`[u8; 2]`): `[a ^ b ^ c ^ d, e ^ f ^ g ^ h]` +#[derive(Debug, Copy, Clone)] +pub struct Key2([u8; 2]); + +/// A compacted 4-byte key +/// +/// This is defined specifically as the following conversion: +/// +/// * Key8 bytes (`[u8; 8]`): `[a, b, c, d, e, f, g, h]` +/// * Key4 bytes (`[u8; 4]`): `[a ^ b, c ^ d, e ^ f, g ^ h]` +#[derive(Debug, Copy, Clone)] +pub struct Key4([u8; 4]); + +impl Key1 { + /// Convert from a 2-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key2(value: Key2) -> Self { + let [a, b] = value.0; + Self(a ^ b) + } + + /// Convert from a 4-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key4(value: Key4) -> Self { + let [a, b, c, d] = value.0; + Self(a ^ b ^ c ^ d) + } + + /// Convert from a full size 8-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key8(value: Key) -> Self { + let [a, b, c, d, e, f, g, h] = value.0; + Self(a ^ b ^ c ^ d ^ e ^ f ^ g ^ h) + } + + /// Convert to the inner byte representation + #[inline] + pub const fn to_bytes(&self) -> u8 { + self.0 + } + + /// Create a `Key1` from a [`VarKey`] + /// + /// This method can never fail, but has the same API as other key + /// types for consistency reasons. + #[inline] + pub fn try_from_varkey(value: &VarKey) -> Option { + Some(match value { + VarKey::Key1(key1) => *key1, + VarKey::Key2(key2) => Key1::from_key2(*key2), + VarKey::Key4(key4) => Key1::from_key4(*key4), + VarKey::Key8(key) => Key1::from_key8(*key), + }) + } +} + +impl Key2 { + /// Convert from a 4-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key4(value: Key4) -> Self { + let [a, b, c, d] = value.0; + Self([a ^ b, c ^ d]) + } + + /// Convert from a full size 8-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key8(value: Key) -> Self { + let [a, b, c, d, e, f, g, h] = value.0; + Self([a ^ b ^ c ^ d, e ^ f ^ g ^ h]) + } + + /// Convert to the inner byte representation + #[inline] + pub const fn to_bytes(&self) -> [u8; 2] { + self.0 + } + + /// Attempt to create a [`Key2`] from a [`VarKey`]. + /// + /// Only succeeds if `value` is a `VarKey::Key2`, `VarKey::Key4`, or `VarKey::Key8`. + #[inline] + pub fn try_from_varkey(value: &VarKey) -> Option { + Some(match value { + VarKey::Key1(_) => return None, + VarKey::Key2(key2) => *key2, + VarKey::Key4(key4) => Key2::from_key4(*key4), + VarKey::Key8(key) => Key2::from_key8(*key), + }) + } +} + +impl Key4 { + /// Convert from a full size 8-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key8(value: Key) -> Self { + let [a, b, c, d, e, f, g, h] = value.0; + Self([a ^ b, c ^ d, e ^ f, g ^ h]) + } + + /// Convert to the inner byte representation + #[inline] + pub const fn to_bytes(&self) -> [u8; 4] { + self.0 + } + + /// Attempt to create a [`Key4`] from a [`VarKey`]. + /// + /// Only succeeds if `value` is a `VarKey::Key4` or `VarKey::Key8`. + #[inline] + pub fn try_from_varkey(value: &VarKey) -> Option { + Some(match value { + VarKey::Key1(_) => return None, + VarKey::Key2(_) => return None, + VarKey::Key4(key4) => *key4, + VarKey::Key8(key) => Key4::from_key8(*key), + }) + } +} + +impl Key { + /// This is an identity function, used for consistency + #[inline] + pub const fn from_key8(value: Key) -> Self { + value + } + + /// Attempt to create a [`Key`] from a [`VarKey`]. + /// + /// Only succeeds if `value` is a `VarKey::Key8`. + #[inline] + pub fn try_from_varkey(value: &VarKey) -> Option { + match value { + VarKey::Key8(key) => Some(*key), + _ => None, + } + } +} + +impl From for Key1 { + fn from(value: Key2) -> Self { + Self::from_key2(value) + } +} + +impl From for Key1 { + fn from(value: Key4) -> Self { + Self::from_key4(value) + } +} + +impl From for Key1 { + fn from(value: Key) -> Self { + Self::from_key8(value) + } +} + +impl From for Key2 { + fn from(value: Key4) -> Self { + Self::from_key4(value) + } +} + +impl From for Key2 { + fn from(value: Key) -> Self { + Self::from_key8(value) + } +} + +impl From for Key4 { + fn from(value: Key) -> Self { + Self::from_key8(value) + } +} + /// A marker trait denoting a single endpoint /// /// Typically used with the [endpoint] macro. @@ -427,28 +459,49 @@ pub mod standard_icd { use postcard_schema::Schema; use serde::{Deserialize, Serialize}; + /// The calculated Key for the type [`WireError`] and the path [`ERROR_PATH`] pub const ERROR_KEY: Key = Key::for_path::(ERROR_PATH); + + /// The path string used for the error type pub const ERROR_PATH: &str = "error"; + /// The given frame was too long #[derive(Serialize, Deserialize, Schema, Debug, PartialEq)] pub struct FrameTooLong { + /// The length of the too-long frame pub len: u32, + /// The maximum frame length supported pub max: u32, } + /// The given frame was too short #[derive(Serialize, Deserialize, Schema, Debug, PartialEq)] pub struct FrameTooShort { + /// The length of the too-short frame pub len: u32, } + /// A protocol error that is handled outside of the normal request type, usually + /// indicating a protocol-level error #[derive(Serialize, Deserialize, Schema, Debug, PartialEq)] pub enum WireError { + /// The frame exceeded the buffering capabilities of the server FrameTooLong(FrameTooLong), + /// The frame was shorter than the minimum frame size and was rejected FrameTooShort(FrameTooShort), + /// Deserialization of a message failed DeserFailed, + /// Serialization of a message failed, usually due to a lack of space to + /// buffer the serialized form SerFailed, - UnknownKey([u8; 8]), + /// The key associated with this request was unknown + UnknownKey, + /// The server was unable to spawn the associated handler, typically due + /// to an exhaustion of resources FailedToSpawn, + /// The provided key is below the minimum key size calculated to avoid hash + /// collisions, and was rejected to avoid potential misunderstanding + KeyTooSmall, } #[cfg(not(feature = "use-std"))] @@ -457,3 +510,49 @@ pub mod standard_icd { #[cfg(feature = "use-std")] crate::topic!(Logging, Vec, "logs/formatted"); } + +/// An overview of all topics (in and out) and endpoints +/// +/// Typically generated by the [`define_dispatch!()`] macro. Contains a list +/// of all unique types across endpoints and topics, as well as the endpoints, +/// topics in (client to server), topics out (server to client), as well as a +/// calculated minimum key length required to avoid collisions in either the in +/// or out direction. +pub struct DeviceMap { + /// The set of unique types used by all endpoints and topics in this map + pub types: &'static [&'static NamedType], + /// The list of endpoints by path string, request key, and response key + pub endpoints: &'static [(&'static str, Key, Key)], + /// The list of topics (client to server) by path string and topic key + pub topics_in: &'static [(&'static str, Key)], + /// The list of topics (server to client) by path string and topic key + pub topics_out: &'static [(&'static str, Key)], + /// The minimum key size required to avoid hash collisions + pub min_key_len: VarKeyKind, +} + +/// An overview of a list of endpoints +/// +/// Typically generated by the [`endpoints!()`] macro. Contains a list of +/// all unique types used by a list of endpoints, as well as the list of these +/// endpoints by path, request key, and response key +#[derive(Debug)] +pub struct EndpointMap { + /// The set of unique types used by all endpoints in this map + pub types: &'static [&'static NamedType], + /// The list of endpoints by path string, request key, and response key + pub endpoints: &'static [(&'static str, Key, Key)], +} + +/// An overview of a list of topics +/// +/// Typically generated by the [`topics!()`] macro. Contains a list of all +/// unique types used by a list of topics as well as the list of the topics +/// by path and key +#[derive(Debug)] +pub struct TopicMap { + /// The set of unique types used by all topics in this map + pub types: &'static [&'static NamedType], + /// The list of topics by path string and topic key + pub topics: &'static [(&'static str, Key)], +} diff --git a/source/postcard-rpc/src/macros.rs b/source/postcard-rpc/src/macros.rs index 7816364..bd047c7 100644 --- a/source/postcard-rpc/src/macros.rs +++ b/source/postcard-rpc/src/macros.rs @@ -3,6 +3,8 @@ /// Used to define a single Endpoint marker type that implements the /// [Endpoint][crate::Endpoint] trait. /// +/// Prefer the [`endpoints!()`][crate::endpoints]` macro instead. +/// /// ```rust /// # use postcard_schema::Schema; /// # use serde::{Serialize, Deserialize}; @@ -45,11 +47,96 @@ macro_rules! endpoint { }; } +/// ## Endpoints macro +/// +/// Used to define multiple Endpoint marker types that implements the +/// [Endpoint][crate::Endpoint] trait. +/// +/// ```rust +/// # use postcard_schema::Schema; +/// # use serde::{Serialize, Deserialize}; +/// use postcard_rpc::endpoints; +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Req1 { +/// a: u8, +/// b: u64, +/// } +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Resp1 { +/// c: [u8; 4], +/// d: i32, +/// } +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Req2 { +/// a: i8, +/// b: i64, +/// } +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Resp2 { +/// c: [i8; 4], +/// d: u32, +/// } +/// +/// endpoints!{ +/// list = ENDPOINTS_LIST; +/// | EndpointTy | RequestTy | ResponseTy | Path | +/// | ---------- | --------- | ---------- | ---- | +/// | Endpoint1 | Req1 | Resp1 | "endpoints/one" | +/// | Endpoint2 | Req2 | Resp2 | "endpoints/two" | +/// } +/// ``` +#[macro_export] +macro_rules! endpoints { + ( + list = $list_name:ident; + | EndpointTy | RequestTy | ResponseTy | Path | + | $(-)* | $(-)* | $(-)* | $(-)* | + $( | $ep_name:ident | $req_ty:ty | $resp_ty:ty | $path_str:literal | )* + ) => { + // struct definitions and trait impls + $( + pub struct $ep_name; + + impl $crate::Endpoint for $ep_name { + type Request = $req_ty; + type Response = $resp_ty; + const PATH: &'static str = $path_str; + const REQ_KEY: $crate::Key = $crate::Key::for_path::<$req_ty>($path_str); + const RESP_KEY: $crate::Key = $crate::Key::for_path::<$resp_ty>($path_str); + } + )* + + pub const $list_name: $crate::EndpointMap = $crate::EndpointMap { + types: &[ + $( + <$ep_name as $crate::Endpoint>::Request::SCHEMA, + <$ep_name as $crate::Endpoint>::Response::SCHEMA, + )* + ], + endpoints: &[ + $( + ( + <$ep_name as $crate::Endpoint>::PATH, + <$ep_name as $crate::Endpoint>::REQ_KEY, + <$ep_name as $crate::Endpoint>::RESP_KEY, + ), + )* + ], + }; + }; +} + /// ## Topic macro /// /// Used to define a single Topic marker type that implements the /// [Topic][crate::Topic] trait. /// +/// Prefer the [`topics!()` macro](crate::topics) macro. +/// /// ```rust /// # use postcard_schema::Schema; /// # use serde::{Serialize, Deserialize}; @@ -74,6 +161,9 @@ macro_rules! topic { topic!($tyname, $msg, $path) }; ($tyname:ident, $msg:ty, $path:expr) => { + /// $tyname - A Topic definition type + /// + /// Generated by the `topic!()` macro pub struct $tyname; impl $crate::Topic for $tyname { @@ -84,16 +174,132 @@ macro_rules! topic { }; } -#[cfg(feature = "embassy-usb-0_3-server")] +/// ## Topics macro +/// +/// Used to define multiple Topic marker types that implements the +/// [Topic][crate::Topic] trait. +/// +/// ```rust +/// # use postcard_schema::Schema; +/// # use serde::{Serialize, Deserialize}; +/// use postcard_rpc::topics; +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Message1 { +/// a: u8, +/// b: u64, +/// } +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Message2 { +/// a: i8, +/// b: i64, +/// } +/// +/// topics!{ +/// list = TOPIC_LIST_NAME; +/// | TopicTy | MessageTy | Path | +/// | ------- | --------- | ---- | +/// | Topic1 | Message1 | "topics/one" | +/// | Topic2 | Message2 | "topics/two" | +/// } +/// ``` #[macro_export] -macro_rules! sender_log { - ($sender:ident, $($arg:tt)*) => { - $sender.fmt_publish::<$crate::standard_icd::Logging>(format_args!($($arg)*)) - }; - ($sender:ident, $s:expr) => { - $sender.str_publish::<$crate::standard_icd::Logging>($s) +macro_rules! topics { + ( + list = $list_name:ident; + | TopicTy | MessageTy | Path | + | $(-)* | $(-)* | $(-)* | + $( | $tp_name:ident | $msg_ty:ty | $path_str:literal | )* + ) => { + // struct definitions and trait impls + $( + /// $tp_name - A Topic definition type + /// + /// Generated by the `topics!()` macro + pub struct $tp_name; + + impl $crate::Topic for $tp_name { + type Message = $msg_ty; + const PATH: &'static str = $path_str; + const TOPIC_KEY: $crate::Key = $crate::Key::for_path::<$msg_ty>($path_str); + } + )* + + pub const $list_name: $crate::TopicMap = $crate::TopicMap { + types: &[ + $( + <$tp_name as $crate::Topic>::Message::SCHEMA, + )* + ], + topics: &[ + $( + ( + <$tp_name as $crate::Topic>::PATH, + <$tp_name as $crate::Topic>::TOPIC_KEY, + ), + )* + ], + }; }; - ($($arg:tt)*) => { - compile_error!("You must pass the sender to `sender_log`!"); +} + +// TODO: bring this back when I sort out how to do formatting in the sender! +// This might require WireTx impls +// +// #[cfg(feature = "embassy-usb-0_3-server")] +// #[macro_export] +// macro_rules! sender_log { +// ($sender:ident, $($arg:tt)*) => { +// $sender.fmt_publish::<$crate::standard_icd::Logging>(format_args!($($arg)*)) +// }; +// ($sender:ident, $s:expr) => { +// $sender.str_publish::<$crate::standard_icd::Logging>($s) +// }; +// ($($arg:tt)*) => { +// compile_error!("You must pass the sender to `sender_log`!"); +// } +// } + +#[cfg(test)] +mod endpoints_test { + use postcard_schema::Schema; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Schema)] + pub struct AReq(pub u8); + #[derive(Serialize, Deserialize, Schema)] + pub struct AResp(pub u16); + #[derive(Serialize, Deserialize, Schema)] + pub struct BTopic(pub u32); + + endpoints! { + list = ENDPOINT_LIST; + | EndpointTy | RequestTy | ResponseTy | Path | + | ---------- | --------- | ---------- | ---- | + | AlphaEndpoint1 | AReq | AResp | "test/alpha1" | + | AlphaEndpoint2 | AReq | AResp | "test/alpha2" | + | AlphaEndpoint3 | AReq | AResp | "test/alpha3" | + } + + topics! { + list = TOPICS_IN_LIST; + | TopicTy | MessageTy | Path | + | ---------- | --------- | ---- | + | BetaTopic1 | BTopic | "test/beta1" | + | BetaTopic2 | BTopic | "test/beta2" | + | BetaTopic3 | BTopic | "test/beta3" | + } + + #[test] + fn eps() { + assert_eq!(ENDPOINT_LIST.types.len(), 6); + assert_eq!(ENDPOINT_LIST.endpoints.len(), 3); + } + + #[test] + fn tps() { + assert_eq!(TOPICS_IN_LIST.types.len(), 3); + assert_eq!(TOPICS_IN_LIST.topics.len(), 3); } } diff --git a/source/postcard-rpc/src/server/dispatch_macro.rs b/source/postcard-rpc/src/server/dispatch_macro.rs new file mode 100644 index 0000000..9e68731 --- /dev/null +++ b/source/postcard-rpc/src/server/dispatch_macro.rs @@ -0,0 +1,454 @@ +#[doc(hidden)] +pub mod export { + pub use paste::paste; +} + +/// Define Dispatch Macro +#[macro_export] +macro_rules! define_dispatch { + ////////////////////////////////////////////////////////////////////////////// + // ENDPOINT HANDLER EXPANSION ARMS + ////////////////////////////////////////////////////////////////////////////// + + // This is the "blocking execution" arm for defining an endpoint + (@ep_arm blocking ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + let reply = $handler($context, $header.clone(), $req); + if $outputter.reply::<$endpoint>($header.seq_no, &reply).await.is_err() { + let err = $crate::standard_icd::WireError::SerFailed; + $outputter.error($header.seq_no, err).await + } else { + Ok(()) + } + } + }; + // This is the "async execution" arm for defining an endpoint + (@ep_arm async ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + let reply = $handler($context, $header.clone(), $req).await; + if $outputter.reply::<$endpoint>($header.seq_no, &reply).await.is_err() { + let err = $crate::standard_icd::WireError::SerFailed; + $outputter.error($header.seq_no, err).await + } else { + Ok(()) + } + } + }; + // This is the "spawn an embassy task" arm for defining an endpoint + (@ep_arm spawn ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + let context = $crate::server::SpawnContext::spawn_ctxt($context); + if $spawn_fn($spawner, $handler(context, $header.clone(), $req, $outputter.clone())).is_err() { + let err = $crate::standard_icd::WireError::FailedToSpawn; + $outputter.error($header.seq_no, err).await + } else { + Ok(()) + } + } + }; + + ////////////////////////////////////////////////////////////////////////////// + // TOPIC HANDLER EXPANSION ARMS + ////////////////////////////////////////////////////////////////////////////// + + // This is the "blocking execution" arm for defining a topic + (@tp_arm blocking $handler:ident $context:ident $header:ident $msg:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + $handler($context, $header.clone(), $msg, $outputter); + } + }; + // This is the "async execution" arm for defining a topic + (@tp_arm async $handler:ident $context:ident $header:ident $msg:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + $handler($context, $header.clone(), $msg, $outputter).await; + } + }; + (@tp_arm spawn $handler:ident $context:ident $header:ident $msg:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + let context = $crate::server::SpawnContext::spawn_ctxt($context); + let _ = $spawn_fn($spawner, $handler(context, $header.clone(), $msg, $outputter.clone())); + } + }; + + ////////////////////////////////////////////////////////////////////////////// + // Implementation of the dispatch trait for the app, where the Key length + // is N, where N is 1, 2, 4, or 8 + ////////////////////////////////////////////////////////////////////////////// + (@matcher + $n:literal $app_name:ident $tx_impl:ty; $spawn_fn:ident $key_ty:ty; $key_kind:expr; + ($($endpoint:ty | $ep_flavor:tt | $ep_handler:ident)*) + ($($topic_in:ty | $tp_flavor:tt | $tp_handler:ident)*) + ) => { + impl $crate::server::Dispatch for $app_name<$n> { + type Tx = $tx_impl; + + fn min_key_len(&self) -> $crate::header::VarKeyKind { + $key_kind + } + + /// Handle dispatching of a single frame + async fn handle( + &mut self, + tx: &$crate::server::Sender, + hdr: &$crate::header::VarHeader, + body: &[u8], + ) -> Result<(), ::Error> { + let key = hdr.key; + let Some(keyb) = <$key_ty>::try_from_varkey(&key) else { + let err = $crate::standard_icd::WireError::KeyTooSmall; + return tx.error(hdr.seq_no, err).await; + }; + let keyb = keyb.to_bytes(); + use consts::*; + match keyb { + $( + $crate::server::dispatch_macro::export::paste! { [<$endpoint:upper _KEY $n>] } => { + // Can we deserialize the request? + let Ok(req) = postcard::from_bytes::<<$endpoint as $crate::Endpoint>::Request>(body) else { + let err = $crate::standard_icd::WireError::DeserFailed; + return tx.error(hdr.seq_no, err).await; + }; + + // Store some items as named bindings, so we can use `ident` in the + // recursive macro expansion. Load bearing order: we borrow `context` + // from `dispatch` because we need `dispatch` AFTER `context`, so NLL + // allows this to still borrowck + let dispatch = self; + let context = &mut dispatch.context; + #[allow(unused)] + let spawninfo = &dispatch.spawn; + + // This will expand to the right "flavor" of handler + define_dispatch!(@ep_arm $ep_flavor ($endpoint) $ep_handler context hdr req tx ($spawn_fn) spawninfo) + } + )* + $( + $crate::server::dispatch_macro::export::paste! { [<$topic_in:upper _KEY $n>] } => { + // Can we deserialize the request? + let Ok(msg) = postcard::from_bytes::<<$topic_in as $crate::Topic>::Message>(body) else { + // This is a topic, not much to be done + return Ok(()); + }; + + // Store some items as named bindings, so we can use `ident` in the + // recursive macro expansion. Load bearing order: we borrow `context` + // from `dispatch` because we need `dispatch` AFTER `context`, so NLL + // allows this to still borrowck + let dispatch = self; + let context = &mut dispatch.context; + #[allow(unused)] + let spawninfo = &dispatch.spawn; + + // (@tp_arm async $handler:ident $context:ident $header:ident $req:ident $outputter:ident) + define_dispatch!(@tp_arm $tp_flavor $tp_handler context hdr msg tx ($spawn_fn) spawninfo); + Ok(()) + } + )* + _other => { + // huh! We have no idea what this key is supposed to be! + let err = $crate::standard_icd::WireError::UnknownKey; + tx.error(hdr.seq_no, err).await + }, + } + } + } + }; + + ////////////////////////////////////////////////////////////////////////////// + // MAIN EXPANSION ENTRYPOINT + ////////////////////////////////////////////////////////////////////////////// + ( + app: $app_name:ident; + + spawn_fn: $spawn_fn:ident; + tx_impl: $tx_impl:ty; + spawn_impl: $spawn_impl:ty; + context: $context_ty:ty; + + endpoints: { + list: $endpoint_list:ident; + + | EndpointTy | kind | handler | + | $(-)* | $(-)* | $(-)* | + $( | $endpoint:ty | $ep_flavor:tt | $ep_handler:ident | )* + }; + topics_in: { + list: $topic_in_list:ident; + + | TopicTy | kind | handler | + | $(-)* | $(-)* | $(-)* | + $( | $topic_in:ty | $tp_flavor:tt | $tp_handler:ident | )* + }; + topics_out: { + list: $topic_out_list:ident; + }; + ) => { + + // Here, we calculate how many bytes (1, 2, 4, or 8) are required to uniquely + // match on the given messages we receive and send†. + // + // This serves as a sort of "perfect hash function", allowing us to use fewer + // bytes on the wire. + // + // †: We don't calculate sending keys yet, oops. This probably requires hashing + // TX/RX differently so endpoints with the same TX and RX don't collide, or + // calculating them separately and taking the max + mod sizer { + use super::*; + use $crate::Key; + + // Create a list of JUST the REQUEST keys from the endpoint report + const EP_IN_KEYS_SZ: usize = $endpoint_list.endpoints.len(); + const EP_IN_KEYS: [Key; EP_IN_KEYS_SZ] = const { + let mut keys = [unsafe { Key::from_bytes([0; 8]) }; EP_IN_KEYS_SZ]; + let mut i = 0; + while i < EP_IN_KEYS_SZ { + keys[i] = $endpoint_list.endpoints[i].1; + i += 1; + } + keys + }; + // Create a list of JUST the RESPONSE keys from the endpoint report + const EP_OUT_KEYS_SZ: usize = $endpoint_list.endpoints.len(); + const EP_OUT_KEYS: [Key; EP_OUT_KEYS_SZ] = const { + let mut keys = [unsafe { Key::from_bytes([0; 8]) }; EP_OUT_KEYS_SZ]; + let mut i = 0; + while i < EP_OUT_KEYS_SZ { + keys[i] = $endpoint_list.endpoints[i].2; + i += 1; + } + keys + }; + // Create a list of JUST the MESSAGE keys from the TOPICS IN report + const TP_IN_KEYS_SZ: usize = $topic_in_list.topics.len(); + const TP_IN_KEYS: [Key; TP_IN_KEYS_SZ] = const { + let mut keys = [unsafe { Key::from_bytes([0; 8]) }; TP_IN_KEYS_SZ]; + let mut i = 0; + while i < TP_IN_KEYS_SZ { + keys[i] = $topic_in_list.topics[i].1; + i += 1; + } + keys + }; + // Create a list of JUST the MESSAGE keys from the TOPICS OUT report + const TP_OUT_KEYS_SZ: usize = $topic_out_list.topics.len(); + const TP_OUT_KEYS: [Key; TP_OUT_KEYS_SZ] = const { + let mut keys = [unsafe { Key::from_bytes([0; 8]) }; TP_OUT_KEYS_SZ]; + let mut i = 0; + while i < TP_OUT_KEYS_SZ { + keys[i] = $topic_out_list.topics[i].1; + i += 1; + } + keys + }; + + // This is a list of all REQUEST KEYS in the actual handlers + // + // This should be a SUBSET of the REQUEST KEYS in the Endpoint report + const EP_HANDLER_IN_KEYS: &[Key] = &[ + $(<$endpoint as $crate::Endpoint>::REQ_KEY,)* + ]; + // This is a list of all RESPONSE KEYS in the actual handlers + // + // This should be a SUBSET of the RESPONSE KEYS in the Endpoint report + const EP_HANDLER_OUT_KEYS: &[Key] = &[ + $(<$endpoint as $crate::Endpoint>::RESP_KEY,)* + ]; + // This is a list of all TOPIC KEYS in the actual handlers + // + // This should be a SUBSET of the TOPIC KEYS in the Topic IN report + // (we can't check the out, we have no way of enumerating that yet, + // which would require linkme-like crimes I think) + const TP_HANDLER_IN_KEYS: &[Key] = &[ + $(<$topic_in as $crate::Topic>::TOPIC_KEY,)* + ]; + + const fn a_is_subset_of_b(a: &[Key], b: &[Key]) -> bool { + let mut i = 0; + while i < a.len() { + let x = u64::from_le_bytes(a[i].to_bytes()); + let mut matched = false; + let mut j = 0; + while j < b.len() { + let y = u64::from_le_bytes(b[j].to_bytes()); + if x == y { + matched = true; + break; + } + j += 1; + } + if !matched { + return false; + } + i += 1; + } + true + } + + // TODO: Warn/error if the list doesn't match the defined handlers? + pub const NEEDED_SZ_IN: usize = $crate::server::min_key_needed(&[ + &EP_IN_KEYS, + &TP_IN_KEYS, + ]); + pub const NEEDED_SZ_OUT: usize = $crate::server::min_key_needed(&[ + &EP_OUT_KEYS, + &TP_OUT_KEYS, + ]); + pub const NEEDED_SZ: usize = const { + assert!( + a_is_subset_of_b(EP_HANDLER_IN_KEYS, &EP_IN_KEYS), + "All listed endpoint handlers must be listed in endpoints->list! Missing Requst Type found!", + ); + assert!( + a_is_subset_of_b(EP_HANDLER_OUT_KEYS, &EP_OUT_KEYS), + "All listed endpoint handlers must be listed in endpoints->list! Missing Response Type found!", + ); + assert!( + a_is_subset_of_b(TP_HANDLER_IN_KEYS, &TP_IN_KEYS), + "All listed endpoint handlers must be listed in endpoints->list! Missing Response Type found!", + ); + if NEEDED_SZ_IN > NEEDED_SZ_OUT { + NEEDED_SZ_IN + } else { + NEEDED_SZ_OUT + } + }; + } + + + // Here, we calculate at const time the keys we need to match against. This is done with + // paste, which is unfortunate, but allows us to match on this correctly later. + mod consts { + use super::*; + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$endpoint:upper _KEY1>]: u8 = $crate::Key1::from_key8(<$endpoint as $crate::Endpoint>::REQ_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$topic_in:upper _KEY1>]: u8 = $crate::Key1::from_key8(<$topic_in as $crate::Topic>::TOPIC_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$endpoint:upper _KEY2>]: [u8; 2] = $crate::Key2::from_key8(<$endpoint as $crate::Endpoint>::REQ_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$topic_in:upper _KEY2>]: [u8; 2] = $crate::Key2::from_key8(<$topic_in as $crate::Topic>::TOPIC_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$endpoint:upper _KEY4>]: [u8; 4] = $crate::Key4::from_key8(<$endpoint as $crate::Endpoint>::REQ_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$topic_in:upper _KEY4>]: [u8; 4] = $crate::Key4::from_key8(<$topic_in as $crate::Topic>::TOPIC_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$endpoint:upper _KEY8>]: [u8; 8] = <$endpoint as $crate::Endpoint>::REQ_KEY.to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$topic_in:upper _KEY8>]: [u8; 8] = <$topic_in as $crate::Topic>::TOPIC_KEY.to_bytes(); + } + )* + } + + // This is the fun part. + // + // For... reasons, we need to generate a match function to allow for dispatching + // different async handlers without degrading to dyn Future, because no alloc on + // embedded systems. + // + // The easiest way I've found to achieve this is actually to implement this + // handler for ALL of 1, 2, 4, 8, BUT to hide that from the user, and instead + // use THIS alias to give them the one that they need. + // + // This is overly complicated because I'm mixing const-time capabilities with + // macro-time capabilities. I'm very open to other suggestions that achieve the + // same outcome. + pub type $app_name = impls::$app_name<{ sizer::NEEDED_SZ }>; + + + + mod impls { + use super::*; + + pub struct $app_name { + pub context: $context_ty, + pub spawn: $spawn_impl, + pub device_map: &'static $crate::DeviceMap, + } + + impl $app_name { + /// Create a new instance of the dispatcher + pub fn new( + context: $context_ty, + spawn: $spawn_impl, + ) -> Self { + const MAP: &$crate::DeviceMap = &$crate::DeviceMap { + types: const { + const LISTS: &[&[&'static postcard_schema::schema::NamedType]] = &[ + $endpoint_list.types, + $topic_in_list.types, + $topic_out_list.types, + ]; + const TTL_COUNT: usize = $endpoint_list.types.len() + $topic_in_list.types.len() + $topic_out_list.types.len(); + + const BIG_RPT: ([Option<&'static postcard_schema::schema::NamedType>; TTL_COUNT], usize) = $crate::uniques::merge_nty_lists(LISTS); + const SMALL_RPT: [&'static postcard_schema::schema::NamedType; BIG_RPT.1] = $crate::uniques::cruncher(BIG_RPT.0.as_slice()); + SMALL_RPT.as_slice() + }, + endpoints: &$endpoint_list.endpoints, + topics_in: &$topic_in_list.topics, + topics_out: &$topic_out_list.topics, + min_key_len: const { + match sizer::NEEDED_SZ { + 1 => $crate::header::VarKeyKind::Key1, + 2 => $crate::header::VarKeyKind::Key2, + 4 => $crate::header::VarKeyKind::Key4, + 8 => $crate::header::VarKeyKind::Key8, + _ => unreachable!(), + } + } + }; + $app_name { + context, + spawn, + device_map: MAP, + } + } + } + + define_dispatch! { + @matcher 1 $app_name $tx_impl; $spawn_fn $crate::Key1; $crate::header::VarKeyKind::Key1; + ($($endpoint | $ep_flavor | $ep_handler)*) + ($($topic_in | $tp_flavor | $tp_handler)*) + } + define_dispatch! { + @matcher 2 $app_name $tx_impl; $spawn_fn $crate::Key2; $crate::header::VarKeyKind::Key2; + ($($endpoint | $ep_flavor | $ep_handler)*) + ($($topic_in | $tp_flavor | $tp_handler)*) + } + define_dispatch! { + @matcher 4 $app_name $tx_impl; $spawn_fn $crate::Key4; $crate::header::VarKeyKind::Key4; + ($($endpoint | $ep_flavor | $ep_handler)*) + ($($topic_in | $tp_flavor | $tp_handler)*) + } + define_dispatch! { + @matcher 8 $app_name $tx_impl; $spawn_fn $crate::Key; $crate::header::VarKeyKind::Key8; + ($($endpoint | $ep_flavor | $ep_handler)*) + ($($topic_in | $tp_flavor | $tp_handler)*) + } + } + + } +} diff --git a/source/postcard-rpc/src/server/impls/embassy_usb_v0_3.rs b/source/postcard-rpc/src/server/impls/embassy_usb_v0_3.rs new file mode 100644 index 0000000..3135aed --- /dev/null +++ b/source/postcard-rpc/src/server/impls/embassy_usb_v0_3.rs @@ -0,0 +1,669 @@ +//! Implementation using `embassy-usb` and bulk interfaces + +use embassy_executor::{SpawnError, SpawnToken, Spawner}; +use embassy_sync::{blocking_mutex::raw::RawMutex, mutex::Mutex}; +use embassy_usb_driver::{Driver, Endpoint, EndpointError, EndpointIn, EndpointOut}; +use futures_util::FutureExt; +use serde::Serialize; + +use crate::{ + header::VarHeader, + server::{WireRx, WireRxErrorKind, WireSpawn, WireTx, WireTxErrorKind}, +}; + +/// A collection of types and aliases useful for importing the correct types +pub mod dispatch_impl { + pub use super::embassy_spawn as spawn_fn; + use super::{EUsbWireRx, EUsbWireTx, EUsbWireTxInner, UsbDeviceBuffers}; + + /// Used for defining the USB interface + pub const DEVICE_INTERFACE_GUIDS: &[&str] = &["{AFB9A6FB-30BA-44BC-9232-806CFC875321}"]; + + use embassy_sync::{blocking_mutex::raw::RawMutex, mutex::Mutex}; + use embassy_usb::{ + msos::{self, windows_version}, + Builder, Config, UsbDevice, + }; + use embassy_usb_driver::Driver; + use static_cell::{ConstStaticCell, StaticCell}; + + /// Type alias for `WireTx` impl + pub type WireTxImpl = super::EUsbWireTx; + /// Type alias for `WireRx` impl + pub type WireRxImpl = super::EUsbWireRx; + /// Type alias for `WireSpawn` impl + pub type WireSpawnImpl = super::EUsbWireSpawn; + /// Type alias for the receive buffer + pub type WireRxBuf = &'static mut [u8]; + + /// A helper type for `static` storage of buffers and driver components + pub struct WireStorage< + M: RawMutex + 'static, + D: Driver<'static> + 'static, + const CONFIG: usize = 256, + const BOS: usize = 256, + const CONTROL: usize = 64, + const MSOS: usize = 256, + > { + /// Usb buffer storage + pub bufs_usb: ConstStaticCell>, + /// WireTx/Sender static storage + pub cell: StaticCell>>, + } + + impl< + M: RawMutex + 'static, + D: Driver<'static> + 'static, + const CONFIG: usize, + const BOS: usize, + const CONTROL: usize, + const MSOS: usize, + > WireStorage + { + /// Create a new, uninitialized static set of buffers + pub const fn new() -> Self { + Self { + bufs_usb: ConstStaticCell::new(UsbDeviceBuffers::new()), + cell: StaticCell::new(), + } + } + + /// Initialize the static storage. + /// + /// This must only be called once. + pub fn init( + &'static self, + driver: D, + config: Config<'static>, + tx_buf: &'static mut [u8], + ) -> (UsbDevice<'static, D>, WireTxImpl, WireRxImpl) { + let bufs = self.bufs_usb.take(); + + let mut builder = Builder::new( + driver, + config, + &mut bufs.config_descriptor, + &mut bufs.bos_descriptor, + &mut bufs.msos_descriptor, + &mut bufs.control_buf, + ); + + // Add the Microsoft OS Descriptor (MSOS/MOD) descriptor. + // We tell Windows that this entire device is compatible with the "WINUSB" feature, + // which causes it to use the built-in WinUSB driver automatically, which in turn + // can be used by libusb/rusb software without needing a custom driver or INF file. + // In principle you might want to call msos_feature() just on a specific function, + // if your device also has other functions that still use standard class drivers. + builder.msos_descriptor(windows_version::WIN8_1, 0); + builder.msos_feature(msos::CompatibleIdFeatureDescriptor::new("WINUSB", "")); + builder.msos_feature(msos::RegistryPropertyFeatureDescriptor::new( + "DeviceInterfaceGUIDs", + msos::PropertyData::RegMultiSz(DEVICE_INTERFACE_GUIDS), + )); + + // Add a vendor-specific function (class 0xFF), and corresponding interface, + // that uses our custom handler. + let mut function = builder.function(0xFF, 0, 0); + let mut interface = function.interface(); + let mut alt = interface.alt_setting(0xFF, 0, 0, None); + let ep_out = alt.endpoint_bulk_out(64); + let ep_in = alt.endpoint_bulk_in(64); + drop(function); + + let wtx = self.cell.init(Mutex::new(EUsbWireTxInner { + ep_in, + _log_seq: 0, + tx_buf, + _max_log_len: 0, + })); + + // Build the builder. + let usb = builder.build(); + + (usb, EUsbWireTx { inner: wtx }, EUsbWireRx { ep_out }) + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// TX +////////////////////////////////////////////////////////////////////////////// + +/// Implementation detail, holding the endpoint and scratch buffer used for sending +pub struct EUsbWireTxInner> { + ep_in: D::EndpointIn, + _log_seq: u32, + tx_buf: &'static mut [u8], + _max_log_len: usize, +} + +/// A [`WireTx`] implementation for embassy-usb 0.3. +#[derive(Copy)] +pub struct EUsbWireTx + 'static> { + inner: &'static Mutex>, +} + +impl + 'static> Clone for EUsbWireTx { + fn clone(&self) -> Self { + EUsbWireTx { inner: self.inner } + } +} + +impl + 'static> WireTx for EUsbWireTx { + type Error = WireTxErrorKind; + + async fn send( + &self, + hdr: VarHeader, + msg: &T, + ) -> Result<(), Self::Error> { + let mut inner = self.inner.lock().await; + + let EUsbWireTxInner { + ep_in, + _log_seq: _, + tx_buf, + _max_log_len: _, + }: &mut EUsbWireTxInner = &mut inner; + + let (hdr_used, remain) = hdr.write_to_slice(tx_buf).ok_or(WireTxErrorKind::Other)?; + let bdy_used = postcard::to_slice(msg, remain).map_err(|_| WireTxErrorKind::Other)?; + let used_ttl = hdr_used.len() + bdy_used.len(); + + if let Some(used) = tx_buf.get(..used_ttl) { + send_all::(ep_in, used).await + } else { + Err(WireTxErrorKind::Other) + } + } + + async fn send_raw(&self, buf: &[u8]) -> Result<(), Self::Error> { + let mut inner = self.inner.lock().await; + send_all::(&mut inner.ep_in, buf).await + } +} + +#[inline] +async fn send_all(ep_in: &mut D::EndpointIn, out: &[u8]) -> Result<(), WireTxErrorKind> +where + D: Driver<'static>, +{ + if out.is_empty() { + return Ok(()); + } + // TODO: Timeout? + if ep_in.wait_enabled().now_or_never().is_none() { + return Ok(()); + } + + // write in segments of 64. The last chunk may + // be 0 < len <= 64. + for ch in out.chunks(64) { + if ep_in.write(ch).await.is_err() { + return Err(WireTxErrorKind::ConnectionClosed); + } + } + // If the total we sent was a multiple of 64, send an + // empty message to "flush" the transaction. We already checked + // above that the len != 0. + if (out.len() & (64 - 1)) == 0 && ep_in.write(&[]).await.is_err() { + return Err(WireTxErrorKind::ConnectionClosed); + } + + Ok(()) +} + +////////////////////////////////////////////////////////////////////////////// +// RX +////////////////////////////////////////////////////////////////////////////// + +/// A [`WireRx`] implementation for embassy-usb 0.3. +pub struct EUsbWireRx> { + ep_out: D::EndpointOut, +} + +impl> WireRx for EUsbWireRx { + type Error = WireRxErrorKind; + + async fn receive<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a mut [u8], Self::Error> { + let buflen = buf.len(); + let mut window = &mut buf[..]; + while !window.is_empty() { + let n = match self.ep_out.read(window).await { + Ok(n) => n, + Err(EndpointError::BufferOverflow) => { + return Err(WireRxErrorKind::ReceivedMessageTooLarge) + } + Err(EndpointError::Disabled) => return Err(WireRxErrorKind::ConnectionClosed), + }; + + let (_now, later) = window.split_at_mut(n); + window = later; + if n != 64 { + // We now have a full frame! Great! + let wlen = window.len(); + let len = buflen - wlen; + let frame = &mut buf[..len]; + + return Ok(frame); + } + } + + // If we got here, we've run out of space. That's disappointing. Accumulate to the + // end of this packet + loop { + match self.ep_out.read(buf).await { + Ok(64) => {} + Ok(_) => return Err(WireRxErrorKind::ReceivedMessageTooLarge), + Err(EndpointError::BufferOverflow) => { + return Err(WireRxErrorKind::ReceivedMessageTooLarge) + } + Err(EndpointError::Disabled) => return Err(WireRxErrorKind::ConnectionClosed), + }; + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// SPAWN +////////////////////////////////////////////////////////////////////////////// + +/// A [`WireSpawn`] impl using the embassy executor +#[derive(Clone)] +pub struct EUsbWireSpawn { + /// The embassy-executor spawner + pub spawner: Spawner, +} + +impl From for EUsbWireSpawn { + fn from(value: Spawner) -> Self { + Self { spawner: value } + } +} + +impl WireSpawn for EUsbWireSpawn { + type Error = SpawnError; + + type Info = Spawner; + + fn info(&self) -> &Self::Info { + &self.spawner + } +} + +/// Attempt to spawn the given token +pub fn embassy_spawn(sp: &Sp, tok: SpawnToken) -> Result<(), Sp::Error> +where + Sp: WireSpawn, +{ + let info = sp.info(); + info.spawn(tok) +} + +////////////////////////////////////////////////////////////////////////////// +// OTHER +////////////////////////////////////////////////////////////////////////////// + +/// A generically sized storage type for buffers +pub struct UsbDeviceBuffers< + const CONFIG: usize = 256, + const BOS: usize = 256, + const CONTROL: usize = 64, + const MSOS: usize = 256, +> { + /// Config descriptor storage + pub config_descriptor: [u8; CONFIG], + /// BOS descriptor storage + pub bos_descriptor: [u8; BOS], + /// CONTROL endpoint buffer storage + pub control_buf: [u8; CONTROL], + /// MSOS descriptor buffer storage + pub msos_descriptor: [u8; MSOS], +} + +impl + UsbDeviceBuffers +{ + /// Create a new, empty set of buffers + pub const fn new() -> Self { + Self { + config_descriptor: [0u8; CONFIG], + bos_descriptor: [0u8; BOS], + msos_descriptor: [0u8; MSOS], + control_buf: [0u8; CONTROL], + } + } +} + +/// Static storage for generically sized input and output packet buffers +pub struct PacketBuffers { + /// the transmit buffer + pub tx_buf: [u8; TX], + /// thereceive buffer + pub rx_buf: [u8; RX], +} + +impl PacketBuffers { + /// Create new empty buffers + pub const fn new() -> Self { + Self { + tx_buf: [0u8; TX], + rx_buf: [0u8; RX], + } + } +} + +/// This is a basic example that everything compiles. It is intended to exercise the macro above, +/// as well as provide impls for docs. Don't rely on any of this! +#[doc(hidden)] +#[allow(dead_code)] +#[cfg(feature = "test-utils")] +pub mod fake { + use crate::{ + define_dispatch, endpoints, + server::{Sender, SpawnContext}, + topics, + }; + use crate::{header::VarHeader, Schema}; + use embassy_usb_driver::{Bus, ControlPipe, EndpointIn, EndpointOut}; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Schema)] + pub struct AReq(pub u8); + #[derive(Serialize, Deserialize, Schema)] + pub struct AResp(pub u8); + #[derive(Serialize, Deserialize, Schema)] + pub struct BReq(pub u16); + #[derive(Serialize, Deserialize, Schema)] + pub struct BResp(pub u32); + #[derive(Serialize, Deserialize, Schema)] + pub struct GReq; + #[derive(Serialize, Deserialize, Schema)] + pub struct GResp; + #[derive(Serialize, Deserialize, Schema)] + pub struct DReq; + #[derive(Serialize, Deserialize, Schema)] + pub struct DResp; + #[derive(Serialize, Deserialize, Schema)] + pub struct EReq; + #[derive(Serialize, Deserialize, Schema)] + pub struct EResp; + #[derive(Serialize, Deserialize, Schema)] + pub struct ZMsg(pub i16); + + endpoints! { + list = ENDPOINT_LIST; + | EndpointTy | RequestTy | ResponseTy | Path | + | ---------- | --------- | ---------- | ---- | + | AlphaEndpoint | AReq | AResp | "alpha" | + | BetaEndpoint | BReq | BResp | "beta" | + | GammaEndpoint | GReq | GResp | "gamma" | + | DeltaEndpoint | DReq | DResp | "delta" | + | EpsilonEndpoint | EReq | EResp | "epsilon" | + } + + topics! { + list = TOPICS_IN_LIST; + | TopicTy | MessageTy | Path | + | ---------- | --------- | ---- | + | ZetaTopic1 | ZMsg | "zeta1" | + | ZetaTopic2 | ZMsg | "zeta2" | + | ZetaTopic3 | ZMsg | "zeta3" | + } + + topics! { + list = TOPICS_OUT_LIST; + | TopicTy | MessageTy | Path | + | ---------- | --------- | ---- | + | ZetaTopic10 | ZMsg | "zeta10" | + } + + pub struct FakeMutex; + pub struct FakeDriver; + pub struct FakeEpOut; + pub struct FakeEpIn; + pub struct FakeCtlPipe; + pub struct FakeBus; + + impl embassy_usb_driver::Endpoint for FakeEpOut { + fn info(&self) -> &embassy_usb_driver::EndpointInfo { + todo!() + } + + async fn wait_enabled(&mut self) { + todo!() + } + } + + impl EndpointOut for FakeEpOut { + async fn read( + &mut self, + _buf: &mut [u8], + ) -> Result { + todo!() + } + } + + impl embassy_usb_driver::Endpoint for FakeEpIn { + fn info(&self) -> &embassy_usb_driver::EndpointInfo { + todo!() + } + + async fn wait_enabled(&mut self) { + todo!() + } + } + + impl EndpointIn for FakeEpIn { + async fn write(&mut self, _buf: &[u8]) -> Result<(), embassy_usb_driver::EndpointError> { + todo!() + } + } + + impl ControlPipe for FakeCtlPipe { + fn max_packet_size(&self) -> usize { + todo!() + } + + async fn setup(&mut self) -> [u8; 8] { + todo!() + } + + async fn data_out( + &mut self, + _buf: &mut [u8], + _first: bool, + _last: bool, + ) -> Result { + todo!() + } + + async fn data_in( + &mut self, + _data: &[u8], + _first: bool, + _last: bool, + ) -> Result<(), embassy_usb_driver::EndpointError> { + todo!() + } + + async fn accept(&mut self) { + todo!() + } + + async fn reject(&mut self) { + todo!() + } + + async fn accept_set_address(&mut self, _addr: u8) { + todo!() + } + } + + impl Bus for FakeBus { + async fn enable(&mut self) { + todo!() + } + + async fn disable(&mut self) { + todo!() + } + + async fn poll(&mut self) -> embassy_usb_driver::Event { + todo!() + } + + fn endpoint_set_enabled( + &mut self, + _ep_addr: embassy_usb_driver::EndpointAddress, + _enabled: bool, + ) { + todo!() + } + + fn endpoint_set_stalled( + &mut self, + _ep_addr: embassy_usb_driver::EndpointAddress, + _stalled: bool, + ) { + todo!() + } + + fn endpoint_is_stalled(&mut self, _ep_addr: embassy_usb_driver::EndpointAddress) -> bool { + todo!() + } + + async fn remote_wakeup(&mut self) -> Result<(), embassy_usb_driver::Unsupported> { + todo!() + } + } + + impl embassy_usb_driver::Driver<'static> for FakeDriver { + type EndpointOut = FakeEpOut; + + type EndpointIn = FakeEpIn; + + type ControlPipe = FakeCtlPipe; + + type Bus = FakeBus; + + fn alloc_endpoint_out( + &mut self, + _ep_type: embassy_usb_driver::EndpointType, + _max_packet_size: u16, + _interval_ms: u8, + ) -> Result { + todo!() + } + + fn alloc_endpoint_in( + &mut self, + _ep_type: embassy_usb_driver::EndpointType, + _max_packet_size: u16, + _interval_ms: u8, + ) -> Result { + todo!() + } + + fn start(self, _control_max_packet_size: u16) -> (Self::Bus, Self::ControlPipe) { + todo!() + } + } + + unsafe impl embassy_sync::blocking_mutex::raw::RawMutex for FakeMutex { + const INIT: Self = Self; + + fn lock(&self, _f: impl FnOnce() -> R) -> R { + todo!() + } + } + + pub struct TestContext { + pub a: u32, + pub b: u32, + } + + impl SpawnContext for TestContext { + type SpawnCtxt = TestSpawnContext; + + fn spawn_ctxt(&mut self) -> Self::SpawnCtxt { + TestSpawnContext { b: self.b } + } + } + + pub struct TestSpawnContext { + b: u32, + } + + // TODO: How to do module path concat? + use crate::server::impls::embassy_usb_v0_3::dispatch_impl::{ + spawn_fn, WireSpawnImpl, WireTxImpl, + }; + + define_dispatch! { + app: SingleDispatcher; + spawn_fn: spawn_fn; + tx_impl: WireTxImpl; + spawn_impl: WireSpawnImpl; + context: TestContext; + + endpoints: { + list: ENDPOINT_LIST; + + | EndpointTy | kind | handler | + | ---------- | ---- | ------- | + | AlphaEndpoint | async | test_alpha_handler | + | EpsilonEndpoint | spawn | test_epsilon_handler_task | + }; + topics_in: { + list: TOPICS_IN_LIST; + + | TopicTy | kind | handler | + | ---------- | ---- | ------- | + // | ZetaTopic1 | blocking | test_zeta_blocking | + // | ZetaTopic2 | async | test_zeta_async | + // | ZetaTopic3 | spawn | test_zeta_spawn | + }; + topics_out: { + list: TOPICS_OUT_LIST; + }; + } + + async fn test_alpha_handler( + _context: &mut TestContext, + _header: VarHeader, + _body: AReq, + ) -> AResp { + todo!() + } + + async fn test_beta_handler( + _context: &mut TestContext, + _header: VarHeader, + _body: BReq, + ) -> BResp { + todo!() + } + + async fn test_gamma_handler( + _context: &mut TestContext, + _header: VarHeader, + _body: GReq, + ) -> GResp { + todo!() + } + + fn test_delta_handler(_context: &mut TestContext, _header: VarHeader, _body: DReq) -> DResp { + todo!() + } + + #[embassy_executor::task] + async fn test_epsilon_handler_task( + _context: TestSpawnContext, + _header: VarHeader, + _body: EReq, + _sender: Sender>, + ) { + todo!() + } +} diff --git a/source/postcard-rpc/src/server/impls/mod.rs b/source/postcard-rpc/src/server/impls/mod.rs new file mode 100644 index 0000000..104cc87 --- /dev/null +++ b/source/postcard-rpc/src/server/impls/mod.rs @@ -0,0 +1,9 @@ +//! Implementations of various Server traits +//! +//! The implementations in this module typically require feature flags to be set. + +#[cfg(feature = "embassy-usb-0_3-server")] +pub mod embassy_usb_v0_3; + +#[cfg(feature = "test-utils")] +pub mod test_channels; diff --git a/source/postcard-rpc/src/server/impls/test_channels.rs b/source/postcard-rpc/src/server/impls/test_channels.rs new file mode 100644 index 0000000..db4aab1 --- /dev/null +++ b/source/postcard-rpc/src/server/impls/test_channels.rs @@ -0,0 +1,199 @@ +//! Implementation that uses channels for local testing + +use core::{convert::Infallible, future::Future}; + +use crate::server::{ + AsWireRxErrorKind, AsWireTxErrorKind, WireRx, WireRxErrorKind, WireSpawn, WireTx, + WireTxErrorKind, +}; +use tokio::sync::mpsc; + +////////////////////////////////////////////////////////////////////////////// +// DISPATCH IMPL +////////////////////////////////////////////////////////////////////////////// + +/// A collection of types and aliases useful for importing the correct types +pub mod dispatch_impl { + use crate::{ + header::VarKeyKind, + server::{Dispatch, Server}, + }; + + pub use super::tokio_spawn as spawn_fn; + + /// The settings necessary for creating a new channel server + pub struct Settings { + /// The frame sender + pub tx: WireTxImpl, + /// The frame receiver + pub rx: WireRxImpl, + /// The size of the receive buffer + pub buf: usize, + /// The sender key size to use + pub kkind: VarKeyKind, + } + + /// Type alias for `WireTx` impl + pub type WireTxImpl = super::ChannelWireTx; + /// Type alias for `WireRx` impl + pub type WireRxImpl = super::ChannelWireRx; + /// Type alias for `WireSpawn` impl + pub type WireSpawnImpl = super::ChannelWireSpawn; + /// Type alias for the receive buffer + pub type WireRxBuf = Box<[u8]>; + + /// Create a new server using the [`Settings`] and [`Dispatch`] implementation + pub fn new_server( + dispatch: D, + settings: Settings, + ) -> crate::server::Server + where + D: Dispatch, + { + let buf = vec![0; settings.buf]; + Server::new( + &settings.tx, + settings.rx, + buf.into_boxed_slice(), + dispatch, + settings.kkind, + ) + } +} + +////////////////////////////////////////////////////////////////////////////// +// TX +////////////////////////////////////////////////////////////////////////////// + +/// A [`WireTx`] impl using tokio mpsc channels +#[derive(Clone)] +pub struct ChannelWireTx { + tx: mpsc::Sender>, +} + +impl ChannelWireTx { + /// Create a new [`ChannelWireTx`] + pub fn new(tx: mpsc::Sender>) -> Self { + Self { tx } + } +} + +impl WireTx for ChannelWireTx { + type Error = ChannelWireTxError; + + async fn send( + &self, + hdr: crate::header::VarHeader, + msg: &T, + ) -> Result<(), Self::Error> { + let mut hdr_ser = hdr.write_to_vec(); + let bdy_ser = postcard::to_stdvec(msg).unwrap(); + hdr_ser.extend_from_slice(&bdy_ser); + self.tx + .send(hdr_ser) + .await + .map_err(|_| ChannelWireTxError::ChannelClosed)?; + Ok(()) + } + + async fn send_raw(&self, buf: &[u8]) -> Result<(), Self::Error> { + let buf = buf.to_vec(); + self.tx + .send(buf) + .await + .map_err(|_| ChannelWireTxError::ChannelClosed)?; + Ok(()) + } +} + +/// A wire tx error +#[derive(Debug)] +pub enum ChannelWireTxError { + /// The receiver closed the channel + ChannelClosed, +} + +impl AsWireTxErrorKind for ChannelWireTxError { + fn as_kind(&self) -> WireTxErrorKind { + match self { + ChannelWireTxError::ChannelClosed => WireTxErrorKind::ConnectionClosed, + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// RX +////////////////////////////////////////////////////////////////////////////// + +/// A [`WireRx`] impl using tokio mpsc channels +pub struct ChannelWireRx { + rx: mpsc::Receiver>, +} + +impl ChannelWireRx { + /// Create a new [`ChannelWireRx`] + pub fn new(rx: mpsc::Receiver>) -> Self { + Self { rx } + } +} + +impl WireRx for ChannelWireRx { + type Error = ChannelWireRxError; + + async fn receive<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a mut [u8], Self::Error> { + // todo: some kind of receive_owned? + let msg = self.rx.recv().await; + let msg = msg.ok_or(ChannelWireRxError::ChannelClosed)?; + let out = buf + .get_mut(..msg.len()) + .ok_or(ChannelWireRxError::MessageTooLarge)?; + out.copy_from_slice(&msg); + Ok(out) + } +} + +/// A wire rx error +#[derive(Debug)] +pub enum ChannelWireRxError { + /// The sender closed the channel + ChannelClosed, + /// The sender sent a too-large message + MessageTooLarge, +} + +impl AsWireRxErrorKind for ChannelWireRxError { + fn as_kind(&self) -> WireRxErrorKind { + match self { + ChannelWireRxError::ChannelClosed => WireRxErrorKind::ConnectionClosed, + ChannelWireRxError::MessageTooLarge => WireRxErrorKind::ReceivedMessageTooLarge, + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// SPAWN +////////////////////////////////////////////////////////////////////////////// + +/// A wire spawn implementation +#[derive(Clone)] +pub struct ChannelWireSpawn; + +impl WireSpawn for ChannelWireSpawn { + type Error = Infallible; + + type Info = (); + + fn info(&self) -> &Self::Info { + &() + } +} + +/// Spawn a task using tokio +pub fn tokio_spawn(_sp: &Sp, fut: F) -> Result<(), Sp::Error> +where + Sp: WireSpawn, + F: Future + 'static + Send, +{ + tokio::task::spawn(fut); + Ok(()) +} diff --git a/source/postcard-rpc/src/server/mod.rs b/source/postcard-rpc/src/server/mod.rs new file mode 100644 index 0000000..eb34327 --- /dev/null +++ b/source/postcard-rpc/src/server/mod.rs @@ -0,0 +1,522 @@ +//! Definitions of a postcard-rpc Server +//! +//! The Server role is responsible for accepting endpoint requests, issuing +//! endpoint responses, receiving client topic messages, and sending server +//! topic messages +//! +//! ## Impls +//! +//! It is intended to allow postcard-rpc servers to be implemented for many +//! different transport types, as well as many different operating environments. +//! +//! Examples of impls include: +//! +//! * A no-std impl using embassy and embassy-usb to provide transport over USB +//! * A std impl using Tokio channels to provide transport for testing +//! +//! Impls are expected to implement three traits: +//! +//! * [`WireTx`]: how the server sends frames to the client +//! * [`WireRx`]: how the server receives frames from the client +//! * [`WireSpawn`]: how the server spawns worker tasks for certain handlers + +#![allow(async_fn_in_trait)] + +#[doc(hidden)] +pub mod dispatch_macro; + +pub mod impls; + +use core::ops::DerefMut; + +use postcard_schema::Schema; +use serde::Serialize; + +use crate::{ + header::{VarHeader, VarKey, VarKeyKind, VarSeq}, + Key, +}; + +////////////////////////////////////////////////////////////////////////////// +// TX +////////////////////////////////////////////////////////////////////////////// + +/// This trait defines how the server sends frames to the client +pub trait WireTx: Clone { + /// The error type of this connection. + /// + /// For simple cases, you can use [`WireTxErrorKind`] directly. You can also + /// use your own custom type that implements [`AsWireTxErrorKind`]. + type Error: AsWireTxErrorKind; + + /// Send a single frame to the client, returning when send is complete. + async fn send(&self, hdr: VarHeader, msg: &T) + -> Result<(), Self::Error>; + + /// Send a single frame to the client, without handling serialization + async fn send_raw(&self, buf: &[u8]) -> Result<(), Self::Error>; +} + +/// The base [`WireTx`] Error Kind +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub enum WireTxErrorKind { + /// The connection has been closed, and is unlikely to succeed until + /// the connection is re-established. This will cause the Server run + /// loop to terminate. + ConnectionClosed, + /// Other unspecified errors + Other, +} + +/// A conversion trait to convert a user error into a base Kind type +pub trait AsWireTxErrorKind { + /// Convert the error type into a base type + fn as_kind(&self) -> WireTxErrorKind; +} + +impl AsWireTxErrorKind for WireTxErrorKind { + #[inline] + fn as_kind(&self) -> WireTxErrorKind { + *self + } +} + +////////////////////////////////////////////////////////////////////////////// +// RX +////////////////////////////////////////////////////////////////////////////// + +/// This trait defines how to receive a single frame from a client +pub trait WireRx { + /// The error type of this connection. + /// + /// For simple cases, you can use [`WireRxErrorKind`] directly. You can also + /// use your own custom type that implements [`AsWireRxErrorKind`]. + type Error: AsWireRxErrorKind; + + /// Receive a single frame + /// + /// On success, the portion of `buf` that contains a single frame is returned. + async fn receive<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a mut [u8], Self::Error>; +} + +/// The base [`WireRx`] Error Kind +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub enum WireRxErrorKind { + /// The connection has been closed, and is unlikely to succeed until + /// the connection is re-established. This will cause the Server run + /// loop to terminate. + ConnectionClosed, + /// The received message was too large for the server to handle + ReceivedMessageTooLarge, + /// Other message kinds + Other, +} + +/// A conversion trait to convert a user error into a base Kind type +pub trait AsWireRxErrorKind { + /// Convert the error type into a base type + fn as_kind(&self) -> WireRxErrorKind; +} + +impl AsWireRxErrorKind for WireRxErrorKind { + #[inline] + fn as_kind(&self) -> WireRxErrorKind { + *self + } +} + +////////////////////////////////////////////////////////////////////////////// +// SPAWN +////////////////////////////////////////////////////////////////////////////// + +/// A trait to assist in spawning a handler task +/// +/// This trait is weird, and mostly exists to abstract over how "normal" async +/// executors like tokio spawn tasks, taking a future, and how unusual async +/// executors like embassy spawn tasks, taking a task token that maps to static +/// storage +pub trait WireSpawn: Clone { + /// An error type returned when spawning fails. If this cannot happen, + /// [`Infallible`][core::convert::Infallible] can be used. + type Error; + /// The context used for spawning a task. + /// + /// For example, in tokio this is `()`, and in embassy this is `Spawner`. + type Info; + + /// Retrieve [`Self::Info`] + fn info(&self) -> &Self::Info; +} + +////////////////////////////////////////////////////////////////////////////// +// SENDER (wrapper of WireTx) +////////////////////////////////////////////////////////////////////////////// + +/// The [`Sender`] type wraps a [`WireTx`] impl, and provides higher level functionality +/// over it +#[derive(Clone)] +pub struct Sender { + tx: Tx, + kkind: VarKeyKind, +} + +impl Sender { + /// Create a new Sender + /// + /// Takes a [`WireTx`] impl, as well as the [`VarKeyKind`] used when sending messages + /// to the client. + /// + /// `kkind` should usually come from [`Dispatch::min_key_len()`]. + pub fn new(tx: Tx, kkind: VarKeyKind) -> Self { + Self { tx, kkind } + } + + /// Send a reply for the given endpoint + #[inline] + pub async fn reply(&self, seq_no: VarSeq, resp: &E::Response) -> Result<(), Tx::Error> + where + E: crate::Endpoint, + E::Response: Serialize + Schema, + { + let mut key = VarKey::Key8(E::RESP_KEY); + key.shrink_to(self.kkind); + let wh = VarHeader { key, seq_no }; + self.tx.send::(wh, resp).await + } + + /// Send a reply with the given Key + /// + /// This is useful when replying with "unusual" keys, for example Error responses + /// not tied to any specific Endpoint. + #[inline] + pub async fn reply_keyed(&self, seq_no: VarSeq, key: Key, resp: &T) -> Result<(), Tx::Error> + where + T: Serialize + Schema, + { + let mut key = VarKey::Key8(key); + key.shrink_to(self.kkind); + let wh = VarHeader { key, seq_no }; + self.tx.send::(wh, resp).await + } + + /// Publish a Topic message + #[inline] + pub async fn publish(&self, seq_no: VarSeq, msg: &T::Message) -> Result<(), Tx::Error> + where + T: crate::Topic, + T::Message: Serialize + Schema, + { + let mut key = VarKey::Key8(T::TOPIC_KEY); + key.shrink_to(self.kkind); + let wh = VarHeader { key, seq_no }; + self.tx.send::(wh, msg).await + } + + /// Send a single error message + pub async fn error( + &self, + seq_no: VarSeq, + error: crate::standard_icd::WireError, + ) -> Result<(), Tx::Error> { + self.reply_keyed(seq_no, crate::standard_icd::ERROR_KEY, &error) + .await + } +} + +////////////////////////////////////////////////////////////////////////////// +// SERVER +////////////////////////////////////////////////////////////////////////////// + +/// The [`Server`] is the main interface for handling communication +pub struct Server +where + Tx: WireTx, + Rx: WireRx, + Buf: DerefMut, + D: Dispatch, +{ + tx: Sender, + rx: Rx, + buf: Buf, + dis: D, +} + +/// A type representing the different errors [`Server::run()`] may return +pub enum ServerError +where + Tx: WireTx, + Rx: WireRx, +{ + /// A fatal error occurred with the [`WireTx::send()`] implementation + TxFatal(Tx::Error), + /// A fatal error occurred with the [`WireRx::receive()`] implementation + RxFatal(Rx::Error), +} + +impl Server +where + Tx: WireTx, + Rx: WireRx, + Buf: DerefMut, + D: Dispatch, +{ + /// Create a new Server + /// + /// Takes: + /// + /// * a [`WireTx`] impl for sending + /// * a [`WireRx`] impl for receiving + /// * a buffer used for receiving frames + /// * The user provided dispatching method, usually generated by [`define_dispatch!()`][crate::define_dispatch] + /// * a [`VarKeyKind`], which controls the key sizes sent by the [`WireTx`] impl + pub fn new(tx: &Tx, rx: Rx, buf: Buf, dis: D, kkind: VarKeyKind) -> Self { + Self { + tx: Sender { + tx: tx.clone(), + kkind, + }, + rx, + buf, + dis, + } + } + + /// Run until a fatal error occurs + /// + /// The server will receive frames, and dispatch them. When a fatal error occurs, + /// this method will return with the fatal error. + /// + /// The caller may decide to wait until a connection is re-established, reset any + /// state, or immediately begin re-running. + pub async fn run(&mut self) -> ServerError { + loop { + let Self { + tx, + rx, + buf, + dis: d, + } = self; + let used = match rx.receive(buf).await { + Ok(u) => u, + Err(e) => { + let kind = e.as_kind(); + match kind { + WireRxErrorKind::ConnectionClosed => return ServerError::RxFatal(e), + WireRxErrorKind::ReceivedMessageTooLarge => continue, + WireRxErrorKind::Other => continue, + } + } + }; + let Some((hdr, body)) = VarHeader::take_from_slice(used) else { + // TODO: send a nak on badly formed messages? We don't have + // much to say because we don't have a key or seq no or anything + continue; + }; + let fut = d.handle(tx, &hdr, body); + if let Err(e) = fut.await { + let kind = e.as_kind(); + match kind { + WireTxErrorKind::ConnectionClosed => return ServerError::TxFatal(e), + WireTxErrorKind::Other => {} + } + } + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// DISPATCH TRAIT +////////////////////////////////////////////////////////////////////////////// + +/// The dispatch trait handles an incoming endpoint or topic message +/// +/// The implementations of this trait are typically implemented by the +/// [`define_dispatch!`][crate::define_dispatch] macro. +pub trait Dispatch { + /// The [`WireTx`] impl used by this dispatcher + type Tx: WireTx; + + /// The minimum key length required to avoid hash collisions + fn min_key_len(&self) -> VarKeyKind; + + /// Handle a single incoming frame (endpoint or topic), and dispatch appropriately + async fn handle( + &mut self, + tx: &Sender, + hdr: &VarHeader, + body: &[u8], + ) -> Result<(), ::Error>; +} + +////////////////////////////////////////////////////////////////////////////// +// SPAWNCONTEXT TRAIT +////////////////////////////////////////////////////////////////////////////// + +/// A conversion trait for taking the Context and making a SpawnContext +/// +/// This is necessary if you use the `spawn` variant of `define_dispatch!`. +pub trait SpawnContext { + /// The spawn context type + type SpawnCtxt: 'static; + /// A method to convert the regular context into [`Self::SpawnCtxt`] + fn spawn_ctxt(&mut self) -> Self::SpawnCtxt; +} + +// Hilarious quadruply nested loop. Hope our lists are relatively small! +macro_rules! keycheck { + ( + $lists:ident; + $($num:literal => $func:ident;)* + ) => { + $( + { + let mut i = 0; + let mut good = true; + // For each list... + 'dupe: while i < $lists.len() { + let ilist = $lists[i]; + let mut j = 0; + // And for each key in the list + while j < ilist.len() { + let jkey = ilist[j]; + let akey = $func(jkey); + + // + // We now start checking against items later in the lists... + // + + // For each list (starting with the one we are on) + let mut x = i; + while x < $lists.len() { + // For each item... + // + // Note that for the STARTING list we continue where we started, + // but on subsequent lists start from the beginning + let xlist = $lists[x]; + let mut y = if x == i { + j + 1 + } else { + 0 + }; + + while y < xlist.len() { + let ykey = xlist[y]; + let bkey = $func(ykey); + + if akey == bkey { + good = false; + break 'dupe; + } + y += 1; + } + x += 1; + } + j += 1; + } + i += 1; + } + if good { + return $num; + } + } + )* + }; +} + +/// Calculates at const time the minimum number of bytes (1, 2, 4, or 8) to avoid +/// hash collisions in the lists of keys provided. +/// +/// If there are any duplicates, this function will panic at compile time. Otherwise, +/// this function will return 1, 2, 4, or 8. +/// +/// This function takes a very dumb "brute force" approach, that is of the order +/// `O(4 * N^2 * M^2)`, where `N` is `lists.len()`, and `M` is the length of each +/// sub-list. It is not recommended to call this outside of const context. +pub const fn min_key_needed(lists: &[&[Key]]) -> usize { + const fn one(key: Key) -> u8 { + crate::Key1::from_key8(key).0 + } + const fn two(key: Key) -> u16 { + u16::from_le_bytes(crate::Key2::from_key8(key).0) + } + const fn four(key: Key) -> u32 { + u32::from_le_bytes(crate::Key4::from_key8(key).0) + } + const fn eight(key: Key) -> u64 { + u64::from_le_bytes(key.0) + } + + keycheck! { + lists; + 1 => one; + 2 => two; + 4 => four; + 8 => eight; + }; + + panic!("Collision requiring more than 8 bytes!"); +} + +#[cfg(test)] +mod test { + use crate::{server::min_key_needed, Key}; + + #[test] + fn min_test_1() { + const MINA: usize = min_key_needed(&[&[ + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }, + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]) }, + ]]); + assert_eq!(1, MINA); + + const MINB: usize = min_key_needed(&[ + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }], + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]) }], + ]); + assert_eq!(1, MINB); + } + + #[test] + fn min_test_2() { + const MINA: usize = min_key_needed(&[&[ + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }, + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01]) }, + ]]); + assert_eq!(2, MINA); + const MINB: usize = min_key_needed(&[ + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }], + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01]) }], + ]); + assert_eq!(2, MINB); + } + + #[test] + fn min_test_4() { + const MINA: usize = min_key_needed(&[&[ + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }, + unsafe { Key::from_bytes([0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01]) }, + ]]); + assert_eq!(4, MINA); + const MINB: usize = min_key_needed(&[ + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }], + &[unsafe { Key::from_bytes([0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01]) }], + ]); + assert_eq!(4, MINB); + } + + #[test] + fn min_test_8() { + const MINA: usize = min_key_needed(&[&[ + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }, + unsafe { Key::from_bytes([0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01]) }, + ]]); + assert_eq!(8, MINA); + const MINB: usize = min_key_needed(&[ + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }], + &[unsafe { Key::from_bytes([0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01]) }], + ]); + assert_eq!(8, MINB); + } +} diff --git a/source/postcard-rpc/src/target_server/buffers.rs b/source/postcard-rpc/src/target_server/buffers.rs deleted file mode 100644 index fa7cc70..0000000 --- a/source/postcard-rpc/src/target_server/buffers.rs +++ /dev/null @@ -1,36 +0,0 @@ -pub struct AllBuffers { - pub usb_device: UsbDeviceBuffers, - pub endpoint_out: [u8; EO], - pub tx_buf: [u8; TX], - pub rx_buf: [u8; RX], -} - -impl AllBuffers { - pub const fn new() -> Self { - Self { - usb_device: UsbDeviceBuffers::new(), - endpoint_out: [0u8; EO], - tx_buf: [0u8; TX], - rx_buf: [0u8; RX], - } - } -} - -/// Buffers used by the [`UsbDevice`][embassy_usb::UsbDevice] of `embassy-usb` -pub struct UsbDeviceBuffers { - pub config_descriptor: [u8; 256], - pub bos_descriptor: [u8; 256], - pub control_buf: [u8; 64], - pub msos_descriptor: [u8; 256], -} - -impl UsbDeviceBuffers { - pub const fn new() -> Self { - Self { - config_descriptor: [0u8; 256], - bos_descriptor: [0u8; 256], - msos_descriptor: [0u8; 256], - control_buf: [0u8; 64], - } - } -} diff --git a/source/postcard-rpc/src/target_server/dispatch_macro.rs b/source/postcard-rpc/src/target_server/dispatch_macro.rs deleted file mode 100644 index 0358876..0000000 --- a/source/postcard-rpc/src/target_server/dispatch_macro.rs +++ /dev/null @@ -1,472 +0,0 @@ -/// # Define Dispatch Macro -/// -/// ```rust -/// # use postcard_rpc::target_server::dispatch_macro::fake::*; -/// # use postcard_rpc::{endpoint, target_server::{sender::Sender, SpawnContext}, WireHeader, define_dispatch}; -/// # use postcard_schema::Schema; -/// # use embassy_usb_driver::{Bus, ControlPipe, EndpointIn, EndpointOut}; -/// # use serde::{Deserialize, Serialize}; -/// -/// pub struct DispatchCtx; -/// pub struct SpawnCtx; -/// -/// // This trait impl is necessary if you want to use the `spawn` variant, -/// // as spawned tasks must take ownership of any context they need. -/// impl SpawnContext for DispatchCtx { -/// type SpawnCtxt = SpawnCtx; -/// fn spawn_ctxt(&mut self) -> Self::SpawnCtxt { -/// SpawnCtx -/// } -/// } -/// -/// define_dispatch! { -/// dispatcher: Dispatcher< -/// Mutex = FakeMutex, -/// Driver = FakeDriver, -/// Context = DispatchCtx, -/// >; -/// AlphaEndpoint => async alpha_handler, -/// BetaEndpoint => async beta_handler, -/// GammaEndpoint => async gamma_handler, -/// DeltaEndpoint => blocking delta_handler, -/// EpsilonEndpoint => spawn epsilon_handler_task, -/// } -/// -/// async fn alpha_handler(_c: &mut DispatchCtx, _h: WireHeader, _b: AReq) -> AResp { -/// todo!() -/// } -/// -/// async fn beta_handler(_c: &mut DispatchCtx, _h: WireHeader, _b: BReq) -> BResp { -/// todo!() -/// } -/// -/// async fn gamma_handler(_c: &mut DispatchCtx, _h: WireHeader, _b: GReq) -> GResp { -/// todo!() -/// } -/// -/// fn delta_handler(_c: &mut DispatchCtx, _h: WireHeader, _b: DReq) -> DResp { -/// todo!() -/// } -/// -/// #[embassy_executor::task] -/// async fn epsilon_handler_task(_c: SpawnCtx, _h: WireHeader, _b: EReq, _sender: Sender) { -/// todo!() -/// } -/// ``` - -#[macro_export] -macro_rules! define_dispatch { - // This is the "blocking execution" arm for defining an endpoint - (@arm blocking ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $dispatch:ident) => { - { - let reply = $handler($context, $header.clone(), $req); - if $dispatch.sender.reply::<$endpoint>($header.seq_no, &reply).await.is_err() { - let err = $crate::standard_icd::WireError::SerFailed; - $dispatch.error($header.seq_no, err).await; - } - } - }; - // This is the "async execution" arm for defining an endpoint - (@arm async ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $dispatch:ident) => { - { - let reply = $handler($context, $header.clone(), $req).await; - if $dispatch.sender.reply::<$endpoint>($header.seq_no, &reply).await.is_err() { - let err = $crate::standard_icd::WireError::SerFailed; - $dispatch.error($header.seq_no, err).await; - } - } - }; - // This is the "spawn an embassy task" arm for defining an endpoint - (@arm spawn ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $dispatch:ident) => { - { - let spawner = ::embassy_executor::Spawner::for_current_executor().await; - let context = $crate::target_server::SpawnContext::spawn_ctxt($context); - if spawner.spawn($handler(context, $header.clone(), $req, $dispatch.sender())).is_err() { - let err = $crate::standard_icd::WireError::FailedToSpawn; - $dispatch.error($header.seq_no, err).await; - } - } - }; - // Optional trailing comma lol - ( - dispatcher: $name:ident; - $($endpoint:ty => $flavor:tt $handler:ident,)* - ) => { - define_dispatch! { - dispatcher: $name; - $( - $endpoint => $flavor $handler, - )* - } - }; - // This is the main entrypoint - ( - dispatcher: $name:ident; - $($endpoint:ty => $flavor:tt $handler:ident,)* - ) => { - /// This is a structure that handles dispatching, generated by the - /// `postcard-rpc::define_dispatch!()` macro. - pub struct $name { - pub sender: $crate::target_server::sender::Sender<$mutex, $driver>, - pub context: $context, - } - - impl $name { - /// Create a new instance of the dispatcher - pub fn new( - tx_buf: &'static mut [u8], - ep_in: <$driver as ::embassy_usb::driver::Driver<'static>>::EndpointIn, - context: $context, - ) -> Self { - static SENDER_INNER: ::static_cell::StaticCell< - ::embassy_sync::mutex::Mutex<$mutex, $crate::target_server::sender::SenderInner<$driver>>, - > = ::static_cell::StaticCell::new(); - $name { - sender: $crate::target_server::sender::Sender::init_sender(&SENDER_INNER, tx_buf, ep_in), - context, - } - } - } - - impl $crate::target_server::Dispatch for $name { - type Mutex = $mutex; - type Driver = $driver; - - /// Handle dispatching of a single frame - async fn dispatch( - &mut self, - hdr: $crate::WireHeader, - body: &[u8], - ) { - const _REQ_KEYS_MUST_BE_UNIQUE: () = { - let keys = [$(<$endpoint as $crate::Endpoint>::REQ_KEY),*]; - - let mut i = 0; - - while i < keys.len() { - let mut j = i + 1; - while j < keys.len() { - if keys[i].const_cmp(&keys[j]) { - panic!("Keys are not unique, there is a collision!"); - } - j += 1; - } - - i += 1; - } - }; - - let _ = _REQ_KEYS_MUST_BE_UNIQUE; - - match hdr.key { - $( - <$endpoint as $crate::Endpoint>::REQ_KEY => { - // Can we deserialize the request? - let Ok(req) = postcard::from_bytes::<<$endpoint as $crate::Endpoint>::Request>(body) else { - let err = $crate::standard_icd::WireError::DeserFailed; - self.error(hdr.seq_no, err).await; - return; - }; - - // Store some items as named bindings, so we can use `ident` in the - // recursive macro expansion. Load bearing order: we borrow `context` - // from `dispatch` because we need `dispatch` AFTER `context`, so NLL - // allows this to still borrowck - let dispatch = self; - let context = &mut dispatch.context; - - // This will expand to the right "flavor" of handler - define_dispatch!(@arm $flavor ($endpoint) $handler context hdr req dispatch); - } - )* - other => { - // huh! We have no idea what this key is supposed to be! - let err = $crate::standard_icd::WireError::UnknownKey(other.to_bytes()); - self.error(hdr.seq_no, err).await; - return; - }, - } - } - - /// Send a single error message - async fn error( - &self, - seq_no: u32, - error: $crate::standard_icd::WireError, - ) { - // If we get an error while sending an error, welp there's not much we can do - let _ = self.sender.reply_keyed(seq_no, $crate::standard_icd::ERROR_KEY, &error).await; - } - - /// Get a clone of the Sender of this Dispatch impl - fn sender(&self) -> $crate::target_server::sender::Sender { - self.sender.clone() - } - } - - } -} - -/// This is a basic example that everything compiles. It is intended to exercise the macro above, -/// as well as provide impls for docs. Don't rely on any of this! -#[doc(hidden)] -#[allow(dead_code)] -#[cfg(feature = "test-utils")] -pub mod fake { - use crate::target_server::SpawnContext; - #[allow(unused_imports)] - use crate::{endpoint, target_server::sender::Sender, Schema, WireHeader}; - use embassy_usb_driver::{Bus, ControlPipe, EndpointIn, EndpointOut}; - use serde::{Deserialize, Serialize}; - - #[derive(Serialize, Deserialize, Schema)] - pub struct AReq; - #[derive(Serialize, Deserialize, Schema)] - pub struct AResp; - #[derive(Serialize, Deserialize, Schema)] - pub struct BReq; - #[derive(Serialize, Deserialize, Schema)] - pub struct BResp; - #[derive(Serialize, Deserialize, Schema)] - pub struct GReq; - #[derive(Serialize, Deserialize, Schema)] - pub struct GResp; - #[derive(Serialize, Deserialize, Schema)] - pub struct DReq; - #[derive(Serialize, Deserialize, Schema)] - pub struct DResp; - #[derive(Serialize, Deserialize, Schema)] - pub struct EReq; - #[derive(Serialize, Deserialize, Schema)] - pub struct EResp; - - endpoint!(AlphaEndpoint, AReq, AResp, "alpha"); - endpoint!(BetaEndpoint, BReq, BResp, "beta"); - endpoint!(GammaEndpoint, GReq, GResp, "gamma"); - endpoint!(DeltaEndpoint, DReq, DResp, "delta"); - endpoint!(EpsilonEndpoint, EReq, EResp, "epsilon"); - - pub struct FakeMutex; - pub struct FakeDriver; - pub struct FakeEpOut; - pub struct FakeEpIn; - pub struct FakeCtlPipe; - pub struct FakeBus; - - impl embassy_usb_driver::Endpoint for FakeEpOut { - fn info(&self) -> &embassy_usb_driver::EndpointInfo { - todo!() - } - - async fn wait_enabled(&mut self) { - todo!() - } - } - - impl EndpointOut for FakeEpOut { - async fn read( - &mut self, - _buf: &mut [u8], - ) -> Result { - todo!() - } - } - - impl embassy_usb_driver::Endpoint for FakeEpIn { - fn info(&self) -> &embassy_usb_driver::EndpointInfo { - todo!() - } - - async fn wait_enabled(&mut self) { - todo!() - } - } - - impl EndpointIn for FakeEpIn { - async fn write(&mut self, _buf: &[u8]) -> Result<(), embassy_usb_driver::EndpointError> { - todo!() - } - } - - impl ControlPipe for FakeCtlPipe { - fn max_packet_size(&self) -> usize { - todo!() - } - - async fn setup(&mut self) -> [u8; 8] { - todo!() - } - - async fn data_out( - &mut self, - _buf: &mut [u8], - _first: bool, - _last: bool, - ) -> Result { - todo!() - } - - async fn data_in( - &mut self, - _data: &[u8], - _first: bool, - _last: bool, - ) -> Result<(), embassy_usb_driver::EndpointError> { - todo!() - } - - async fn accept(&mut self) { - todo!() - } - - async fn reject(&mut self) { - todo!() - } - - async fn accept_set_address(&mut self, _addr: u8) { - todo!() - } - } - - impl Bus for FakeBus { - async fn enable(&mut self) { - todo!() - } - - async fn disable(&mut self) { - todo!() - } - - async fn poll(&mut self) -> embassy_usb_driver::Event { - todo!() - } - - fn endpoint_set_enabled( - &mut self, - _ep_addr: embassy_usb_driver::EndpointAddress, - _enabled: bool, - ) { - todo!() - } - - fn endpoint_set_stalled( - &mut self, - _ep_addr: embassy_usb_driver::EndpointAddress, - _stalled: bool, - ) { - todo!() - } - - fn endpoint_is_stalled(&mut self, _ep_addr: embassy_usb_driver::EndpointAddress) -> bool { - todo!() - } - - async fn remote_wakeup(&mut self) -> Result<(), embassy_usb_driver::Unsupported> { - todo!() - } - } - - impl embassy_usb_driver::Driver<'static> for FakeDriver { - type EndpointOut = FakeEpOut; - - type EndpointIn = FakeEpIn; - - type ControlPipe = FakeCtlPipe; - - type Bus = FakeBus; - - fn alloc_endpoint_out( - &mut self, - _ep_type: embassy_usb_driver::EndpointType, - _max_packet_size: u16, - _interval_ms: u8, - ) -> Result { - todo!() - } - - fn alloc_endpoint_in( - &mut self, - _ep_type: embassy_usb_driver::EndpointType, - _max_packet_size: u16, - _interval_ms: u8, - ) -> Result { - todo!() - } - - fn start(self, _control_max_packet_size: u16) -> (Self::Bus, Self::ControlPipe) { - todo!() - } - } - - unsafe impl embassy_sync::blocking_mutex::raw::RawMutex for FakeMutex { - const INIT: Self = Self; - - fn lock(&self, _f: impl FnOnce() -> R) -> R { - todo!() - } - } - - pub struct TestContext { - pub a: u32, - pub b: u32, - } - - impl SpawnContext for TestContext { - type SpawnCtxt = TestSpawnContext; - - fn spawn_ctxt(&mut self) -> Self::SpawnCtxt { - TestSpawnContext { b: self.b } - } - } - - pub struct TestSpawnContext { - b: u32, - } - - define_dispatch! { - dispatcher: TestDispatcher; - AlphaEndpoint => async test_alpha_handler, - BetaEndpoint => async test_beta_handler, - GammaEndpoint => async test_gamma_handler, - DeltaEndpoint => blocking test_delta_handler, - EpsilonEndpoint => spawn test_epsilon_handler_task, - } - - async fn test_alpha_handler( - _context: &mut TestContext, - _header: WireHeader, - _body: AReq, - ) -> AResp { - todo!() - } - - async fn test_beta_handler( - _context: &mut TestContext, - _header: WireHeader, - _body: BReq, - ) -> BResp { - todo!() - } - - async fn test_gamma_handler( - _context: &mut TestContext, - _header: WireHeader, - _body: GReq, - ) -> GResp { - todo!() - } - - fn test_delta_handler(_context: &mut TestContext, _header: WireHeader, _body: DReq) -> DResp { - todo!() - } - - #[embassy_executor::task] - async fn test_epsilon_handler_task( - _context: TestSpawnContext, - _header: WireHeader, - _body: EReq, - _sender: Sender, - ) { - todo!() - } -} diff --git a/source/postcard-rpc/src/target_server/mod.rs b/source/postcard-rpc/src/target_server/mod.rs deleted file mode 100644 index cfa1707..0000000 --- a/source/postcard-rpc/src/target_server/mod.rs +++ /dev/null @@ -1,195 +0,0 @@ -#![allow(async_fn_in_trait)] - -use crate::{ - headered::extract_header_from_bytes, - standard_icd::{FrameTooLong, FrameTooShort, WireError}, - WireHeader, -}; -use embassy_sync::blocking_mutex::raw::RawMutex; -use embassy_usb::{ - driver::Driver, - msos::{self, windows_version}, - Builder, UsbDevice, -}; -use embassy_usb_driver::{Endpoint, EndpointError, EndpointOut}; -use sender::Sender; - -pub mod buffers; -pub mod dispatch_macro; -pub mod sender; - -const DEVICE_INTERFACE_GUIDS: &[&str] = &["{AFB9A6FB-30BA-44BC-9232-806CFC875321}"]; - -/// A trait that defines the postcard-rpc message dispatching behavior -/// -/// This is normally generated automatically by the [`define_dispatch!()`][crate::define_dispatch] -/// macro. -pub trait Dispatch { - type Mutex: RawMutex; - type Driver: Driver<'static>; - - /// Handle a single message, with the header deserialized and the - /// body not yet deserialized. - /// - /// This function must handle replying (either immediately or - /// in the future, for example if spawning a task) - async fn dispatch(&mut self, hdr: WireHeader, body: &[u8]); - - /// Send an error message, of the path and key defined for - /// the connection - async fn error(&self, seq_no: u32, error: WireError); - - /// Obtain an owned sender - fn sender(&self) -> Sender; -} - -/// A conversion trait for taking the Context and making a SpawnContext -/// -/// This is necessary if you use the `spawn` variant of `define_dispatch!`. -pub trait SpawnContext { - type SpawnCtxt: 'static; - fn spawn_ctxt(&mut self) -> Self::SpawnCtxt; -} - -/// A basic example of embassy_usb configuration values -pub fn example_config() -> embassy_usb::Config<'static> { - // Create embassy-usb Config - let mut config = embassy_usb::Config::new(0x16c0, 0x27DD); - config.manufacturer = Some("Embassy"); - config.product = Some("postcard-rpc example"); - config.serial_number = Some("12345678"); - - // Required for windows compatibility. - // https://developer.nordicsemi.com/nRF_Connect_SDK/doc/1.9.1/kconfig/CONFIG_CDC_ACM_IAD.html#help - config.device_class = 0xEF; - config.device_sub_class = 0x02; - config.device_protocol = 0x01; - config.composite_with_iads = true; - - config -} - -/// Configure the USB driver for use with postcard-rpc -/// -/// At the moment this is very geared towards USB FS. -pub fn configure_usb>( - driver: D, - bufs: &'static mut buffers::UsbDeviceBuffers, - config: embassy_usb::Config<'static>, -) -> (UsbDevice<'static, D>, D::EndpointIn, D::EndpointOut) { - let mut builder = Builder::new( - driver, - config, - &mut bufs.config_descriptor, - &mut bufs.bos_descriptor, - &mut bufs.msos_descriptor, - &mut bufs.control_buf, - ); - - // Add the Microsoft OS Descriptor (MSOS/MOD) descriptor. - // We tell Windows that this entire device is compatible with the "WINUSB" feature, - // which causes it to use the built-in WinUSB driver automatically, which in turn - // can be used by libusb/rusb software without needing a custom driver or INF file. - // In principle you might want to call msos_feature() just on a specific function, - // if your device also has other functions that still use standard class drivers. - builder.msos_descriptor(windows_version::WIN8_1, 0); - builder.msos_feature(msos::CompatibleIdFeatureDescriptor::new("WINUSB", "")); - builder.msos_feature(msos::RegistryPropertyFeatureDescriptor::new( - "DeviceInterfaceGUIDs", - msos::PropertyData::RegMultiSz(DEVICE_INTERFACE_GUIDS), - )); - - // Add a vendor-specific function (class 0xFF), and corresponding interface, - // that uses our custom handler. - let mut function = builder.function(0xFF, 0, 0); - let mut interface = function.interface(); - let mut alt = interface.alt_setting(0xFF, 0, 0, None); - let ep_out = alt.endpoint_bulk_out(64); - let ep_in = alt.endpoint_bulk_in(64); - drop(function); - - // Build the builder. - let usb = builder.build(); - - (usb, ep_in, ep_out) -} - -/// Handle RPC Dispatching -pub async fn rpc_dispatch( - mut ep_out: D::EndpointOut, - mut dispatch: T, - rx_buf: &'static mut [u8], -) -> ! -where - M: RawMutex + 'static, - D: Driver<'static> + 'static, - T: Dispatch, -{ - 'connect: loop { - // Wait for connection - ep_out.wait_enabled().await; - - // For each packet... - 'packet: loop { - // Accumulate a whole frame - let mut window = &mut rx_buf[..]; - 'buffer: loop { - if window.is_empty() { - #[cfg(feature = "defmt")] - defmt::warn!("Overflow!"); - let mut bonus: usize = 0; - loop { - // Just drain until the end of the overflow frame - match ep_out.read(rx_buf).await { - Ok(n) if n < 64 => { - bonus = bonus.saturating_add(n); - let err = WireError::FrameTooLong(FrameTooLong { - len: u32::try_from(bonus.saturating_add(rx_buf.len())) - .unwrap_or(u32::MAX), - max: u32::try_from(rx_buf.len()).unwrap_or(u32::MAX), - }); - dispatch.error(0, err).await; - continue 'packet; - } - Ok(n) => { - bonus = bonus.saturating_add(n); - } - Err(EndpointError::BufferOverflow) => panic!(), - Err(EndpointError::Disabled) => continue 'connect, - }; - } - } - - let n = match ep_out.read(window).await { - Ok(n) => n, - Err(EndpointError::BufferOverflow) => panic!(), - Err(EndpointError::Disabled) => continue 'connect, - }; - - let (_now, later) = window.split_at_mut(n); - window = later; - if n != 64 { - break 'buffer; - } - } - - // We now have a full frame! Great! - let wlen = window.len(); - let len = rx_buf.len() - wlen; - let frame = &rx_buf[..len]; - - #[cfg(feature = "defmt")] - defmt::debug!("got frame: {=usize}", frame.len()); - - // If it's for us, process it - if let Ok((hdr, body)) = extract_header_from_bytes(frame) { - dispatch.dispatch(hdr, body).await; - } else { - let err = WireError::FrameTooShort(FrameTooShort { - len: u32::try_from(frame.len()).unwrap_or(u32::MAX), - }); - dispatch.error(0, err).await; - } - } - } -} diff --git a/source/postcard-rpc/src/target_server/sender.rs b/source/postcard-rpc/src/target_server/sender.rs deleted file mode 100644 index 87bae4c..0000000 --- a/source/postcard-rpc/src/target_server/sender.rs +++ /dev/null @@ -1,319 +0,0 @@ -use embassy_sync::{blocking_mutex::raw::RawMutex, mutex::Mutex}; -use embassy_usb_driver::{Driver, Endpoint, EndpointIn}; -use futures_util::FutureExt; -use postcard_schema::Schema; -use serde::Serialize; -use static_cell::StaticCell; - -use crate::Key; - -/// This is the interface for sending information to the client. -/// -/// This is normally used by postcard-rpc itself, as well as for cases where -/// you have to manually send data, like publishing on a topic or delayed -/// replies (e.g. when spawning a task). -#[derive(Copy)] -pub struct Sender + 'static> { - inner: &'static Mutex>, -} - -impl + 'static> Sender { - /// Initialize the Sender, giving it the pieces it needs - /// - /// Panics if called more than once. - pub fn init_sender( - sc: &'static StaticCell>>, - tx_buf: &'static mut [u8], - ep_in: D::EndpointIn, - ) -> Self { - let max_log_len = actual_varint_max_len(tx_buf.len()); - let x = sc.init(Mutex::new(SenderInner { - ep_in, - tx_buf, - log_seq: 0, - max_log_len, - })); - Sender { inner: x } - } - - /// Send a reply for the given endpoint - #[inline] - pub async fn reply(&self, seq_no: u32, resp: &E::Response) -> Result<(), ()> - where - E: crate::Endpoint, - E::Response: Serialize + Schema, - { - let mut inner = self.inner.lock().await; - let SenderInner { - ep_in, - tx_buf, - log_seq: _, - max_log_len: _, - } = &mut *inner; - if let Ok(used) = crate::headered::to_slice_keyed(seq_no, E::RESP_KEY, resp, tx_buf) { - send_all::(ep_in, used, true).await - } else { - Err(()) - } - } - - /// Send a reply with the given Key - /// - /// This is useful when replying with "unusual" keys, for example Error responses - /// not tied to any specific Endpoint. - #[inline] - pub async fn reply_keyed(&self, seq_no: u32, key: Key, resp: &T) -> Result<(), ()> - where - T: Serialize + Schema, - { - let mut inner = self.inner.lock().await; - let SenderInner { - ep_in, - tx_buf, - log_seq: _, - max_log_len: _, - } = &mut *inner; - if let Ok(used) = crate::headered::to_slice_keyed(seq_no, key, resp, tx_buf) { - send_all::(ep_in, used, true).await - } else { - Err(()) - } - } - - /// Publish a Topic message - #[inline] - pub async fn publish(&self, seq_no: u32, msg: &T::Message) -> Result<(), ()> - where - T: crate::Topic, - T::Message: Serialize + Schema, - { - let mut inner = self.inner.lock().await; - let SenderInner { - ep_in, - tx_buf, - log_seq: _, - max_log_len: _, - } = &mut *inner; - - if let Ok(used) = crate::headered::to_slice_keyed(seq_no, T::TOPIC_KEY, msg, tx_buf) { - send_all::(ep_in, used, true).await - } else { - Err(()) - } - } - - pub async fn str_publish<'a, T>(&self, s: &'a str) - where - T: crate::Topic, - { - let mut inner = self.inner.lock().await; - let SenderInner { - ep_in, - tx_buf, - log_seq, - max_log_len: _, - } = &mut *inner; - let seq_no = *log_seq; - *log_seq = log_seq.wrapping_add(1); - if let Ok(used) = - crate::headered::to_slice_keyed(seq_no, T::TOPIC_KEY, s.as_bytes(), tx_buf) - { - let _ = send_all::(ep_in, used, false).await; - } - } - - pub async fn fmt_publish<'a, T>(&self, args: core::fmt::Arguments<'a>) - where - T: crate::Topic, - { - let mut inner = self.inner.lock().await; - let SenderInner { - ep_in, - tx_buf, - log_seq, - max_log_len, - } = &mut *inner; - let ttl_len = tx_buf.len(); - - // First, populate the header - let hdr = crate::WireHeader { - key: T::TOPIC_KEY, - seq_no: *log_seq, - }; - *log_seq = log_seq.wrapping_add(1); - let Ok(hdr_used) = postcard::to_slice(&hdr, tx_buf) else { - return; - }; - let hdr_used = hdr_used.len(); - - // Then, reserve space for non-canonical length fields - // We also set all but the last bytes to be "continuation" - // bytes - let (_, remaining) = tx_buf.split_at_mut(hdr_used); - if remaining.len() < *max_log_len { - return; - } - let (len_field, body) = remaining.split_at_mut(*max_log_len); - for b in len_field.iter_mut() { - *b = 0x80; - } - len_field.last_mut().map(|b| *b = 0x00); - - // Then, do the formatting - let body_len = body.len(); - let mut sw = SliceWriter(body); - let res = core::fmt::write(&mut sw, args); - - // Calculate the number of bytes used *for formatting*. - let remain = sw.0.len(); - let used = body_len - remain; - - // If we had an error, that's probably because we ran out - // of room. If we had an error, AND there is at least three - // bytes, then replace those with '.'s like ... - if res.is_err() && (body.len() >= 3) { - let start = body.len() - 3; - body[start..].iter_mut().for_each(|b| *b = b'.'); - } - - // then go back and fill in the len - we write the len - // directly to the reserved bytes, and if we DIDN'T use - // the full space, we mark the end of the real length as - // a continuation field. This will result in a non-canonical - // "extended" length in postcard, and will "spill into" the - // bytes we wrote previously above - let mut len_bytes = [0u8; varint_max::()]; - let len_used = varint_usize(used, &mut len_bytes); - if len_used.len() != len_field.len() { - len_used.last_mut().map(|b| *b = *b | 0x80); - } - len_field[..len_used.len()].copy_from_slice(len_used); - - // Calculate the TOTAL amount - let act_used = ttl_len - remain; - - let _ = send_all::(ep_in, &tx_buf[..act_used], false).await; - } -} - -impl + 'static> Clone for Sender { - fn clone(&self) -> Self { - Sender { inner: self.inner } - } -} - -/// Implementation detail, holding the endpoint and scratch buffer used for sending -pub struct SenderInner> { - ep_in: D::EndpointIn, - tx_buf: &'static mut [u8], - log_seq: u32, - max_log_len: usize, -} - -/// Helper function for sending a single frame. -/// -/// If an empty slice is provided, no bytes will be sent. -#[inline] -async fn send_all( - ep_in: &mut D::EndpointIn, - out: &[u8], - wait_for_enabled: bool, -) -> Result<(), ()> -where - D: Driver<'static>, -{ - if out.is_empty() { - return Ok(()); - } - if wait_for_enabled { - ep_in.wait_enabled().await; - } else if ep_in.wait_enabled().now_or_never().is_none() { - return Ok(()); - } - - // write in segments of 64. The last chunk may - // be 0 < len <= 64. - for ch in out.chunks(64) { - if ep_in.write(ch).await.is_err() { - return Err(()); - } - } - // If the total we sent was a multiple of 64, send an - // empty message to "flush" the transaction. We already checked - // above that the len != 0. - if (out.len() & (64 - 1)) == 0 && ep_in.write(&[]).await.is_err() { - return Err(()); - } - - Ok(()) -} - -struct SliceWriter<'a>(&'a mut [u8]); - -impl<'a> core::fmt::Write for SliceWriter<'a> { - fn write_str(&mut self, s: &str) -> Result<(), core::fmt::Error> { - let sli = core::mem::take(&mut self.0); - - // If this write would overflow us, note that, but still take - // as much as we possibly can here - let bad = s.len() > sli.len(); - let to_write = s.len().min(sli.len()); - let (now, later) = sli.split_at_mut(to_write); - now.copy_from_slice(s.as_bytes()); - self.0 = later; - - // Now, report whether we overflowed or not - if bad { - Err(core::fmt::Error) - } else { - Ok(()) - } - } -} - -/// Returns the maximum number of bytes required to encode T. -const fn varint_max() -> usize { - const BITS_PER_BYTE: usize = 8; - const BITS_PER_VARINT_BYTE: usize = 7; - - // How many data bits do we need for this type? - let bits = core::mem::size_of::() * BITS_PER_BYTE; - - // We add (BITS_PER_VARINT_BYTE - 1), to ensure any integer divisions - // with a remainder will always add exactly one full byte, but - // an evenly divided number of bits will be the same - let roundup_bits = bits + (BITS_PER_VARINT_BYTE - 1); - - // Apply division, using normal "round down" integer division - roundup_bits / BITS_PER_VARINT_BYTE -} - -#[inline] -fn varint_usize(n: usize, out: &mut [u8; varint_max::()]) -> &mut [u8] { - let mut value = n; - for i in 0..varint_max::() { - out[i] = value.to_le_bytes()[0]; - if value < 128 { - return &mut out[..=i]; - } - - out[i] |= 0x80; - value >>= 7; - } - debug_assert_eq!(value, 0); - &mut out[..] -} - -fn actual_varint_max_len(largest: usize) -> usize { - if largest < (2 << 7) { - 1 - } else if largest < (2 << 14) { - 2 - } else if largest < (2 << 21) { - 3 - } else if largest < (2 << 28) { - 4 - } else { - varint_max::() - } -} diff --git a/source/postcard-rpc/src/test_utils.rs b/source/postcard-rpc/src/test_utils.rs index f5b3901..d51c7db 100644 --- a/source/postcard-rpc/src/test_utils.rs +++ b/source/postcard-rpc/src/test_utils.rs @@ -2,11 +2,11 @@ use core::{fmt::Display, future::Future}; +use crate::header::{VarHeader, VarKey, VarSeq, VarSeqKind}; use crate::host_client::util::Stopper; use crate::{ - headered::extract_header_from_bytes, host_client::{HostClient, RpcFrame, WireRx, WireSpawn, WireTx}, - Endpoint, Topic, WireHeader, + Endpoint, Topic, }; use postcard_schema::Schema; use serde::{de::DeserializeOwned, Serialize}; @@ -15,25 +15,32 @@ use tokio::{ sync::mpsc::{channel, Receiver, Sender}, }; +/// Rx Helper type pub struct LocalRx { fake_error: Stopper, from_server: Receiver>, } +/// Tx Helper type pub struct LocalTx { fake_error: Stopper, to_server: Sender>, } +/// Spawn helper type pub struct LocalSpawn; +/// Server type pub struct LocalFakeServer { fake_error: Stopper, + /// from client to server pub from_client: Receiver>, + /// from server to client pub to_client: Sender>, } impl LocalFakeServer { + /// receive a frame pub async fn recv_from_client(&mut self) -> Result { let msg = self.from_client.recv().await.ok_or(LocalError::TxClosed)?; - let Ok((hdr, body)) = extract_header_from_bytes(&msg) else { + let Some((hdr, body)) = VarHeader::take_from_slice(&msg) else { return Err(LocalError::BadFrame); }; Ok(RpcFrame { @@ -42,6 +49,7 @@ impl LocalFakeServer { }) } + /// Reply pub async fn reply( &mut self, seq_no: u32, @@ -51,9 +59,9 @@ impl LocalFakeServer { E::Response: Serialize, { let frame = RpcFrame { - header: WireHeader { - key: E::RESP_KEY, - seq_no, + header: VarHeader { + key: VarKey::Key8(E::RESP_KEY), + seq_no: VarSeq::Seq4(seq_no), }, body: postcard::to_stdvec(data).unwrap(), }; @@ -63,6 +71,7 @@ impl LocalFakeServer { .map_err(|_| LocalError::RxClosed) } + /// Publish pub async fn publish( &mut self, seq_no: u32, @@ -72,9 +81,9 @@ impl LocalFakeServer { T::Message: Serialize, { let frame = RpcFrame { - header: WireHeader { - key: T::TOPIC_KEY, - seq_no, + header: VarHeader { + key: VarKey::Key8(T::TOPIC_KEY), + seq_no: VarSeq::Seq4(seq_no), }, body: postcard::to_stdvec(data).unwrap(), }; @@ -84,16 +93,22 @@ impl LocalFakeServer { .map_err(|_| LocalError::RxClosed) } + /// oops pub fn cause_fatal_error(&self) { self.fake_error.stop(); } } +/// Local error type #[derive(Debug, PartialEq)] pub enum LocalError { + /// RxClosed RxClosed, + /// TxClosed TxClosed, + /// BadFrame BadFrame, + /// FatalError FatalError, } @@ -187,6 +202,7 @@ where fake_error: fake_error.clone(), }, LocalSpawn, + VarSeqKind::Seq2, err_uri_path, bound, ); diff --git a/source/postcard-rpc/src/uniques.rs b/source/postcard-rpc/src/uniques.rs new file mode 100644 index 0000000..aa7f90f --- /dev/null +++ b/source/postcard-rpc/src/uniques.rs @@ -0,0 +1,944 @@ +//! Create unique type lists at compile time +//! +//! This is an excercise in the capabilities of macros and const fns. +//! +//! From a very high level, the process goes like this: +//! +//! 1. We recursively look at a type, counting how many types it contains, +//! WITHOUT considering de-duplication. This is used as an "upper bound" +//! of the number of potential types we could have to report +//! 2. Create an array of `[Option<&NamedType>; MAX]` that we use something +//! like an append-only vec. +//! 3. Recursively traverse the type AGAIN, this time collecting all unique +//! non-primitive types we encounter, and adding them to the list. This +//! is outrageously inefficient, but it is done at const time with all +//! the restrictions it entails, because we don't pay at runtime. +//! 4. Record how many types we ACTUALLY collected in step 3, and create a +//! new array, `[&NamedType; ACTUAL]`, and copy the unique types into +//! this new array +//! 5. Convert this `[&NamedType; N]` array into a `&'static [&NamedType]` +//! array to make it possible to handle with multiple types +//! 6. If we are collecting MULTIPLE types into a single aggregate report, +//! then we make a new array of `[Option<&NamedType>; sum(all types)]`, +//! by calculating the sum of types contained for each list calculated +//! in step 4. +//! 7. We then perform the same "merging" process from 3, pushing any unique +//! type we find into the aggregate list, and recording the number of +//! unique types we found in the entire set. +//! 8. We then perform the same "shrinking" process from step 4, leaving us +//! with a single array, `[&NamedType; TOTAL]` containing all unique types +//! 9. We then perform the same "slicing" process from step 5, to get our +//! final `&'static [&NamedType]`. + +use postcard_schema::{ + schema::{DataModelType, DataModelVariant, NamedType, NamedValue, NamedVariant}, + Schema, +}; + +////////////////////////////////////////////////////////////////////////////// +// STAGE 0 - HELPERS +////////////////////////////////////////////////////////////////////////////// + +/// `is_prim` returns whether the type is a *primitive*, or a built-in type that +/// does not need to be sent over the wire. +const fn is_prim(dmt: &DataModelType) -> bool { + match dmt { + // These are all primitives + DataModelType::Bool => true, + DataModelType::I8 => true, + DataModelType::U8 => true, + DataModelType::I16 => true, + DataModelType::I32 => true, + DataModelType::I64 => true, + DataModelType::I128 => true, + DataModelType::U16 => true, + DataModelType::U32 => true, + DataModelType::U64 => true, + DataModelType::U128 => true, + DataModelType::Usize => true, + DataModelType::Isize => true, + DataModelType::F32 => true, + DataModelType::F64 => true, + DataModelType::Char => true, + DataModelType::String => true, + DataModelType::ByteArray => true, + DataModelType::Unit => true, + DataModelType::Schema => true, + + // A unit-struct is always named, so it is not primitive, as the + // name has meaning even without a value + DataModelType::UnitStruct => false, + // Items with subtypes are composite, and therefore not primitives, as + // we need to convey this information. + DataModelType::Option(_) | DataModelType::NewtypeStruct(_) | DataModelType::Seq(_) => false, + DataModelType::Tuple(_) | DataModelType::TupleStruct(_) => false, + DataModelType::Map { .. } => false, + DataModelType::Struct(_) => false, + DataModelType::Enum(_) => false, + } +} + +/// A const version of `::eq` +const fn str_eq(a: &str, b: &str) -> bool { + let mut i = 0; + if a.len() != b.len() { + return false; + } + let a_by = a.as_bytes(); + let b_by = b.as_bytes(); + while i < a.len() { + if a_by[i] != b_by[i] { + return false; + } + i += 1; + } + true +} + +/// A const version of `::eq` +const fn nty_eq(a: &NamedType, b: &NamedType) -> bool { + str_eq(a.name, b.name) && dmt_eq(a.ty, b.ty) +} + +/// A const version of `<[&NamedType] as PartialEq>::eq` +const fn ntys_eq(a: &[&NamedType], b: &[&NamedType]) -> bool { + if a.len() != b.len() { + return false; + } + let mut i = 0; + while i < a.len() { + if !nty_eq(a[i], b[i]) { + return false; + } + i += 1; + } + true +} + +/// A const version of `::eq` +const fn dmt_eq(a: &DataModelType, b: &DataModelType) -> bool { + match (a, b) { + // Data model types are ONLY matching if they are both the same variant + // + // For primitives (and unit structs), we only check the discriminant matches. + (DataModelType::Bool, DataModelType::Bool) => true, + (DataModelType::I8, DataModelType::I8) => true, + (DataModelType::U8, DataModelType::U8) => true, + (DataModelType::I16, DataModelType::I16) => true, + (DataModelType::I32, DataModelType::I32) => true, + (DataModelType::I64, DataModelType::I64) => true, + (DataModelType::I128, DataModelType::I128) => true, + (DataModelType::U16, DataModelType::U16) => true, + (DataModelType::U32, DataModelType::U32) => true, + (DataModelType::U64, DataModelType::U64) => true, + (DataModelType::U128, DataModelType::U128) => true, + (DataModelType::Usize, DataModelType::Usize) => true, + (DataModelType::Isize, DataModelType::Isize) => true, + (DataModelType::F32, DataModelType::F32) => true, + (DataModelType::F64, DataModelType::F64) => true, + (DataModelType::Char, DataModelType::Char) => true, + (DataModelType::String, DataModelType::String) => true, + (DataModelType::ByteArray, DataModelType::ByteArray) => true, + (DataModelType::Unit, DataModelType::Unit) => true, + (DataModelType::UnitStruct, DataModelType::UnitStruct) => true, + (DataModelType::Schema, DataModelType::Schema) => true, + + // For non-primitive types, we check whether all children are equivalent as well. + (DataModelType::Option(nta), DataModelType::Option(ntb)) => nty_eq(nta, ntb), + (DataModelType::NewtypeStruct(nta), DataModelType::NewtypeStruct(ntb)) => nty_eq(nta, ntb), + (DataModelType::Seq(nta), DataModelType::Seq(ntb)) => nty_eq(nta, ntb), + + (DataModelType::Tuple(ntsa), DataModelType::Tuple(ntsb)) => ntys_eq(ntsa, ntsb), + (DataModelType::TupleStruct(ntsa), DataModelType::TupleStruct(ntsb)) => ntys_eq(ntsa, ntsb), + ( + DataModelType::Map { + key: keya, + val: vala, + }, + DataModelType::Map { + key: keyb, + val: valb, + }, + ) => nty_eq(keya, keyb) && nty_eq(vala, valb), + (DataModelType::Struct(nvalsa), DataModelType::Struct(nvalsb)) => vals_eq(nvalsa, nvalsb), + (DataModelType::Enum(nvarsa), DataModelType::Enum(nvarsb)) => vars_eq(nvarsa, nvarsb), + + // Any mismatches are not equal + _ => false, + } +} + +/// A const version of `::eq` +const fn var_eq(a: &NamedVariant, b: &NamedVariant) -> bool { + str_eq(a.name, b.name) && dmv_eq(a.ty, b.ty) +} + +/// A const version of `<&[&NamedVariant] as PartialEq>::eq` +const fn vars_eq(a: &[&NamedVariant], b: &[&NamedVariant]) -> bool { + if a.len() != b.len() { + return false; + } + let mut i = 0; + while i < a.len() { + if !var_eq(a[i], b[i]) { + return false; + } + i += 1; + } + true +} + +/// A const version of `<&[&NamedValue] as PartialEq>::eq` +const fn vals_eq(a: &[&NamedValue], b: &[&NamedValue]) -> bool { + if a.len() != b.len() { + return false; + } + let mut i = 0; + while i < a.len() { + if !str_eq(a[i].name, b[i].name) { + return false; + } + if !nty_eq(a[i].ty, b[i].ty) { + return false; + } + + i += 1; + } + true +} + +/// A const version of `::eq` +const fn dmv_eq(a: &DataModelVariant, b: &DataModelVariant) -> bool { + match (a, b) { + (DataModelVariant::UnitVariant, DataModelVariant::UnitVariant) => true, + (DataModelVariant::NewtypeVariant(nta), DataModelVariant::NewtypeVariant(ntb)) => { + nty_eq(nta, ntb) + } + (DataModelVariant::TupleVariant(ntsa), DataModelVariant::TupleVariant(ntsb)) => { + ntys_eq(ntsa, ntsb) + } + (DataModelVariant::StructVariant(nvarsa), DataModelVariant::StructVariant(nvarsb)) => { + vals_eq(nvarsa, nvarsb) + } + _ => false, + } +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 1 - UPPER BOUND CALCULATION +////////////////////////////////////////////////////////////////////////////// + +/// Count the number of unique types contained by this NamedType, +/// including children and this type itself. +/// +/// For built-in primitives, this could be zero. For non-primitive +/// types, this will be at least one. +/// +/// This function does NOT attempt to perform any de-duplication. +pub const fn unique_types_nty_upper(nty: &NamedType) -> usize { + let child_ct = unique_types_dmt_upper(nty.ty); + if is_prim(nty.ty) { + child_ct + } else { + child_ct + 1 + } +} + +/// Count the number of unique types contained by this DataModelType, +/// ONLY counting children, and not this type, as this will be counted +/// when considering the NamedType instead. +// +// TODO: We could attempt to do LOCAL de-duplication, for example +// a `[u8; 32]` would end up as a tuple of 32 items, drastically +// inflating the total. +const fn unique_types_dmt_upper(dmt: &DataModelType) -> usize { + match dmt { + // These are all primitives + DataModelType::Bool => 0, + DataModelType::I8 => 0, + DataModelType::U8 => 0, + DataModelType::I16 => 0, + DataModelType::I32 => 0, + DataModelType::I64 => 0, + DataModelType::I128 => 0, + DataModelType::U16 => 0, + DataModelType::U32 => 0, + DataModelType::U64 => 0, + DataModelType::U128 => 0, + DataModelType::Usize => 0, + DataModelType::Isize => 0, + DataModelType::F32 => 0, + DataModelType::F64 => 0, + DataModelType::Char => 0, + DataModelType::String => 0, + DataModelType::ByteArray => 0, + DataModelType::Unit => 0, + DataModelType::UnitStruct => 0, + DataModelType::Schema => 0, + + // Items with one subtype + DataModelType::Option(nt) | DataModelType::NewtypeStruct(nt) | DataModelType::Seq(nt) => { + unique_types_nty_upper(nt) + } + // tuple-ish + DataModelType::Tuple(nts) | DataModelType::TupleStruct(nts) => { + let mut uniq = 0; + let mut i = 0; + while i < nts.len() { + uniq += unique_types_nty_upper(nts[i]); + i += 1; + } + uniq + } + DataModelType::Map { key, val } => { + unique_types_nty_upper(key) + unique_types_nty_upper(val) + } + DataModelType::Struct(nvals) => { + let mut uniq = 0; + let mut i = 0; + while i < nvals.len() { + uniq += unique_types_nty_upper(nvals[i].ty); + i += 1; + } + uniq + } + DataModelType::Enum(nvars) => { + let mut uniq = 0; + let mut i = 0; + while i < nvars.len() { + uniq += unique_types_var_upper(nvars[i]); + i += 1; + } + uniq + } + } +} + +/// Count the number of unique types contained by this NamedVariant, +/// ONLY counting children, and not this type, as this will be counted +/// when considering the NamedType instead. +// +// TODO: We could attempt to do LOCAL de-duplication, for example +// a `[u8; 32]` would end up as a tuple of 32 items, drastically +// inflating the total. +const fn unique_types_var_upper(nvar: &NamedVariant) -> usize { + match nvar.ty { + DataModelVariant::UnitVariant => 0, + DataModelVariant::NewtypeVariant(nt) => unique_types_nty_upper(nt), + DataModelVariant::TupleVariant(nts) => { + let mut uniq = 0; + let mut i = 0; + while i < nts.len() { + uniq += unique_types_nty_upper(nts[i]); + i += 1; + } + uniq + } + DataModelVariant::StructVariant(nvals) => { + let mut uniq = 0; + let mut i = 0; + while i < nvals.len() { + uniq += unique_types_nty_upper(nvals[i].ty); + i += 1; + } + uniq + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 2/3 - COLLECTION OF UNIQUES AND CALCULATION OF EXACT SIZE +////////////////////////////////////////////////////////////////////////////// + +/// This function collects the set of unique types, reporting the entire list +/// (which might only be partially used), as well as the *used* length. +/// +/// The parameter MAX should be the highest possible number of unique types, +/// if NONE of the types have any duplication. This should be calculated using +/// [`unique_types_nty_upper()`]. This upper bound allows us to pre-allocate +/// enough storage for the collection process. +pub const fn type_chewer_nty( + nty: &NamedType, +) -> ([Option<&NamedType>; MAX], usize) { + // Calculate the number of unique items in the children of this type + let (mut arr, mut used) = type_chewer_dmt::(nty.ty); + let mut i = 0; + + // determine if this is a single-item primitive - if so, skip adding + // this type to the unique list + let mut found = is_prim(nty.ty); + + while !found && i < used { + let Some(ty) = arr[i] else { panic!() }; + if nty_eq(nty, ty) { + found = true; + } + i += 1; + } + if !found { + arr[used] = Some(nty); + used += 1; + } + (arr, used) +} + +/// This function collects the set of unique types, reporting the entire list +/// (which might only be partially used), as well as the *used* length. +/// +/// The parameter MAX should be the highest possible number of unique types, +/// if NONE of the types have any duplication. This should be calculated using +/// [`unique_types_nty_upper()`]. This upper bound allows us to pre-allocate +/// enough storage for the collection process. +// +// TODO: There is a LOT of duplicated code here. This is to reduce the number of +// intermediate `[Option; MAX]` arrays we contain, as well as the total amount +// of recursion depth. I am open to suggestions of how to reduce this. Part of +// this restriction is that we can't take an `&mut` as a const fn arg, so we +// always have to do it by value, then merge-in the changes. +const fn type_chewer_dmt( + dmt: &DataModelType, +) -> ([Option<&NamedType>; MAX], usize) { + match dmt { + // These are all primitives - they never have any children to report. + DataModelType::Bool => ([None; MAX], 0), + DataModelType::I8 => ([None; MAX], 0), + DataModelType::U8 => ([None; MAX], 0), + DataModelType::I16 => ([None; MAX], 0), + DataModelType::I32 => ([None; MAX], 0), + DataModelType::I64 => ([None; MAX], 0), + DataModelType::I128 => ([None; MAX], 0), + DataModelType::U16 => ([None; MAX], 0), + DataModelType::U32 => ([None; MAX], 0), + DataModelType::U64 => ([None; MAX], 0), + DataModelType::U128 => ([None; MAX], 0), + DataModelType::Usize => ([None; MAX], 0), + DataModelType::Isize => ([None; MAX], 0), + DataModelType::F32 => ([None; MAX], 0), + DataModelType::F64 => ([None; MAX], 0), + DataModelType::Char => ([None; MAX], 0), + DataModelType::String => ([None; MAX], 0), + DataModelType::ByteArray => ([None; MAX], 0), + DataModelType::Unit => ([None; MAX], 0), + DataModelType::Schema => ([None; MAX], 0), + + // A unit struct *as a namedtype* can be a unique/non-primitive type, + // but DataModelType calculation is only concerned with CHILDREN, and + // a unit struct has none. + DataModelType::UnitStruct => ([None; MAX], 0), + + // Items with one subtype + DataModelType::Option(nt) | DataModelType::NewtypeStruct(nt) | DataModelType::Seq(nt) => { + type_chewer_nty::(nt) + } + // tuple-ish + DataModelType::Tuple(nts) | DataModelType::TupleStruct(nts) => { + let mut out = [None; MAX]; + let mut i = 0; + let mut outidx = 0; + + // For each type in the tuple... + while i < nts.len() { + // Get the types used by this field + let (arr, used) = type_chewer_nty::(nts[i]); + let mut j = 0; + // For each type in this field... + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + i += 1; + } + (out, outidx) + } + DataModelType::Map { key, val } => { + let mut out = [None; MAX]; + let mut outidx = 0; + + // Do key + let (arr, used) = type_chewer_nty::(key); + let mut j = 0; + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + + // Then do val + let (arr, used) = type_chewer_nty::(val); + let mut j = 0; + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + + (out, outidx) + } + DataModelType::Struct(nvals) => { + let mut out = [None; MAX]; + let mut i = 0; + let mut outidx = 0; + + // For each type in the tuple... + while i < nvals.len() { + // Get the types used by this field + let (arr, used) = type_chewer_nty::(nvals[i].ty); + let mut j = 0; + // For each type in this field... + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + i += 1; + } + (out, outidx) + } + DataModelType::Enum(nvars) => { + let mut out = [None; MAX]; + let mut i = 0; + let mut outidx = 0; + + // For each type in the variant... + while i < nvars.len() { + match nvars[i].ty { + DataModelVariant::UnitVariant => continue, + DataModelVariant::NewtypeVariant(nt) => { + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, nt); + k += 1; + } + if !found { + out[outidx] = Some(nt); + outidx += 1; + } + } + DataModelVariant::TupleVariant(nts) => { + let mut x = 0; + + // For each type in the tuple... + while x < nts.len() { + // Get the types used by this field + let (arr, used) = type_chewer_nty::(nts[x]); + let mut j = 0; + // For each type in this field... + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + x += 1; + } + } + DataModelVariant::StructVariant(nvals) => { + let mut x = 0; + + // For each type in the struct... + while x < nvals.len() { + // Get the types used by this field + let (arr, used) = type_chewer_nty::(nvals[x].ty); + let mut j = 0; + // For each type in this field... + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + x += 1; + } + } + } + i += 1; + } + (out, outidx) + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 4 - REDUCTION TO CORRECT SIZE +////////////////////////////////////////////////////////////////////////////// + +/// This function reduces a `&[Option<&NamedType>]` to a `[&NamedType; A]`. +/// +/// The parameter `A` should be calculated by [`type_chewer_nty()`]. +/// +/// We also validate that all items >= idx `A` are in fact None. +pub const fn cruncher( + opts: &[Option<&'static NamedType>], +) -> [&'static NamedType; A] { + let mut out = [<() as Schema>::SCHEMA; A]; + let mut i = 0; + while i < A { + let Some(ty) = opts[i] else { panic!() }; + out[i] = ty; + i += 1; + } + while i < opts.len() { + assert!(opts[i].is_none()); + i += 1; + } + out +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 1-5 (macro op) +////////////////////////////////////////////////////////////////////////////// + +/// `unique_types` collects all unique, non-primitive types contained by the given +/// single type. It can be used with any type that implements the [`Schema`] trait, +/// and returns a `&'static [&'static NamedType]`. +#[macro_export] +macro_rules! unique_types { + ($t:ty) => { + const { + const MAX_TYS: usize = + $crate::uniques::unique_types_nty_upper(<$t as postcard_schema::Schema>::SCHEMA); + const BIG_RPT: ( + [Option<&'static postcard_schema::schema::NamedType>; MAX_TYS], + usize, + ) = $crate::uniques::type_chewer_nty(<$t as postcard_schema::Schema>::SCHEMA); + const SMALL_RPT: [&'static postcard_schema::schema::NamedType; BIG_RPT.1] = + $crate::uniques::cruncher(BIG_RPT.0.as_slice()); + SMALL_RPT.as_slice() + } + }; +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 6 - COLLECTION OF UNIQUES ACROSS MULTIPLE TYPES +////////////////////////////////////////////////////////////////////////////// + +/// This function turns an array of type lists into a single list of unique types +/// +/// The type parameter `M` is the maximum potential output size, it should be +/// equal to `lists.iter().map(|l| l.len()).sum()`, and should generally be +/// calculated as part of [`merge_unique_types!()`][crate::merge_unique_types]. +pub const fn merge_nty_lists( + lists: &[&[&'static NamedType]], +) -> ([Option<&'static NamedType>; M], usize) { + let mut out: [Option<&NamedType>; M] = [None; M]; + let mut out_ct = 0; + let mut i = 0; + + while i < lists.len() { + let mut j = 0; + let list = lists[i]; + while j < list.len() { + let item = list[j]; + let mut k = 0; + let mut found = false; + while !found && k < out_ct { + let Some(oitem) = out[k] else { panic!() }; + if nty_eq(item, oitem) { + found = true; + } + k += 1; + } + if !found { + out[out_ct] = Some(item); + out_ct += 1; + } + j += 1; + } + i += 1; + } + + (out, out_ct) +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 6-9 (macro op) +////////////////////////////////////////////////////////////////////////////// + +/// `merge_unique_types` collects all unique, non-primitive types contained by +/// the given comma separated types. It can be used with any types that implement +/// the [`Schema`] trait, and returns a `&'static [&'static NamedType]`. +#[macro_export] +macro_rules! merge_unique_types { + ($($t:ty,)*) => { + const { + const LISTS: &[&[&'static postcard_schema::schema::NamedType]] = &[ + $( + $crate::unique_types!($t), + )* + ]; + const TTL_COUNT: usize = const { + let mut i = 0; + let mut ct = 0; + while i < LISTS.len() { + ct += LISTS[i].len(); + i += 1; + } + ct + }; + const BIG_RPT: ([Option<&'static postcard_schema::schema::NamedType>; TTL_COUNT], usize) = $crate::uniques::merge_nty_lists(LISTS); + const SMALL_RPT: [&'static postcard_schema::schema::NamedType; BIG_RPT.1] = $crate::uniques::cruncher(BIG_RPT.0.as_slice()); + SMALL_RPT.as_slice() + } + } +} + +#[cfg(test)] +mod test { + #![allow(dead_code)] + use postcard_schema::{ + schema::{owned::OwnedNamedType, NamedType}, + Schema, + }; + + use crate::uniques::{ + is_prim, type_chewer_nty, unique_types_dmt_upper, unique_types_nty_upper, + }; + + #[derive(Schema)] + struct Example0; + + #[derive(Schema)] + struct ExampleA { + a: u32, + } + + #[derive(Schema)] + struct Example1 { + a: u32, + b: Option, + } + + #[derive(Schema)] + struct Example2 { + x: i32, + y: Option, + c: Example1, + } + + #[derive(Schema)] + struct Example3 { + a: u32, + b: Option, + c: Example2, + d: Example2, + e: Example2, + } + + #[test] + fn subpar_arrs() { + const MAXARR: usize = unique_types_nty_upper(<[Example0; 32]>::SCHEMA); + // I don't *like* this, it really should be 2. Leaving it as a test so + // I can remember that it's here. See TODO on unique_types_dmt_upper. + assert_eq!(MAXARR, 33); + } + + #[test] + fn uniqlo() { + const MAX0: usize = unique_types_nty_upper(Example0::SCHEMA); + const MAXA: usize = unique_types_nty_upper(ExampleA::SCHEMA); + const MAX1: usize = unique_types_nty_upper(Example1::SCHEMA); + const MAX2: usize = unique_types_nty_upper(Example2::SCHEMA); + const MAX3: usize = unique_types_nty_upper(Example3::SCHEMA); + assert_eq!(MAX0, 1); + assert_eq!(MAXA, 1); + assert_eq!(MAX1, 2); + assert_eq!(MAX2, 4); + assert_eq!(MAX3, 14); + + println!(); + println!("Example0"); + let (arr0, used): ([Option<_>; MAX0], usize) = type_chewer_nty(Example0::SCHEMA); + assert_eq!(used, 1); + println!("max: {MAX0} used: {used}"); + for a in arr0 { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + println!("ExampleA"); + let (arra, used): ([Option<_>; MAXA], usize) = type_chewer_nty(ExampleA::SCHEMA); + assert_eq!(used, 1); + println!("max: {MAXA} used: {used}"); + for a in arra { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + println!("Option"); + let (arr1, used): ( + [Option<_>; unique_types_nty_upper(Option::::SCHEMA)], + usize, + ) = type_chewer_nty(Option::::SCHEMA); + assert_eq!(used, 1); + println!( + "max: {} used: {used}", + unique_types_nty_upper(Option::::SCHEMA) + ); + for a in arr1 { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + println!("Example1"); + let (arr1, used): ([Option<_>; MAX1], usize) = type_chewer_nty(Example1::SCHEMA); + assert!(!is_prim(Example1::SCHEMA.ty)); + let child_ct = unique_types_dmt_upper(Example1::SCHEMA.ty); + assert_eq!(child_ct, 1); + assert_eq!(used, 2); + println!("max: {MAX1} used: {used}"); + for a in arr1 { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + println!("Example2"); + let (arr2, used): ([Option<_>; MAX2], usize) = type_chewer_nty(Example2::SCHEMA); + println!("max: {MAX2} used: {used}"); + for a in arr2 { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + println!("Example3"); + let (arr3, used): ([Option<_>; MAX3], usize) = type_chewer_nty(Example3::SCHEMA); + println!("max: {MAX3} used: {used}"); + for a in arr3 { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + let rpt0 = unique_types!(Example0); + println!("{}", rpt0.len()); + for a in rpt0 { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + let rpta = unique_types!(ExampleA); + println!("{}", rpta.len()); + for a in rpta { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + let rpt1 = unique_types!(Example1); + println!("{}", rpt1.len()); + for a in rpt1 { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + let rpt2 = unique_types!(Example2); + println!("{}", rpt2.len()); + for a in rpt2 { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + let rpt3 = unique_types!(Example3); + println!("{}", rpt3.len()); + for a in rpt3 { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + const MERGED: &[&NamedType] = merge_unique_types![Example3, ExampleA, Example0,]; + println!("{}", MERGED.len()); + for a in MERGED { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + println!(); + println!(); + println!(); + + // panic!("test passed but I want to see the data"); + } +}