Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa): don't deduplicate constraints in blocks that are not dominated #6627

Merged
merged 23 commits into from
Nov 28, 2024
Merged
Changes from 10 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
2d19468
Add a test that fails
asterite Nov 26, 2024
98d246b
Add simplify function
aakoshh Nov 26, 2024
2dec436
Add SimplificationCache
aakoshh Nov 26, 2024
f27b1d2
Add docs
aakoshh Nov 26, 2024
4e9c140
Fix bug by restriting view
aakoshh Nov 26, 2024
2c95130
Allow multiple blocks
aakoshh Nov 26, 2024
fd19fb0
Add a regression test for instruction hoisting
asterite Nov 26, 2024
cd26a1b
Don't hoist instructions that have side effects
asterite Nov 26, 2024
85fbebd
Move Dom up
aakoshh Nov 26, 2024
7b28c59
Merge branch 'ab/constant_folding_bug' of github.com:noir-lang/noir i…
aakoshh Nov 26, 2024
e17a570
Use BTreeMap to store block IDs
aakoshh Nov 27, 2024
069f260
Walk up the dominator tree to find the first simplification
aakoshh Nov 27, 2024
3db0696
Do not find anything for unreachable nodes
aakoshh Nov 27, 2024
5b196d5
Add Instruction::has_side_effect and use it to disable hoisting
aakoshh Nov 27, 2024
911aedb
Update compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
aakoshh Nov 27, 2024
c07d760
Remove Predicate
aakoshh Nov 27, 2024
41502f6
No need for mutable any more. No need to check empty either
aakoshh Nov 27, 2024
8572b40
Merge branch 'master' into ab/constant_folding_bug
aakoshh Nov 27, 2024
05cf0a9
Fix comment
aakoshh Nov 27, 2024
1a9bb76
Merge branch 'ab/constant_folding_bug' of github.com:noir-lang/noir i…
aakoshh Nov 27, 2024
ae1da8d
Merge remote-tracking branch 'origin/master' into ab/constant_folding…
aakoshh Nov 27, 2024
1e835b5
Use requires_acir_gen_predicate in has_side_effects
aakoshh Nov 27, 2024
75090df
Merge branch 'master' into ab/constant_folding_bug
TomAFrench Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 185 additions & 53 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ impl Function {
use_constraint_info: bool,
brillig_info: Option<BrilligInfo>,
) {
let mut context = Context::new(self, use_constraint_info, brillig_info);
let mut context = Context::new(use_constraint_info, brillig_info);
let mut dom = DominatorTree::with_function(self);
context.block_queue.push_back(self.entry_block());

while let Some(block) = context.block_queue.pop_front() {
Expand All @@ -155,11 +156,14 @@ impl Function {
}

context.visited_blocks.insert(block);
context.fold_constants_in_block(self, block);
context.fold_constants_in_block(&mut self.dfg, &mut dom, block);
}
}
}

/// `ValueId` in an `EnableConstraintsIf` instruction.
type Predicate = ValueId;
aakoshh marked this conversation as resolved.
Show resolved Hide resolved

struct Context<'a> {
use_constraint_info: bool,
brillig_info: Option<BrilligInfo<'a>>,
Expand All @@ -174,12 +178,10 @@ struct Context<'a> {
/// We partition the maps of constrained values according to the side-effects flag at the point
/// at which the values are constrained. This prevents constraints which are only sometimes enforced
/// being used to modify the rest of the program.
constraint_simplification_mappings: HashMap<ValueId, HashMap<ValueId, ValueId>>,
constraint_simplification_mappings: ConstraintSimplificationCache,

// Cache of instructions without any side-effects along with their outputs.
cached_instruction_results: InstructionResultCache,

dom: DominatorTree,
}

#[derive(Copy, Clone)]
Expand All @@ -188,12 +190,60 @@ pub(crate) struct BrilligInfo<'a> {
brillig_functions: &'a BTreeMap<FunctionId, Function>,
}

/// Records a simplified equivalent of an [`Instruction`]s along with the blocks in which the
/// constraint that advised the simplification has been encountered.
///
/// For more information see [`ConstraintSimplificationCache`].
#[derive(Default)]
struct SimplificationCache {
simplifications: HashMap<BasicBlockId, ValueId>,
}

