diff --git a/src/runtime.rs b/src/runtime.rs index fe5605c9..e740f9c2 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,5 +1,4 @@ use std::{ - mem, panic::panic_any, sync::{ atomic::{AtomicBool, Ordering}, @@ -200,16 +199,14 @@ impl Runtime { }); let result = local_state.with_query_stack(|stack| { - let (new_stack, result) = DependencyGraph::block_on( + DependencyGraph::block_on( dg, thread_id, database_key, other_id, - mem::take(stack), + stack, query_mutex_guard, - ); - *stack = new_stack; - result + ) }); match result { diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs index 84c5327f..69a71559 100644 --- a/src/runtime/dependency_graph.rs +++ b/src/runtime/dependency_graph.rs @@ -1,22 +1,19 @@ -use std::sync::Arc; use std::thread::ThreadId; use crate::active_query::ActiveQuery; use crate::key::DatabaseKeyIndex; use crate::runtime::WaitResult; -use parking_lot::{Condvar, MutexGuard}; +use parking_lot::MutexGuard; use rustc_hash::FxHashMap; use smallvec::SmallVec; -type QueryStack = Vec; - #[derive(Debug, Default)] pub(super) struct DependencyGraph { /// A `(K -> V)` pair in this map indicates that the the runtime /// `K` is blocked on some query executing in the runtime `V`. /// This encodes a graph that must be acyclic (or else deadlock /// will result). - edges: FxHashMap, + edges: FxHashMap, /// Encodes the `ThreadId` that are blocked waiting for the result /// of a given query. @@ -25,18 +22,7 @@ pub(super) struct DependencyGraph { /// When a key K completes which had dependent queries Qs blocked on it, /// it stores its `WaitResult` here. As they wake up, each query Q in Qs will /// come here to fetch their results. - wait_results: FxHashMap, -} - -#[derive(Debug)] -struct Edge { - blocked_on_id: ThreadId, - blocked_on_key: DatabaseKeyIndex, - stack: QueryStack, - - /// Signalled whenever a query with dependents completes. - /// Allows those dependents to check if they are ready to unblock. - condvar: Arc, + wait_results: FxHashMap, } impl DependencyGraph { @@ -64,7 +50,7 @@ impl DependencyGraph { pub(super) fn for_each_cycle_participant( &mut self, from_id: ThreadId, - from_stack: &mut QueryStack, + from_stack: &mut [ActiveQuery], database_key: DatabaseKeyIndex, to_id: ThreadId, mut closure: impl FnMut(&mut [ActiveQuery]), @@ -104,7 +90,7 @@ impl DependencyGraph { // load up the next thread (i.e., we start at B/QB2, // and then load up the dependency on C/QC2). let edge = self.edges.get_mut(&id).unwrap(); - closure(strip_prefix_query_stack_mut(&mut edge.stack, key)); + closure(strip_prefix_query_stack_mut(edge.stack_mut(), key)); id = edge.blocked_on_id; key = edge.blocked_on_key; } @@ -123,7 +109,7 @@ impl DependencyGraph { pub(super) fn maybe_unblock_runtimes_in_cycle( &mut self, from_id: ThreadId, - from_stack: &QueryStack, + from_stack: &[ActiveQuery], database_key: DatabaseKeyIndex, to_id: ThreadId, ) -> (bool, bool) { @@ -136,7 +122,7 @@ impl DependencyGraph { let next_id = edge.blocked_on_id; let next_key = edge.blocked_on_key; - if let Some(cycle) = strip_prefix_query_stack(&edge.stack, key) + if let Some(cycle) = strip_prefix_query_stack(edge.stack(), key) .iter() .rev() .find_map(|aq| aq.cycle.clone()) @@ -182,19 +168,21 @@ impl DependencyGraph { from_id: ThreadId, database_key: DatabaseKeyIndex, to_id: ThreadId, - from_stack: QueryStack, + from_stack: &mut [ActiveQuery], query_mutex_guard: QueryMutexGuard, - ) -> (QueryStack, WaitResult) { - let condvar = me.add_edge(from_id, database_key, to_id, from_stack); + ) -> WaitResult { + // SAFETY: We are blocking until the result is removed from `DependencyGraph::wait_results` + // and as such we are keeping `from_stack` alive. + let condvar = unsafe { me.add_edge(from_id, database_key, to_id, from_stack) }; // Release the mutex that prevents `database_key` // from completing, now that the edge has been added. drop(query_mutex_guard); loop { - if let Some(stack_and_result) = me.wait_results.remove(&from_id) { + if let Some(result) = me.wait_results.remove(&from_id) { debug_assert!(!me.edges.contains_key(&from_id)); - return stack_and_result; + return result; } condvar.wait(&mut me); } @@ -203,32 +191,29 @@ impl DependencyGraph { /// Helper for `block_on`: performs actual graph modification /// to add a dependency edge from `from_id` to `to_id`, which is /// computing `database_key`. - fn add_edge( + /// + /// # Safety + /// + /// The caller needs to keep `from_stack`/`'aq`` alive until `from_id` has been removed from the `wait_results`. + // This safety invariant is consumed by the `Edge` struct + unsafe fn add_edge<'aq>( &mut self, from_id: ThreadId, database_key: DatabaseKeyIndex, to_id: ThreadId, - from_stack: QueryStack, - ) -> Arc { + from_stack: &'aq mut [ActiveQuery], + ) -> edge::EdgeGuard<'aq> { assert_ne!(from_id, to_id); debug_assert!(!self.edges.contains_key(&from_id)); debug_assert!(!self.depends_on(to_id, from_id)); - - let condvar = Arc::new(Condvar::new()); - self.edges.insert( - from_id, - Edge { - blocked_on_id: to_id, - blocked_on_key: database_key, - stack: from_stack, - condvar: condvar.clone(), - }, - ); + // SAFETY: The caller is responsible for ensuring that the `EdgeGuard` outlives the `Edge`. + let (edge, guard) = unsafe { edge::Edge::new(to_id, database_key, from_stack) }; + self.edges.insert(from_id, edge); self.query_dependents .entry(database_key) .or_default() .push(from_id); - condvar + guard } /// Invoked when runtime `to_id` completes executing @@ -253,11 +238,85 @@ impl DependencyGraph { /// the lock on this data structure first, to recover the wait result). fn unblock_runtime(&mut self, id: ThreadId, wait_result: WaitResult) { let edge = self.edges.remove(&id).expect("not blocked"); - self.wait_results.insert(id, (edge.stack, wait_result)); + self.wait_results.insert(id, wait_result); // Now that we have inserted the `wait_results`, // notify the thread. - edge.condvar.notify_one(); + edge.notify(); + } +} + +mod edge { + use std::{marker::PhantomData, mem, sync::Arc, thread::ThreadId}; + + use parking_lot::MutexGuard; + + use crate::{ + runtime::{dependency_graph::DependencyGraph, ActiveQuery}, + DatabaseKeyIndex, + }; + + #[derive(Debug)] + pub(super) struct Edge { + pub(super) blocked_on_id: ThreadId, + pub(super) blocked_on_key: DatabaseKeyIndex, + // the 'static is a lie, we erased the actual lifetime here + stack: &'static mut [ActiveQuery], + + /// Signalled whenever a query with dependents completes. + /// Allows those dependents to check if they are ready to unblock. + condvar: Arc, + } + + pub struct EdgeGuard<'aq> { + condvar: Arc, + // Inform the borrow checker that the edge stack is borrowed until the guard is released. + // This is necessary to ensure that the stack is not modified by the caller of + // `DependencyGraph::add_edge` after the call returns. + _pd: PhantomData<&'aq mut ()>, + } + + impl EdgeGuard<'_> { + pub fn wait(&self, mutex_guard: &mut MutexGuard<'_, DependencyGraph>) { + self.condvar.wait(mutex_guard) + } + } + + impl Edge { + pub(super) unsafe fn new<'aq>( + blocked_on_id: ThreadId, + blocked_on_key: DatabaseKeyIndex, + stack: &'aq mut [ActiveQuery], + ) -> (Self, EdgeGuard<'aq>) { + let condvar = Arc::new(parking_lot::Condvar::new()); + let edge = Self { + blocked_on_id, + blocked_on_key, + // SAFETY: We erase the lifetime here, the caller is responsible for ensuring that + // the `EdgeGuard` outlives this `Edge`. + stack: unsafe { + mem::transmute::<&'aq mut [ActiveQuery], &'static mut [ActiveQuery]>(stack) + }, + condvar: condvar.clone(), + }; + let edge_guard = EdgeGuard { + condvar, + _pd: PhantomData, + }; + (edge, edge_guard) + } + + pub(super) fn stack_mut(&mut self) -> &mut [ActiveQuery] { + self.stack + } + + pub(super) fn stack(&self) -> &[ActiveQuery] { + self.stack + } + + pub(super) fn notify(self) { + self.condvar.notify_one(); + } } }