Skip to content

Commit

Permalink
Merge pull request #67 from worldcoin/ewoolsey/multithread
Browse files Browse the repository at this point in the history
Multithreading Dense Merkle Tree Initialization
  • Loading branch information
0xForerunner authored Mar 13, 2024
2 parents 0af6b37 + d62e043 commit 4f676a7
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 55 deletions.
2 changes: 2 additions & 0 deletions crates/semaphore-depth-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ use syn::{
/// assert!(depth > 0);
/// }
///
/// # #[no_run]
/// #[test]
/// fn test_depth_non_zero_depth_16() {
/// test_depth_non_zero(16);
/// }
///
/// # #[no_run]
/// #[test]
/// fn test_depth_non_zero_depth_30() {
/// test_depth_non_zero(30);
Expand Down
2 changes: 2 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[toolchain]
channel = "nightly-2024-02-05"
97 changes: 48 additions & 49 deletions src/lazy_merkle_tree.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::merkle_tree::{Branch, Hasher, Proof};
use crate::{
merkle_tree::{Branch, Hasher, Proof},
util::as_bytes,
};
use std::{
fs::OpenOptions,
io::Write,
Expand All @@ -10,6 +13,7 @@ use std::{
};

use mmap_rs::{MmapMut, MmapOptions};
use rayon::prelude::*;
use thiserror::Error;

pub trait VersionMarker {}
Expand Down Expand Up @@ -82,7 +86,6 @@ impl<H: Hasher, Version: VersionMarker> LazyMerkleTree<H, Version> {

/// Creates a new memory mapped file specified by path and creates a tree
/// with dense prefix of the given depth with initial values
#[must_use]
pub fn new_mmapped_with_dense_prefix_with_init_values(
depth: usize,
prefix_depth: usize,
Expand Down Expand Up @@ -685,18 +688,40 @@ impl<H: Hasher> Clone for DenseTree<H> {
}

impl<H: Hasher> DenseTree<H> {
fn new_with_values(values: &[H::Hash], empty_leaf: &H::Hash, depth: usize) -> Self {
fn vec_from_values(values: &[H::Hash], empty_value: &H::Hash, depth: usize) -> Vec<H::Hash> {
let leaf_count = 1 << depth;
let first_leaf_index = 1 << depth;
let storage_size = 1 << (depth + 1);
assert!(values.len() <= leaf_count);
let mut storage = vec![empty_leaf.clone(); storage_size];
storage[first_leaf_index..(first_leaf_index + values.len())].clone_from_slice(values);
for i in (1..first_leaf_index).rev() {
let left = &storage[2 * i];
let right = &storage[2 * i + 1];
storage[i] = H::hash_node(left, right);
let mut storage = Vec::with_capacity(storage_size);

let empties = repeat(empty_value.clone()).take(leaf_count);
storage.extend(empties);
storage.extend_from_slice(values);
if values.len() < leaf_count {
let empties = repeat(empty_value.clone()).take(leaf_count - values.len());
storage.extend(empties);
}

// We iterate over mutable layers of the tree
for current_depth in (1..=depth).rev() {
let (top, child_layer) = storage.split_at_mut(1 << current_depth);
let parent_layer = &mut top[(1 << (current_depth - 1))..];

parent_layer
.par_iter_mut()
.enumerate()
.for_each(|(i, value)| {
let left = &child_layer[2 * i];
let right = &child_layer[2 * i + 1];
*value = H::hash_node(left, right);
});
}

storage
}

fn new_with_values(values: &[H::Hash], empty_value: &H::Hash, depth: usize) -> Self {
let storage = Self::vec_from_values(values, empty_value, depth);

Self {
depth,
root_index: 1,
Expand Down Expand Up @@ -881,7 +906,7 @@ impl<H: Hasher> DenseMMapTree<H> {
/// - returns Err if mmap creation fails
fn new_with_values(
values: &[H::Hash],
empty_leaf: &H::Hash,
empty_value: &H::Hash,
depth: usize,
mmap_file_path: &str,
) -> Result<Self, DenseMMapError> {
Expand All @@ -890,19 +915,9 @@ impl<H: Hasher> DenseMMapTree<H> {
Err(_e) => return Err(DenseMMapError::FailedToCreatePathBuf),
};

let leaf_count = 1 << depth;
let first_leaf_index = 1 << depth;
let storage_size = 1 << (depth + 1);
let storage = DenseTree::<H>::vec_from_values(values, empty_value, depth);

assert!(values.len() <= leaf_count);

let mut mmap = MmapMutWrapper::new_with_initial_values(path_buf, empty_leaf, storage_size)?;
mmap[first_leaf_index..(first_leaf_index + values.len())].clone_from_slice(values);
for i in (1..first_leaf_index).rev() {
let left = &mmap[2 * i];
let right = &mmap[2 * i + 1];
mmap[i] = H::hash_node(left, right);
}
let mmap = MmapMutWrapper::new_from_storage(path_buf, &storage)?;

Ok(Self {
depth,
Expand Down Expand Up @@ -1113,29 +1128,14 @@ impl<H: Hasher> MmapMutWrapper<H> {
/// - file size cannot be set
/// - file is too large, possible truncation can occur
/// - cannot build memory map
pub fn new_with_initial_values(
pub fn new_from_storage(
file_path: PathBuf,
initial_value: &H::Hash,
storage_size: usize,
storage: &[H::Hash],
) -> Result<Self, DenseMMapError> {
let size_of_val = std::mem::size_of_val(initial_value);
let initial_vals: Vec<H::Hash> = vec![initial_value.clone(); storage_size];

// cast Hash pointer to u8 pointer
let ptr = initial_vals.as_ptr().cast::<u8>();

let size_of_buffer: usize = storage_size * size_of_val;

let buf: &[u8] = unsafe {
// moving pointer by u8 for storage_size * size of hash would get us the full
// buffer
std::slice::from_raw_parts(ptr, size_of_buffer)
};

// assure that buffer is correct length
assert_eq!(buf.len(), size_of_buffer);

let file_size: u64 = storage_size as u64 * size_of_val as u64;
// Safety: potential uninitialized padding from `H::Hash` is safe to use if
// we're casting back to the same type.
let buf = unsafe { as_bytes(storage) };
let buf_len = buf.len();

let mut file = match OpenOptions::new()
.read(true)
Expand All @@ -1148,13 +1148,13 @@ impl<H: Hasher> MmapMutWrapper<H> {
Err(_e) => return Err(DenseMMapError::FileCreationFailed),
};

file.set_len(file_size).expect("cannot set file size");
file.set_len(buf_len as u64).expect("cannot set file size");
if file.write_all(buf).is_err() {
return Err(DenseMMapError::FileCannotWriteBytes);
}

let mmap = unsafe {
MmapOptions::new(usize::try_from(file_size).expect("file size truncated"))
MmapOptions::new(usize::try_from(buf_len as u64).expect("file size truncated"))
.expect("cannot create memory map")
.with_file(file, 0)
.map_mut()
Expand Down Expand Up @@ -1735,8 +1735,7 @@ pub mod bench {
let prefix_depth = depth;
let empty_value = Field::from(0);

let initial_values: Vec<ruint::Uint<256, 4>> =
(0..(1 << depth)).map(|value| Field::from(value)).collect();
let initial_values: Vec<ruint::Uint<256, 4>> = (0..(1 << depth)).map(Field::from).collect();

TreeValues {
depth,
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![doc = include_str!("../README.md")]
#![warn(clippy::all, clippy::pedantic, clippy::cargo, clippy::nursery)]
#![warn(clippy::all, clippy::cargo)]
// TODO: ark-circom and ethers-core pull in a lot of dependencies, some duplicate.
#![allow(clippy::multiple_crate_versions)]

Expand Down
2 changes: 1 addition & 1 deletion src/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::{
/// Hash types, values and algorithms for a Merkle tree
pub trait Hasher {
/// Type of the leaf and node hashes
type Hash: Clone + Eq + Serialize + Debug;
type Hash: Clone + Eq + Serialize + Debug + Send + Sync;

/// Compute the hash of an intermediate node
fn hash_node(left: &Self::Hash, right: &Self::Hash) -> Self::Hash;
Expand Down
6 changes: 3 additions & 3 deletions src/protocol/authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub fn generate_proof(
let merkle_proof = LazyPoseidonTree::new(depth, Field::from(0))
.update(0, &identity.commitment())
.proof(0);
return super::generate_proof(identity, &merkle_proof, ext_nullifier_hash, signal_hash);
super::generate_proof(identity, &merkle_proof, ext_nullifier_hash, signal_hash)
}

pub fn verify_proof(
Expand All @@ -28,12 +28,12 @@ pub fn verify_proof(
let root = LazyPoseidonTree::new(depth, Field::from(0))
.update(0, &id_commitment)
.root();
return super::verify_proof(
super::verify_proof(
root,
nullifier_hash,
signal_hash,
ext_nullifier_hash,
proof,
depth,
);
)
}
2 changes: 1 addition & 1 deletion src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ mod test {
#[test_all_depths]
fn test_proof_serialize(depth: usize) {
let proof = arb_proof(456, depth);
let json = serde_json::to_value(&proof).unwrap();
let json = serde_json::to_value(proof).unwrap();
let valid_values = match depth {
16 => json!([
[
Expand Down
32 changes: 32 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ pub(crate) fn keccak256(bytes: &[u8]) -> [u8; 32] {
output
}

/// function to convert a reference to any type to a slice of bytes
/// # Safety
/// This function is unsafe because it transmutes a reference to a slice of
/// bytes. If `T` is not densly packed, then the padding bytes will be
/// unitialized.
pub(crate) unsafe fn as_bytes<T: ?Sized>(p: &T) -> &[u8] {
::core::slice::from_raw_parts(
(p as *const T).cast::<u8>(),
::core::mem::size_of_val::<T>(p),
)
}

pub(crate) fn bytes_to_hex<const N: usize, const M: usize>(bytes: &[u8; N]) -> [u8; M] {
// TODO: Replace `M` with a const expression once it's stable.
debug_assert_eq!(M, 2 * N + 2);
Expand Down Expand Up @@ -113,6 +125,26 @@ pub(crate) fn deserialize_bytes<'de, const N: usize, D: Deserializer<'de>>(
mod test {
use super::*;

#[test]
fn test_as_bytes() {
// test byte array
let bytes = [0, 1, 2, 3u8];
let converted: [u8; 4] = unsafe { as_bytes(&bytes).try_into().unwrap() };
assert_eq!(bytes, converted);

// test u64 array
let array = [1u64, 1];
let converted: [u8; 16] = unsafe { as_bytes(&array).try_into().unwrap() };
let expected = [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0];
assert_eq!(converted, expected);

// test u64
let value = 1u64;
let converted: [u8; 8] = unsafe { as_bytes(&value).try_into().unwrap() };
let expected = [1, 0, 0, 0, 0, 0, 0, 0];
assert_eq!(converted, expected);
}

#[test]
fn test_serialize_bytes_hex() {
let bytes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
Expand Down

0 comments on commit 4f676a7

Please sign in to comment.