impl SimplificationCache {
/// Called with a newly encountered simplification.
fn add(&mut self, dfg: &DataFlowGraph, simple: ValueId, block: BasicBlockId) {
let existing = self.simplifications.entry(block).or_insert(simple);
// Keep the simpler expression in this block.
if *existing != simple {
match simplify(dfg, *existing, simple) {
Some((complex, simple)) if *existing == complex => {
*existing = simple;
}
_ => {}
}
}
aakoshh marked this conversation as resolved.
Show resolved Hide resolved
}

/// Try to find a simplification in a visible block.
fn get(&self, block: BasicBlockId, dom: &mut DominatorTree) -> Option<ValueId> {
aakoshh marked this conversation as resolved.
Show resolved Hide resolved
// See if we have a direct simplification in this block.
if let Some(value) = self.simplifications.get(&block) {
return Some(*value);
}
// Check if there is a dominating block we can take a simplification from.
for (constraining_block, value) in self.simplifications.iter() {
if dom.dominates(*constraining_block, block) {
return Some(*value);
}
}
None
}
}

/// HashMap from Instruction to a simplified expression that it can be replaced with based on
/// constraints that testify to their equivalence, stored together with the set of blocks at which
/// this constraint has been observed. Only blocks dominated by one in the cache should have
/// access to this information, otherwise we create a sort of time paradox where we replace
/// an instruction with a constant we believe _should_ be true about it, without ever actually
/// producing and asserting the value.
type ConstraintSimplificationCache = HashMap<Predicate, HashMap<ValueId, SimplificationCache>>;

/// HashMap from (Instruction, side_effects_enabled_var) to the results of the instruction.
/// Stored as a two-level map to avoid cloning Instructions during the `.get` call.
///
/// In addition to each result, the original BasicBlockId is stored as well. This allows us
/// to deduplicate instructions across blocks as long as the new block dominates the original.
type InstructionResultCache = HashMap<Instruction, HashMap<Option<ValueId>, ResultCache>>;
type InstructionResultCache = HashMap<Instruction, HashMap<Option<Predicate>, ResultCache>>;

/// Records the results of all duplicate [`Instruction`]s along with the blocks in which they sit.
///
Expand All @@ -204,65 +254,71 @@ struct ResultCache {
}

