From cba13d41ea58042cdfeae9cc12c9efe0a9d2047e Mon Sep 17 00:00:00 2001 From: Timo Glane Date: Mon, 13 May 2024 18:33:50 +0200 Subject: [PATCH] Drop the join waker of a task eagerly when the task completes and there is no join interest --- tokio/src/runtime/task/harness.rs | 86 +++++++++++++++---- tokio/src/runtime/task/state.rs | 83 +++++++++++++++--- .../src/runtime/tests/loom_current_thread.rs | 18 ++++ tokio/src/runtime/tests/loom_multi_thread.rs | 36 ++++++++ tokio/src/sync/tests/loom_atomic_waker.rs | 2 +- 5 files changed, 197 insertions(+), 28 deletions(-) diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 8479becd80a..38cefa90a9d 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -283,21 +283,33 @@ where } pub(super) fn drop_join_handle_slow(self) { + use super::state::TransitionToJoinHandleDrop; // Try to unset `JOIN_INTEREST`. This must be done as a first step in // case the task concurrently completed. - if self.state().unset_join_interested().is_err() { - // It is our responsibility to drop the output. This is critical as - // the task output may not be `Send` and as such must remain with - // the scheduler or `JoinHandle`. i.e. if the output remains in the - // task structure until the task is deallocated, it may be dropped - // by a Waker on any arbitrary thread. - // - // Panics are delivered to the user via the `JoinHandle`. Given that - // they are dropping the `JoinHandle`, we assume they are not - // interested in the panic and swallow it. - let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { - self.core().drop_future_or_output(); - })); + // + // TODO Create new bit/flag in state -> Set WantToDropJoinWaker in transition when failing + let transition = self.state().transition_to_join_handle_drop(); + match transition { + TransitionToJoinHandleDrop::Failed => { + // It is our responsibility to drop the output. This is critical as + // the task output may not be `Send` and as such must remain with + // the scheduler or `JoinHandle`. i.e. if the output remains in the + // task structure until the task is deallocated, it may be dropped + // by a Waker on any arbitrary thread. + // + // Panics are delivered to the user via the `JoinHandle`. Given that + // they are dropping the `JoinHandle`, we assume they are not + // interested in the panic and swallow it. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + self.core().drop_future_or_output(); + })); + } + TransitionToJoinHandleDrop::OkDropJoinWaker => unsafe { + // If there is a waker associated with this task when the `JoinHandle` is about to get + // dropped we want to also drop this waker if the task is already completed. + self.trailer().set_waker(None); + }, + TransitionToJoinHandleDrop::OkDoNothing => (), } // Drop the `JoinHandle` reference, possibly deallocating the task @@ -308,6 +320,7 @@ where /// Completes the task. This method assumes that the state is RUNNING. fn complete(self) { + use super::state::TransitionToTerminal; // The future has completed and its output has been written to the task // stage. We transition from running to complete. @@ -332,8 +345,28 @@ where // The task has completed execution and will no longer be scheduled. let num_release = self.release(); - if self.state().transition_to_terminal(num_release) { - self.dealloc(); + match self.state().transition_to_terminal(num_release) { + TransitionToTerminal::OkDoNothing => (), + TransitionToTerminal::OkDealloc => { + self.dealloc(); + } + TransitionToTerminal::FailedDropJoinWaker => { + // Safety: In this case we are the only one referencing the task and the active + // waker is the only one preventing the task from being deallocated so noone else + // will try to access the waker here. + unsafe { + self.trailer().set_waker(None); + } + + match self.state().transition_to_terminal(num_release) { + TransitionToTerminal::OkDealloc => self.dealloc(), + // We do not expect this to happen since `TransitionToTerminal::DropJoinWaker` + // will only be returned when after dropping the JoinWaker the task can be + // safely. Because after this failed transition the COMPLETE bit is still set + // its fine to transition to terminal in two steps here + _ => (), + } + } } } @@ -373,7 +406,7 @@ fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool { debug_assert!(snapshot.is_join_interested()); - if !snapshot.is_complete() { + if !snapshot.is_complete() && !snapshot.is_terminal() { // If the task is not complete, try storing the provided waker in the // task's waker field. @@ -439,6 +472,27 @@ fn set_join_waker( res } +fn unset_join_waker( + header: &Header, + trailer: &Trailer, + snapshot: Snapshot, +) -> Result { + assert!(snapshot.is_join_interested()); + assert!(snapshot.is_join_waker_set()); + + // Make sure the `JoinWaker` bit is unset before accessing the `waker` directly. + let res = header.state.unset_waker(); + if res.is_ok() { + // Safety: Only the `JoinHandle` may set the `waker` field. When + // `JOIN_INTEREST` is **not** set, nothing else will touch the field. + unsafe { + trailer.set_waker(None); + } + } + + res +} + enum PollFuture { Complete, Notified, diff --git a/tokio/src/runtime/task/state.rs b/tokio/src/runtime/task/state.rs index 0fc7bb0329b..bf15892ca6e 100644 --- a/tokio/src/runtime/task/state.rs +++ b/tokio/src/runtime/task/state.rs @@ -36,8 +36,12 @@ const JOIN_WAKER: usize = 0b10_000; /// The task has been forcibly cancelled. const CANCELLED: usize = 0b100_000; +const TERMINAL: usize = 0b1_000_000; + /// All bits. -const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED; +// const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED; +const STATE_MASK: usize = + LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED | TERMINAL; /// Bits used by the ref count portion of the state. const REF_COUNT_MASK: usize = !STATE_MASK; @@ -89,6 +93,20 @@ pub(crate) enum TransitionToNotifiedByRef { Submit, } +#[must_use] +pub(crate) enum TransitionToJoinHandleDrop { + Failed, + OkDoNothing, + OkDropJoinWaker, +} + +#[must_use] +pub(crate) enum TransitionToTerminal { + OkDoNothing, + OkDealloc, + FailedDropJoinWaker, +} + /// All transitions are performed via RMW operations. This establishes an /// unambiguous modification order. impl State { @@ -174,6 +192,23 @@ impl State { }) } + pub(super) fn transition_to_join_handle_drop(&self) -> TransitionToJoinHandleDrop { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_join_interested() { + snapshot.unset_join_interested() + } + + if snapshot.is_complete() && !snapshot.is_terminal() { + (TransitionToJoinHandleDrop::Failed, None) + } else if snapshot.is_join_waker_set() { + snapshot.unset_join_waker(); + (TransitionToJoinHandleDrop::OkDropJoinWaker, Some(snapshot)) + } else { + (TransitionToJoinHandleDrop::OkDoNothing, Some(snapshot)) + } + }) + } + /// Transitions the task from `Running` -> `Complete`. pub(super) fn transition_to_complete(&self) -> Snapshot { const DELTA: usize = RUNNING | COMPLETE; @@ -181,6 +216,7 @@ impl State { let prev = Snapshot(self.val.fetch_xor(DELTA, AcqRel)); assert!(prev.is_running()); assert!(!prev.is_complete()); + assert!(!prev.is_terminal()); Snapshot(prev.0 ^ DELTA) } @@ -188,16 +224,37 @@ impl State { /// Transitions from `Complete` -> `Terminal`, decrementing the reference /// count the specified number of times. /// - /// Returns true if the task should be deallocated. - pub(super) fn transition_to_terminal(&self, count: usize) -> bool { - let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel)); - assert!( - prev.ref_count() >= count, - "current: {}, sub: {}", - prev.ref_count(), - count - ); - prev.ref_count() == count + /// Returns `TransitionToTerminal::OkDoNothing` if transition was successful but the task can + /// not already be deallocated. + /// Returns `TransitionToTerminal::OkDealloc` if the task should be deallocated. + /// Returns `TransitionToTerminal::FailedDropJoinWaker` if the transition failed because of a + /// the join waker being the only last. In this case the reference count will not be decremented + /// but the `JOIN_WAKER` bit will be unset. + pub(super) fn transition_to_terminal(&self, count: usize) -> TransitionToTerminal { + self.fetch_update_action(|mut snapshot| { + assert!(!snapshot.is_running()); + assert!(snapshot.is_complete()); + assert!(!snapshot.is_terminal()); + assert!( + snapshot.ref_count() >= count, + "current: {}, sub: {}", + snapshot.ref_count(), + count + ); + + if snapshot.ref_count() == count { + snapshot.0 -= count * REF_ONE; + snapshot.0 |= TERMINAL; + (TransitionToTerminal::OkDealloc, Some(snapshot)) + } else if !snapshot.is_join_interested() && snapshot.is_join_waker_set() { + snapshot.unset_join_waker(); + (TransitionToTerminal::FailedDropJoinWaker, Some(snapshot)) + } else { + snapshot.0 -= count * REF_ONE; + snapshot.0 |= TERMINAL; + (TransitionToTerminal::OkDoNothing, Some(snapshot)) + } + }) } /// Transitions the state to `NOTIFIED`. @@ -557,6 +614,10 @@ impl Snapshot { self.0 & COMPLETE == COMPLETE } + pub(super) fn is_terminal(self) -> bool { + self.0 & TERMINAL == TERMINAL + } + pub(super) fn is_join_interested(self) -> bool { self.0 & JOIN_INTEREST == JOIN_INTEREST } diff --git a/tokio/src/runtime/tests/loom_current_thread.rs b/tokio/src/runtime/tests/loom_current_thread.rs index edda6e49954..be30fc4a042 100644 --- a/tokio/src/runtime/tests/loom_current_thread.rs +++ b/tokio/src/runtime/tests/loom_current_thread.rs @@ -106,6 +106,24 @@ fn assert_no_unnecessary_polls() { }); } +// #[test] +// fn new_test() { +// loom::model(|| { +// let rt = Builder::new_current_thread().build().unwrap(); +// +// let jh = rt.spawn(async {}); +// +// let bg = std::thread::spawn(move || { +// jh.poll(); +// }); +// +// rt.block_on(async {}); +// +// rt.shutdown(); +// bg.join(); +// }); +// } + struct BlockedFuture { rx: Receiver<()>, num_polls: Arc, diff --git a/tokio/src/runtime/tests/loom_multi_thread.rs b/tokio/src/runtime/tests/loom_multi_thread.rs index c5980c226e0..4a72f24ea96 100644 --- a/tokio/src/runtime/tests/loom_multi_thread.rs +++ b/tokio/src/runtime/tests/loom_multi_thread.rs @@ -460,3 +460,39 @@ impl Future for Track { }) } } + +#[test] +fn timo_test() { + use crate::sync::mpsc::channel; + + loom::model(|| { + let pool = mk_pool(2); + + pool.block_on(async move { + let (tx, mut rx) = channel(1); + + let (a_closer, mut wait_for_close_a) = channel::<()>(1); + let (b_closer, mut wait_for_close_b) = channel::<()>(1); + + let a = spawn(async move { + let b = rx.recv().await.unwrap(); + + futures::future::select( + std::pin::pin!(b), + // std::pin::pin!(futures::future::ready(())), + std::pin::pin!(a_closer.send(())), + ) + .await; + }); + + let b = spawn(async move { + let _ = a.await; + let _ = b_closer.send(()).await; + }); + + tx.send(b).await.unwrap(); + + futures::future::join(wait_for_close_a.recv(), wait_for_close_b.recv()).await; + }); + }); +} diff --git a/tokio/src/sync/tests/loom_atomic_waker.rs b/tokio/src/sync/tests/loom_atomic_waker.rs index f8bae65d130..5461514e032 100644 --- a/tokio/src/sync/tests/loom_atomic_waker.rs +++ b/tokio/src/sync/tests/loom_atomic_waker.rs @@ -60,7 +60,7 @@ fn test_panicky_waker() { // which would otherwise log. // // We can't however leaved it uncommented, because it's global. - // panic::set_hook(Box::new(|_| ())); + panic::set_hook(Box::new(|_| ())); const NUM_NOTIFY: usize = 2;