diff --git a/Cargo.toml b/Cargo.toml index d097566..e51c3f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,8 @@ categories = ["asynchronous", "network-programming", "MCU", "embedded"] exclude = ["/.*"] [features] -default = ["futures-io", "futures-lite"] +default = ["futures-io", "futures-lite", "embassy-time"] +embassy-time = ["embassy-time-driver", "embassy-time-queue-driver", "dep:embassy-time"] [dependencies] libc = "0.2" @@ -21,6 +22,8 @@ heapless = "0.8" log = { version = "0.4", default-features = false } futures-io = { version = "0.3", default-features = false, optional = true, features = ["std"] } futures-lite = { version = "1", default-features = false, optional = true } +embassy-time-driver = { version = "0.1", optional = true } +embassy-time-queue-driver = { version = "0.1", optional = true } embassy-time = { version = "0.3", optional = true } [dev-dependencies] @@ -32,4 +35,8 @@ env_logger = "0.10" [[test]] name = "async" -required-features = ["futures-io", "futures-lite"] +required-features = ["futures-io", "futures-lite", "embassy-time"] + +[[test]] +name = "timer" +required-features = ["futures-lite", "embassy-time"] diff --git a/README.md b/README.md index 6fff3ba..9b6d7f9 100644 --- a/README.md +++ b/README.md @@ -5,13 +5,13 @@ [![Cargo](https://img.shields.io/crates/v/async-io-mini.svg)](https://crates.io/crates/async-io-mini) [![Documentation](https://docs.rs/async-io/badge.svg)](https://docs.rs/async-io-mini) -Async I/O. **EXPERIMENTAL!!** +Async I/O and timers. **Experimental** -This crate is an **experimental** fork of the splendid [`async-io`](https://github.com/smol-rs/async-io) crate targetting MCUs and ESP-IDF in particular. +This crate is an experimental fork of the splendid [`async-io`](https://github.com/smol-rs/async-io) crate targetting MCUs and ESP-IDF in particular. ## How to use? -`async-io-mini` is a drop-in, API-compatible replacement for the `Async` type from `async-io` (but does NOT have an equivalent of `Timer` - see why [below](#limitations)). +`async-io-mini` is a drop-in, API-compatible replacement for the `Async` and `Timer` types from `async-io`. So either: * Just replace all `use async_io` occurances in your crate with `use async_io_mini` @@ -37,13 +37,16 @@ Further, `async-io` has a non-trivial set of dependencies (again - for MCUs; for - `log` (might become optional); - `enumset` (not crucial, might remove). -## Limitations +## Enhancements + +The `Timer` type of `async_io_mini` is based on the `embassy-time` crate, and as such should offer a higher resolution on embedded operating systems like the ESP-IDF than what can be normally achieved by implementing timers using the `timeout` parameter of the `select` syscall (as `async-io` does). -### No timers +The reason for this is that on the ESP-IDF, the `timeout` parameter of `select` provides a resolution of 10ms (one FreeRTOS sys-tick), while +`embassy-time` is implemented using the [ESP-IDF Timer service](https://docs.espressif.com/projects/esp-idf/en/stable/esp32/api-reference/system/esp_timer.html), which provides resolutions up to 1 microsecond. -`async-io-mini` does NOT have an equivalent of `async_io::Timer`. On ESP-IDF at least, timers based on OS systicks are often not very useful, as the OS systick is low-res (10ms). +With that said, for greenfield code that does not need to be compatible with `async-io`, use the native `embassy_time::Timer` and `embassy_time::Ticker` rather than `async_io_mini::Timer`, because the latter has a larger memory footprint (40 bytes on 32bit archs) compared to the `embassy-time` types (8 and 16 bytes each). -Workaround: use the `Timer` struct from the [`embassy-time`](https://crates.io/crates/embassy-time) crate, which provides a very similar API and is highly optimized for embedded environments. On the ESP-IDF, the `embassy-time-driver` implementation is backed by the ESP-IDF Timer service, which runs off from a high priority thread by default and thus has good res. +## Limitations ### No equivalent of `async_io::block_on` @@ -51,18 +54,24 @@ Implementing socket polling as a shared task between the hidden `async-io-mini` ## Implementation +### Async + The first time `Async` is used, a thread named `async-io-mini` will be spawned. The purpose of this thread is to wait for I/O events reported by the operating system, and then wake appropriate futures blocked on I/O when they can be resumed. To wait for the next I/O event, the "async-io-mini" thread uses the [select](https://en.wikipedia.org/wiki/Select_(Unix)) syscall, and **is thus only useful for MCUs (might just be the ESP-IDF) where the number of file or socket handles is very small anyway**. +### Timer + +As per above, the `Timer` type is a wrapper around the functionality provided by the `embassy-time` crate. + ## Examples -Connect to `example.com:80`. +Connect to `example.com:80`, or time out after 10 seconds. ```rust -use async_io_mini::Async; +use async_io_mini::{Async, Timer}; use futures_lite::{future::FutureExt, io}; use std::net::{TcpStream, ToSocketAddrs}; @@ -70,7 +79,11 @@ use std::time::Duration; let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); -let stream = Async::::connect(addr).await?; +let stream = Async::::connect(addr).or(async { + Timer::after(Duration::from_secs(10)).await; + Err(io::ErrorKind::TimedOut.into()) +}) +.await?; ``` ## License diff --git a/src/io.rs b/src/io.rs new file mode 100644 index 0000000..802dc47 --- /dev/null +++ b/src/io.rs @@ -0,0 +1,1222 @@ +use core::future::{poll_fn, Future}; +use core::pin::pin; +use core::task::{Context, Poll}; + +use std::io::{self, Read, Write}; +use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket}; +use std::os::fd::FromRawFd; +use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd}; + +use super::reactor::{Event, REACTOR}; +use super::sys; +use super::{ready, syscall, syscall_los, syscall_los_eagain}; + +/// Async adapter for I/O types. +/// +/// This type puts an I/O handle into non-blocking mode, registers it in +/// [epoll]/[kqueue]/[event ports]/[IOCP], and then provides an async interface for it. +/// +/// [epoll]: https://en.wikipedia.org/wiki/Epoll +/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue +/// [event ports]: https://illumos.org/man/port_create +/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports +/// +/// # Caveats +/// +/// [`Async`] is a low-level primitive, and as such it comes with some caveats. +/// +/// For higher-level primitives built on top of [`Async`], look into [`async-net`] or +/// [`async-process`] (on Unix). +/// +/// The most notable caveat is that it is unsafe to access the inner I/O source mutably +/// using this primitive. Traits likes [`AsyncRead`] and [`AsyncWrite`] are not implemented by +/// default unless it is guaranteed that the resource won't be invalidated by reading or writing. +/// See the [`IoSafe`] trait for more information. +/// +/// [`async-net`]: https://github.com/smol-rs/async-net +/// [`async-process`]: https://github.com/smol-rs/async-process +/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html +/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html +/// +/// ### Supported types +/// +/// [`Async`] supports all networking types, as well as some OS-specific file descriptors like +/// [timerfd] and [inotify]. +/// +/// However, do not use [`Async`] with types like [`File`][`std::fs::File`], +/// [`Stdin`][`std::io::Stdin`], [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`] +/// because all operating systems have issues with them when put in non-blocking mode. +/// +/// [timerfd]: https://github.com/smol-rs/async-io/blob/master/examples/linux-timerfd.rs +/// [inotify]: https://github.com/smol-rs/async-io/blob/master/examples/linux-inotify.rs +/// +/// ### Concurrent I/O +/// +/// Note that [`&Async`][`Async`] implements [`AsyncRead`] and [`AsyncWrite`] if `&T` +/// implements those traits, which means tasks can concurrently read and write using shared +/// references. +/// +/// But there is a catch: only one task can read a time, and only one task can write at a time. It +/// is okay to have two tasks where one is reading and the other is writing at the same time, but +/// it is not okay to have two tasks reading at the same time or writing at the same time. If you +/// try to do that, conflicting tasks will just keep waking each other in turn, thus wasting CPU +/// time. +/// +/// Besides [`AsyncRead`] and [`AsyncWrite`], this caveat also applies to +/// [`poll_readable()`][`Async::poll_readable()`] and +/// [`poll_writable()`][`Async::poll_writable()`]. +/// +/// However, any number of tasks can be concurrently calling other methods like +/// [`readable()`][`Async::readable()`] or [`read_with()`][`Async::read_with()`]. +/// +/// ### Closing +/// +/// Closing the write side of [`Async`] with [`close()`][`futures_lite::AsyncWriteExt::close()`] +/// simply flushes. If you want to shutdown a TCP or Unix socket, use +/// [`Shutdown`][`std::net::Shutdown`]. +/// +/// # Examples +/// +/// Connect to a server and echo incoming messages back to the server: +/// +/// ```no_run +/// use async_io_mini::Async; +/// use futures_lite::io; +/// use std::net::TcpStream; +/// +/// # futures_lite::future::block_on(async { +/// // Connect to a local server. +/// let stream = Async::::connect(([127, 0, 0, 1], 8000)).await?; +/// +/// // Echo all messages from the read side of the stream into the write side. +/// io::copy(&stream, &stream).await?; +/// # std::io::Result::Ok(()) }); +/// ``` +/// +/// You can use either predefined async methods or wrap blocking I/O operations in +/// [`Async::read_with()`], [`Async::read_with_mut()`], [`Async::write_with()`], and +/// [`Async::write_with_mut()`]: +/// +/// ```no_run +/// use async_io_mini::Async; +/// use std::net::TcpListener; +/// +/// # futures_lite::future::block_on(async { +/// let listener = Async::::bind(([127, 0, 0, 1], 0))?; +/// +/// // These two lines are equivalent: +/// let (stream, addr) = listener.accept().await?; +/// let (stream, addr) = listener.read_with(|inner| inner.accept()).await?; +/// # std::io::Result::Ok(()) }); +/// ``` +#[derive(Debug)] +pub struct Async { + io: Option, +} + +impl Unpin for Async {} + +impl Async { + /// Creates an async I/O handle. + /// + /// This method will put the handle in non-blocking mode and register it in + /// [epoll]/[kqueue]/[event ports]/[IOCP]. + /// + /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement + /// `AsSocket`. + /// + /// [epoll]: https://en.wikipedia.org/wiki/Epoll + /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue + /// [event ports]: https://illumos.org/man/port_create + /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Async; + /// use std::net::{SocketAddr, TcpListener}; + /// + /// # futures_lite::future::block_on(async { + /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; + /// let listener = Async::new(listener)?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn new(io: T) -> io::Result { + // Put the file descriptor in non-blocking mode. + set_nonblocking(io.as_fd())?; + + Self::new_nonblocking(io) + } + + /// Creates an async I/O handle without setting it to non-blocking mode. + /// + /// This method will register the handle in [epoll]/[kqueue]/[event ports]/[IOCP]. + /// + /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement + /// `AsSocket`. + /// + /// [epoll]: https://en.wikipedia.org/wiki/Epoll + /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue + /// [event ports]: https://illumos.org/man/port_create + /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports + /// + /// # Caveats + /// + /// The caller should ensure that the handle is set to non-blocking mode or that it is okay if + /// it is not set. If not set to non-blocking mode, I/O operations may block the current thread + /// and cause a deadlock in an asynchronous context. + pub fn new_nonblocking(io: T) -> io::Result { + REACTOR.start()?; + // SAFETY: It is impossible to drop the I/O source while it is registered. + REACTOR.register(io.as_fd().as_raw_fd())?; + + Ok(Self { io: Some(io) }) + } +} + +impl AsRawFd for Async { + fn as_raw_fd(&self) -> RawFd { + self.get_ref().as_raw_fd() + } +} + +impl AsFd for Async { + fn as_fd(&self) -> BorrowedFd<'_> { + self.get_ref().as_fd() + } +} + +impl> TryFrom for Async { + type Error = io::Error; + + fn try_from(value: OwnedFd) -> Result { + Async::new(value.into()) + } +} + +impl> TryFrom> for OwnedFd { + type Error = io::Error; + + fn try_from(value: Async) -> Result { + value.into_inner().map(Into::into) + } +} + +impl Async { + /// Gets a reference to the inner I/O handle. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// let inner = listener.get_ref(); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn get_ref(&self) -> &T { + self.io.as_ref().unwrap() + } + + /// Gets a mutable reference to the inner I/O handle. + /// + /// # Safety + /// + /// The underlying I/O source must not be dropped using this function. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// let inner = unsafe { listener.get_mut() }; + /// # std::io::Result::Ok(()) }); + /// ``` + pub unsafe fn get_mut(&mut self) -> &mut T { + self.io.as_mut().unwrap() + } + + /// Unwraps the inner I/O handle. + /// + /// This method will **not** put the I/O handle back into blocking mode. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// let inner = listener.into_inner()?; + /// + /// // Put the listener back into blocking mode. + /// inner.set_nonblocking(false)?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn into_inner(mut self) -> io::Result { + REACTOR.deregister(self.as_fd().as_raw_fd())?; + Ok(self.io.take().unwrap()) + } + + /// Waits until the I/O handle is readable. + /// + /// This method completes when a read operation on this I/O handle wouldn't block. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Wait until a client can be accepted. + /// listener.readable().await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn readable(&self) -> io::Result<()> { + poll_fn(|cx| self.poll_readable(cx)).await + } + + /// Waits until the I/O handle is writable. + /// + /// This method completes when a write operation on this I/O handle wouldn't block. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Async; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let stream = Async::::connect(addr).await?; + /// + /// // Wait until the stream is writable. + /// stream.writable().await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn writable(&self) -> io::Result<()> { + poll_fn(|cx| self.poll_writable(cx)).await + } + + /// Polls the I/O handle for readability. + /// + /// When this method returns [`Poll::Ready`], that means the OS has delivered an event + /// indicating readability since the last time this task has called the method and received + /// [`Poll::Pending`]. + /// + /// # Caveats + /// + /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks + /// will just keep waking each other in turn, thus wasting CPU time. + /// + /// Note that the [`AsyncRead`] implementation for [`Async`] also uses this method. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use futures_lite::future; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Wait until a client can be accepted. + /// future::poll_fn(|cx| listener.poll_readable(cx)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll> { + if REACTOR.fetch_or_set(self.as_fd().as_raw_fd(), Event::Read, cx.waker())? { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + /// Polls the I/O handle for writability. + /// + /// When this method returns [`Poll::Ready`], that means the OS has delivered an event + /// indicating writability since the last time this task has called the method and received + /// [`Poll::Pending`]. + /// + /// # Caveats + /// + /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks + /// will just keep waking each other in turn, thus wasting CPU time. + /// + /// Note that the [`AsyncWrite`] implementation for [`Async`] also uses this method. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Async; + /// use futures_lite::future; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let stream = Async::::connect(addr).await?; + /// + /// // Wait until the stream is writable. + /// future::poll_fn(|cx| stream.poll_writable(cx)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll> { + if REACTOR.fetch_or_set(self.as_fd().as_raw_fd(), Event::Write, cx.waker())? { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + /// Performs a read operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is readable. + /// + /// The closure receives a shared reference to the I/O handle. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Accept a new client asynchronously. + /// let (stream, addr) = listener.read_with(|l| l.accept()).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn read_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { + REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Read)?; + + let mut op = op; + loop { + match op(self.get_ref()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.readable()).await?; + } + } + + /// Performs a read operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is readable. + /// + /// The closure receives a mutable reference to the I/O handle. + /// + /// # Safety + /// + /// In the closure, the underlying I/O source must not be dropped. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Accept a new client asynchronously. + /// let (stream, addr) = unsafe { listener.read_with_mut(|l| l.accept()).await? }; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async unsafe fn read_with_mut( + &mut self, + op: impl FnMut(&mut T) -> io::Result, + ) -> io::Result { + REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Read)?; + + let mut op = op; + loop { + match op(self.get_mut()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.readable()).await?; + } + } + + /// Performs a write operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is writable. + /// + /// The closure receives a shared reference to the I/O handle. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let msg = b"hello"; + /// let len = socket.write_with(|s| s.send(msg)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn write_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { + REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Write)?; + + let mut op = op; + loop { + match op(self.get_ref()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.writable()).await?; + } + } + + /// Performs a write operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is writable. + /// + /// # Safety + /// + /// The closure receives a mutable reference to the I/O handle. In the closure, the underlying + /// I/O source must not be dropped. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let mut socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let msg = b"hello"; + /// let len = unsafe { socket.write_with_mut(|s| s.send(msg)).await? }; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async unsafe fn write_with_mut( + &mut self, + op: impl FnMut(&mut T) -> io::Result, + ) -> io::Result { + REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Write)?; + + let mut op = op; + loop { + match op(self.get_mut()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.writable()).await?; + } + } +} + +impl AsRef for Async { + fn as_ref(&self) -> &T { + self.io.as_ref().unwrap() + } +} + +impl Drop for Async { + fn drop(&mut self) { + if let Some(io) = &self.io { + REACTOR.deregister(io.as_fd().as_raw_fd()).ok(); + } + } +} + +/// Types whose I/O trait implementations do not drop the underlying I/O source. +/// +/// The resource contained inside of the [`Async`] cannot be invalidated. This invalidation can +/// happen if the inner resource (the [`TcpStream`], [`UnixListener`] or other `T`) is moved out +/// and dropped before the [`Async`]. Because of this, functions that grant mutable access to +/// the inner type are unsafe, as there is no way to guarantee that the source won't be dropped +/// and a dangling handle won't be left behind. +/// +/// Unfortunately this extends to implementations of [`Read`] and [`Write`]. Since methods on those +/// traits take `&mut`, there is no guarantee that the implementor of those traits won't move the +/// source out while the method is being run. +/// +/// This trait is an antidote to this predicament. By implementing this trait, the user pledges +/// that using any I/O traits won't destroy the source. This way, [`Async`] can implement the +/// `async` version of these I/O traits, like [`AsyncRead`] and [`AsyncWrite`]. +/// +/// # Safety +/// +/// Any I/O trait implementations for this type must not drop the underlying I/O source. Traits +/// affected by this trait include [`Read`], [`Write`], [`Seek`] and [`BufRead`]. +/// +/// This trait is implemented by default on top of `libstd` types. In addition, it is implemented +/// for immutable reference types, as it is impossible to invalidate any outstanding references +/// while holding an immutable reference, even with interior mutability. As Rust's current pinning +/// system relies on similar guarantees, I believe that this approach is robust. +/// +/// [`BufRead`]: https://doc.rust-lang.org/std/io/trait.BufRead.html +/// [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html +/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html +/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html +/// +/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html +/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html +pub unsafe trait IoSafe {} + +/// Reference types can't be mutated. +/// +/// The worst thing that can happen is that external state is used to change what kind of pointer +/// `as_fd()` returns. For instance: +/// +/// ``` +/// # #[cfg(unix)] { +/// use std::cell::Cell; +/// use std::net::TcpStream; +/// use std::os::unix::io::{AsFd, BorrowedFd}; +/// +/// struct Bar { +/// flag: Cell, +/// a: TcpStream, +/// b: TcpStream +/// } +/// +/// impl AsFd for Bar { +/// fn as_fd(&self) -> BorrowedFd<'_> { +/// if self.flag.replace(!self.flag.get()) { +/// self.a.as_fd() +/// } else { +/// self.b.as_fd() +/// } +/// } +/// } +/// # } +/// ``` +/// +/// We solve this problem by only calling `as_fd()` once to get the original source. Implementations +/// like this are considered buggy (but not unsound) and are thus not really supported by `async-io`. +unsafe impl IoSafe for &T {} + +// Can be implemented on top of libstd types. +unsafe impl IoSafe for std::fs::File {} +unsafe impl IoSafe for std::io::Stderr {} +unsafe impl IoSafe for std::io::Stdin {} +unsafe impl IoSafe for std::io::Stdout {} +unsafe impl IoSafe for std::io::StderrLock<'_> {} +unsafe impl IoSafe for std::io::StdinLock<'_> {} +unsafe impl IoSafe for std::io::StdoutLock<'_> {} +unsafe impl IoSafe for std::net::TcpStream {} +unsafe impl IoSafe for std::process::ChildStdin {} +unsafe impl IoSafe for std::process::ChildStdout {} +unsafe impl IoSafe for std::process::ChildStderr {} + +unsafe impl IoSafe for std::io::BufReader {} +unsafe impl IoSafe for std::io::BufWriter {} +unsafe impl IoSafe for std::io::LineWriter {} +unsafe impl IoSafe for &mut T {} +//unsafe impl IoSafe for alloc::boxed::Box {} +unsafe impl IoSafe for std::borrow::Cow<'_, T> {} + +#[cfg(feature = "futures-io")] +impl futures_io::AsyncRead for Async { + fn poll_read( + mut self: core::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.read(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } + + fn poll_read_vectored( + mut self: core::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [std::io::IoSliceMut<'_>], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.read_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } +} + +// Since this is through a reference, we can't mutate the inner I/O source. +// Therefore this is safe! +#[cfg(feature = "futures-io")] +impl futures_io::AsyncRead for &Async +where + for<'a> &'a T: Read, +{ + fn poll_read( + self: core::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + match (*self).get_ref().read(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } + + fn poll_read_vectored( + self: core::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [std::io::IoSliceMut<'_>], + ) -> Poll> { + loop { + match (*self).get_ref().read_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } +} + +#[cfg(feature = "futures-io")] +impl futures_io::AsyncWrite for Async { + fn poll_write( + mut self: core::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.write(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_write_vectored( + mut self: core::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.write_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_flush( + mut self: core::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.flush() { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_close(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +#[cfg(feature = "futures-io")] +impl futures_io::AsyncWrite for &Async +where + for<'a> &'a T: Write, +{ + fn poll_write( + self: core::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match (*self).get_ref().write(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_write_vectored( + self: core::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + loop { + match (*self).get_ref().write_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_flush(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match (*self).get_ref().flush() { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_close(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +impl Async { + /// Creates a TCP listener bound to the specified address. + /// + /// Binding with port number 0 will request an available port from the OS. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// println!("Listening on {}", listener.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(addr: A) -> io::Result> { + let addr = addr.into(); + Async::new(TcpListener::bind(addr)?) + } + + /// Accepts a new incoming TCP connection. + /// + /// When a connection is established, it will be returned as a TCP stream together with its + /// remote address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; + /// let (stream, addr) = listener.accept().await?; + /// println!("Accepted client: {}", addr); + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn accept(&self) -> io::Result<(Async, SocketAddr)> { + let (stream, addr) = self.read_with(|io| io.accept()).await?; + Ok((Async::new(stream)?, addr)) + } + + /// Returns a stream of incoming TCP connections. + /// + /// The stream is infinite, i.e. it never stops with a [`None`]. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use futures_lite::{pin, stream::StreamExt}; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; + /// let incoming = listener.incoming(); + /// pin!(incoming); + /// + /// while let Some(stream) = incoming.next().await { + /// let stream = stream?; + /// println!("Accepted client: {}", stream.get_ref().peer_addr()?); + /// } + /// # std::io::Result::Ok(()) }); + /// ``` + #[cfg(feature = "futures-lite")] + pub fn incoming( + &self, + ) -> impl futures_lite::Stream>> + Send + '_ { + futures_lite::stream::unfold(self, |listener| async move { + let res = listener.accept().await.map(|(stream, _)| stream); + Some((res, listener)) + }) + } +} + +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(listener: std::net::TcpListener) -> io::Result { + Async::new(listener) + } +} + +impl Async { + /// Creates a TCP connection to the specified address. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Async; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let stream = Async::::connect(addr).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn connect>(addr: A) -> io::Result> { + // Figure out how to handle this address. + let addr = addr.into(); + + let socket = match addr { + SocketAddr::V4(v4) => { + let addr = sys::sockaddr_in { + sin_family: sys::AF_INET as _, + sin_port: u16::to_be(v4.port()), + sin_addr: sys::in_addr { + s_addr: u32::from_ne_bytes(v4.ip().octets()), + }, + #[cfg(target_os = "espidf")] + sin_len: Default::default(), + sin_zero: Default::default(), + }; + + connect( + &addr as *const _ as *const _, + core::mem::size_of_val(&addr), + sys::AF_INET, + sys::SOCK_STREAM, + 0, + ) + } + SocketAddr::V6(v6) => { + let addr = sys::sockaddr_in6 { + sin6_family: sys::AF_INET6 as _, + sin6_port: u16::to_be(v6.port()), + sin6_flowinfo: 0, + sin6_addr: sys::in6_addr { + s6_addr: v6.ip().octets(), + }, + sin6_scope_id: 0, + #[cfg(target_os = "espidf")] + sin6_len: Default::default(), + }; + + connect( + &addr as *const _ as *const _, + core::mem::size_of_val(&addr), + sys::AF_INET6, + sys::SOCK_STREAM, + 6, + ) + } + }?; + + // Use new_nonblocking because connect already sets socket to non-blocking mode. + let stream = Async::new_nonblocking(TcpStream::from(socket))?; + + // The stream becomes writable when connected. + stream.writable().await?; + + // Check if there was an error while connecting. + match stream.get_ref().take_error()? { + None => Ok(stream), + Some(err) => Err(err), + } + } + + /// Reads data from the stream without removing it from the buffer. + /// + /// Returns the number of bytes read. Successive calls of this method read the same data. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Async; + /// use futures_lite::{io::AsyncWriteExt, stream::StreamExt}; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let mut stream = Async::::connect(addr).await?; + /// + /// stream + /// .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + /// .await?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = stream.peek(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.peek(buf)).await + } +} + +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(stream: std::net::TcpStream) -> io::Result { + Async::new(stream) + } +} + +impl Async { + /// Creates a UDP socket bound to the specified address. + /// + /// Binding with port number 0 will request an available port from the OS. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; + /// println!("Bound to {}", socket.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(addr: A) -> io::Result> { + let addr = addr.into(); + Async::new(UdpSocket::bind(addr)?) + } + + /// Receives a single datagram message. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.recv_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.read_with(|io| io.recv_from(buf)).await + } + + /// Receives a single datagram message without removing it from the queue. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.peek_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.read_with(|io| io.peek_from(buf)).await + } + + /// Sends data to the specified address. + /// + /// Returns the number of bytes writen. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; + /// let addr = socket.get_ref().local_addr()?; + /// + /// let msg = b"hello"; + /// let len = socket.send_to(msg, addr).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send_to>(&self, buf: &[u8], addr: A) -> io::Result { + let addr = addr.into(); + self.write_with(|io| io.send_to(buf, addr)).await + } + + /// Receives a single datagram message from the connected peer. + /// + /// Returns the number of bytes read. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.recv(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.recv(buf)).await + } + + /// Receives a single datagram message from the connected peer without removing it from the + /// queue. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.peek(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.peek(buf)).await + } + + /// Sends data to the connected peer. + /// + /// Returns the number of bytes written. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io_mini::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let msg = b"hello"; + /// let len = socket.send(msg).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result { + self.write_with(|io| io.send(buf)).await + } +} + +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(socket: std::net::UdpSocket) -> io::Result { + Async::new(socket) + } +} + +/// Polls a future once, waits for a wakeup, and then optimistically assumes the future is ready. +async fn optimistic(fut: impl Future>) -> io::Result<()> { + let mut polled = false; + let mut fut = pin!(fut); + + poll_fn(move |cx| { + if !polled { + polled = true; + fut.as_mut().poll(cx) + } else { + Poll::Ready(Ok(())) + } + }) + .await +} + +fn connect( + addr: *const sys::sockaddr, + addr_len: usize, + domain: sys::c_int, + ty: sys::c_int, + protocol: sys::c_int, +) -> io::Result { + // Create the socket. + let socket = unsafe { OwnedFd::from_raw_fd(syscall_los!(sys::socket(domain, ty, protocol))?) }; + + // Set non-blocking mode. + set_nonblocking(socket.as_fd())?; + + syscall_los_eagain!(unsafe { sys::connect(socket.as_raw_fd(), addr, addr_len as _) })?; + + Ok(socket) +} + +fn set_nonblocking(fd: BorrowedFd) -> io::Result<()> { + let previous = unsafe { sys::fcntl(fd.as_raw_fd(), sys::F_GETFL) }; + let new = previous | sys::O_NONBLOCK; + if new != previous { + syscall!(unsafe { sys::fcntl(fd.as_raw_fd(), sys::F_SETFL, new) })?; + } + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index fdbc427..2f1f006 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,1226 +30,12 @@ #![allow(unknown_lints)] #![allow(clippy::needless_maybe_sized)] -use core::future::{poll_fn, Future}; -use core::pin::pin; -use core::task::{Context, Poll}; - -use std::io::{self, Read, Write}; -use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket}; -use std::os::fd::FromRawFd; -use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd}; - -use reactor::{Event, REACTOR}; +pub use io::*; +#[cfg(feature = "embassy-time")] +pub use timer::*; +mod io; mod reactor; mod sys; - -/// Async adapter for I/O types. -/// -/// This type puts an I/O handle into non-blocking mode, registers it in -/// [epoll]/[kqueue]/[event ports]/[IOCP], and then provides an async interface for it. -/// -/// [epoll]: https://en.wikipedia.org/wiki/Epoll -/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue -/// [event ports]: https://illumos.org/man/port_create -/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports -/// -/// # Caveats -/// -/// [`Async`] is a low-level primitive, and as such it comes with some caveats. -/// -/// For higher-level primitives built on top of [`Async`], look into [`async-net`] or -/// [`async-process`] (on Unix). -/// -/// The most notable caveat is that it is unsafe to access the inner I/O source mutably -/// using this primitive. Traits likes [`AsyncRead`] and [`AsyncWrite`] are not implemented by -/// default unless it is guaranteed that the resource won't be invalidated by reading or writing. -/// See the [`IoSafe`] trait for more information. -/// -/// [`async-net`]: https://github.com/smol-rs/async-net -/// [`async-process`]: https://github.com/smol-rs/async-process -/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html -/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html -/// -/// ### Supported types -/// -/// [`Async`] supports all networking types, as well as some OS-specific file descriptors like -/// [timerfd] and [inotify]. -/// -/// However, do not use [`Async`] with types like [`File`][`std::fs::File`], -/// [`Stdin`][`std::io::Stdin`], [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`] -/// because all operating systems have issues with them when put in non-blocking mode. -/// -/// [timerfd]: https://github.com/smol-rs/async-io/blob/master/examples/linux-timerfd.rs -/// [inotify]: https://github.com/smol-rs/async-io/blob/master/examples/linux-inotify.rs -/// -/// ### Concurrent I/O -/// -/// Note that [`&Async`][`Async`] implements [`AsyncRead`] and [`AsyncWrite`] if `&T` -/// implements those traits, which means tasks can concurrently read and write using shared -/// references. -/// -/// But there is a catch: only one task can read a time, and only one task can write at a time. It -/// is okay to have two tasks where one is reading and the other is writing at the same time, but -/// it is not okay to have two tasks reading at the same time or writing at the same time. If you -/// try to do that, conflicting tasks will just keep waking each other in turn, thus wasting CPU -/// time. -/// -/// Besides [`AsyncRead`] and [`AsyncWrite`], this caveat also applies to -/// [`poll_readable()`][`Async::poll_readable()`] and -/// [`poll_writable()`][`Async::poll_writable()`]. -/// -/// However, any number of tasks can be concurrently calling other methods like -/// [`readable()`][`Async::readable()`] or [`read_with()`][`Async::read_with()`]. -/// -/// ### Closing -/// -/// Closing the write side of [`Async`] with [`close()`][`futures_lite::AsyncWriteExt::close()`] -/// simply flushes. If you want to shutdown a TCP or Unix socket, use -/// [`Shutdown`][`std::net::Shutdown`]. -/// -/// # Examples -/// -/// Connect to a server and echo incoming messages back to the server: -/// -/// ```no_run -/// use async_io_mini::Async; -/// use futures_lite::io; -/// use std::net::TcpStream; -/// -/// # futures_lite::future::block_on(async { -/// // Connect to a local server. -/// let stream = Async::::connect(([127, 0, 0, 1], 8000)).await?; -/// -/// // Echo all messages from the read side of the stream into the write side. -/// io::copy(&stream, &stream).await?; -/// # std::io::Result::Ok(()) }); -/// ``` -/// -/// You can use either predefined async methods or wrap blocking I/O operations in -/// [`Async::read_with()`], [`Async::read_with_mut()`], [`Async::write_with()`], and -/// [`Async::write_with_mut()`]: -/// -/// ```no_run -/// use async_io_mini::Async; -/// use std::net::TcpListener; -/// -/// # futures_lite::future::block_on(async { -/// let listener = Async::::bind(([127, 0, 0, 1], 0))?; -/// -/// // These two lines are equivalent: -/// let (stream, addr) = listener.accept().await?; -/// let (stream, addr) = listener.read_with(|inner| inner.accept()).await?; -/// # std::io::Result::Ok(()) }); -/// ``` -#[derive(Debug)] -pub struct Async { - io: Option, -} - -impl Unpin for Async {} - -impl Async { - /// Creates an async I/O handle. - /// - /// This method will put the handle in non-blocking mode and register it in - /// [epoll]/[kqueue]/[event ports]/[IOCP]. - /// - /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement - /// `AsSocket`. - /// - /// [epoll]: https://en.wikipedia.org/wiki/Epoll - /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue - /// [event ports]: https://illumos.org/man/port_create - /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports - /// - /// # Examples - /// - /// ``` - /// use async_io_mini::Async; - /// use std::net::{SocketAddr, TcpListener}; - /// - /// # futures_lite::future::block_on(async { - /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; - /// let listener = Async::new(listener)?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn new(io: T) -> io::Result { - // Put the file descriptor in non-blocking mode. - set_nonblocking(io.as_fd())?; - - Self::new_nonblocking(io) - } - - /// Creates an async I/O handle without setting it to non-blocking mode. - /// - /// This method will register the handle in [epoll]/[kqueue]/[event ports]/[IOCP]. - /// - /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement - /// `AsSocket`. - /// - /// [epoll]: https://en.wikipedia.org/wiki/Epoll - /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue - /// [event ports]: https://illumos.org/man/port_create - /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports - /// - /// # Caveats - /// - /// The caller should ensure that the handle is set to non-blocking mode or that it is okay if - /// it is not set. If not set to non-blocking mode, I/O operations may block the current thread - /// and cause a deadlock in an asynchronous context. - pub fn new_nonblocking(io: T) -> io::Result { - REACTOR.start()?; - // SAFETY: It is impossible to drop the I/O source while it is registered. - REACTOR.register(io.as_fd().as_raw_fd())?; - - Ok(Self { io: Some(io) }) - } -} - -impl AsRawFd for Async { - fn as_raw_fd(&self) -> RawFd { - self.get_ref().as_raw_fd() - } -} - -impl AsFd for Async { - fn as_fd(&self) -> BorrowedFd<'_> { - self.get_ref().as_fd() - } -} - -impl> TryFrom for Async { - type Error = io::Error; - - fn try_from(value: OwnedFd) -> Result { - Async::new(value.into()) - } -} - -impl> TryFrom> for OwnedFd { - type Error = io::Error; - - fn try_from(value: Async) -> Result { - value.into_inner().map(Into::into) - } -} - -impl Async { - /// Gets a reference to the inner I/O handle. - /// - /// # Examples - /// - /// ``` - /// use async_io_mini::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// let inner = listener.get_ref(); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn get_ref(&self) -> &T { - self.io.as_ref().unwrap() - } - - /// Gets a mutable reference to the inner I/O handle. - /// - /// # Safety - /// - /// The underlying I/O source must not be dropped using this function. - /// - /// # Examples - /// - /// ``` - /// use async_io_mini::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// let inner = unsafe { listener.get_mut() }; - /// # std::io::Result::Ok(()) }); - /// ``` - pub unsafe fn get_mut(&mut self) -> &mut T { - self.io.as_mut().unwrap() - } - - /// Unwraps the inner I/O handle. - /// - /// This method will **not** put the I/O handle back into blocking mode. - /// - /// # Examples - /// - /// ``` - /// use async_io_mini::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// let inner = listener.into_inner()?; - /// - /// // Put the listener back into blocking mode. - /// inner.set_nonblocking(false)?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn into_inner(mut self) -> io::Result { - REACTOR.deregister(self.as_fd().as_raw_fd())?; - Ok(self.io.take().unwrap()) - } - - /// Waits until the I/O handle is readable. - /// - /// This method completes when a read operation on this I/O handle wouldn't block. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Wait until a client can be accepted. - /// listener.readable().await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn readable(&self) -> io::Result<()> { - poll_fn(|cx| self.poll_readable(cx)).await - } - - /// Waits until the I/O handle is writable. - /// - /// This method completes when a write operation on this I/O handle wouldn't block. - /// - /// # Examples - /// - /// ``` - /// use async_io_mini::Async; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let stream = Async::::connect(addr).await?; - /// - /// // Wait until the stream is writable. - /// stream.writable().await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn writable(&self) -> io::Result<()> { - poll_fn(|cx| self.poll_writable(cx)).await - } - - /// Polls the I/O handle for readability. - /// - /// When this method returns [`Poll::Ready`], that means the OS has delivered an event - /// indicating readability since the last time this task has called the method and received - /// [`Poll::Pending`]. - /// - /// # Caveats - /// - /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks - /// will just keep waking each other in turn, thus wasting CPU time. - /// - /// Note that the [`AsyncRead`] implementation for [`Async`] also uses this method. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use futures_lite::future; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Wait until a client can be accepted. - /// future::poll_fn(|cx| listener.poll_readable(cx)).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll> { - if REACTOR.fetch_or_set(self.as_fd().as_raw_fd(), Event::Read, cx.waker())? { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - - /// Polls the I/O handle for writability. - /// - /// When this method returns [`Poll::Ready`], that means the OS has delivered an event - /// indicating writability since the last time this task has called the method and received - /// [`Poll::Pending`]. - /// - /// # Caveats - /// - /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks - /// will just keep waking each other in turn, thus wasting CPU time. - /// - /// Note that the [`AsyncWrite`] implementation for [`Async`] also uses this method. - /// - /// # Examples - /// - /// ``` - /// use async_io_mini::Async; - /// use futures_lite::future; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let stream = Async::::connect(addr).await?; - /// - /// // Wait until the stream is writable. - /// future::poll_fn(|cx| stream.poll_writable(cx)).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll> { - if REACTOR.fetch_or_set(self.as_fd().as_raw_fd(), Event::Write, cx.waker())? { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - - /// Performs a read operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is readable. - /// - /// The closure receives a shared reference to the I/O handle. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Accept a new client asynchronously. - /// let (stream, addr) = listener.read_with(|l| l.accept()).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn read_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { - REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Read)?; - - let mut op = op; - loop { - match op(self.get_ref()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.readable()).await?; - } - } - - /// Performs a read operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is readable. - /// - /// The closure receives a mutable reference to the I/O handle. - /// - /// # Safety - /// - /// In the closure, the underlying I/O source must not be dropped. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Accept a new client asynchronously. - /// let (stream, addr) = unsafe { listener.read_with_mut(|l| l.accept()).await? }; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async unsafe fn read_with_mut( - &mut self, - op: impl FnMut(&mut T) -> io::Result, - ) -> io::Result { - REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Read)?; - - let mut op = op; - loop { - match op(self.get_mut()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.readable()).await?; - } - } - - /// Performs a write operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is writable. - /// - /// The closure receives a shared reference to the I/O handle. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let msg = b"hello"; - /// let len = socket.write_with(|s| s.send(msg)).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn write_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { - REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Write)?; - - let mut op = op; - loop { - match op(self.get_ref()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.writable()).await?; - } - } - - /// Performs a write operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is writable. - /// - /// # Safety - /// - /// The closure receives a mutable reference to the I/O handle. In the closure, the underlying - /// I/O source must not be dropped. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let mut socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let msg = b"hello"; - /// let len = unsafe { socket.write_with_mut(|s| s.send(msg)).await? }; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async unsafe fn write_with_mut( - &mut self, - op: impl FnMut(&mut T) -> io::Result, - ) -> io::Result { - REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Write)?; - - let mut op = op; - loop { - match op(self.get_mut()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.writable()).await?; - } - } -} - -impl AsRef for Async { - fn as_ref(&self) -> &T { - self.io.as_ref().unwrap() - } -} - -impl Drop for Async { - fn drop(&mut self) { - if let Some(io) = &self.io { - REACTOR.deregister(io.as_fd().as_raw_fd()).ok(); - } - } -} - -/// Types whose I/O trait implementations do not drop the underlying I/O source. -/// -/// The resource contained inside of the [`Async`] cannot be invalidated. This invalidation can -/// happen if the inner resource (the [`TcpStream`], [`UnixListener`] or other `T`) is moved out -/// and dropped before the [`Async`]. Because of this, functions that grant mutable access to -/// the inner type are unsafe, as there is no way to guarantee that the source won't be dropped -/// and a dangling handle won't be left behind. -/// -/// Unfortunately this extends to implementations of [`Read`] and [`Write`]. Since methods on those -/// traits take `&mut`, there is no guarantee that the implementor of those traits won't move the -/// source out while the method is being run. -/// -/// This trait is an antidote to this predicament. By implementing this trait, the user pledges -/// that using any I/O traits won't destroy the source. This way, [`Async`] can implement the -/// `async` version of these I/O traits, like [`AsyncRead`] and [`AsyncWrite`]. -/// -/// # Safety -/// -/// Any I/O trait implementations for this type must not drop the underlying I/O source. Traits -/// affected by this trait include [`Read`], [`Write`], [`Seek`] and [`BufRead`]. -/// -/// This trait is implemented by default on top of `libstd` types. In addition, it is implemented -/// for immutable reference types, as it is impossible to invalidate any outstanding references -/// while holding an immutable reference, even with interior mutability. As Rust's current pinning -/// system relies on similar guarantees, I believe that this approach is robust. -/// -/// [`BufRead`]: https://doc.rust-lang.org/std/io/trait.BufRead.html -/// [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html -/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html -/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html -/// -/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html -/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html -pub unsafe trait IoSafe {} - -/// Reference types can't be mutated. -/// -/// The worst thing that can happen is that external state is used to change what kind of pointer -/// `as_fd()` returns. For instance: -/// -/// ``` -/// # #[cfg(unix)] { -/// use std::cell::Cell; -/// use std::net::TcpStream; -/// use std::os::unix::io::{AsFd, BorrowedFd}; -/// -/// struct Bar { -/// flag: Cell, -/// a: TcpStream, -/// b: TcpStream -/// } -/// -/// impl AsFd for Bar { -/// fn as_fd(&self) -> BorrowedFd<'_> { -/// if self.flag.replace(!self.flag.get()) { -/// self.a.as_fd() -/// } else { -/// self.b.as_fd() -/// } -/// } -/// } -/// # } -/// ``` -/// -/// We solve this problem by only calling `as_fd()` once to get the original source. Implementations -/// like this are considered buggy (but not unsound) and are thus not really supported by `async-io`. -unsafe impl IoSafe for &T {} - -// Can be implemented on top of libstd types. -unsafe impl IoSafe for std::fs::File {} -unsafe impl IoSafe for std::io::Stderr {} -unsafe impl IoSafe for std::io::Stdin {} -unsafe impl IoSafe for std::io::Stdout {} -unsafe impl IoSafe for std::io::StderrLock<'_> {} -unsafe impl IoSafe for std::io::StdinLock<'_> {} -unsafe impl IoSafe for std::io::StdoutLock<'_> {} -unsafe impl IoSafe for std::net::TcpStream {} -unsafe impl IoSafe for std::process::ChildStdin {} -unsafe impl IoSafe for std::process::ChildStdout {} -unsafe impl IoSafe for std::process::ChildStderr {} - -unsafe impl IoSafe for std::io::BufReader {} -unsafe impl IoSafe for std::io::BufWriter {} -unsafe impl IoSafe for std::io::LineWriter {} -unsafe impl IoSafe for &mut T {} -//unsafe impl IoSafe for alloc::boxed::Box {} -unsafe impl IoSafe for std::borrow::Cow<'_, T> {} - -#[cfg(feature = "futures-io")] -impl futures_io::AsyncRead for Async { - fn poll_read( - mut self: core::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.read(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } - - fn poll_read_vectored( - mut self: core::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &mut [std::io::IoSliceMut<'_>], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.read_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } -} - -// Since this is through a reference, we can't mutate the inner I/O source. -// Therefore this is safe! -#[cfg(feature = "futures-io")] -impl futures_io::AsyncRead for &Async -where - for<'a> &'a T: Read, -{ - fn poll_read( - self: core::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - loop { - match (*self).get_ref().read(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } - - fn poll_read_vectored( - self: core::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &mut [std::io::IoSliceMut<'_>], - ) -> Poll> { - loop { - match (*self).get_ref().read_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } -} - -#[cfg(feature = "futures-io")] -impl futures_io::AsyncWrite for Async { - fn poll_write( - mut self: core::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.write(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_write_vectored( - mut self: core::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.write_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_flush( - mut self: core::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.flush() { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_close(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_flush(cx) - } -} - -#[cfg(feature = "futures-io")] -impl futures_io::AsyncWrite for &Async -where - for<'a> &'a T: Write, -{ - fn poll_write( - self: core::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - match (*self).get_ref().write(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_write_vectored( - self: core::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - loop { - match (*self).get_ref().write_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_flush(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match (*self).get_ref().flush() { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_close(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_flush(cx) - } -} - -impl Async { - /// Creates a TCP listener bound to the specified address. - /// - /// Binding with port number 0 will request an available port from the OS. - /// - /// # Examples - /// - /// ``` - /// use async_io_mini::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// println!("Listening on {}", listener.get_ref().local_addr()?); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn bind>(addr: A) -> io::Result> { - let addr = addr.into(); - Async::new(TcpListener::bind(addr)?) - } - - /// Accepts a new incoming TCP connection. - /// - /// When a connection is established, it will be returned as a TCP stream together with its - /// remote address. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; - /// let (stream, addr) = listener.accept().await?; - /// println!("Accepted client: {}", addr); - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn accept(&self) -> io::Result<(Async, SocketAddr)> { - let (stream, addr) = self.read_with(|io| io.accept()).await?; - Ok((Async::new(stream)?, addr)) - } - - /// Returns a stream of incoming TCP connections. - /// - /// The stream is infinite, i.e. it never stops with a [`None`]. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use futures_lite::{pin, stream::StreamExt}; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; - /// let incoming = listener.incoming(); - /// pin!(incoming); - /// - /// while let Some(stream) = incoming.next().await { - /// let stream = stream?; - /// println!("Accepted client: {}", stream.get_ref().peer_addr()?); - /// } - /// # std::io::Result::Ok(()) }); - /// ``` - #[cfg(feature = "futures-lite")] - pub fn incoming( - &self, - ) -> impl futures_lite::Stream>> + Send + '_ { - futures_lite::stream::unfold(self, |listener| async move { - let res = listener.accept().await.map(|(stream, _)| stream); - Some((res, listener)) - }) - } -} - -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(listener: std::net::TcpListener) -> io::Result { - Async::new(listener) - } -} - -impl Async { - /// Creates a TCP connection to the specified address. - /// - /// # Examples - /// - /// ``` - /// use async_io_mini::Async; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let stream = Async::::connect(addr).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn connect>(addr: A) -> io::Result> { - // Figure out how to handle this address. - let addr = addr.into(); - - let socket = match addr { - SocketAddr::V4(v4) => { - let addr = sys::sockaddr_in { - sin_family: sys::AF_INET as _, - sin_port: u16::to_be(v4.port()), - sin_addr: sys::in_addr { - s_addr: u32::from_ne_bytes(v4.ip().octets()), - }, - #[cfg(target_os = "espidf")] - sin_len: Default::default(), - sin_zero: Default::default(), - }; - - connect( - &addr as *const _ as *const _, - core::mem::size_of_val(&addr), - sys::AF_INET, - sys::SOCK_STREAM, - 0, - ) - } - SocketAddr::V6(v6) => { - let addr = sys::sockaddr_in6 { - sin6_family: sys::AF_INET6 as _, - sin6_port: u16::to_be(v6.port()), - sin6_flowinfo: 0, - sin6_addr: sys::in6_addr { - s6_addr: v6.ip().octets(), - }, - sin6_scope_id: 0, - #[cfg(target_os = "espidf")] - sin6_len: Default::default(), - }; - - connect( - &addr as *const _ as *const _, - core::mem::size_of_val(&addr), - sys::AF_INET6, - sys::SOCK_STREAM, - 6, - ) - } - }?; - - // Use new_nonblocking because connect already sets socket to non-blocking mode. - let stream = Async::new_nonblocking(TcpStream::from(socket))?; - - // The stream becomes writable when connected. - stream.writable().await?; - - // Check if there was an error while connecting. - match stream.get_ref().take_error()? { - None => Ok(stream), - Some(err) => Err(err), - } - } - - /// Reads data from the stream without removing it from the buffer. - /// - /// Returns the number of bytes read. Successive calls of this method read the same data. - /// - /// # Examples - /// - /// ``` - /// use async_io_mini::Async; - /// use futures_lite::{io::AsyncWriteExt, stream::StreamExt}; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let mut stream = Async::::connect(addr).await?; - /// - /// stream - /// .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") - /// .await?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = stream.peek(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn peek(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.peek(buf)).await - } -} - -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(stream: std::net::TcpStream) -> io::Result { - Async::new(stream) - } -} - -impl Async { - /// Creates a UDP socket bound to the specified address. - /// - /// Binding with port number 0 will request an available port from the OS. - /// - /// # Examples - /// - /// ``` - /// use async_io_mini::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; - /// println!("Bound to {}", socket.get_ref().local_addr()?); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn bind>(addr: A) -> io::Result> { - let addr = addr.into(); - Async::new(UdpSocket::bind(addr)?) - } - - /// Receives a single datagram message. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// - /// let mut buf = [0u8; 1024]; - /// let (len, addr) = socket.recv_from(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - self.read_with(|io| io.recv_from(buf)).await - } - - /// Receives a single datagram message without removing it from the queue. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// - /// let mut buf = [0u8; 1024]; - /// let (len, addr) = socket.peek_from(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - self.read_with(|io| io.peek_from(buf)).await - } - - /// Sends data to the specified address. - /// - /// Returns the number of bytes writen. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; - /// let addr = socket.get_ref().local_addr()?; - /// - /// let msg = b"hello"; - /// let len = socket.send_to(msg, addr).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn send_to>(&self, buf: &[u8], addr: A) -> io::Result { - let addr = addr.into(); - self.write_with(|io| io.send_to(buf, addr)).await - } - - /// Receives a single datagram message from the connected peer. - /// - /// Returns the number of bytes read. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = socket.recv(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn recv(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.recv(buf)).await - } - - /// Receives a single datagram message from the connected peer without removing it from the - /// queue. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = socket.peek(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn peek(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.peek(buf)).await - } - - /// Sends data to the connected peer. - /// - /// Returns the number of bytes written. - /// - /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io_mini::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let msg = b"hello"; - /// let len = socket.send(msg).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn send(&self, buf: &[u8]) -> io::Result { - self.write_with(|io| io.send(buf)).await - } -} - -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(socket: std::net::UdpSocket) -> io::Result { - Async::new(socket) - } -} - -/// Polls a future once, waits for a wakeup, and then optimistically assumes the future is ready. -async fn optimistic(fut: impl Future>) -> io::Result<()> { - let mut polled = false; - let mut fut = pin!(fut); - - poll_fn(move |cx| { - if !polled { - polled = true; - fut.as_mut().poll(cx) - } else { - Poll::Ready(Ok(())) - } - }) - .await -} - -fn connect( - addr: *const sys::sockaddr, - addr_len: usize, - domain: sys::c_int, - ty: sys::c_int, - protocol: sys::c_int, -) -> io::Result { - // Create the socket. - let socket = unsafe { OwnedFd::from_raw_fd(syscall_los!(sys::socket(domain, ty, protocol))?) }; - - // Set non-blocking mode. - set_nonblocking(socket.as_fd())?; - - syscall_los_eagain!(unsafe { sys::connect(socket.as_raw_fd(), addr, addr_len as _) })?; - - Ok(socket) -} - -fn set_nonblocking(fd: BorrowedFd) -> io::Result<()> { - let previous = unsafe { sys::fcntl(fd.as_raw_fd(), sys::F_GETFL) }; - let new = previous | sys::O_NONBLOCK; - if new != previous { - syscall!(unsafe { sys::fcntl(fd.as_raw_fd(), sys::F_SETFL, new) })?; - } - - Ok(()) -} +#[cfg(feature = "embassy-time")] +mod timer; diff --git a/src/timer.rs b/src/timer.rs new file mode 100644 index 0000000..4372e8a --- /dev/null +++ b/src/timer.rs @@ -0,0 +1,458 @@ +use core::fmt::{self, Debug}; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll, Waker}; +use core::time::Duration; + +use std::time::Instant; + +/// A future or stream that emits timed events. +/// +/// Timers are futures that output a single [`Instant`] when they fire. +/// +/// Timers are also streams that can output [`Instant`]s periodically. +/// +/// # Precision +/// +/// There is a limit on the maximum precision that a `Timer` can provide. This limit is +/// dependent on the current platform and follows the precision provided by the `embassy-time` +/// crate for that platform; for instance, on Windows, the maximum precision is +/// about 16 milliseconds. Because of this limit, the timer may sleep for longer than the +/// requested duration. It will never sleep for less. +/// +/// On embedded platforms like ESP-IDF, the precision is much higer (up to 1 microsecond), +/// because the `embassy-time` crate for ESP-IDF uses the ESP-IDF Timer service. +/// +/// # Examples +/// +/// Sleep for 1 second: +/// +/// ``` +/// use async_io_mini::Timer; +/// use std::time::Duration; +/// +/// # futures_lite::future::block_on(async { +/// Timer::after(Duration::from_secs(1)).await; +/// # }); +/// ``` +/// +/// Timeout after 1 second: +/// +/// ``` +/// use async_io_mini::Timer; +/// use futures_lite::FutureExt; +/// use std::time::Duration; +/// +/// # futures_lite::future::block_on(async { +/// let wait = core::future::pending::>() +/// .or(async { +/// Timer::after(Duration::from_secs(1)).await; +/// Err(std::io::ErrorKind::TimedOut.into()) +/// }) +/// .await?; +/// # std::io::Result::Ok(()) }); +/// ``` +pub struct Timer { + when: Option, + period: Duration, + waker: Option, +} + +impl Timer { + /// Creates a timer that will never fire. + /// + /// # Examples + /// + /// This function may also be useful for creating a function with an optional timeout. + /// + /// ``` + /// # futures_lite::future::block_on(async { + /// use async_io_mini::Timer; + /// use futures_lite::prelude::*; + /// use std::time::Duration; + /// + /// async fn run_with_timeout(timeout: Option) { + /// let timer = timeout + /// .map(|timeout| Timer::after(timeout)) + /// .unwrap_or_else(Timer::never); + /// + /// run_lengthy_operation().or(timer).await; + /// } + /// # // Note that since a Timer as a Future returns an Instant, + /// # // this function needs to return an Instant to be used + /// # // in "or". + /// # async fn run_lengthy_operation() -> std::time::Instant { + /// # std::time::Instant::now() + /// # } + /// + /// // Times out after 5 seconds. + /// run_with_timeout(Some(Duration::from_secs(5))).await; + /// // Does not time out. + /// run_with_timeout(None).await; + /// # }); + /// ``` + pub fn never() -> Timer { + let _fix_linking = embassy_time::Timer::after(embassy_time::Duration::from_secs(1)); + + Timer { + when: None, + period: Duration::MAX, + waker: None, + } + } + + /// Creates a timer that emits an event once after the given duration of time. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Timer; + /// use std::time::Duration; + /// + /// # futures_lite::future::block_on(async { + /// Timer::after(Duration::from_secs(1)).await; + /// # }); + /// ``` + pub fn after(duration: Duration) -> Timer { + let Some(start) = Instant::now().checked_add(duration) else { + return Timer::never(); + }; + + Timer::interval_at(start, Duration::MAX) + } + + /// Creates a timer that emits an event once at the given time instant. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Timer; + /// use std::time::{Duration, Instant}; + /// + /// # futures_lite::future::block_on(async { + /// let now = Instant::now(); + /// let when = now + Duration::from_secs(1); + /// Timer::at(when).await; + /// # }); + /// ``` + pub fn at(instant: Instant) -> Timer { + Timer::interval_at(instant, Duration::MAX) + } + + /// Creates a timer that emits events periodically. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Timer; + /// use futures_lite::StreamExt; + /// use std::time::{Duration, Instant}; + /// + /// # futures_lite::future::block_on(async { + /// let period = Duration::from_secs(1); + /// Timer::interval(period).next().await; + /// # }); + /// ``` + pub fn interval(period: Duration) -> Timer { + let Some(start) = Instant::now().checked_add(period) else { + return Timer::never(); + }; + + Timer::interval_at(start, period) + } + + /// Creates a timer that emits events periodically, starting at `start`. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Timer; + /// use futures_lite::StreamExt; + /// use std::time::{Duration, Instant}; + /// + /// # futures_lite::future::block_on(async { + /// let start = Instant::now(); + /// let period = Duration::from_secs(1); + /// Timer::interval_at(start, period).next().await; + /// # }); + /// ``` + pub fn interval_at(start: Instant, period: Duration) -> Timer { + if Self::ticks(&start).is_some() { + Timer { + when: Some(start), + period, + waker: None, + } + } else { + Timer::never() + } + } + + /// Indicates whether or not this timer will ever fire. + /// + /// [`never()`] will never fire, and timers created with [`after()`] or [`at()`] will fire + /// if the duration is not too large. + /// + /// [`never()`]: Timer::never() + /// [`after()`]: Timer::after() + /// [`at()`]: Timer::at() + /// + /// # Examples + /// + /// ``` + /// # futures_lite::future::block_on(async { + /// use async_io_mini::Timer; + /// use futures_lite::prelude::*; + /// use std::time::Duration; + /// + /// // `never` will never fire. + /// assert!(!Timer::never().will_fire()); + /// + /// // `after` will fire if the duration is not too large. + /// assert!(Timer::after(Duration::from_secs(1)).will_fire()); + /// assert!(!Timer::after(Duration::MAX).will_fire()); + /// + /// // However, once an `after` timer has fired, it will never fire again. + /// let mut t = Timer::after(Duration::from_secs(1)); + /// assert!(t.will_fire()); + /// (&mut t).await; + /// assert!(!t.will_fire()); + /// + /// // Interval timers will fire periodically. + /// let mut t = Timer::interval(Duration::from_secs(1)); + /// assert!(t.will_fire()); + /// t.next().await; + /// assert!(t.will_fire()); + /// # }); + /// ``` + #[inline] + pub fn will_fire(&self) -> bool { + self.when.is_some() + } + + /// Sets the timer to emit an en event once after the given duration of time. + /// + /// Note that resetting a timer is different from creating a new timer because + /// [`set_after()`][`Timer::set_after()`] does not remove the waker associated with the task + /// that is polling the timer. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Timer; + /// use std::time::Duration; + /// + /// # futures_lite::future::block_on(async { + /// let mut t = Timer::after(Duration::from_secs(1)); + /// t.set_after(Duration::from_millis(100)); + /// # }); + /// ``` + pub fn set_after(&mut self, duration: Duration) { + match Instant::now().checked_add(duration) { + Some(instant) => self.set_at(instant), + // Overflow to never going off. + None => self.set_never(), + } + } + + /// Sets the timer to emit an event once at the given time instant. + /// + /// Note that resetting a timer is different from creating a new timer because + /// [`set_at()`][`Timer::set_at()`] does not remove the waker associated with the task + /// that is polling the timer. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Timer; + /// use std::time::{Duration, Instant}; + /// + /// # futures_lite::future::block_on(async { + /// let mut t = Timer::after(Duration::from_secs(1)); + /// + /// let now = Instant::now(); + /// let when = now + Duration::from_secs(1); + /// t.set_at(when); + /// # }); + /// ``` + pub fn set_at(&mut self, instant: Instant) { + let ticks = Self::ticks(&instant); + + if let Some(ticks) = ticks { + self.when = Some(instant); + self.period = Duration::MAX; + + if let Some(waker) = self.waker.as_ref() { + embassy_time_queue_driver::schedule_wake(ticks, waker); + } + } else { + self.set_never(); + } + } + + /// Sets the timer to emit events periodically. + /// + /// Note that resetting a timer is different from creating a new timer because + /// [`set_interval()`][`Timer::set_interval()`] does not remove the waker associated with the + /// task that is polling the timer. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Timer; + /// use futures_lite::StreamExt; + /// use std::time::{Duration, Instant}; + /// + /// # futures_lite::future::block_on(async { + /// let mut t = Timer::after(Duration::from_secs(1)); + /// + /// let period = Duration::from_secs(2); + /// t.set_interval(period); + /// # }); + /// ``` + pub fn set_interval(&mut self, period: Duration) { + match Instant::now().checked_add(period) { + Some(instant) => self.set_interval_at(instant, period), + // Overflow to never going off. + None => self.set_never(), + } + } + + /// Sets the timer to emit events periodically, starting at `start`. + /// + /// Note that resetting a timer is different from creating a new timer because + /// [`set_interval_at()`][`Timer::set_interval_at()`] does not remove the waker associated with + /// the task that is polling the timer. + /// + /// # Examples + /// + /// ``` + /// use async_io_mini::Timer; + /// use futures_lite::StreamExt; + /// use std::time::{Duration, Instant}; + /// + /// # futures_lite::future::block_on(async { + /// let mut t = Timer::after(Duration::from_secs(1)); + /// + /// let start = Instant::now(); + /// let period = Duration::from_secs(2); + /// t.set_interval_at(start, period); + /// # }); + /// ``` + pub fn set_interval_at(&mut self, start: Instant, period: Duration) { + let ticks = Self::ticks(&start); + + if let Some(ticks) = ticks { + self.when = Some(start); + self.period = period; + + if let Some(waker) = self.waker.as_ref() { + embassy_time_queue_driver::schedule_wake(ticks, waker); + } + } else { + // Overflow to never going off. + self.set_never(); + } + } + + fn set_never(&mut self) { + self.when = None; + self.waker = None; + self.period = Duration::MAX; + } + + fn fired_at(&mut self, cx: &mut Context<'_>) -> Option { + let when = self.when?; + + if when > Instant::now() { + let ticks = Self::ticks(&when); + + if let Some(ticks) = ticks { + if self + .waker + .as_ref() + .map(|waker| !waker.will_wake(cx.waker())) + .unwrap_or(true) + { + self.waker = Some(cx.waker().clone()); + embassy_time_queue_driver::schedule_wake(ticks, cx.waker()); + } + } else { + self.set_never(); + } + + None + } else { + Some(when) + } + } + + fn ticks(instant: &Instant) -> Option { + fn duration_ticks(duration: &Duration) -> Option { + let ticks = duration.as_secs() as u128 * embassy_time_driver::TICK_HZ as u128 + + duration.subsec_nanos() as u128 * embassy_time_driver::TICK_HZ as u128 + / 1_000_000_000; + + u64::try_from(ticks).ok() + } + + let now = Instant::now(); + let now_ticks = embassy_time_driver::now(); + + if *instant >= now { + let dur_ticks = duration_ticks(&instant.duration_since(now)); + + dur_ticks.and_then(|dur_ticks| now_ticks.checked_add(dur_ticks)) + } else { + let dur_ticks = duration_ticks(&now.duration_since(*instant)); + + dur_ticks.map(|dur_ticks| now_ticks.saturating_sub(dur_ticks)) + } + } +} + +impl Debug for Timer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Timer") + .field("start", &self.when.as_ref()) + .field("period", &self.period) + .finish() + } +} + +impl Future for Timer { + type Output = Instant; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let Some(when) = self.fired_at(cx) else { + return Poll::Pending; + }; + + self.set_never(); + + Poll::Ready(when) + } +} + +#[cfg(feature = "futures-lite")] +impl futures_lite::Stream for Timer { + type Item = Instant; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Some(when) = self.fired_at(cx) else { + return Poll::Pending; + }; + + let next_when = when.checked_add(self.period); + + if let Some(next_when) = next_when { + let period = self.period; + + self.set_interval_at(next_when, period); + } else { + self.set_never(); + } + + Poll::Ready(Some(when)) + } +} diff --git a/tests/timer.rs b/tests/timer.rs new file mode 100644 index 0000000..baf6d00 --- /dev/null +++ b/tests/timer.rs @@ -0,0 +1,99 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::thread; +use std::time::{Duration, Instant}; + +use async_io_mini::Timer; +use futures_lite::{future, FutureExt, StreamExt}; + +fn spawn( + f: impl Future + Send + 'static, +) -> impl Future + Send + 'static { + let (s, r) = async_channel::bounded(1); + + thread::spawn(move || { + future::block_on(async { + s.send(f.await).await.ok(); + }) + }); + + Box::pin(async move { r.recv().await.unwrap() }) +} + +#[test] +fn smoke() { + future::block_on(async { + let start = Instant::now(); + Timer::after(Duration::from_secs(1)).await; + assert!(start.elapsed() >= Duration::from_secs(1)); + }); +} + +#[test] +fn interval() { + future::block_on(async { + let period = Duration::from_secs(1); + let jitter = Duration::from_millis(500); + let start = Instant::now(); + let mut timer = Timer::interval(period); + timer.next().await; + let elapsed = start.elapsed(); + assert!(elapsed >= period && elapsed - period < jitter); + timer.next().await; + let elapsed = start.elapsed(); + assert!(elapsed >= period * 2 && elapsed - period * 2 < jitter); + }); +} + +#[test] +fn poll_across_tasks() { + future::block_on(async { + let start = Instant::now(); + let (sender, receiver) = async_channel::bounded(1); + + let task1 = spawn(async move { + let mut timer = Timer::after(Duration::from_secs(1)); + + async { + (&mut timer).await; + panic!("timer should not be ready") + } + .or(async {}) + .await; + + sender.send(timer).await.ok(); + }); + + let task2 = spawn(async move { + let timer = receiver.recv().await.unwrap(); + timer.await; + }); + + task1.await; + task2.await; + + assert!(start.elapsed() >= Duration::from_secs(1)); + }); +} + +#[test] +fn set() { + future::block_on(async { + let start = Instant::now(); + let timer = Arc::new(Mutex::new(Timer::after(Duration::from_secs(10)))); + + thread::spawn({ + let timer = timer.clone(); + move || { + thread::sleep(Duration::from_secs(1)); + timer.lock().unwrap().set_after(Duration::from_secs(2)); + } + }); + + future::poll_fn(|cx| Pin::new(&mut *timer.lock().unwrap()).poll(cx)).await; + + assert!(start.elapsed() >= Duration::from_secs(2)); + assert!(start.elapsed() < Duration::from_secs(10)); + }); +}