From 8932dac4847c643341320c2893f7e4297c78c621 Mon Sep 17 00:00:00 2001 From: jfecher Date: Fri, 15 Nov 2024 10:45:11 -0600 Subject: [PATCH] fix: Fix poor handling of aliased references in flattening pass causing some values to be zeroed (#6434) Co-authored-by: Tom French Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> --- .../noirc_evaluator/src/ssa/ir/instruction.rs | 22 +- compiler/noirc_evaluator/src/ssa/opt/die.rs | 2 +- .../src/ssa/opt/flatten_cfg.rs | 269 +++++------------- .../src/ssa/opt/flatten_cfg/value_merger.rs | 43 +-- .../noirc_evaluator/src/ssa/opt/mem2reg.rs | 17 +- 5 files changed, 133 insertions(+), 220 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index f187a279b9b..254a0afe88b 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -14,11 +14,12 @@ use fxhash::FxHasher; use iter_extended::vecmap; use noirc_frontend::hir_def::types::Type as HirType; -use crate::ssa::opt::flatten_cfg::value_merger::ValueMerger; +use crate::ssa::{ir::function::RuntimeType, opt::flatten_cfg::value_merger::ValueMerger}; use super::{ basic_block::BasicBlockId, dfg::{CallStack, DataFlowGraph}, + function::Function, map::Id, types::{NumericType, Type}, value::{Value, ValueId}, @@ -363,12 +364,12 @@ impl Instruction { } } - pub(crate) fn can_eliminate_if_unused(&self, dfg: &DataFlowGraph) -> bool { + pub(crate) fn can_eliminate_if_unused(&self, function: &Function) -> bool { use Instruction::*; match self { Binary(binary) => { if matches!(binary.operator, BinaryOp::Div | BinaryOp::Mod) { - if let Some(rhs) = dfg.get_numeric_constant(binary.rhs) { + if let Some(rhs) = function.dfg.get_numeric_constant(binary.rhs) { rhs != FieldElement::zero() } else { false @@ -386,15 +387,26 @@ impl Instruction { | IfElse { .. } | ArraySet { .. } => true, + // Store instructions must be removed by DIE in acir code, any load + // instructions should already be unused by that point. + // + // Note that this check assumes that it is being performed after the flattening + // pass and after the last mem2reg pass. This is currently the case for the DIE + // pass where this check is done, but does mean that we cannot perform mem2reg + // after the DIE pass. + Store { .. } => { + matches!(function.runtime(), RuntimeType::Acir(_)) + && function.reachable_blocks().len() == 1 + } + Constrain(..) - | Store { .. } | EnableSideEffectsIf { .. } | IncrementRc { .. } | DecrementRc { .. } | RangeCheck { .. } => false, // Some `Intrinsic`s have side effects so we must check what kind of `Call` this is. - Call { func, .. } => match dfg[*func] { + Call { func, .. } => match function.dfg[*func] { // Explicitly allows removal of unused ec operations, even if they can fail Value::Intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul)) | Value::Intrinsic(Intrinsic::BlackBox(BlackBoxFunc::EmbeddedCurveAdd)) => true, diff --git a/compiler/noirc_evaluator/src/ssa/opt/die.rs b/compiler/noirc_evaluator/src/ssa/opt/die.rs index 57af27e8dcd..a7b7af91a18 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -172,7 +172,7 @@ impl Context { fn is_unused(&self, instruction_id: InstructionId, function: &Function) -> bool { let instruction = &function.dfg[instruction_id]; - if instruction.can_eliminate_if_unused(&function.dfg) { + if instruction.can_eliminate_if_unused(function) { let results = function.dfg.instruction_results(instruction_id); results.iter().all(|result| !self.used_values.contains(result)) } else if let Instruction::Call { func, arguments } = instruction { diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index db2d96aac81..54c21a68ea2 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -132,7 +132,6 @@ //! v12 = add v10, v11 //! store v12 at v5 (new store) use fxhash::FxHashMap as HashMap; -use std::collections::{BTreeMap, HashSet}; use acvm::{acir::AcirField, acir::BlackBoxFunc, FieldElement}; use iter_extended::vecmap; @@ -186,18 +185,6 @@ struct Context<'f> { /// Maps start of branch -> end of branch branch_ends: HashMap, - /// Maps an address to the old and new value of the element at that address - /// These only hold stores for one block at a time and is cleared - /// between inlining of branches. - store_values: HashMap, - - /// Stores all allocations local to the current branch. - /// Since these branches are local to the current branch (ie. only defined within one branch of - /// an if expression), they should not be merged with their previous value or stored value in - /// the other branch since there is no such value. The ValueId here is that which is returned - /// by the allocate instruction. - local_allocations: HashSet, - /// A stack of each jmpif condition that was taken to reach a particular point in the program. /// When two branches are merged back into one, this constitutes a join point, and is analogous /// to the rest of the program after an if statement. When such a join point / end block is @@ -216,13 +203,6 @@ struct Context<'f> { arguments_stack: Vec>, } -#[derive(Clone)] -pub(crate) struct Store { - old_value: ValueId, - new_value: ValueId, - call_stack: CallStack, -} - #[derive(Clone)] struct ConditionalBranch { // Contains the last processed block during the processing of the branch. @@ -231,10 +211,6 @@ struct ConditionalBranch { old_condition: ValueId, // The condition of the branch condition: ValueId, - // The store values accumulated when processing the branch - store_values: HashMap, - // The allocations accumulated when processing the branch - local_allocations: HashSet, } struct ConditionalContext { @@ -263,8 +239,6 @@ fn flatten_function_cfg(function: &mut Function, no_predicates: &HashMap Context<'f> { // If this is not a separate variable, clippy gets confused and says the to_vec is // unnecessary, when removing it actually causes an aliasing/mutability error. let instructions = self.inserter.function.dfg[block].instructions().to_vec(); + let mut previous_allocate_result = None; + for instruction in instructions.iter() { if self.is_no_predicate(no_predicates, instruction) { // disable side effect for no_predicate functions @@ -356,10 +332,10 @@ impl<'f> Context<'f> { None, im::Vector::new(), ); - self.push_instruction(*instruction); + self.push_instruction(*instruction, &mut previous_allocate_result); self.insert_current_side_effects_enabled(); } else { - self.push_instruction(*instruction); + self.push_instruction(*instruction, &mut previous_allocate_result); } } } @@ -429,13 +405,9 @@ impl<'f> Context<'f> { let old_condition = *condition; let then_condition = self.inserter.resolve(old_condition); - let old_stores = std::mem::take(&mut self.store_values); - let old_allocations = std::mem::take(&mut self.local_allocations); let branch = ConditionalBranch { old_condition, condition: self.link_condition(then_condition), - store_values: old_stores, - local_allocations: old_allocations, last_block: *then_destination, }; let cond_context = ConditionalContext { @@ -463,21 +435,11 @@ impl<'f> Context<'f> { ); let else_condition = self.link_condition(else_condition); - // Make sure the else branch sees the previous values of each store - // rather than any values created in the 'then' branch. - let old_stores = std::mem::take(&mut cond_context.then_branch.store_values); - cond_context.then_branch.store_values = std::mem::take(&mut self.store_values); - self.undo_stores_in_then_branch(&cond_context.then_branch.store_values); - - let old_allocations = std::mem::take(&mut self.local_allocations); let else_branch = ConditionalBranch { old_condition: cond_context.then_branch.old_condition, condition: else_condition, - store_values: old_stores, - local_allocations: old_allocations, last_block: *block, }; - cond_context.then_branch.local_allocations.clear(); cond_context.else_branch = Some(else_branch); self.condition_stack.push(cond_context); @@ -499,10 +461,7 @@ impl<'f> Context<'f> { } let mut else_branch = cond_context.else_branch.unwrap(); - let stores_in_branch = std::mem::replace(&mut self.store_values, else_branch.store_values); - self.local_allocations = std::mem::take(&mut else_branch.local_allocations); else_branch.last_block = *block; - else_branch.store_values = stores_in_branch; cond_context.else_branch = Some(else_branch); // We must remember to reset whether side effects are enabled when both branches @@ -571,8 +530,6 @@ impl<'f> Context<'f> { .first() }); - let call_stack = cond_context.call_stack; - self.merge_stores(cond_context.then_branch, cond_context.else_branch, call_stack); self.arguments_stack.pop(); self.arguments_stack.pop(); self.arguments_stack.push(args); @@ -627,130 +584,47 @@ impl<'f> Context<'f> { self.insert_instruction_with_typevars(enable_side_effects, None, call_stack); } - /// Merge any store instructions found in each branch. - /// - /// This function relies on the 'then' branch being merged before the 'else' branch of a jmpif - /// instruction. If this ordering is changed, the ordering that store values are merged within - /// this function also needs to be changed to reflect that. - fn merge_stores( - &mut self, - then_branch: ConditionalBranch, - else_branch: Option, - call_stack: CallStack, - ) { - // Address -> (then_value, else_value, value_before_the_if) - let mut new_map = BTreeMap::new(); - - for (address, store) in then_branch.store_values { - new_map.insert(address, (store.new_value, store.old_value, store.old_value)); - } - - if else_branch.is_some() { - for (address, store) in else_branch.clone().unwrap().store_values { - if let Some(entry) = new_map.get_mut(&address) { - entry.1 = store.new_value; - } else { - new_map.insert(address, (store.old_value, store.new_value, store.old_value)); - } - } - } - - let then_condition = then_branch.condition; - let else_condition = if let Some(branch) = else_branch { - branch.condition - } else { - self.inserter.function.dfg.make_constant(FieldElement::zero(), Type::bool()) - }; - let block = self.inserter.function.entry_block(); - - // Merging must occur in a separate loop as we cannot borrow `self` as mutable while `value_merger` does - let mut new_values = HashMap::default(); - for (address, (then_case, else_case, _)) in &new_map { - let instruction = Instruction::IfElse { - then_condition, - then_value: *then_case, - else_condition, - else_value: *else_case, - }; - let dfg = &mut self.inserter.function.dfg; - let value = dfg - .insert_instruction_and_results(instruction, block, None, call_stack.clone()) - .first(); - - new_values.insert(address, value); - } - - // Replace stores with new merged values - for (address, (_, _, old_value)) in &new_map { - let value = new_values[address]; - let address = *address; - self.insert_instruction_with_typevars( - Instruction::Store { address, value }, - None, - call_stack.clone(), - ); - - if let Some(store) = self.store_values.get_mut(&address) { - store.new_value = value; - } else { - self.store_values.insert( - address, - Store { - old_value: *old_value, - new_value: value, - call_stack: call_stack.clone(), - }, - ); - } - } - } - - fn remember_store(&mut self, address: ValueId, new_value: ValueId, call_stack: CallStack) { - if !self.local_allocations.contains(&address) { - if let Some(store_value) = self.store_values.get_mut(&address) { - store_value.new_value = new_value; - } else { - let load = Instruction::Load { address }; - - let load_type = Some(vec![self.inserter.function.dfg.type_of_value(new_value)]); - let old_value = self - .insert_instruction_with_typevars(load.clone(), load_type, call_stack.clone()) - .first(); - - self.store_values.insert(address, Store { old_value, new_value, call_stack }); - } - } - } - /// Push the given instruction to the end of the entry block of the current function. /// /// Note that each ValueId of the instruction will be mapped via self.inserter.resolve. /// As a result, the instruction that will be pushed will actually be a new instruction /// with a different InstructionId from the original. The results of the given instruction /// will also be mapped to the results of the new instruction. - fn push_instruction(&mut self, id: InstructionId) -> Vec { + /// + /// `previous_allocate_result` should only be set to the result of an allocate instruction + /// if that instruction was the instruction immediately previous to this one - if there are + /// any instructions in between it should be None. + fn push_instruction( + &mut self, + id: InstructionId, + previous_allocate_result: &mut Option, + ) -> Vec { let (instruction, call_stack) = self.inserter.map_instruction(id); - let instruction = self.handle_instruction_side_effects(instruction, call_stack.clone()); - let is_allocate = matches!(instruction, Instruction::Allocate); + let instruction = self.handle_instruction_side_effects( + instruction, + call_stack.clone(), + *previous_allocate_result, + ); + let instruction_is_allocate = matches!(&instruction, Instruction::Allocate); let entry = self.inserter.function.entry_block(); let results = self.inserter.push_instruction_value(instruction, id, entry, call_stack); - // Remember an allocate was created local to this branch so that we do not try to merge store - // values across branches for it later. - if is_allocate { - self.local_allocations.insert(results.first()); - } - + *previous_allocate_result = instruction_is_allocate.then(|| results.first()); results.results().into_owned() } /// If we are currently in a branch, we need to modify constrain instructions /// to multiply them by the branch's condition (see optimization #1 in the module comment). + /// + /// `previous_allocate_result` should only be set to the result of an allocate instruction + /// if that instruction was the instruction immediately previous to this one - if there are + /// any instructions in between it should be None. fn handle_instruction_side_effects( &mut self, instruction: Instruction, call_stack: CallStack, + previous_allocate_result: Option, ) -> Instruction { if let Some(condition) = self.get_last_condition() { match instruction { @@ -779,8 +653,35 @@ impl<'f> Context<'f> { Instruction::Constrain(lhs, rhs, message) } Instruction::Store { address, value } => { - self.remember_store(address, value, call_stack); - Instruction::Store { address, value } + // If this instruction immediately follows an allocate, and stores to that + // address there is no previous value to load and we don't need a merge anyway. + if Some(address) == previous_allocate_result { + Instruction::Store { address, value } + } else { + // Instead of storing `value`, store `if condition { value } else { previous_value }` + let typ = self.inserter.function.dfg.type_of_value(value); + let load = Instruction::Load { address }; + let previous_value = self + .insert_instruction_with_typevars( + load, + Some(vec![typ]), + call_stack.clone(), + ) + .first(); + + let not = Instruction::Not(condition); + let else_condition = self.insert_instruction(not, call_stack.clone()); + + let instruction = Instruction::IfElse { + then_condition: condition, + then_value: value, + else_condition, + else_value: previous_value, + }; + + let updated_value = self.insert_instruction(instruction, call_stack); + Instruction::Store { address, value: updated_value } + } } Instruction::RangeCheck { value, max_bit_size, assert_message } => { // Replace value with `value * predicate` to zero out value when predicate is inactive. @@ -902,16 +803,6 @@ impl<'f> Context<'f> { call_stack, ) } - - fn undo_stores_in_then_branch(&mut self, store_values: &HashMap) { - for (address, store) in store_values { - let address = *address; - let value = store.old_value; - let instruction = Instruction::Store { address, value }; - // Considering the location of undoing a store to be the same as the original store. - self.insert_instruction_with_typevars(instruction, None, store.call_stack.clone()); - } - } } #[cfg(test)] @@ -958,11 +849,9 @@ mod test { v1 = not v0 enable_side_effects u1 1 v3 = cast v0 as Field - v4 = cast v1 as Field - v6 = mul v3, Field 3 - v8 = mul v4, Field 4 - v9 = add v6, v8 - return v9 + v5 = mul v3, Field -1 + v7 = add Field 4, v5 + return v7 } "; @@ -1022,16 +911,14 @@ mod test { b0(v0: u1, v1: &mut Field): enable_side_effects v0 v2 = load v1 -> Field - store Field 5 at v1 - v4 = not v0 - store v2 at v1 + v3 = not v0 + v4 = cast v0 as Field + v6 = sub Field 5, v2 + v7 = mul v4, v6 + v8 = add v2, v7 + store v8 at v1 + v9 = not v0 enable_side_effects u1 1 - v6 = cast v0 as Field - v7 = cast v4 as Field - v8 = mul v6, Field 5 - v9 = mul v7, v2 - v10 = add v8, v9 - store v10 at v1 return } "; @@ -1062,19 +949,21 @@ mod test { b0(v0: u1, v1: &mut Field): enable_side_effects v0 v2 = load v1 -> Field - store Field 5 at v1 - v4 = not v0 - store v2 at v1 - enable_side_effects v4 - v5 = load v1 -> Field - store Field 6 at v1 + v3 = not v0 + v4 = cast v0 as Field + v6 = sub Field 5, v2 + v7 = mul v4, v6 + v8 = add v2, v7 + store v8 at v1 + v9 = not v0 + enable_side_effects v9 + v10 = load v1 -> Field + v11 = cast v9 as Field + v13 = sub Field 6, v10 + v14 = mul v11, v13 + v15 = add v10, v14 + store v15 at v1 enable_side_effects u1 1 - v8 = cast v0 as Field - v9 = cast v4 as Field - v10 = mul v8, Field 5 - v11 = mul v9, Field 6 - v12 = add v10, v11 - store v12 at v1 return } "; @@ -1242,7 +1131,7 @@ mod test { }; let merged_values = get_all_constants_reachable_from_instruction(&main.dfg, ret); - assert_eq!(merged_values, vec![3, 5, 6]); + assert_eq!(merged_values, vec![1, 3, 5, 6]); } #[test] diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs index 75ee57dd4fa..799378b1678 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs @@ -91,7 +91,7 @@ impl<'a> ValueMerger<'a> { dfg: &mut DataFlowGraph, block: BasicBlockId, then_condition: ValueId, - else_condition: ValueId, + _else_condition: ValueId, then_value: ValueId, else_value: ValueId, ) -> ValueId { @@ -114,31 +114,38 @@ impl<'a> ValueMerger<'a> { // We must cast the bool conditions to the actual numeric type used by each value. let then_condition = dfg .insert_instruction_and_results( - Instruction::Cast(then_condition, then_type), - block, - None, - call_stack.clone(), - ) - .first(); - let else_condition = dfg - .insert_instruction_and_results( - Instruction::Cast(else_condition, else_type), + Instruction::Cast(then_condition, Type::field()), block, None, call_stack.clone(), ) .first(); - let mul = Instruction::binary(BinaryOp::Mul, then_condition, then_value); - let then_value = - dfg.insert_instruction_and_results(mul, block, None, call_stack.clone()).first(); + let then_field = Instruction::Cast(then_value, Type::field()); + let then_field_value = + dfg.insert_instruction_and_results(then_field, block, None, call_stack.clone()).first(); + + let else_field = Instruction::Cast(else_value, Type::field()); + let else_field_value = + dfg.insert_instruction_and_results(else_field, block, None, call_stack.clone()).first(); + + let diff = Instruction::binary(BinaryOp::Sub, then_field_value, else_field_value); + let diff_value = + dfg.insert_instruction_and_results(diff, block, None, call_stack.clone()).first(); - let mul = Instruction::binary(BinaryOp::Mul, else_condition, else_value); - let else_value = - dfg.insert_instruction_and_results(mul, block, None, call_stack.clone()).first(); + let conditional_diff = Instruction::binary(BinaryOp::Mul, then_condition, diff_value); + let conditional_diff_value = dfg + .insert_instruction_and_results(conditional_diff, block, None, call_stack.clone()) + .first(); + + let merged_field = + Instruction::binary(BinaryOp::Add, else_field_value, conditional_diff_value); + let merged_field_value = dfg + .insert_instruction_and_results(merged_field, block, None, call_stack.clone()) + .first(); - let add = Instruction::binary(BinaryOp::Add, then_value, else_value); - dfg.insert_instruction_and_results(add, block, None, call_stack).first() + let merged = Instruction::Cast(merged_field_value, then_type); + dfg.insert_instruction_and_results(merged, block, None, call_stack).first() } /// Given an if expression that returns an array: `if c { array1 } else { array2 }`, diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index a052abc5e16..38d73e3dca8 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -428,11 +428,13 @@ impl<'f> PerFunctionContext<'f> { self.check_array_aliasing(references, value); + // FIXME: This causes errors in the sha256 tests + // // If there was another store to this instruction without any (unremoved) loads or // function calls in-between, we can remove the previous store. - if let Some(last_store) = references.last_stores.get(&address) { - self.instructions_to_remove.insert(*last_store); - } + // if let Some(last_store) = references.last_stores.get(&address) { + // self.instructions_to_remove.insert(*last_store); + // } if self.inserter.function.dfg.value_is_reference(value) { if let Some(expression) = references.expressions.get(&value) { @@ -908,16 +910,19 @@ mod tests { // We would need to track whether the store where `v9` is the store value gets removed to know whether // to remove it. assert_eq!(count_stores(main.entry_block(), &main.dfg), 1); + // The first store in b1 is removed since there is another store to the same reference // in the same block, and the store is not needed before the later store. // The rest of the stores are also removed as no loads are done within any blocks // to the stored values. - assert_eq!(count_stores(b1, &main.dfg), 0); + // + // NOTE: This store is not removed due to the FIXME when handling Instruction::Store. + assert_eq!(count_stores(b1, &main.dfg), 1); let b1_instructions = main.dfg[b1].instructions(); - // We expect the last eq to be optimized out - assert_eq!(b1_instructions.len(), 0); + // We expect the last eq to be optimized out, only the store from above remains + assert_eq!(b1_instructions.len(), 1); } #[test]