diff --git a/crates/semaphore-depth-macros/src/lib.rs b/crates/semaphore-depth-macros/src/lib.rs index fa2d9e3..ddb388c 100644 --- a/crates/semaphore-depth-macros/src/lib.rs +++ b/crates/semaphore-depth-macros/src/lib.rs @@ -27,11 +27,13 @@ use syn::{ /// assert!(depth > 0); /// } /// +/// # #[no_run] /// #[test] /// fn test_depth_non_zero_depth_16() { /// test_depth_non_zero(16); /// } /// +/// # #[no_run] /// #[test] /// fn test_depth_non_zero_depth_30() { /// test_depth_non_zero(30); diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..940f6ae --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly-2024-02-05" diff --git a/src/lazy_merkle_tree.rs b/src/lazy_merkle_tree.rs index 98c4a74..3147316 100644 --- a/src/lazy_merkle_tree.rs +++ b/src/lazy_merkle_tree.rs @@ -1,4 +1,7 @@ -use crate::merkle_tree::{Branch, Hasher, Proof}; +use crate::{ + merkle_tree::{Branch, Hasher, Proof}, + util::as_bytes, +}; use std::{ fs::OpenOptions, io::Write, @@ -10,6 +13,7 @@ use std::{ }; use mmap_rs::{MmapMut, MmapOptions}; +use rayon::prelude::*; use thiserror::Error; pub trait VersionMarker {} @@ -82,7 +86,6 @@ impl LazyMerkleTree { /// Creates a new memory mapped file specified by path and creates a tree /// with dense prefix of the given depth with initial values - #[must_use] pub fn new_mmapped_with_dense_prefix_with_init_values( depth: usize, prefix_depth: usize, @@ -685,18 +688,40 @@ impl Clone for DenseTree { } impl DenseTree { - fn new_with_values(values: &[H::Hash], empty_leaf: &H::Hash, depth: usize) -> Self { + fn vec_from_values(values: &[H::Hash], empty_value: &H::Hash, depth: usize) -> Vec { let leaf_count = 1 << depth; - let first_leaf_index = 1 << depth; let storage_size = 1 << (depth + 1); - assert!(values.len() <= leaf_count); - let mut storage = vec![empty_leaf.clone(); storage_size]; - storage[first_leaf_index..(first_leaf_index + values.len())].clone_from_slice(values); - for i in (1..first_leaf_index).rev() { - let left = &storage[2 * i]; - let right = &storage[2 * i + 1]; - storage[i] = H::hash_node(left, right); + let mut storage = Vec::with_capacity(storage_size); + + let empties = repeat(empty_value.clone()).take(leaf_count); + storage.extend(empties); + storage.extend_from_slice(values); + if values.len() < leaf_count { + let empties = repeat(empty_value.clone()).take(leaf_count - values.len()); + storage.extend(empties); + } + + // We iterate over mutable layers of the tree + for current_depth in (1..=depth).rev() { + let (top, child_layer) = storage.split_at_mut(1 << current_depth); + let parent_layer = &mut top[(1 << (current_depth - 1))..]; + + parent_layer + .par_iter_mut() + .enumerate() + .for_each(|(i, value)| { + let left = &child_layer[2 * i]; + let right = &child_layer[2 * i + 1]; + *value = H::hash_node(left, right); + }); } + + storage + } + + fn new_with_values(values: &[H::Hash], empty_value: &H::Hash, depth: usize) -> Self { + let storage = Self::vec_from_values(values, empty_value, depth); + Self { depth, root_index: 1, @@ -881,7 +906,7 @@ impl DenseMMapTree { /// - returns Err if mmap creation fails fn new_with_values( values: &[H::Hash], - empty_leaf: &H::Hash, + empty_value: &H::Hash, depth: usize, mmap_file_path: &str, ) -> Result { @@ -890,19 +915,9 @@ impl DenseMMapTree { Err(_e) => return Err(DenseMMapError::FailedToCreatePathBuf), }; - let leaf_count = 1 << depth; - let first_leaf_index = 1 << depth; - let storage_size = 1 << (depth + 1); + let storage = DenseTree::::vec_from_values(values, empty_value, depth); - assert!(values.len() <= leaf_count); - - let mut mmap = MmapMutWrapper::new_with_initial_values(path_buf, empty_leaf, storage_size)?; - mmap[first_leaf_index..(first_leaf_index + values.len())].clone_from_slice(values); - for i in (1..first_leaf_index).rev() { - let left = &mmap[2 * i]; - let right = &mmap[2 * i + 1]; - mmap[i] = H::hash_node(left, right); - } + let mmap = MmapMutWrapper::new_from_storage(path_buf, &storage)?; Ok(Self { depth, @@ -1113,29 +1128,14 @@ impl MmapMutWrapper { /// - file size cannot be set /// - file is too large, possible truncation can occur /// - cannot build memory map - pub fn new_with_initial_values( + pub fn new_from_storage( file_path: PathBuf, - initial_value: &H::Hash, - storage_size: usize, + storage: &[H::Hash], ) -> Result { - let size_of_val = std::mem::size_of_val(initial_value); - let initial_vals: Vec = vec![initial_value.clone(); storage_size]; - - // cast Hash pointer to u8 pointer - let ptr = initial_vals.as_ptr().cast::(); - - let size_of_buffer: usize = storage_size * size_of_val; - - let buf: &[u8] = unsafe { - // moving pointer by u8 for storage_size * size of hash would get us the full - // buffer - std::slice::from_raw_parts(ptr, size_of_buffer) - }; - - // assure that buffer is correct length - assert_eq!(buf.len(), size_of_buffer); - - let file_size: u64 = storage_size as u64 * size_of_val as u64; + // Safety: potential uninitialized padding from `H::Hash` is safe to use if + // we're casting back to the same type. + let buf = unsafe { as_bytes(storage) }; + let buf_len = buf.len(); let mut file = match OpenOptions::new() .read(true) @@ -1148,13 +1148,13 @@ impl MmapMutWrapper { Err(_e) => return Err(DenseMMapError::FileCreationFailed), }; - file.set_len(file_size).expect("cannot set file size"); + file.set_len(buf_len as u64).expect("cannot set file size"); if file.write_all(buf).is_err() { return Err(DenseMMapError::FileCannotWriteBytes); } let mmap = unsafe { - MmapOptions::new(usize::try_from(file_size).expect("file size truncated")) + MmapOptions::new(usize::try_from(buf_len as u64).expect("file size truncated")) .expect("cannot create memory map") .with_file(file, 0) .map_mut() @@ -1735,8 +1735,7 @@ pub mod bench { let prefix_depth = depth; let empty_value = Field::from(0); - let initial_values: Vec> = - (0..(1 << depth)).map(|value| Field::from(value)).collect(); + let initial_values: Vec> = (0..(1 << depth)).map(Field::from).collect(); TreeValues { depth, diff --git a/src/lib.rs b/src/lib.rs index 3eeca67..9824a47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ #![doc = include_str!("../README.md")] -#![warn(clippy::all, clippy::pedantic, clippy::cargo, clippy::nursery)] +#![warn(clippy::all, clippy::cargo)] // TODO: ark-circom and ethers-core pull in a lot of dependencies, some duplicate. #![allow(clippy::multiple_crate_versions)] diff --git a/src/merkle_tree.rs b/src/merkle_tree.rs index c3ddcbc..de81c24 100644 --- a/src/merkle_tree.rs +++ b/src/merkle_tree.rs @@ -14,7 +14,7 @@ use std::{ /// Hash types, values and algorithms for a Merkle tree pub trait Hasher { /// Type of the leaf and node hashes - type Hash: Clone + Eq + Serialize + Debug; + type Hash: Clone + Eq + Serialize + Debug + Send + Sync; /// Compute the hash of an intermediate node fn hash_node(left: &Self::Hash, right: &Self::Hash) -> Self::Hash; diff --git a/src/protocol/authentication.rs b/src/protocol/authentication.rs index 8ecded0..76fd338 100644 --- a/src/protocol/authentication.rs +++ b/src/protocol/authentication.rs @@ -14,7 +14,7 @@ pub fn generate_proof( let merkle_proof = LazyPoseidonTree::new(depth, Field::from(0)) .update(0, &identity.commitment()) .proof(0); - return super::generate_proof(identity, &merkle_proof, ext_nullifier_hash, signal_hash); + super::generate_proof(identity, &merkle_proof, ext_nullifier_hash, signal_hash) } pub fn verify_proof( @@ -28,12 +28,12 @@ pub fn verify_proof( let root = LazyPoseidonTree::new(depth, Field::from(0)) .update(0, &id_commitment) .root(); - return super::verify_proof( + super::verify_proof( root, nullifier_hash, signal_hash, ext_nullifier_hash, proof, depth, - ); + ) } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index e23bcff..0149a6c 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -269,7 +269,7 @@ mod test { #[test_all_depths] fn test_proof_serialize(depth: usize) { let proof = arb_proof(456, depth); - let json = serde_json::to_value(&proof).unwrap(); + let json = serde_json::to_value(proof).unwrap(); let valid_values = match depth { 16 => json!([ [ diff --git a/src/util.rs b/src/util.rs index 86bfafc..0449a08 100644 --- a/src/util.rs +++ b/src/util.rs @@ -16,6 +16,18 @@ pub(crate) fn keccak256(bytes: &[u8]) -> [u8; 32] { output } +/// function to convert a reference to any type to a slice of bytes +/// # Safety +/// This function is unsafe because it transmutes a reference to a slice of +/// bytes. If `T` is not densly packed, then the padding bytes will be +/// unitialized. +pub(crate) unsafe fn as_bytes(p: &T) -> &[u8] { + ::core::slice::from_raw_parts( + (p as *const T).cast::(), + ::core::mem::size_of_val::(p), + ) +} + pub(crate) fn bytes_to_hex(bytes: &[u8; N]) -> [u8; M] { // TODO: Replace `M` with a const expression once it's stable. debug_assert_eq!(M, 2 * N + 2); @@ -113,6 +125,26 @@ pub(crate) fn deserialize_bytes<'de, const N: usize, D: Deserializer<'de>>( mod test { use super::*; + #[test] + fn test_as_bytes() { + // test byte array + let bytes = [0, 1, 2, 3u8]; + let converted: [u8; 4] = unsafe { as_bytes(&bytes).try_into().unwrap() }; + assert_eq!(bytes, converted); + + // test u64 array + let array = [1u64, 1]; + let converted: [u8; 16] = unsafe { as_bytes(&array).try_into().unwrap() }; + let expected = [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]; + assert_eq!(converted, expected); + + // test u64 + let value = 1u64; + let converted: [u8; 8] = unsafe { as_bytes(&value).try_into().unwrap() }; + let expected = [1, 0, 0, 0, 0, 0, 0, 0]; + assert_eq!(converted, expected); + } + #[test] fn test_serialize_bytes_hex() { let bytes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];