diff --git a/src/cascading_merkle_tree.rs b/src/cascading_merkle_tree.rs index 0ee7d05..848ae25 100644 --- a/src/cascading_merkle_tree.rs +++ b/src/cascading_merkle_tree.rs @@ -85,13 +85,15 @@ where let mut tree = CascadingMerkleTree { depth, - root: storage.storage_root(), + root: *empty_value, empty_value: *empty_value, sparse_column, storage, _marker: std::marker::PhantomData, }; + tree.recompute_root(); + Ok(tree) } @@ -382,10 +384,15 @@ where #[cfg(test)] mod tests { + use rand::Rng; use serial_test::serial; use super::*; - use crate::generic_storage::{GenericStorage, MmapVec}; + use crate::{ + generic_storage::{GenericStorage, MmapVec}, + poseidon_tree::PoseidonHash, + Field, + }; #[derive(Debug, Clone, PartialEq, Eq)] struct TestHasher; @@ -881,8 +888,15 @@ mod tests { #[test] #[serial] fn test_restore_from_cache() -> color_eyre::Result<()> { - let empty = 0; - let leaves = vec![1; 1 << 20]; + let mut rng = rand::thread_rng(); + + let leaves: Vec = (0..1 << 2) + .map(|_| { + let val = rng.gen::(); + + Field::from(val) + }) + .collect::>(); // Create a new tmp file for mmap storage let tempfile = tempfile::NamedTempFile::new()?; @@ -890,22 +904,26 @@ mod tests { // Initialize the expected tree let mmap_vec: MmapVec<_> = unsafe { MmapVec::new(tempfile.reopen()?).unwrap() }; - let expected_tree = CascadingMerkleTree::>::new_with_leaves( - mmap_vec, 30, &empty, &leaves, + let expected_tree = CascadingMerkleTree::>::new_with_leaves( + mmap_vec, + 3, + &Field::ZERO, + &leaves, ); let expected_root = expected_tree.root(); - let expected_leaves = expected_tree.leaves().collect::>(); + let expected_leaves = expected_tree.leaves().collect::>(); drop(expected_tree); // Restore the tree let mmap_vec: MmapVec<_> = unsafe { MmapVec::restore(file_path).unwrap() }; - let tree = CascadingMerkleTree::>::restore(mmap_vec, 30, &empty)?; + let tree = + CascadingMerkleTree::>::restore(mmap_vec, 3, &Field::ZERO)?; // Assert that the root and the leaves are as expected assert_eq!(tree.root(), expected_root); - assert_eq!(tree.leaves().collect::>(), expected_leaves); + assert_eq!(tree.leaves().collect::>(), expected_leaves); Ok(()) }