Skip to content

Commit

Permalink
refactor and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Nov 20, 2024
1 parent 4ae1789 commit d7a2a74
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 259 deletions.
1 change: 0 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
{
"rust-analyzer.cargo.features": [
"stats",
"db_extension"
]
}
2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ license = "MIT"

[dependencies]
crossbeam-epoch = "0.9.18"
rand = { version = "0.8.5", optional = true }
serde = { version = "1.0.214", features = ["derive"], optional = true }

[dev-dependencies]
Expand Down Expand Up @@ -42,7 +41,6 @@ harness = false
flamegraph = ["shumai/flamegraph"]
perf = ["shumai/perf"]
stats = ["serde"]
db_extension = ["rand"]
shuttle = []

[profile.bench]
Expand Down
13 changes: 2 additions & 11 deletions src/base_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ pub(crate) trait Node {
fn remove(&mut self, k: u8);
fn copy_to<N: Node>(&self, dst: &mut N);
fn get_type() -> NodeType;

#[cfg(feature = "db_extension")]
fn get_random_child(&self, rng: &mut impl rand::Rng) -> Option<(u8, NodePtr)>;
}

pub(crate) enum NodeIter<'a> {
Expand Down Expand Up @@ -174,12 +171,6 @@ macro_rules! gen_method_mut {
gen_method!(get_child, (k: u8), Option<NodePtr>);
gen_method!(get_children, (start: u8, end: u8), NodeIter<'_>);

#[cfg(feature = "db_extension")]
gen_method!(
get_random_child,
(rng: &mut impl rand::Rng),
Option<(u8, NodePtr)>
);
gen_method_mut!(change, (key: u8, val: NodePtr), NodePtr);
gen_method_mut!(remove, (key: u8), ());

Expand Down Expand Up @@ -252,11 +243,11 @@ impl BaseNode {
Ok(ReadGuard::new(version, node))
}

pub(crate) fn read_lock2<'a>(node: NonNull<BaseNode>) -> Result<ReadGuard<'a>, ArtError> {
pub(crate) fn read_lock<'a>(node: NonNull<BaseNode>) -> Result<ReadGuard<'a>, ArtError> {
Self::read_lock_inner(node)
}

pub(crate) fn read_lock<'a, const MAX_LEVEL: usize>(
pub(crate) fn read_lock_deprecated<'a, const MAX_LEVEL: usize>(
node: NodePtr,
current_level: usize,
) -> Result<ReadGuard<'a>, ArtError> {
Expand Down
42 changes: 0 additions & 42 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,48 +381,6 @@ where
self.inner.stats()
}

/// Get a random value from the tree, perform the transformation `f`.
/// This is useful for randomized algorithms.
///
/// `f` takes key and value as input and return the new value, |key: usize, value: usize| -> usize.
///
/// Returns (key, old_value, new_value)
///
/// Note that the function `f` is a FnMut and it must be safe to execute multiple times.
/// The `f` is expected to be short and fast as it will hold a exclusive lock on the leaf node.
///
/// # Examples:
/// ```
/// use congee::Congee;
/// let tree = Congee::default();
/// let guard = tree.pin();
/// tree.insert(1, 42, &guard);
/// let mut rng = rand::thread_rng();
/// let (key, old_v, new_v) = tree.compute_on_random(&mut rng, |k, v| {
/// assert_eq!(k, 1);
/// assert_eq!(v, 42);
/// v + 1
/// }, &guard).unwrap();
/// assert_eq!(key, 1);
/// assert_eq!(old_v, 42);
/// assert_eq!(new_v, 43);
/// ```
#[cfg(feature = "db_extension")]
#[cfg_attr(docsrs, doc(cfg(feature = "db_extension")))]
pub fn compute_on_random(
&self,
rng: &mut impl rand::Rng,
mut f: impl FnMut(K, V) -> V,
guard: &epoch::Guard,
) -> Option<(K, V, V)> {
let mut remapped = |key: usize, value: usize| -> usize {
let v = f(K::from(key), V::from(value));
usize::from(v)
};
let (key, old_v, new_v) = self.inner.compute_on_random(rng, &mut remapped, guard)?;
Some((K::from(key), V::from(old_v), V::from(new_v)))
}

/// Update the value if the old value matches with the new one.
/// Returns the current value.
///
Expand Down
50 changes: 6 additions & 44 deletions src/node_16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,45 +22,17 @@ impl Node16 {
val ^ 128
}

#[cfg(all(target_feature = "sse2", not(miri)))]
fn ctz(val: u16) -> u16 {
val.trailing_zeros() as u16
}

fn get_insert_pos(&self, key: u8) -> usize {
let flipped = Self::flip_sign(key);

#[cfg(all(target_feature = "sse2", not(miri)))]
{
unsafe {
use std::arch::x86_64::{
__m128i, _mm_cmplt_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8,
};
let cmp = _mm_cmplt_epi8(
_mm_set1_epi8(flipped as i8),
_mm_loadu_si128(&self.keys as *const [u8; 16] as *const __m128i),
);
let bit_field = _mm_movemask_epi8(cmp) & (0xFFFF >> (16 - self.base.meta.count));
let pos = if bit_field > 0 {
Self::ctz(bit_field as u16)
} else {
self.base.meta.count
};
pos as usize
}
}

#[cfg(any(not(target_feature = "sse2"), miri))]
{
let mut pos = 0;
while pos < self.base.meta.count {
if self.keys[pos as usize] >= flipped {
return pos as usize;
}
pos += 1;
let mut pos = 0;
while pos < self.base.meta.count {
if self.keys[pos as usize] >= flipped {
return pos as usize;
}
pos as usize
pos += 1;
}
pos as usize
}

fn get_child_pos(&self, key: u8) -> Option<usize> {
Expand Down Expand Up @@ -199,14 +171,4 @@ impl Node for Node16 {
let child = unsafe { self.children.get_unchecked(pos) };
Some(*child)
}

#[cfg(feature = "db_extension")]
fn get_random_child(&self, rng: &mut impl rand::Rng) -> Option<(u8, NodePtr)> {
if self.base.meta.count == 0 {
return None;
}

let idx = rng.gen_range(0..self.base.meta.count);
Some((self.keys[idx as usize], self.children[idx as usize]))
}
}
16 changes: 0 additions & 16 deletions src/node_256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,4 @@ impl Node for Node256 {
None
}
}

#[cfg(feature = "db_extension")]
fn get_random_child(&self, rng: &mut impl rand::Rng) -> Option<(u8, NodePtr)> {
if self.base.meta.count == 0 {
return None;
}
let mut scan_cnt = rng.gen_range(1..=self.base.meta.count);
let mut idx = 0;
while scan_cnt > 0 {
if self.get_mask(idx) {
scan_cnt -= 1;
}
idx += 1;
}
Some(((idx - 1) as u8, self.children[idx - 1]))
}
}
9 changes: 0 additions & 9 deletions src/node_4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,4 @@ impl Node for Node4 {
.find(|(k, _)| **k == key)
.map(|(_, c)| *c)
}

#[cfg(feature = "db_extension")]
fn get_random_child(&self, rng: &mut impl rand::Rng) -> Option<(u8, NodePtr)> {
if self.base.meta.count == 0 {
return None;
}
let idx = rng.gen_range(0..self.base.meta.count);
Some((self.keys[idx as usize], self.children[idx as usize]))
}
}
21 changes: 0 additions & 21 deletions src/node_48.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,25 +123,4 @@ impl Node for Node48 {
Some(*child)
}
}

