diff --git a/.vscode/settings.json b/.vscode/settings.json index 54155d1..9b75824 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,5 @@ { "rust-analyzer.cargo.features": [ "stats", - "db_extension" ] } \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 955346e..9ae6465 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,6 @@ license = "MIT" [dependencies] crossbeam-epoch = "0.9.18" -rand = { version = "0.8.5", optional = true } serde = { version = "1.0.214", features = ["derive"], optional = true } [dev-dependencies] @@ -42,7 +41,6 @@ harness = false flamegraph = ["shumai/flamegraph"] perf = ["shumai/perf"] stats = ["serde"] -db_extension = ["rand"] shuttle = [] [profile.bench] diff --git a/src/base_node.rs b/src/base_node.rs index efddd10..58da4b4 100644 --- a/src/base_node.rs +++ b/src/base_node.rs @@ -69,9 +69,6 @@ pub(crate) trait Node { fn remove(&mut self, k: u8); fn copy_to(&self, dst: &mut N); fn get_type() -> NodeType; - - #[cfg(feature = "db_extension")] - fn get_random_child(&self, rng: &mut impl rand::Rng) -> Option<(u8, NodePtr)>; } pub(crate) enum NodeIter<'a> { @@ -174,12 +171,6 @@ macro_rules! gen_method_mut { gen_method!(get_child, (k: u8), Option); gen_method!(get_children, (start: u8, end: u8), NodeIter<'_>); -#[cfg(feature = "db_extension")] -gen_method!( - get_random_child, - (rng: &mut impl rand::Rng), - Option<(u8, NodePtr)> -); gen_method_mut!(change, (key: u8, val: NodePtr), NodePtr); gen_method_mut!(remove, (key: u8), ()); @@ -252,11 +243,11 @@ impl BaseNode { Ok(ReadGuard::new(version, node)) } - pub(crate) fn read_lock2<'a>(node: NonNull) -> Result, ArtError> { + pub(crate) fn read_lock<'a>(node: NonNull) -> Result, ArtError> { Self::read_lock_inner(node) } - pub(crate) fn read_lock<'a, const MAX_LEVEL: usize>( + pub(crate) fn read_lock_deprecated<'a, const MAX_LEVEL: usize>( node: NodePtr, current_level: usize, ) -> Result, ArtError> { diff --git a/src/lib.rs b/src/lib.rs index a92d44b..8ee45aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -381,48 +381,6 @@ where self.inner.stats() } - /// Get a random value from the tree, perform the transformation `f`. - /// This is useful for randomized algorithms. - /// - /// `f` takes key and value as input and return the new value, |key: usize, value: usize| -> usize. - /// - /// Returns (key, old_value, new_value) - /// - /// Note that the function `f` is a FnMut and it must be safe to execute multiple times. - /// The `f` is expected to be short and fast as it will hold a exclusive lock on the leaf node. - /// - /// # Examples: - /// ``` - /// use congee::Congee; - /// let tree = Congee::default(); - /// let guard = tree.pin(); - /// tree.insert(1, 42, &guard); - /// let mut rng = rand::thread_rng(); - /// let (key, old_v, new_v) = tree.compute_on_random(&mut rng, |k, v| { - /// assert_eq!(k, 1); - /// assert_eq!(v, 42); - /// v + 1 - /// }, &guard).unwrap(); - /// assert_eq!(key, 1); - /// assert_eq!(old_v, 42); - /// assert_eq!(new_v, 43); - /// ``` - #[cfg(feature = "db_extension")] - #[cfg_attr(docsrs, doc(cfg(feature = "db_extension")))] - pub fn compute_on_random( - &self, - rng: &mut impl rand::Rng, - mut f: impl FnMut(K, V) -> V, - guard: &epoch::Guard, - ) -> Option<(K, V, V)> { - let mut remapped = |key: usize, value: usize| -> usize { - let v = f(K::from(key), V::from(value)); - usize::from(v) - }; - let (key, old_v, new_v) = self.inner.compute_on_random(rng, &mut remapped, guard)?; - Some((K::from(key), V::from(old_v), V::from(new_v))) - } - /// Update the value if the old value matches with the new one. /// Returns the current value. /// diff --git a/src/node_16.rs b/src/node_16.rs index 5b7102a..6f89112 100644 --- a/src/node_16.rs +++ b/src/node_16.rs @@ -22,45 +22,17 @@ impl Node16 { val ^ 128 } - #[cfg(all(target_feature = "sse2", not(miri)))] - fn ctz(val: u16) -> u16 { - val.trailing_zeros() as u16 - } - fn get_insert_pos(&self, key: u8) -> usize { let flipped = Self::flip_sign(key); - #[cfg(all(target_feature = "sse2", not(miri)))] - { - unsafe { - use std::arch::x86_64::{ - __m128i, _mm_cmplt_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8, - }; - let cmp = _mm_cmplt_epi8( - _mm_set1_epi8(flipped as i8), - _mm_loadu_si128(&self.keys as *const [u8; 16] as *const __m128i), - ); - let bit_field = _mm_movemask_epi8(cmp) & (0xFFFF >> (16 - self.base.meta.count)); - let pos = if bit_field > 0 { - Self::ctz(bit_field as u16) - } else { - self.base.meta.count - }; - pos as usize - } - } - - #[cfg(any(not(target_feature = "sse2"), miri))] - { - let mut pos = 0; - while pos < self.base.meta.count { - if self.keys[pos as usize] >= flipped { - return pos as usize; - } - pos += 1; + let mut pos = 0; + while pos < self.base.meta.count { + if self.keys[pos as usize] >= flipped { + return pos as usize; } - pos as usize + pos += 1; } + pos as usize } fn get_child_pos(&self, key: u8) -> Option { @@ -199,14 +171,4 @@ impl Node for Node16 { let child = unsafe { self.children.get_unchecked(pos) }; Some(*child) } - - #[cfg(feature = "db_extension")] - fn get_random_child(&self, rng: &mut impl rand::Rng) -> Option<(u8, NodePtr)> { - if self.base.meta.count == 0 { - return None; - } - - let idx = rng.gen_range(0..self.base.meta.count); - Some((self.keys[idx as usize], self.children[idx as usize])) - } } diff --git a/src/node_256.rs b/src/node_256.rs index 9c929b6..646f3e3 100644 --- a/src/node_256.rs +++ b/src/node_256.rs @@ -124,20 +124,4 @@ impl Node for Node256 { None } } - - #[cfg(feature = "db_extension")] - fn get_random_child(&self, rng: &mut impl rand::Rng) -> Option<(u8, NodePtr)> { - if self.base.meta.count == 0 { - return None; - } - let mut scan_cnt = rng.gen_range(1..=self.base.meta.count); - let mut idx = 0; - while scan_cnt > 0 { - if self.get_mask(idx) { - scan_cnt -= 1; - } - idx += 1; - } - Some(((idx - 1) as u8, self.children[idx - 1])) - } } diff --git a/src/node_4.rs b/src/node_4.rs index e8803d2..9ebc15d 100644 --- a/src/node_4.rs +++ b/src/node_4.rs @@ -145,13 +145,4 @@ impl Node for Node4 { .find(|(k, _)| **k == key) .map(|(_, c)| *c) } - - #[cfg(feature = "db_extension")] - fn get_random_child(&self, rng: &mut impl rand::Rng) -> Option<(u8, NodePtr)> { - if self.base.meta.count == 0 { - return None; - } - let idx = rng.gen_range(0..self.base.meta.count); - Some((self.keys[idx as usize], self.children[idx as usize])) - } } diff --git a/src/node_48.rs b/src/node_48.rs index 4ea37b6..91b6133 100644 --- a/src/node_48.rs +++ b/src/node_48.rs @@ -123,25 +123,4 @@ impl Node for Node48 { Some(*child) } } - - #[cfg(feature = "db_extension")] - fn get_random_child(&self, rng: &mut impl rand::Rng) -> Option<(u8, NodePtr)> { - if self.base.meta.count == 0 { - return None; - } - - let mut scan_cnt = rng.gen_range(1..=self.base.meta.count); - let mut idx = 0; - while scan_cnt > 0 { - if self.child_idx[idx] != EMPTY_MARKER { - scan_cnt -= 1; - } - idx += 1; - } - - Some(( - (idx - 1) as u8, - self.children[self.child_idx[idx - 1] as usize], - )) - } } diff --git a/src/node_ptr.rs b/src/node_ptr.rs index d9410b7..26b2071 100644 --- a/src/node_ptr.rs +++ b/src/node_ptr.rs @@ -97,7 +97,7 @@ impl NodePtr { unsafe { self.sub_node } } - pub(crate) fn node_type(&self, current_level: usize) -> PtrType { + pub(crate) fn downcast(&self, current_level: usize) -> PtrType { if current_level == (K_LEN - 1) { PtrType::Payload(unsafe { self.as_payload_unchecked() }) } else { diff --git a/src/range_scan.rs b/src/range_scan.rs index 8ed1213..c293f30 100644 --- a/src/range_scan.rs +++ b/src/range_scan.rs @@ -122,7 +122,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { } key_tracker.push(start_level); - let next_node = BaseNode::read_lock::(next_node_tmp, level)?; + let next_node = BaseNode::read_lock_deprecated::(next_node_tmp, level)?; parent_node = Some(node); node = next_node; continue; @@ -148,7 +148,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { ) -> Result<(), ArtError> { debug_assert!(key_tracker.len() != 8); - let node = BaseNode::read_lock::(node, key_tracker.len())?; + let node = BaseNode::read_lock_deprecated::(node, key_tracker.len())?; let prefix_result = self.check_prefix_compare(node.as_ref(), self.end, 255, &mut key_tracker); let level = key_tracker.len(); @@ -198,7 +198,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { ) -> Result<(), ArtError> { debug_assert!(key_tracker.len() != 8); - let node = BaseNode::read_lock::(node, key_tracker.len())?; + let node = BaseNode::read_lock_deprecated::(node, key_tracker.len())?; let prefix_result = self.check_prefix_compare(node.as_ref(), self.start, 0, &mut key_tracker); @@ -253,7 +253,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { self.result_found += 1; }; } else { - let node = BaseNode::read_lock::(node, key_tracker.len())?; + let node = BaseNode::read_lock_deprecated::(node, key_tracker.len())?; let mut key_tracker = key_tracker.clone(); let children = node.as_ref().get_children(0, 255); diff --git a/src/stats.rs b/src/stats.rs index bd74147..339869d 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -1,8 +1,9 @@ -use std::fmt::Display; +use std::{fmt::Display, ptr::NonNull}; use crate::{ - base_node::{BaseNode, NodeType, MAX_KEY_LEN}, - node_ptr::NodePtr, + base_node::{BaseNode, NodeType}, + node_256::Node256, + node_ptr::PtrType, tree::RawCongee, Allocator, }; @@ -97,10 +98,12 @@ impl RawCongee { pub fn stats(&self) -> NodeStats { let mut node_stats = NodeStats::default(); - let mut sub_nodes = vec![(0, 0, NodePtr::from_root(self.root))]; + let mut sub_nodes = vec![(0, 0, unsafe { + std::mem::transmute::, NonNull>(self.root) + })]; while let Some((level, key_level, node)) = sub_nodes.pop() { - let node = BaseNode::read_lock::(node, level).unwrap(); + let node = BaseNode::read_lock(node).unwrap(); if node_stats.0.len() <= level { node_stats.0.push(LevelStats::new_level(level)); @@ -127,13 +130,16 @@ impl RawCongee { let children = node.as_ref().get_children(0, 255); for (_k, n) in children { - if key_level != (MAX_KEY_LEN - 1) { - let child_node = BaseNode::read_lock::(n, level).unwrap(); - sub_nodes.push(( - level + 1, - key_level + 1 + child_node.as_ref().prefix().len(), - n, - )); + match n.downcast::(level) { + PtrType::Payload(_) => {} + PtrType::SubNode(sub_node) => { + let child_node = BaseNode::read_lock(sub_node).unwrap(); + sub_nodes.push(( + level + 1, + key_level + 1 + child_node.as_ref().prefix().len(), + sub_node, + )); + } } } } diff --git a/src/tree.rs b/src/tree.rs index 718d2e8..34e4c17 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -36,27 +36,24 @@ impl Drop for RawCongee { let mut sub_nodes = vec![(NodePtr::from_root(self.root), 0)]; while let Some((node, level)) = sub_nodes.pop() { - match node.node_type::(level) { + match node.downcast::(level) { PtrType::Payload(_) => { continue; } PtrType::SubNode(sub_node) => { - let node_lock = BaseNode::read_lock2(sub_node).unwrap(); + let node_lock = BaseNode::read_lock(sub_node).unwrap(); let children = node_lock.as_ref().get_children(0, 255); for (_k, n) in children { - match n.node_type::(level) { + match n.downcast::(level) { PtrType::Payload(_) => {} PtrType::SubNode(sub_sub_node) => { - let node_lock = BaseNode::read_lock2(sub_sub_node).unwrap(); + let node_lock = BaseNode::read_lock(sub_sub_node).unwrap(); sub_nodes.push((n, node_lock.as_ref().prefix().len())); } } } unsafe { - BaseNode::drop_node( - node.as_ptr_safe::(level), - self.allocator.clone(), - ); + BaseNode::drop_node(sub_node, self.allocator.clone()); } } } @@ -100,14 +97,14 @@ impl RawCongee { let child_node = child_node?; - match child_node.node_type::(level) { + match child_node.downcast::(level) { PtrType::Payload(tid) => { return Some(tid); } PtrType::SubNode(sub_node) => { level += 1; - node = if let Ok(n) = BaseNode::read_lock2(sub_node) { + node = if let Ok(n) = BaseNode::read_lock(sub_node) { n } else { continue 'outer; @@ -184,13 +181,11 @@ impl RawCongee { &self.allocator, guard, ) { - if level != (K_LEN - 1) { - unsafe { - BaseNode::drop_node( - new_leaf.as_ptr_safe::(level), - self.allocator.clone(), - ); - } + match new_leaf.downcast::(level) { + PtrType::Payload(_) => {} + PtrType::SubNode(sub_node) => unsafe { + BaseNode::drop_node(sub_node, self.allocator.clone()); + }, } return Err(e); } @@ -202,7 +197,7 @@ impl RawCongee { p.unlock()?; } - match next_node.node_type::(level) { + match next_node.downcast::(level) { PtrType::Payload(old) => { // At this point, the level must point to the last u8 of the key, // meaning that we are updating an existing value. @@ -221,7 +216,7 @@ impl RawCongee { } PtrType::SubNode(sub_node) => { parent_node = Some(node); - node = BaseNode::read_lock2(sub_node)?; + node = BaseNode::read_lock(sub_node)?; level += 1; } } @@ -421,7 +416,7 @@ impl RawCongee { None => return Ok(None), }; - match child_node.node_type::(level) { + match child_node.downcast::(level) { PtrType::Payload(tid) => { let new_v = remapping_function(tid); @@ -468,7 +463,7 @@ impl RawCongee { PtrType::SubNode(sub_node) => { level += 1; parent = Some((node, node_key)); - node = BaseNode::read_lock2(sub_node)?; + node = BaseNode::read_lock(sub_node)?; } } } @@ -492,78 +487,4 @@ impl RawCongee { } } } - - #[inline] - #[cfg(feature = "db_extension")] - pub(crate) fn compute_on_random( - &self, - rng: &mut impl rand::Rng, - f: &mut impl FnMut(usize, usize) -> usize, - guard: &Guard, - ) -> Option<(usize, usize, usize)> { - let backoff = Backoff::new(); - loop { - match self.compute_on_random_inner(rng, f, guard) { - Ok(n) => return n, - Err(_) => backoff.spin(), - } - } - } - - #[inline] - #[cfg(feature = "db_extension")] - fn compute_on_random_inner( - &self, - rng: &mut impl rand::Rng, - f: &mut impl FnMut(usize, usize) -> usize, - _guard: &Guard, - ) -> Result, ArtError> { - let mut node = BaseNode::read_lock_root(self.root)?; - - let mut key_tracker = crate::utils::KeyTracker::default(); - - loop { - for k in node.as_ref().prefix() { - key_tracker.push(*k); - } - - let child_node = node.as_ref().get_random_child(rng); - node.check_version()?; - - let (k, child_node) = match child_node { - Some(n) => n, - None => return Ok(None), - }; - - key_tracker.push(k); - - if let Some(last_level_key) = key_tracker.as_last_level::() { - let new_v = f( - last_level_key.to_usize_key(), - child_node.as_payload(&last_level_key), - ); - if new_v == child_node.as_payload(&last_level_key) { - // Don't acquire the lock if the value is not changed - return Ok(Some((last_level_key.to_usize_key(), new_v, new_v))); - } - - let mut write_n = node.upgrade().map_err(|(_n, v)| v)?; - - let old_v = write_n.as_mut().change(k, NodePtr::from_payload(new_v)); - - debug_assert_eq!( - old_v.as_payload(&last_level_key), - child_node.as_payload(&last_level_key) - ); - - return Ok(Some(( - last_level_key.to_usize_key(), - child_node.as_payload(&last_level_key), - new_v, - ))); - } - - node = BaseNode::read_lock::(child_node, key_tracker.len())?; - } - } } diff --git a/src/utils.rs b/src/utils.rs index 63c7200..788a7ca 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -131,7 +131,7 @@ impl KeyTracker { if key_tracker.len() == MAX_KEY_LEN { cur_key } else { - let node_ref = BaseNode::read_lock::(node, 0).unwrap(); + let node_ref = BaseNode::read_lock_deprecated::(node, 0).unwrap(); let n_prefix = node_ref.as_ref().prefix().iter().skip(key_tracker.len()); for i in n_prefix { cur_key.push(*i);