From ae058af2259996498d23b947f63de314d57f3a9b Mon Sep 17 00:00:00 2001 From: Xiangpeng Hao Date: Wed, 20 Nov 2024 18:08:04 -0600 Subject: [PATCH] refactor --- src/node_ptr.rs | 8 +++ src/range_scan.rs | 131 +++++++++++++++++++++++++++------------------- src/utils.rs | 1 - 3 files changed, 84 insertions(+), 56 deletions(-) diff --git a/src/node_ptr.rs b/src/node_ptr.rs index 3987882..4ca9c93 100644 --- a/src/node_ptr.rs +++ b/src/node_ptr.rs @@ -3,6 +3,7 @@ use std::ptr::NonNull; use crate::{ base_node::{BaseNode, Node}, node_256::Node256, + utils::KeyTracker, }; pub(crate) struct ChildIsPayload<'a> { @@ -81,6 +82,13 @@ impl NodePtr { PtrType::SubNode(unsafe { self.sub_node }) } } + + pub(crate) fn downcast_key_tracker( + &self, + key_tracker: &KeyTracker, + ) -> PtrType { + self.downcast::(key_tracker.len() - 1) + } } pub(crate) struct AllocatedNode { diff --git a/src/range_scan.rs b/src/range_scan.rs index beeeffb..c0301a7 100644 --- a/src/range_scan.rs +++ b/src/range_scan.rs @@ -17,7 +17,7 @@ pub(crate) struct RangeScan<'a, const K_LEN: usize> { end: &'a [u8; K_LEN], result: &'a mut [([u8; K_LEN], usize)], root: NonNull, - to_continue: usize, + to_continue: bool, result_found: usize, } @@ -33,7 +33,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { end, result, root, - to_continue: 0, + to_continue: false, result_found: 0, } } @@ -49,7 +49,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { pub(crate) fn scan(&mut self) -> Result { let mut node = BaseNode::read_lock_root(self.root)?; let mut parent_node: Option = None; - self.to_continue = 0; + self.to_continue = false; self.result_found = 0; let mut key_tracker = KeyTracker::empty(); @@ -85,19 +85,24 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { key_tracker.push(k); - if key_tracker.len() == K_LEN { - self.copy_node(n, &key_tracker)?; - } else if k == start_level { - self.find_start(n, &node, key_tracker.clone())?; - } else if k > start_level && k < end_level { - let cur_key = KeyTracker::append_prefix(n, &key_tracker); - self.copy_node(n, &cur_key)?; - } else if k == end_level { - self.find_end(n, &node, key_tracker.clone())?; + match n.downcast_key_tracker::(&key_tracker) { + PtrType::Payload(payload) => { + self.write_result(payload, &key_tracker); + } + PtrType::SubNode(sub_node) => { + if k == start_level { + self.find_start(sub_node, &node, key_tracker.clone())?; + } else if k > start_level && k < end_level { + let cur_key = KeyTracker::append_prefix(n, &key_tracker); + self.copy_node_recursive(n, &cur_key)?; + } else if k == end_level { + self.find_end(sub_node, &node, key_tracker.clone())?; + } + } } key_tracker.pop(); - if self.to_continue > 0 { + if self.to_continue { return Ok(self.result_found); } } @@ -110,7 +115,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { node.check_version()?; if key_tracker.len() == (K_LEN - 1) { - self.copy_node(next_node_tmp, &key_tracker)?; + self.copy_node_recursive(next_node_tmp, &key_tracker)?; return Ok(self.result_found); } key_tracker.push(start_level); @@ -124,7 +129,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { return Ok(self.result_found); } PrefixCheckEqualsResult::Contained => { - self.copy_node(NodePtr::from_node(node.as_ref()), &key_tracker)?; + self.copy_node_recursive(NodePtr::from_node(node.as_ref()), &key_tracker)?; return Ok(self.result_found); } PrefixCheckEqualsResult::NotMatch => { @@ -136,13 +141,11 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { fn find_end( &mut self, - node: NodePtr, + node: NonNull, parent_node: &ReadGuard, mut key_tracker: KeyTracker, ) -> Result<(), ArtError> { - debug_assert!(key_tracker.len() != K_LEN); - - let node = BaseNode::read_lock_deprecated::(node, key_tracker.len())?; + let node = BaseNode::read_lock(node)?; let prefix_result = self.check_prefix_compare(node.as_ref(), self.end, 255, &mut key_tracker); let level = key_tracker.len(); @@ -165,34 +168,40 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { key_tracker.push(k); - if key_tracker.len() == K_LEN { - self.copy_node(n, &key_tracker)?; - } else if k == end_level { - self.find_end(n, &node, key_tracker.clone())?; - } else if k < end_level { - let cur_key = KeyTracker::append_prefix(n, &key_tracker); - self.copy_node(n, &cur_key)?; - } + match n.downcast_key_tracker::(&key_tracker) { + PtrType::Payload(payload) => { + self.write_result(payload, &key_tracker); + } + PtrType::SubNode(sub_node) => { + if k == end_level { + self.find_end(sub_node, &node, key_tracker.clone())?; + } else if k < end_level { + let cur_key = KeyTracker::append_prefix(n, &key_tracker); + self.copy_node_recursive(n, &cur_key)?; + } + } + }; + key_tracker.pop(); - if self.to_continue != 0 { + if self.to_continue { break; } } Ok(()) } - cmp::Ordering::Less => self.copy_node(NodePtr::from_node(node.as_ref()), &key_tracker), + cmp::Ordering::Less => { + self.copy_node_recursive(NodePtr::from_node(node.as_ref()), &key_tracker) + } } } fn find_start( &mut self, - node: NodePtr, + node: NonNull, parent_node: &ReadGuard, mut key_tracker: KeyTracker, ) -> Result<(), ArtError> { - debug_assert!(key_tracker.len() != K_LEN); - - let node = BaseNode::read_lock_deprecated::(node, key_tracker.len())?; + let node = BaseNode::read_lock(node)?; let prefix_result = self.check_prefix_compare(node.as_ref(), self.start, 0, &mut key_tracker); @@ -201,7 +210,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { match prefix_result { cmp::Ordering::Greater => { - self.copy_node(NodePtr::from_node(node.as_ref()), &key_tracker) + self.copy_node_recursive(NodePtr::from_node(node.as_ref()), &key_tracker) } cmp::Ordering::Equal => { let start_level = if self.start.len() > key_tracker.len() { @@ -216,16 +225,23 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { node.check_version()?; key_tracker.push(k); - if key_tracker.len() == K_LEN { - self.copy_node(n, &key_tracker)?; - } else if k == start_level { - self.find_start(n, &node, key_tracker.clone())?; - } else if k > start_level { - let cur_key = KeyTracker::append_prefix(n, &key_tracker); - self.copy_node(n, &cur_key)?; + + match n.downcast_key_tracker::(&key_tracker) { + PtrType::Payload(payload) => { + self.write_result(payload, &key_tracker); + } + PtrType::SubNode(sub_node) => { + if k == start_level { + self.find_start(sub_node, &node, key_tracker.clone())?; + } else if k > start_level { + let cur_key = KeyTracker::append_prefix(n, &key_tracker); + self.copy_node_recursive(n, &cur_key)?; + } + } } + key_tracker.pop(); - if self.to_continue != 0 { + if self.to_continue { break; } } @@ -235,22 +251,27 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { } } - fn copy_node( + fn write_result(&mut self, payload: usize, key_tracker: &KeyTracker) { + let last_level_key = unsafe { key_tracker.as_last_level_unchecked() }; + if self.key_in_range(&last_level_key) { + if self.result_found == self.result.len() { + self.to_continue = true; + return; + } + self.result[self.result_found] = (key_tracker.get_key(), payload); + self.result_found += 1; + } + } + + /// Copy this node and all its children recursively. + fn copy_node_recursive( &mut self, node: NodePtr, key_tracker: &KeyTracker, ) -> Result<(), ArtError> { - match node.downcast::(key_tracker.len() - 1) { + match node.downcast_key_tracker::(key_tracker) { PtrType::Payload(payload) => { - let last_level_key = unsafe { key_tracker.as_last_level_unchecked() }; - if self.key_in_range(&last_level_key) { - if self.result_found == self.result.len() { - self.to_continue = payload; - return Ok(()); - } - self.result[self.result_found] = (key_tracker.get_key(), payload); - self.result_found += 1; - }; + self.write_result(payload, key_tracker); } PtrType::SubNode(sub_node) => { let node = BaseNode::read_lock(sub_node)?; @@ -264,9 +285,9 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { key_tracker.push(k); let cur_key = KeyTracker::append_prefix(c, &key_tracker); - self.copy_node(c, &cur_key)?; + self.copy_node_recursive(c, &cur_key)?; - if self.to_continue != 0 { + if self.to_continue { break; } diff --git a/src/utils.rs b/src/utils.rs index 5863d7b..a68a86e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -122,7 +122,6 @@ impl KeyTracker { LastLevelKey { key: self } } - pub(crate) fn get_key(&self) -> [u8; N] { assert!(self.len == N); self.data[..N].try_into().unwrap()