#[cfg(feature = "db_extension")]
fn get_random_child(&self, rng: &mut impl rand::Rng) -> Option<(u8, NodePtr)> {
if self.base.meta.count == 0 {
return None;
}

let mut scan_cnt = rng.gen_range(1..=self.base.meta.count);
let mut idx = 0;
while scan_cnt > 0 {
if self.child_idx[idx] != EMPTY_MARKER {
scan_cnt -= 1;
}
idx += 1;
}

Some((
(idx - 1) as u8,
self.children[self.child_idx[idx - 1] as usize],
))
}
}
2 changes: 1 addition & 1 deletion src/node_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl NodePtr {
unsafe { self.sub_node }
}

pub(crate) fn node_type<const K_LEN: usize>(&self, current_level: usize) -> PtrType {
pub(crate) fn downcast<const K_LEN: usize>(&self, current_level: usize) -> PtrType {
if current_level == (K_LEN - 1) {
PtrType::Payload(unsafe { self.as_payload_unchecked() })
} else {
Expand Down
8 changes: 4 additions & 4 deletions src/range_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
}
key_tracker.push(start_level);

let next_node = BaseNode::read_lock::<K_LEN>(next_node_tmp, level)?;
let next_node = BaseNode::read_lock_deprecated::<K_LEN>(next_node_tmp, level)?;
parent_node = Some(node);
node = next_node;
continue;
Expand All @@ -148,7 +148,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
) -> Result<(), ArtError> {
debug_assert!(key_tracker.len() != 8);

let node = BaseNode::read_lock::<K_LEN>(node, key_tracker.len())?;
let node = BaseNode::read_lock_deprecated::<K_LEN>(node, key_tracker.len())?;
let prefix_result =
self.check_prefix_compare(node.as_ref(), self.end, 255, &mut key_tracker);
let level = key_tracker.len();
Expand Down Expand Up @@ -198,7 +198,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
) -> Result<(), ArtError> {
debug_assert!(key_tracker.len() != 8);

let node = BaseNode::read_lock::<K_LEN>(node, key_tracker.len())?;
let node = BaseNode::read_lock_deprecated::<K_LEN>(node, key_tracker.len())?;
let prefix_result =
self.check_prefix_compare(node.as_ref(), self.start, 0, &mut key_tracker);

Expand Down Expand Up @@ -253,7 +253,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
self.result_found += 1;
};
} else {
let node = BaseNode::read_lock::<K_LEN>(node, key_tracker.len())?;
let node = BaseNode::read_lock_deprecated::<K_LEN>(node, key_tracker.len())?;
let mut key_tracker = key_tracker.clone();

let children = node.as_ref().get_children(0, 255);
Expand Down
30 changes: 18 additions & 12 deletions src/stats.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::fmt::Display;
use std::{fmt::Display, ptr::NonNull};

use crate::{
base_node::{BaseNode, NodeType, MAX_KEY_LEN},
node_ptr::NodePtr,
base_node::{BaseNode, NodeType},
node_256::Node256,
node_ptr::PtrType,
tree::RawCongee,
Allocator,
};
Expand Down Expand Up @@ -97,10 +98,12 @@ impl<const K_LEN: usize, A: Allocator + Clone> RawCongee<K_LEN, A> {
pub fn stats(&self) -> NodeStats {
let mut node_stats = NodeStats::default();

let mut sub_nodes = vec![(0, 0, NodePtr::from_root(self.root))];
let mut sub_nodes = vec![(0, 0, unsafe {
std::mem::transmute::<NonNull<Node256>, NonNull<BaseNode>>(self.root)
})];

while let Some((level, key_level, node)) = sub_nodes.pop() {
let node = BaseNode::read_lock::<K_LEN>(node, level).unwrap();
let node = BaseNode::read_lock(node).unwrap();

if node_stats.0.len() <= level {
node_stats.0.push(LevelStats::new_level(level));
Expand All @@ -127,13 +130,16 @@ impl<const K_LEN: usize, A: Allocator + Clone> RawCongee<K_LEN, A> {

let children = node.as_ref().get_children(0, 255);
for (_k, n) in children {
if key_level != (MAX_KEY_LEN - 1) {
let child_node = BaseNode::read_lock::<K_LEN>(n, level).unwrap();
sub_nodes.push((
level + 1,
key_level + 1 + child_node.as_ref().prefix().len(),
n,
));
match n.downcast::<K_LEN>(level) {
PtrType::Payload(_) => {}
PtrType::SubNode(sub_node) => {
let child_node = BaseNode::read_lock(sub_node).unwrap();
sub_nodes.push((
level + 1,
key_level + 1 + child_node.as_ref().prefix().len(),
sub_node,
));
}
}
}
}
Expand Down
Loading

0 comments on commit d7a2a74

Please sign in to comment.