From b3e745b74f92d302de4c08dddcf9cae0649c62b8 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 13 Dec 2024 12:22:55 +0100 Subject: [PATCH] Do not pass ownership of the `QueryStack` in `Runtime::block_on_or_unwind` This commit changes `Edge` such that it no longer takes direct ownership of the query stack and instead keeps a lifetime erased mutable slice, an exclusive borrow and as such the ownership model does not change. The caller now does have to uphold the safety invariant that the query stack borrow is life for the entire computation which is trivially achievable as the caller will block until the computation as done. --- src/runtime.rs | 9 +- src/runtime/dependency_graph.rs | 159 +++++++++++++++++++++++--------- 2 files changed, 119 insertions(+), 49 deletions(-) diff --git a/src/runtime.rs b/src/runtime.rs index fe5605c9..e740f9c2 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,5 +1,4 @@ use std::{ - mem, panic::panic_any, sync::{ atomic::{AtomicBool, Ordering}, @@ -200,16 +199,14 @@ impl Runtime { }); let result = local_state.with_query_stack(|stack| { - let (new_stack, result) = DependencyGraph::block_on( + DependencyGraph::block_on( dg, thread_id, database_key, other_id, - mem::take(stack), + stack, query_mutex_guard, - ); - *stack = new_stack; - result + ) }); match result { diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs index 84c5327f..e6a99e73 100644 --- a/src/runtime/dependency_graph.rs +++ b/src/runtime/dependency_graph.rs @@ -1,22 +1,19 @@ -use std::sync::Arc; use std::thread::ThreadId; use crate::active_query::ActiveQuery; use crate::key::DatabaseKeyIndex; use crate::runtime::WaitResult; -use parking_lot::{Condvar, MutexGuard}; +use parking_lot::MutexGuard; use rustc_hash::FxHashMap; use smallvec::SmallVec; -type QueryStack = Vec; - #[derive(Debug, Default)] pub(super) struct DependencyGraph { /// A `(K -> V)` pair in this map indicates that the the runtime /// `K` is blocked on some query executing in the runtime `V`. /// This encodes a graph that must be acyclic (or else deadlock /// will result). - edges: FxHashMap, + edges: FxHashMap, /// Encodes the `ThreadId` that are blocked waiting for the result /// of a given query. @@ -25,18 +22,7 @@ pub(super) struct DependencyGraph { /// When a key K completes which had dependent queries Qs blocked on it, /// it stores its `WaitResult` here. As they wake up, each query Q in Qs will /// come here to fetch their results. - wait_results: FxHashMap, -} - -#[derive(Debug)] -struct Edge { - blocked_on_id: ThreadId, - blocked_on_key: DatabaseKeyIndex, - stack: QueryStack, - - /// Signalled whenever a query with dependents completes. - /// Allows those dependents to check if they are ready to unblock. - condvar: Arc, + wait_results: FxHashMap, } impl DependencyGraph { @@ -64,7 +50,7 @@ impl DependencyGraph { pub(super) fn for_each_cycle_participant( &mut self, from_id: ThreadId, - from_stack: &mut QueryStack, + from_stack: &mut [ActiveQuery], database_key: DatabaseKeyIndex, to_id: ThreadId, mut closure: impl FnMut(&mut [ActiveQuery]), @@ -104,7 +90,7 @@ impl DependencyGraph { // load up the next thread (i.e., we start at B/QB2, // and then load up the dependency on C/QC2). let edge = self.edges.get_mut(&id).unwrap(); - closure(strip_prefix_query_stack_mut(&mut edge.stack, key)); + closure(strip_prefix_query_stack_mut(edge.stack_mut(), key)); id = edge.blocked_on_id; key = edge.blocked_on_key; } @@ -123,7 +109,7 @@ impl DependencyGraph { pub(super) fn maybe_unblock_runtimes_in_cycle( &mut self, from_id: ThreadId, - from_stack: &QueryStack, + from_stack: &[ActiveQuery], database_key: DatabaseKeyIndex, to_id: ThreadId, ) -> (bool, bool) { @@ -136,7 +122,7 @@ impl DependencyGraph { let next_id = edge.blocked_on_id; let next_key = edge.blocked_on_key; - if let Some(cycle) = strip_prefix_query_stack(&edge.stack, key) + if let Some(cycle) = strip_prefix_query_stack(edge.stack(), key) .iter() .rev() .find_map(|aq| aq.cycle.clone()) @@ -182,19 +168,21 @@ impl DependencyGraph { from_id: ThreadId, database_key: DatabaseKeyIndex, to_id: ThreadId, - from_stack: QueryStack, + from_stack: &mut [ActiveQuery], query_mutex_guard: QueryMutexGuard, - ) -> (QueryStack, WaitResult) { - let condvar = me.add_edge(from_id, database_key, to_id, from_stack); + ) -> WaitResult { + // SAFETY: We are blocking until the result is removed from `DependencyGraph::wait_results` + // and as such we are keeping `from_stack` alive. + let condvar = unsafe { me.add_edge(from_id, database_key, to_id, from_stack) }; // Release the mutex that prevents `database_key` // from completing, now that the edge has been added. drop(query_mutex_guard); loop { - if let Some(stack_and_result) = me.wait_results.remove(&from_id) { + if let Some(result) = me.wait_results.remove(&from_id) { debug_assert!(!me.edges.contains_key(&from_id)); - return stack_and_result; + return result; } condvar.wait(&mut me); } @@ -203,32 +191,28 @@ impl DependencyGraph { /// Helper for `block_on`: performs actual graph modification /// to add a dependency edge from `from_id` to `to_id`, which is /// computing `database_key`. - fn add_edge( + /// + /// # Safety + /// + /// The caller needs to keep `from_stack`/`'aq`` alive until `from_id` has been removed from the `wait_results`. + // This safety invariant is consumed by the `Edge` struct + unsafe fn add_edge<'aq>( &mut self, from_id: ThreadId, database_key: DatabaseKeyIndex, to_id: ThreadId, - from_stack: QueryStack, - ) -> Arc { + from_stack: &'aq mut [ActiveQuery], + ) -> edge::EdgeGuard<'aq> { assert_ne!(from_id, to_id); debug_assert!(!self.edges.contains_key(&from_id)); debug_assert!(!self.depends_on(to_id, from_id)); - - let condvar = Arc::new(Condvar::new()); - self.edges.insert( - from_id, - Edge { - blocked_on_id: to_id, - blocked_on_key: database_key, - stack: from_stack, - condvar: condvar.clone(), - }, - ); + let (edge, guard) = edge::Edge::new(to_id, database_key, from_stack); + self.edges.insert(from_id, edge); self.query_dependents .entry(database_key) .or_default() .push(from_id); - condvar + guard } /// Invoked when runtime `to_id` completes executing @@ -253,11 +237,100 @@ impl DependencyGraph { /// the lock on this data structure first, to recover the wait result). fn unblock_runtime(&mut self, id: ThreadId, wait_result: WaitResult) { let edge = self.edges.remove(&id).expect("not blocked"); - self.wait_results.insert(id, (edge.stack, wait_result)); + self.wait_results.insert(id, wait_result); // Now that we have inserted the `wait_results`, // notify the thread. - edge.condvar.notify_one(); + edge.notify(); + } +} + +mod edge { + use std::{marker::PhantomData, ptr::NonNull, sync::Arc, thread::ThreadId}; + + use parking_lot::MutexGuard; + + use crate::{ + runtime::{dependency_graph::DependencyGraph, ActiveQuery}, + DatabaseKeyIndex, + }; + + #[derive(Debug)] + pub(super) struct Edge { + pub(super) blocked_on_id: ThreadId, + pub(super) blocked_on_key: DatabaseKeyIndex, + stack: SendNonNull<[ActiveQuery]>, + + /// Signalled whenever a query with dependents completes. + /// Allows those dependents to check if they are ready to unblock. + condvar: Arc, + } + + pub struct EdgeGuard<'aq> { + condvar: Arc, + // Inform the borrow checker that the edge stack is borrowed until the guard is released. + // This is necessary to ensure that the stack is not modified by the caller of + // `DependencyGraph::add_edge` after the call returns. + _pd: PhantomData<&'aq mut ()>, + } + + impl EdgeGuard<'_> { + pub fn wait(&self, mutex_guard: &mut MutexGuard<'_, DependencyGraph>) { + self.condvar.wait(mutex_guard) + } + } + + // Wrapper type to allow `Edge` to be `Send` without disregarding its other fields. + struct SendNonNull(NonNull); + + // SAFETY: `Edge` is `Send` as its `stack: NonNull<[ActiveQuery]>,` field is a lifetime erased + // mutable reference to a `Send` type (`ActiveQuery`) that is subject to the owner of `Edge` and is + // guaranteed to be live according to the safety invariants of `DependencyGraph::add_edge`.` + unsafe impl Send for SendNonNull where for<'a> &'a mut T: Send {} + // unsafe impl Sync for SendNonNull where for<'a> &'a mut T: Sync {} + + impl std::fmt::Debug for SendNonNull { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } + } + + impl Edge { + pub(super) fn new( + blocked_on_id: ThreadId, + blocked_on_key: DatabaseKeyIndex, + stack: &mut [ActiveQuery], + ) -> (Self, EdgeGuard<'_>) { + let condvar = Arc::new(parking_lot::Condvar::new()); + let stack = SendNonNull(NonNull::from(stack)); + let edge = Self { + blocked_on_id, + blocked_on_key, + stack, + condvar: condvar.clone(), + }; + let edge_guard = EdgeGuard { + condvar, + _pd: PhantomData, + }; + (edge, edge_guard) + } + + // unerase the lifetime of the stack + pub(super) fn stack_mut(&mut self) -> &mut [ActiveQuery] { + // SAFETY: This is safe due to the invariants upheld by DependencyGraph::add_edge. + unsafe { self.stack.0.as_mut() } + } + + // unerase the lifetime of the stack + pub(super) fn stack(&self) -> &[ActiveQuery] { + // SAFETY: This is safe due to the invariants upheld by DependencyGraph::add_edge. + unsafe { self.stack.0.as_ref() } + } + + pub(super) fn notify(self) { + self.condvar.notify_one(); + } } }