diff --git a/example/firmware/Cargo.lock b/example/firmware/Cargo.lock index 3a594ba..da38593 100644 --- a/example/firmware/Cargo.lock +++ b/example/firmware/Cargo.lock @@ -344,17 +344,6 @@ dependencies = [ "nb 1.1.0", ] -[[package]] -name = "embassy-executor" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec648daedd2143466eff4b3e8002024f9f6c1de4ab7666bb679688752624c925" -dependencies = [ - "critical-section", - "document-features", - "embassy-executor-macros 0.4.1", -] - [[package]] name = "embassy-executor" version = "0.6.0" @@ -364,23 +353,11 @@ dependencies = [ "critical-section", "defmt", "document-features", - "embassy-executor-macros 0.5.0", + "embassy-executor-macros", "embassy-time-driver", "embassy-time-queue-driver", ] -[[package]] -name = "embassy-executor-macros" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad454accf80050e9cf7a51e994132ba0e56286b31f9317b68703897c328c59b5" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 2.0.72", -] - [[package]] name = "embassy-executor-macros" version = "0.5.0" @@ -1154,12 +1131,13 @@ dependencies = [ name = "postcard-rpc" version = "0.8.0" dependencies = [ - "embassy-executor 0.5.0", + "embassy-executor", "embassy-sync", "embassy-usb", "embassy-usb-driver", "futures-util", "heapless 0.8.0", + "paste", "postcard", "postcard-schema", "serde", @@ -1759,7 +1737,7 @@ dependencies = [ "cortex-m-rt", "defmt", "defmt-rtt", - "embassy-executor 0.6.0", + "embassy-executor", "embassy-rp", "embassy-sync", "embassy-time", diff --git a/example/firmware/src/bin/comms-01.rs b/example/firmware/src/bin/comms-01.rs index f188857..2aa2de8 100644 --- a/example/firmware/src/bin/comms-01.rs +++ b/example/firmware/src/bin/comms-01.rs @@ -3,77 +3,119 @@ use defmt::info; use embassy_executor::Spawner; -use embassy_rp::{ - peripherals::USB, - usb::{self, Driver, Endpoint, Out}, -}; - +use embassy_rp::{peripherals::USB, usb}; use embassy_sync::blocking_mutex::raw::ThreadModeRawMutex; - -use embassy_usb::UsbDevice; +use embassy_usb::{Config, UsbDevice}; use postcard_rpc::{ define_dispatch, - target_server::{buffers::AllBuffers, configure_usb, example_config, rpc_dispatch}, - WireHeader, + header::VarHeader, + server::{ + impls::embassy_usb_v0_3::{ + dispatch_impl::{WireRxBuf, WireRxImpl, WireSpawnImpl, WireStorage, WireTxImpl}, + PacketBuffers, + }, + Dispatch, Server, + }, }; - use static_cell::ConstStaticCell; use workbook_fw::{get_unique_id, Irqs}; -use workbook_icd::PingEndpoint; +use workbook_icd::{PingEndpoint, ENDPOINT_LIST, TOPICS_IN_LIST, TOPICS_OUT_LIST}; +use {defmt_rtt as _, panic_probe as _}; + +pub struct Context; -static ALL_BUFFERS: ConstStaticCell> = - ConstStaticCell::new(AllBuffers::new()); +type AppDriver = usb::Driver<'static, USB>; +type AppStorage = WireStorage; +type BufStorage = PacketBuffers<1024, 1024>; +type AppTx = WireTxImpl; +type AppRx = WireRxImpl; +type AppServer = Server; -pub struct Context {} +static PBUFS: ConstStaticCell = ConstStaticCell::new(BufStorage::new()); +static STORAGE: AppStorage = AppStorage::new(); +fn usb_config() -> Config<'static> { + let mut config = Config::new(0x16c0, 0x27DD); + config.manufacturer = Some("OneVariable"); + config.product = Some("ov-twin"); + config.serial_number = Some("12345678"); + + // Required for windows compatibility. + // https://developer.nordicsemi.com/nRF_Connect_SDK/doc/1.9.1/kconfig/CONFIG_CDC_ACM_IAD.html#help + config.device_class = 0xEF; + config.device_sub_class = 0x02; + config.device_protocol = 0x01; + config.composite_with_iads = true; + + config +} define_dispatch! { - dispatcher: Dispatcher< - Mutex = ThreadModeRawMutex, - Driver = usb::Driver<'static, USB>, - Context = Context - >; - PingEndpoint => blocking ping_handler, + app: MyApp; + spawn_fn: spawn_fn; + tx_impl: AppTx; + spawn_impl: WireSpawnImpl; + context: Context; + + endpoints: { + list: ENDPOINT_LIST; + + | EndpointTy | kind | handler | + | ---------- | ---- | ------- | + | PingEndpoint | blocking | ping_handler | + }; + topics_in: { + list: TOPICS_IN_LIST; + + | TopicTy | kind | handler | + | ---------- | ---- | ------- | + }; + topics_out: { + list: TOPICS_OUT_LIST; + }; } #[embassy_executor::main] async fn main(spawner: Spawner) { // SYSTEM INIT info!("Start"); - let mut p = embassy_rp::init(Default::default()); - let unique_id = get_unique_id(&mut p.FLASH).unwrap(); + let unique_id = defmt::unwrap!(get_unique_id(&mut p.FLASH)); info!("id: {=u64:016X}", unique_id); // USB/RPC INIT let driver = usb::Driver::new(p.USB, Irqs); - let mut config = example_config(); - config.manufacturer = Some("OneVariable"); - config.product = Some("ov-twin"); - let buffers = ALL_BUFFERS.take(); - let (device, ep_in, ep_out) = configure_usb(driver, &mut buffers.usb_device, config); - let dispatch = Dispatcher::new(&mut buffers.tx_buf, ep_in, Context {}); + let pbufs = PBUFS.take(); + let config = usb_config(); - spawner.must_spawn(dispatch_task(ep_out, dispatch, &mut buffers.rx_buf)); + let context = Context; + let (device, tx_impl, rx_impl) = STORAGE.init(driver, config, pbufs.tx_buf.as_mut_slice()); + let dispatcher = MyApp::new(context, spawner.into()); + let vkk = dispatcher.min_key_len(); + let mut server: AppServer = Server::new( + &tx_impl, + rx_impl, + pbufs.rx_buf.as_mut_slice(), + dispatcher, + vkk, + ); spawner.must_spawn(usb_task(device)); -} -/// This actually runs the dispatcher -#[embassy_executor::task] -async fn dispatch_task( - ep_out: Endpoint<'static, USB, Out>, - dispatch: Dispatcher, - rx_buf: &'static mut [u8], -) { - rpc_dispatch(ep_out, dispatch, rx_buf).await; + loop { + // If the host disconnects, we'll return an error here. + // If this happens, just wait until the host reconnects + let _ = server.run().await; + } } /// This handles the low level USB management #[embassy_executor::task] -pub async fn usb_task(mut usb: UsbDevice<'static, Driver<'static, USB>>) { +pub async fn usb_task(mut usb: UsbDevice<'static, AppDriver>) { usb.run().await; } -fn ping_handler(_context: &mut Context, header: WireHeader, rqst: u32) -> u32 { - info!("ping: seq - {=u32}", header.seq_no); +// --- + +fn ping_handler(_context: &mut Context, _header: VarHeader, rqst: u32) -> u32 { + info!("ping"); rqst } diff --git a/example/firmware/src/bin/comms-02.rs b/example/firmware/src/bin/comms-02.rs index ad1efff..4461a34 100644 --- a/example/firmware/src/bin/comms-02.rs +++ b/example/firmware/src/bin/comms-02.rs @@ -8,31 +8,38 @@ use embassy_rp::{ peripherals::{PIO0, SPI0, USB}, pio::Pio, spi::{self, Spi}, - usb::{self, Driver, Endpoint, Out}, + usb, }; use embassy_sync::{blocking_mutex::raw::ThreadModeRawMutex, mutex::Mutex}; use embassy_time::{Delay, Duration, Ticker}; -use embassy_usb::UsbDevice; +use embassy_usb::{Config, UsbDevice}; use embedded_hal_bus::spi::ExclusiveDevice; use lis3dh_async::{Lis3dh, Lis3dhSPI}; use portable_atomic::{AtomicBool, Ordering}; use postcard_rpc::{ define_dispatch, - target_server::{ - buffers::AllBuffers, configure_usb, example_config, rpc_dispatch, sender::Sender, - SpawnContext, + header::VarHeader, + server::{ + impls::embassy_usb_v0_3::{ + dispatch_impl::{ + spawn_fn, WireRxBuf, WireRxImpl, WireSpawnImpl, WireStorage, WireTxImpl, + }, + PacketBuffers, + }, + Dispatch, Sender, Server, SpawnContext, }, - WireHeader, }; use smart_leds::{colors::BLACK, RGB8}; use static_cell::{ConstStaticCell, StaticCell}; use workbook_fw::{ - get_unique_id, ws2812::{self, Ws2812}, Irqs + get_unique_id, + ws2812::{self, Ws2812}, + Irqs, }; use workbook_icd::{ AccelTopic, Acceleration, BadPositionError, GetUniqueIdEndpoint, PingEndpoint, Rgb8, SetAllLedEndpoint, SetSingleLedEndpoint, SingleLed, StartAccel, StartAccelerationEndpoint, - StopAccelerationEndpoint, + StopAccelerationEndpoint, ENDPOINT_LIST, TOPICS_IN_LIST, TOPICS_OUT_LIST, }; use {defmt_rtt as _, panic_probe as _}; @@ -40,9 +47,6 @@ pub type Accel = Lis3dh, Output<'static>, Delay>>>; static ACCEL: StaticCell> = StaticCell::new(); -static ALL_BUFFERS: ConstStaticCell> = - ConstStaticCell::new(AllBuffers::new()); - pub struct Context { pub unique_id: u64, pub ws2812: Ws2812<'static, PIO0, 0, 24>, @@ -61,18 +65,60 @@ impl SpawnContext for Context { } } +type AppDriver = usb::Driver<'static, USB>; +type AppStorage = WireStorage; +type BufStorage = PacketBuffers<1024, 1024>; +type AppTx = WireTxImpl; +type AppRx = WireRxImpl; +type AppServer = Server; + +static PBUFS: ConstStaticCell = ConstStaticCell::new(BufStorage::new()); +static STORAGE: AppStorage = AppStorage::new(); + +fn usb_config() -> Config<'static> { + let mut config = Config::new(0x16c0, 0x27DD); + config.manufacturer = Some("OneVariable"); + config.product = Some("ov-twin"); + config.serial_number = Some("12345678"); + + // Required for windows compatibility. + // https://developer.nordicsemi.com/nRF_Connect_SDK/doc/1.9.1/kconfig/CONFIG_CDC_ACM_IAD.html#help + config.device_class = 0xEF; + config.device_sub_class = 0x02; + config.device_protocol = 0x01; + config.composite_with_iads = true; + + config +} + define_dispatch! { - dispatcher: Dispatcher< - Mutex = ThreadModeRawMutex, - Driver = usb::Driver<'static, USB>, - Context = Context, - >; - PingEndpoint => blocking ping_handler, - GetUniqueIdEndpoint => blocking unique_id_handler, - SetSingleLedEndpoint => async set_led_handler, - SetAllLedEndpoint => async set_all_led_handler, - StartAccelerationEndpoint => spawn accelerometer_handler, - StopAccelerationEndpoint => blocking accelerometer_stop_handler, + app: MyApp; + spawn_fn: spawn_fn; + tx_impl: AppTx; + spawn_impl: WireSpawnImpl; + context: Context; + + endpoints: { + list: ENDPOINT_LIST; + + | EndpointTy | kind | handler | + | ---------- | ---- | ------- | + | PingEndpoint | blocking | ping_handler | + | GetUniqueIdEndpoint | blocking | unique_id_handler | + | SetSingleLedEndpoint | async | set_led_handler | + | SetAllLedEndpoint | async | set_all_led_handler | + | StartAccelerationEndpoint | spawn | accelerometer_handler | + | StopAccelerationEndpoint | blocking | accelerometer_stop_handler | + }; + topics_in: { + list: TOPICS_IN_LIST; + + | TopicTy | kind | handler | + | ---------- | ---- | ------- | + }; + topics_out: { + list: TOPICS_OUT_LIST; + }; } #[embassy_executor::main] @@ -106,60 +152,59 @@ async fn main(spawner: Spawner) { // USB/RPC INIT let driver = usb::Driver::new(p.USB, Irqs); - let mut config = example_config(); - config.manufacturer = Some("OneVariable"); - config.product = Some("ov-twin"); - let buffers = ALL_BUFFERS.take(); - let (device, ep_in, ep_out) = configure_usb(driver, &mut buffers.usb_device, config); - let dispatch = Dispatcher::new( - &mut buffers.tx_buf, - ep_in, - Context { - unique_id, - ws2812, - ws2812_state: [BLACK; 24], - accel: accel_ref, - }, - ); + let pbufs = PBUFS.take(); + let config = usb_config(); - spawner.must_spawn(dispatch_task(ep_out, dispatch, &mut buffers.rx_buf)); + let context = Context { + unique_id, + ws2812, + ws2812_state: [BLACK; 24], + accel: accel_ref, + }; + + let (device, tx_impl, rx_impl) = STORAGE.init(driver, config, pbufs.tx_buf.as_mut_slice()); + let dispatcher = MyApp::new(context, spawner.into()); + let vkk = dispatcher.min_key_len(); + let mut server: AppServer = Server::new( + &tx_impl, + rx_impl, + pbufs.rx_buf.as_mut_slice(), + dispatcher, + vkk, + ); spawner.must_spawn(usb_task(device)); -} -/// This actually runs the dispatcher -#[embassy_executor::task] -async fn dispatch_task( - ep_out: Endpoint<'static, USB, Out>, - dispatch: Dispatcher, - rx_buf: &'static mut [u8], -) { - rpc_dispatch(ep_out, dispatch, rx_buf).await; + loop { + // If the host disconnects, we'll return an error here. + // If this happens, just wait until the host reconnects + let _ = server.run().await; + } } /// This handles the low level USB management #[embassy_executor::task] -pub async fn usb_task(mut usb: UsbDevice<'static, Driver<'static, USB>>) { +pub async fn usb_task(mut usb: UsbDevice<'static, AppDriver>) { usb.run().await; } // --- -fn ping_handler(_context: &mut Context, header: WireHeader, rqst: u32) -> u32 { - info!("ping: seq - {=u32}", header.seq_no); +fn ping_handler(_context: &mut Context, _header: VarHeader, rqst: u32) -> u32 { + info!("ping"); rqst } -fn unique_id_handler(context: &mut Context, header: WireHeader, _rqst: ()) -> u64 { - info!("unique_id: seq - {=u32}", header.seq_no); +fn unique_id_handler(context: &mut Context, _header: VarHeader, _rqst: ()) -> u64 { + info!("unique_id"); context.unique_id } async fn set_led_handler( context: &mut Context, - header: WireHeader, + _header: VarHeader, rqst: SingleLed, ) -> Result<(), BadPositionError> { - info!("set_led: seq - {=u32}", header.seq_no); + info!("set_led"); if rqst.position >= 24 { return Err(BadPositionError); } @@ -173,8 +218,8 @@ async fn set_led_handler( Ok(()) } -async fn set_all_led_handler(context: &mut Context, header: WireHeader, rqst: [Rgb8; 24]) { - info!("set_all_led: seq - {=u32}", header.seq_no); +async fn set_all_led_handler(context: &mut Context, _header: VarHeader, rqst: [Rgb8; 24]) { + info!("set_all_led"); context .ws2812_state .iter_mut() @@ -192,9 +237,9 @@ static STOP: AtomicBool = AtomicBool::new(false); #[embassy_executor::task] async fn accelerometer_handler( context: SpawnCtx, - header: WireHeader, + header: VarHeader, rqst: StartAccel, - sender: Sender>, + sender: Sender, ) { let mut accel = context.accel.lock().await; if sender @@ -209,7 +254,7 @@ async fn accelerometer_handler( defmt::unwrap!(accel.set_range(lis3dh_async::Range::G8).await.map_err(drop)); let mut ticker = Ticker::every(Duration::from_millis(rqst.interval_ms.into())); - let mut seq = 0; + let mut seq = 0u8; while !STOP.load(Ordering::Acquire) { ticker.next().await; let acc = defmt::unwrap!(accel.accel_raw().await.map_err(drop)); @@ -219,7 +264,11 @@ async fn accelerometer_handler( y: acc.y, z: acc.z, }; - if sender.publish::(seq, &msg).await.is_err() { + if sender + .publish::(seq.into(), &msg) + .await + .is_err() + { defmt::error!("Send error!"); break; } @@ -229,8 +278,8 @@ async fn accelerometer_handler( STOP.store(false, Ordering::Release); } -fn accelerometer_stop_handler(context: &mut Context, header: WireHeader, _rqst: ()) -> bool { - info!("accel_stop: seq - {=u32}", header.seq_no); +fn accelerometer_stop_handler(context: &mut Context, _header: VarHeader, _rqst: ()) -> bool { + info!("accel_stop"); let was_busy = context.accel.try_lock().is_err(); if was_busy { STOP.store(true, Ordering::Release); diff --git a/example/workbook-host/Cargo.lock b/example/workbook-host/Cargo.lock index d50a8c7..86a1103 100644 --- a/example/workbook-host/Cargo.lock +++ b/example/workbook-host/Cargo.lock @@ -136,6 +136,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "errno" version = "0.3.8" @@ -407,6 +413,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pin-project" version = "1.1.5" @@ -469,9 +481,11 @@ dependencies = [ "heapless 0.8.0", "maitake-sync", "nusb", + "paste", "postcard", "postcard-schema", "serde", + "ssmarshal", "thiserror", "tokio", "tracing", @@ -655,6 +669,16 @@ dependencies = [ "lock_api", ] +[[package]] +name = "ssmarshal" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3e6ad23b128192ed337dfa4f1b8099ced0c2bf30d61e551b65fda5916dbb850" +dependencies = [ + "encode_unicode", + "serde", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/example/workbook-host/src/bin/comms-01.rs b/example/workbook-host/src/bin/comms-01.rs index 24306d8..1313f6e 100644 --- a/example/workbook-host/src/bin/comms-01.rs +++ b/example/workbook-host/src/bin/comms-01.rs @@ -1,7 +1,7 @@ use std::time::Duration; -use workbook_host_client::client::WorkbookClient; use tokio::time::interval; +use workbook_host_client::client::WorkbookClient; #[tokio::main] pub async fn main() { diff --git a/example/workbook-host/src/client.rs b/example/workbook-host/src/client.rs index 3e8ded2..0688a7a 100644 --- a/example/workbook-host/src/client.rs +++ b/example/workbook-host/src/client.rs @@ -1,9 +1,14 @@ -use std::convert::Infallible; use postcard_rpc::{ + header::VarSeqKind, host_client::{HostClient, HostErr}, standard_icd::{WireError, ERROR_PATH}, }; -use workbook_icd::{AccelRange, BadPositionError, GetUniqueIdEndpoint, PingEndpoint, Rgb8, SetAllLedEndpoint, SetSingleLedEndpoint, SingleLed, StartAccel, StartAccelerationEndpoint, StopAccelerationEndpoint}; +use std::convert::Infallible; +use workbook_icd::{ + AccelRange, BadPositionError, GetUniqueIdEndpoint, PingEndpoint, Rgb8, SetAllLedEndpoint, + SetSingleLedEndpoint, SingleLed, StartAccel, StartAccelerationEndpoint, + StopAccelerationEndpoint, +}; pub struct WorkbookClient { pub client: HostClient, @@ -39,8 +44,12 @@ impl FlattenErr for Result { impl WorkbookClient { pub fn new() -> Self { - let client = - HostClient::new_raw_nusb(|d| d.product_string() == Some("ov-twin"), ERROR_PATH, 8); + let client = HostClient::new_raw_nusb( + |d| d.product_string() == Some("ov-twin"), + ERROR_PATH, + 8, + VarSeqKind::Seq2, + ); Self { client } } diff --git a/example/workbook-icd/Cargo.lock b/example/workbook-icd/Cargo.lock index a2cc8d1..a39eb63 100644 --- a/example/workbook-icd/Cargo.lock +++ b/example/workbook-icd/Cargo.lock @@ -87,6 +87,12 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "postcard" version = "1.0.8" @@ -111,9 +117,10 @@ dependencies = [ [[package]] name = "postcard-rpc" -version = "0.7.0" +version = "0.8.0" dependencies = [ "heapless 0.8.0", + "paste", "postcard", "postcard-schema", "serde", diff --git a/example/workbook-icd/src/lib.rs b/example/workbook-icd/src/lib.rs index 2783dc1..ecf109c 100644 --- a/example/workbook-icd/src/lib.rs +++ b/example/workbook-icd/src/lib.rs @@ -1,22 +1,35 @@ #![no_std] +use postcard_rpc::{endpoints, topics}; use postcard_schema::Schema; -use postcard_rpc::{endpoint, topic}; use serde::{Deserialize, Serialize}; -endpoint!(PingEndpoint, u32, u32, "ping"); - // --- -endpoint!(GetUniqueIdEndpoint, (), u64, "unique_id/get"); - -endpoint!(SetSingleLedEndpoint, SingleLed, Result<(), BadPositionError>, "led/set_one"); -endpoint!(SetAllLedEndpoint, [Rgb8; 24], (), "led/set_all"); +endpoints! { + list = ENDPOINT_LIST; + | EndpointTy | RequestTy | ResponseTy | Path | + | ---------- | --------- | ---------- | ---- | + | PingEndpoint | u32 | u32 | "ping" | + | GetUniqueIdEndpoint | () | u64 | "unique_id/get" | + | SetSingleLedEndpoint | SingleLed | Result<(), BadPositionError> | "led/set_one" | + | SetAllLedEndpoint | [Rgb8; 24] | () | "led/set_all" | + | StartAccelerationEndpoint | StartAccel | () | "accel/start" | + | StopAccelerationEndpoint | () | bool | "accel/stop" | +} -endpoint!(StartAccelerationEndpoint, StartAccel, (), "accel/start"); -endpoint!(StopAccelerationEndpoint, (), bool, "accel/stop"); +topics! { + list = TOPICS_IN_LIST; + | TopicTy | MessageTy | Path | + | ------- | --------- | ---- | +} -topic!(AccelTopic, Acceleration, "accel/data"); +topics! { + list = TOPICS_OUT_LIST; + | TopicTy | MessageTy | Path | + | ------- | --------- | ---- | + | AccelTopic | Acceleration | "accel/data" | +} #[derive(Serialize, Deserialize, Schema, Debug, PartialEq)] pub struct SingleLed { @@ -28,7 +41,7 @@ pub struct SingleLed { pub struct Rgb8 { pub r: u8, pub g: u8, - pub b: u8 + pub b: u8, } #[derive(Serialize, Deserialize, Schema, Debug, PartialEq)] diff --git a/source/postcard-rpc-test/Cargo.lock b/source/postcard-rpc-test/Cargo.lock index 0306838..5c791ae 100644 --- a/source/postcard-rpc-test/Cargo.lock +++ b/source/postcard-rpc-test/Cargo.lock @@ -131,6 +131,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "generator" version = "0.7.5" @@ -320,6 +326,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pin-project" version = "1.1.3" @@ -394,9 +406,11 @@ version = "0.8.0" dependencies = [ "heapless 0.8.0", "maitake-sync", + "paste", "postcard", "postcard-schema", "serde", + "ssmarshal", "thiserror", "tokio", "tracing", @@ -569,6 +583,16 @@ dependencies = [ "lock_api", ] +[[package]] +name = "ssmarshal" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3e6ad23b128192ed337dfa4f1b8099ced0c2bf30d61e551b65fda5916dbb850" +dependencies = [ + "encode_unicode", + "serde", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/source/postcard-rpc-test/tests/basic.rs b/source/postcard-rpc-test/tests/basic.rs index e830fb1..33ffff7 100644 --- a/source/postcard-rpc-test/tests/basic.rs +++ b/source/postcard-rpc-test/tests/basic.rs @@ -1,305 +1,449 @@ -use std::{collections::HashMap, time::Duration}; - -use postcard_schema::Schema; -use postcard_rpc::test_utils::local_setup; -use postcard_rpc::{ - endpoint, headered::to_stdvec_keyed, topic, Dispatch, Endpoint, Key, Topic, WireHeader, +use core::{ + sync::atomic::{AtomicUsize, Ordering}, + time::Duration, }; +use std::{sync::Arc, time::Instant}; + +use postcard_schema::{schema::owned::OwnedNamedType, Schema}; use serde::{Deserialize, Serialize}; -use tokio::task::yield_now; -use tokio::time::timeout; +use tokio::{sync::mpsc, task::yield_now}; -endpoint!(EndpointOne, Req1, Resp1, "endpoint/one"); -topic!(TopicOne, Req1, "unsolicited/topic1"); +use postcard_rpc::{ + define_dispatch, endpoints, + header::{VarHeader, VarKey, VarKeyKind, VarSeq, VarSeqKind}, + host_client::test_channels as client, + server::{ + impls::test_channels::{ + dispatch_impl::{new_server, spawn_fn, Settings, WireSpawnImpl, WireTxImpl}, + ChannelWireRx, ChannelWireSpawn, ChannelWireTx, + }, + Dispatch, Sender, SpawnContext, + }, + topics, Endpoint, Topic, +}; -#[derive(Debug, PartialEq, Serialize, Deserialize, Schema)] -pub struct Req1 { - a: u8, - b: u64, +#[derive(Serialize, Deserialize, Schema)] +pub struct AReq(pub u8); +#[derive(Serialize, Deserialize, Schema)] +pub struct AResp(pub u8); +#[derive(Serialize, Deserialize, Schema)] +pub struct BReq(pub u16); +#[derive(Serialize, Deserialize, Schema)] +pub struct BResp(pub u32); +#[derive(Serialize, Deserialize, Schema)] +pub struct GReq; +#[derive(Serialize, Deserialize, Schema)] +pub struct GResp; +#[derive(Serialize, Deserialize, Schema)] +pub struct DReq; +#[derive(Serialize, Deserialize, Schema)] +pub struct DResp; +#[derive(Serialize, Deserialize, Schema)] +pub struct EReq; +#[derive(Serialize, Deserialize, Schema)] +pub struct EResp; +#[derive(Serialize, Deserialize, Schema)] +pub struct ZMsg(pub i16); + +endpoints! { + list = ENDPOINT_LIST; + | EndpointTy | RequestTy | ResponseTy | Path | + | ---------- | --------- | ---------- | ---- | + | AlphaEndpoint | AReq | AResp | "alpha" | + | BetaEndpoint | BReq | BResp | "beta" | + | GammaEndpoint | GReq | GResp | "gamma" | + | DeltaEndpoint | DReq | DResp | "delta" | + | EpsilonEndpoint | EReq | EResp | "epsilon" | } -#[derive(Debug, PartialEq, Serialize, Deserialize, Schema)] -pub struct Resp1 { - c: [u8; 8], - d: i32, +topics! { + list = TOPICS_IN_LIST; + | TopicTy | MessageTy | Path | + | ---------- | --------- | ---- | + | ZetaTopic1 | ZMsg | "zeta1" | + | ZetaTopic2 | ZMsg | "zeta2" | + | ZetaTopic3 | ZMsg | "zeta3" | } -#[derive(Debug, PartialEq, Serialize, Deserialize, Schema)] -pub enum WireError { - LeastBad, - MediumBad, - MostBad, +topics! { + list = TOPICS_OUT_LIST; + | TopicTy | MessageTy | Path | + | ---------- | --------- | ---- | + | ZetaTopic10 | ZMsg | "zeta10" | } -struct SmokeContext { - got: HashMap)>, - next_err: bool, +pub struct TestContext { + pub ctr: Arc, + pub topic_ctr: Arc, } -fn store_disp(hdr: &WireHeader, ctx: &mut SmokeContext, body: &[u8]) -> Result<(), WireError> { - if ctx.next_err { - ctx.next_err = false; - return Err(WireError::MediumBad); - } - ctx.got.insert(hdr.key, (hdr.clone(), body.to_vec())); - Ok(()) +pub struct TestSpawnContext { + pub ctr: Arc, + pub topic_ctr: Arc, } -impl SmokeDispatch { - pub fn new() -> Self { - let ctx = SmokeContext { - got: HashMap::new(), - next_err: false, - }; - let disp = Dispatch::new(ctx); - Self { disp } - } -} +impl SpawnContext for TestContext { + type SpawnCtxt = TestSpawnContext; -struct SmokeDispatch { - disp: Dispatch, -} - -#[tokio::test] -async fn smoke_reqresp() { - let (mut srv, client) = local_setup::(8, "error"); - - // Create the Dispatch Server - let mut disp = SmokeDispatch::new(); - disp.disp.add_handler::(store_disp).unwrap(); - - // Start the request - let send1 = tokio::spawn({ - let client = client.clone(); - async move { - client - .send_resp::(&Req1 { a: 10, b: 100 }) - .await + fn spawn_ctxt(&mut self) -> Self::SpawnCtxt { + TestSpawnContext { + ctr: self.ctr.clone(), + topic_ctr: self.topic_ctr.clone(), } - }); - - // As the wire, get the outgoing request - let out1 = srv.recv_from_client().await.unwrap(); - - // Does the outgoing value match what we expect? - let exp_out = to_stdvec_keyed(0, EndpointOne::REQ_KEY, &Req1 { a: 10, b: 100 }).unwrap(); - let act_out = out1.to_bytes(); - assert_eq!(act_out, exp_out); - - // The request is still awaiting a response - assert!(!send1.is_finished()); - - // Feed the request through the dispatcher - disp.disp.dispatch(&act_out).unwrap(); + } +} - // Make sure we "dispatched" it right - let disp_got = disp.disp.context().got.remove(&out1.header.key).unwrap(); - assert_eq!(disp_got.0, out1.header); - assert!(act_out.ends_with(&disp_got.1)); +define_dispatch! { + app: SingleDispatcher; + spawn_fn: spawn_fn; + tx_impl: WireTxImpl; + spawn_impl: WireSpawnImpl; + context: TestContext; - // The request is still awaiting a response - assert!(!send1.is_finished()); + endpoints: { + list: ENDPOINT_LIST; - // Feed a simulated response "from the wire" back to the - // awaiting request - const RESP_001: Resp1 = Resp1 { - c: [1, 2, 3, 4, 5, 6, 7, 8], - d: -10, + | EndpointTy | kind | handler | + | ---------- | ---- | ------- | + | AlphaEndpoint | async | test_alpha_handler | + | BetaEndpoint | spawn | test_beta_handler | + }; + topics_in: { + list: TOPICS_IN_LIST; + + | TopicTy | kind | handler | + | ---------- | ---- | ------- | + | ZetaTopic1 | blocking | test_zeta_blocking | + | ZetaTopic2 | async | test_zeta_async | + | ZetaTopic3 | spawn | test_zeta_spawn | + }; + topics_out: { + list: TOPICS_OUT_LIST; }; - srv.reply::(out1.header.seq_no, &RESP_001) - .await - .unwrap(); - - // Now wait for the request to complete - let end = send1.await.unwrap().unwrap(); - - // We got the simulated value back - assert_eq!(end, RESP_001); } -#[tokio::test] -async fn smoke_publish() { - let (mut srv, client) = local_setup::(8, "error"); - - // Start the request - client - .publish::(123, &Req1 { a: 10, b: 100 }) - .await - .unwrap(); - - // As the wire, get the outgoing request - let out1 = srv.recv_from_client().await.unwrap(); - - // Does the outgoing value match what we expect? - let exp_out = to_stdvec_keyed(123, TopicOne::TOPIC_KEY, &Req1 { a: 10, b: 100 }).unwrap(); - let act_out = out1.to_bytes(); - assert_eq!(act_out, exp_out); +fn test_zeta_blocking( + context: &mut TestContext, + _header: VarHeader, + _body: ZMsg, + _out: &Sender, +) { + context.topic_ctr.fetch_add(1, Ordering::Relaxed); } -#[tokio::test] -async fn smoke_subscribe() { - let (mut srv, client) = local_setup::(8, "error"); - - // Do a subscription - let mut sub = client.subscribe::(8).await.unwrap(); - - // Start the listen - let recv1 = timeout(Duration::from_millis(100), sub.recv()); - let _ = recv1.await.unwrap_err(); +async fn test_zeta_async( + context: &mut TestContext, + _header: VarHeader, + _body: ZMsg, + _out: &Sender, +) { + context.topic_ctr.fetch_add(1, Ordering::Relaxed); +} - // Send a message on the topic - const VAL: Req1 = Req1 { a: 10, b: 100 }; - srv.publish::(123, &VAL).await.unwrap(); +async fn test_zeta_spawn( + context: TestSpawnContext, + _header: VarHeader, + _body: ZMsg, + _out: Sender, +) { + context.topic_ctr.fetch_add(1, Ordering::Relaxed); +} - // Now the request resolves - let publ = timeout(Duration::from_millis(100), sub.recv()) - .await - .unwrap() - .unwrap(); +async fn test_alpha_handler(context: &mut TestContext, _header: VarHeader, body: AReq) -> AResp { + context.ctr.fetch_add(1, Ordering::Relaxed); + AResp(body.0) +} - assert_eq!(publ, VAL); +async fn test_beta_handler( + context: TestSpawnContext, + header: VarHeader, + body: BReq, + out: Sender, +) { + context.ctr.fetch_add(1, Ordering::Relaxed); + let _ = out + .reply::(header.seq_no, &BResp(body.0.into())) + .await; } #[tokio::test] -async fn smoke_io_error() { - let (mut srv, client) = local_setup::(8, "error"); - - // Do one round trip to make sure the connection works - { - let rr_rt = tokio::task::spawn({ - let client = client.clone(); - async move { - client - .send_resp::(&Req1 { a: 10, b: 100 }) - .await +async fn smoke() { + let (client_tx, server_rx) = mpsc::channel(16); + let (server_tx, mut client_rx) = mpsc::channel(16); + let topic_ctr = Arc::new(AtomicUsize::new(0)); + + let app = SingleDispatcher::new( + TestContext { + ctr: Arc::new(AtomicUsize::new(0)), + topic_ctr: topic_ctr.clone(), + }, + ChannelWireSpawn {}, + ); + + let cwrx = ChannelWireRx::new(server_rx); + let cwtx = ChannelWireTx::new(server_tx); + let kkind = app.min_key_len(); + let mut server = new_server( + app, + Settings { + tx: cwtx, + rx: cwrx, + buf: 1024, + kkind, + }, + ); + tokio::task::spawn(async move { + server.run().await; + }); + + // manually build request - Alpha + let mut msg = VarHeader { + key: VarKey::Key8(AlphaEndpoint::REQ_KEY), + seq_no: VarSeq::Seq4(123), + } + .write_to_vec(); + let body = postcard::to_stdvec(&AReq(42)).unwrap(); + msg.extend_from_slice(&body); + client_tx.send(msg).await.unwrap(); + let resp = client_rx.recv().await.unwrap(); + + // manually extract response + let (hdr, body) = VarHeader::take_from_slice(&resp).unwrap(); + let resp = postcard::from_bytes::<::Response>(body).unwrap(); + assert_eq!(resp.0, 42); + assert_eq!(hdr.key, VarKey::Key8(AlphaEndpoint::RESP_KEY)); + assert_eq!(hdr.seq_no, VarSeq::Seq4(123)); + + // manually build request - Beta + let mut msg = VarHeader { + key: VarKey::Key8(BetaEndpoint::REQ_KEY), + seq_no: VarSeq::Seq4(234), + } + .write_to_vec(); + let body = postcard::to_stdvec(&BReq(1000)).unwrap(); + msg.extend_from_slice(&body); + client_tx.send(msg).await.unwrap(); + let resp = client_rx.recv().await.unwrap(); + + // manually extract response + let (hdr, body) = VarHeader::take_from_slice(&resp).unwrap(); + let resp = postcard::from_bytes::<::Response>(body).unwrap(); + assert_eq!(resp.0, 1000); + assert_eq!(hdr.key, VarKey::Key8(BetaEndpoint::RESP_KEY)); + assert_eq!(hdr.seq_no, VarSeq::Seq4(234)); + + // blocking topic handler + for i in 0..3 { + let mut msg = VarHeader { + key: VarKey::Key8(ZetaTopic1::TOPIC_KEY), + seq_no: VarSeq::Seq4(i), + } + .write_to_vec(); + + let body = postcard::to_stdvec(&ZMsg(456)).unwrap(); + msg.extend_from_slice(&body); + client_tx.send(msg).await.unwrap(); + + let start = Instant::now(); + let mut good = false; + while start.elapsed() < Duration::from_millis(100) { + let ct = topic_ctr.load(Ordering::Relaxed); + if ct == (i + 1) as usize { + good = true; + break; + } else { + yield_now().await } - }); - - // As the wire, get the outgoing request - let out1 = srv.recv_from_client().await.unwrap(); - - // Does the outgoing value match what we expect? - let exp_out = to_stdvec_keyed(0, EndpointOne::REQ_KEY, &Req1 { a: 10, b: 100 }).unwrap(); - let act_out = out1.to_bytes(); - assert_eq!(act_out, exp_out); - - // The request is still awaiting a response - assert!(!rr_rt.is_finished()); - - // Feed a simulated response "from the wire" back to the - // awaiting request - const RESP_001: Resp1 = Resp1 { - c: [1, 2, 3, 4, 5, 6, 7, 8], - d: -10, - }; - srv.reply::(out1.header.seq_no, &RESP_001) - .await - .unwrap(); - - // Now wait for the request to complete - let end = rr_rt.await.unwrap().unwrap(); - - // We got the simulated value back - assert_eq!(end, RESP_001); + } + assert!(good); } - // Now, simulate an I/O error - srv.cause_fatal_error(); - - // Give the clients some time to halt - yield_now().await; - - // Our server channels should now be closed - the tasks hung up - assert!(srv.from_client.recv().await.is_none()); - assert!(srv.to_client.send(Vec::new()).await.is_err()); + // async topic handler + for i in 0..3 { + let mut msg = VarHeader { + key: VarKey::Key8(ZetaTopic2::TOPIC_KEY), + seq_no: VarSeq::Seq4(i), + } + .write_to_vec(); + let body = postcard::to_stdvec(&ZMsg(456)).unwrap(); + msg.extend_from_slice(&body); + client_tx.send(msg).await.unwrap(); + + let start = Instant::now(); + let mut good = false; + while start.elapsed() < Duration::from_millis(100) { + let ct = topic_ctr.load(Ordering::Relaxed); + if ct == (i + 4) as usize { + good = true; + break; + } else { + yield_now().await + } + } + assert!(good); + } - // Try again, but nothing should work because all the worker tasks just halted - { - let rr_rt = tokio::task::spawn({ - let client = client.clone(); - async move { - client - .send_resp::(&Req1 { a: 10, b: 100 }) - .await + // spawn topic handler + for i in 0..3 { + let mut msg = VarHeader { + key: VarKey::Key8(ZetaTopic3::TOPIC_KEY), + seq_no: VarSeq::Seq4(i), + } + .write_to_vec(); + let body = postcard::to_stdvec(&ZMsg(456)).unwrap(); + msg.extend_from_slice(&body); + client_tx.send(msg).await.unwrap(); + + let start = Instant::now(); + let mut good = false; + while start.elapsed() < Duration::from_millis(100) { + let ct = topic_ctr.load(Ordering::Relaxed); + if ct == (i + 7) as usize { + good = true; + break; + } else { + yield_now().await } - }); + } + assert!(good); + } +} - // As the wire, get the outgoing request - didn't happen - assert!(srv.recv_from_client().await.is_err()); +#[tokio::test] +async fn end_to_end() { + let (client_tx, server_rx) = mpsc::channel(16); + let (server_tx, client_rx) = mpsc::channel(16); + let topic_ctr = Arc::new(AtomicUsize::new(0)); + + let app = SingleDispatcher::new( + TestContext { + ctr: Arc::new(AtomicUsize::new(0)), + topic_ctr: topic_ctr.clone(), + }, + ChannelWireSpawn {}, + ); + + let cwrx = ChannelWireRx::new(server_rx); + let cwtx = ChannelWireTx::new(server_tx); + + let kkind = app.min_key_len(); + let mut server = new_server( + app, + Settings { + tx: cwtx, + rx: cwrx, + buf: 1024, + kkind, + }, + ); + tokio::task::spawn(async move { + server.run().await; + }); - // Now wait for the request to complete - it failed - rr_rt.await.unwrap().unwrap_err(); - } + let cli = client::new_from_channels(client_tx, client_rx, VarSeqKind::Seq1); + + let resp = cli.send_resp::(&AReq(42)).await.unwrap(); + assert_eq!(resp.0, 42); + let resp = cli.send_resp::(&BReq(1234)).await.unwrap(); + assert_eq!(resp.0, 1234); } #[tokio::test] -async fn smoke_closed() { - let (mut srv, client) = local_setup::(8, "error"); - - // Do one round trip to make sure the connection works - { - let rr_rt = tokio::task::spawn({ - let client = client.clone(); - async move { - client - .send_resp::(&Req1 { a: 10, b: 100 }) - .await - } - }); - - // As the wire, get the outgoing request - let out1 = srv.recv_from_client().await.unwrap(); - - // Does the outgoing value match what we expect? - let exp_out = to_stdvec_keyed(0, EndpointOne::REQ_KEY, &Req1 { a: 10, b: 100 }).unwrap(); - let act_out = out1.to_bytes(); - assert_eq!(act_out, exp_out); - - // The request is still awaiting a response - assert!(!rr_rt.is_finished()); - - // Feed a simulated response "from the wire" back to the - // awaiting request - const RESP_001: Resp1 = Resp1 { - c: [1, 2, 3, 4, 5, 6, 7, 8], - d: -10, - }; - srv.reply::(out1.header.seq_no, &RESP_001) - .await - .unwrap(); - - // Now wait for the request to complete - let end = rr_rt.await.unwrap().unwrap(); - - // We got the simulated value back - assert_eq!(end, RESP_001); - } +async fn end_to_end_force8() { + let (client_tx, server_rx) = mpsc::channel(16); + let (server_tx, client_rx) = mpsc::channel(16); + let topic_ctr = Arc::new(AtomicUsize::new(0)); + + let app = SingleDispatcher::new( + TestContext { + ctr: Arc::new(AtomicUsize::new(0)), + topic_ctr: topic_ctr.clone(), + }, + ChannelWireSpawn {}, + ); + + let cwrx = ChannelWireRx::new(server_rx); + let cwtx = ChannelWireTx::new(server_tx); + + let kkind = VarKeyKind::Key8; + let mut server = new_server( + app, + Settings { + tx: cwtx, + rx: cwrx, + buf: 1024, + kkind, + }, + ); + tokio::task::spawn(async move { + server.run().await; + }); - // Now, use the *client* to close the connection - client.close(); + let cli = client::new_from_channels(client_tx, client_rx, VarSeqKind::Seq4); - // Give the clients some time to halt - yield_now().await; + let resp = cli.send_resp::(&AReq(42)).await.unwrap(); + assert_eq!(resp.0, 42); + let resp = cli.send_resp::(&BReq(1234)).await.unwrap(); + assert_eq!(resp.0, 1234); +} - // Our server channels should now be closed - the tasks hung up - assert!(srv.from_client.recv().await.is_none()); - assert!(srv.to_client.send(Vec::new()).await.is_err()); +#[test] +fn device_map() { + let topic_ctr = Arc::new(AtomicUsize::new(0)); + let app = SingleDispatcher::new( + TestContext { + ctr: Arc::new(AtomicUsize::new(0)), + topic_ctr: topic_ctr.clone(), + }, + ChannelWireSpawn {}, + ); + + println!("# SingleDispatcher"); + println!(); + + println!("## Types"); + println!(); + for ty in app.device_map.types { + let ty = OwnedNamedType::from(*ty); + println!("* {ty}"); + } - // Try again, but nothing should work because all the worker tasks just halted - { - let rr_rt = tokio::task::spawn({ - let client = client.clone(); - async move { - client - .send_resp::(&Req1 { a: 10, b: 100 }) - .await - } - }); + println!(); + println!("## Endpoints"); + println!(); + for ep in app.device_map.endpoints { + println!( + "* {} ({:016X} -> {:016X})", + ep.0, + u64::from_le_bytes(ep.1.to_bytes()), + u64::from_le_bytes(ep.2.to_bytes()), + ); + } - // As the wire, get the outgoing request - didn't happen - assert!(srv.recv_from_client().await.is_err()); + println!(); + println!("## Topics (In)"); + println!(); + for tp in app.device_map.topics_in { + println!( + "* {} <- ({:016X})", + tp.0, + u64::from_le_bytes(tp.1.to_bytes()), + ); + } - // Now wait for the request to complete - it failed - rr_rt.await.unwrap().unwrap_err(); + println!(); + println!("## Topics (Out)"); + println!(); + for tp in app.device_map.topics_out { + println!( + "* {} -> ({:016X})", + tp.0, + u64::from_le_bytes(tp.1.to_bytes()), + ); } + println!(); + println!("## Min Key Length"); + println!(); + println!("{:?}", app.device_map.min_key_len); + println!(); } diff --git a/source/postcard-rpc/Cargo.toml b/source/postcard-rpc/Cargo.toml index be20af9..704f41d 100644 --- a/source/postcard-rpc/Cargo.toml +++ b/source/postcard-rpc/Cargo.toml @@ -30,6 +30,7 @@ heapless = "0.8.0" postcard = { version = "1.0.8" } serde = { version = "1.0.192", default-features = false, features = ["derive"] } postcard-schema = { version = "0.1.0", features = ["derive"] } +paste = "1.0.15" # # std-only features @@ -124,7 +125,7 @@ version = "2.1" optional = true [dependencies.embassy-executor] -version = "0.5" +version = "0.6" optional = true [dependencies.futures-util] @@ -155,6 +156,7 @@ use-std = [ "dep:thiserror", "dep:tracing", "dep:trait-variant", + "dep:ssmarshal", ] # Cobs Serial support. diff --git a/source/postcard-rpc/src/accumulator.rs b/source/postcard-rpc/src/accumulator.rs index 421c885..493ba80 100644 --- a/source/postcard-rpc/src/accumulator.rs +++ b/source/postcard-rpc/src/accumulator.rs @@ -9,6 +9,7 @@ pub mod raw { use cobs::decode_in_place; + /// A header-aware COBS accumulator pub struct CobsAccumulator { buf: [u8; N], idx: usize, @@ -116,103 +117,3 @@ pub mod raw { } } } - -/// Accumulate and Dispatch -pub mod dispatch { - use super::raw::{CobsAccumulator, FeedResult}; - use crate::Dispatch; - - /// An error containing the handler-specific error, as well as the unprocessed - /// feed bytes - #[derive(Debug, PartialEq)] - pub struct FeedError<'a, E> { - pub err: E, - pub remainder: &'a [u8], - } - - /// A COBS-flavored version of [Dispatch] - /// - /// [Dispatch]: crate::Dispatch - /// - /// CobsDispatch is generic over four types: - /// - /// 1. The `Context`, which will be passed as a mutable reference - /// to each of the handlers. It typically should contain - /// whatever resource is necessary to send replies back to - /// the host. - /// 2. The `Error` type, which can be returned by handlers - /// 3. `N`, for the maximum number of handlers - /// 4. `BUF`, for the maximum number of bytes to buffer for a single - /// COBS-encoded message - pub struct CobsDispatch { - dispatch: Dispatch, - acc: CobsAccumulator, - } - - impl CobsDispatch { - /// Create a new `CobsDispatch` - pub fn new(c: Context) -> Self { - Self { - dispatch: Dispatch::new(c), - acc: CobsAccumulator::new(), - } - } - - /// Access the contained [Dispatch]` - pub fn dispatcher(&mut self) -> &mut Dispatch { - &mut self.dispatch - } - - /// Feed the given bytes into the dispatcher, attempting to dispatch each framed - /// message found. - /// - /// Line format errors, such as an overfull buffer or bad COBS frames will be - /// silently ignored. - /// - /// If an error in dispatching occurs, this function will return immediately, - /// yielding the error and the remaining unprocessed bytes for further processing. - pub fn feed<'a>( - &mut self, - buf: &'a [u8], - ) -> Result<(), FeedError<'a, crate::Error>> { - let mut window = buf; - let CobsDispatch { dispatch, acc } = self; - 'cobs: while !window.is_empty() { - window = match acc.feed(window) { - FeedResult::Consumed => break 'cobs, - FeedResult::OverFull(new_wind) => new_wind, - FeedResult::DeserError(new_wind) => new_wind, - FeedResult::Success { data, remaining } => { - dispatch.dispatch(data).map_err(|e| FeedError { - err: e, - remainder: remaining, - })?; - remaining - } - }; - } - - // We have dispatched all (if any) messages, and consumed the buffer - // without dispatch errors. - Ok(()) - } - - /// Similar to [CobsDispatch::feed], but the provided closure is called on each - /// error, allowing for handling. - /// - /// Useful if you need to do something blocking on each error case. - /// - /// If you need to handle the error in an async context, you may want to use - /// [CobsDispatch::feed] instead. - pub fn feed_with_err(&mut self, buf: &[u8], mut f: F) - where - F: FnMut(&mut Context, crate::Error), - { - let mut window = buf; - while let Err(FeedError { err, remainder }) = self.feed(window) { - f(&mut self.dispatch.context, err); - window = remainder; - } - } - } -} diff --git a/source/postcard-rpc/src/hash.rs b/source/postcard-rpc/src/hash.rs index 8d629df..3929de3 100644 --- a/source/postcard-rpc/src/hash.rs +++ b/source/postcard-rpc/src/hash.rs @@ -13,6 +13,7 @@ use postcard_schema::{ Schema, }; +/// A const compatible Fnv1a64 hasher pub struct Fnv1a64Hasher { state: u64, } @@ -22,10 +23,12 @@ impl Fnv1a64Hasher { const BASIS: u64 = 0xcbf2_9ce4_8422_2325; const PRIME: u64 = 0x0000_0100_0000_01b3; + /// Create a new hasher with the default basis as state contents pub fn new() -> Self { Self { state: Self::BASIS } } + /// Calculate the hash for each of the given data bytes pub fn update(&mut self, data: &[u8]) { for b in data { let ext = u64::from(*b); @@ -34,10 +37,12 @@ impl Fnv1a64Hasher { } } + /// Extract the current state for finalizing the hash pub fn digest(self) -> u64 { self.state } + /// Same as digest but as bytes pub fn digest_bytes(self) -> [u8; 8] { self.digest().to_le_bytes() } @@ -50,10 +55,12 @@ impl Default for Fnv1a64Hasher { } pub mod fnv1a64 { + //! Const and no-std helper methods and types for perfoming hash calculation use postcard_schema::schema::DataModelVariant; use super::*; + /// Calculate the Key hash for the given path and type T pub const fn hash_ty_path(path: &str) -> [u8; 8] { let schema = T::SCHEMA; let state = hash_update_str(Fnv1a64Hasher::BASIS, path); @@ -205,7 +212,7 @@ pub mod fnv1a64 { let mut state = hash_update(state, &[0xC7]); let mut idx = 0; while idx < nts.len() { - state = hash_named_type(state, &nts[idx]); + state = hash_named_type(state, nts[idx]); idx += 1; } state @@ -214,7 +221,7 @@ pub mod fnv1a64 { let mut state = hash_update(state, &[0x67]); let mut idx = 0; while idx < nvs.len() { - state = hash_named_value(state, &nvs[idx]); + state = hash_named_value(state, nvs[idx]); idx += 1; } state @@ -230,6 +237,8 @@ pub mod fnv1a64 { #[cfg(feature = "use-std")] pub mod fnv1a64_owned { + //! Heapful helpers and versions of hashing for use on `std` targets + use postcard_schema::schema::owned::{ OwnedDataModelType, OwnedDataModelVariant, OwnedNamedType, OwnedNamedValue, OwnedNamedVariant, @@ -238,6 +247,7 @@ pub mod fnv1a64_owned { use super::fnv1a64::*; use super::*; + /// Calculate the Key hash for the given path and OwnedNamedType pub fn hash_ty_path_owned(path: &str, nt: &OwnedNamedType) -> [u8; 8] { let state = hash_update_str(Fnv1a64Hasher::BASIS, path); hash_named_type_owned(state, nt).to_le_bytes() diff --git a/source/postcard-rpc/src/header.rs b/source/postcard-rpc/src/header.rs new file mode 100644 index 0000000..7e34a09 --- /dev/null +++ b/source/postcard-rpc/src/header.rs @@ -0,0 +1,630 @@ +//! # Postcard-RPC Header Format +//! +//! Postcard-RPC's header is made up of three main parts: +//! +//! 1. A one-byte discriminant +//! 2. A 1-8 byte "Key" +//! 3. A 1-4 byte "Sequence Number" +//! +//! The Postcard-RPC Header is NOT encoded using `postcard`'s wire format. +//! +//! ## Discriminant +//! +//! The discriminant field is always one byte, and consists of three subfields +//! in the form `0bNNMM_VVVV`. +//! +//! * The two msbits are "key length", where the two N length bits represent +//! a key length of 2^N. All values are valid. +//! * The next two msbits are "sequence number length", where the two M length +//! bits represent a sequence number length of 2^M. Values 00, 01, and 10 +//! are valid. +//! * The four lsbits are "protocol version", where the four V version bits +//! represent an unsigned 4-bit number. Currently only 0000 is a valid value. +//! +//! ## Key +//! +//! The Key consists of an fnv1a hash of the path string and schema of the +//! contained message. These are calculated using the [`hash` module](crate::hash), +//! and are natively calculated as an 8-byte hash. +//! +//! Keys may be encoded with variable fidelity on the wire, as follows: +//! +//! * For 8-byte keys, all key bytes appear in the form `[A, B, C, D, E, F, G, H]`. +//! * For 4-byte keys, the 8-byte form is compressed as `[A^B, C^D, E^F, G^H]`. +//! * For 2-byte keys, the 8-byte form is compressed as `[A^B^C^D, E^F^G^H]`. +//! * For 1-byte keys, the 8-byte form is compressed as `A^B^C^D^E^F^G^H`. +//! +//! The length of the Key is determined by the two `NN` bits in the discriminant. +//! +//! The length of the key is usually chosen by the **Server**, as the server is +//! able to calculate the minimum number of bits necessary to avoid collisions. +//! +//! When Clients receive a server response, they shall note the Key length used, +//! and match that for all subsequent messages. When Clients make first connection, +//! they shall use the 8-byte form by default. +//! +//! ## Sequence Number +//! +//! The Sequence Number is an unsigned integer used to match request-response pairs, +//! and disambiguate between multiple in-flight messages. +//! +//! Sequence Numbers may be encoded with variable fidelity on the wire, always in +//! little-endian order, of 1, 2, or 4 bytes. +//! +//! The length of the Sequence Number is determined by the two `MM` bits in the +//! discriminant. +//! +//! The length of the key is chosen by the "originator" of the message. For Endpoints +//! this is the client making the request. For Topics, this is the device sending the +//! topic message. + +use crate::{Key, Key1, Key2, Key4}; + +////////////////////////////////////////////////////////////////////////////// +// VARKEY +////////////////////////////////////////////////////////////////////////////// + +/// A variably sized header Key +/// +/// NOTE: We DO NOT impl Serialize/Deserialize for this type because +/// we use non-postcard-compatible format (externally tagged) on the wire. +/// +/// NOTE: VarKey implements `PartialEq` by reducing two VarKeys down to the +/// smaller of the two forms, and checking whether they match. This allows +/// a key in 8-byte form to be compared to a key in 1, 2, or 4-byte form +/// for equality. +#[derive(Debug, Copy, Clone)] +pub enum VarKey { + /// A one byte key + Key1(Key1), + /// A two byte key + Key2(Key2), + /// A four byte key + Key4(Key4), + /// An eight byte key + Key8(Key), +} + +/// We implement PartialEq MANUALLY for VarKey, because keys of different lengths SHOULD compare +/// as equal. +impl PartialEq for VarKey { + fn eq(&self, other: &Self) -> bool { + // figure out the minimum length + match (self, other) { + // Matching kinds + (VarKey::Key1(self_key), VarKey::Key1(other_key)) => self_key.0.eq(&other_key.0), + (VarKey::Key2(self_key), VarKey::Key2(other_key)) => self_key.0.eq(&other_key.0), + (VarKey::Key4(self_key), VarKey::Key4(other_key)) => self_key.0.eq(&other_key.0), + (VarKey::Key8(self_key), VarKey::Key8(other_key)) => self_key.0.eq(&other_key.0), + + // For the rest of the options, degrade the LARGER key to the SMALLER key, and then + // check for equivalence after that. + (VarKey::Key1(this), VarKey::Key2(other)) => { + let other = Key1::from_key2(*other); + this.0.eq(&other.0) + } + (VarKey::Key1(this), VarKey::Key4(other)) => { + let other = Key1::from_key4(*other); + this.0.eq(&other.0) + } + (VarKey::Key1(this), VarKey::Key8(other)) => { + let other = Key1::from_key8(*other); + this.0.eq(&other.0) + } + (VarKey::Key2(this), VarKey::Key1(other)) => { + let this = Key1::from_key2(*this); + this.0.eq(&other.0) + } + (VarKey::Key2(this), VarKey::Key4(other)) => { + let other = Key2::from_key4(*other); + this.0.eq(&other.0) + } + (VarKey::Key2(this), VarKey::Key8(other)) => { + let other = Key2::from_key8(*other); + this.0.eq(&other.0) + } + (VarKey::Key4(this), VarKey::Key1(other)) => { + let this = Key1::from_key4(*this); + this.0.eq(&other.0) + } + (VarKey::Key4(this), VarKey::Key2(other)) => { + let this = Key2::from_key4(*this); + this.0.eq(&other.0) + } + (VarKey::Key4(this), VarKey::Key8(other)) => { + let other = Key4::from_key8(*other); + this.0.eq(&other.0) + } + (VarKey::Key8(this), VarKey::Key1(other)) => { + let this = Key1::from_key8(*this); + this.0.eq(&other.0) + } + (VarKey::Key8(this), VarKey::Key2(other)) => { + let this = Key2::from_key8(*this); + this.0.eq(&other.0) + } + (VarKey::Key8(this), VarKey::Key4(other)) => { + let this = Key4::from_key8(*this); + this.0.eq(&other.0) + } + } + } +} + +impl VarKey { + /// Keys can not be reaised, but instead only shrunk. + /// + /// This method will shrink to the requested length if that length is + /// smaller than the current representation, or if the requested length + /// is the same or larger than the current representation, it will be + /// kept unchanged + pub fn shrink_to(&mut self, kind: VarKeyKind) { + match (&self, kind) { + (VarKey::Key1(_), _) => { + // Nothing to shrink + } + (VarKey::Key2(key2), VarKeyKind::Key1) => { + *self = VarKey::Key1(Key1::from_key2(*key2)); + } + (VarKey::Key2(_), _) => { + // We are already as small or smaller than the request + } + (VarKey::Key4(key4), VarKeyKind::Key1) => { + *self = VarKey::Key1(Key1::from_key4(*key4)); + } + (VarKey::Key4(key4), VarKeyKind::Key2) => { + *self = VarKey::Key2(Key2::from_key4(*key4)); + } + (VarKey::Key4(_), _) => { + // We are already as small or smaller than the request + } + (VarKey::Key8(key), VarKeyKind::Key1) => { + *self = VarKey::Key1(Key1::from_key8(*key)); + } + (VarKey::Key8(key), VarKeyKind::Key2) => { + *self = VarKey::Key2(Key2::from_key8(*key)); + } + (VarKey::Key8(key), VarKeyKind::Key4) => { + *self = VarKey::Key4(Key4::from_key8(*key)); + } + (VarKey::Key8(_), VarKeyKind::Key8) => { + // Nothing to do + } + } + } + + /// The current kind/length of the key + pub fn kind(&self) -> VarKeyKind { + match self { + VarKey::Key1(_) => VarKeyKind::Key1, + VarKey::Key2(_) => VarKeyKind::Key2, + VarKey::Key4(_) => VarKeyKind::Key4, + VarKey::Key8(_) => VarKeyKind::Key8, + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// VARKEYKIND +////////////////////////////////////////////////////////////////////////////// + +/// The kind or length of the variably sized header Key +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum VarKeyKind { + /// A one byte key + Key1, + /// A two byte key + Key2, + /// A four byte key + Key4, + /// An eight byte key + Key8, +} + +////////////////////////////////////////////////////////////////////////////// +// VARSEQ +////////////////////////////////////////////////////////////////////////////// + +/// A variably sized sequence number +/// +/// NOTE: We use the standard PartialEq here, as we DO NOT treat sequence +/// numbers of different lengths as equivalent. +/// +/// We DO NOT impl Serialize/Deserialize for this type because we use +/// non-postcard-compatible format (externally tagged) +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum VarSeq { + /// A one byte sequence number + Seq1(u8), + /// A two byte sequence number + Seq2(u16), + /// A four byte sequence number + Seq4(u32), +} + +impl From for VarSeq { + fn from(value: u8) -> Self { + Self::Seq1(value) + } +} + +impl From for VarSeq { + fn from(value: u16) -> Self { + Self::Seq2(value) + } +} + +impl From for VarSeq { + fn from(value: u32) -> Self { + Self::Seq4(value) + } +} + +impl VarSeq { + /// Resize (up or down) to the requested kind. + /// + /// When increasing size, the number is left-extended, e.g. `0x42u8` becomes + /// `0x0000_0042u32` when resizing 1 -> 4. + /// + /// When decreasing size, the number is truncated, e.g. `0xABCD_EF12u32` + /// becomes `0x12u8` when resizing 4 -> 1. + pub fn resize(&mut self, kind: VarSeqKind) { + match (&self, kind) { + (VarSeq::Seq1(_), VarSeqKind::Seq1) => {} + (VarSeq::Seq2(_), VarSeqKind::Seq2) => {} + (VarSeq::Seq4(_), VarSeqKind::Seq4) => {} + (VarSeq::Seq1(s), VarSeqKind::Seq2) => { + *self = VarSeq::Seq2((*s).into()); + } + (VarSeq::Seq1(s), VarSeqKind::Seq4) => { + *self = VarSeq::Seq4((*s).into()); + } + (VarSeq::Seq2(s), VarSeqKind::Seq1) => { + *self = VarSeq::Seq1((*s) as u8); + } + (VarSeq::Seq2(s), VarSeqKind::Seq4) => { + *self = VarSeq::Seq4((*s).into()); + } + (VarSeq::Seq4(s), VarSeqKind::Seq1) => { + *self = VarSeq::Seq1((*s) as u8); + } + (VarSeq::Seq4(s), VarSeqKind::Seq2) => { + *self = VarSeq::Seq2((*s) as u16); + } + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// VARSEQKIND +////////////////////////////////////////////////////////////////////////////// + +/// The Kind or Length of a VarSeq +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum VarSeqKind { + /// A one byte sequence number + Seq1, + /// A two byte sequence number + Seq2, + /// A four byte sequence number + Seq4, +} + +////////////////////////////////////////////////////////////////////////////// +// VARHEADER +////////////////////////////////////////////////////////////////////////////// + +/// A variably sized message header +/// +/// NOTE: We use the standard PartialEq here as it will do the correct things. +/// +/// Sequence numbers must be EXACTLY the same, and keys must be equivalent when +/// degraded to the smaller of the two. +/// +/// We DO NOT impl Serialize/Deserialize for this type because we use +/// non-postcard-compatible format (externally tagged) +#[derive(Debug, PartialEq, Clone, Copy)] +pub struct VarHeader { + /// The variably sized Key + pub key: VarKey, + /// The variably sized Sequence Number + pub seq_no: VarSeq, +} + +#[allow(clippy::unusual_byte_groupings)] +impl VarHeader { + /// Bits for a key of ONE byte + pub const KEY_ONE_BITS: u8 = 0b00_00_0000; + /// Bits for a key of TWO bytes + pub const KEY_TWO_BITS: u8 = 0b01_00_0000; + /// Bits for a key of FOUR bytes + pub const KEY_FOUR_BITS: u8 = 0b10_00_0000; + /// Bits for a key of EIGHT bytes + pub const KEY_EIGHT_BITS: u8 = 0b11_00_0000; + /// Mask bits + pub const KEY_MASK_BITS: u8 = 0b11_00_0000; + + /// Bits for a sequence number of ONE bytes + pub const SEQ_ONE_BITS: u8 = 0b00_00_0000; + /// Bits for a sequence number of TWO bytes + pub const SEQ_TWO_BITS: u8 = 0b00_01_0000; + /// Bits for a sequence number of FOUR bytes + pub const SEQ_FOUR_BITS: u8 = 0b00_10_0000; + /// Mask bits + pub const SEQ_MASK_BITS: u8 = 0b00_11_0000; + + /// Bits for a version number of ZERO + pub const VER_ZERO_BITS: u8 = 0b00_00_0000; + /// Mask bits + pub const VER_MASK_BITS: u8 = 0b00_00_1111; + + /// Encode the header to a Vec of bytes + #[cfg(feature = "use-std")] + pub fn write_to_vec(&self) -> Vec { + // start with placeholder byte + let mut out = vec![0u8; 1]; + let mut disc_out: u8; + match &self.key { + VarKey::Key1(k) => { + disc_out = Self::KEY_ONE_BITS; + out.push(k.0); + } + VarKey::Key2(k) => { + disc_out = Self::KEY_TWO_BITS; + out.extend_from_slice(&k.0); + } + VarKey::Key4(k) => { + disc_out = Self::KEY_FOUR_BITS; + out.extend_from_slice(&k.0); + } + VarKey::Key8(k) => { + disc_out = Self::KEY_EIGHT_BITS; + out.extend_from_slice(&k.0); + } + } + match &self.seq_no { + VarSeq::Seq1(s) => { + disc_out |= Self::SEQ_ONE_BITS; + out.push(*s); + } + VarSeq::Seq2(s) => { + disc_out |= Self::SEQ_TWO_BITS; + out.extend_from_slice(&s.to_le_bytes()); + } + VarSeq::Seq4(s) => { + disc_out |= Self::SEQ_FOUR_BITS; + out.extend_from_slice(&s.to_le_bytes()); + } + } + // push discriminant to the end... + out.push(disc_out); + // ...and swap-remove the placeholder byte, moving the discriminant to the front + out.swap_remove(0); + out + } + + /// Attempt to write the header to the given slice + /// + /// If the slice is large enough, a `Some` will be returned with the bytes used + /// to encode the header, as well as the remaining unused bytes. + /// + /// If the slice is not large enough, a `None` will be returned, and some bytes + /// of the buffer may have been modified. + pub fn write_to_slice<'a>(&self, buf: &'a mut [u8]) -> Option<(&'a mut [u8], &'a mut [u8])> { + let (disc_out, mut remain) = buf.split_first_mut()?; + let mut used = 1; + + match &self.key { + VarKey::Key1(k) => { + *disc_out = Self::KEY_ONE_BITS; + let (keybs, remain2) = remain.split_first_mut()?; + *keybs = k.0; + remain = remain2; + used += 1; + } + VarKey::Key2(k) => { + *disc_out = Self::KEY_TWO_BITS; + let (keybs, remain2) = remain.split_at_mut_checked(2)?; + keybs.copy_from_slice(&k.0); + remain = remain2; + used += 2; + } + VarKey::Key4(k) => { + *disc_out = Self::KEY_FOUR_BITS; + let (keybs, remain2) = remain.split_at_mut_checked(4)?; + keybs.copy_from_slice(&k.0); + remain = remain2; + used += 4; + } + VarKey::Key8(k) => { + *disc_out = Self::KEY_EIGHT_BITS; + let (keybs, remain2) = remain.split_at_mut_checked(8)?; + keybs.copy_from_slice(&k.0); + remain = remain2; + used += 8; + } + } + match &self.seq_no { + VarSeq::Seq1(s) => { + *disc_out |= Self::SEQ_ONE_BITS; + let (seqbs, _) = remain.split_first_mut()?; + *seqbs = *s; + used += 1; + } + VarSeq::Seq2(s) => { + *disc_out |= Self::SEQ_TWO_BITS; + let (seqbs, _) = remain.split_at_mut_checked(2)?; + seqbs.copy_from_slice(&s.to_le_bytes()); + used += 2; + } + VarSeq::Seq4(s) => { + *disc_out |= Self::SEQ_FOUR_BITS; + let (seqbs, _) = remain.split_at_mut_checked(4)?; + seqbs.copy_from_slice(&s.to_le_bytes()); + used += 4; + } + } + Some(buf.split_at_mut(used)) + } + + /// Attempt to decode a header from the given bytes. + /// + /// If a well-formed header was found, a `Some` will be returned with the + /// decoded header and unused remaining bytes. + /// + /// If no well-formed header was found, a `None` will be returned. + pub fn take_from_slice(buf: &[u8]) -> Option<(Self, &[u8])> { + let (disc, mut remain) = buf.split_first()?; + + // For now, we only trust version zero + if (*disc & Self::VER_MASK_BITS) != Self::VER_ZERO_BITS { + return None; + } + + let key = match (*disc) & Self::KEY_MASK_BITS { + Self::KEY_ONE_BITS => { + let (keybs, remain2) = remain.split_first()?; + remain = remain2; + VarKey::Key1(Key1(*keybs)) + } + Self::KEY_TWO_BITS => { + let (keybs, remain2) = remain.split_at_checked(2)?; + remain = remain2; + let mut buf = [0u8; 2]; + buf.copy_from_slice(keybs); + VarKey::Key2(Key2(buf)) + } + Self::KEY_FOUR_BITS => { + let (keybs, remain2) = remain.split_at_checked(4)?; + remain = remain2; + let mut buf = [0u8; 4]; + buf.copy_from_slice(keybs); + VarKey::Key4(Key4(buf)) + } + Self::KEY_EIGHT_BITS => { + let (keybs, remain2) = remain.split_at_checked(8)?; + remain = remain2; + let mut buf = [0u8; 8]; + buf.copy_from_slice(keybs); + VarKey::Key8(Key(buf)) + } + // Impossible: all bits covered + _ => unreachable!(), + }; + let seq_no = match (*disc) & Self::SEQ_MASK_BITS { + Self::SEQ_ONE_BITS => { + let (seqbs, remain3) = remain.split_first()?; + remain = remain3; + VarSeq::Seq1(*seqbs) + } + Self::SEQ_TWO_BITS => { + let (seqbs, remain3) = remain.split_at_checked(2)?; + remain = remain3; + let mut buf = [0u8; 2]; + buf.copy_from_slice(seqbs); + VarSeq::Seq2(u16::from_le_bytes(buf)) + } + Self::SEQ_FOUR_BITS => { + let (seqbs, remain3) = remain.split_at_checked(4)?; + remain = remain3; + let mut buf = [0u8; 4]; + buf.copy_from_slice(seqbs); + VarSeq::Seq4(u32::from_le_bytes(buf)) + } + // Possible (could be 0b11), is invalid + _ => return None, + }; + Some((Self { key, seq_no }, remain)) + } +} + +#[cfg(test)] +mod test { + use super::{VarHeader, VarKey, VarSeq}; + use crate::{Key, Key1, Key2}; + + #[test] + fn wire_format() { + let checks: &[(_, &[u8])] = &[ + ( + VarHeader { + key: VarKey::Key1(Key1(0)), + seq_no: VarSeq::Seq1(0x00), + }, + &[ + VarHeader::KEY_ONE_BITS | VarHeader::SEQ_ONE_BITS, + 0x00, + 0x00, + ], + ), + ( + VarHeader { + key: VarKey::Key1(Key1(1)), + seq_no: VarSeq::Seq1(0x02), + }, + &[ + VarHeader::KEY_ONE_BITS | VarHeader::SEQ_ONE_BITS, + 0x01, + 0x02, + ], + ), + ( + VarHeader { + key: VarKey::Key2(Key2([0x42, 0xAF])), + seq_no: VarSeq::Seq1(0x02), + }, + &[ + VarHeader::KEY_TWO_BITS | VarHeader::SEQ_ONE_BITS, + 0x42, + 0xAF, + 0x02, + ], + ), + ( + VarHeader { + key: VarKey::Key1(Key1(1)), + seq_no: VarSeq::Seq2(0x42_AF), + }, + &[ + VarHeader::KEY_ONE_BITS | VarHeader::SEQ_TWO_BITS, + 0x01, + 0xAF, + 0x42, + ], + ), + ( + VarHeader { + key: VarKey::Key8(Key([0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89])), + seq_no: VarSeq::Seq4(0x42_AF_AA_BB), + }, + &[ + VarHeader::KEY_EIGHT_BITS | VarHeader::SEQ_FOUR_BITS, + 0x12, + 0x23, + 0x34, + 0x45, + 0x56, + 0x67, + 0x78, + 0x89, + 0xBB, + 0xAA, + 0xAF, + 0x42, + ], + ), + ]; + + let mut buf = [0u8; 1 + 8 + 4]; + + for (val, exp) in checks.iter() { + let (used, _) = val.write_to_slice(&mut buf).unwrap(); + assert_eq!(used, *exp); + let v = val.write_to_vec(); + assert_eq!(&v, *exp); + let (deser, remain) = VarHeader::take_from_slice(used).unwrap(); + assert!(remain.is_empty()); + assert_eq!(val, &deser); + } + } +} diff --git a/source/postcard-rpc/src/headered.rs b/source/postcard-rpc/src/headered.rs deleted file mode 100644 index 0521e86..0000000 --- a/source/postcard-rpc/src/headered.rs +++ /dev/null @@ -1,126 +0,0 @@ -//! Helper functions for encoding/decoding messages with postcard-rpc headers - -use crate::{Key, WireHeader}; -use postcard::{ - ser_flavors::{Cobs, Flavor as SerFlavor, Slice}, - Serializer, -}; -use postcard_schema::Schema; - -use serde::Serialize; - -struct Headered { - flavor: B, -} - -impl Headered { - fn try_new_keyed(b: B, seq_no: u32, key: Key) -> Result { - let mut serializer = Serializer { output: b }; - let hdr = WireHeader { key, seq_no }; - hdr.serialize(&mut serializer)?; - Ok(Self { - flavor: serializer.output, - }) - } - - fn try_new(b: B, seq_no: u32, path: &str) -> Result { - let key = Key::for_path::(path); - Self::try_new_keyed(b, seq_no, key) - } -} - -impl SerFlavor for Headered { - type Output = B::Output; - - #[inline] - fn try_push(&mut self, data: u8) -> postcard::Result<()> { - self.flavor.try_push(data) - } - - #[inline] - fn finalize(self) -> postcard::Result { - self.flavor.finalize() - } - - #[inline] - fn try_extend(&mut self, data: &[u8]) -> postcard::Result<()> { - self.flavor.try_extend(data) - } -} - -/// Serialize to a slice with a prepended header -/// -/// WARNING: This rehashes the schema! Prefer [to_slice_keyed]! -pub fn to_slice<'a, T: Serialize + ?Sized + Schema>( - seq_no: u32, - path: &str, - value: &T, - buf: &'a mut [u8], -) -> Result<&'a mut [u8], postcard::Error> { - let flavor = Headered::try_new::(Slice::new(buf), seq_no, path)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Serialize to a slice with a prepended header -pub fn to_slice_keyed<'a, T: Serialize + ?Sized + Schema>( - seq_no: u32, - key: Key, - value: &T, - buf: &'a mut [u8], -) -> Result<&'a mut [u8], postcard::Error> { - let flavor = Headered::try_new_keyed(Slice::new(buf), seq_no, key)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Serialize to a COBS-encoded slice with a prepended header -/// -/// WARNING: This rehashes the schema! Prefer [to_slice_cobs_keyed]! -pub fn to_slice_cobs<'a, T: Serialize + ?Sized + Schema>( - seq_no: u32, - path: &str, - value: &T, - buf: &'a mut [u8], -) -> Result<&'a mut [u8], postcard::Error> { - let flavor = Headered::try_new::(Cobs::try_new(Slice::new(buf))?, seq_no, path)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Serialize to a COBS-encoded slice with a prepended header -pub fn to_slice_cobs_keyed<'a, T: Serialize + ?Sized + Schema>( - seq_no: u32, - key: Key, - value: &T, - buf: &'a mut [u8], -) -> Result<&'a mut [u8], postcard::Error> { - let flavor = Headered::try_new_keyed(Cobs::try_new(Slice::new(buf))?, seq_no, key)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Serialize to a Vec with a prepended header -/// -/// WARNING: This rehashes the schema! Prefer [to_slice_keyed]! -#[cfg(feature = "use-std")] -pub fn to_stdvec( - seq_no: u32, - path: &str, - value: &T, -) -> Result, postcard::Error> { - let flavor = Headered::try_new::(postcard::ser_flavors::StdVec::new(), seq_no, path)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Serialize to a Vec with a prepended header -#[cfg(feature = "use-std")] -pub fn to_stdvec_keyed( - seq_no: u32, - key: Key, - value: &T, -) -> Result, postcard::Error> { - let flavor = Headered::try_new_keyed(postcard::ser_flavors::StdVec::new(), seq_no, key)?; - postcard::serialize_with_flavor(value, flavor) -} - -/// Extract the header from a slice of bytes -pub fn extract_header_from_bytes(slice: &[u8]) -> Result<(WireHeader, &[u8]), postcard::Error> { - postcard::take_from_bytes::(slice) -} diff --git a/source/postcard-rpc/src/host_client/mod.rs b/source/postcard-rpc/src/host_client/mod.rs index aafec4d..f3d435e 100644 --- a/source/postcard-rpc/src/host_client/mod.rs +++ b/source/postcard-rpc/src/host_client/mod.rs @@ -8,7 +8,7 @@ use std::{ marker::PhantomData, sync::{ atomic::{AtomicU32, Ordering}, - Arc, + Arc, RwLock, }, }; @@ -20,10 +20,17 @@ use postcard_schema::Schema; use serde::{de::DeserializeOwned, Serialize}; use tokio::{ select, - sync::mpsc::{Receiver, Sender}, + sync::{ + mpsc::{Receiver, Sender}, + Mutex, + }, }; +use util::Subscriptions; -use crate::{Endpoint, Key, Topic, WireHeader}; +use crate::{ + header::{VarHeader, VarKey, VarKeyKind, VarSeq, VarSeqKind}, + Endpoint, Key, Topic, +}; use self::util::Stopper; @@ -38,6 +45,9 @@ pub mod webusb; pub(crate) mod util; +#[cfg(feature = "test-utils")] +pub mod test_channels; + /// Host Error Kind #[derive(Debug, PartialEq)] pub enum HostErr { @@ -76,7 +86,9 @@ impl From for HostErr { /// be returned to the caller. #[cfg(target_family = "wasm")] pub trait WireTx: 'static { - type Error: std::error::Error; // or std? + /// Transmit error type + type Error: std::error::Error; + /// Send a single frame fn send(&mut self, data: Vec) -> impl Future>; } @@ -89,7 +101,9 @@ pub trait WireTx: 'static { /// be returned to the caller. #[cfg(target_family = "wasm")] pub trait WireRx: 'static { + /// Receive error type type Error: std::error::Error; // or std? + /// Receive a single frame fn receive(&mut self) -> impl Future, Self::Error>>; } @@ -98,6 +112,7 @@ pub trait WireRx: 'static { /// Should be suitable for spawning a task in the host executor. #[cfg(target_family = "wasm")] pub trait WireSpawn: 'static { + /// Spawn a task fn spawn(&mut self, fut: impl Future + 'static); } @@ -113,7 +128,9 @@ pub trait WireSpawn: 'static { /// be returned to the caller. #[cfg(not(target_family = "wasm"))] pub trait WireTx: Send + 'static { - type Error: std::error::Error; // or std? + /// Transmit error type + type Error: std::error::Error; + /// Send a single frame fn send(&mut self, data: Vec) -> impl Future> + Send; } @@ -126,7 +143,9 @@ pub trait WireTx: Send + 'static { /// be returned to the caller. #[cfg(not(target_family = "wasm"))] pub trait WireRx: Send + 'static { - type Error: std::error::Error; // or std? + /// Receive error type + type Error: std::error::Error; + /// Receive a single frame fn receive(&mut self) -> impl Future, Self::Error>> + Send; } @@ -135,6 +154,7 @@ pub trait WireRx: Send + 'static { /// Should be suitable for spawning a task in the host executor. #[cfg(not(target_family = "wasm"))] pub trait WireSpawn: 'static { + /// Spawn a task fn spawn(&mut self, fut: impl Future + Send + 'static); } @@ -153,9 +173,10 @@ pub trait WireSpawn: 'static { pub struct HostClient { ctx: Arc, out: Sender, - subber: Sender, + subscriptions: Arc>, err_key: Key, stopper: Stopper, + seq_kind: VarSeqKind, _pd: PhantomData WireErr>, } @@ -164,23 +185,16 @@ impl HostClient where WireErr: DeserializeOwned + Schema, { - /// Create a new manually implemented [HostClient]. - /// - /// This is now deprecated, it is recommended to use [`HostClient::new_with_wire`] instead. - #[deprecated = "HostClient::new_manual will become private in the future, use HostClient::new_with_wire instead"] - pub fn new_manual(err_uri_path: &str, outgoing_depth: usize) -> (Self, WireContext) { - Self::new_manual_priv(err_uri_path, outgoing_depth) - } - /// Private method for creating internal context pub(crate) fn new_manual_priv( err_uri_path: &str, outgoing_depth: usize, + seq_kind: VarSeqKind, ) -> (Self, WireContext) { let (tx_pc, rx_pc) = tokio::sync::mpsc::channel(outgoing_depth); - let (tx_si, rx_si) = tokio::sync::mpsc::channel(outgoing_depth); let ctx = Arc::new(HostContext { + kkind: RwLock::new(VarKeyKind::Key8), map: WaitMap::new(), seq: AtomicU32::new(0), }); @@ -192,14 +206,14 @@ where out: tx_pc, err_key, _pd: PhantomData, - subber: tx_si.clone(), + subscriptions: Arc::new(Mutex::new(Subscriptions::default())), stopper: Stopper::new(), + seq_kind, }; let wire = WireContext { outgoing: rx_pc, incoming: ctx, - new_subs: rx_si, }; (me, wire) @@ -224,11 +238,14 @@ where E::Response: DeserializeOwned + Schema, { let seq_no = self.ctx.seq.fetch_add(1, Ordering::Relaxed); + let msg = postcard::to_stdvec(&t).expect("Allocations should not ever fail"); let frame = RpcFrame { - header: WireHeader { - key: E::REQ_KEY, - seq_no, + // NOTE: send_resp_raw automatically shrinks down key and sequence + // kinds to the appropriate amount + header: VarHeader { + key: VarKey::Key8(E::REQ_KEY), + seq_no: VarSeq::Seq4(seq_no), }, body: msg, }; @@ -237,33 +254,47 @@ where Ok(r) } + /// Perform an endpoint request/response,but without handling the + /// Ser/De automatically pub async fn send_resp_raw( &self, - rqst: RpcFrame, + mut rqst: RpcFrame, resp_key: Key, ) -> Result> { let cancel_fut = self.stopper.wait_stopped(); + let kkind: VarKeyKind = *self.ctx.kkind.read().unwrap(); + rqst.header.key.shrink_to(kkind); + rqst.header.seq_no.resize(self.seq_kind); + let mut resp_key = VarKey::Key8(resp_key); + let mut err_key = VarKey::Key8(self.err_key); + resp_key.shrink_to(kkind); + err_key.shrink_to(kkind); // TODO: Do I need something like a .subscribe method to ensure this is enqueued? - let ok_resp = self.ctx.map.wait(WireHeader { + let ok_resp = self.ctx.map.wait(VarHeader { seq_no: rqst.header.seq_no, key: resp_key, }); - let err_resp = self.ctx.map.wait(WireHeader { + let err_resp = self.ctx.map.wait(VarHeader { seq_no: rqst.header.seq_no, - key: self.err_key, + key: err_key, }); - let seq_no = rqst.header.seq_no; self.out.send(rqst).await.map_err(|_| HostErr::Closed)?; select! { _c = cancel_fut => Err(HostErr::Closed), o = ok_resp => { - let resp = o?; - Ok(RpcFrame { header: WireHeader { key: resp_key, seq_no }, body: resp }) + let (hdr, resp) = o?; + if hdr.key.kind() != kkind { + *self.ctx.kkind.write().unwrap() = hdr.key.kind(); + } + Ok(RpcFrame { header: hdr, body: resp }) }, e = err_resp => { - let resp = e?; + let (hdr, resp) = e?; + if hdr.key.kind() != kkind { + *self.ctx.kkind.write().unwrap() = hdr.key.kind(); + } let r = postcard::from_bytes::(&resp)?; Err(HostErr::Wire(r)) }, @@ -274,14 +305,14 @@ where /// /// There is no feedback if the server received our message. If the I/O worker is /// closed, an error is returned. - pub async fn publish(&self, seq_no: u32, msg: &T::Message) -> Result<(), IoClosed> + pub async fn publish(&self, seq_no: VarSeq, msg: &T::Message) -> Result<(), IoClosed> where T::Message: Serialize, { let smsg = postcard::to_stdvec(msg).expect("alloc should never fail"); let frame = RpcFrame { - header: WireHeader { - key: T::TOPIC_KEY, + header: VarHeader { + key: VarKey::Key8(T::TOPIC_KEY), seq_no, }, body: smsg, @@ -289,7 +320,12 @@ where self.publish_raw(frame).await } - pub async fn publish_raw(&self, frame: RpcFrame) -> Result<(), IoClosed> { + /// Publish the given raw frame + pub async fn publish_raw(&self, mut frame: RpcFrame) -> Result<(), IoClosed> { + let kkind: VarKeyKind = *self.ctx.kkind.read().unwrap(); + frame.header.key.shrink_to(kkind); + frame.header.seq_no.resize(self.seq_kind); + let cancel_fut = self.stopper.wait_stopped(); let operate_fut = self.out.send(frame); @@ -330,19 +366,25 @@ where T::Message: DeserializeOwned, { let (tx, rx) = tokio::sync::mpsc::channel(depth); - self.subber - .send(SubInfo { - key: T::TOPIC_KEY, - tx, - }) - .await - .map_err(|_| IoClosed)?; + { + let mut guard = self.subscriptions.lock().await; + if guard.stopped { + return Err(IoClosed); + } + if let Some(entry) = guard.list.iter_mut().find(|(k, _)| *k == T::TOPIC_KEY) { + tracing::warn!("replacing subscription for topic path '{}'", T::PATH); + entry.1 = tx; + } else { + guard.list.push((T::TOPIC_KEY, tx)); + } + } Ok(Subscription { rx, _pd: PhantomData, }) } + /// Subscribe to the given [`Key`], without automatically handling deserialization pub async fn subscribe_raw(&self, key: Key, depth: usize) -> Result { let cancel_fut = self.stopper.wait_stopped(); let operate_fut = self.subscribe_inner_raw(key, depth); @@ -359,10 +401,18 @@ where depth: usize, ) -> Result { let (tx, rx) = tokio::sync::mpsc::channel(depth); - self.subber - .send(SubInfo { key, tx }) - .await - .map_err(|_| IoClosed)?; + { + let mut guard = self.subscriptions.lock().await; + if guard.stopped { + return Err(IoClosed); + } + if let Some(entry) = guard.list.iter_mut().find(|(k, _)| *k == key) { + tracing::warn!("replacing subscription for raw topic key '{:?}'", key); + entry.1 = tx; + } else { + guard.list.push((key, tx)); + } + } Ok(RawSubscription { rx }) } @@ -388,6 +438,8 @@ where } } +/// Like Subscription, but receives Raw frames that are not +/// automatically deserialized pub struct RawSubscription { rx: Receiver, } @@ -432,18 +484,13 @@ impl Clone for HostClient { out: self.out.clone(), err_key: self.err_key, _pd: PhantomData, - subber: self.subber.clone(), + subscriptions: self.subscriptions.clone(), stopper: self.stopper.clone(), + seq_kind: self.seq_kind, } } } -/// A new subscription that should be accounted for -pub struct SubInfo { - pub key: Key, - pub tx: Sender, -} - /// Items necessary for implementing a custom I/O Task pub struct WireContext { /// This is a stream of frames that should be placed on the @@ -452,14 +499,12 @@ pub struct WireContext { /// This shared information contains the WaitMap used for replying to /// open requests. pub incoming: Arc, - /// This is a stream of new subscriptions that should be tracked - pub new_subs: Receiver, } /// A single postcard-rpc frame pub struct RpcFrame { /// The wire header - pub header: WireHeader, + pub header: VarHeader, /// The serialized message payload pub body: Vec, } @@ -467,7 +512,7 @@ pub struct RpcFrame { impl RpcFrame { /// Serialize the `RpcFrame` into a Vec of bytes pub fn to_bytes(&self) -> Vec { - let mut out = postcard::to_stdvec(&self.header).expect("Alloc should never fail"); + let mut out = self.header.write_to_vec(); out.extend_from_slice(&self.body); out } @@ -475,7 +520,8 @@ impl RpcFrame { /// Shared context between [HostClient] and the I/O worker task pub struct HostContext { - map: WaitMap>, + kkind: RwLock, + map: WaitMap)>, seq: AtomicU32, } @@ -495,7 +541,7 @@ impl HostContext { /// Like `HostContext::process` but tells you if we processed the message or /// nobody wanted it pub fn process_did_wake(&self, frame: RpcFrame) -> Result { - match self.map.wake(&frame.header, frame.body) { + match self.map.wake(&frame.header, (frame.header, frame.body)) { WakeOutcome::Woke => Ok(true), WakeOutcome::NoMatch(_) => Ok(false), WakeOutcome::Closed(_) => Err(ProcessError::Closed), @@ -506,7 +552,7 @@ impl HostContext { /// /// Returns an Err if the map was closed. pub fn process(&self, frame: RpcFrame) -> Result<(), ProcessError> { - if let WakeOutcome::Closed(_) = self.map.wake(&frame.header, frame.body) { + if let WakeOutcome::Closed(_) = self.map.wake(&frame.header, (frame.header, frame.body)) { Err(ProcessError::Closed) } else { Ok(()) diff --git a/source/postcard-rpc/src/host_client/raw_nusb.rs b/source/postcard-rpc/src/host_client/raw_nusb.rs index e1f2adb..55752f6 100644 --- a/source/postcard-rpc/src/host_client/raw_nusb.rs +++ b/source/postcard-rpc/src/host_client/raw_nusb.rs @@ -1,3 +1,5 @@ +//! Implementation of transport using nusb + use std::future::Future; use nusb::{ @@ -7,7 +9,10 @@ use nusb::{ use postcard_schema::Schema; use serde::de::DeserializeOwned; -use crate::host_client::{HostClient, WireRx, WireSpawn, WireTx}; +use crate::{ + header::VarSeqKind, + host_client::{HostClient, WireRx, WireSpawn, WireTx}, +}; // TODO: These should all be configurable, PRs welcome @@ -48,6 +53,7 @@ where /// /// ```rust,no_run /// use postcard_rpc::host_client::HostClient; + /// use postcard_rpc::header::VarSeqKind; /// use serde::{Serialize, Deserialize}; /// use postcard_schema::Schema; /// @@ -65,12 +71,15 @@ where /// "error", /// // Outgoing queue depth in messages /// 8, + /// // Use one-byte sequence numbers + /// VarSeqKind::Seq1, /// ).unwrap(); /// ``` pub fn try_new_raw_nusb bool>( func: F, err_uri_path: &str, outgoing_depth: usize, + seq_no_kind: VarSeqKind, ) -> Result { let x = nusb::list_devices() .map_err(|e| format!("Error listing devices: {e:?}"))? @@ -93,6 +102,7 @@ where consecutive_errs: 0, }, NusbSpawn, + seq_no_kind, err_uri_path, outgoing_depth, )) @@ -108,6 +118,7 @@ where /// /// ```rust,no_run /// use postcard_rpc::host_client::HostClient; + /// use postcard_rpc::header::VarSeqKind; /// use serde::{Serialize, Deserialize}; /// use postcard_schema::Schema; /// @@ -125,14 +136,17 @@ where /// "error", /// // Outgoing queue depth in messages /// 8, + /// // Use one-byte sequence numbers + /// VarSeqKind::Seq1, /// ); /// ``` pub fn new_raw_nusb bool>( func: F, err_uri_path: &str, outgoing_depth: usize, + seq_no_kind: VarSeqKind, ) -> Self { - Self::try_new_raw_nusb(func, err_uri_path, outgoing_depth) + Self::try_new_raw_nusb(func, err_uri_path, outgoing_depth, seq_no_kind) .expect("should have found nusb device") } } diff --git a/source/postcard-rpc/src/host_client/serial.rs b/source/postcard-rpc/src/host_client/serial.rs index 9f748b9..98e2a34 100644 --- a/source/postcard-rpc/src/host_client/serial.rs +++ b/source/postcard-rpc/src/host_client/serial.rs @@ -8,6 +8,7 @@ use tokio_serial::{SerialPortBuilderExt, SerialStream}; use crate::{ accumulator::raw::{CobsAccumulator, FeedResult}, + header::VarSeqKind, host_client::{HostClient, WireRx, WireSpawn, WireTx}, }; @@ -31,6 +32,7 @@ where /// /// ```rust,no_run /// use postcard_rpc::host_client::HostClient; + /// use postcard_rpc::header::VarSeqKind; /// use serde::{Serialize, Deserialize}; /// use postcard_schema::Schema; /// @@ -51,6 +53,8 @@ where /// // Baud rate of serial (does not generally matter for /// // USB UART/CDC-ACM serial connections) /// 115_200, + /// // Use one-byte sequence numbers + /// VarSeqKind::Seq1, /// ); /// ``` pub fn try_new_serial_cobs( @@ -58,6 +62,7 @@ where err_uri_path: &str, outgoing_depth: usize, baud: u32, + seq_no_kind: VarSeqKind, ) -> Result { let port = tokio_serial::new(serial_path, baud) .open_native_async() @@ -74,6 +79,7 @@ where pending: VecDeque::new(), }, SerialSpawn, + seq_no_kind, err_uri_path, outgoing_depth, )) @@ -89,8 +95,10 @@ where err_uri_path: &str, outgoing_depth: usize, baud: u32, + seq_no_kind: VarSeqKind, ) -> Self { - Self::try_new_serial_cobs(serial_path, err_uri_path, outgoing_depth, baud).unwrap() + Self::try_new_serial_cobs(serial_path, err_uri_path, outgoing_depth, baud, seq_no_kind) + .unwrap() } } diff --git a/source/postcard-rpc/src/host_client/test_channels.rs b/source/postcard-rpc/src/host_client/test_channels.rs new file mode 100644 index 0000000..a6df2d3 --- /dev/null +++ b/source/postcard-rpc/src/host_client/test_channels.rs @@ -0,0 +1,85 @@ +//! A Client implementation using channels for testing + +use crate::{ + header::VarSeqKind, + host_client::{HostClient, WireRx, WireSpawn, WireTx}, + standard_icd::WireError, +}; +use core::fmt::Display; +use tokio::sync::mpsc; + +/// Create a new HostClient from the given server channels +pub fn new_from_channels( + tx: mpsc::Sender>, + rx: mpsc::Receiver>, + seq_kind: VarSeqKind, +) -> HostClient { + HostClient::new_with_wire( + ChannelTx { tx }, + ChannelRx { rx }, + TokSpawn, + seq_kind, + crate::standard_icd::ERROR_PATH, + 64, + ) +} + +/// Server error kinds +#[derive(Debug)] +pub enum ChannelError { + /// Rx was closed + RxClosed, + /// Tx was closed + TxClosed, +} + +impl Display for ChannelError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) + } +} + +impl std::error::Error for ChannelError {} + +/// Trait impl for channels +pub struct ChannelRx { + rx: mpsc::Receiver>, +} +/// Trait impl for channels +pub struct ChannelTx { + tx: mpsc::Sender>, +} +/// Trait impl for channels +pub struct ChannelSpawn; + +impl WireRx for ChannelRx { + type Error = ChannelError; + + async fn receive(&mut self) -> Result, Self::Error> { + match self.rx.recv().await { + Some(v) => { + #[cfg(test)] + println!("c<-s: {v:?}"); + Ok(v) + } + None => Err(ChannelError::RxClosed), + } + } +} + +impl WireTx for ChannelTx { + type Error = ChannelError; + + async fn send(&mut self, data: Vec) -> Result<(), Self::Error> { + #[cfg(test)] + println!("c->s: {data:?}"); + self.tx.send(data).await.map_err(|_| ChannelError::TxClosed) + } +} + +struct TokSpawn; +impl WireSpawn for TokSpawn { + fn spawn(&mut self, fut: impl std::future::Future + Send + 'static) { + _ = tokio::task::spawn(fut); + } +} diff --git a/source/postcard-rpc/src/host_client/util.rs b/source/postcard-rpc/src/host_client/util.rs index 95c60fe..8222da1 100644 --- a/source/postcard-rpc/src/host_client/util.rs +++ b/source/postcard-rpc/src/host_client/util.rs @@ -1,5 +1,5 @@ // the contents of this file can probably be moved up to `mod.rs` -use std::{collections::HashMap, fmt::Debug, sync::Arc}; +use std::{fmt::Debug, sync::Arc}; use maitake_sync::WaitQueue; use postcard_schema::Schema; @@ -14,15 +14,18 @@ use tokio::{ use tracing::{debug, trace, warn}; use crate::{ - headered::extract_header_from_bytes, + header::{VarHeader, VarKey, VarSeqKind}, host_client::{ - HostClient, HostContext, ProcessError, RpcFrame, SubInfo, WireContext, WireRx, WireSpawn, - WireTx, + HostClient, HostContext, ProcessError, RpcFrame, WireContext, WireRx, WireSpawn, WireTx, }, Key, }; -pub(crate) type Subscriptions = HashMap>; +#[derive(Default, Debug)] +pub(crate) struct Subscriptions { + pub(crate) list: Vec<(Key, Sender)>, + pub(crate) stopped: bool, +} /// A basic cancellation-token /// @@ -74,6 +77,7 @@ where tx: WTX, rx: WRX, mut sp: WSP, + seq_kind: VarSeqKind, err_uri_path: &str, outgoing_depth: usize, ) -> Self @@ -82,24 +86,17 @@ where WRX: WireRx, WSP: WireSpawn, { - let (me, wire_ctx) = Self::new_manual_priv(err_uri_path, outgoing_depth); - - let WireContext { - outgoing, - incoming, - new_subs, - } = wire_ctx; + let (me, wire_ctx) = Self::new_manual_priv(err_uri_path, outgoing_depth, seq_kind); - let subscriptions: Arc> = Arc::new(Mutex::new(Subscriptions::new())); + let WireContext { outgoing, incoming } = wire_ctx; sp.spawn(out_worker(tx, outgoing, me.stopper.clone())); sp.spawn(in_worker( rx, incoming, - subscriptions.clone(), + me.subscriptions.clone(), me.stopper.clone(), )); - sp.spawn(sub_worker(new_subs, subscriptions, me.stopper.clone())); me } @@ -150,7 +147,7 @@ async fn in_worker( W::Error: Debug, { let cancel_fut = stop.wait_stopped(); - let operate_fut = in_worker_inner(wire, host_ctx, subscriptions); + let operate_fut = in_worker_inner(wire, host_ctx, subscriptions.clone()); select! { _ = cancel_fut => {}, _ = operate_fut => { @@ -158,6 +155,11 @@ async fn in_worker( stop.stop(); }, } + // If we stop, purge the subscription list so that it is clear that no more messages are coming + // TODO: Have a "stopped" flag to prevent later additions (e.g. sub after store?) + let mut guard = subscriptions.lock().await; + guard.stopped = true; + guard.list.clear(); } async fn in_worker_inner( @@ -174,7 +176,7 @@ async fn in_worker_inner( return; }; - let Ok((hdr, body)) = extract_header_from_bytes(&res) else { + let Some((hdr, body)) = VarHeader::take_from_slice(&res) else { warn!("Header decode error!"); continue; }; @@ -188,10 +190,14 @@ async fn in_worker_inner( let key = hdr.key; // Remove if sending fails - let remove_sub = if let Some(m) = subs_guard.get(&key) { + let remove_sub = if let Some((_h, m)) = subs_guard + .list + .iter() + .find(|(k, _)| VarKey::Key8(*k) == key) + { handled = true; let frame = RpcFrame { - header: hdr.clone(), + header: hdr, body: body.to_vec(), }; let res = m.try_send(frame); @@ -213,7 +219,7 @@ async fn in_worker_inner( if remove_sub { debug!("Dropping subscription"); - subs_guard.remove(&key); + subs_guard.list.retain(|(k, _)| VarKey::Key8(*k) != key); } } @@ -236,31 +242,3 @@ async fn in_worker_inner( } } } - -async fn sub_worker( - new_subs: Receiver, - subscriptions: Arc>, - stop: Stopper, -) { - let cancel_fut = stop.wait_stopped(); - let operate_fut = sub_worker_inner(new_subs, subscriptions); - select! { - _ = cancel_fut => {}, - _ = operate_fut => { - // if WE exited, notify everyone else it's stoppin time - stop.stop(); - }, - } -} - -async fn sub_worker_inner( - mut new_subs: Receiver, - subscriptions: Arc>, -) { - while let Some(sub) = new_subs.recv().await { - let mut sub_guard = subscriptions.lock().await; - if let Some(_old) = sub_guard.insert(sub.key, sub.tx) { - warn!("Replacing old subscription for {:?}", sub.key); - } - } -} diff --git a/source/postcard-rpc/src/host_client/webusb.rs b/source/postcard-rpc/src/host_client/webusb.rs index d2f8d48..8609f82 100644 --- a/source/postcard-rpc/src/host_client/webusb.rs +++ b/source/postcard-rpc/src/host_client/webusb.rs @@ -1,3 +1,5 @@ +//! Implementation of transport using webusb + use gloo::utils::format::JsValueSerdeExt; use postcard_schema::Schema; use serde::de::DeserializeOwned; @@ -7,8 +9,10 @@ use wasm_bindgen::{prelude::*, JsCast}; use wasm_bindgen_futures::{spawn_local, JsFuture}; use web_sys::{UsbDevice, UsbInTransferResult, UsbTransferStatus}; +use crate::header::VarSeqKind; use crate::host_client::{HostClient, WireRx, WireSpawn, WireTx}; +/// Implementation of the wire interface for WebUsb #[derive(Clone)] pub struct WebUsbWire { device: UsbDevice, @@ -17,10 +21,13 @@ pub struct WebUsbWire { ep_out: u8, } +/// WebUsb Error type #[derive(thiserror::Error, Debug, Clone)] pub enum Error { + /// Error originating from the browser #[error("Browser error: {0}")] Browser(String), + /// Error originating from the USB stack #[error("USB transfer error: {0}")] UsbTransfer(&'static str), } @@ -49,6 +56,7 @@ impl HostClient where WireErr: DeserializeOwned + Schema, { + /// Create a new webusb connection instance pub async fn try_new_webusb( vendor_id: u16, interface: u8, @@ -57,6 +65,7 @@ where ep_out: u8, err_uri_path: &str, outgoing_depth: usize, + seq_no_len: VarSeqKind, ) -> Result { let wire = WebUsbWire::new(vendor_id, interface, transfer_max_length, ep_in, ep_out).await?; @@ -64,6 +73,7 @@ where wire.clone(), wire.clone(), wire, + seq_no_len, err_uri_path, outgoing_depth, )) @@ -71,6 +81,7 @@ where } /// # Example usage () +/// /// ```no_run /// let wire = WebUsbWire::new(0x16c0, 0, 1000, 1, 1) /// .await @@ -82,10 +93,12 @@ where /// wire, /// crate::standard_icd::ERROR_PATH, /// 8, +/// VarSeqKind::Seq1, /// ) /// .expect("could not create HostClient"); /// ``` impl WebUsbWire { + /// Create a new instance of [`WebUsbWire`] pub async fn new( vendor_id: u16, interface: u8, @@ -146,11 +159,7 @@ impl WireTx for WebUsbWire { let data: js_sys::Uint8Array = data.as_slice().into(); // TODO for reasons unknown, web-sys wants mutable access to the send buffer. // tracking issue: https://github.com/rustwasm/wasm-bindgen/issues/3963 - JsFuture::from( - self.device - .transfer_out_with_u8_array(self.ep_out, &data)?, - ) - .await?; + JsFuture::from(self.device.transfer_out_with_u8_array(self.ep_out, &data)?).await?; Ok(()) } } diff --git a/source/postcard-rpc/src/lib.rs b/source/postcard-rpc/src/lib.rs index 5d34d15..34bde2b 100644 --- a/source/postcard-rpc/src/lib.rs +++ b/source/postcard-rpc/src/lib.rs @@ -1,12 +1,67 @@ //! The goal of `postcard-rpc` is to make it easier for a //! host PC to talk to a constrained device, like a microcontroller. //! -//! See [the repo] for examples, and [the overview] for more details on how -//! to use this crate. +//! See [the repo] for examples //! //! [the repo]: https://github.com/jamesmunns/postcard-rpc //! [the overview]: https://github.com/jamesmunns/postcard-rpc/blob/main/docs/overview.md //! +//! ## Architecture overview +//! +//! ```text +//! ┌──────────┐ ┌─────────┐ ┌───────────┐ +//! │ Endpoint │ │ Publish │ │ Subscribe │ +//! └──────────┘ └─────────┘ └───────────┘ +//! │ ▲ message│ │ ▲ +//! ┌────────┐ rqst│ │resp │ subscribe│ │messages +//! ┌─┤ CLIENT ├─────┼─────┼──────────────┼────────────────┼────────┼──┐ +//! │ └────────┘ ▼ │ ▼ ▼ │ │ +//! │ ┌─────────────────────────────────────────────────────┐ │ │ +//! │ │ HostClient │ │ │ +//! │ └─────────────────────────────────────────────────────┘ │ │ +//! │ │ │ ▲ │ | │ +//! │ │ │ │ │ │ │ +//! │ │ │ │ ▼ │ │ +//! │ │ │ ┌──────────────┬──────────────┐│ +//! │ │ └─────▶│ Pending Resp │ Subscription ││ +//! │ │ └──────────────┴──────────────┘│ +//! │ │ ▲ ▲ │ +//! │ │ └───────┬──────┘ │ +//! │ ▼ │ │ +//! │ ┌────────────────────┐ ┌────────────────────┐ │ +//! │ ││ Task: out_worker │ │ Task: in_worker ▲│ │ +//! │ ├┼───────────────────┤ ├───────────────────┼┤ │ +//! │ │▼ Trait: WireTx │ │ Trait: WireRx ││ │ +//! └──────┴────────────────────┴────────────┴────────────────────┴────┘ +//! │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ ▲ +//! │ The Server + Client WireRx │ +//! │ │ and WireTx traits can be │ │ +//! │ impl'd for any wire │ +//! │ │ transport like USB, TCP, │ │ +//! │ I2C, UART, etc. │ +//! ▼ └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ │ +//! ┌─────┬────────────────────┬────────────┬────────────────────┬─────┐ +//! │ ││ Trait: WireRx │ │ Trait: WireTx ▲│ │ +//! │ ├┼───────────────────┤ ├───────────────────┼┤ │ +//! │ ││ Server │ ┌───▶│ Sender ││ │ +//! │ ├┼───────────────────┤ │ └────────────────────┘ │ +//! │ │▼ Macro: Dispatch │ │ ▲ │ +//! │ └────────────────────┘ │ │ │ +//! │ ┌─────────┐ │ ┌──────────┐ │ ┌───────────┐ │ ┌───────────┐ │ +//! │ │ Topic │ │ │ Endpoint │ │ │ Publisher │ │ │ Publisher │ │ +//! │ │ fn │◀┼▶│ async fn │────┤ │ Task │─┼─│ Task │ │ +//! │ │ Handler │ │ │ Handler │ │ └───────────┘ │ └───────────┘ │ +//! │ └─────────┘ │ └──────────┘ │ │ │ +//! │ ┌─────────┐ │ ┌──────────┐ │ ┌───────────┐ │ ┌───────────┐ │ +//! │ │ Topic │ │ │ Endpoint │ │ │ Publisher │ │ │ Publisher │ │ +//! │ │async fn │◀┴▶│ task │────┘ │ Task │─┴─│ Task │ │ +//! │ │ Handler │ │ Handler │ └───────────┘ └───────────┘ │ +//! │ └─────────┘ └──────────┘ │ +//! │ ┌────────┐ │ +//! └─┤ SERVER ├───────────────────────────────────────────────────────┘ +//! └────────┘ +//! ``` +//! //! ## Defining a schema //! //! Typically, you will define your "wire types" in a shared schema crate. This @@ -24,47 +79,9 @@ //! convert a type into bytes on the wire //! * [`serde`]'s [`Deserialize`] trait - which defines how we //! can convert bytes on the wire into a type -//! * [`postcard`]'s [`Schema`] trait - which generates a reflection-style +//! * [`postcard_schema`]'s [`Schema`] trait - which generates a reflection-style //! schema value for a given type. //! -//! Here's an example of three types we'll use in future examples: -//! -//! ```rust -//! // Consider making your shared "wire types" crate conditionally no-std, -//! // if you want to use it with no-std embedded targets! This makes it no_std -//! // except for testing and when the "use-std" feature is active. -//! // -//! // You may need to also ensure that `std`/`use-std` features are not active -//! // in any dependencies as well. -//! #![cfg_attr(not(any(test, feature = "use-std")), no_std)] -//! # fn main() {} -//! -//! use serde::{Serialize, Deserialize}; -//! use postcard_schema::Schema; -//! -//! #[derive(Serialize, Deserialize, Schema)] -//! pub struct Alpha { -//! pub one: u8, -//! pub two: i64, -//! } -//! -//! #[derive(Serialize, Deserialize, Schema)] -//! pub enum Beta { -//! Bib, -//! Bim(i16), -//! Bap, -//! } -//! -//! #[derive(Serialize, Deserialize, Schema)] -//! pub struct Delta(pub [u8; 32]); -//! -//! #[derive(Serialize, Deserialize, Schema)] -//! pub enum WireError { -//! ALittleBad, -//! VeryBad, -//! } -//! ``` -//! //! ### Endpoints //! //! Now that we have some basic types that will be used on the wire, we need @@ -80,42 +97,6 @@ //! * The type of the Response //! * A string "path", like an HTTP URI that uniquely identifies the endpoint. //! -//! The easiest way to define an Endpoint is to use the [`endpoint!`][endpoint] -//! macro. -//! -//! ```rust -//! # use serde::{Serialize, Deserialize}; -//! # use postcard_schema::Schema; -//! # -//! # #[derive(Serialize, Deserialize, Schema)] -//! # pub struct Alpha { -//! # pub one: u8, -//! # pub two: i64, -//! # } -//! # -//! # #[derive(Serialize, Deserialize, Schema)] -//! # pub enum Beta { -//! # Bib, -//! # Bim(i16), -//! # Bap, -//! # } -//! # -//! use postcard_rpc::endpoint; -//! -//! // Define an endpoint -//! endpoint!( -//! // This is the name of a marker type that represents our Endpoint, -//! // and implements the `Endpoint` trait. -//! FirstEndpoint, -//! // This is the request type for this endpoint -//! Alpha, -//! // This is the response type for this endpoint -//! Beta, -//! // This is the path/URI of the endpoint -//! "endpoints/first", -//! ); -//! ``` -//! //! ### Topics //! //! Sometimes, you would just like to send data in a single direction, with no @@ -129,60 +110,20 @@ //! //! * The type of the Message //! * A string "path", like an HTTP URI that uniquely identifies the topic. -//! -//! The easiest way to define a Topic is to use the [`topic!`][topic] -//! macro. -//! -//! ```rust -//! # use serde::{Serialize, Deserialize}; -//! # use postcard_schema::Schema; -//! # -//! # #[derive(Serialize, Deserialize, Schema)] -//! # pub struct Delta(pub [u8; 32]); -//! # -//! use postcard_rpc::topic; -//! -//! // Define a topic -//! topic!( -//! // This is the name of a marker type that represents our Topic, -//! // and implements `Topic` trait. -//! FirstTopic, -//! // This is the message type for the endpoint (note there is no -//! // response type!) -//! Delta, -//! // This is the path/URI of the topic -//! "topics/first", -//! ); -//! ``` -//! -//! ## Using a schema -//! -//! At the moment, this library is primarily oriented around: -//! -//! * A single Client, usually a PC, with access to `std` -//! * A single Server, usually an MCU, without access to `std` -//! -//! For Client facilities, check out the [`host_client`] module, -//! particularly the [`HostClient`][host_client::HostClient] struct. -//! This is only available with the `use-std` feature active. -//! -//! A serial-port transport using cobs encoding is available with the `cobs-serial` feature. -//! This feature will add the [`new_serial_cobs`][host_client::HostClient::new_serial_cobs] constructor to [`HostClient`][host_client::HostClient]. -//! -//! For Server facilities, check out the [`Dispatch`] struct. This is -//! available with or without the standard library. #![cfg_attr(not(any(test, feature = "use-std")), no_std)] +#![deny(missing_docs)] +#![deny(rustdoc::broken_intra_doc_links)] -use headered::extract_header_from_bytes; -use postcard_schema::Schema; +use header::{VarKey, VarKeyKind}; +use postcard_schema::{schema::NamedType, Schema}; use serde::{Deserialize, Serialize}; #[cfg(feature = "cobs")] pub mod accumulator; pub mod hash; -pub mod headered; +pub mod header; #[cfg(feature = "use-std")] pub mod host_client; @@ -190,121 +131,11 @@ pub mod host_client; #[cfg(any(test, feature = "test-utils"))] pub mod test_utils; -#[cfg(feature = "embassy-usb-0_3-server")] -pub mod target_server; - mod macros; -/// Error type for [Dispatch] -#[derive(Debug, PartialEq)] -pub enum Error { - /// No handler was found for the given message. - /// The decoded key and sequence number are returned - NoMatchingHandler { key: Key, seq_no: u32 }, - /// The handler returned an error - DispatchFailure(E), - /// An error when decoding messages - Postcard(postcard::Error), -} +pub mod server; -impl From for Error { - fn from(value: postcard::Error) -> Self { - Self::Postcard(value) - } -} - -/// Dispatch is the primary interface for MCU "server" devices. -/// -/// Dispatch is generic over three types: -/// -/// 1. The `Context`, which will be passed as a mutable reference -/// to each of the handlers. It typically should contain -/// whatever resource is necessary to send replies back to -/// the host. -/// 2. The `Error` type, which can be returned by handlers -/// 3. `N`, for the maximum number of handlers -/// -/// If you plan to use COBS encoding, you can also use [CobsDispatch]. -/// which will automatically handle accumulating bytes from the wire. -/// -/// [CobsDispatch]: crate::accumulator::dispatch::CobsDispatch -/// Note: This will be available when the `cobs` or `cobs-serial` feature is enabled. -pub struct Dispatch { - items: heapless::Vec<(Key, Handler), N>, - context: Context, -} - -impl Dispatch { - /// Create a new [Dispatch] - pub fn new(c: Context) -> Self { - Self { - items: heapless::Vec::new(), - context: c, - } - } - - /// Add a handler to the [Dispatch] for the given path and type - /// - /// Returns an error if the given type+path have already been added, - /// or if Dispatch is full. - pub fn add_handler( - &mut self, - handler: Handler, - ) -> Result<(), &'static str> { - if self.items.is_full() { - return Err("full"); - } - let id = E::REQ_KEY; - if self.items.iter().any(|(k, _)| k == &id) { - return Err("dupe"); - } - let _ = self.items.push((id, handler)); - - // TODO: Why does this throw lifetime errors? - // self.items.sort_unstable_by_key(|(k, _)| k); - Ok(()) - } - - /// Accessor function for the Context field - pub fn context(&mut self) -> &mut Context { - &mut self.context - } - - /// Attempt to dispatch the given message - /// - /// The bytes should consist of exactly one message (including the header). - /// - /// Returns an error in any of the following cases: - /// - /// * We failed to decode a header - /// * No handler was found for the decoded key - /// * The handler ran, but returned an error - pub fn dispatch(&mut self, bytes: &[u8]) -> Result<(), Error> { - let (hdr, remain) = extract_header_from_bytes(bytes)?; - - // TODO: switch to binary search once we sort? - let Some(disp) = self - .items - .iter() - .find_map(|(k, d)| if k == &hdr.key { Some(d) } else { None }) - else { - return Err(Error::::NoMatchingHandler { - key: hdr.key, - seq_no: hdr.seq_no, - }); - }; - (disp)(&hdr, &mut self.context, remain).map_err(Error::DispatchFailure) - } -} - -type Handler = fn(&WireHeader, &mut C, &[u8]) -> Result<(), E>; - -/// The WireHeader is appended to all messages -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct WireHeader { - pub key: Key, - pub seq_no: u32, -} +pub mod uniques; /// The `Key` uniquely identifies what "kind" of message this is. /// @@ -381,12 +212,213 @@ mod key_owned { use super::*; use postcard_schema::schema::owned::OwnedNamedType; impl Key { + /// Calculate the Key for the given path and [`OwnedNamedType`] pub fn for_owned_schema_path(path: &str, nt: &OwnedNamedType) -> Key { Key(crate::hash::fnv1a64_owned::hash_ty_path_owned(path, nt)) } } } +/// A compacted 2-byte key +/// +/// This is defined specifically as the following conversion: +/// +/// * Key8 bytes (`[u8; 8]`): `[a, b, c, d, e, f, g, h]` +/// * Key4 bytes (`u8`): `a ^ b ^ c ^ d ^ e ^ f ^ g ^ h` +#[derive(Debug, Copy, Clone)] +pub struct Key1(u8); + +/// A compacted 2-byte key +/// +/// This is defined specifically as the following conversion: +/// +/// * Key8 bytes (`[u8; 8]`): `[a, b, c, d, e, f, g, h]` +/// * Key4 bytes (`[u8; 2]`): `[a ^ b ^ c ^ d, e ^ f ^ g ^ h]` +#[derive(Debug, Copy, Clone)] +pub struct Key2([u8; 2]); + +/// A compacted 4-byte key +/// +/// This is defined specifically as the following conversion: +/// +/// * Key8 bytes (`[u8; 8]`): `[a, b, c, d, e, f, g, h]` +/// * Key4 bytes (`[u8; 4]`): `[a ^ b, c ^ d, e ^ f, g ^ h]` +#[derive(Debug, Copy, Clone)] +pub struct Key4([u8; 4]); + +impl Key1 { + /// Convert from a 2-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key2(value: Key2) -> Self { + let [a, b] = value.0; + Self(a ^ b) + } + + /// Convert from a 4-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key4(value: Key4) -> Self { + let [a, b, c, d] = value.0; + Self(a ^ b ^ c ^ d) + } + + /// Convert from a full size 8-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key8(value: Key) -> Self { + let [a, b, c, d, e, f, g, h] = value.0; + Self(a ^ b ^ c ^ d ^ e ^ f ^ g ^ h) + } + + /// Convert to the inner byte representation + #[inline] + pub const fn to_bytes(&self) -> u8 { + self.0 + } + + /// Create a `Key1` from a [`VarKey`] + /// + /// This method can never fail, but has the same API as other key + /// types for consistency reasons. + #[inline] + pub fn try_from_varkey(value: &VarKey) -> Option { + Some(match value { + VarKey::Key1(key1) => *key1, + VarKey::Key2(key2) => Key1::from_key2(*key2), + VarKey::Key4(key4) => Key1::from_key4(*key4), + VarKey::Key8(key) => Key1::from_key8(*key), + }) + } +} + +impl Key2 { + /// Convert from a 4-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key4(value: Key4) -> Self { + let [a, b, c, d] = value.0; + Self([a ^ b, c ^ d]) + } + + /// Convert from a full size 8-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key8(value: Key) -> Self { + let [a, b, c, d, e, f, g, h] = value.0; + Self([a ^ b ^ c ^ d, e ^ f ^ g ^ h]) + } + + /// Convert to the inner byte representation + #[inline] + pub const fn to_bytes(&self) -> [u8; 2] { + self.0 + } + + /// Attempt to create a [`Key2`] from a [`VarKey`]. + /// + /// Only succeeds if `value` is a `VarKey::Key2`, `VarKey::Key4`, or `VarKey::Key8`. + #[inline] + pub fn try_from_varkey(value: &VarKey) -> Option { + Some(match value { + VarKey::Key1(_) => return None, + VarKey::Key2(key2) => *key2, + VarKey::Key4(key4) => Key2::from_key4(*key4), + VarKey::Key8(key) => Key2::from_key8(*key), + }) + } +} + +impl Key4 { + /// Convert from a full size 8-byte key + /// + /// This is a lossy conversion, and can never fail + #[inline] + pub const fn from_key8(value: Key) -> Self { + let [a, b, c, d, e, f, g, h] = value.0; + Self([a ^ b, c ^ d, e ^ f, g ^ h]) + } + + /// Convert to the inner byte representation + #[inline] + pub const fn to_bytes(&self) -> [u8; 4] { + self.0 + } + + /// Attempt to create a [`Key4`] from a [`VarKey`]. + /// + /// Only succeeds if `value` is a `VarKey::Key4` or `VarKey::Key8`. + #[inline] + pub fn try_from_varkey(value: &VarKey) -> Option { + Some(match value { + VarKey::Key1(_) => return None, + VarKey::Key2(_) => return None, + VarKey::Key4(key4) => *key4, + VarKey::Key8(key) => Key4::from_key8(*key), + }) + } +} + +impl Key { + /// This is an identity function, used for consistency + #[inline] + pub const fn from_key8(value: Key) -> Self { + value + } + + /// Attempt to create a [`Key`] from a [`VarKey`]. + /// + /// Only succeeds if `value` is a `VarKey::Key8`. + #[inline] + pub fn try_from_varkey(value: &VarKey) -> Option { + match value { + VarKey::Key8(key) => Some(*key), + _ => None, + } + } +} + +impl From for Key1 { + fn from(value: Key2) -> Self { + Self::from_key2(value) + } +} + +impl From for Key1 { + fn from(value: Key4) -> Self { + Self::from_key4(value) + } +} + +impl From for Key1 { + fn from(value: Key) -> Self { + Self::from_key8(value) + } +} + +impl From for Key2 { + fn from(value: Key4) -> Self { + Self::from_key4(value) + } +} + +impl From for Key2 { + fn from(value: Key) -> Self { + Self::from_key8(value) + } +} + +impl From for Key4 { + fn from(value: Key) -> Self { + Self::from_key8(value) + } +} + /// A marker trait denoting a single endpoint /// /// Typically used with the [endpoint] macro. @@ -427,28 +459,49 @@ pub mod standard_icd { use postcard_schema::Schema; use serde::{Deserialize, Serialize}; + /// The calculated Key for the type [`WireError`] and the path [`ERROR_PATH`] pub const ERROR_KEY: Key = Key::for_path::(ERROR_PATH); + + /// The path string used for the error type pub const ERROR_PATH: &str = "error"; + /// The given frame was too long #[derive(Serialize, Deserialize, Schema, Debug, PartialEq)] pub struct FrameTooLong { + /// The length of the too-long frame pub len: u32, + /// The maximum frame length supported pub max: u32, } + /// The given frame was too short #[derive(Serialize, Deserialize, Schema, Debug, PartialEq)] pub struct FrameTooShort { + /// The length of the too-short frame pub len: u32, } + /// A protocol error that is handled outside of the normal request type, usually + /// indicating a protocol-level error #[derive(Serialize, Deserialize, Schema, Debug, PartialEq)] pub enum WireError { + /// The frame exceeded the buffering capabilities of the server FrameTooLong(FrameTooLong), + /// The frame was shorter than the minimum frame size and was rejected FrameTooShort(FrameTooShort), + /// Deserialization of a message failed DeserFailed, + /// Serialization of a message failed, usually due to a lack of space to + /// buffer the serialized form SerFailed, - UnknownKey([u8; 8]), + /// The key associated with this request was unknown + UnknownKey, + /// The server was unable to spawn the associated handler, typically due + /// to an exhaustion of resources FailedToSpawn, + /// The provided key is below the minimum key size calculated to avoid hash + /// collisions, and was rejected to avoid potential misunderstanding + KeyTooSmall, } #[cfg(not(feature = "use-std"))] @@ -457,3 +510,49 @@ pub mod standard_icd { #[cfg(feature = "use-std")] crate::topic!(Logging, Vec, "logs/formatted"); } + +/// An overview of all topics (in and out) and endpoints +/// +/// Typically generated by the [`define_dispatch!()`] macro. Contains a list +/// of all unique types across endpoints and topics, as well as the endpoints, +/// topics in (client to server), topics out (server to client), as well as a +/// calculated minimum key length required to avoid collisions in either the in +/// or out direction. +pub struct DeviceMap { + /// The set of unique types used by all endpoints and topics in this map + pub types: &'static [&'static NamedType], + /// The list of endpoints by path string, request key, and response key + pub endpoints: &'static [(&'static str, Key, Key)], + /// The list of topics (client to server) by path string and topic key + pub topics_in: &'static [(&'static str, Key)], + /// The list of topics (server to client) by path string and topic key + pub topics_out: &'static [(&'static str, Key)], + /// The minimum key size required to avoid hash collisions + pub min_key_len: VarKeyKind, +} + +/// An overview of a list of endpoints +/// +/// Typically generated by the [`endpoints!()`] macro. Contains a list of +/// all unique types used by a list of endpoints, as well as the list of these +/// endpoints by path, request key, and response key +#[derive(Debug)] +pub struct EndpointMap { + /// The set of unique types used by all endpoints in this map + pub types: &'static [&'static NamedType], + /// The list of endpoints by path string, request key, and response key + pub endpoints: &'static [(&'static str, Key, Key)], +} + +/// An overview of a list of topics +/// +/// Typically generated by the [`topics!()`] macro. Contains a list of all +/// unique types used by a list of topics as well as the list of the topics +/// by path and key +#[derive(Debug)] +pub struct TopicMap { + /// The set of unique types used by all topics in this map + pub types: &'static [&'static NamedType], + /// The list of topics by path string and topic key + pub topics: &'static [(&'static str, Key)], +} diff --git a/source/postcard-rpc/src/macros.rs b/source/postcard-rpc/src/macros.rs index 7816364..bd047c7 100644 --- a/source/postcard-rpc/src/macros.rs +++ b/source/postcard-rpc/src/macros.rs @@ -3,6 +3,8 @@ /// Used to define a single Endpoint marker type that implements the /// [Endpoint][crate::Endpoint] trait. /// +/// Prefer the [`endpoints!()`][crate::endpoints]` macro instead. +/// /// ```rust /// # use postcard_schema::Schema; /// # use serde::{Serialize, Deserialize}; @@ -45,11 +47,96 @@ macro_rules! endpoint { }; } +/// ## Endpoints macro +/// +/// Used to define multiple Endpoint marker types that implements the +/// [Endpoint][crate::Endpoint] trait. +/// +/// ```rust +/// # use postcard_schema::Schema; +/// # use serde::{Serialize, Deserialize}; +/// use postcard_rpc::endpoints; +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Req1 { +/// a: u8, +/// b: u64, +/// } +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Resp1 { +/// c: [u8; 4], +/// d: i32, +/// } +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Req2 { +/// a: i8, +/// b: i64, +/// } +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Resp2 { +/// c: [i8; 4], +/// d: u32, +/// } +/// +/// endpoints!{ +/// list = ENDPOINTS_LIST; +/// | EndpointTy | RequestTy | ResponseTy | Path | +/// | ---------- | --------- | ---------- | ---- | +/// | Endpoint1 | Req1 | Resp1 | "endpoints/one" | +/// | Endpoint2 | Req2 | Resp2 | "endpoints/two" | +/// } +/// ``` +#[macro_export] +macro_rules! endpoints { + ( + list = $list_name:ident; + | EndpointTy | RequestTy | ResponseTy | Path | + | $(-)* | $(-)* | $(-)* | $(-)* | + $( | $ep_name:ident | $req_ty:ty | $resp_ty:ty | $path_str:literal | )* + ) => { + // struct definitions and trait impls + $( + pub struct $ep_name; + + impl $crate::Endpoint for $ep_name { + type Request = $req_ty; + type Response = $resp_ty; + const PATH: &'static str = $path_str; + const REQ_KEY: $crate::Key = $crate::Key::for_path::<$req_ty>($path_str); + const RESP_KEY: $crate::Key = $crate::Key::for_path::<$resp_ty>($path_str); + } + )* + + pub const $list_name: $crate::EndpointMap = $crate::EndpointMap { + types: &[ + $( + <$ep_name as $crate::Endpoint>::Request::SCHEMA, + <$ep_name as $crate::Endpoint>::Response::SCHEMA, + )* + ], + endpoints: &[ + $( + ( + <$ep_name as $crate::Endpoint>::PATH, + <$ep_name as $crate::Endpoint>::REQ_KEY, + <$ep_name as $crate::Endpoint>::RESP_KEY, + ), + )* + ], + }; + }; +} + /// ## Topic macro /// /// Used to define a single Topic marker type that implements the /// [Topic][crate::Topic] trait. /// +/// Prefer the [`topics!()` macro](crate::topics) macro. +/// /// ```rust /// # use postcard_schema::Schema; /// # use serde::{Serialize, Deserialize}; @@ -74,6 +161,9 @@ macro_rules! topic { topic!($tyname, $msg, $path) }; ($tyname:ident, $msg:ty, $path:expr) => { + /// $tyname - A Topic definition type + /// + /// Generated by the `topic!()` macro pub struct $tyname; impl $crate::Topic for $tyname { @@ -84,16 +174,132 @@ macro_rules! topic { }; } -#[cfg(feature = "embassy-usb-0_3-server")] +/// ## Topics macro +/// +/// Used to define multiple Topic marker types that implements the +/// [Topic][crate::Topic] trait. +/// +/// ```rust +/// # use postcard_schema::Schema; +/// # use serde::{Serialize, Deserialize}; +/// use postcard_rpc::topics; +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Message1 { +/// a: u8, +/// b: u64, +/// } +/// +/// #[derive(Debug, Serialize, Deserialize, Schema)] +/// pub struct Message2 { +/// a: i8, +/// b: i64, +/// } +/// +/// topics!{ +/// list = TOPIC_LIST_NAME; +/// | TopicTy | MessageTy | Path | +/// | ------- | --------- | ---- | +/// | Topic1 | Message1 | "topics/one" | +/// | Topic2 | Message2 | "topics/two" | +/// } +/// ``` #[macro_export] -macro_rules! sender_log { - ($sender:ident, $($arg:tt)*) => { - $sender.fmt_publish::<$crate::standard_icd::Logging>(format_args!($($arg)*)) - }; - ($sender:ident, $s:expr) => { - $sender.str_publish::<$crate::standard_icd::Logging>($s) +macro_rules! topics { + ( + list = $list_name:ident; + | TopicTy | MessageTy | Path | + | $(-)* | $(-)* | $(-)* | + $( | $tp_name:ident | $msg_ty:ty | $path_str:literal | )* + ) => { + // struct definitions and trait impls + $( + /// $tp_name - A Topic definition type + /// + /// Generated by the `topics!()` macro + pub struct $tp_name; + + impl $crate::Topic for $tp_name { + type Message = $msg_ty; + const PATH: &'static str = $path_str; + const TOPIC_KEY: $crate::Key = $crate::Key::for_path::<$msg_ty>($path_str); + } + )* + + pub const $list_name: $crate::TopicMap = $crate::TopicMap { + types: &[ + $( + <$tp_name as $crate::Topic>::Message::SCHEMA, + )* + ], + topics: &[ + $( + ( + <$tp_name as $crate::Topic>::PATH, + <$tp_name as $crate::Topic>::TOPIC_KEY, + ), + )* + ], + }; }; - ($($arg:tt)*) => { - compile_error!("You must pass the sender to `sender_log`!"); +} + +// TODO: bring this back when I sort out how to do formatting in the sender! +// This might require WireTx impls +// +// #[cfg(feature = "embassy-usb-0_3-server")] +// #[macro_export] +// macro_rules! sender_log { +// ($sender:ident, $($arg:tt)*) => { +// $sender.fmt_publish::<$crate::standard_icd::Logging>(format_args!($($arg)*)) +// }; +// ($sender:ident, $s:expr) => { +// $sender.str_publish::<$crate::standard_icd::Logging>($s) +// }; +// ($($arg:tt)*) => { +// compile_error!("You must pass the sender to `sender_log`!"); +// } +// } + +#[cfg(test)] +mod endpoints_test { + use postcard_schema::Schema; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Schema)] + pub struct AReq(pub u8); + #[derive(Serialize, Deserialize, Schema)] + pub struct AResp(pub u16); + #[derive(Serialize, Deserialize, Schema)] + pub struct BTopic(pub u32); + + endpoints! { + list = ENDPOINT_LIST; + | EndpointTy | RequestTy | ResponseTy | Path | + | ---------- | --------- | ---------- | ---- | + | AlphaEndpoint1 | AReq | AResp | "test/alpha1" | + | AlphaEndpoint2 | AReq | AResp | "test/alpha2" | + | AlphaEndpoint3 | AReq | AResp | "test/alpha3" | + } + + topics! { + list = TOPICS_IN_LIST; + | TopicTy | MessageTy | Path | + | ---------- | --------- | ---- | + | BetaTopic1 | BTopic | "test/beta1" | + | BetaTopic2 | BTopic | "test/beta2" | + | BetaTopic3 | BTopic | "test/beta3" | + } + + #[test] + fn eps() { + assert_eq!(ENDPOINT_LIST.types.len(), 6); + assert_eq!(ENDPOINT_LIST.endpoints.len(), 3); + } + + #[test] + fn tps() { + assert_eq!(TOPICS_IN_LIST.types.len(), 3); + assert_eq!(TOPICS_IN_LIST.topics.len(), 3); } } diff --git a/source/postcard-rpc/src/server/dispatch_macro.rs b/source/postcard-rpc/src/server/dispatch_macro.rs new file mode 100644 index 0000000..9e68731 --- /dev/null +++ b/source/postcard-rpc/src/server/dispatch_macro.rs @@ -0,0 +1,454 @@ +#[doc(hidden)] +pub mod export { + pub use paste::paste; +} + +/// Define Dispatch Macro +#[macro_export] +macro_rules! define_dispatch { + ////////////////////////////////////////////////////////////////////////////// + // ENDPOINT HANDLER EXPANSION ARMS + ////////////////////////////////////////////////////////////////////////////// + + // This is the "blocking execution" arm for defining an endpoint + (@ep_arm blocking ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + let reply = $handler($context, $header.clone(), $req); + if $outputter.reply::<$endpoint>($header.seq_no, &reply).await.is_err() { + let err = $crate::standard_icd::WireError::SerFailed; + $outputter.error($header.seq_no, err).await + } else { + Ok(()) + } + } + }; + // This is the "async execution" arm for defining an endpoint + (@ep_arm async ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + let reply = $handler($context, $header.clone(), $req).await; + if $outputter.reply::<$endpoint>($header.seq_no, &reply).await.is_err() { + let err = $crate::standard_icd::WireError::SerFailed; + $outputter.error($header.seq_no, err).await + } else { + Ok(()) + } + } + }; + // This is the "spawn an embassy task" arm for defining an endpoint + (@ep_arm spawn ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + let context = $crate::server::SpawnContext::spawn_ctxt($context); + if $spawn_fn($spawner, $handler(context, $header.clone(), $req, $outputter.clone())).is_err() { + let err = $crate::standard_icd::WireError::FailedToSpawn; + $outputter.error($header.seq_no, err).await + } else { + Ok(()) + } + } + }; + + ////////////////////////////////////////////////////////////////////////////// + // TOPIC HANDLER EXPANSION ARMS + ////////////////////////////////////////////////////////////////////////////// + + // This is the "blocking execution" arm for defining a topic + (@tp_arm blocking $handler:ident $context:ident $header:ident $msg:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + $handler($context, $header.clone(), $msg, $outputter); + } + }; + // This is the "async execution" arm for defining a topic + (@tp_arm async $handler:ident $context:ident $header:ident $msg:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + $handler($context, $header.clone(), $msg, $outputter).await; + } + }; + (@tp_arm spawn $handler:ident $context:ident $header:ident $msg:ident $outputter:ident ($spawn_fn:path) $spawner:ident) => { + { + let context = $crate::server::SpawnContext::spawn_ctxt($context); + let _ = $spawn_fn($spawner, $handler(context, $header.clone(), $msg, $outputter.clone())); + } + }; + + ////////////////////////////////////////////////////////////////////////////// + // Implementation of the dispatch trait for the app, where the Key length + // is N, where N is 1, 2, 4, or 8 + ////////////////////////////////////////////////////////////////////////////// + (@matcher + $n:literal $app_name:ident $tx_impl:ty; $spawn_fn:ident $key_ty:ty; $key_kind:expr; + ($($endpoint:ty | $ep_flavor:tt | $ep_handler:ident)*) + ($($topic_in:ty | $tp_flavor:tt | $tp_handler:ident)*) + ) => { + impl $crate::server::Dispatch for $app_name<$n> { + type Tx = $tx_impl; + + fn min_key_len(&self) -> $crate::header::VarKeyKind { + $key_kind + } + + /// Handle dispatching of a single frame + async fn handle( + &mut self, + tx: &$crate::server::Sender, + hdr: &$crate::header::VarHeader, + body: &[u8], + ) -> Result<(), ::Error> { + let key = hdr.key; + let Some(keyb) = <$key_ty>::try_from_varkey(&key) else { + let err = $crate::standard_icd::WireError::KeyTooSmall; + return tx.error(hdr.seq_no, err).await; + }; + let keyb = keyb.to_bytes(); + use consts::*; + match keyb { + $( + $crate::server::dispatch_macro::export::paste! { [<$endpoint:upper _KEY $n>] } => { + // Can we deserialize the request? + let Ok(req) = postcard::from_bytes::<<$endpoint as $crate::Endpoint>::Request>(body) else { + let err = $crate::standard_icd::WireError::DeserFailed; + return tx.error(hdr.seq_no, err).await; + }; + + // Store some items as named bindings, so we can use `ident` in the + // recursive macro expansion. Load bearing order: we borrow `context` + // from `dispatch` because we need `dispatch` AFTER `context`, so NLL + // allows this to still borrowck + let dispatch = self; + let context = &mut dispatch.context; + #[allow(unused)] + let spawninfo = &dispatch.spawn; + + // This will expand to the right "flavor" of handler + define_dispatch!(@ep_arm $ep_flavor ($endpoint) $ep_handler context hdr req tx ($spawn_fn) spawninfo) + } + )* + $( + $crate::server::dispatch_macro::export::paste! { [<$topic_in:upper _KEY $n>] } => { + // Can we deserialize the request? + let Ok(msg) = postcard::from_bytes::<<$topic_in as $crate::Topic>::Message>(body) else { + // This is a topic, not much to be done + return Ok(()); + }; + + // Store some items as named bindings, so we can use `ident` in the + // recursive macro expansion. Load bearing order: we borrow `context` + // from `dispatch` because we need `dispatch` AFTER `context`, so NLL + // allows this to still borrowck + let dispatch = self; + let context = &mut dispatch.context; + #[allow(unused)] + let spawninfo = &dispatch.spawn; + + // (@tp_arm async $handler:ident $context:ident $header:ident $req:ident $outputter:ident) + define_dispatch!(@tp_arm $tp_flavor $tp_handler context hdr msg tx ($spawn_fn) spawninfo); + Ok(()) + } + )* + _other => { + // huh! We have no idea what this key is supposed to be! + let err = $crate::standard_icd::WireError::UnknownKey; + tx.error(hdr.seq_no, err).await + }, + } + } + } + }; + + ////////////////////////////////////////////////////////////////////////////// + // MAIN EXPANSION ENTRYPOINT + ////////////////////////////////////////////////////////////////////////////// + ( + app: $app_name:ident; + + spawn_fn: $spawn_fn:ident; + tx_impl: $tx_impl:ty; + spawn_impl: $spawn_impl:ty; + context: $context_ty:ty; + + endpoints: { + list: $endpoint_list:ident; + + | EndpointTy | kind | handler | + | $(-)* | $(-)* | $(-)* | + $( | $endpoint:ty | $ep_flavor:tt | $ep_handler:ident | )* + }; + topics_in: { + list: $topic_in_list:ident; + + | TopicTy | kind | handler | + | $(-)* | $(-)* | $(-)* | + $( | $topic_in:ty | $tp_flavor:tt | $tp_handler:ident | )* + }; + topics_out: { + list: $topic_out_list:ident; + }; + ) => { + + // Here, we calculate how many bytes (1, 2, 4, or 8) are required to uniquely + // match on the given messages we receive and send†. + // + // This serves as a sort of "perfect hash function", allowing us to use fewer + // bytes on the wire. + // + // †: We don't calculate sending keys yet, oops. This probably requires hashing + // TX/RX differently so endpoints with the same TX and RX don't collide, or + // calculating them separately and taking the max + mod sizer { + use super::*; + use $crate::Key; + + // Create a list of JUST the REQUEST keys from the endpoint report + const EP_IN_KEYS_SZ: usize = $endpoint_list.endpoints.len(); + const EP_IN_KEYS: [Key; EP_IN_KEYS_SZ] = const { + let mut keys = [unsafe { Key::from_bytes([0; 8]) }; EP_IN_KEYS_SZ]; + let mut i = 0; + while i < EP_IN_KEYS_SZ { + keys[i] = $endpoint_list.endpoints[i].1; + i += 1; + } + keys + }; + // Create a list of JUST the RESPONSE keys from the endpoint report + const EP_OUT_KEYS_SZ: usize = $endpoint_list.endpoints.len(); + const EP_OUT_KEYS: [Key; EP_OUT_KEYS_SZ] = const { + let mut keys = [unsafe { Key::from_bytes([0; 8]) }; EP_OUT_KEYS_SZ]; + let mut i = 0; + while i < EP_OUT_KEYS_SZ { + keys[i] = $endpoint_list.endpoints[i].2; + i += 1; + } + keys + }; + // Create a list of JUST the MESSAGE keys from the TOPICS IN report + const TP_IN_KEYS_SZ: usize = $topic_in_list.topics.len(); + const TP_IN_KEYS: [Key; TP_IN_KEYS_SZ] = const { + let mut keys = [unsafe { Key::from_bytes([0; 8]) }; TP_IN_KEYS_SZ]; + let mut i = 0; + while i < TP_IN_KEYS_SZ { + keys[i] = $topic_in_list.topics[i].1; + i += 1; + } + keys + }; + // Create a list of JUST the MESSAGE keys from the TOPICS OUT report + const TP_OUT_KEYS_SZ: usize = $topic_out_list.topics.len(); + const TP_OUT_KEYS: [Key; TP_OUT_KEYS_SZ] = const { + let mut keys = [unsafe { Key::from_bytes([0; 8]) }; TP_OUT_KEYS_SZ]; + let mut i = 0; + while i < TP_OUT_KEYS_SZ { + keys[i] = $topic_out_list.topics[i].1; + i += 1; + } + keys + }; + + // This is a list of all REQUEST KEYS in the actual handlers + // + // This should be a SUBSET of the REQUEST KEYS in the Endpoint report + const EP_HANDLER_IN_KEYS: &[Key] = &[ + $(<$endpoint as $crate::Endpoint>::REQ_KEY,)* + ]; + // This is a list of all RESPONSE KEYS in the actual handlers + // + // This should be a SUBSET of the RESPONSE KEYS in the Endpoint report + const EP_HANDLER_OUT_KEYS: &[Key] = &[ + $(<$endpoint as $crate::Endpoint>::RESP_KEY,)* + ]; + // This is a list of all TOPIC KEYS in the actual handlers + // + // This should be a SUBSET of the TOPIC KEYS in the Topic IN report + // (we can't check the out, we have no way of enumerating that yet, + // which would require linkme-like crimes I think) + const TP_HANDLER_IN_KEYS: &[Key] = &[ + $(<$topic_in as $crate::Topic>::TOPIC_KEY,)* + ]; + + const fn a_is_subset_of_b(a: &[Key], b: &[Key]) -> bool { + let mut i = 0; + while i < a.len() { + let x = u64::from_le_bytes(a[i].to_bytes()); + let mut matched = false; + let mut j = 0; + while j < b.len() { + let y = u64::from_le_bytes(b[j].to_bytes()); + if x == y { + matched = true; + break; + } + j += 1; + } + if !matched { + return false; + } + i += 1; + } + true + } + + // TODO: Warn/error if the list doesn't match the defined handlers? + pub const NEEDED_SZ_IN: usize = $crate::server::min_key_needed(&[ + &EP_IN_KEYS, + &TP_IN_KEYS, + ]); + pub const NEEDED_SZ_OUT: usize = $crate::server::min_key_needed(&[ + &EP_OUT_KEYS, + &TP_OUT_KEYS, + ]); + pub const NEEDED_SZ: usize = const { + assert!( + a_is_subset_of_b(EP_HANDLER_IN_KEYS, &EP_IN_KEYS), + "All listed endpoint handlers must be listed in endpoints->list! Missing Requst Type found!", + ); + assert!( + a_is_subset_of_b(EP_HANDLER_OUT_KEYS, &EP_OUT_KEYS), + "All listed endpoint handlers must be listed in endpoints->list! Missing Response Type found!", + ); + assert!( + a_is_subset_of_b(TP_HANDLER_IN_KEYS, &TP_IN_KEYS), + "All listed endpoint handlers must be listed in endpoints->list! Missing Response Type found!", + ); + if NEEDED_SZ_IN > NEEDED_SZ_OUT { + NEEDED_SZ_IN + } else { + NEEDED_SZ_OUT + } + }; + } + + + // Here, we calculate at const time the keys we need to match against. This is done with + // paste, which is unfortunate, but allows us to match on this correctly later. + mod consts { + use super::*; + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$endpoint:upper _KEY1>]: u8 = $crate::Key1::from_key8(<$endpoint as $crate::Endpoint>::REQ_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$topic_in:upper _KEY1>]: u8 = $crate::Key1::from_key8(<$topic_in as $crate::Topic>::TOPIC_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$endpoint:upper _KEY2>]: [u8; 2] = $crate::Key2::from_key8(<$endpoint as $crate::Endpoint>::REQ_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$topic_in:upper _KEY2>]: [u8; 2] = $crate::Key2::from_key8(<$topic_in as $crate::Topic>::TOPIC_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$endpoint:upper _KEY4>]: [u8; 4] = $crate::Key4::from_key8(<$endpoint as $crate::Endpoint>::REQ_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$topic_in:upper _KEY4>]: [u8; 4] = $crate::Key4::from_key8(<$topic_in as $crate::Topic>::TOPIC_KEY).to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$endpoint:upper _KEY8>]: [u8; 8] = <$endpoint as $crate::Endpoint>::REQ_KEY.to_bytes(); + } + )* + $( + $crate::server::dispatch_macro::export::paste! { + pub const [<$topic_in:upper _KEY8>]: [u8; 8] = <$topic_in as $crate::Topic>::TOPIC_KEY.to_bytes(); + } + )* + } + + // This is the fun part. + // + // For... reasons, we need to generate a match function to allow for dispatching + // different async handlers without degrading to dyn Future, because no alloc on + // embedded systems. + // + // The easiest way I've found to achieve this is actually to implement this + // handler for ALL of 1, 2, 4, 8, BUT to hide that from the user, and instead + // use THIS alias to give them the one that they need. + // + // This is overly complicated because I'm mixing const-time capabilities with + // macro-time capabilities. I'm very open to other suggestions that achieve the + // same outcome. + pub type $app_name = impls::$app_name<{ sizer::NEEDED_SZ }>; + + + + mod impls { + use super::*; + + pub struct $app_name { + pub context: $context_ty, + pub spawn: $spawn_impl, + pub device_map: &'static $crate::DeviceMap, + } + + impl $app_name { + /// Create a new instance of the dispatcher + pub fn new( + context: $context_ty, + spawn: $spawn_impl, + ) -> Self { + const MAP: &$crate::DeviceMap = &$crate::DeviceMap { + types: const { + const LISTS: &[&[&'static postcard_schema::schema::NamedType]] = &[ + $endpoint_list.types, + $topic_in_list.types, + $topic_out_list.types, + ]; + const TTL_COUNT: usize = $endpoint_list.types.len() + $topic_in_list.types.len() + $topic_out_list.types.len(); + + const BIG_RPT: ([Option<&'static postcard_schema::schema::NamedType>; TTL_COUNT], usize) = $crate::uniques::merge_nty_lists(LISTS); + const SMALL_RPT: [&'static postcard_schema::schema::NamedType; BIG_RPT.1] = $crate::uniques::cruncher(BIG_RPT.0.as_slice()); + SMALL_RPT.as_slice() + }, + endpoints: &$endpoint_list.endpoints, + topics_in: &$topic_in_list.topics, + topics_out: &$topic_out_list.topics, + min_key_len: const { + match sizer::NEEDED_SZ { + 1 => $crate::header::VarKeyKind::Key1, + 2 => $crate::header::VarKeyKind::Key2, + 4 => $crate::header::VarKeyKind::Key4, + 8 => $crate::header::VarKeyKind::Key8, + _ => unreachable!(), + } + } + }; + $app_name { + context, + spawn, + device_map: MAP, + } + } + } + + define_dispatch! { + @matcher 1 $app_name $tx_impl; $spawn_fn $crate::Key1; $crate::header::VarKeyKind::Key1; + ($($endpoint | $ep_flavor | $ep_handler)*) + ($($topic_in | $tp_flavor | $tp_handler)*) + } + define_dispatch! { + @matcher 2 $app_name $tx_impl; $spawn_fn $crate::Key2; $crate::header::VarKeyKind::Key2; + ($($endpoint | $ep_flavor | $ep_handler)*) + ($($topic_in | $tp_flavor | $tp_handler)*) + } + define_dispatch! { + @matcher 4 $app_name $tx_impl; $spawn_fn $crate::Key4; $crate::header::VarKeyKind::Key4; + ($($endpoint | $ep_flavor | $ep_handler)*) + ($($topic_in | $tp_flavor | $tp_handler)*) + } + define_dispatch! { + @matcher 8 $app_name $tx_impl; $spawn_fn $crate::Key; $crate::header::VarKeyKind::Key8; + ($($endpoint | $ep_flavor | $ep_handler)*) + ($($topic_in | $tp_flavor | $tp_handler)*) + } + } + + } +} diff --git a/source/postcard-rpc/src/server/impls/embassy_usb_v0_3.rs b/source/postcard-rpc/src/server/impls/embassy_usb_v0_3.rs new file mode 100644 index 0000000..3135aed --- /dev/null +++ b/source/postcard-rpc/src/server/impls/embassy_usb_v0_3.rs @@ -0,0 +1,669 @@ +//! Implementation using `embassy-usb` and bulk interfaces + +use embassy_executor::{SpawnError, SpawnToken, Spawner}; +use embassy_sync::{blocking_mutex::raw::RawMutex, mutex::Mutex}; +use embassy_usb_driver::{Driver, Endpoint, EndpointError, EndpointIn, EndpointOut}; +use futures_util::FutureExt; +use serde::Serialize; + +use crate::{ + header::VarHeader, + server::{WireRx, WireRxErrorKind, WireSpawn, WireTx, WireTxErrorKind}, +}; + +/// A collection of types and aliases useful for importing the correct types +pub mod dispatch_impl { + pub use super::embassy_spawn as spawn_fn; + use super::{EUsbWireRx, EUsbWireTx, EUsbWireTxInner, UsbDeviceBuffers}; + + /// Used for defining the USB interface + pub const DEVICE_INTERFACE_GUIDS: &[&str] = &["{AFB9A6FB-30BA-44BC-9232-806CFC875321}"]; + + use embassy_sync::{blocking_mutex::raw::RawMutex, mutex::Mutex}; + use embassy_usb::{ + msos::{self, windows_version}, + Builder, Config, UsbDevice, + }; + use embassy_usb_driver::Driver; + use static_cell::{ConstStaticCell, StaticCell}; + + /// Type alias for `WireTx` impl + pub type WireTxImpl = super::EUsbWireTx; + /// Type alias for `WireRx` impl + pub type WireRxImpl = super::EUsbWireRx; + /// Type alias for `WireSpawn` impl + pub type WireSpawnImpl = super::EUsbWireSpawn; + /// Type alias for the receive buffer + pub type WireRxBuf = &'static mut [u8]; + + /// A helper type for `static` storage of buffers and driver components + pub struct WireStorage< + M: RawMutex + 'static, + D: Driver<'static> + 'static, + const CONFIG: usize = 256, + const BOS: usize = 256, + const CONTROL: usize = 64, + const MSOS: usize = 256, + > { + /// Usb buffer storage + pub bufs_usb: ConstStaticCell>, + /// WireTx/Sender static storage + pub cell: StaticCell>>, + } + + impl< + M: RawMutex + 'static, + D: Driver<'static> + 'static, + const CONFIG: usize, + const BOS: usize, + const CONTROL: usize, + const MSOS: usize, + > WireStorage + { + /// Create a new, uninitialized static set of buffers + pub const fn new() -> Self { + Self { + bufs_usb: ConstStaticCell::new(UsbDeviceBuffers::new()), + cell: StaticCell::new(), + } + } + + /// Initialize the static storage. + /// + /// This must only be called once. + pub fn init( + &'static self, + driver: D, + config: Config<'static>, + tx_buf: &'static mut [u8], + ) -> (UsbDevice<'static, D>, WireTxImpl, WireRxImpl) { + let bufs = self.bufs_usb.take(); + + let mut builder = Builder::new( + driver, + config, + &mut bufs.config_descriptor, + &mut bufs.bos_descriptor, + &mut bufs.msos_descriptor, + &mut bufs.control_buf, + ); + + // Add the Microsoft OS Descriptor (MSOS/MOD) descriptor. + // We tell Windows that this entire device is compatible with the "WINUSB" feature, + // which causes it to use the built-in WinUSB driver automatically, which in turn + // can be used by libusb/rusb software without needing a custom driver or INF file. + // In principle you might want to call msos_feature() just on a specific function, + // if your device also has other functions that still use standard class drivers. + builder.msos_descriptor(windows_version::WIN8_1, 0); + builder.msos_feature(msos::CompatibleIdFeatureDescriptor::new("WINUSB", "")); + builder.msos_feature(msos::RegistryPropertyFeatureDescriptor::new( + "DeviceInterfaceGUIDs", + msos::PropertyData::RegMultiSz(DEVICE_INTERFACE_GUIDS), + )); + + // Add a vendor-specific function (class 0xFF), and corresponding interface, + // that uses our custom handler. + let mut function = builder.function(0xFF, 0, 0); + let mut interface = function.interface(); + let mut alt = interface.alt_setting(0xFF, 0, 0, None); + let ep_out = alt.endpoint_bulk_out(64); + let ep_in = alt.endpoint_bulk_in(64); + drop(function); + + let wtx = self.cell.init(Mutex::new(EUsbWireTxInner { + ep_in, + _log_seq: 0, + tx_buf, + _max_log_len: 0, + })); + + // Build the builder. + let usb = builder.build(); + + (usb, EUsbWireTx { inner: wtx }, EUsbWireRx { ep_out }) + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// TX +////////////////////////////////////////////////////////////////////////////// + +/// Implementation detail, holding the endpoint and scratch buffer used for sending +pub struct EUsbWireTxInner> { + ep_in: D::EndpointIn, + _log_seq: u32, + tx_buf: &'static mut [u8], + _max_log_len: usize, +} + +/// A [`WireTx`] implementation for embassy-usb 0.3. +#[derive(Copy)] +pub struct EUsbWireTx + 'static> { + inner: &'static Mutex>, +} + +impl + 'static> Clone for EUsbWireTx { + fn clone(&self) -> Self { + EUsbWireTx { inner: self.inner } + } +} + +impl + 'static> WireTx for EUsbWireTx { + type Error = WireTxErrorKind; + + async fn send( + &self, + hdr: VarHeader, + msg: &T, + ) -> Result<(), Self::Error> { + let mut inner = self.inner.lock().await; + + let EUsbWireTxInner { + ep_in, + _log_seq: _, + tx_buf, + _max_log_len: _, + }: &mut EUsbWireTxInner = &mut inner; + + let (hdr_used, remain) = hdr.write_to_slice(tx_buf).ok_or(WireTxErrorKind::Other)?; + let bdy_used = postcard::to_slice(msg, remain).map_err(|_| WireTxErrorKind::Other)?; + let used_ttl = hdr_used.len() + bdy_used.len(); + + if let Some(used) = tx_buf.get(..used_ttl) { + send_all::(ep_in, used).await + } else { + Err(WireTxErrorKind::Other) + } + } + + async fn send_raw(&self, buf: &[u8]) -> Result<(), Self::Error> { + let mut inner = self.inner.lock().await; + send_all::(&mut inner.ep_in, buf).await + } +} + +#[inline] +async fn send_all(ep_in: &mut D::EndpointIn, out: &[u8]) -> Result<(), WireTxErrorKind> +where + D: Driver<'static>, +{ + if out.is_empty() { + return Ok(()); + } + // TODO: Timeout? + if ep_in.wait_enabled().now_or_never().is_none() { + return Ok(()); + } + + // write in segments of 64. The last chunk may + // be 0 < len <= 64. + for ch in out.chunks(64) { + if ep_in.write(ch).await.is_err() { + return Err(WireTxErrorKind::ConnectionClosed); + } + } + // If the total we sent was a multiple of 64, send an + // empty message to "flush" the transaction. We already checked + // above that the len != 0. + if (out.len() & (64 - 1)) == 0 && ep_in.write(&[]).await.is_err() { + return Err(WireTxErrorKind::ConnectionClosed); + } + + Ok(()) +} + +////////////////////////////////////////////////////////////////////////////// +// RX +////////////////////////////////////////////////////////////////////////////// + +/// A [`WireRx`] implementation for embassy-usb 0.3. +pub struct EUsbWireRx> { + ep_out: D::EndpointOut, +} + +impl> WireRx for EUsbWireRx { + type Error = WireRxErrorKind; + + async fn receive<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a mut [u8], Self::Error> { + let buflen = buf.len(); + let mut window = &mut buf[..]; + while !window.is_empty() { + let n = match self.ep_out.read(window).await { + Ok(n) => n, + Err(EndpointError::BufferOverflow) => { + return Err(WireRxErrorKind::ReceivedMessageTooLarge) + } + Err(EndpointError::Disabled) => return Err(WireRxErrorKind::ConnectionClosed), + }; + + let (_now, later) = window.split_at_mut(n); + window = later; + if n != 64 { + // We now have a full frame! Great! + let wlen = window.len(); + let len = buflen - wlen; + let frame = &mut buf[..len]; + + return Ok(frame); + } + } + + // If we got here, we've run out of space. That's disappointing. Accumulate to the + // end of this packet + loop { + match self.ep_out.read(buf).await { + Ok(64) => {} + Ok(_) => return Err(WireRxErrorKind::ReceivedMessageTooLarge), + Err(EndpointError::BufferOverflow) => { + return Err(WireRxErrorKind::ReceivedMessageTooLarge) + } + Err(EndpointError::Disabled) => return Err(WireRxErrorKind::ConnectionClosed), + }; + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// SPAWN +////////////////////////////////////////////////////////////////////////////// + +/// A [`WireSpawn`] impl using the embassy executor +#[derive(Clone)] +pub struct EUsbWireSpawn { + /// The embassy-executor spawner + pub spawner: Spawner, +} + +impl From for EUsbWireSpawn { + fn from(value: Spawner) -> Self { + Self { spawner: value } + } +} + +impl WireSpawn for EUsbWireSpawn { + type Error = SpawnError; + + type Info = Spawner; + + fn info(&self) -> &Self::Info { + &self.spawner + } +} + +/// Attempt to spawn the given token +pub fn embassy_spawn(sp: &Sp, tok: SpawnToken) -> Result<(), Sp::Error> +where + Sp: WireSpawn, +{ + let info = sp.info(); + info.spawn(tok) +} + +////////////////////////////////////////////////////////////////////////////// +// OTHER +////////////////////////////////////////////////////////////////////////////// + +/// A generically sized storage type for buffers +pub struct UsbDeviceBuffers< + const CONFIG: usize = 256, + const BOS: usize = 256, + const CONTROL: usize = 64, + const MSOS: usize = 256, +> { + /// Config descriptor storage + pub config_descriptor: [u8; CONFIG], + /// BOS descriptor storage + pub bos_descriptor: [u8; BOS], + /// CONTROL endpoint buffer storage + pub control_buf: [u8; CONTROL], + /// MSOS descriptor buffer storage + pub msos_descriptor: [u8; MSOS], +} + +impl + UsbDeviceBuffers +{ + /// Create a new, empty set of buffers + pub const fn new() -> Self { + Self { + config_descriptor: [0u8; CONFIG], + bos_descriptor: [0u8; BOS], + msos_descriptor: [0u8; MSOS], + control_buf: [0u8; CONTROL], + } + } +} + +/// Static storage for generically sized input and output packet buffers +pub struct PacketBuffers { + /// the transmit buffer + pub tx_buf: [u8; TX], + /// thereceive buffer + pub rx_buf: [u8; RX], +} + +impl PacketBuffers { + /// Create new empty buffers + pub const fn new() -> Self { + Self { + tx_buf: [0u8; TX], + rx_buf: [0u8; RX], + } + } +} + +/// This is a basic example that everything compiles. It is intended to exercise the macro above, +/// as well as provide impls for docs. Don't rely on any of this! +#[doc(hidden)] +#[allow(dead_code)] +#[cfg(feature = "test-utils")] +pub mod fake { + use crate::{ + define_dispatch, endpoints, + server::{Sender, SpawnContext}, + topics, + }; + use crate::{header::VarHeader, Schema}; + use embassy_usb_driver::{Bus, ControlPipe, EndpointIn, EndpointOut}; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Schema)] + pub struct AReq(pub u8); + #[derive(Serialize, Deserialize, Schema)] + pub struct AResp(pub u8); + #[derive(Serialize, Deserialize, Schema)] + pub struct BReq(pub u16); + #[derive(Serialize, Deserialize, Schema)] + pub struct BResp(pub u32); + #[derive(Serialize, Deserialize, Schema)] + pub struct GReq; + #[derive(Serialize, Deserialize, Schema)] + pub struct GResp; + #[derive(Serialize, Deserialize, Schema)] + pub struct DReq; + #[derive(Serialize, Deserialize, Schema)] + pub struct DResp; + #[derive(Serialize, Deserialize, Schema)] + pub struct EReq; + #[derive(Serialize, Deserialize, Schema)] + pub struct EResp; + #[derive(Serialize, Deserialize, Schema)] + pub struct ZMsg(pub i16); + + endpoints! { + list = ENDPOINT_LIST; + | EndpointTy | RequestTy | ResponseTy | Path | + | ---------- | --------- | ---------- | ---- | + | AlphaEndpoint | AReq | AResp | "alpha" | + | BetaEndpoint | BReq | BResp | "beta" | + | GammaEndpoint | GReq | GResp | "gamma" | + | DeltaEndpoint | DReq | DResp | "delta" | + | EpsilonEndpoint | EReq | EResp | "epsilon" | + } + + topics! { + list = TOPICS_IN_LIST; + | TopicTy | MessageTy | Path | + | ---------- | --------- | ---- | + | ZetaTopic1 | ZMsg | "zeta1" | + | ZetaTopic2 | ZMsg | "zeta2" | + | ZetaTopic3 | ZMsg | "zeta3" | + } + + topics! { + list = TOPICS_OUT_LIST; + | TopicTy | MessageTy | Path | + | ---------- | --------- | ---- | + | ZetaTopic10 | ZMsg | "zeta10" | + } + + pub struct FakeMutex; + pub struct FakeDriver; + pub struct FakeEpOut; + pub struct FakeEpIn; + pub struct FakeCtlPipe; + pub struct FakeBus; + + impl embassy_usb_driver::Endpoint for FakeEpOut { + fn info(&self) -> &embassy_usb_driver::EndpointInfo { + todo!() + } + + async fn wait_enabled(&mut self) { + todo!() + } + } + + impl EndpointOut for FakeEpOut { + async fn read( + &mut self, + _buf: &mut [u8], + ) -> Result { + todo!() + } + } + + impl embassy_usb_driver::Endpoint for FakeEpIn { + fn info(&self) -> &embassy_usb_driver::EndpointInfo { + todo!() + } + + async fn wait_enabled(&mut self) { + todo!() + } + } + + impl EndpointIn for FakeEpIn { + async fn write(&mut self, _buf: &[u8]) -> Result<(), embassy_usb_driver::EndpointError> { + todo!() + } + } + + impl ControlPipe for FakeCtlPipe { + fn max_packet_size(&self) -> usize { + todo!() + } + + async fn setup(&mut self) -> [u8; 8] { + todo!() + } + + async fn data_out( + &mut self, + _buf: &mut [u8], + _first: bool, + _last: bool, + ) -> Result { + todo!() + } + + async fn data_in( + &mut self, + _data: &[u8], + _first: bool, + _last: bool, + ) -> Result<(), embassy_usb_driver::EndpointError> { + todo!() + } + + async fn accept(&mut self) { + todo!() + } + + async fn reject(&mut self) { + todo!() + } + + async fn accept_set_address(&mut self, _addr: u8) { + todo!() + } + } + + impl Bus for FakeBus { + async fn enable(&mut self) { + todo!() + } + + async fn disable(&mut self) { + todo!() + } + + async fn poll(&mut self) -> embassy_usb_driver::Event { + todo!() + } + + fn endpoint_set_enabled( + &mut self, + _ep_addr: embassy_usb_driver::EndpointAddress, + _enabled: bool, + ) { + todo!() + } + + fn endpoint_set_stalled( + &mut self, + _ep_addr: embassy_usb_driver::EndpointAddress, + _stalled: bool, + ) { + todo!() + } + + fn endpoint_is_stalled(&mut self, _ep_addr: embassy_usb_driver::EndpointAddress) -> bool { + todo!() + } + + async fn remote_wakeup(&mut self) -> Result<(), embassy_usb_driver::Unsupported> { + todo!() + } + } + + impl embassy_usb_driver::Driver<'static> for FakeDriver { + type EndpointOut = FakeEpOut; + + type EndpointIn = FakeEpIn; + + type ControlPipe = FakeCtlPipe; + + type Bus = FakeBus; + + fn alloc_endpoint_out( + &mut self, + _ep_type: embassy_usb_driver::EndpointType, + _max_packet_size: u16, + _interval_ms: u8, + ) -> Result { + todo!() + } + + fn alloc_endpoint_in( + &mut self, + _ep_type: embassy_usb_driver::EndpointType, + _max_packet_size: u16, + _interval_ms: u8, + ) -> Result { + todo!() + } + + fn start(self, _control_max_packet_size: u16) -> (Self::Bus, Self::ControlPipe) { + todo!() + } + } + + unsafe impl embassy_sync::blocking_mutex::raw::RawMutex for FakeMutex { + const INIT: Self = Self; + + fn lock(&self, _f: impl FnOnce() -> R) -> R { + todo!() + } + } + + pub struct TestContext { + pub a: u32, + pub b: u32, + } + + impl SpawnContext for TestContext { + type SpawnCtxt = TestSpawnContext; + + fn spawn_ctxt(&mut self) -> Self::SpawnCtxt { + TestSpawnContext { b: self.b } + } + } + + pub struct TestSpawnContext { + b: u32, + } + + // TODO: How to do module path concat? + use crate::server::impls::embassy_usb_v0_3::dispatch_impl::{ + spawn_fn, WireSpawnImpl, WireTxImpl, + }; + + define_dispatch! { + app: SingleDispatcher; + spawn_fn: spawn_fn; + tx_impl: WireTxImpl; + spawn_impl: WireSpawnImpl; + context: TestContext; + + endpoints: { + list: ENDPOINT_LIST; + + | EndpointTy | kind | handler | + | ---------- | ---- | ------- | + | AlphaEndpoint | async | test_alpha_handler | + | EpsilonEndpoint | spawn | test_epsilon_handler_task | + }; + topics_in: { + list: TOPICS_IN_LIST; + + | TopicTy | kind | handler | + | ---------- | ---- | ------- | + // | ZetaTopic1 | blocking | test_zeta_blocking | + // | ZetaTopic2 | async | test_zeta_async | + // | ZetaTopic3 | spawn | test_zeta_spawn | + }; + topics_out: { + list: TOPICS_OUT_LIST; + }; + } + + async fn test_alpha_handler( + _context: &mut TestContext, + _header: VarHeader, + _body: AReq, + ) -> AResp { + todo!() + } + + async fn test_beta_handler( + _context: &mut TestContext, + _header: VarHeader, + _body: BReq, + ) -> BResp { + todo!() + } + + async fn test_gamma_handler( + _context: &mut TestContext, + _header: VarHeader, + _body: GReq, + ) -> GResp { + todo!() + } + + fn test_delta_handler(_context: &mut TestContext, _header: VarHeader, _body: DReq) -> DResp { + todo!() + } + + #[embassy_executor::task] + async fn test_epsilon_handler_task( + _context: TestSpawnContext, + _header: VarHeader, + _body: EReq, + _sender: Sender>, + ) { + todo!() + } +} diff --git a/source/postcard-rpc/src/server/impls/mod.rs b/source/postcard-rpc/src/server/impls/mod.rs new file mode 100644 index 0000000..104cc87 --- /dev/null +++ b/source/postcard-rpc/src/server/impls/mod.rs @@ -0,0 +1,9 @@ +//! Implementations of various Server traits +//! +//! The implementations in this module typically require feature flags to be set. + +#[cfg(feature = "embassy-usb-0_3-server")] +pub mod embassy_usb_v0_3; + +#[cfg(feature = "test-utils")] +pub mod test_channels; diff --git a/source/postcard-rpc/src/server/impls/test_channels.rs b/source/postcard-rpc/src/server/impls/test_channels.rs new file mode 100644 index 0000000..db4aab1 --- /dev/null +++ b/source/postcard-rpc/src/server/impls/test_channels.rs @@ -0,0 +1,199 @@ +//! Implementation that uses channels for local testing + +use core::{convert::Infallible, future::Future}; + +use crate::server::{ + AsWireRxErrorKind, AsWireTxErrorKind, WireRx, WireRxErrorKind, WireSpawn, WireTx, + WireTxErrorKind, +}; +use tokio::sync::mpsc; + +////////////////////////////////////////////////////////////////////////////// +// DISPATCH IMPL +////////////////////////////////////////////////////////////////////////////// + +/// A collection of types and aliases useful for importing the correct types +pub mod dispatch_impl { + use crate::{ + header::VarKeyKind, + server::{Dispatch, Server}, + }; + + pub use super::tokio_spawn as spawn_fn; + + /// The settings necessary for creating a new channel server + pub struct Settings { + /// The frame sender + pub tx: WireTxImpl, + /// The frame receiver + pub rx: WireRxImpl, + /// The size of the receive buffer + pub buf: usize, + /// The sender key size to use + pub kkind: VarKeyKind, + } + + /// Type alias for `WireTx` impl + pub type WireTxImpl = super::ChannelWireTx; + /// Type alias for `WireRx` impl + pub type WireRxImpl = super::ChannelWireRx; + /// Type alias for `WireSpawn` impl + pub type WireSpawnImpl = super::ChannelWireSpawn; + /// Type alias for the receive buffer + pub type WireRxBuf = Box<[u8]>; + + /// Create a new server using the [`Settings`] and [`Dispatch`] implementation + pub fn new_server( + dispatch: D, + settings: Settings, + ) -> crate::server::Server + where + D: Dispatch, + { + let buf = vec![0; settings.buf]; + Server::new( + &settings.tx, + settings.rx, + buf.into_boxed_slice(), + dispatch, + settings.kkind, + ) + } +} + +////////////////////////////////////////////////////////////////////////////// +// TX +////////////////////////////////////////////////////////////////////////////// + +/// A [`WireTx`] impl using tokio mpsc channels +#[derive(Clone)] +pub struct ChannelWireTx { + tx: mpsc::Sender>, +} + +impl ChannelWireTx { + /// Create a new [`ChannelWireTx`] + pub fn new(tx: mpsc::Sender>) -> Self { + Self { tx } + } +} + +impl WireTx for ChannelWireTx { + type Error = ChannelWireTxError; + + async fn send( + &self, + hdr: crate::header::VarHeader, + msg: &T, + ) -> Result<(), Self::Error> { + let mut hdr_ser = hdr.write_to_vec(); + let bdy_ser = postcard::to_stdvec(msg).unwrap(); + hdr_ser.extend_from_slice(&bdy_ser); + self.tx + .send(hdr_ser) + .await + .map_err(|_| ChannelWireTxError::ChannelClosed)?; + Ok(()) + } + + async fn send_raw(&self, buf: &[u8]) -> Result<(), Self::Error> { + let buf = buf.to_vec(); + self.tx + .send(buf) + .await + .map_err(|_| ChannelWireTxError::ChannelClosed)?; + Ok(()) + } +} + +/// A wire tx error +#[derive(Debug)] +pub enum ChannelWireTxError { + /// The receiver closed the channel + ChannelClosed, +} + +impl AsWireTxErrorKind for ChannelWireTxError { + fn as_kind(&self) -> WireTxErrorKind { + match self { + ChannelWireTxError::ChannelClosed => WireTxErrorKind::ConnectionClosed, + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// RX +////////////////////////////////////////////////////////////////////////////// + +/// A [`WireRx`] impl using tokio mpsc channels +pub struct ChannelWireRx { + rx: mpsc::Receiver>, +} + +impl ChannelWireRx { + /// Create a new [`ChannelWireRx`] + pub fn new(rx: mpsc::Receiver>) -> Self { + Self { rx } + } +} + +impl WireRx for ChannelWireRx { + type Error = ChannelWireRxError; + + async fn receive<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a mut [u8], Self::Error> { + // todo: some kind of receive_owned? + let msg = self.rx.recv().await; + let msg = msg.ok_or(ChannelWireRxError::ChannelClosed)?; + let out = buf + .get_mut(..msg.len()) + .ok_or(ChannelWireRxError::MessageTooLarge)?; + out.copy_from_slice(&msg); + Ok(out) + } +} + +/// A wire rx error +#[derive(Debug)] +pub enum ChannelWireRxError { + /// The sender closed the channel + ChannelClosed, + /// The sender sent a too-large message + MessageTooLarge, +} + +impl AsWireRxErrorKind for ChannelWireRxError { + fn as_kind(&self) -> WireRxErrorKind { + match self { + ChannelWireRxError::ChannelClosed => WireRxErrorKind::ConnectionClosed, + ChannelWireRxError::MessageTooLarge => WireRxErrorKind::ReceivedMessageTooLarge, + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// SPAWN +////////////////////////////////////////////////////////////////////////////// + +/// A wire spawn implementation +#[derive(Clone)] +pub struct ChannelWireSpawn; + +impl WireSpawn for ChannelWireSpawn { + type Error = Infallible; + + type Info = (); + + fn info(&self) -> &Self::Info { + &() + } +} + +/// Spawn a task using tokio +pub fn tokio_spawn(_sp: &Sp, fut: F) -> Result<(), Sp::Error> +where + Sp: WireSpawn, + F: Future + 'static + Send, +{ + tokio::task::spawn(fut); + Ok(()) +} diff --git a/source/postcard-rpc/src/server/mod.rs b/source/postcard-rpc/src/server/mod.rs new file mode 100644 index 0000000..eb34327 --- /dev/null +++ b/source/postcard-rpc/src/server/mod.rs @@ -0,0 +1,522 @@ +//! Definitions of a postcard-rpc Server +//! +//! The Server role is responsible for accepting endpoint requests, issuing +//! endpoint responses, receiving client topic messages, and sending server +//! topic messages +//! +//! ## Impls +//! +//! It is intended to allow postcard-rpc servers to be implemented for many +//! different transport types, as well as many different operating environments. +//! +//! Examples of impls include: +//! +//! * A no-std impl using embassy and embassy-usb to provide transport over USB +//! * A std impl using Tokio channels to provide transport for testing +//! +//! Impls are expected to implement three traits: +//! +//! * [`WireTx`]: how the server sends frames to the client +//! * [`WireRx`]: how the server receives frames from the client +//! * [`WireSpawn`]: how the server spawns worker tasks for certain handlers + +#![allow(async_fn_in_trait)] + +#[doc(hidden)] +pub mod dispatch_macro; + +pub mod impls; + +use core::ops::DerefMut; + +use postcard_schema::Schema; +use serde::Serialize; + +use crate::{ + header::{VarHeader, VarKey, VarKeyKind, VarSeq}, + Key, +}; + +////////////////////////////////////////////////////////////////////////////// +// TX +////////////////////////////////////////////////////////////////////////////// + +/// This trait defines how the server sends frames to the client +pub trait WireTx: Clone { + /// The error type of this connection. + /// + /// For simple cases, you can use [`WireTxErrorKind`] directly. You can also + /// use your own custom type that implements [`AsWireTxErrorKind`]. + type Error: AsWireTxErrorKind; + + /// Send a single frame to the client, returning when send is complete. + async fn send(&self, hdr: VarHeader, msg: &T) + -> Result<(), Self::Error>; + + /// Send a single frame to the client, without handling serialization + async fn send_raw(&self, buf: &[u8]) -> Result<(), Self::Error>; +} + +/// The base [`WireTx`] Error Kind +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub enum WireTxErrorKind { + /// The connection has been closed, and is unlikely to succeed until + /// the connection is re-established. This will cause the Server run + /// loop to terminate. + ConnectionClosed, + /// Other unspecified errors + Other, +} + +/// A conversion trait to convert a user error into a base Kind type +pub trait AsWireTxErrorKind { + /// Convert the error type into a base type + fn as_kind(&self) -> WireTxErrorKind; +} + +impl AsWireTxErrorKind for WireTxErrorKind { + #[inline] + fn as_kind(&self) -> WireTxErrorKind { + *self + } +} + +////////////////////////////////////////////////////////////////////////////// +// RX +////////////////////////////////////////////////////////////////////////////// + +/// This trait defines how to receive a single frame from a client +pub trait WireRx { + /// The error type of this connection. + /// + /// For simple cases, you can use [`WireRxErrorKind`] directly. You can also + /// use your own custom type that implements [`AsWireRxErrorKind`]. + type Error: AsWireRxErrorKind; + + /// Receive a single frame + /// + /// On success, the portion of `buf` that contains a single frame is returned. + async fn receive<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a mut [u8], Self::Error>; +} + +/// The base [`WireRx`] Error Kind +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub enum WireRxErrorKind { + /// The connection has been closed, and is unlikely to succeed until + /// the connection is re-established. This will cause the Server run + /// loop to terminate. + ConnectionClosed, + /// The received message was too large for the server to handle + ReceivedMessageTooLarge, + /// Other message kinds + Other, +} + +/// A conversion trait to convert a user error into a base Kind type +pub trait AsWireRxErrorKind { + /// Convert the error type into a base type + fn as_kind(&self) -> WireRxErrorKind; +} + +impl AsWireRxErrorKind for WireRxErrorKind { + #[inline] + fn as_kind(&self) -> WireRxErrorKind { + *self + } +} + +////////////////////////////////////////////////////////////////////////////// +// SPAWN +////////////////////////////////////////////////////////////////////////////// + +/// A trait to assist in spawning a handler task +/// +/// This trait is weird, and mostly exists to abstract over how "normal" async +/// executors like tokio spawn tasks, taking a future, and how unusual async +/// executors like embassy spawn tasks, taking a task token that maps to static +/// storage +pub trait WireSpawn: Clone { + /// An error type returned when spawning fails. If this cannot happen, + /// [`Infallible`][core::convert::Infallible] can be used. + type Error; + /// The context used for spawning a task. + /// + /// For example, in tokio this is `()`, and in embassy this is `Spawner`. + type Info; + + /// Retrieve [`Self::Info`] + fn info(&self) -> &Self::Info; +} + +////////////////////////////////////////////////////////////////////////////// +// SENDER (wrapper of WireTx) +////////////////////////////////////////////////////////////////////////////// + +/// The [`Sender`] type wraps a [`WireTx`] impl, and provides higher level functionality +/// over it +#[derive(Clone)] +pub struct Sender { + tx: Tx, + kkind: VarKeyKind, +} + +impl Sender { + /// Create a new Sender + /// + /// Takes a [`WireTx`] impl, as well as the [`VarKeyKind`] used when sending messages + /// to the client. + /// + /// `kkind` should usually come from [`Dispatch::min_key_len()`]. + pub fn new(tx: Tx, kkind: VarKeyKind) -> Self { + Self { tx, kkind } + } + + /// Send a reply for the given endpoint + #[inline] + pub async fn reply(&self, seq_no: VarSeq, resp: &E::Response) -> Result<(), Tx::Error> + where + E: crate::Endpoint, + E::Response: Serialize + Schema, + { + let mut key = VarKey::Key8(E::RESP_KEY); + key.shrink_to(self.kkind); + let wh = VarHeader { key, seq_no }; + self.tx.send::(wh, resp).await + } + + /// Send a reply with the given Key + /// + /// This is useful when replying with "unusual" keys, for example Error responses + /// not tied to any specific Endpoint. + #[inline] + pub async fn reply_keyed(&self, seq_no: VarSeq, key: Key, resp: &T) -> Result<(), Tx::Error> + where + T: Serialize + Schema, + { + let mut key = VarKey::Key8(key); + key.shrink_to(self.kkind); + let wh = VarHeader { key, seq_no }; + self.tx.send::(wh, resp).await + } + + /// Publish a Topic message + #[inline] + pub async fn publish(&self, seq_no: VarSeq, msg: &T::Message) -> Result<(), Tx::Error> + where + T: crate::Topic, + T::Message: Serialize + Schema, + { + let mut key = VarKey::Key8(T::TOPIC_KEY); + key.shrink_to(self.kkind); + let wh = VarHeader { key, seq_no }; + self.tx.send::(wh, msg).await + } + + /// Send a single error message + pub async fn error( + &self, + seq_no: VarSeq, + error: crate::standard_icd::WireError, + ) -> Result<(), Tx::Error> { + self.reply_keyed(seq_no, crate::standard_icd::ERROR_KEY, &error) + .await + } +} + +////////////////////////////////////////////////////////////////////////////// +// SERVER +////////////////////////////////////////////////////////////////////////////// + +/// The [`Server`] is the main interface for handling communication +pub struct Server +where + Tx: WireTx, + Rx: WireRx, + Buf: DerefMut, + D: Dispatch, +{ + tx: Sender, + rx: Rx, + buf: Buf, + dis: D, +} + +/// A type representing the different errors [`Server::run()`] may return +pub enum ServerError +where + Tx: WireTx, + Rx: WireRx, +{ + /// A fatal error occurred with the [`WireTx::send()`] implementation + TxFatal(Tx::Error), + /// A fatal error occurred with the [`WireRx::receive()`] implementation + RxFatal(Rx::Error), +} + +impl Server +where + Tx: WireTx, + Rx: WireRx, + Buf: DerefMut, + D: Dispatch, +{ + /// Create a new Server + /// + /// Takes: + /// + /// * a [`WireTx`] impl for sending + /// * a [`WireRx`] impl for receiving + /// * a buffer used for receiving frames + /// * The user provided dispatching method, usually generated by [`define_dispatch!()`][crate::define_dispatch] + /// * a [`VarKeyKind`], which controls the key sizes sent by the [`WireTx`] impl + pub fn new(tx: &Tx, rx: Rx, buf: Buf, dis: D, kkind: VarKeyKind) -> Self { + Self { + tx: Sender { + tx: tx.clone(), + kkind, + }, + rx, + buf, + dis, + } + } + + /// Run until a fatal error occurs + /// + /// The server will receive frames, and dispatch them. When a fatal error occurs, + /// this method will return with the fatal error. + /// + /// The caller may decide to wait until a connection is re-established, reset any + /// state, or immediately begin re-running. + pub async fn run(&mut self) -> ServerError { + loop { + let Self { + tx, + rx, + buf, + dis: d, + } = self; + let used = match rx.receive(buf).await { + Ok(u) => u, + Err(e) => { + let kind = e.as_kind(); + match kind { + WireRxErrorKind::ConnectionClosed => return ServerError::RxFatal(e), + WireRxErrorKind::ReceivedMessageTooLarge => continue, + WireRxErrorKind::Other => continue, + } + } + }; + let Some((hdr, body)) = VarHeader::take_from_slice(used) else { + // TODO: send a nak on badly formed messages? We don't have + // much to say because we don't have a key or seq no or anything + continue; + }; + let fut = d.handle(tx, &hdr, body); + if let Err(e) = fut.await { + let kind = e.as_kind(); + match kind { + WireTxErrorKind::ConnectionClosed => return ServerError::TxFatal(e), + WireTxErrorKind::Other => {} + } + } + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// DISPATCH TRAIT +////////////////////////////////////////////////////////////////////////////// + +/// The dispatch trait handles an incoming endpoint or topic message +/// +/// The implementations of this trait are typically implemented by the +/// [`define_dispatch!`][crate::define_dispatch] macro. +pub trait Dispatch { + /// The [`WireTx`] impl used by this dispatcher + type Tx: WireTx; + + /// The minimum key length required to avoid hash collisions + fn min_key_len(&self) -> VarKeyKind; + + /// Handle a single incoming frame (endpoint or topic), and dispatch appropriately + async fn handle( + &mut self, + tx: &Sender, + hdr: &VarHeader, + body: &[u8], + ) -> Result<(), ::Error>; +} + +////////////////////////////////////////////////////////////////////////////// +// SPAWNCONTEXT TRAIT +////////////////////////////////////////////////////////////////////////////// + +/// A conversion trait for taking the Context and making a SpawnContext +/// +/// This is necessary if you use the `spawn` variant of `define_dispatch!`. +pub trait SpawnContext { + /// The spawn context type + type SpawnCtxt: 'static; + /// A method to convert the regular context into [`Self::SpawnCtxt`] + fn spawn_ctxt(&mut self) -> Self::SpawnCtxt; +} + +// Hilarious quadruply nested loop. Hope our lists are relatively small! +macro_rules! keycheck { + ( + $lists:ident; + $($num:literal => $func:ident;)* + ) => { + $( + { + let mut i = 0; + let mut good = true; + // For each list... + 'dupe: while i < $lists.len() { + let ilist = $lists[i]; + let mut j = 0; + // And for each key in the list + while j < ilist.len() { + let jkey = ilist[j]; + let akey = $func(jkey); + + // + // We now start checking against items later in the lists... + // + + // For each list (starting with the one we are on) + let mut x = i; + while x < $lists.len() { + // For each item... + // + // Note that for the STARTING list we continue where we started, + // but on subsequent lists start from the beginning + let xlist = $lists[x]; + let mut y = if x == i { + j + 1 + } else { + 0 + }; + + while y < xlist.len() { + let ykey = xlist[y]; + let bkey = $func(ykey); + + if akey == bkey { + good = false; + break 'dupe; + } + y += 1; + } + x += 1; + } + j += 1; + } + i += 1; + } + if good { + return $num; + } + } + )* + }; +} + +/// Calculates at const time the minimum number of bytes (1, 2, 4, or 8) to avoid +/// hash collisions in the lists of keys provided. +/// +/// If there are any duplicates, this function will panic at compile time. Otherwise, +/// this function will return 1, 2, 4, or 8. +/// +/// This function takes a very dumb "brute force" approach, that is of the order +/// `O(4 * N^2 * M^2)`, where `N` is `lists.len()`, and `M` is the length of each +/// sub-list. It is not recommended to call this outside of const context. +pub const fn min_key_needed(lists: &[&[Key]]) -> usize { + const fn one(key: Key) -> u8 { + crate::Key1::from_key8(key).0 + } + const fn two(key: Key) -> u16 { + u16::from_le_bytes(crate::Key2::from_key8(key).0) + } + const fn four(key: Key) -> u32 { + u32::from_le_bytes(crate::Key4::from_key8(key).0) + } + const fn eight(key: Key) -> u64 { + u64::from_le_bytes(key.0) + } + + keycheck! { + lists; + 1 => one; + 2 => two; + 4 => four; + 8 => eight; + }; + + panic!("Collision requiring more than 8 bytes!"); +} + +#[cfg(test)] +mod test { + use crate::{server::min_key_needed, Key}; + + #[test] + fn min_test_1() { + const MINA: usize = min_key_needed(&[&[ + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }, + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]) }, + ]]); + assert_eq!(1, MINA); + + const MINB: usize = min_key_needed(&[ + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }], + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]) }], + ]); + assert_eq!(1, MINB); + } + + #[test] + fn min_test_2() { + const MINA: usize = min_key_needed(&[&[ + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }, + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01]) }, + ]]); + assert_eq!(2, MINA); + const MINB: usize = min_key_needed(&[ + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }], + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01]) }], + ]); + assert_eq!(2, MINB); + } + + #[test] + fn min_test_4() { + const MINA: usize = min_key_needed(&[&[ + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }, + unsafe { Key::from_bytes([0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01]) }, + ]]); + assert_eq!(4, MINA); + const MINB: usize = min_key_needed(&[ + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }], + &[unsafe { Key::from_bytes([0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01]) }], + ]); + assert_eq!(4, MINB); + } + + #[test] + fn min_test_8() { + const MINA: usize = min_key_needed(&[&[ + unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }, + unsafe { Key::from_bytes([0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01]) }, + ]]); + assert_eq!(8, MINA); + const MINB: usize = min_key_needed(&[ + &[unsafe { Key::from_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) }], + &[unsafe { Key::from_bytes([0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01]) }], + ]); + assert_eq!(8, MINB); + } +} diff --git a/source/postcard-rpc/src/target_server/buffers.rs b/source/postcard-rpc/src/target_server/buffers.rs deleted file mode 100644 index fa7cc70..0000000 --- a/source/postcard-rpc/src/target_server/buffers.rs +++ /dev/null @@ -1,36 +0,0 @@ -pub struct AllBuffers { - pub usb_device: UsbDeviceBuffers, - pub endpoint_out: [u8; EO], - pub tx_buf: [u8; TX], - pub rx_buf: [u8; RX], -} - -impl AllBuffers { - pub const fn new() -> Self { - Self { - usb_device: UsbDeviceBuffers::new(), - endpoint_out: [0u8; EO], - tx_buf: [0u8; TX], - rx_buf: [0u8; RX], - } - } -} - -/// Buffers used by the [`UsbDevice`][embassy_usb::UsbDevice] of `embassy-usb` -pub struct UsbDeviceBuffers { - pub config_descriptor: [u8; 256], - pub bos_descriptor: [u8; 256], - pub control_buf: [u8; 64], - pub msos_descriptor: [u8; 256], -} - -impl UsbDeviceBuffers { - pub const fn new() -> Self { - Self { - config_descriptor: [0u8; 256], - bos_descriptor: [0u8; 256], - msos_descriptor: [0u8; 256], - control_buf: [0u8; 64], - } - } -} diff --git a/source/postcard-rpc/src/target_server/dispatch_macro.rs b/source/postcard-rpc/src/target_server/dispatch_macro.rs deleted file mode 100644 index 0358876..0000000 --- a/source/postcard-rpc/src/target_server/dispatch_macro.rs +++ /dev/null @@ -1,472 +0,0 @@ -/// # Define Dispatch Macro -/// -/// ```rust -/// # use postcard_rpc::target_server::dispatch_macro::fake::*; -/// # use postcard_rpc::{endpoint, target_server::{sender::Sender, SpawnContext}, WireHeader, define_dispatch}; -/// # use postcard_schema::Schema; -/// # use embassy_usb_driver::{Bus, ControlPipe, EndpointIn, EndpointOut}; -/// # use serde::{Deserialize, Serialize}; -/// -/// pub struct DispatchCtx; -/// pub struct SpawnCtx; -/// -/// // This trait impl is necessary if you want to use the `spawn` variant, -/// // as spawned tasks must take ownership of any context they need. -/// impl SpawnContext for DispatchCtx { -/// type SpawnCtxt = SpawnCtx; -/// fn spawn_ctxt(&mut self) -> Self::SpawnCtxt { -/// SpawnCtx -/// } -/// } -/// -/// define_dispatch! { -/// dispatcher: Dispatcher< -/// Mutex = FakeMutex, -/// Driver = FakeDriver, -/// Context = DispatchCtx, -/// >; -/// AlphaEndpoint => async alpha_handler, -/// BetaEndpoint => async beta_handler, -/// GammaEndpoint => async gamma_handler, -/// DeltaEndpoint => blocking delta_handler, -/// EpsilonEndpoint => spawn epsilon_handler_task, -/// } -/// -/// async fn alpha_handler(_c: &mut DispatchCtx, _h: WireHeader, _b: AReq) -> AResp { -/// todo!() -/// } -/// -/// async fn beta_handler(_c: &mut DispatchCtx, _h: WireHeader, _b: BReq) -> BResp { -/// todo!() -/// } -/// -/// async fn gamma_handler(_c: &mut DispatchCtx, _h: WireHeader, _b: GReq) -> GResp { -/// todo!() -/// } -/// -/// fn delta_handler(_c: &mut DispatchCtx, _h: WireHeader, _b: DReq) -> DResp { -/// todo!() -/// } -/// -/// #[embassy_executor::task] -/// async fn epsilon_handler_task(_c: SpawnCtx, _h: WireHeader, _b: EReq, _sender: Sender) { -/// todo!() -/// } -/// ``` - -#[macro_export] -macro_rules! define_dispatch { - // This is the "blocking execution" arm for defining an endpoint - (@arm blocking ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $dispatch:ident) => { - { - let reply = $handler($context, $header.clone(), $req); - if $dispatch.sender.reply::<$endpoint>($header.seq_no, &reply).await.is_err() { - let err = $crate::standard_icd::WireError::SerFailed; - $dispatch.error($header.seq_no, err).await; - } - } - }; - // This is the "async execution" arm for defining an endpoint - (@arm async ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $dispatch:ident) => { - { - let reply = $handler($context, $header.clone(), $req).await; - if $dispatch.sender.reply::<$endpoint>($header.seq_no, &reply).await.is_err() { - let err = $crate::standard_icd::WireError::SerFailed; - $dispatch.error($header.seq_no, err).await; - } - } - }; - // This is the "spawn an embassy task" arm for defining an endpoint - (@arm spawn ($endpoint:ty) $handler:ident $context:ident $header:ident $req:ident $dispatch:ident) => { - { - let spawner = ::embassy_executor::Spawner::for_current_executor().await; - let context = $crate::target_server::SpawnContext::spawn_ctxt($context); - if spawner.spawn($handler(context, $header.clone(), $req, $dispatch.sender())).is_err() { - let err = $crate::standard_icd::WireError::FailedToSpawn; - $dispatch.error($header.seq_no, err).await; - } - } - }; - // Optional trailing comma lol - ( - dispatcher: $name:ident; - $($endpoint:ty => $flavor:tt $handler:ident,)* - ) => { - define_dispatch! { - dispatcher: $name; - $( - $endpoint => $flavor $handler, - )* - } - }; - // This is the main entrypoint - ( - dispatcher: $name:ident; - $($endpoint:ty => $flavor:tt $handler:ident,)* - ) => { - /// This is a structure that handles dispatching, generated by the - /// `postcard-rpc::define_dispatch!()` macro. - pub struct $name { - pub sender: $crate::target_server::sender::Sender<$mutex, $driver>, - pub context: $context, - } - - impl $name { - /// Create a new instance of the dispatcher - pub fn new( - tx_buf: &'static mut [u8], - ep_in: <$driver as ::embassy_usb::driver::Driver<'static>>::EndpointIn, - context: $context, - ) -> Self { - static SENDER_INNER: ::static_cell::StaticCell< - ::embassy_sync::mutex::Mutex<$mutex, $crate::target_server::sender::SenderInner<$driver>>, - > = ::static_cell::StaticCell::new(); - $name { - sender: $crate::target_server::sender::Sender::init_sender(&SENDER_INNER, tx_buf, ep_in), - context, - } - } - } - - impl $crate::target_server::Dispatch for $name { - type Mutex = $mutex; - type Driver = $driver; - - /// Handle dispatching of a single frame - async fn dispatch( - &mut self, - hdr: $crate::WireHeader, - body: &[u8], - ) { - const _REQ_KEYS_MUST_BE_UNIQUE: () = { - let keys = [$(<$endpoint as $crate::Endpoint>::REQ_KEY),*]; - - let mut i = 0; - - while i < keys.len() { - let mut j = i + 1; - while j < keys.len() { - if keys[i].const_cmp(&keys[j]) { - panic!("Keys are not unique, there is a collision!"); - } - j += 1; - } - - i += 1; - } - }; - - let _ = _REQ_KEYS_MUST_BE_UNIQUE; - - match hdr.key { - $( - <$endpoint as $crate::Endpoint>::REQ_KEY => { - // Can we deserialize the request? - let Ok(req) = postcard::from_bytes::<<$endpoint as $crate::Endpoint>::Request>(body) else { - let err = $crate::standard_icd::WireError::DeserFailed; - self.error(hdr.seq_no, err).await; - return; - }; - - // Store some items as named bindings, so we can use `ident` in the - // recursive macro expansion. Load bearing order: we borrow `context` - // from `dispatch` because we need `dispatch` AFTER `context`, so NLL - // allows this to still borrowck - let dispatch = self; - let context = &mut dispatch.context; - - // This will expand to the right "flavor" of handler - define_dispatch!(@arm $flavor ($endpoint) $handler context hdr req dispatch); - } - )* - other => { - // huh! We have no idea what this key is supposed to be! - let err = $crate::standard_icd::WireError::UnknownKey(other.to_bytes()); - self.error(hdr.seq_no, err).await; - return; - }, - } - } - - /// Send a single error message - async fn error( - &self, - seq_no: u32, - error: $crate::standard_icd::WireError, - ) { - // If we get an error while sending an error, welp there's not much we can do - let _ = self.sender.reply_keyed(seq_no, $crate::standard_icd::ERROR_KEY, &error).await; - } - - /// Get a clone of the Sender of this Dispatch impl - fn sender(&self) -> $crate::target_server::sender::Sender { - self.sender.clone() - } - } - - } -} - -/// This is a basic example that everything compiles. It is intended to exercise the macro above, -/// as well as provide impls for docs. Don't rely on any of this! -#[doc(hidden)] -#[allow(dead_code)] -#[cfg(feature = "test-utils")] -pub mod fake { - use crate::target_server::SpawnContext; - #[allow(unused_imports)] - use crate::{endpoint, target_server::sender::Sender, Schema, WireHeader}; - use embassy_usb_driver::{Bus, ControlPipe, EndpointIn, EndpointOut}; - use serde::{Deserialize, Serialize}; - - #[derive(Serialize, Deserialize, Schema)] - pub struct AReq; - #[derive(Serialize, Deserialize, Schema)] - pub struct AResp; - #[derive(Serialize, Deserialize, Schema)] - pub struct BReq; - #[derive(Serialize, Deserialize, Schema)] - pub struct BResp; - #[derive(Serialize, Deserialize, Schema)] - pub struct GReq; - #[derive(Serialize, Deserialize, Schema)] - pub struct GResp; - #[derive(Serialize, Deserialize, Schema)] - pub struct DReq; - #[derive(Serialize, Deserialize, Schema)] - pub struct DResp; - #[derive(Serialize, Deserialize, Schema)] - pub struct EReq; - #[derive(Serialize, Deserialize, Schema)] - pub struct EResp; - - endpoint!(AlphaEndpoint, AReq, AResp, "alpha"); - endpoint!(BetaEndpoint, BReq, BResp, "beta"); - endpoint!(GammaEndpoint, GReq, GResp, "gamma"); - endpoint!(DeltaEndpoint, DReq, DResp, "delta"); - endpoint!(EpsilonEndpoint, EReq, EResp, "epsilon"); - - pub struct FakeMutex; - pub struct FakeDriver; - pub struct FakeEpOut; - pub struct FakeEpIn; - pub struct FakeCtlPipe; - pub struct FakeBus; - - impl embassy_usb_driver::Endpoint for FakeEpOut { - fn info(&self) -> &embassy_usb_driver::EndpointInfo { - todo!() - } - - async fn wait_enabled(&mut self) { - todo!() - } - } - - impl EndpointOut for FakeEpOut { - async fn read( - &mut self, - _buf: &mut [u8], - ) -> Result { - todo!() - } - } - - impl embassy_usb_driver::Endpoint for FakeEpIn { - fn info(&self) -> &embassy_usb_driver::EndpointInfo { - todo!() - } - - async fn wait_enabled(&mut self) { - todo!() - } - } - - impl EndpointIn for FakeEpIn { - async fn write(&mut self, _buf: &[u8]) -> Result<(), embassy_usb_driver::EndpointError> { - todo!() - } - } - - impl ControlPipe for FakeCtlPipe { - fn max_packet_size(&self) -> usize { - todo!() - } - - async fn setup(&mut self) -> [u8; 8] { - todo!() - } - - async fn data_out( - &mut self, - _buf: &mut [u8], - _first: bool, - _last: bool, - ) -> Result { - todo!() - } - - async fn data_in( - &mut self, - _data: &[u8], - _first: bool, - _last: bool, - ) -> Result<(), embassy_usb_driver::EndpointError> { - todo!() - } - - async fn accept(&mut self) { - todo!() - } - - async fn reject(&mut self) { - todo!() - } - - async fn accept_set_address(&mut self, _addr: u8) { - todo!() - } - } - - impl Bus for FakeBus { - async fn enable(&mut self) { - todo!() - } - - async fn disable(&mut self) { - todo!() - } - - async fn poll(&mut self) -> embassy_usb_driver::Event { - todo!() - } - - fn endpoint_set_enabled( - &mut self, - _ep_addr: embassy_usb_driver::EndpointAddress, - _enabled: bool, - ) { - todo!() - } - - fn endpoint_set_stalled( - &mut self, - _ep_addr: embassy_usb_driver::EndpointAddress, - _stalled: bool, - ) { - todo!() - } - - fn endpoint_is_stalled(&mut self, _ep_addr: embassy_usb_driver::EndpointAddress) -> bool { - todo!() - } - - async fn remote_wakeup(&mut self) -> Result<(), embassy_usb_driver::Unsupported> { - todo!() - } - } - - impl embassy_usb_driver::Driver<'static> for FakeDriver { - type EndpointOut = FakeEpOut; - - type EndpointIn = FakeEpIn; - - type ControlPipe = FakeCtlPipe; - - type Bus = FakeBus; - - fn alloc_endpoint_out( - &mut self, - _ep_type: embassy_usb_driver::EndpointType, - _max_packet_size: u16, - _interval_ms: u8, - ) -> Result { - todo!() - } - - fn alloc_endpoint_in( - &mut self, - _ep_type: embassy_usb_driver::EndpointType, - _max_packet_size: u16, - _interval_ms: u8, - ) -> Result { - todo!() - } - - fn start(self, _control_max_packet_size: u16) -> (Self::Bus, Self::ControlPipe) { - todo!() - } - } - - unsafe impl embassy_sync::blocking_mutex::raw::RawMutex for FakeMutex { - const INIT: Self = Self; - - fn lock(&self, _f: impl FnOnce() -> R) -> R { - todo!() - } - } - - pub struct TestContext { - pub a: u32, - pub b: u32, - } - - impl SpawnContext for TestContext { - type SpawnCtxt = TestSpawnContext; - - fn spawn_ctxt(&mut self) -> Self::SpawnCtxt { - TestSpawnContext { b: self.b } - } - } - - pub struct TestSpawnContext { - b: u32, - } - - define_dispatch! { - dispatcher: TestDispatcher; - AlphaEndpoint => async test_alpha_handler, - BetaEndpoint => async test_beta_handler, - GammaEndpoint => async test_gamma_handler, - DeltaEndpoint => blocking test_delta_handler, - EpsilonEndpoint => spawn test_epsilon_handler_task, - } - - async fn test_alpha_handler( - _context: &mut TestContext, - _header: WireHeader, - _body: AReq, - ) -> AResp { - todo!() - } - - async fn test_beta_handler( - _context: &mut TestContext, - _header: WireHeader, - _body: BReq, - ) -> BResp { - todo!() - } - - async fn test_gamma_handler( - _context: &mut TestContext, - _header: WireHeader, - _body: GReq, - ) -> GResp { - todo!() - } - - fn test_delta_handler(_context: &mut TestContext, _header: WireHeader, _body: DReq) -> DResp { - todo!() - } - - #[embassy_executor::task] - async fn test_epsilon_handler_task( - _context: TestSpawnContext, - _header: WireHeader, - _body: EReq, - _sender: Sender, - ) { - todo!() - } -} diff --git a/source/postcard-rpc/src/target_server/mod.rs b/source/postcard-rpc/src/target_server/mod.rs deleted file mode 100644 index cfa1707..0000000 --- a/source/postcard-rpc/src/target_server/mod.rs +++ /dev/null @@ -1,195 +0,0 @@ -#![allow(async_fn_in_trait)] - -use crate::{ - headered::extract_header_from_bytes, - standard_icd::{FrameTooLong, FrameTooShort, WireError}, - WireHeader, -}; -use embassy_sync::blocking_mutex::raw::RawMutex; -use embassy_usb::{ - driver::Driver, - msos::{self, windows_version}, - Builder, UsbDevice, -}; -use embassy_usb_driver::{Endpoint, EndpointError, EndpointOut}; -use sender::Sender; - -pub mod buffers; -pub mod dispatch_macro; -pub mod sender; - -const DEVICE_INTERFACE_GUIDS: &[&str] = &["{AFB9A6FB-30BA-44BC-9232-806CFC875321}"]; - -/// A trait that defines the postcard-rpc message dispatching behavior -/// -/// This is normally generated automatically by the [`define_dispatch!()`][crate::define_dispatch] -/// macro. -pub trait Dispatch { - type Mutex: RawMutex; - type Driver: Driver<'static>; - - /// Handle a single message, with the header deserialized and the - /// body not yet deserialized. - /// - /// This function must handle replying (either immediately or - /// in the future, for example if spawning a task) - async fn dispatch(&mut self, hdr: WireHeader, body: &[u8]); - - /// Send an error message, of the path and key defined for - /// the connection - async fn error(&self, seq_no: u32, error: WireError); - - /// Obtain an owned sender - fn sender(&self) -> Sender; -} - -/// A conversion trait for taking the Context and making a SpawnContext -/// -/// This is necessary if you use the `spawn` variant of `define_dispatch!`. -pub trait SpawnContext { - type SpawnCtxt: 'static; - fn spawn_ctxt(&mut self) -> Self::SpawnCtxt; -} - -/// A basic example of embassy_usb configuration values -pub fn example_config() -> embassy_usb::Config<'static> { - // Create embassy-usb Config - let mut config = embassy_usb::Config::new(0x16c0, 0x27DD); - config.manufacturer = Some("Embassy"); - config.product = Some("postcard-rpc example"); - config.serial_number = Some("12345678"); - - // Required for windows compatibility. - // https://developer.nordicsemi.com/nRF_Connect_SDK/doc/1.9.1/kconfig/CONFIG_CDC_ACM_IAD.html#help - config.device_class = 0xEF; - config.device_sub_class = 0x02; - config.device_protocol = 0x01; - config.composite_with_iads = true; - - config -} - -/// Configure the USB driver for use with postcard-rpc -/// -/// At the moment this is very geared towards USB FS. -pub fn configure_usb>( - driver: D, - bufs: &'static mut buffers::UsbDeviceBuffers, - config: embassy_usb::Config<'static>, -) -> (UsbDevice<'static, D>, D::EndpointIn, D::EndpointOut) { - let mut builder = Builder::new( - driver, - config, - &mut bufs.config_descriptor, - &mut bufs.bos_descriptor, - &mut bufs.msos_descriptor, - &mut bufs.control_buf, - ); - - // Add the Microsoft OS Descriptor (MSOS/MOD) descriptor. - // We tell Windows that this entire device is compatible with the "WINUSB" feature, - // which causes it to use the built-in WinUSB driver automatically, which in turn - // can be used by libusb/rusb software without needing a custom driver or INF file. - // In principle you might want to call msos_feature() just on a specific function, - // if your device also has other functions that still use standard class drivers. - builder.msos_descriptor(windows_version::WIN8_1, 0); - builder.msos_feature(msos::CompatibleIdFeatureDescriptor::new("WINUSB", "")); - builder.msos_feature(msos::RegistryPropertyFeatureDescriptor::new( - "DeviceInterfaceGUIDs", - msos::PropertyData::RegMultiSz(DEVICE_INTERFACE_GUIDS), - )); - - // Add a vendor-specific function (class 0xFF), and corresponding interface, - // that uses our custom handler. - let mut function = builder.function(0xFF, 0, 0); - let mut interface = function.interface(); - let mut alt = interface.alt_setting(0xFF, 0, 0, None); - let ep_out = alt.endpoint_bulk_out(64); - let ep_in = alt.endpoint_bulk_in(64); - drop(function); - - // Build the builder. - let usb = builder.build(); - - (usb, ep_in, ep_out) -} - -/// Handle RPC Dispatching -pub async fn rpc_dispatch( - mut ep_out: D::EndpointOut, - mut dispatch: T, - rx_buf: &'static mut [u8], -) -> ! -where - M: RawMutex + 'static, - D: Driver<'static> + 'static, - T: Dispatch, -{ - 'connect: loop { - // Wait for connection - ep_out.wait_enabled().await; - - // For each packet... - 'packet: loop { - // Accumulate a whole frame - let mut window = &mut rx_buf[..]; - 'buffer: loop { - if window.is_empty() { - #[cfg(feature = "defmt")] - defmt::warn!("Overflow!"); - let mut bonus: usize = 0; - loop { - // Just drain until the end of the overflow frame - match ep_out.read(rx_buf).await { - Ok(n) if n < 64 => { - bonus = bonus.saturating_add(n); - let err = WireError::FrameTooLong(FrameTooLong { - len: u32::try_from(bonus.saturating_add(rx_buf.len())) - .unwrap_or(u32::MAX), - max: u32::try_from(rx_buf.len()).unwrap_or(u32::MAX), - }); - dispatch.error(0, err).await; - continue 'packet; - } - Ok(n) => { - bonus = bonus.saturating_add(n); - } - Err(EndpointError::BufferOverflow) => panic!(), - Err(EndpointError::Disabled) => continue 'connect, - }; - } - } - - let n = match ep_out.read(window).await { - Ok(n) => n, - Err(EndpointError::BufferOverflow) => panic!(), - Err(EndpointError::Disabled) => continue 'connect, - }; - - let (_now, later) = window.split_at_mut(n); - window = later; - if n != 64 { - break 'buffer; - } - } - - // We now have a full frame! Great! - let wlen = window.len(); - let len = rx_buf.len() - wlen; - let frame = &rx_buf[..len]; - - #[cfg(feature = "defmt")] - defmt::debug!("got frame: {=usize}", frame.len()); - - // If it's for us, process it - if let Ok((hdr, body)) = extract_header_from_bytes(frame) { - dispatch.dispatch(hdr, body).await; - } else { - let err = WireError::FrameTooShort(FrameTooShort { - len: u32::try_from(frame.len()).unwrap_or(u32::MAX), - }); - dispatch.error(0, err).await; - } - } - } -} diff --git a/source/postcard-rpc/src/target_server/sender.rs b/source/postcard-rpc/src/target_server/sender.rs deleted file mode 100644 index 87bae4c..0000000 --- a/source/postcard-rpc/src/target_server/sender.rs +++ /dev/null @@ -1,319 +0,0 @@ -use embassy_sync::{blocking_mutex::raw::RawMutex, mutex::Mutex}; -use embassy_usb_driver::{Driver, Endpoint, EndpointIn}; -use futures_util::FutureExt; -use postcard_schema::Schema; -use serde::Serialize; -use static_cell::StaticCell; - -use crate::Key; - -/// This is the interface for sending information to the client. -/// -/// This is normally used by postcard-rpc itself, as well as for cases where -/// you have to manually send data, like publishing on a topic or delayed -/// replies (e.g. when spawning a task). -#[derive(Copy)] -pub struct Sender + 'static> { - inner: &'static Mutex>, -} - -impl + 'static> Sender { - /// Initialize the Sender, giving it the pieces it needs - /// - /// Panics if called more than once. - pub fn init_sender( - sc: &'static StaticCell>>, - tx_buf: &'static mut [u8], - ep_in: D::EndpointIn, - ) -> Self { - let max_log_len = actual_varint_max_len(tx_buf.len()); - let x = sc.init(Mutex::new(SenderInner { - ep_in, - tx_buf, - log_seq: 0, - max_log_len, - })); - Sender { inner: x } - } - - /// Send a reply for the given endpoint - #[inline] - pub async fn reply(&self, seq_no: u32, resp: &E::Response) -> Result<(), ()> - where - E: crate::Endpoint, - E::Response: Serialize + Schema, - { - let mut inner = self.inner.lock().await; - let SenderInner { - ep_in, - tx_buf, - log_seq: _, - max_log_len: _, - } = &mut *inner; - if let Ok(used) = crate::headered::to_slice_keyed(seq_no, E::RESP_KEY, resp, tx_buf) { - send_all::(ep_in, used, true).await - } else { - Err(()) - } - } - - /// Send a reply with the given Key - /// - /// This is useful when replying with "unusual" keys, for example Error responses - /// not tied to any specific Endpoint. - #[inline] - pub async fn reply_keyed(&self, seq_no: u32, key: Key, resp: &T) -> Result<(), ()> - where - T: Serialize + Schema, - { - let mut inner = self.inner.lock().await; - let SenderInner { - ep_in, - tx_buf, - log_seq: _, - max_log_len: _, - } = &mut *inner; - if let Ok(used) = crate::headered::to_slice_keyed(seq_no, key, resp, tx_buf) { - send_all::(ep_in, used, true).await - } else { - Err(()) - } - } - - /// Publish a Topic message - #[inline] - pub async fn publish(&self, seq_no: u32, msg: &T::Message) -> Result<(), ()> - where - T: crate::Topic, - T::Message: Serialize + Schema, - { - let mut inner = self.inner.lock().await; - let SenderInner { - ep_in, - tx_buf, - log_seq: _, - max_log_len: _, - } = &mut *inner; - - if let Ok(used) = crate::headered::to_slice_keyed(seq_no, T::TOPIC_KEY, msg, tx_buf) { - send_all::(ep_in, used, true).await - } else { - Err(()) - } - } - - pub async fn str_publish<'a, T>(&self, s: &'a str) - where - T: crate::Topic, - { - let mut inner = self.inner.lock().await; - let SenderInner { - ep_in, - tx_buf, - log_seq, - max_log_len: _, - } = &mut *inner; - let seq_no = *log_seq; - *log_seq = log_seq.wrapping_add(1); - if let Ok(used) = - crate::headered::to_slice_keyed(seq_no, T::TOPIC_KEY, s.as_bytes(), tx_buf) - { - let _ = send_all::(ep_in, used, false).await; - } - } - - pub async fn fmt_publish<'a, T>(&self, args: core::fmt::Arguments<'a>) - where - T: crate::Topic, - { - let mut inner = self.inner.lock().await; - let SenderInner { - ep_in, - tx_buf, - log_seq, - max_log_len, - } = &mut *inner; - let ttl_len = tx_buf.len(); - - // First, populate the header - let hdr = crate::WireHeader { - key: T::TOPIC_KEY, - seq_no: *log_seq, - }; - *log_seq = log_seq.wrapping_add(1); - let Ok(hdr_used) = postcard::to_slice(&hdr, tx_buf) else { - return; - }; - let hdr_used = hdr_used.len(); - - // Then, reserve space for non-canonical length fields - // We also set all but the last bytes to be "continuation" - // bytes - let (_, remaining) = tx_buf.split_at_mut(hdr_used); - if remaining.len() < *max_log_len { - return; - } - let (len_field, body) = remaining.split_at_mut(*max_log_len); - for b in len_field.iter_mut() { - *b = 0x80; - } - len_field.last_mut().map(|b| *b = 0x00); - - // Then, do the formatting - let body_len = body.len(); - let mut sw = SliceWriter(body); - let res = core::fmt::write(&mut sw, args); - - // Calculate the number of bytes used *for formatting*. - let remain = sw.0.len(); - let used = body_len - remain; - - // If we had an error, that's probably because we ran out - // of room. If we had an error, AND there is at least three - // bytes, then replace those with '.'s like ... - if res.is_err() && (body.len() >= 3) { - let start = body.len() - 3; - body[start..].iter_mut().for_each(|b| *b = b'.'); - } - - // then go back and fill in the len - we write the len - // directly to the reserved bytes, and if we DIDN'T use - // the full space, we mark the end of the real length as - // a continuation field. This will result in a non-canonical - // "extended" length in postcard, and will "spill into" the - // bytes we wrote previously above - let mut len_bytes = [0u8; varint_max::()]; - let len_used = varint_usize(used, &mut len_bytes); - if len_used.len() != len_field.len() { - len_used.last_mut().map(|b| *b = *b | 0x80); - } - len_field[..len_used.len()].copy_from_slice(len_used); - - // Calculate the TOTAL amount - let act_used = ttl_len - remain; - - let _ = send_all::(ep_in, &tx_buf[..act_used], false).await; - } -} - -impl + 'static> Clone for Sender { - fn clone(&self) -> Self { - Sender { inner: self.inner } - } -} - -/// Implementation detail, holding the endpoint and scratch buffer used for sending -pub struct SenderInner> { - ep_in: D::EndpointIn, - tx_buf: &'static mut [u8], - log_seq: u32, - max_log_len: usize, -} - -/// Helper function for sending a single frame. -/// -/// If an empty slice is provided, no bytes will be sent. -#[inline] -async fn send_all( - ep_in: &mut D::EndpointIn, - out: &[u8], - wait_for_enabled: bool, -) -> Result<(), ()> -where - D: Driver<'static>, -{ - if out.is_empty() { - return Ok(()); - } - if wait_for_enabled { - ep_in.wait_enabled().await; - } else if ep_in.wait_enabled().now_or_never().is_none() { - return Ok(()); - } - - // write in segments of 64. The last chunk may - // be 0 < len <= 64. - for ch in out.chunks(64) { - if ep_in.write(ch).await.is_err() { - return Err(()); - } - } - // If the total we sent was a multiple of 64, send an - // empty message to "flush" the transaction. We already checked - // above that the len != 0. - if (out.len() & (64 - 1)) == 0 && ep_in.write(&[]).await.is_err() { - return Err(()); - } - - Ok(()) -} - -struct SliceWriter<'a>(&'a mut [u8]); - -impl<'a> core::fmt::Write for SliceWriter<'a> { - fn write_str(&mut self, s: &str) -> Result<(), core::fmt::Error> { - let sli = core::mem::take(&mut self.0); - - // If this write would overflow us, note that, but still take - // as much as we possibly can here - let bad = s.len() > sli.len(); - let to_write = s.len().min(sli.len()); - let (now, later) = sli.split_at_mut(to_write); - now.copy_from_slice(s.as_bytes()); - self.0 = later; - - // Now, report whether we overflowed or not - if bad { - Err(core::fmt::Error) - } else { - Ok(()) - } - } -} - -/// Returns the maximum number of bytes required to encode T. -const fn varint_max() -> usize { - const BITS_PER_BYTE: usize = 8; - const BITS_PER_VARINT_BYTE: usize = 7; - - // How many data bits do we need for this type? - let bits = core::mem::size_of::() * BITS_PER_BYTE; - - // We add (BITS_PER_VARINT_BYTE - 1), to ensure any integer divisions - // with a remainder will always add exactly one full byte, but - // an evenly divided number of bits will be the same - let roundup_bits = bits + (BITS_PER_VARINT_BYTE - 1); - - // Apply division, using normal "round down" integer division - roundup_bits / BITS_PER_VARINT_BYTE -} - -#[inline] -fn varint_usize(n: usize, out: &mut [u8; varint_max::()]) -> &mut [u8] { - let mut value = n; - for i in 0..varint_max::() { - out[i] = value.to_le_bytes()[0]; - if value < 128 { - return &mut out[..=i]; - } - - out[i] |= 0x80; - value >>= 7; - } - debug_assert_eq!(value, 0); - &mut out[..] -} - -fn actual_varint_max_len(largest: usize) -> usize { - if largest < (2 << 7) { - 1 - } else if largest < (2 << 14) { - 2 - } else if largest < (2 << 21) { - 3 - } else if largest < (2 << 28) { - 4 - } else { - varint_max::() - } -} diff --git a/source/postcard-rpc/src/test_utils.rs b/source/postcard-rpc/src/test_utils.rs index f5b3901..d51c7db 100644 --- a/source/postcard-rpc/src/test_utils.rs +++ b/source/postcard-rpc/src/test_utils.rs @@ -2,11 +2,11 @@ use core::{fmt::Display, future::Future}; +use crate::header::{VarHeader, VarKey, VarSeq, VarSeqKind}; use crate::host_client::util::Stopper; use crate::{ - headered::extract_header_from_bytes, host_client::{HostClient, RpcFrame, WireRx, WireSpawn, WireTx}, - Endpoint, Topic, WireHeader, + Endpoint, Topic, }; use postcard_schema::Schema; use serde::{de::DeserializeOwned, Serialize}; @@ -15,25 +15,32 @@ use tokio::{ sync::mpsc::{channel, Receiver, Sender}, }; +/// Rx Helper type pub struct LocalRx { fake_error: Stopper, from_server: Receiver>, } +/// Tx Helper type pub struct LocalTx { fake_error: Stopper, to_server: Sender>, } +/// Spawn helper type pub struct LocalSpawn; +/// Server type pub struct LocalFakeServer { fake_error: Stopper, + /// from client to server pub from_client: Receiver>, + /// from server to client pub to_client: Sender>, } impl LocalFakeServer { + /// receive a frame pub async fn recv_from_client(&mut self) -> Result { let msg = self.from_client.recv().await.ok_or(LocalError::TxClosed)?; - let Ok((hdr, body)) = extract_header_from_bytes(&msg) else { + let Some((hdr, body)) = VarHeader::take_from_slice(&msg) else { return Err(LocalError::BadFrame); }; Ok(RpcFrame { @@ -42,6 +49,7 @@ impl LocalFakeServer { }) } + /// Reply pub async fn reply( &mut self, seq_no: u32, @@ -51,9 +59,9 @@ impl LocalFakeServer { E::Response: Serialize, { let frame = RpcFrame { - header: WireHeader { - key: E::RESP_KEY, - seq_no, + header: VarHeader { + key: VarKey::Key8(E::RESP_KEY), + seq_no: VarSeq::Seq4(seq_no), }, body: postcard::to_stdvec(data).unwrap(), }; @@ -63,6 +71,7 @@ impl LocalFakeServer { .map_err(|_| LocalError::RxClosed) } + /// Publish pub async fn publish( &mut self, seq_no: u32, @@ -72,9 +81,9 @@ impl LocalFakeServer { T::Message: Serialize, { let frame = RpcFrame { - header: WireHeader { - key: T::TOPIC_KEY, - seq_no, + header: VarHeader { + key: VarKey::Key8(T::TOPIC_KEY), + seq_no: VarSeq::Seq4(seq_no), }, body: postcard::to_stdvec(data).unwrap(), }; @@ -84,16 +93,22 @@ impl LocalFakeServer { .map_err(|_| LocalError::RxClosed) } + /// oops pub fn cause_fatal_error(&self) { self.fake_error.stop(); } } +/// Local error type #[derive(Debug, PartialEq)] pub enum LocalError { + /// RxClosed RxClosed, + /// TxClosed TxClosed, + /// BadFrame BadFrame, + /// FatalError FatalError, } @@ -187,6 +202,7 @@ where fake_error: fake_error.clone(), }, LocalSpawn, + VarSeqKind::Seq2, err_uri_path, bound, ); diff --git a/source/postcard-rpc/src/uniques.rs b/source/postcard-rpc/src/uniques.rs new file mode 100644 index 0000000..aa7f90f --- /dev/null +++ b/source/postcard-rpc/src/uniques.rs @@ -0,0 +1,944 @@ +//! Create unique type lists at compile time +//! +//! This is an excercise in the capabilities of macros and const fns. +//! +//! From a very high level, the process goes like this: +//! +//! 1. We recursively look at a type, counting how many types it contains, +//! WITHOUT considering de-duplication. This is used as an "upper bound" +//! of the number of potential types we could have to report +//! 2. Create an array of `[Option<&NamedType>; MAX]` that we use something +//! like an append-only vec. +//! 3. Recursively traverse the type AGAIN, this time collecting all unique +//! non-primitive types we encounter, and adding them to the list. This +//! is outrageously inefficient, but it is done at const time with all +//! the restrictions it entails, because we don't pay at runtime. +//! 4. Record how many types we ACTUALLY collected in step 3, and create a +//! new array, `[&NamedType; ACTUAL]`, and copy the unique types into +//! this new array +//! 5. Convert this `[&NamedType; N]` array into a `&'static [&NamedType]` +//! array to make it possible to handle with multiple types +//! 6. If we are collecting MULTIPLE types into a single aggregate report, +//! then we make a new array of `[Option<&NamedType>; sum(all types)]`, +//! by calculating the sum of types contained for each list calculated +//! in step 4. +//! 7. We then perform the same "merging" process from 3, pushing any unique +//! type we find into the aggregate list, and recording the number of +//! unique types we found in the entire set. +//! 8. We then perform the same "shrinking" process from step 4, leaving us +//! with a single array, `[&NamedType; TOTAL]` containing all unique types +//! 9. We then perform the same "slicing" process from step 5, to get our +//! final `&'static [&NamedType]`. + +use postcard_schema::{ + schema::{DataModelType, DataModelVariant, NamedType, NamedValue, NamedVariant}, + Schema, +}; + +////////////////////////////////////////////////////////////////////////////// +// STAGE 0 - HELPERS +////////////////////////////////////////////////////////////////////////////// + +/// `is_prim` returns whether the type is a *primitive*, or a built-in type that +/// does not need to be sent over the wire. +const fn is_prim(dmt: &DataModelType) -> bool { + match dmt { + // These are all primitives + DataModelType::Bool => true, + DataModelType::I8 => true, + DataModelType::U8 => true, + DataModelType::I16 => true, + DataModelType::I32 => true, + DataModelType::I64 => true, + DataModelType::I128 => true, + DataModelType::U16 => true, + DataModelType::U32 => true, + DataModelType::U64 => true, + DataModelType::U128 => true, + DataModelType::Usize => true, + DataModelType::Isize => true, + DataModelType::F32 => true, + DataModelType::F64 => true, + DataModelType::Char => true, + DataModelType::String => true, + DataModelType::ByteArray => true, + DataModelType::Unit => true, + DataModelType::Schema => true, + + // A unit-struct is always named, so it is not primitive, as the + // name has meaning even without a value + DataModelType::UnitStruct => false, + // Items with subtypes are composite, and therefore not primitives, as + // we need to convey this information. + DataModelType::Option(_) | DataModelType::NewtypeStruct(_) | DataModelType::Seq(_) => false, + DataModelType::Tuple(_) | DataModelType::TupleStruct(_) => false, + DataModelType::Map { .. } => false, + DataModelType::Struct(_) => false, + DataModelType::Enum(_) => false, + } +} + +/// A const version of `::eq` +const fn str_eq(a: &str, b: &str) -> bool { + let mut i = 0; + if a.len() != b.len() { + return false; + } + let a_by = a.as_bytes(); + let b_by = b.as_bytes(); + while i < a.len() { + if a_by[i] != b_by[i] { + return false; + } + i += 1; + } + true +} + +/// A const version of `::eq` +const fn nty_eq(a: &NamedType, b: &NamedType) -> bool { + str_eq(a.name, b.name) && dmt_eq(a.ty, b.ty) +} + +/// A const version of `<[&NamedType] as PartialEq>::eq` +const fn ntys_eq(a: &[&NamedType], b: &[&NamedType]) -> bool { + if a.len() != b.len() { + return false; + } + let mut i = 0; + while i < a.len() { + if !nty_eq(a[i], b[i]) { + return false; + } + i += 1; + } + true +} + +/// A const version of `::eq` +const fn dmt_eq(a: &DataModelType, b: &DataModelType) -> bool { + match (a, b) { + // Data model types are ONLY matching if they are both the same variant + // + // For primitives (and unit structs), we only check the discriminant matches. + (DataModelType::Bool, DataModelType::Bool) => true, + (DataModelType::I8, DataModelType::I8) => true, + (DataModelType::U8, DataModelType::U8) => true, + (DataModelType::I16, DataModelType::I16) => true, + (DataModelType::I32, DataModelType::I32) => true, + (DataModelType::I64, DataModelType::I64) => true, + (DataModelType::I128, DataModelType::I128) => true, + (DataModelType::U16, DataModelType::U16) => true, + (DataModelType::U32, DataModelType::U32) => true, + (DataModelType::U64, DataModelType::U64) => true, + (DataModelType::U128, DataModelType::U128) => true, + (DataModelType::Usize, DataModelType::Usize) => true, + (DataModelType::Isize, DataModelType::Isize) => true, + (DataModelType::F32, DataModelType::F32) => true, + (DataModelType::F64, DataModelType::F64) => true, + (DataModelType::Char, DataModelType::Char) => true, + (DataModelType::String, DataModelType::String) => true, + (DataModelType::ByteArray, DataModelType::ByteArray) => true, + (DataModelType::Unit, DataModelType::Unit) => true, + (DataModelType::UnitStruct, DataModelType::UnitStruct) => true, + (DataModelType::Schema, DataModelType::Schema) => true, + + // For non-primitive types, we check whether all children are equivalent as well. + (DataModelType::Option(nta), DataModelType::Option(ntb)) => nty_eq(nta, ntb), + (DataModelType::NewtypeStruct(nta), DataModelType::NewtypeStruct(ntb)) => nty_eq(nta, ntb), + (DataModelType::Seq(nta), DataModelType::Seq(ntb)) => nty_eq(nta, ntb), + + (DataModelType::Tuple(ntsa), DataModelType::Tuple(ntsb)) => ntys_eq(ntsa, ntsb), + (DataModelType::TupleStruct(ntsa), DataModelType::TupleStruct(ntsb)) => ntys_eq(ntsa, ntsb), + ( + DataModelType::Map { + key: keya, + val: vala, + }, + DataModelType::Map { + key: keyb, + val: valb, + }, + ) => nty_eq(keya, keyb) && nty_eq(vala, valb), + (DataModelType::Struct(nvalsa), DataModelType::Struct(nvalsb)) => vals_eq(nvalsa, nvalsb), + (DataModelType::Enum(nvarsa), DataModelType::Enum(nvarsb)) => vars_eq(nvarsa, nvarsb), + + // Any mismatches are not equal + _ => false, + } +} + +/// A const version of `::eq` +const fn var_eq(a: &NamedVariant, b: &NamedVariant) -> bool { + str_eq(a.name, b.name) && dmv_eq(a.ty, b.ty) +} + +/// A const version of `<&[&NamedVariant] as PartialEq>::eq` +const fn vars_eq(a: &[&NamedVariant], b: &[&NamedVariant]) -> bool { + if a.len() != b.len() { + return false; + } + let mut i = 0; + while i < a.len() { + if !var_eq(a[i], b[i]) { + return false; + } + i += 1; + } + true +} + +/// A const version of `<&[&NamedValue] as PartialEq>::eq` +const fn vals_eq(a: &[&NamedValue], b: &[&NamedValue]) -> bool { + if a.len() != b.len() { + return false; + } + let mut i = 0; + while i < a.len() { + if !str_eq(a[i].name, b[i].name) { + return false; + } + if !nty_eq(a[i].ty, b[i].ty) { + return false; + } + + i += 1; + } + true +} + +/// A const version of `::eq` +const fn dmv_eq(a: &DataModelVariant, b: &DataModelVariant) -> bool { + match (a, b) { + (DataModelVariant::UnitVariant, DataModelVariant::UnitVariant) => true, + (DataModelVariant::NewtypeVariant(nta), DataModelVariant::NewtypeVariant(ntb)) => { + nty_eq(nta, ntb) + } + (DataModelVariant::TupleVariant(ntsa), DataModelVariant::TupleVariant(ntsb)) => { + ntys_eq(ntsa, ntsb) + } + (DataModelVariant::StructVariant(nvarsa), DataModelVariant::StructVariant(nvarsb)) => { + vals_eq(nvarsa, nvarsb) + } + _ => false, + } +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 1 - UPPER BOUND CALCULATION +////////////////////////////////////////////////////////////////////////////// + +/// Count the number of unique types contained by this NamedType, +/// including children and this type itself. +/// +/// For built-in primitives, this could be zero. For non-primitive +/// types, this will be at least one. +/// +/// This function does NOT attempt to perform any de-duplication. +pub const fn unique_types_nty_upper(nty: &NamedType) -> usize { + let child_ct = unique_types_dmt_upper(nty.ty); + if is_prim(nty.ty) { + child_ct + } else { + child_ct + 1 + } +} + +/// Count the number of unique types contained by this DataModelType, +/// ONLY counting children, and not this type, as this will be counted +/// when considering the NamedType instead. +// +// TODO: We could attempt to do LOCAL de-duplication, for example +// a `[u8; 32]` would end up as a tuple of 32 items, drastically +// inflating the total. +const fn unique_types_dmt_upper(dmt: &DataModelType) -> usize { + match dmt { + // These are all primitives + DataModelType::Bool => 0, + DataModelType::I8 => 0, + DataModelType::U8 => 0, + DataModelType::I16 => 0, + DataModelType::I32 => 0, + DataModelType::I64 => 0, + DataModelType::I128 => 0, + DataModelType::U16 => 0, + DataModelType::U32 => 0, + DataModelType::U64 => 0, + DataModelType::U128 => 0, + DataModelType::Usize => 0, + DataModelType::Isize => 0, + DataModelType::F32 => 0, + DataModelType::F64 => 0, + DataModelType::Char => 0, + DataModelType::String => 0, + DataModelType::ByteArray => 0, + DataModelType::Unit => 0, + DataModelType::UnitStruct => 0, + DataModelType::Schema => 0, + + // Items with one subtype + DataModelType::Option(nt) | DataModelType::NewtypeStruct(nt) | DataModelType::Seq(nt) => { + unique_types_nty_upper(nt) + } + // tuple-ish + DataModelType::Tuple(nts) | DataModelType::TupleStruct(nts) => { + let mut uniq = 0; + let mut i = 0; + while i < nts.len() { + uniq += unique_types_nty_upper(nts[i]); + i += 1; + } + uniq + } + DataModelType::Map { key, val } => { + unique_types_nty_upper(key) + unique_types_nty_upper(val) + } + DataModelType::Struct(nvals) => { + let mut uniq = 0; + let mut i = 0; + while i < nvals.len() { + uniq += unique_types_nty_upper(nvals[i].ty); + i += 1; + } + uniq + } + DataModelType::Enum(nvars) => { + let mut uniq = 0; + let mut i = 0; + while i < nvars.len() { + uniq += unique_types_var_upper(nvars[i]); + i += 1; + } + uniq + } + } +} + +/// Count the number of unique types contained by this NamedVariant, +/// ONLY counting children, and not this type, as this will be counted +/// when considering the NamedType instead. +// +// TODO: We could attempt to do LOCAL de-duplication, for example +// a `[u8; 32]` would end up as a tuple of 32 items, drastically +// inflating the total. +const fn unique_types_var_upper(nvar: &NamedVariant) -> usize { + match nvar.ty { + DataModelVariant::UnitVariant => 0, + DataModelVariant::NewtypeVariant(nt) => unique_types_nty_upper(nt), + DataModelVariant::TupleVariant(nts) => { + let mut uniq = 0; + let mut i = 0; + while i < nts.len() { + uniq += unique_types_nty_upper(nts[i]); + i += 1; + } + uniq + } + DataModelVariant::StructVariant(nvals) => { + let mut uniq = 0; + let mut i = 0; + while i < nvals.len() { + uniq += unique_types_nty_upper(nvals[i].ty); + i += 1; + } + uniq + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 2/3 - COLLECTION OF UNIQUES AND CALCULATION OF EXACT SIZE +////////////////////////////////////////////////////////////////////////////// + +/// This function collects the set of unique types, reporting the entire list +/// (which might only be partially used), as well as the *used* length. +/// +/// The parameter MAX should be the highest possible number of unique types, +/// if NONE of the types have any duplication. This should be calculated using +/// [`unique_types_nty_upper()`]. This upper bound allows us to pre-allocate +/// enough storage for the collection process. +pub const fn type_chewer_nty( + nty: &NamedType, +) -> ([Option<&NamedType>; MAX], usize) { + // Calculate the number of unique items in the children of this type + let (mut arr, mut used) = type_chewer_dmt::(nty.ty); + let mut i = 0; + + // determine if this is a single-item primitive - if so, skip adding + // this type to the unique list + let mut found = is_prim(nty.ty); + + while !found && i < used { + let Some(ty) = arr[i] else { panic!() }; + if nty_eq(nty, ty) { + found = true; + } + i += 1; + } + if !found { + arr[used] = Some(nty); + used += 1; + } + (arr, used) +} + +/// This function collects the set of unique types, reporting the entire list +/// (which might only be partially used), as well as the *used* length. +/// +/// The parameter MAX should be the highest possible number of unique types, +/// if NONE of the types have any duplication. This should be calculated using +/// [`unique_types_nty_upper()`]. This upper bound allows us to pre-allocate +/// enough storage for the collection process. +// +// TODO: There is a LOT of duplicated code here. This is to reduce the number of +// intermediate `[Option; MAX]` arrays we contain, as well as the total amount +// of recursion depth. I am open to suggestions of how to reduce this. Part of +// this restriction is that we can't take an `&mut` as a const fn arg, so we +// always have to do it by value, then merge-in the changes. +const fn type_chewer_dmt( + dmt: &DataModelType, +) -> ([Option<&NamedType>; MAX], usize) { + match dmt { + // These are all primitives - they never have any children to report. + DataModelType::Bool => ([None; MAX], 0), + DataModelType::I8 => ([None; MAX], 0), + DataModelType::U8 => ([None; MAX], 0), + DataModelType::I16 => ([None; MAX], 0), + DataModelType::I32 => ([None; MAX], 0), + DataModelType::I64 => ([None; MAX], 0), + DataModelType::I128 => ([None; MAX], 0), + DataModelType::U16 => ([None; MAX], 0), + DataModelType::U32 => ([None; MAX], 0), + DataModelType::U64 => ([None; MAX], 0), + DataModelType::U128 => ([None; MAX], 0), + DataModelType::Usize => ([None; MAX], 0), + DataModelType::Isize => ([None; MAX], 0), + DataModelType::F32 => ([None; MAX], 0), + DataModelType::F64 => ([None; MAX], 0), + DataModelType::Char => ([None; MAX], 0), + DataModelType::String => ([None; MAX], 0), + DataModelType::ByteArray => ([None; MAX], 0), + DataModelType::Unit => ([None; MAX], 0), + DataModelType::Schema => ([None; MAX], 0), + + // A unit struct *as a namedtype* can be a unique/non-primitive type, + // but DataModelType calculation is only concerned with CHILDREN, and + // a unit struct has none. + DataModelType::UnitStruct => ([None; MAX], 0), + + // Items with one subtype + DataModelType::Option(nt) | DataModelType::NewtypeStruct(nt) | DataModelType::Seq(nt) => { + type_chewer_nty::(nt) + } + // tuple-ish + DataModelType::Tuple(nts) | DataModelType::TupleStruct(nts) => { + let mut out = [None; MAX]; + let mut i = 0; + let mut outidx = 0; + + // For each type in the tuple... + while i < nts.len() { + // Get the types used by this field + let (arr, used) = type_chewer_nty::(nts[i]); + let mut j = 0; + // For each type in this field... + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + i += 1; + } + (out, outidx) + } + DataModelType::Map { key, val } => { + let mut out = [None; MAX]; + let mut outidx = 0; + + // Do key + let (arr, used) = type_chewer_nty::(key); + let mut j = 0; + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + + // Then do val + let (arr, used) = type_chewer_nty::(val); + let mut j = 0; + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + + (out, outidx) + } + DataModelType::Struct(nvals) => { + let mut out = [None; MAX]; + let mut i = 0; + let mut outidx = 0; + + // For each type in the tuple... + while i < nvals.len() { + // Get the types used by this field + let (arr, used) = type_chewer_nty::(nvals[i].ty); + let mut j = 0; + // For each type in this field... + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + i += 1; + } + (out, outidx) + } + DataModelType::Enum(nvars) => { + let mut out = [None; MAX]; + let mut i = 0; + let mut outidx = 0; + + // For each type in the variant... + while i < nvars.len() { + match nvars[i].ty { + DataModelVariant::UnitVariant => continue, + DataModelVariant::NewtypeVariant(nt) => { + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, nt); + k += 1; + } + if !found { + out[outidx] = Some(nt); + outidx += 1; + } + } + DataModelVariant::TupleVariant(nts) => { + let mut x = 0; + + // For each type in the tuple... + while x < nts.len() { + // Get the types used by this field + let (arr, used) = type_chewer_nty::(nts[x]); + let mut j = 0; + // For each type in this field... + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + x += 1; + } + } + DataModelVariant::StructVariant(nvals) => { + let mut x = 0; + + // For each type in the struct... + while x < nvals.len() { + // Get the types used by this field + let (arr, used) = type_chewer_nty::(nvals[x].ty); + let mut j = 0; + // For each type in this field... + while j < used { + let Some(ty) = arr[j] else { panic!() }; + let mut k = 0; + let mut found = false; + // Check against all currently known tys + while !found && k < outidx { + let Some(kty) = out[k] else { panic!() }; + found |= nty_eq(kty, ty); + k += 1; + } + if !found { + out[outidx] = Some(ty); + outidx += 1; + } + j += 1; + } + x += 1; + } + } + } + i += 1; + } + (out, outidx) + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 4 - REDUCTION TO CORRECT SIZE +////////////////////////////////////////////////////////////////////////////// + +/// This function reduces a `&[Option<&NamedType>]` to a `[&NamedType; A]`. +/// +/// The parameter `A` should be calculated by [`type_chewer_nty()`]. +/// +/// We also validate that all items >= idx `A` are in fact None. +pub const fn cruncher( + opts: &[Option<&'static NamedType>], +) -> [&'static NamedType; A] { + let mut out = [<() as Schema>::SCHEMA; A]; + let mut i = 0; + while i < A { + let Some(ty) = opts[i] else { panic!() }; + out[i] = ty; + i += 1; + } + while i < opts.len() { + assert!(opts[i].is_none()); + i += 1; + } + out +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 1-5 (macro op) +////////////////////////////////////////////////////////////////////////////// + +/// `unique_types` collects all unique, non-primitive types contained by the given +/// single type. It can be used with any type that implements the [`Schema`] trait, +/// and returns a `&'static [&'static NamedType]`. +#[macro_export] +macro_rules! unique_types { + ($t:ty) => { + const { + const MAX_TYS: usize = + $crate::uniques::unique_types_nty_upper(<$t as postcard_schema::Schema>::SCHEMA); + const BIG_RPT: ( + [Option<&'static postcard_schema::schema::NamedType>; MAX_TYS], + usize, + ) = $crate::uniques::type_chewer_nty(<$t as postcard_schema::Schema>::SCHEMA); + const SMALL_RPT: [&'static postcard_schema::schema::NamedType; BIG_RPT.1] = + $crate::uniques::cruncher(BIG_RPT.0.as_slice()); + SMALL_RPT.as_slice() + } + }; +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 6 - COLLECTION OF UNIQUES ACROSS MULTIPLE TYPES +////////////////////////////////////////////////////////////////////////////// + +/// This function turns an array of type lists into a single list of unique types +/// +/// The type parameter `M` is the maximum potential output size, it should be +/// equal to `lists.iter().map(|l| l.len()).sum()`, and should generally be +/// calculated as part of [`merge_unique_types!()`][crate::merge_unique_types]. +pub const fn merge_nty_lists( + lists: &[&[&'static NamedType]], +) -> ([Option<&'static NamedType>; M], usize) { + let mut out: [Option<&NamedType>; M] = [None; M]; + let mut out_ct = 0; + let mut i = 0; + + while i < lists.len() { + let mut j = 0; + let list = lists[i]; + while j < list.len() { + let item = list[j]; + let mut k = 0; + let mut found = false; + while !found && k < out_ct { + let Some(oitem) = out[k] else { panic!() }; + if nty_eq(item, oitem) { + found = true; + } + k += 1; + } + if !found { + out[out_ct] = Some(item); + out_ct += 1; + } + j += 1; + } + i += 1; + } + + (out, out_ct) +} + +////////////////////////////////////////////////////////////////////////////// +// STAGE 6-9 (macro op) +////////////////////////////////////////////////////////////////////////////// + +/// `merge_unique_types` collects all unique, non-primitive types contained by +/// the given comma separated types. It can be used with any types that implement +/// the [`Schema`] trait, and returns a `&'static [&'static NamedType]`. +#[macro_export] +macro_rules! merge_unique_types { + ($($t:ty,)*) => { + const { + const LISTS: &[&[&'static postcard_schema::schema::NamedType]] = &[ + $( + $crate::unique_types!($t), + )* + ]; + const TTL_COUNT: usize = const { + let mut i = 0; + let mut ct = 0; + while i < LISTS.len() { + ct += LISTS[i].len(); + i += 1; + } + ct + }; + const BIG_RPT: ([Option<&'static postcard_schema::schema::NamedType>; TTL_COUNT], usize) = $crate::uniques::merge_nty_lists(LISTS); + const SMALL_RPT: [&'static postcard_schema::schema::NamedType; BIG_RPT.1] = $crate::uniques::cruncher(BIG_RPT.0.as_slice()); + SMALL_RPT.as_slice() + } + } +} + +#[cfg(test)] +mod test { + #![allow(dead_code)] + use postcard_schema::{ + schema::{owned::OwnedNamedType, NamedType}, + Schema, + }; + + use crate::uniques::{ + is_prim, type_chewer_nty, unique_types_dmt_upper, unique_types_nty_upper, + }; + + #[derive(Schema)] + struct Example0; + + #[derive(Schema)] + struct ExampleA { + a: u32, + } + + #[derive(Schema)] + struct Example1 { + a: u32, + b: Option, + } + + #[derive(Schema)] + struct Example2 { + x: i32, + y: Option, + c: Example1, + } + + #[derive(Schema)] + struct Example3 { + a: u32, + b: Option, + c: Example2, + d: Example2, + e: Example2, + } + + #[test] + fn subpar_arrs() { + const MAXARR: usize = unique_types_nty_upper(<[Example0; 32]>::SCHEMA); + // I don't *like* this, it really should be 2. Leaving it as a test so + // I can remember that it's here. See TODO on unique_types_dmt_upper. + assert_eq!(MAXARR, 33); + } + + #[test] + fn uniqlo() { + const MAX0: usize = unique_types_nty_upper(Example0::SCHEMA); + const MAXA: usize = unique_types_nty_upper(ExampleA::SCHEMA); + const MAX1: usize = unique_types_nty_upper(Example1::SCHEMA); + const MAX2: usize = unique_types_nty_upper(Example2::SCHEMA); + const MAX3: usize = unique_types_nty_upper(Example3::SCHEMA); + assert_eq!(MAX0, 1); + assert_eq!(MAXA, 1); + assert_eq!(MAX1, 2); + assert_eq!(MAX2, 4); + assert_eq!(MAX3, 14); + + println!(); + println!("Example0"); + let (arr0, used): ([Option<_>; MAX0], usize) = type_chewer_nty(Example0::SCHEMA); + assert_eq!(used, 1); + println!("max: {MAX0} used: {used}"); + for a in arr0 { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + println!("ExampleA"); + let (arra, used): ([Option<_>; MAXA], usize) = type_chewer_nty(ExampleA::SCHEMA); + assert_eq!(used, 1); + println!("max: {MAXA} used: {used}"); + for a in arra { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + println!("Option"); + let (arr1, used): ( + [Option<_>; unique_types_nty_upper(Option::::SCHEMA)], + usize, + ) = type_chewer_nty(Option::::SCHEMA); + assert_eq!(used, 1); + println!( + "max: {} used: {used}", + unique_types_nty_upper(Option::::SCHEMA) + ); + for a in arr1 { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + println!("Example1"); + let (arr1, used): ([Option<_>; MAX1], usize) = type_chewer_nty(Example1::SCHEMA); + assert!(!is_prim(Example1::SCHEMA.ty)); + let child_ct = unique_types_dmt_upper(Example1::SCHEMA.ty); + assert_eq!(child_ct, 1); + assert_eq!(used, 2); + println!("max: {MAX1} used: {used}"); + for a in arr1 { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + println!("Example2"); + let (arr2, used): ([Option<_>; MAX2], usize) = type_chewer_nty(Example2::SCHEMA); + println!("max: {MAX2} used: {used}"); + for a in arr2 { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + println!("Example3"); + let (arr3, used): ([Option<_>; MAX3], usize) = type_chewer_nty(Example3::SCHEMA); + println!("max: {MAX3} used: {used}"); + for a in arr3 { + match a { + Some(a) => println!("Some({})", OwnedNamedType::from(a)), + None => println!("None"), + } + } + + println!(); + let rpt0 = unique_types!(Example0); + println!("{}", rpt0.len()); + for a in rpt0 { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + let rpta = unique_types!(ExampleA); + println!("{}", rpta.len()); + for a in rpta { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + let rpt1 = unique_types!(Example1); + println!("{}", rpt1.len()); + for a in rpt1 { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + let rpt2 = unique_types!(Example2); + println!("{}", rpt2.len()); + for a in rpt2 { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + let rpt3 = unique_types!(Example3); + println!("{}", rpt3.len()); + for a in rpt3 { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + const MERGED: &[&NamedType] = merge_unique_types![Example3, ExampleA, Example0,]; + println!("{}", MERGED.len()); + for a in MERGED { + println!("{}", OwnedNamedType::from(*a)) + } + + println!(); + println!(); + println!(); + println!(); + + // panic!("test passed but I want to see the data"); + } +}