From a430baf3f52a7b4a4029a4b9a244d06350bf61c4 Mon Sep 17 00:00:00 2001 From: Klimenty Tsoutsman Date: Sun, 17 Dec 2023 09:14:40 +1100 Subject: [PATCH] Add `async_channel` without cancel safety Signed-off-by: Klimenty Tsoutsman --- kernel/async_channel/Cargo.toml | 14 +++ kernel/async_channel/src/lib.rs | 164 +++++++++++++++++++++++++++++ kernel/async_wait_queue/Cargo.toml | 12 +++ kernel/async_wait_queue/src/lib.rs | 105 ++++++++++++++++++ 4 files changed, 295 insertions(+) create mode 100644 kernel/async_channel/Cargo.toml create mode 100644 kernel/async_channel/src/lib.rs create mode 100644 kernel/async_wait_queue/Cargo.toml create mode 100644 kernel/async_wait_queue/src/lib.rs diff --git a/kernel/async_channel/Cargo.toml b/kernel/async_channel/Cargo.toml new file mode 100644 index 0000000000..fe267bb1df --- /dev/null +++ b/kernel/async_channel/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "async_channel" +version = "0.1.0" +authors = ["Klim Tsoutsman "] +description = "A bounded, multi-producer, multi-consumer asynchronous channel" +edition = "2021" + +[dependencies] +async_wait_queue = { path = "../async_wait_queue" } +dreadnought = { path = "../dreadnought" } +futures = { version = "0.3.28", default-features = false } +mpmc = "0.1.6" +sync = { path = "../../libs/sync" } +sync_spin = { path = "../../libs/sync_spin" } diff --git a/kernel/async_channel/src/lib.rs b/kernel/async_channel/src/lib.rs new file mode 100644 index 0000000000..c8a7a5296a --- /dev/null +++ b/kernel/async_channel/src/lib.rs @@ -0,0 +1,164 @@ +//! A bounded, multi-producer, multi-consumer asynchronous channel. +//! +//! See [`Channel`] for more details. + +#![no_std] + +use core::{ + pin::Pin, + task::{Context, Poll}, +}; + +use async_wait_queue::WaitQueue; +use futures::stream::{FusedStream, Stream}; +use mpmc::Queue; +use sync::DeadlockPrevention; +use sync_spin::Spin; + +/// A bounded, multi-producer, multi-consumer asynchronous channel. +/// +/// The channel can also be used outside of an asynchronous runtime with the +/// [`blocking_send`], and [`blocking_recv`] methods. +/// +/// [`blocking_send`]: Self::blocking_send +/// [`blocking_recv`]: Self::blocking_recv +#[derive(Clone)] +pub struct Channel +where + T: Send, + P: DeadlockPrevention, +{ + inner: Queue, + senders: WaitQueue

, + receivers: WaitQueue

