diff --git a/crates/rune/src/runtime/future.rs b/crates/rune/src/runtime/future.rs index 05a3d7d92..4cf9fd2c3 100644 --- a/crates/rune/src/runtime/future.rs +++ b/crates/rune/src/runtime/future.rs @@ -1,16 +1,21 @@ use core::fmt; use core::future; use core::pin::Pin; +use core::ptr::NonNull; use core::task::{Context, Poll}; +use crate::alloc::alloc::Global; use crate::alloc::{self, Box}; use crate::runtime::{ToValue, Value, VmErrorKind, VmResult}; use crate::Any; use pin_project::pin_project; -/// dyn future alias. -type DynFuture = dyn future::Future> + 'static; +/// A virtual table for a type-erased future. +struct Vtable { + poll: unsafe fn(*mut (), cx: &mut Context<'_>) -> Poll>, + drop: unsafe fn(*mut ()), +} /// A type-erased future that can only be unsafely polled in combination with /// the virtual machine that created it. @@ -19,7 +24,8 @@ type DynFuture = dyn future::Future> + 'static; #[rune(builtin, static_type = FUTURE)] #[rune(from_value = Value::into_future)] pub struct Future { - future: Option>>, + future: Option>, + vtable: &'static Vtable, } impl Future { @@ -29,27 +35,25 @@ impl Future { T: 'static + future::Future>, O: ToValue, { - // First construct a normal box, then coerce unsized. - let b = Box::try_new(async move { - let value = vm_try!(future.await); - value.to_value() - })?; - - // SAFETY: We know that the allocator the boxed used is `Global`, which - // is compatible with the allocator used by the `std` box. - unsafe { - let (ptr, alloc) = Box::into_raw_with_allocator(b); - // Our janky coerce unsized. - let b: ::rust_alloc::boxed::Box = ::rust_alloc::boxed::Box::from_raw(ptr); - let b = ::rust_alloc::boxed::Box::into_raw(b); - let b = Box::from_raw_in(b, alloc); - - // Second convert into one of our boxes, which ensures that memory is - // being accounted for. - Ok(Self { - future: Some(Box::into_pin(b)), - }) - } + let (future, Global) = Box::into_raw_with_allocator(Box::try_new(future)?); + + let future = unsafe { NonNull::new_unchecked(future).cast() }; + + Ok(Self { + future: Some(future), + vtable: &Vtable { + poll: |future, cx| unsafe { + match Pin::new_unchecked(&mut *future.cast::()).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(VmResult::Ok(result)) => Poll::Ready(result.to_value()), + Poll::Ready(VmResult::Err(err)) => Poll::Ready(VmResult::Err(err)), + } + }, + drop: |future| unsafe { + _ = Box::from_raw_in(future.cast::(), Global); + }, + }, + }) } /// Check if future is completed. @@ -64,21 +68,31 @@ impl future::Future for Future { type Output = VmResult; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); + unsafe { + let this = self.get_unchecked_mut(); - let future = match &mut this.future { - Some(future) => future, - None => { + let Some(future) = this.future else { return Poll::Ready(VmResult::err(VmErrorKind::FutureCompleted)); + }; + + match (this.vtable.poll)(future.as_ptr(), cx) { + Poll::Ready(result) => { + this.future = None; + (this.vtable.drop)(future.as_ptr()); + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, } - }; + } + } +} - match future.as_mut().poll(cx) { - Poll::Ready(result) => { - this.future = None; - Poll::Ready(result) +impl Drop for Future { + fn drop(&mut self) { + unsafe { + if let Some(future) = self.future.take() { + (self.vtable.drop)(future.as_ptr()); } - Poll::Pending => Poll::Pending, } } } diff --git a/crates/rune/src/runtime/tests.rs b/crates/rune/src/runtime/tests.rs index 9910f7493..62d495bbe 100644 --- a/crates/rune/src/runtime/tests.rs +++ b/crates/rune/src/runtime/tests.rs @@ -1,9 +1,25 @@ +use core::future::Future as _; +use core::pin::pin; +use core::task::{Context, Poll}; + +use std::sync::Arc; +use std::task::Wake; + use crate as rune; -use crate::runtime::{AnyObj, Shared, Value}; + +use crate::runtime::{AnyObj, Shared, Value, VmResult}; use crate::Any; use crate::support::Result; +struct NoopWaker; + +impl Wake for NoopWaker { + fn wake(self: Arc) { + // nothing + } +} + #[derive(Any, Debug, PartialEq, Eq)] struct Foo(isize); @@ -224,3 +240,46 @@ fn shared_is_writable() -> crate::support::Result<()> { assert!(shared.is_writable()); Ok(()) } + +#[test] +fn ensure_future_dropped_poll() -> crate::support::Result<()> { + use crate::runtime::Future; + + let mut future = pin!(Future::new(async { VmResult::Ok(10) })?); + + let waker = Arc::new(NoopWaker).into(); + let mut cx = Context::from_waker(&waker); + + assert!(!future.is_completed()); + + // NB: By polling the future to completion we are causing it to be dropped when polling is completed. + let Poll::Ready(ok) = future.as_mut().poll(&mut cx) else { + panic!("expected ready"); + }; + + assert_eq!(ok.unwrap().as_integer().unwrap(), 10); + assert!(future.is_completed()); + Ok(()) +} + +#[test] +fn ensure_future_dropped_explicitly() -> crate::support::Result<()> { + use crate::runtime::Future; + + let mut future = pin!(Future::new(async { VmResult::Ok(10) })?); + // NB: We cause the future to be dropped explicitly through it's Drop destructor here by replacing it. + future.set(Future::new(async { VmResult::Ok(0) })?); + + let waker = Arc::new(NoopWaker).into(); + let mut cx = Context::from_waker(&waker); + + assert!(!future.is_completed()); + + let Poll::Ready(ok) = future.as_mut().poll(&mut cx) else { + panic!("expected ready"); + }; + + assert_eq!(ok.unwrap().as_integer().unwrap(), 0); + assert!(future.is_completed()); + Ok(()) +}