Skip to content

Commit

Permalink
Do not pass ownership of the QueryStack in `Runtime::block_on_or_un…
Browse files Browse the repository at this point in the history
…wind`

This commit changes `Edge` such that it no longer takes direct ownership of the query stack and instead keeps a lifetime erased mutable slice, an exclusive borrow and as such the ownership model does not change.
The caller now does have to uphold the safety invariant that the query stack borrow is life for the entire computation which is trivially achievable as the caller will block until the computation as done.
  • Loading branch information
Veykril committed Dec 13, 2024
1 parent 1925bf2 commit d1a944e
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 49 deletions.
9 changes: 3 additions & 6 deletions src/runtime.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
mem,
panic::panic_any,
sync::{atomic::AtomicUsize, Arc},
thread::ThreadId,
Expand Down Expand Up @@ -205,16 +204,14 @@ impl Runtime {
});

let result = local_state.with_query_stack(|stack| {
let (new_stack, result) = DependencyGraph::block_on(
DependencyGraph::block_on(
dg,
thread_id,
database_key,
other_id,
mem::take(stack),
stack,
query_mutex_guard,
);
*stack = new_stack;
result
)
});

match result {
Expand Down
159 changes: 116 additions & 43 deletions src/runtime/dependency_graph.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
use std::sync::Arc;
use std::thread::ThreadId;

use crate::active_query::ActiveQuery;
use crate::key::DatabaseKeyIndex;
use crate::runtime::WaitResult;
use parking_lot::{Condvar, MutexGuard};
use parking_lot::MutexGuard;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;

type QueryStack = Vec<ActiveQuery>;

#[derive(Debug, Default)]
pub(super) struct DependencyGraph {
/// A `(K -> V)` pair in this map indicates that the the runtime
/// `K` is blocked on some query executing in the runtime `V`.
/// This encodes a graph that must be acyclic (or else deadlock
/// will result).
edges: FxHashMap<ThreadId, Edge>,
edges: FxHashMap<ThreadId, edge::Edge>,

/// Encodes the `ThreadId` that are blocked waiting for the result
/// of a given query.
Expand All @@ -25,18 +22,7 @@ pub(super) struct DependencyGraph {
/// When a key K completes which had dependent queries Qs blocked on it,
/// it stores its `WaitResult` here. As they wake up, each query Q in Qs will
/// come here to fetch their results.
wait_results: FxHashMap<ThreadId, (QueryStack, WaitResult)>,
}

#[derive(Debug)]
struct Edge {
blocked_on_id: ThreadId,
blocked_on_key: DatabaseKeyIndex,
stack: QueryStack,

/// Signalled whenever a query with dependents completes.
/// Allows those dependents to check if they are ready to unblock.
condvar: Arc<parking_lot::Condvar>,
wait_results: FxHashMap<ThreadId, WaitResult>,
}

impl DependencyGraph {
Expand Down Expand Up @@ -64,7 +50,7 @@ impl DependencyGraph {
pub(super) fn for_each_cycle_participant(
&mut self,
from_id: ThreadId,
from_stack: &mut QueryStack,
from_stack: &mut [ActiveQuery],
database_key: DatabaseKeyIndex,
to_id: ThreadId,
mut closure: impl FnMut(&mut [ActiveQuery]),
Expand Down Expand Up @@ -104,7 +90,7 @@ impl DependencyGraph {
// load up the next thread (i.e., we start at B/QB2,
// and then load up the dependency on C/QC2).
let edge = self.edges.get_mut(&id).unwrap();
closure(strip_prefix_query_stack_mut(&mut edge.stack, key));
closure(strip_prefix_query_stack_mut(edge.stack_mut(), key));
id = edge.blocked_on_id;
key = edge.blocked_on_key;
}
Expand All @@ -123,7 +109,7 @@ impl DependencyGraph {
pub(super) fn maybe_unblock_runtimes_in_cycle(
&mut self,
from_id: ThreadId,
from_stack: &QueryStack,
from_stack: &[ActiveQuery],
database_key: DatabaseKeyIndex,
to_id: ThreadId,
) -> (bool, bool) {
Expand All @@ -136,7 +122,7 @@ impl DependencyGraph {
let next_id = edge.blocked_on_id;
let next_key = edge.blocked_on_key;

if let Some(cycle) = strip_prefix_query_stack(&edge.stack, key)
if let Some(cycle) = strip_prefix_query_stack(edge.stack(), key)
.iter()
.rev()
.find_map(|aq| aq.cycle.clone())
Expand Down Expand Up @@ -182,19 +168,21 @@ impl DependencyGraph {
from_id: ThreadId,
database_key: DatabaseKeyIndex,
to_id: ThreadId,
from_stack: QueryStack,
from_stack: &mut [ActiveQuery],
query_mutex_guard: QueryMutexGuard,
) -> (QueryStack, WaitResult) {
let condvar = me.add_edge(from_id, database_key, to_id, from_stack);
) -> WaitResult {
// SAFETY: We are blocking until the result is removed from `DependencyGraph::wait_results`
// and as such we are keeping `from_stack` alive.
let condvar = unsafe { me.add_edge(from_id, database_key, to_id, from_stack) };

// Release the mutex that prevents `database_key`
// from completing, now that the edge has been added.
drop(query_mutex_guard);

loop {
if let Some(stack_and_result) = me.wait_results.remove(&from_id) {
if let Some(result) = me.wait_results.remove(&from_id) {
debug_assert!(!me.edges.contains_key(&from_id));
return stack_and_result;
return result;
}
condvar.wait(&mut me);
}
Expand All @@ -203,32 +191,28 @@ impl DependencyGraph {
/// Helper for `block_on`: performs actual graph modification
/// to add a dependency edge from `from_id` to `to_id`, which is
/// computing `database_key`.
fn add_edge(
///
/// # Safety
///
/// The caller needs to keep `from_stack`/`'aq`` alive until `from_id` has been removed from the `wait_results`.
// This safety invariant is consumed by the `Edge` struct
unsafe fn add_edge<'aq>(
&mut self,
from_id: ThreadId,
database_key: DatabaseKeyIndex,
to_id: ThreadId,
from_stack: QueryStack,
) -> Arc<parking_lot::Condvar> {
from_stack: &'aq mut [ActiveQuery],
) -> edge::EdgeGuard<'aq> {
assert_ne!(from_id, to_id);
debug_assert!(!self.edges.contains_key(&from_id));
debug_assert!(!self.depends_on(to_id, from_id));

let condvar = Arc::new(Condvar::new());
self.edges.insert(
from_id,
Edge {
blocked_on_id: to_id,
blocked_on_key: database_key,
stack: from_stack,
condvar: condvar.clone(),
},
);
let (edge, guard) = edge::Edge::new(to_id, database_key, from_stack);
self.edges.insert(from_id, edge);
self.query_dependents
.entry(database_key)
.or_default()
.push(from_id);
condvar
guard
}

/// Invoked when runtime `to_id` completes executing
Expand All @@ -253,11 +237,100 @@ impl DependencyGraph {
/// the lock on this data structure first, to recover the wait result).
fn unblock_runtime(&mut self, id: ThreadId, wait_result: WaitResult) {
let edge = self.edges.remove(&id).expect("not blocked");
self.wait_results.insert(id, (edge.stack, wait_result));
self.wait_results.insert(id, wait_result);

// Now that we have inserted the `wait_results`,
// notify the thread.
edge.condvar.notify_one();
edge.notify();
}
}

mod edge {
use std::{marker::PhantomData, ptr::NonNull, sync::Arc, thread::ThreadId};

use parking_lot::MutexGuard;

use crate::{
runtime::{dependency_graph::DependencyGraph, ActiveQuery},
DatabaseKeyIndex,
};

#[derive(Debug)]
pub(super) struct Edge {
pub(super) blocked_on_id: ThreadId,
pub(super) blocked_on_key: DatabaseKeyIndex,
stack: SendNonNull<[ActiveQuery]>,

/// Signalled whenever a query with dependents completes.
/// Allows those dependents to check if they are ready to unblock.
condvar: Arc<parking_lot::Condvar>,
}

pub struct EdgeGuard<'aq> {
condvar: Arc<parking_lot::Condvar>,
// Inform the borrow checker that the edge stack is borrowed until the guard is released.
// This is necessary to ensure that the stack is not modified by the caller of
// `DependencyGraph::add_edge` after the call returns.
_pd: PhantomData<&'aq mut ()>,
}

impl EdgeGuard<'_> {
pub fn wait(&self, mutex_guard: &mut MutexGuard<'_, DependencyGraph>) {
self.condvar.wait(mutex_guard)
}
}

// Wrapper type to allow `Edge` to be `Send` without disregarding its other fields.
struct SendNonNull<T: ?Sized>(NonNull<T>);

// SAFETY: `Edge` is `Send` as its `stack: NonNull<[ActiveQuery]>,` field is a lifetime erased
// mutable reference to a `Send` type (`ActiveQuery`) that is subject to the owner of `Edge` and is
// guaranteed to be live according to the safety invariants of `DependencyGraph::add_edge`.`
unsafe impl<T: ?Sized> Send for SendNonNull<T> where for<'a> &'a mut T: Send {}
// unsafe impl<T> Sync for SendNonNull<T> where for<'a> &'a mut T: Sync {}

impl<T: ?Sized + std::fmt::Debug> std::fmt::Debug for SendNonNull<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

impl Edge {
pub(super) fn new(
blocked_on_id: ThreadId,
blocked_on_key: DatabaseKeyIndex,
stack: &mut [ActiveQuery],
) -> (Self, EdgeGuard<'_>) {
let condvar = Arc::new(parking_lot::Condvar::new());
let stack = SendNonNull(NonNull::from(stack));
let edge = Self {
blocked_on_id,
blocked_on_key,
stack,
condvar: condvar.clone(),
};
let edge_guard = EdgeGuard {
condvar,
_pd: PhantomData,
};
(edge, edge_guard)
}

// unerase the lifetime of the stack
pub(super) fn stack_mut(&mut self) -> &mut [ActiveQuery] {
// SAFETY: This is safe due to the invariants upheld by DependencyGraph::add_edge.
unsafe { self.stack.0.as_mut() }
}

// unerase the lifetime of the stack
pub(super) fn stack(&self) -> &[ActiveQuery] {
// SAFETY: This is safe due to the invariants upheld by DependencyGraph::add_edge.
unsafe { self.stack.0.as_ref() }
}

pub(super) fn notify(self) {
self.condvar.notify_one();
}
}
}

Expand Down

0 comments on commit d1a944e

Please sign in to comment.