From 342d6edf50da5a1a2df4d6ed75f1f91644b70a5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aurimas=20Bla=C5=BEulionis?= <0x60@pm.me> Date: Tue, 5 Dec 2023 16:09:47 +0200 Subject: [PATCH] Add futures trait compatibility --- mfio/src/futures_compat.rs | 140 +++++++++++++++++++++++++++++++++++++ mfio/src/lib.rs | 4 ++ 2 files changed, 144 insertions(+) create mode 100644 mfio/src/futures_compat.rs diff --git a/mfio/src/futures_compat.rs b/mfio/src/futures_compat.rs new file mode 100644 index 0000000..402f5d5 --- /dev/null +++ b/mfio/src/futures_compat.rs @@ -0,0 +1,140 @@ +use crate::io::*; +use crate::stdeq::{self, AsyncIoFut}; +use crate::util::PosShift; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; +use futures::io::{AsyncRead, AsyncSeek, AsyncWrite}; +use std::io::{Result, SeekFrom}; + +pub struct Compat<'a, Io: ?Sized> { + io: &'a Io, + read: Option>, + write: Option>, +} + +/// Bridges mfio with futures. +/// +/// # Examples +/// +/// Read from mfio object through futures traits. +/// +/// ```rust +/// # mod sample { +/// # include!("sample.rs"); +/// # } +/// # use sample::SampleIo; +/// # fn work() -> mfio::error::Result<()> { +/// use futures::io::{AsyncReadExt, Cursor}; +/// use mfio::backend::*; +/// use mfio::futures_compat::FuturesCompat; +/// use mfio::stdeq::SeekableRef; +/// +/// let mem = vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144]; +/// let handle = SampleIo::new(mem.clone()); +/// +/// handle.block_on(async { +/// let mut buf = Cursor::new(vec![0; mem.len()]); +/// +/// let handle = SeekableRef::from(&handle); +/// futures::io::copy(&mut handle.compat(), &mut buf).await?; +/// assert_eq!(mem, buf.into_inner()); +/// +/// Ok(()) +/// }) +/// # } +/// # work().unwrap(); +/// ``` +pub trait FuturesCompat { + fn compat(&self) -> Compat { + Compat { + io: self, + read: None, + write: None, + } + } +} + +// StreamPos is needed for all I/O traits, so we use it to make sure rust gives better diagnostics. +impl<'a, Io: ?Sized + stdeq::StreamPos> FuturesCompat for Io {} + +impl<'a, Io: ?Sized + stdeq::AsyncRead> AsyncRead for Compat<'a, Io> +where + u64: PosShift, +{ + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + let this = unsafe { self.get_unchecked_mut() }; + + loop { + if let Some(read) = this.read.as_mut() { + // Update the sync handle. This is how we hack around the lifetimes of input buffer. + #[cfg(not(mfio_assume_linear_types))] + { + assert_eq!(read.sync.as_ref().map(|v| v.len()), Some(buf.len())); + // SAFETY: AsyncIoFut will only use the sync object if, and only if the buffer is + // to be written in this poll. + read.sync = Some(unsafe { &mut *(buf as *mut _) }); + } + + let read = unsafe { Pin::new_unchecked(read) }; + + break read.poll(cx).map(|v| { + this.read = None; + v.map_err(|_| std::io::ErrorKind::Other.into()) + }); + } else { + // SAFETY: on mfio_assume_linear_types, this is unsafe. Without the switch this is + // safe, because the buffer is stored in a sync variable that is only used whenever + // the I/O completes. That is processed in this poll function, and we update the + // sync at every iteration of the loop. + let buf = unsafe { &mut *(buf as *mut _) }; + this.read = Some(stdeq::AsyncRead::read(this.io, buf)); + } + } + } +} + +impl<'a, Io: ?Sized + stdeq::AsyncWrite> AsyncWrite for Compat<'a, Io> +where + u64: PosShift, +{ + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + let this = unsafe { self.get_unchecked_mut() }; + + loop { + if let Some(write) = this.write.as_mut() { + let write = unsafe { Pin::new_unchecked(write) }; + + break write.poll(cx).map(|v| { + this.write = None; + v.map_err(|_| std::io::ErrorKind::Other.into()) + }); + } else { + // SAFETY: on mfio_assume_linear_types, this is unsafe. Without the switch this is + // safe, because the buffer is transferred to an intermediate one before this + // function returns.. + let buf = unsafe { &*(buf as *const _) }; + this.write = Some(stdeq::AsyncWrite::write(this.io, buf)); + } + } + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + // Completion of every request currently implies we've flushed. + // TODO: improve semantics + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + // We currently imply that we can just close on drop. + // TODO: improve semantics + Poll::Ready(Ok(())) + } +} + +impl<'a, Io: ?Sized + stdeq::StreamPos> AsyncSeek for Compat<'a, Io> { + fn poll_seek(self: Pin<&mut Self>, _: &mut Context<'_>, pos: SeekFrom) -> Poll> { + let this = unsafe { self.get_unchecked_mut() }; + Poll::Ready(stdeq::std_seek(this.io, pos)) + } +} diff --git a/mfio/src/lib.rs b/mfio/src/lib.rs index 38f7ce4..efae8dc 100644 --- a/mfio/src/lib.rs +++ b/mfio/src/lib.rs @@ -181,6 +181,8 @@ pub(crate) mod std_prelude { pub mod backend; pub mod error; +#[cfg(feature = "std")] +pub mod futures_compat; pub mod io; pub mod stdeq; pub mod traits; @@ -193,6 +195,8 @@ pub mod prelude { pub use crate::backend::integrations::tokio::Tokio; pub use crate::backend::{Integration, IoBackend, IoBackendExt, Null}; pub use crate::error::*; + #[cfg(feature = "std")] + pub use crate::futures_compat::FuturesCompat; pub use crate::io::{ FullPacket, IntoPacket, OwnedPacket, Packet, PacketIo, PacketIoExt, PacketView, Read, RefPacket, VecPacket, Write,