forked from theseus-os/Theseus
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
async_channel
without cancel safety
Signed-off-by: Klimenty Tsoutsman <[email protected]>
- Loading branch information
Showing
4 changed files
with
295 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[package] | ||
name = "async_channel" | ||
version = "0.1.0" | ||
authors = ["Klim Tsoutsman <[email protected]>"] | ||
description = "A bounded, multi-producer, multi-consumer asynchronous channel" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
async_wait_queue = { path = "../async_wait_queue" } | ||
dreadnought = { path = "../dreadnought" } | ||
futures = { version = "0.3.28", default-features = false } | ||
mpmc = "0.1.6" | ||
sync = { path = "../../libs/sync" } | ||
sync_spin = { path = "../../libs/sync_spin" } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
//! A bounded, multi-producer, multi-consumer asynchronous channel. | ||
//! | ||
//! See [`Channel`] for more details. | ||
#![no_std] | ||
|
||
use core::{ | ||
pin::Pin, | ||
task::{Context, Poll}, | ||
}; | ||
|
||
use async_wait_queue::WaitQueue; | ||
use futures::stream::{FusedStream, Stream}; | ||
use mpmc::Queue; | ||
use sync::DeadlockPrevention; | ||
use sync_spin::Spin; | ||
|
||
/// A bounded, multi-producer, multi-consumer asynchronous channel. | ||
/// | ||
/// The channel can also be used outside of an asynchronous runtime with the | ||
/// [`blocking_send`], and [`blocking_recv`] methods. | ||
/// | ||
/// [`blocking_send`]: Self::blocking_send | ||
/// [`blocking_recv`]: Self::blocking_recv | ||
#[derive(Clone)] | ||
pub struct Channel<T, P = Spin> | ||
where | ||
T: Send, | ||
P: DeadlockPrevention, | ||
{ | ||
inner: Queue<T>, | ||
senders: WaitQueue<P>, | ||
receivers: WaitQueue<P>, | ||
} | ||
|
||
impl<T, P> Channel<T, P> | ||
where | ||
T: Send, | ||
P: DeadlockPrevention, | ||
{ | ||
/// Creates a new channel. | ||
/// | ||
/// The provided capacity dictates how many messages can be stored in the | ||
/// queue before the sender blocks. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ``` | ||
/// use async_channel::Channel; | ||
/// | ||
/// let channel = Channel::new(2); | ||
/// | ||
/// assert!(channel.try_send(1).is_ok()); | ||
/// assert!(channel.try_send(2).is_ok()); | ||
/// // The channel is full. | ||
/// assert!(channel.try_send(3).is_err()); | ||
/// | ||
/// assert_eq!(channel.try_recv(), Some(1)); | ||
/// assert_eq!(channel.try_recv(), Some(2)); | ||
/// assert!(channel.try_recv().is_none()); | ||
/// ``` | ||
pub fn new(capacity: usize) -> Self { | ||
Self { | ||
inner: Queue::with_capacity(capacity), | ||
senders: WaitQueue::new(), | ||
receivers: WaitQueue::new(), | ||
} | ||
} | ||
|
||
/// Sends `value`. | ||
/// | ||
/// # Cancel safety | ||
/// | ||
/// This method is cancel safe, in that if it is dropped prior to | ||
/// completion, `value` is guaranteed to have not been set. However, in that | ||
/// case `value` will be dropped. | ||
pub async fn send(&self, value: T) { | ||
let mut temp = Some(value); | ||
|
||
self.senders | ||
.wait_until(|| match self.inner.push(temp.take().unwrap()) { | ||
Ok(()) => { | ||
self.receivers.notify_one(); | ||
Some(()) | ||
} | ||
Err(value) => { | ||
temp = Some(value); | ||
None | ||
} | ||
}) | ||
.await | ||
} | ||
|
||
/// Tries to send `value`. | ||
/// | ||
/// # Errors | ||
/// | ||
/// Returns an error containing `value` if the channel was full. | ||
pub fn try_send(&self, value: T) -> Result<(), T> { | ||
self.inner.push(value)?; | ||
self.receivers.notify_one(); | ||
Ok(()) | ||
} | ||
|
||
/// Blocks the current thread until `value` is sent. | ||
pub fn blocking_send(&self, value: T) { | ||
dreadnought::block_on(self.send(value)) | ||
} | ||
|
||
/// Receives the next value. | ||
/// | ||
/// # Cancel safety | ||
/// | ||
/// This method is cancel safe. | ||
pub async fn recv(&self) -> T { | ||
let value = self.receivers.wait_until(|| self.inner.pop()).await; | ||
self.senders.notify_one(); | ||
value | ||
} | ||
|
||
/// Tries to receive the next value. | ||
pub fn try_recv(&self) -> Option<T> { | ||
let value = self.inner.pop()?; | ||
self.senders.notify_one(); | ||
Some(value) | ||
} | ||
|
||
/// Blocks the current thread until a value is received. | ||
pub fn blocking_recv(&self) -> T { | ||
dreadnought::block_on(self.recv()) | ||
} | ||
} | ||
|
||
impl<T, P> Stream for Channel<T, P> | ||
where | ||
T: Send, | ||
P: DeadlockPrevention, | ||
{ | ||
type Item = T; | ||
|
||
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> { | ||
match self | ||
.receivers | ||
.poll_wait_until(ctx, &mut || self.inner.pop()) | ||
{ | ||
Poll::Ready(value) => { | ||
self.senders.notify_one(); | ||
Poll::Ready(Some(value)) | ||
} | ||
Poll::Pending => Poll::Pending, | ||
} | ||
} | ||
} | ||
|
||
impl<T, P> FusedStream for Channel<T, P> | ||
where | ||
T: Send, | ||
P: DeadlockPrevention, | ||
{ | ||
fn is_terminated(&self) -> bool { | ||
// NOTE: If we ever implement disconnections, this will need to be modified. | ||
false | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[package] | ||
name = "async_wait_queue" | ||
version = "0.1.0" | ||
authors = ["Klim Tsoutsman <[email protected]>"] | ||
description = "An asynchronous wait queue" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
dreadnought = { path = "../dreadnought" } | ||
mpmc_queue = { path = "../../libs/mpmc_queue" } | ||
sync = { path = "../../libs/sync" } | ||
sync_spin = { path = "../../libs/sync_spin" } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
//! An asynchronous wait queue. | ||
//! | ||
//! See [`WaitQueue`] for more details. | ||
#![no_std] | ||
|
||
extern crate alloc; | ||
|
||
use alloc::sync::Arc; | ||
use core::{ | ||
future::poll_fn, | ||
task::{Context, Poll, Waker}, | ||
}; | ||
|
||
use mpmc_queue::Queue; | ||
use sync::DeadlockPrevention; | ||
use sync_spin::Spin; | ||
|
||
/// An asynchronous queue of tasks waiting to be notified. | ||
#[derive(Clone)] | ||
pub struct WaitQueue<P = Spin> | ||
where | ||
P: DeadlockPrevention, | ||
{ | ||
inner: Arc<Queue<Waker, P>>, | ||
} | ||
|
||
impl<P> Default for WaitQueue<P> | ||
where | ||
P: DeadlockPrevention, | ||
{ | ||
fn default() -> Self { | ||
Self::new() | ||
} | ||
} | ||
|
||
impl<P> WaitQueue<P> | ||
where | ||
P: DeadlockPrevention, | ||
{ | ||
/// Creates a new empty wait queue. | ||
pub fn new() -> Self { | ||
Self { | ||
inner: Arc::new(Queue::new()), | ||
} | ||
} | ||
|
||
pub async fn wait_until<F, T>(&self, mut condition: F) -> T | ||
where | ||
F: FnMut() -> Option<T>, | ||
{ | ||
poll_fn(move |context| self.poll_wait_until(context, &mut condition)).await | ||
} | ||
|
||
pub fn poll_wait_until<F, T>(&self, ctx: &mut Context, condition: &mut F) -> Poll<T> | ||
where | ||
F: FnMut() -> Option<T>, | ||
{ | ||
let wrapped_condition = || { | ||
if let Some(value) = condition() { | ||
Ok(value) | ||
} else { | ||
Err(()) | ||
} | ||
}; | ||
|
||
match self | ||
.inner | ||
.push_if_fail(ctx.waker().clone(), wrapped_condition) | ||
{ | ||
Ok(value) => Poll::Ready(value), | ||
Err(()) => Poll::Pending, | ||
} | ||
} | ||
|
||
pub fn blocking_wait_until<F, T>(&self, condition: F) -> T | ||
where | ||
F: FnMut() -> Option<T>, | ||
{ | ||
dreadnought::block_on(self.wait_until(condition)) | ||
} | ||
|
||
/// Notifies the first task in the wait queue. | ||
/// | ||
/// Returns whether or not a task was awoken. | ||
pub fn notify_one(&self) -> bool { | ||
match self.inner.pop() { | ||
Some(waker) => { | ||
waker.wake(); | ||
// From the `Waker` documentation: | ||
// > As long as the executor keeps running and the task is not | ||
// finished, it is guaranteed that each invocation of `wake()` | ||
// will be followed by at least one `poll()` of the task to | ||
// which this `Waker` belongs. | ||
true | ||
} | ||
None => false, | ||
} | ||
} | ||
|
||
/// Notifies all the tasks in the wait queue. | ||
pub fn notify_all(&self) { | ||
while self.notify_one() {} | ||
} | ||
} |