impl<'brillig> Context<'brillig> {
fn new(
function: &Function,
use_constraint_info: bool,
brillig_info: Option<BrilligInfo<'brillig>>,
) -> Self {
fn new(use_constraint_info: bool, brillig_info: Option<BrilligInfo<'brillig>>) -> Self {
Self {
use_constraint_info,
brillig_info,
visited_blocks: Default::default(),
block_queue: Default::default(),
constraint_simplification_mappings: Default::default(),
cached_instruction_results: Default::default(),
dom: DominatorTree::with_function(function),
}
}

fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) {
let instructions = function.dfg[block].take_instructions();
fn fold_constants_in_block(
&mut self,
dfg: &mut DataFlowGraph,
dom: &mut DominatorTree,
block: BasicBlockId,
) {
let instructions = dfg[block].take_instructions();

let mut side_effects_enabled_var =
function.dfg.make_constant(FieldElement::one(), Type::bool());
let mut side_effects_enabled_var = dfg.make_constant(FieldElement::one(), Type::bool());

for instruction_id in instructions {
self.fold_constants_into_instruction(
&mut function.dfg,
dfg,
dom,
block,
instruction_id,
&mut side_effects_enabled_var,
);
}
self.block_queue.extend(function.dfg[block].successors());
self.block_queue.extend(dfg[block].successors());
}

fn fold_constants_into_instruction(
&mut self,
dfg: &mut DataFlowGraph,
dom: &mut DominatorTree,
mut block: BasicBlockId,
id: InstructionId,
side_effects_enabled_var: &mut ValueId,
) {
let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var);
let instruction = Self::resolve_instruction(id, dfg, constraint_simplification_mapping);

let instruction =
Self::resolve_instruction(id, block, dfg, dom, constraint_simplification_mapping);

let old_results = dfg.instruction_results(id).to_vec();

// If a copy of this instruction exists earlier in the block, then reuse the previous results.
if let Some(cache_result) =
self.get_cached(dfg, &instruction, *side_effects_enabled_var, block)
self.get_cached(dfg, dom, &instruction, *side_effects_enabled_var, block)
{
match cache_result {
CacheResult::Cached(cached) => {
Self::replace_result_ids(dfg, &old_results, cached);
return;
}
CacheResult::NeedToHoistToCommonBlock(dominator, _cached) => {
// Just change the block to insert in the common dominator instead.
// This will only move the current instance of the instruction right now.
// When constant folding is run a second time later on, it'll catch
// that the previous instance can be deduplicated to this instance.
block = dominator;
if instruction_can_be_hoisted(&instruction, dfg, self.use_constraint_info) {
// Just change the block to insert in the common dominator instead.
// This will only move the current instance of the instruction right now.
// When constant folding is run a second time later on, it'll catch
// that the previous instance can be deduplicated to this instance.
block = dominator;
}
}
}
}
Expand Down Expand Up @@ -307,8 +363,10 @@ impl<'brillig> Context<'brillig> {
/// Fetches an [`Instruction`] by its [`InstructionId`] and fully resolves its inputs.
fn resolve_instruction(
instruction_id: InstructionId,
block: BasicBlockId,
dfg: &DataFlowGraph,
constraint_simplification_mapping: &HashMap<ValueId, ValueId>,
dom: &mut DominatorTree,
constraint_simplification_mapping: &HashMap<ValueId, SimplificationCache>,
) -> Instruction {
let instruction = dfg[instruction_id].clone();

Expand All @@ -318,20 +376,29 @@ impl<'brillig> Context<'brillig> {
// This allows us to reach a stable final `ValueId` for each instruction input as we add more
// constraints to the cache.
fn resolve_cache(
block: BasicBlockId,
dfg: &DataFlowGraph,
cache: &HashMap<ValueId, ValueId>,
dom: &mut DominatorTree,
cache: &HashMap<ValueId, SimplificationCache>,
value_id: ValueId,
) -> ValueId {
let resolved_id = dfg.resolve(value_id);
match cache.get(&resolved_id) {
Some(cached_value) => resolve_cache(dfg, cache, *cached_value),
Some(simplification_cache) => {
if let Some(simplified) = simplification_cache.get(block, dom) {
resolve_cache(block, dfg, dom, cache, simplified)
} else {
resolved_id
}
}
None => resolved_id,
}
}

// Resolve any inputs to ensure that we're comparing like-for-like instructions.
instruction
.map_values(|value_id| resolve_cache(dfg, constraint_simplification_mapping, value_id))
instruction.map_values(|value_id| {
resolve_cache(block, dfg, dom, constraint_simplification_mapping, value_id)
})
}

/// Pushes a new [`Instruction`] into the [`DataFlowGraph`] which applies any optimizations
Expand Down Expand Up @@ -377,26 +444,11 @@ impl<'brillig> Context<'brillig> {
// to map from the more complex to the simpler value.
if let Instruction::Constrain(lhs, rhs, _) = instruction {
// These `ValueId`s should be fully resolved now.
match (&dfg[lhs], &dfg[rhs]) {
// Ignore trivial constraints
(Value::NumericConstant { .. }, Value::NumericConstant { .. }) => (),

// Prefer replacing with constants where possible.
(Value::NumericConstant { .. }, _) => {
self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs);
}
(_, Value::NumericConstant { .. }) => {
self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs);
}
// Otherwise prefer block parameters over instruction results.
// This is as block parameters are more likely to be a single witness rather than a full expression.
(Value::Param { .. }, Value::Instruction { .. }) => {
self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs);
}
(Value::Instruction { .. }, Value::Param { .. }) => {
self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs);
}
(_, _) => (),
if let Some((complex, simple)) = simplify(dfg, lhs, rhs) {
self.get_constraint_map(side_effects_enabled_var)
.entry(complex)
.or_default()
.add(dfg, simple, block);
}
}
}
Expand All @@ -420,7 +472,7 @@ impl<'brillig> Context<'brillig> {
fn get_constraint_map(
&mut self,
side_effects_enabled_var: ValueId,
) -> &mut HashMap<ValueId, ValueId> {
) -> &mut HashMap<ValueId, SimplificationCache> {
self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default()
}

Expand All @@ -436,8 +488,9 @@ impl<'brillig> Context<'brillig> {
}

fn get_cached(
&mut self,
&self,
dfg: &DataFlowGraph,
dom: &mut DominatorTree,
instruction: &Instruction,
side_effects_enabled_var: ValueId,
block: BasicBlockId,
Expand All @@ -447,7 +500,7 @@ impl<'brillig> Context<'brillig> {
let predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg);
let predicate = predicate.then_some(side_effects_enabled_var);

results_for_instruction.get(&predicate)?.get(block, &mut self.dom)
results_for_instruction.get(&predicate)?.get(block, dom)
}

/// Checks if the given instruction is a call to a brillig function with all constant arguments.
Expand Down Expand Up @@ -611,6 +664,20 @@ impl<'brillig> Context<'brillig> {
}
}

fn instruction_can_be_hoisted(
instruction: &Instruction,
dfg: &mut DataFlowGraph,
deduplicate_with_predicate: bool,
) -> bool {
// These two can never be hoisted as they have a side-effect
// (though it's fine to de-duplicate them, just not fine to hoist them)
if matches!(instruction, Instruction::Constrain(..) | Instruction::RangeCheck { .. }) {
return false;
}

instruction.can_be_deduplicated(dfg, deduplicate_with_predicate)
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
}

impl ResultCache {
/// Records that an `Instruction` in block `block` produced the result values `results`.
fn cache(&mut self, block: BasicBlockId, results: Vec<ValueId>) {
Expand Down Expand Up @@ -687,6 +754,25 @@ fn value_id_to_calldata(value_id: ValueId, dfg: &DataFlowGraph, calldata: &mut V
panic!("Expected ValueId to be numeric constant or array constant");
}

/// Check if one expression is simpler than the other.
/// Returns `Some((complex, simple))` if a simplification was found, otherwise `None`.
/// Expects the `ValueId`s to be fully resolved.
fn simplify(dfg: &DataFlowGraph, lhs: ValueId, rhs: ValueId) -> Option<(ValueId, ValueId)> {
match (&dfg[lhs], &dfg[rhs]) {
// Ignore trivial constraints
(Value::NumericConstant { .. }, Value::NumericConstant { .. }) => None,

// Prefer replacing with constants where possible.
(Value::NumericConstant { .. }, _) => Some((rhs, lhs)),
(_, Value::NumericConstant { .. }) => Some((lhs, rhs)),
// Otherwise prefer block parameters over instruction results.
// This is as block parameters are more likely to be a single witness rather than a full expression.
(Value::Param { .. }, Value::Instruction { .. }) => Some((rhs, lhs)),
(Value::Instruction { .. }, Value::Param { .. }) => Some((lhs, rhs)),
(_, _) => None,
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;
Expand Down Expand Up @@ -1341,4 +1427,50 @@ mod test {
let ssa = ssa.fold_constants_with_brillig(&brillig);
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn does_not_use_cached_constrain_in_block_that_is_not_dominated() {
let src = "
brillig(inline) fn main f0 {
b0(v0: Field, v1: Field):
v3 = eq v0, Field 0
jmpif v3 then: b1, else: b2
b1():
v5 = eq v1, Field 1
constrain v1 == Field 1
jmp b2()
b2():
v6 = eq v1, Field 0
constrain v1 == Field 0
return
}
";
let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.fold_constants_using_constraints();
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn does_not_hoist_constrain_to_common_ancestor() {
let src = "
brillig(inline) fn main f0 {
b0(v0: Field, v1: Field):
v3 = eq v0, Field 0
jmpif v3 then: b1, else: b2
b1():
constrain v1 == Field 1
jmp b2()
b2():
jmpif v0 then: b3, else: b4
b3():
constrain v1 == Field 1 // This was incorrectly hoisted to b0 but this condition is not valid when going b0 -> b2 -> b4
jmp b4()
b4():
return
}
";
let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.fold_constants_using_constraints();
assert_normalized_ssa_equals(ssa, src);
}
}