Skip to content

Commit

Permalink
Merge pull request #886 from toposware/poseidon-native
Browse files Browse the repository at this point in the history
Add FFT-based specification for Poseidon MDS layer on x86 targets
  • Loading branch information
nbgl authored Mar 16, 2023
2 parents 1576a30 + bb2233c commit 2d7d2ac
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 1 deletion.
15 changes: 14 additions & 1 deletion field/src/field_testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,22 @@ macro_rules! test_field_arithmetic {

use num::bigint::BigUint;
use rand::rngs::OsRng;
use rand::Rng;
use rand::{Rng, RngCore};
use $crate::types::{Field, Sample};

#[test]
fn modular_reduction() {
let mut rng = OsRng;
for _ in 0..10 {
let x_lo = rng.next_u64();
let x_hi = rng.next_u32();
let x = (x_lo as u128) + ((x_hi as u128) << 64);
let a = <$field>::from_noncanonical_u128(x);
let b = <$field>::from_noncanonical_u96((x_lo, x_hi));
assert_eq!(a, b);
}
}

#[test]
fn batch_inversion() {
for n in 0..20 {
Expand Down
13 changes: 13 additions & 0 deletions field/src/goldilocks_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ impl Field for GoldilocksField {
Self(n)
}

fn from_noncanonical_u96((n_lo, n_hi): (u64, u32)) -> Self {
reduce96((n_lo, n_hi))
}

fn from_noncanonical_u128(n: u128) -> Self {
reduce128(n)
}
Expand Down Expand Up @@ -337,6 +341,15 @@ unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
res_wrapped + EPSILON * (carry as u64)
}

/// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the
/// field order and `2^64`.
#[inline]
fn reduce96((x_lo, x_hi): (u64, u32)) -> GoldilocksField {
let t1 = x_hi as u64 * EPSILON;
let t2 = unsafe { add_no_canonicalize_trashing_input(x_lo, t1) };
GoldilocksField(t2)
}

/// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the
/// field order and `2^64`.
#[inline]
Expand Down
172 changes: 172 additions & 0 deletions plonky2/src/hash/poseidon_goldilocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
//! `poseidon_constants.sage` script in the `mir-protocol/hash-constants`
//! repository.
use plonky2_field::types::Field;
use unroll::unroll_for_loops;

use crate::field::goldilocks_field::GoldilocksField;
use crate::hash::poseidon::{Poseidon, N_PARTIAL_ROUNDS};

Expand Down Expand Up @@ -211,6 +214,39 @@ impl Poseidon for GoldilocksField {
0xdcedab70f40718ba, 0xe796d293a47a64cb, 0x80772dc2645b280b, ],
];

#[cfg(target_arch="x86_64")]
#[inline(always)]
#[unroll_for_loops]
fn mds_layer(state: &[Self; 12]) -> [Self; 12] {
let mut result = [GoldilocksField::ZERO; 12];

// Using the linearity of the operations we can split the state into a low||high decomposition
// and operate on each with no overflow and then combine/reduce the result to a field element.
let mut state_l = [0u64; 12];
let mut state_h = [0u64; 12];

for r in 0..12 {
let s = state[r].0;
state_h[r] = s >> 32;
state_l[r] = (s as u32) as u64;
}

let state_h = mds_multiply_freq(state_h);
let state_l = mds_multiply_freq(state_l);

for r in 0..12 {
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);

result[r] = GoldilocksField::from_noncanonical_u96((s as u64, (s >> 64) as u32));
}

// Add first element with the only non-zero diagonal matrix coefficient.
let s = Self::MDS_MATRIX_DIAG[0] as u128 * (state[0].0 as u128);
result[0] += GoldilocksField::from_noncanonical_u96((s as u64, (s >> 64) as u32));

result
}

// #[cfg(all(target_arch="x86_64", target_feature="avx2", target_feature="bmi2"))]
// #[inline]
// fn poseidon(input: [Self; 12]) -> [Self; 12] {
Expand Down Expand Up @@ -268,6 +304,142 @@ impl Poseidon for GoldilocksField {
}
}

// MDS layer helper methods
// The following code has been adapted from winterfell/crypto/src/hash/mds/mds_f64_12x12.rs
// located at https://github.com/facebook/winterfell.

const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 32, 16];
const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(2, -1), (-4, 1), (16, 1)];
const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-1, -8, 2];

/// Split 3 x 4 FFT-based MDS vector-multiplication with the Poseidon circulant MDS matrix.
#[inline(always)]
fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;

let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
let (u4, u5, u6) = fft4_real([s1, s4, s7, s10]);
let (u8, u9, u10) = fft4_real([s2, s5, s8, s11]);

// This where the multiplication in frequency domain is done. More precisely, and with
// the appropriate permuations in between, the sequence of
// 3-point FFTs --> multiplication by twiddle factors --> Hadamard multiplication -->
// 3 point iFFTs --> multiplication by (inverse) twiddle factors
// is "squashed" into one step composed of the functions "block1", "block2" and "block3".
// The expressions in the aforementioned functions are the result of explicit computations
// combined with the Karatsuba trick for the multiplication of complex numbers.

