Skip to content

Commit

Permalink
fix: Discard optimisation that would change execution ordering or tha…
Browse files Browse the repository at this point in the history
…t is related to call outputs (#6461)

Co-authored-by: Tom French <[email protected]>
  • Loading branch information
guipublic and TomAFrench authored Nov 6, 2024
1 parent 9ef8369 commit b8654f7
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 18 deletions.
194 changes: 176 additions & 18 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use std::collections::{BTreeMap, BTreeSet, HashMap};

use acir::{
circuit::{brillig::BrilligInputs, directives::Directive, opcodes::BlockId, Circuit, Opcode},
circuit::{
brillig::{BrilligInputs, BrilligOutputs},
directives::Directive,
opcodes::BlockId,
Circuit, Opcode,
},
native_types::{Expression, Witness},
AcirField,
};
Expand Down Expand Up @@ -72,23 +77,31 @@ impl MergeExpressionsOptimizer {
if let (Opcode::AssertZero(expr_define), Opcode::AssertZero(expr_use)) =
(opcode.clone(), second_gate)
{
if let Some(expr) = Self::merge(&expr_use, &expr_define, w) {
// sanity check
assert!(i < b);
modified_gates.insert(b, Opcode::AssertZero(expr));
to_keep = false;
// Update the 'used_witness' map to account for the merge.
for w2 in CircuitSimulator::expr_wit(&expr_define) {
if !circuit_inputs.contains(&w2) {
let mut v = used_witness[&w2].clone();
v.insert(b);
v.remove(&i);
used_witness.insert(w2, v);
// We cannot merge an expression into an earlier opcode, because this
// would break the 'execution ordering' of the opcodes
// This case can happen because a previous merge would change an opcode
// and eliminate a witness from it, giving new opportunities for this
// witness to be used in only two expressions
// TODO: the missed optimization for the i>b case can be handled by
// - doing this pass again until there is no change, or
// - merging 'b' into 'i' instead
if i < b {
if let Some(expr) = Self::merge(&expr_use, &expr_define, w) {
modified_gates.insert(b, Opcode::AssertZero(expr));
to_keep = false;
// Update the 'used_witness' map to account for the merge.
for w2 in CircuitSimulator::expr_wit(&expr_define) {
if !circuit_inputs.contains(&w2) {
let mut v = used_witness[&w2].clone();
v.insert(b);
v.remove(&i);
used_witness.insert(w2, v);
}
}
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
}
// We need to stop here and continue with the next opcode
// because the merge invalidate the current opcode
break;
}
}
}
Expand Down Expand Up @@ -125,6 +138,19 @@ impl MergeExpressionsOptimizer {
result
}

fn brillig_output_wit(&self, output: &BrilligOutputs) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
match output {
BrilligOutputs::Simple(witness) => {
result.insert(*witness);
}
BrilligOutputs::Array(witnesses) => {
result.extend(witnesses);
}
}
result
}

// Returns the input witnesses used by the opcode
fn witness_inputs<F: AcirField>(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
let mut witnesses = BTreeSet::new();
Expand All @@ -146,16 +172,22 @@ impl MergeExpressionsOptimizer {
Opcode::MemoryInit { block_id: _, init, block_type: _ } => {
init.iter().cloned().collect()
}
Opcode::BrilligCall { inputs, .. } => {
Opcode::BrilligCall { inputs, outputs, .. } => {
for i in inputs {
witnesses.extend(self.brillig_input_wit(i));
}
for i in outputs {
witnesses.extend(self.brillig_output_wit(i));
}
witnesses
}
Opcode::Call { id: _, inputs, outputs: _, predicate } => {
Opcode::Call { id: _, inputs, outputs, predicate } => {
for i in inputs {
witnesses.insert(*i);
}
for i in outputs {
witnesses.insert(*i);
}
if let Some(p) = predicate {
witnesses.extend(CircuitSimulator::expr_wit(p));
}
Expand Down Expand Up @@ -195,3 +227,129 @@ impl MergeExpressionsOptimizer {
None
}
}

#[cfg(test)]
mod tests {
use crate::compiler::{optimizers::MergeExpressionsOptimizer, CircuitSimulator};
use acir::{
acir_field::AcirField,
circuit::{
brillig::{BrilligFunctionId, BrilligOutputs},
opcodes::FunctionInput,
Circuit, ExpressionWidth, Opcode, PublicInputs,
},
native_types::{Expression, Witness},
FieldElement,
};
use std::collections::BTreeSet;

fn check_circuit(circuit: Circuit<FieldElement>) {
assert!(CircuitSimulator::default().check_circuit(&circuit));
let mut merge_optimizer = MergeExpressionsOptimizer::new();
let acir_opcode_positions = vec![0; 20];
let (opcodes, _) =
merge_optimizer.eliminate_intermediate_variable(&circuit, acir_opcode_positions);
let mut optimized_circuit = circuit;
optimized_circuit.opcodes = opcodes;
// check that the circuit is still valid after optimization
assert!(CircuitSimulator::default().check_circuit(&optimized_circuit));
}

#[test]
fn does_not_eliminate_witnesses_returned_from_brillig() {
let opcodes = vec![
Opcode::BrilligCall {
id: BrilligFunctionId::default(),
inputs: Vec::new(),
outputs: vec![BrilligOutputs::Simple(Witness(1))],
predicate: None,
},
Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
linear_combinations: vec![
(FieldElement::from(2_u128), Witness(0)),
(FieldElement::from(3_u128), Witness(1)),
(FieldElement::from(1_u128), Witness(2)),
],
q_c: FieldElement::one(),
}),
Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
linear_combinations: vec![
(FieldElement::from(2_u128), Witness(0)),
(FieldElement::from(2_u128), Witness(1)),
(FieldElement::from(1_u128), Witness(5)),
],
q_c: FieldElement::one(),
}),
];

let mut private_parameters = BTreeSet::new();
private_parameters.insert(Witness(0));

let circuit = Circuit {
current_witness_index: 1,
expression_width: ExpressionWidth::Bounded { width: 4 },
opcodes,
private_parameters,
public_parameters: PublicInputs::default(),
return_values: PublicInputs::default(),
assert_messages: Default::default(),
recursive: false,
};
check_circuit(circuit);
}

#[test]
fn does_not_attempt_to_merge_into_previous_opcodes() {
let opcodes = vec![
Opcode::AssertZero(Expression {
mul_terms: vec![(FieldElement::one(), Witness(0), Witness(0))],
linear_combinations: vec![(-FieldElement::one(), Witness(4))],
q_c: FieldElement::zero(),
}),
Opcode::AssertZero(Expression {
mul_terms: vec![(FieldElement::one(), Witness(0), Witness(1))],
linear_combinations: vec![(FieldElement::one(), Witness(5))],
q_c: FieldElement::zero(),
}),
Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
linear_combinations: vec![
(-FieldElement::one(), Witness(2)),
(FieldElement::one(), Witness(4)),
(FieldElement::one(), Witness(5)),
],
q_c: FieldElement::zero(),
}),
Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
linear_combinations: vec![
(FieldElement::one(), Witness(2)),
(-FieldElement::one(), Witness(3)),
(FieldElement::one(), Witness(4)),
(FieldElement::one(), Witness(5)),
],
q_c: FieldElement::zero(),
}),
Opcode::BlackBoxFuncCall(acir::circuit::opcodes::BlackBoxFuncCall::RANGE {
input: FunctionInput::witness(Witness(3), 32),
}),
];

let mut private_parameters = BTreeSet::new();
private_parameters.insert(Witness(0));
private_parameters.insert(Witness(1));
let circuit = Circuit {
current_witness_index: 5,
expression_width: ExpressionWidth::Bounded { width: 4 },
opcodes,
private_parameters,
public_parameters: PublicInputs::default(),
return_values: PublicInputs::default(),
assert_messages: Default::default(),
recursive: false,
};
check_circuit(circuit);
}
}
5 changes: 5 additions & 0 deletions test_programs/execution_success/regression_6451/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
name = "regression_6451"
type = "bin"
authors = [""]
[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = 0
23 changes: 23 additions & 0 deletions test_programs/execution_success/regression_6451/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
fn main(x: Field) {
// Regression test for #6451
let y = unsafe { empty(x) };
let mut value = 0;
let term1 = x * x - x * y;
std::as_witness(term1);
value += term1;
let term2 = x * x - y * x;
value += term2;
value.assert_max_bit_size::<1>();

// Regression test for Aztec Packages issue #6451
let y = unsafe { empty(x + 1) };
let z = y + x + 1;
let z1 = z + y;
assert(z + z1 != 3);
let w = y + 2 * x + 3;
assert(w + z1 != z);
}

unconstrained fn empty(_: Field) -> Field {
0
}

0 comments on commit b8654f7

Please sign in to comment.