diff --git a/src/lib.rs b/src/lib.rs index c6d31d2..a723dca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,6 +70,62 @@ impl Allocator for DefaultAllocator { } } +pub struct U64Congee< + V: Clone + From + Into, + A: Allocator + Clone + 'static = DefaultAllocator, +> { + inner: RawCongee<8, A>, + pt_val: PhantomData, +} + +impl + Into> Default for U64Congee { + fn default() -> Self { + Self::new() + } +} + +impl + Into> U64Congee { + pub fn new() -> Self { + Self { + inner: RawCongee::new(DefaultAllocator {}), + pt_val: PhantomData, + } + } + + pub fn get(&self, key: u64, guard: &epoch::Guard) -> Option { + let key: [u8; 8] = key.to_be_bytes(); + let v = self.inner.get(&key, guard)?; + Some(V::from(v)) + } + + pub fn insert(&self, key: u64, val: V, guard: &epoch::Guard) -> Result, OOMError> { + let key: [u8; 8] = key.to_be_bytes(); + let val = val.into(); + self.inner + .insert(&key, val, guard) + .map(|v| v.map(|v| V::from(v))) + } + + pub fn remove(&self, key: u64, guard: &epoch::Guard) -> Option { + let key: [u8; 8] = key.to_be_bytes(); + let (old, new) = self.inner.compute_if_present(&key, &mut |_v| None, guard)?; + debug_assert!(new.is_none()); + Some(V::from(old)) + } + + pub fn range( + &self, + start: u64, + end: u64, + result: &mut [([u8; 8], usize)], + guard: &epoch::Guard, + ) -> usize { + let start: [u8; 8] = start.to_be_bytes(); + let end: [u8; 8] = end.to_be_bytes(); + self.inner.range(&start, &end, result, guard) + } +} + /// The adaptive radix tree. pub struct Congee< K: Clone + From, @@ -229,7 +285,13 @@ where let end = usize::from(end.clone()); let start: [u8; 8] = start.to_be_bytes(); let end: [u8; 8] = end.to_be_bytes(); - self.inner.range(&start, &end, result, guard) + let result_ref = unsafe { + std::slice::from_raw_parts_mut( + result.as_mut_ptr() as *mut ([u8; 8], usize), + result.len(), + ) + }; + self.inner.range(&start, &end, result_ref, guard) } /// Compute and update the value if the key presents in the tree. diff --git a/src/range_scan.rs b/src/range_scan.rs index 96fee22..e8de283 100644 --- a/src/range_scan.rs +++ b/src/range_scan.rs @@ -1,4 +1,3 @@ -use crate::base_node::MAX_KEY_LEN; use crate::error::ArtError; use crate::{base_node::BaseNode, lock::ReadGuard, node_ptr::NodePtr, utils::KeyTracker}; use std::cmp; @@ -12,7 +11,7 @@ enum PrefixCheckEqualsResult { pub(crate) struct RangeScan<'a, const K_LEN: usize> { start: &'a [u8; K_LEN], end: &'a [u8; K_LEN], - result: &'a mut [(usize, usize)], + result: &'a mut [([u8; K_LEN], usize)], root: *const BaseNode, to_continue: usize, result_found: usize, @@ -22,7 +21,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { pub(crate) fn new( start: &'a [u8; K_LEN], end: &'a [u8; K_LEN], - result: &'a mut [(usize, usize)], + result: &'a mut [([u8; K_LEN], usize)], root: *const BaseNode, ) -> Self { Self { @@ -94,7 +93,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { key_tracker.push(k); - if key_tracker.len() == MAX_KEY_LEN { + if key_tracker.len() == K_LEN { self.copy_node(n, &key_tracker)?; } else if k == start_level { self.find_start(n, &node, key_tracker.clone())?; @@ -118,7 +117,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { }; node.check_version()?; - if key_tracker.len() == (MAX_KEY_LEN - 1) { + if key_tracker.len() == (K_LEN - 1) { self.copy_node(next_node_tmp, &key_tracker)?; return Ok(self.result_found); } @@ -172,7 +171,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { key_tracker.push(k); - if key_tracker.len() == MAX_KEY_LEN { + if key_tracker.len() == K_LEN { self.copy_node(n, &key_tracker)?; } else if k == end_level { self.find_end(n, &node, key_tracker.clone())?; @@ -223,7 +222,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { node.check_version()?; key_tracker.push(k); - if key_tracker.len() == MAX_KEY_LEN { + if key_tracker.len() == K_LEN { self.copy_node(n, &key_tracker)?; } else if k == start_level { self.find_start(n, &node, key_tracker.clone())?; @@ -243,13 +242,13 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { } fn copy_node(&mut self, node: NodePtr, key_tracker: &KeyTracker) -> Result<(), ArtError> { - if key_tracker.len() == MAX_KEY_LEN { + if key_tracker.len() == K_LEN { if self.key_in_range(key_tracker) { if self.result_found == self.result.len() { self.to_continue = node.as_tid(); return Ok(()); } - self.result[self.result_found] = (key_tracker.to_usize_key(), node.as_tid()); + self.result[self.result_found] = (key_tracker.get_key(), node.as_tid()); self.result_found += 1; }; } else { diff --git a/src/tests/scan.rs b/src/tests/scan.rs index f1f18b0..18a1acd 100644 --- a/src/tests/scan.rs +++ b/src/tests/scan.rs @@ -23,7 +23,7 @@ fn small_scan() { let low_key: [u8; 8] = low_v.to_be_bytes(); let high_key: [u8; 8] = (low_v + scan_cnt).to_be_bytes(); - let mut results = [(0, 0); 20]; + let mut results = [([0; 8], 0); 20]; let scan_r = tree.range(&low_key, &high_key, &mut results, &guard); assert_eq!(scan_r, scan_cnt); @@ -60,7 +60,7 @@ fn large_scan() { let low_key: [u8; 8] = low_key_v.to_be_bytes(); let high_key: [u8; 8] = (low_key_v + scan_cnt).to_be_bytes(); - let mut scan_results = vec![(0, 0); *scan_cnt]; + let mut scan_results = vec![([0; 8], 0); *scan_cnt]; let r_found = tree.range(&low_key, &high_key, &mut scan_results, &guard); assert_eq!(r_found, *scan_cnt); @@ -78,7 +78,7 @@ fn large_scan() { let low_key: [u8; 8] = low_key_v.to_be_bytes(); let high_key: [u8; 8] = (low_key_v + scan_cnt).to_be_bytes(); - let mut scan_results = vec![(0, 0); *scan_cnt]; + let mut scan_results = vec![([0; 8], 0); *scan_cnt]; let r_found = tree.range(&low_key, &high_key, &mut scan_results, &guard); assert_eq!(r_found, 0); } @@ -112,7 +112,7 @@ fn large_scan_small_buffer() { let low_key: [u8; 8] = low_key_v.to_be_bytes(); let high_key: [u8; 8] = (low_key_v + scan_cnt * 5).to_be_bytes(); - let mut scan_results = vec![(0, 0); (*scan_cnt) / 2]; + let mut scan_results = vec![([0; 8], 0); (*scan_cnt) / 2]; let r_found = tree.range(&low_key, &high_key, &mut scan_results, &guard); assert_eq!(r_found, *scan_cnt / 2); @@ -126,7 +126,7 @@ fn large_scan_small_buffer() { let scan_cnt = scan_counts.choose(&mut r).unwrap(); let low_key: [u8; 8] = 0x6_0000usize.to_be_bytes(); let high_key: [u8; 8] = (0x6_ffff as usize).to_be_bytes(); - let mut scan_results = vec![(0, 0); *scan_cnt]; + let mut scan_results = vec![([0; 8], 0); *scan_cnt]; let r_found = tree.range(&low_key, &high_key, &mut scan_results, &guard); assert_eq!(r_found, *scan_cnt); @@ -169,7 +169,7 @@ fn test_insert_and_scan() { let low_key: [u8; 8] = low_key_v.to_be_bytes(); let high_key: [u8; 8] = (low_key_v + scan_cnt).to_be_bytes(); - let mut scan_results = vec![(0, 0); *scan_cnt]; + let mut scan_results = vec![([0; 8], 0); *scan_cnt]; let _v = tree.range(&low_key, &high_key, &mut scan_results, &guard); })); } @@ -212,7 +212,7 @@ fn fuzz_0() { let low_key: [u8; 8] = 0usize.to_be_bytes(); let high_key: [u8; 8] = 0usize.to_be_bytes(); - let mut results = vec![(0, 0); 255]; + let mut results = vec![([0; 8], 0); 255]; let scanned = tree.range(&low_key, &high_key, &mut results, &guard); assert_eq!(scanned, 0); } @@ -230,7 +230,7 @@ fn fuzz_1() { let low_key: [u8; 8] = scan_key.to_be_bytes(); let high_key: [u8; 8] = (scan_key + 255).to_be_bytes(); - let mut results = vec![(0, 0); 256]; + let mut results = vec![([0; 8], 0); 256]; let scanned = tree.range(&low_key, &high_key, &mut results, &guard); assert_eq!(scanned, 0); @@ -257,7 +257,7 @@ fn fuzz_2() { let low_key: [u8; 8] = scan_key.to_be_bytes(); let high_key: [u8; 8] = (scan_key + 253).to_be_bytes(); - let mut results = vec![(0, 0); 256]; + let mut results = vec![([0; 8], 0); 256]; let scanned = tree.range(&low_key, &high_key, &mut results, &guard); assert_eq!(scanned, 0); } @@ -279,7 +279,7 @@ fn fuzz_3() { let low_key: [u8; 8] = scan_key.to_be_bytes(); let high_key: [u8; 8] = (scan_key + 253).to_be_bytes(); - let mut results = vec![(0, 0); 256]; + let mut results = vec![([0; 8], 0); 256]; let scanned = tree.range(&low_key, &high_key, &mut results, &guard); assert_eq!(scanned, 2); @@ -308,7 +308,7 @@ fn fuzz_4() { let low_key: [u8; 8] = scan_key.to_be_bytes(); let high_key: [u8; 8] = (scan_key + 253).to_be_bytes(); - let mut results = vec![(0, 0); 256]; + let mut results = vec![([0; 8], 0); 256]; let scanned = tree.range(&low_key, &high_key, &mut results, &guard); assert_eq!(scanned, 0); } @@ -330,7 +330,7 @@ fn fuzz_5() { let low_key: [u8; 8] = scan_key.to_be_bytes(); let high_key: [u8; 8] = (scan_key + 255).to_be_bytes(); - let mut results = vec![(0, 0); 256]; + let mut results = vec![([0; 8], 0); 256]; let scanned = tree.range(&low_key, &high_key, &mut results, &guard); assert_eq!(scanned, 1); } @@ -352,7 +352,7 @@ fn fuzz_6() { let low_key: [u8; 8] = scan_key.to_be_bytes(); let high_key: [u8; 8] = (scan_key + 255).to_be_bytes(); - let mut results = vec![(0, 0); 256]; + let mut results = vec![([0; 8], 0); 256]; let scanned = tree.range(&low_key, &high_key, &mut results, &guard); assert_eq!(scanned, 1); } diff --git a/src/tree.rs b/src/tree.rs index 70867a1..e29ce94 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -337,7 +337,7 @@ impl RawCongee { &self, start: &[u8; K_LEN], end: &[u8; K_LEN], - result: &mut [(usize, usize)], + result: &mut [([u8; K_LEN], usize)], _guard: &Guard, ) -> usize { let mut range_scan = RangeScan::new(start, end, result, self.root as *const BaseNode); diff --git a/src/utils.rs b/src/utils.rs index 6951d4c..dc43f75 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -108,6 +108,11 @@ impl KeyTracker { val.swap_bytes() } + pub(crate) fn get_key(&self) -> [u8; N] { + assert!(self.len == N); + self.data[..N].try_into().unwrap() + } + #[inline] pub(crate) fn append_prefix(node: NodePtr, key_tracker: &KeyTracker) -> KeyTracker { let mut cur_key = key_tracker.clone();