let [v0, v4, v8] = block1([u0, u4, u8], MDS_FREQ_BLOCK_ONE);
let [v1, v5, v9] = block2([u1, u5, u9], MDS_FREQ_BLOCK_TWO);
let [v2, v6, v10] = block3([u2, u6, u10], MDS_FREQ_BLOCK_THREE);
// The 4th block is not computed as it is similar to the 2nd one, up to complex conjugation.

let [s0, s3, s6, s9] = ifft4_real_unreduced((v0, v1, v2));
let [s1, s4, s7, s10] = ifft4_real_unreduced((v4, v5, v6));
let [s2, s5, s8, s11] = ifft4_real_unreduced((v8, v9, v10));

[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
}

/// Real 2-FFT over u64 integers.
#[inline(always)]
fn fft2_real(x: [u64; 2]) -> [i64; 2] {
[(x[0] as i64 + x[1] as i64), (x[0] as i64 - x[1] as i64)]
}

/// Real 2-iFFT over u64 integers.
/// Division by two to complete the inverse FFT is not performed here.
#[inline(always)]
fn ifft2_real_unreduced(y: [i64; 2]) -> [u64; 2] {
[(y[0] + y[1]) as u64, (y[0] - y[1]) as u64]
}

/// Real 4-FFT over u64 integers.
#[inline(always)]
fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) {
let [z0, z2] = fft2_real([x[0], x[2]]);
let [z1, z3] = fft2_real([x[1], x[3]]);
let y0 = z0 + z1;
let y1 = (z2, -z3);
let y2 = z0 - z1;
(y0, y1, y2)
}

/// Real 4-iFFT over u64 integers.
/// Division by four to complete the inverse FFT is not performed here.
#[inline(always)]
fn ifft4_real_unreduced(y: (i64, (i64, i64), i64)) -> [u64; 4] {
let z0 = y.0 + y.2;
let z1 = y.0 - y.2;
let z2 = y.1 .0;
let z3 = -y.1 .1;

let [x0, x2] = ifft2_real_unreduced([z0, z2]);
let [x1, x3] = ifft2_real_unreduced([z1, z3]);

[x0, x1, x2, x3]
}

#[inline(always)]
fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
let [x0, x1, x2] = x;
let [y0, y1, y2] = y;
let z0 = x0 * y0 + x1 * y2 + x2 * y1;
let z1 = x0 * y1 + x1 * y0 + x2 * y2;
let z2 = x0 * y2 + x1 * y1 + x2 * y0;

[z0, z1, z2]
}

#[inline(always)]
fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] {
let [(x0r, x0i), (x1r, x1i), (x2r, x2i)] = x;
let [(y0r, y0i), (y1r, y1i), (y2r, y2i)] = y;
let x0s = x0r + x0i;
let x1s = x1r + x1i;
let x2s = x2r + x2i;
let y0s = y0r + y0i;
let y1s = y1r + y1i;
let y2s = y2r + y2i;

// Compute x0​y0 ​− ix1​y2​ − ix2​y1​ using Karatsuba for complex numbers multiplication
let m0 = (x0r * y0r, x0i * y0i);
let m1 = (x1r * y2r, x1i * y2i);
let m2 = (x2r * y1r, x2i * y1i);
let z0r = (m0.0 - m0.1) + (x1s * y2s - m1.0 - m1.1) + (x2s * y1s - m2.0 - m2.1);
let z0i = (x0s * y0s - m0.0 - m0.1) + (-m1.0 + m1.1) + (-m2.0 + m2.1);
let z0 = (z0r, z0i);

// Compute x0​y1​ + x1​y0​ − ix2​y2 using Karatsuba for complex numbers multiplication
let m0 = (x0r * y1r, x0i * y1i);
let m1 = (x1r * y0r, x1i * y0i);
let m2 = (x2r * y2r, x2i * y2i);
let z1r = (m0.0 - m0.1) + (m1.0 - m1.1) + (x2s * y2s - m2.0 - m2.1);
let z1i = (x0s * y1s - m0.0 - m0.1) + (x1s * y0s - m1.0 - m1.1) + (-m2.0 + m2.1);
let z1 = (z1r, z1i);

// Compute x0​y2​ + x1​y1 ​+ x2​y0​ using Karatsuba for complex numbers multiplication
let m0 = (x0r * y2r, x0i * y2i);
let m1 = (x1r * y1r, x1i * y1i);
let m2 = (x2r * y0r, x2i * y0i);
let z2r = (m0.0 - m0.1) + (m1.0 - m1.1) + (m2.0 - m2.1);
let z2i = (x0s * y2s - m0.0 - m0.1) + (x1s * y1s - m1.0 - m1.1) + (x2s * y0s - m2.0 - m2.1);
let z2 = (z2r, z2i);

[z0, z1, z2]
}

#[inline(always)]
fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
let [x0, x1, x2] = x;
let [y0, y1, y2] = y;
let z0 = x0 * y0 - x1 * y2 - x2 * y1;
let z1 = x0 * y1 + x1 * y0 - x2 * y2;
let z2 = x0 * y2 + x1 * y1 + x2 * y0;

[z0, z1, z2]
}

#[cfg(test)]
mod tests {
use crate::field::goldilocks_field::GoldilocksField as F;
Expand Down

0 comments on commit 2d7d2ac

Please sign in to comment.