Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Nov 21, 2024
1 parent 72c5fc2 commit ae058af
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 56 deletions.
8 changes: 8 additions & 0 deletions src/node_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ptr::NonNull;
use crate::{
base_node::{BaseNode, Node},
node_256::Node256,
utils::KeyTracker,
};

pub(crate) struct ChildIsPayload<'a> {
Expand Down Expand Up @@ -81,6 +82,13 @@ impl NodePtr {
PtrType::SubNode(unsafe { self.sub_node })
}
}

pub(crate) fn downcast_key_tracker<const K_LEN: usize>(
&self,
key_tracker: &KeyTracker<K_LEN>,
) -> PtrType {
self.downcast::<K_LEN>(key_tracker.len() - 1)
}
}

pub(crate) struct AllocatedNode<N: Node> {
Expand Down
131 changes: 76 additions & 55 deletions src/range_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub(crate) struct RangeScan<'a, const K_LEN: usize> {
end: &'a [u8; K_LEN],
result: &'a mut [([u8; K_LEN], usize)],
root: NonNull<Node256>,
to_continue: usize,
to_continue: bool,
result_found: usize,
}

Expand All @@ -33,7 +33,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
end,
result,
root,
to_continue: 0,
to_continue: false,
result_found: 0,
}
}
Expand All @@ -49,7 +49,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
pub(crate) fn scan(&mut self) -> Result<usize, ArtError> {
let mut node = BaseNode::read_lock_root(self.root)?;
let mut parent_node: Option<ReadGuard> = None;
self.to_continue = 0;
self.to_continue = false;
self.result_found = 0;

let mut key_tracker = KeyTracker::empty();
Expand Down Expand Up @@ -85,19 +85,24 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {

key_tracker.push(k);

if key_tracker.len() == K_LEN {
self.copy_node(n, &key_tracker)?;
} else if k == start_level {
self.find_start(n, &node, key_tracker.clone())?;
} else if k > start_level && k < end_level {
let cur_key = KeyTracker::append_prefix(n, &key_tracker);
self.copy_node(n, &cur_key)?;
} else if k == end_level {
self.find_end(n, &node, key_tracker.clone())?;
match n.downcast_key_tracker::<K_LEN>(&key_tracker) {
PtrType::Payload(payload) => {
self.write_result(payload, &key_tracker);
}
PtrType::SubNode(sub_node) => {
if k == start_level {
self.find_start(sub_node, &node, key_tracker.clone())?;
} else if k > start_level && k < end_level {
let cur_key = KeyTracker::append_prefix(n, &key_tracker);
self.copy_node_recursive(n, &cur_key)?;
} else if k == end_level {
self.find_end(sub_node, &node, key_tracker.clone())?;
}
}
}
key_tracker.pop();

if self.to_continue > 0 {
if self.to_continue {
return Ok(self.result_found);
}
}
Expand All @@ -110,7 +115,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
node.check_version()?;

if key_tracker.len() == (K_LEN - 1) {
self.copy_node(next_node_tmp, &key_tracker)?;
self.copy_node_recursive(next_node_tmp, &key_tracker)?;
return Ok(self.result_found);
}
key_tracker.push(start_level);
Expand All @@ -124,7 +129,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
return Ok(self.result_found);
}
PrefixCheckEqualsResult::Contained => {
self.copy_node(NodePtr::from_node(node.as_ref()), &key_tracker)?;
self.copy_node_recursive(NodePtr::from_node(node.as_ref()), &key_tracker)?;
return Ok(self.result_found);
}
PrefixCheckEqualsResult::NotMatch => {
Expand All @@ -136,13 +141,11 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {

fn find_end(
&mut self,
node: NodePtr,
node: NonNull<BaseNode>,
parent_node: &ReadGuard,
mut key_tracker: KeyTracker<K_LEN>,
) -> Result<(), ArtError> {
debug_assert!(key_tracker.len() != K_LEN);

let node = BaseNode::read_lock_deprecated::<K_LEN>(node, key_tracker.len())?;
let node = BaseNode::read_lock(node)?;
let prefix_result =
self.check_prefix_compare(node.as_ref(), self.end, 255, &mut key_tracker);
let level = key_tracker.len();
Expand All @@ -165,34 +168,40 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {

key_tracker.push(k);

if key_tracker.len() == K_LEN {
self.copy_node(n, &key_tracker)?;
} else if k == end_level {
self.find_end(n, &node, key_tracker.clone())?;
} else if k < end_level {
let cur_key = KeyTracker::append_prefix(n, &key_tracker);
self.copy_node(n, &cur_key)?;
}
match n.downcast_key_tracker::<K_LEN>(&key_tracker) {
PtrType::Payload(payload) => {
self.write_result(payload, &key_tracker);
}
PtrType::SubNode(sub_node) => {
if k == end_level {
self.find_end(sub_node, &node, key_tracker.clone())?;
} else if k < end_level {
let cur_key = KeyTracker::append_prefix(n, &key_tracker);
self.copy_node_recursive(n, &cur_key)?;
}
}
};

key_tracker.pop();
if self.to_continue != 0 {
if self.to_continue {
break;
}
}
Ok(())
}
cmp::Ordering::Less => self.copy_node(NodePtr::from_node(node.as_ref()), &key_tracker),
cmp::Ordering::Less => {
self.copy_node_recursive(NodePtr::from_node(node.as_ref()), &key_tracker)
}
}
}

fn find_start(
&mut self,
node: NodePtr,
node: NonNull<BaseNode>,
parent_node: &ReadGuard,
mut key_tracker: KeyTracker<K_LEN>,
) -> Result<(), ArtError> {
debug_assert!(key_tracker.len() != K_LEN);

let node = BaseNode::read_lock_deprecated::<K_LEN>(node, key_tracker.len())?;
let node = BaseNode::read_lock(node)?;
let prefix_result =
self.check_prefix_compare(node.as_ref(), self.start, 0, &mut key_tracker);

Expand All @@ -201,7 +210,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {

match prefix_result {
cmp::Ordering::Greater => {
self.copy_node(NodePtr::from_node(node.as_ref()), &key_tracker)
self.copy_node_recursive(NodePtr::from_node(node.as_ref()), &key_tracker)
}
cmp::Ordering::Equal => {
let start_level = if self.start.len() > key_tracker.len() {
Expand All @@ -216,16 +225,23 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
node.check_version()?;

key_tracker.push(k);
if key_tracker.len() == K_LEN {
self.copy_node(n, &key_tracker)?;
} else if k == start_level {
self.find_start(n, &node, key_tracker.clone())?;
} else if k > start_level {
let cur_key = KeyTracker::append_prefix(n, &key_tracker);
self.copy_node(n, &cur_key)?;

match n.downcast_key_tracker::<K_LEN>(&key_tracker) {
PtrType::Payload(payload) => {
self.write_result(payload, &key_tracker);
}
PtrType::SubNode(sub_node) => {
if k == start_level {
self.find_start(sub_node, &node, key_tracker.clone())?;
} else if k > start_level {
let cur_key = KeyTracker::append_prefix(n, &key_tracker);
self.copy_node_recursive(n, &cur_key)?;
}
}
}

key_tracker.pop();
if self.to_continue != 0 {
if self.to_continue {
break;
}
}
Expand All @@ -235,22 +251,27 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
}
}

fn copy_node(
fn write_result(&mut self, payload: usize, key_tracker: &KeyTracker<K_LEN>) {
let last_level_key = unsafe { key_tracker.as_last_level_unchecked() };
if self.key_in_range(&last_level_key) {
if self.result_found == self.result.len() {
self.to_continue = true;
return;
}
self.result[self.result_found] = (key_tracker.get_key(), payload);
self.result_found += 1;
}
}

/// Copy this node and all its children recursively.
fn copy_node_recursive(
&mut self,
node: NodePtr,
key_tracker: &KeyTracker<K_LEN>,
) -> Result<(), ArtError> {
match node.downcast::<K_LEN>(key_tracker.len() - 1) {
match node.downcast_key_tracker::<K_LEN>(key_tracker) {
PtrType::Payload(payload) => {
let last_level_key = unsafe { key_tracker.as_last_level_unchecked() };
if self.key_in_range(&last_level_key) {
if self.result_found == self.result.len() {
self.to_continue = payload;
return Ok(());
}
self.result[self.result_found] = (key_tracker.get_key(), payload);
self.result_found += 1;
};
self.write_result(payload, key_tracker);
}
PtrType::SubNode(sub_node) => {
let node = BaseNode::read_lock(sub_node)?;
Expand All @@ -264,9 +285,9 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
key_tracker.push(k);

let cur_key = KeyTracker::append_prefix(c, &key_tracker);
self.copy_node(c, &cur_key)?;
self.copy_node_recursive(c, &cur_key)?;

if self.to_continue != 0 {
if self.to_continue {
break;
}

Expand Down
1 change: 0 additions & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ impl<const K_LEN: usize> KeyTracker<K_LEN> {
LastLevelKey { key: self }
}


pub(crate) fn get_key<const N: usize>(&self) -> [u8; N] {
assert!(self.len == N);
self.data[..N].try_into().unwrap()
Expand Down

0 comments on commit ae058af

Please sign in to comment.