, +} + +impl Channel +where + T: Send, + P: DeadlockPrevention, +{ + /// Creates a new channel. + /// + /// The provided capacity dictates how many messages can be stored in the + /// queue before the sender blocks. + /// + /// # Examples + /// + /// ``` + /// use async_channel::Channel; + /// + /// let channel = Channel::new(2); + /// + /// assert!(channel.try_send(1).is_ok()); + /// assert!(channel.try_send(2).is_ok()); + /// // The channel is full. + /// assert!(channel.try_send(3).is_err()); + /// + /// assert_eq!(channel.try_recv(), Some(1)); + /// assert_eq!(channel.try_recv(), Some(2)); + /// assert!(channel.try_recv().is_none()); + /// ``` + pub fn new(capacity: usize) -> Self { + Self { + inner: Queue::with_capacity(capacity), + senders: WaitQueue::new(), + receivers: WaitQueue::new(), + } + } + + /// Sends `value`. + /// + /// # Cancel safety + /// + /// This method is cancel safe, in that if it is dropped prior to + /// completion, `value` is guaranteed to have not been set. However, in that + /// case `value` will be dropped. + pub async fn send(&self, value: T) { + let mut temp = Some(value); + + self.senders + .wait_until(|| match self.inner.push(temp.take().unwrap()) { + Ok(()) => { + self.receivers.notify_one(); + Some(()) + } + Err(value) => { + temp = Some(value); + None + } + }) + .await + } + + /// Tries to send `value`. + /// + /// # Errors + /// + /// Returns an error containing `value` if the channel was full. + pub fn try_send(&self, value: T) -> Result<(), T> { + self.inner.push(value)?; + self.receivers.notify_one(); + Ok(()) + } + + /// Blocks the current thread until `value` is sent. + pub fn blocking_send(&self, value: T) { + dreadnought::block_on(self.send(value)) + } + + /// Receives the next value. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + pub async fn recv(&self) -> T { + let value = self.receivers.wait_until(|| self.inner.pop()).await; + self.senders.notify_one(); + value + } + + /// Tries to receive the next value. + pub fn try_recv(&self) -> Option { + let value = self.inner.pop()?; + self.senders.notify_one(); + Some(value) + } + + /// Blocks the current thread until a value is received. + pub fn blocking_recv(&self) -> T { + dreadnought::block_on(self.recv()) + } +} + +impl Stream for Channel +where + T: Send, + P: DeadlockPrevention, +{ + type Item = T; + + fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + match self + .receivers + .poll_wait_until(ctx, &mut || self.inner.pop()) + { + Poll::Ready(value) => { + self.senders.notify_one(); + Poll::Ready(Some(value)) + } + Poll::Pending => Poll::Pending, + } + } +} + +impl FusedStream for Channel +where + T: Send, + P: DeadlockPrevention, +{ + fn is_terminated(&self) -> bool { + // NOTE: If we ever implement disconnections, this will need to be modified. + false + } +} diff --git a/kernel/async_wait_queue/Cargo.toml b/kernel/async_wait_queue/Cargo.toml new file mode 100644 index 0000000000..ef341895c3 --- /dev/null +++ b/kernel/async_wait_queue/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "async_wait_queue" +version = "0.1.0" +authors = ["Klim Tsoutsman "] +description = "An asynchronous wait queue" +edition = "2021" + +[dependencies] +dreadnought = { path = "../dreadnought" } +mpmc_queue = { path = "../../libs/mpmc_queue" } +sync = { path = "../../libs/sync" } +sync_spin = { path = "../../libs/sync_spin" } diff --git a/kernel/async_wait_queue/src/lib.rs b/kernel/async_wait_queue/src/lib.rs new file mode 100644 index 0000000000..32a617abc0 --- /dev/null +++ b/kernel/async_wait_queue/src/lib.rs @@ -0,0 +1,105 @@ +//! An asynchronous wait queue. +//! +//! See [`WaitQueue`] for more details. + +#![no_std] + +extern crate alloc; + +use alloc::sync::Arc; +use core::{ + future::poll_fn, + task::{Context, Poll, Waker}, +}; + +use mpmc_queue::Queue; +use sync::DeadlockPrevention; +use sync_spin::Spin; + +/// An asynchronous queue of tasks waiting to be notified. +#[derive(Clone)] +pub struct WaitQueue

+where + P: DeadlockPrevention, +{ + inner: Arc>, +} + +impl

Default for WaitQueue

+where + P: DeadlockPrevention, +{ + fn default() -> Self { + Self::new() + } +} + +impl

WaitQueue

+where + P: DeadlockPrevention, +{ + /// Creates a new empty wait queue. + pub fn new() -> Self { + Self { + inner: Arc::new(Queue::new()), + } + } + + pub async fn wait_until(&self, mut condition: F) -> T + where + F: FnMut() -> Option, + { + poll_fn(move |context| self.poll_wait_until(context, &mut condition)).await + } + + pub fn poll_wait_until(&self, ctx: &mut Context, condition: &mut F) -> Poll + where + F: FnMut() -> Option, + { + let wrapped_condition = || { + if let Some(value) = condition() { + Ok(value) + } else { + Err(()) + } + }; + + match self + .inner + .push_if_fail(ctx.waker().clone(), wrapped_condition) + { + Ok(value) => Poll::Ready(value), + Err(()) => Poll::Pending, + } + } + + pub fn blocking_wait_until(&self, condition: F) -> T + where + F: FnMut() -> Option, + { + dreadnought::block_on(self.wait_until(condition)) + } + + /// Notifies the first task in the wait queue. + /// + /// Returns whether or not a task was awoken. + pub fn notify_one(&self) -> bool { + match self.inner.pop() { + Some(waker) => { + waker.wake(); + // From the `Waker` documentation: + // > As long as the executor keeps running and the task is not + // finished, it is guaranteed that each invocation of `wake()` + // will be followed by at least one `poll()` of the task to + // which this `Waker` belongs. + true + } + None => false, + } + } + + /// Notifies all the tasks in the wait queue. + pub fn notify_all(&self) { + while self.notify_one() {} + } +}