Skip to content

Commit

Permalink
pass inputs to cycle recovery functions
Browse files Browse the repository at this point in the history
  • Loading branch information
carljm committed Nov 14, 2024
1 parent 84f5eab commit f44f2f7
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 20 deletions.
7 changes: 4 additions & 3 deletions components/salsa-macro-rules/src/setup_tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,17 @@ macro_rules! setup_tracked_fn {
$inner($db, $($input_id),*)
}

fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db) -> Self::Output<$db_lt> {
$($cycle_recovery_initial)*(db)
fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db, ($($input_id),*): ($($input_ty),*)) -> Self::Output<$db_lt> {
$($cycle_recovery_initial)*(db, $($input_id),*)
}

fn recover_from_cycle<$db_lt>(
db: &$db_lt dyn $Db,
value: &Self::Output<$db_lt>,
count: u32,
($($input_id),*): ($($input_ty),*)
) -> $zalsa::CycleRecoveryAction<Self::Output<$db_lt>> {
$($cycle_recovery_fn)*(db, value, count)
$($cycle_recovery_fn)*(db, value, count, $($input_id),*)
}

fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> {
Expand Down
10 changes: 6 additions & 4 deletions components/salsa-macro-rules/src/unexpected_cycle_recovery.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
// Macro that generates the body of the cycle recovery function
// for the case where no cycle recovery is possible. Must be a macro
// because the signature types must match the particular tracked function.
// for the case where no cycle recovery is possible. This has to be
// a macro because it can take a variadic number of arguments.
#[macro_export]
macro_rules! unexpected_cycle_recovery {
($db:ident, $value:ident, $count:ident) => {{
($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{
std::mem::drop($db);
std::mem::drop(($($other_inputs),*));
panic!("cannot recover from cycle")
}};
}

#[macro_export]
macro_rules! unexpected_cycle_initial {
($db:ident) => {{
($db:ident, $($other_inputs:ident),*) => {{
std::mem::drop($db);
std::mem::drop(($($other_inputs),*));
panic!("no cycle initial value")
}};
}
3 changes: 2 additions & 1 deletion src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ pub trait Configuration: Any {
fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>;

/// Get the cycle recovery initial value.
fn cycle_initial(db: &Self::DbView) -> Self::Output<'_>;
fn cycle_initial<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>;

/// Decide whether to iterate a cycle again or fallback.
fn recover_from_cycle<'db>(
db: &'db Self::DbView,
value: &Self::Output<'db>,
count: u32,
input: Self::Input<'db>,
) -> CycleRecoveryAction<Self::Output<'db>>;
}

Expand Down
7 changes: 6 additions & 1 deletion src/function/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ where
if !C::values_equal(&new_value, last_provisional_value) {
// We are in a cycle that hasn't converged; ask the user's
// cycle-recovery function what to do:
match C::recover_from_cycle(db, &new_value, iteration_count) {
match C::recover_from_cycle(
db,
&new_value,
iteration_count,
C::id_to_input(db, id),
) {
crate::CycleRecoveryAction::Iterate => {
tracing::debug!("{database_key_index:?}: execute: iterate again");
iteration_count = iteration_count.checked_add(1).expect(
Expand Down
2 changes: 1 addition & 1 deletion src/function/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ where
ClaimResult::Retry => return None,
ClaimResult::Cycle => {
return self
.initial_value(db)
.initial_value(db, database_key_index.key_index)
.map(|initial_value| {
tracing::debug!(
"hit cycle at {database_key_index:#?}, \
Expand Down
8 changes: 6 additions & 2 deletions src/function/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,13 @@ impl<C: Configuration> IngredientImpl<C> {
}
}

pub(super) fn initial_value<'db>(&'db self, db: &'db C::DbView) -> Option<C::Output<'db>> {
pub(super) fn initial_value<'db>(
&'db self,
db: &'db C::DbView,
key: Id,
) -> Option<C::Output<'db>> {
match C::CYCLE_STRATEGY {
CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db)),
CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db, C::id_to_input(db, key))),
CycleRecoveryStrategy::Panic => None,
}
}
Expand Down
30 changes: 26 additions & 4 deletions tests/cycle/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl Type {
}
}

#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)]
#[salsa::tracked(cycle_fn=use_cycle_recover, cycle_initial=use_cycle_initial)]
fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type {
let defs = u.reaching_definitions(db);
match defs[..] {
Expand All @@ -57,7 +57,7 @@ fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type {
}
}

#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)]
#[salsa::tracked(cycle_fn=def_cycle_recover, cycle_initial=def_cycle_initial)]
fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type {
let increment_ty = Type::Values(Box::from([def.increment(db)]));
if let Some(base) = def.base(db) {
Expand All @@ -68,11 +68,33 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type {
}
}

fn cycle_initial(_db: &dyn Db) -> Type {
fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type {
Type::Bottom
}

fn cycle_recover(_db: &dyn Db, value: &Type, count: u32) -> CycleRecoveryAction<Type> {
fn def_cycle_recover(
_db: &dyn Db,
value: &Type,
count: u32,
_def: Definition,
) -> CycleRecoveryAction<Type> {
cycle_recover(value, count)
}

fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type {
Type::Bottom
}

fn use_cycle_recover(
_db: &dyn Db,
value: &Type,
count: u32,
_use: Use,
) -> CycleRecoveryAction<Type> {
cycle_recover(value, count)
}

fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction<Type> {
match value {
Type::Bottom => CycleRecoveryAction::Iterate,
Type::Values(_) => {
Expand Down
8 changes: 4 additions & 4 deletions tests/cycle/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ const MIN_COUNT_FALLBACK: u8 = 100;
const MIN_VALUE_FALLBACK: u8 = 5;
const MIN_VALUE: u8 = 10;

fn min_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction<u8> {
fn min_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction<u8> {
if *value < MIN_VALUE {
CycleRecoveryAction::Fallback(MIN_VALUE_FALLBACK)
} else if count > 10 {
Expand All @@ -86,7 +86,7 @@ fn min_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction<u8>
}
}

fn min_initial(_db: &dyn Db) -> u8 {
fn min_initial(_db: &dyn Db, _inputs: Inputs) -> u8 {
255
}

Expand All @@ -99,7 +99,7 @@ const MAX_COUNT_FALLBACK: u8 = 200;
const MAX_VALUE_FALLBACK: u8 = 250;
const MAX_VALUE: u8 = 245;

fn max_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction<u8> {
fn max_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction<u8> {
if *value > MAX_VALUE {
CycleRecoveryAction::Fallback(MAX_VALUE_FALLBACK)
} else if count > 10 {
Expand All @@ -109,7 +109,7 @@ fn max_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction<u8>
}
}

fn max_initial(_db: &dyn Db) -> u8 {
fn max_initial(_db: &dyn Db, _inputs: Inputs) -> u8 {
0
}

Expand Down

0 comments on commit f44f2f7

Please sign in to comment.