Skip to content

Commit

Permalink
Implement multi_scalar_mul
Browse files Browse the repository at this point in the history
  • Loading branch information
aakoshh committed Nov 21, 2024
1 parent efe9f27 commit d89d68d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,12 @@ pub(crate) fn to_byte_array(values: &[u8]) -> Value {
pub(crate) fn to_byte_slice(values: &[u8]) -> Value {
Value::Slice(values.iter().copied().map(Value::U8).collect(), byte_slice_type())
}

/// Create a `Value::Array` from fields.
pub(crate) fn to_field_array(values: &[FieldElement]) -> Value {
let typ = Type::Array(
Box::new(Type::Constant(values.len().into(), Kind::u32())),
Box::new(Type::FieldElement),
);
Value::Array(values.iter().copied().map(Value::Field).collect(), typ)
}
61 changes: 48 additions & 13 deletions compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ use crate::{
Value,
},
node_interner::NodeInterner,
Kind, Type,
Type,
};

use super::{
builtin::builtin_helpers::{
check_arguments, check_one_argument, check_three_arguments, check_two_arguments,
get_array_map, get_bool, get_field, get_fixed_array_map, get_slice_map, get_struct_field,
get_struct_fields, get_u32, get_u64, get_u8, to_byte_slice,
get_struct_fields, get_u32, get_u64, get_u8, to_byte_slice, to_field_array,
},
Interpreter,
};
Expand Down Expand Up @@ -69,6 +69,7 @@ fn call_foreign(
acvm::blackbox_solver::ecdsa_secp256r1_verify,
),
"embedded_curve_add" => embedded_curve_add(args, location),
"multi_scalar_mul" => multi_scalar_mul(interner, args, location),
"poseidon2_permutation" => poseidon2_permutation(interner, args, location),
"keccakf1600" => keccakf1600(interner, args, location),
"range" => apply_range_constraint(args, location),
Expand Down Expand Up @@ -254,7 +255,7 @@ fn ecdsa_secp256_verify(
Ok(Value::Bool(is_valid))
}

/// ```
/// ```text
/// fn embedded_curve_add(
/// point1: EmbeddedCurvePoint,
/// point2: EmbeddedCurvePoint,
Expand All @@ -266,20 +267,42 @@ fn embedded_curve_add(arguments: Vec<(Value, Location)>, location: Location) ->
let (p1x, p1y, p1inf) = get_embedded_curve_point(point1)?;
let (p2x, p2y, p2inf) = get_embedded_curve_point(point2)?;

let p1 = [p1x, p1y, p1inf.into()];
let p2 = [p2x, p2y, p2inf.into()];

let (x, y, inf) = bn254_blackbox_solver::embedded_curve_add(p1, p2)
let (x, y, inf) = Bn254BlackBoxSolver
.ec_add(&p1x, &p1y, &p1inf.into(), &p2x, &p2y, &p2inf.into())
.map_err(|e| InterpreterError::BlackBoxError(e, location))?;

let values = Vector::from_iter([x, y, inf].map(Value::Field));
Ok(to_field_array(&[x, y, inf]))
}

/// ```text
/// pub fn multi_scalar_mul<let N: u32>(
/// points: [EmbeddedCurvePoint; N],
/// scalars: [EmbeddedCurveScalar; N],
/// ) -> [Field; 3]
/// ```
fn multi_scalar_mul(
interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
let (points, scalars) = check_two_arguments(arguments, location)?;

let (points, _) = get_array_map(interner, points, get_embedded_curve_point)?;
let (scalars, _) = get_array_map(interner, scalars, get_embedded_curve_scalar)?;

let points: Vec<_> = points.into_iter().flat_map(|(x, y, inf)| [x, y, inf.into()]).collect();
let mut scalars_lo = Vec::new();
let mut scalars_hi = Vec::new();
for (lo, hi) in scalars {
scalars_lo.push(lo);
scalars_hi.push(hi);
}

let return_type = Type::Array(
Box::new(Type::Constant(values.len().into(), Kind::u32())),
Box::new(Type::FieldElement),
);
let (x, y, inf) = Bn254BlackBoxSolver
.multi_scalar_mul(&points, &scalars_lo, &scalars_hi)
.map_err(|e| InterpreterError::BlackBoxError(e, location))?;

Ok(Value::Array(values, return_type))
Ok(to_field_array(&[x, y, inf]))
}

/// `poseidon2_permutation<let N: u32>(_input: [Field; N], _state_length: u32) -> [Field; N]`
Expand Down Expand Up @@ -348,6 +371,18 @@ fn get_embedded_curve_point(
Ok((x, y, is_infinite))
}

/// Decode an `EmbeddedCurveScalar` struct.
///
/// Returns `(hi, lo)`.
fn get_embedded_curve_scalar(
(value, location): (Value, Location),
) -> IResult<(FieldElement, FieldElement)> {
let (fields, typ) = get_struct_fields("EmbeddedCurveScalar", (value, location))?;
let lo = get_struct_field("lo", &fields, &typ, location, get_field)?;
let hi = get_struct_field("hi", &fields, &typ, location, get_field)?;
Ok((lo, hi))
}

#[cfg(test)]
mod tests {
use acvm::acir::BlackBoxFunc;
Expand Down

0 comments on commit d89d68d

Please sign in to comment.