diff --git a/src/dynamic_merkle_tree.rs b/src/dynamic_merkle_tree.rs index d68dabe..c72e89f 100644 --- a/src/dynamic_merkle_tree.rs +++ b/src/dynamic_merkle_tree.rs @@ -39,6 +39,11 @@ pub struct DynamicMerkleTree { } impl DynamicMerkleTree { + #[must_use] + pub fn new(depth: usize, empty_value: &H::Hash) -> DynamicMerkleTree { + Self::new_with_leaves(depth, empty_value, &[]) + } + /// initial leaves populated from the given slice. #[must_use] pub fn new_with_leaves( @@ -46,22 +51,24 @@ impl DynamicMerkleTree { empty_value: &H::Hash, leaves: &[H::Hash], ) -> DynamicMerkleTree { + assert!(depth > 0); let storage = Self::storage_from_leaves(empty_value, leaves); - let len = storage.len(); - let root = storage[len >> 1]; let sparse_column = Self::sparse_column(depth, empty_value); let num_leaves = leaves.len(); - DynamicMerkleTree { + let mut tree = DynamicMerkleTree { depth, num_leaves, - root, + root: *empty_value, empty_value: *empty_value, sparse_column, storage, _version: Canonical, _marker: std::marker::PhantomData, - } + }; + + tree.recompute_root(); + tree } /// Index 0 represents the bottow layer @@ -190,6 +197,48 @@ impl DynamicMerkleTree { self.propogate_up(index); } + // pub fn extend_from_slice(&mut self, leaves: &[H::Hash]) { + // let mut storage_len = self.storage.len(); + // let leaf_capacity = storage_len >> 1; + // if self.num_leaves + leaves.len() > leaf_capacity { + // self.reallocate(); + // storage_len = self.storage.len(); + // } + // + // let base_len = storage_len >> 1; + // let depth = base_len.ilog2(); + // + // let mut parents = vec![]; + // leaves.par_iter().enumerate().for_each(|(i, &val)| { + // let leaf_index = self.num_leaves + i; + // let index = index_from_leaf(leaf_index); + // self.storage[index] = val; + // parents.push(index); + // }); + + // leaves.iter().enumerate().for_each(|(i, &val)| { + // storage[base_len + i] = val; + // }); + // + // // We iterate over mutable layers of the tree + // for current_depth in (1..=depth).rev() { + // let (top, child_layer) = storage.split_at_mut(1 << + // current_depth); let parent_layer = &mut top[(1 << + // (current_depth - 1))..]; + // + // parent_layer + // .par_iter_mut() + // .enumerate() + // .for_each(|(i, value)| { + // let left = &child_layer[2 * i]; + // let right = &child_layer[2 * i + 1]; + // *value = H::hash_node(left, right); + // }); + // } + // + // storage[1] + // } + pub fn set_leaf(&mut self, leaf: usize, value: H::Hash) { let index = index_from_leaf(leaf); self.storage[index] = value; @@ -205,21 +254,22 @@ impl DynamicMerkleTree { let left_hash = self.storage.get(left)?; let right_hash = self.storage.get(right)?; let parent_index = parent(index); - self.storage[parent_index] = H::hash_node(left_hash, &right_hash); + self.storage[parent_index] = H::hash_node(left_hash, right_hash); index = parent_index; } } - // pub fn extend_from_slice(&mut self, leaves: H::Hash) { - // match self.storage.get_many_mut.get_mut(self.num_leaves + 1) { - // Some(val) => *val = value, - // None => { - // self.expand(); - // self.storage[self.num_leaves + 1] = value; - // } - // } - // self.num_leaves += 1; - // } + /// Returns the root of the tree. + pub fn recompute_root(&mut self) -> H::Hash { + let storage_root = storage_root(&self.storage); + let storage_depth = storage_depth(&self.storage) as usize; + let mut hash = storage_root; + for i in storage_depth..self.depth { + hash = H::hash_node(&hash, &self.sparse_column[i]); + } + self.root = hash; + hash + } /// Returns the depth of the tree. #[must_use] @@ -494,6 +544,14 @@ fn index_from_leaf(leaf: usize) -> usize { leaf + (leaf + 1).next_power_of_two() } +fn storage_root(storage: &[T]) -> T { + storage[storage.len() >> 1] +} + +fn storage_depth(storage: &[T]) -> u32 { + (storage.len() >> 1).ilog2() +} + fn parent(i: usize) -> usize { if i.is_power_of_two() { return i << 1; @@ -754,6 +812,45 @@ mod tests { assert_eq!(tree, expected); } + #[test] + fn test_sparse_column() { + let leaves = vec![]; + let empty = 1; + let tree = DynamicMerkleTree::::new_with_leaves(10, &empty, &leaves); + let expected = DynamicMerkleTree:: { + depth: 10, + num_leaves: 0, + root: 1024, + empty_value: 1, + sparse_column: vec![1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024], + storage: vec![1, 1], + _version: Canonical, + _marker: std::marker::PhantomData, + }; + debug_tree(&tree); + assert_eq!(tree, expected); + } + + #[test] + fn test_compute_root() { + let num_leaves = 1 << 3; + let leaves = vec![0; num_leaves]; + let empty = 1; + let tree = DynamicMerkleTree::::new_with_leaves(4, &empty, &leaves); + let expected = DynamicMerkleTree:: { + depth: 4, + num_leaves: 8, + root: 8, + empty_value: 1, + sparse_column: vec![1, 2, 4, 8, 16], + storage: vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + _version: Canonical, + _marker: std::marker::PhantomData, + }; + debug_tree(&tree); + assert_eq!(tree, expected); + } + #[test] fn test_push() { let num_leaves = 1 << 3;