From 1fac7a3ace7fe564b2695c9aeecd844dabb9b8b9 Mon Sep 17 00:00:00 2001 From: Timo Glane Date: Mon, 13 May 2024 18:33:50 +0200 Subject: [PATCH] Drop the join waker of a task eagerly when the task completes and there is no join interest --- tokio/src/runtime/task/harness.rs | 66 ++++++++++++---- tokio/src/runtime/task/state.rs | 83 +++++++++++++++++--- tokio/src/runtime/tests/loom_multi_thread.rs | 33 +++++++- 3 files changed, 154 insertions(+), 28 deletions(-) diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 996f0f2d9b4..5ac3aabb1dd 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -284,21 +284,33 @@ where } pub(super) fn drop_join_handle_slow(self) { + use super::state::TransitionToJoinHandleDrop; // Try to unset `JOIN_INTEREST`. This must be done as a first step in // case the task concurrently completed. - if self.state().unset_join_interested().is_err() { - // It is our responsibility to drop the output. This is critical as - // the task output may not be `Send` and as such must remain with - // the scheduler or `JoinHandle`. i.e. if the output remains in the - // task structure until the task is deallocated, it may be dropped - // by a Waker on any arbitrary thread. - // - // Panics are delivered to the user via the `JoinHandle`. Given that - // they are dropping the `JoinHandle`, we assume they are not - // interested in the panic and swallow it. - let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { - self.core().drop_future_or_output(); - })); + // + // TODO Create new bit/flag in state -> Set WantToDropJoinWaker in transition when failing + let transition = self.state().transition_to_join_handle_drop(); + match transition { + TransitionToJoinHandleDrop::Failed => { + // It is our responsibility to drop the output. This is critical as + // the task output may not be `Send` and as such must remain with + // the scheduler or `JoinHandle`. i.e. if the output remains in the + // task structure until the task is deallocated, it may be dropped + // by a Waker on any arbitrary thread. + // + // Panics are delivered to the user via the `JoinHandle`. Given that + // they are dropping the `JoinHandle`, we assume they are not + // interested in the panic and swallow it. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + self.core().drop_future_or_output(); + })); + } + TransitionToJoinHandleDrop::OkDropJoinWaker => unsafe { + // If there is a waker associated with this task when the `JoinHandle` is about to get + // dropped we want to also drop this waker if the task is already completed. + self.trailer().set_waker(None); + }, + TransitionToJoinHandleDrop::OkDoNothing => (), } // Drop the `JoinHandle` reference, possibly deallocating the task @@ -309,6 +321,7 @@ where /// Completes the task. This method assumes that the state is RUNNING. fn complete(self) { + use super::state::TransitionToTerminal; // The future has completed and its output has been written to the task // stage. We transition from running to complete. @@ -346,8 +359,29 @@ where // The task has completed execution and will no longer be scheduled. let num_release = self.release(); - if self.state().transition_to_terminal(num_release) { - self.dealloc(); + match self.state().transition_to_terminal(num_release) { + TransitionToTerminal::OkDoNothing => (), + TransitionToTerminal::OkDealloc => { + self.dealloc(); + } + TransitionToTerminal::FailedDropJoinWaker => { + // Safety: In this case we are the only one referencing the task and the active + // waker is the only one preventing the task from being deallocated so noone else + // will try to access the waker here. + unsafe { + self.trailer().set_waker(None); + } + + // We do not expect this to happen since `TransitionToTerminal::DropJoinWaker` + // will only be returned when after dropping the JoinWaker the task can be + // safely. Because after this failed transition the COMPLETE bit is still set + // its fine to transition to terminal in two steps here + if let TransitionToTerminal::OkDealloc = + self.state().transition_to_terminal(num_release) + { + self.dealloc(); + } + } } } @@ -387,7 +421,7 @@ fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool { debug_assert!(snapshot.is_join_interested()); - if !snapshot.is_complete() { + if !snapshot.is_complete() && !snapshot.is_terminal() { // If the task is not complete, try storing the provided waker in the // task's waker field. diff --git a/tokio/src/runtime/task/state.rs b/tokio/src/runtime/task/state.rs index 0fc7bb0329b..bf15892ca6e 100644 --- a/tokio/src/runtime/task/state.rs +++ b/tokio/src/runtime/task/state.rs @@ -36,8 +36,12 @@ const JOIN_WAKER: usize = 0b10_000; /// The task has been forcibly cancelled. const CANCELLED: usize = 0b100_000; +const TERMINAL: usize = 0b1_000_000; + /// All bits. -const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED; +// const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED; +const STATE_MASK: usize = + LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED | TERMINAL; /// Bits used by the ref count portion of the state. const REF_COUNT_MASK: usize = !STATE_MASK; @@ -89,6 +93,20 @@ pub(crate) enum TransitionToNotifiedByRef { Submit, } +#[must_use] +pub(crate) enum TransitionToJoinHandleDrop { + Failed, + OkDoNothing, + OkDropJoinWaker, +} + +#[must_use] +pub(crate) enum TransitionToTerminal { + OkDoNothing, + OkDealloc, + FailedDropJoinWaker, +} + /// All transitions are performed via RMW operations. This establishes an /// unambiguous modification order. impl State { @@ -174,6 +192,23 @@ impl State { }) } + pub(super) fn transition_to_join_handle_drop(&self) -> TransitionToJoinHandleDrop { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_join_interested() { + snapshot.unset_join_interested() + } + + if snapshot.is_complete() && !snapshot.is_terminal() { + (TransitionToJoinHandleDrop::Failed, None) + } else if snapshot.is_join_waker_set() { + snapshot.unset_join_waker(); + (TransitionToJoinHandleDrop::OkDropJoinWaker, Some(snapshot)) + } else { + (TransitionToJoinHandleDrop::OkDoNothing, Some(snapshot)) + } + }) + } + /// Transitions the task from `Running` -> `Complete`. pub(super) fn transition_to_complete(&self) -> Snapshot { const DELTA: usize = RUNNING | COMPLETE; @@ -181,6 +216,7 @@ impl State { let prev = Snapshot(self.val.fetch_xor(DELTA, AcqRel)); assert!(prev.is_running()); assert!(!prev.is_complete()); + assert!(!prev.is_terminal()); Snapshot(prev.0 ^ DELTA) } @@ -188,16 +224,37 @@ impl State { /// Transitions from `Complete` -> `Terminal`, decrementing the reference /// count the specified number of times. /// - /// Returns true if the task should be deallocated. - pub(super) fn transition_to_terminal(&self, count: usize) -> bool { - let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel)); - assert!( - prev.ref_count() >= count, - "current: {}, sub: {}", - prev.ref_count(), - count - ); - prev.ref_count() == count + /// Returns `TransitionToTerminal::OkDoNothing` if transition was successful but the task can + /// not already be deallocated. + /// Returns `TransitionToTerminal::OkDealloc` if the task should be deallocated. + /// Returns `TransitionToTerminal::FailedDropJoinWaker` if the transition failed because of a + /// the join waker being the only last. In this case the reference count will not be decremented + /// but the `JOIN_WAKER` bit will be unset. + pub(super) fn transition_to_terminal(&self, count: usize) -> TransitionToTerminal { + self.fetch_update_action(|mut snapshot| { + assert!(!snapshot.is_running()); + assert!(snapshot.is_complete()); + assert!(!snapshot.is_terminal()); + assert!( + snapshot.ref_count() >= count, + "current: {}, sub: {}", + snapshot.ref_count(), + count + ); + + if snapshot.ref_count() == count { + snapshot.0 -= count * REF_ONE; + snapshot.0 |= TERMINAL; + (TransitionToTerminal::OkDealloc, Some(snapshot)) + } else if !snapshot.is_join_interested() && snapshot.is_join_waker_set() { + snapshot.unset_join_waker(); + (TransitionToTerminal::FailedDropJoinWaker, Some(snapshot)) + } else { + snapshot.0 -= count * REF_ONE; + snapshot.0 |= TERMINAL; + (TransitionToTerminal::OkDoNothing, Some(snapshot)) + } + }) } /// Transitions the state to `NOTIFIED`. @@ -557,6 +614,10 @@ impl Snapshot { self.0 & COMPLETE == COMPLETE } + pub(super) fn is_terminal(self) -> bool { + self.0 & TERMINAL == TERMINAL + } + pub(super) fn is_join_interested(self) -> bool { self.0 & JOIN_INTEREST == JOIN_INTEREST } diff --git a/tokio/src/runtime/tests/loom_multi_thread.rs b/tokio/src/runtime/tests/loom_multi_thread.rs index ddd14b7fb3f..9f261b6ad99 100644 --- a/tokio/src/runtime/tests/loom_multi_thread.rs +++ b/tokio/src/runtime/tests/loom_multi_thread.rs @@ -18,7 +18,7 @@ use loom::sync::Arc; use pin_project_lite::pin_project; use std::future::{poll_fn, Future}; -use std::pin::Pin; +use std::pin::{pin, Pin}; use std::sync::atomic::Ordering::{Relaxed, SeqCst}; use std::task::{ready, Context, Poll}; @@ -459,3 +459,34 @@ impl Future for Track { }) } } + +#[test] +fn timo_test() { + use crate::sync::mpsc::channel; + + loom::model(|| { + let pool = mk_pool(2); + + pool.block_on(async move { + let (tx, mut rx) = channel(1); + + let (a_closer, mut wait_for_close_a) = channel::<()>(1); + let (b_closer, mut wait_for_close_b) = channel::<()>(1); + + let a = spawn(async move { + let b = rx.recv().await.unwrap(); + + futures::future::select(pin!(b), pin!(a_closer.send(()))).await; + }); + + let b = spawn(async move { + let _ = a.await; + let _ = b_closer.send(()).await; + }); + + tx.send(b).await.unwrap(); + + futures::future::join(wait_for_close_a.recv(), wait_for_close_b.recv()).await; + }); + }); +}