diff --git a/src/node_48.rs b/src/node_48.rs index 5720c57..4ea37b6 100644 --- a/src/node_48.rs +++ b/src/node_48.rs @@ -99,7 +99,7 @@ impl Node for Node48 { fn insert(&mut self, key: u8, node: NodePtr) { let pos = self.next_empty as usize; - self.next_empty = self.children[pos].as_payload() as u8; + self.next_empty = unsafe { self.children[pos].as_payload_unchecked() } as u8; debug_assert!(pos < 48); diff --git a/src/node_ptr.rs b/src/node_ptr.rs index 133fdd1..221dcfc 100644 --- a/src/node_ptr.rs +++ b/src/node_ptr.rs @@ -31,18 +31,17 @@ impl NodePtr { Self { sub_node: ptr } } - #[inline] pub(crate) fn from_payload(payload: usize) -> Self { Self { payload } } - #[inline] - pub(crate) fn as_payload(&self) -> usize { + pub(crate) unsafe fn as_payload_unchecked(&self) -> usize { unsafe { self.payload } } - pub(crate) fn as_payload_checked(&self, _proof: &mut LastLevelProof) -> usize { - unsafe { self.payload } + pub(crate) fn as_payload(&self, _proof: &LastLevelProof) -> usize { + // Safety: We have a proof that the node is at the last level + unsafe { self.as_payload_unchecked() } } pub(crate) fn as_ptr_safe( diff --git a/src/range_scan.rs b/src/range_scan.rs index ad22fb4..37bd659 100644 --- a/src/range_scan.rs +++ b/src/range_scan.rs @@ -1,5 +1,6 @@ use crate::error::ArtError; use crate::node_256::Node256; +use crate::node_ptr::LastLevelProof; use crate::{base_node::BaseNode, lock::ReadGuard, node_ptr::NodePtr, utils::KeyTracker}; use std::cmp; use std::ptr::NonNull; @@ -40,9 +41,9 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { self.start < self.end } - fn key_in_range(&self, key: &KeyTracker) -> bool { + fn key_in_range(&self, key: &KeyTracker, last_level_proof: &LastLevelProof) -> bool { debug_assert_eq!(key.len(), 8); - let cur_key = key.to_usize_key(); + let cur_key = key.to_usize_key(last_level_proof); let start_key = unsafe { *(self.start.as_ptr() as *const usize) }.swap_bytes(); let end_key = unsafe { *(self.end.as_ptr() as *const usize) }.swap_bytes(); @@ -242,13 +243,13 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { } fn copy_node(&mut self, node: NodePtr, key_tracker: &KeyTracker) -> Result<(), ArtError> { - if key_tracker.len() == K_LEN { - if self.key_in_range(key_tracker) { + if let Some(proof) = key_tracker.is_last_level::() { + if self.key_in_range(key_tracker, &proof) { if self.result_found == self.result.len() { - self.to_continue = node.as_payload(); + self.to_continue = node.as_payload(&proof); return Ok(()); } - self.result[self.result_found] = (key_tracker.get_key(), node.as_payload()); + self.result[self.result_found] = (key_tracker.get_key(), node.as_payload(&proof)); self.result_found += 1; }; } else { diff --git a/src/tree.rs b/src/tree.rs index 8894f09..3cfcdcd 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -87,8 +87,8 @@ impl RawCongee { let child_node = child_node?; - if let Some(mut proof) = Self::is_last_level(level) { - let tid = child_node.as_payload_checked(&mut proof); + if let Some(proof) = Self::is_last_level(level) { + let tid = child_node.as_payload(&proof); return Some(tid); } @@ -145,7 +145,7 @@ impl RawCongee { n } else { let new_leaf = { - if level == (K_LEN - 1) { + if let Some(_proof) = Self::is_last_level(level) { // last key, just insert the tid NodePtr::from_payload(tid_func(None)) } else { @@ -185,11 +185,15 @@ impl RawCongee { p.unlock()?; } - if level == (K_LEN - 1) { + if let Some(proof) = Self::is_last_level(level) { // At this point, the level must point to the last u8 of the key, // meaning that we are updating an existing value. - let old = node.as_ref().get_child(node_key).unwrap().as_payload(); + let old = node + .as_ref() + .get_child(node_key) + .unwrap() + .as_payload(&proof); let new = tid_func(Some(old)); if old == new { node.check_version()?; @@ -201,7 +205,7 @@ impl RawCongee { let old = write_n .as_mut() .change(node_key, NodePtr::from_payload(new)); - return Ok(Some(old.as_payload())); + return Ok(Some(old.as_payload(&proof))); } parent_node = Some(node); node = BaseNode::read_lock::(next_node_tmp, level)?; @@ -402,8 +406,8 @@ impl RawCongee { None => return Ok(None), }; - if level == (K_LEN - 1) { - let tid = child_node.as_payload(); + if let Some(proof) = Self::is_last_level(level) { + let tid = child_node.as_payload(&proof); let new_v = remapping_function(tid); match new_v { @@ -417,9 +421,9 @@ impl RawCongee { .as_mut() .change(k[level], NodePtr::from_payload(new_v)); - debug_assert_eq!(tid, old.as_payload()); + debug_assert_eq!(tid, old.as_payload(&proof)); - return Ok(Some((old.as_payload(), Some(new_v)))); + return Ok(Some((old.as_payload(&proof), Some(new_v)))); } None => { // new value is none, we need to delete this entry @@ -444,7 +448,7 @@ impl RawCongee { write_n.as_mut().remove(node_key); } - return Ok(Some((child_node.as_payload(), None))); + return Ok(Some((child_node.as_payload(&proof), None))); } } } @@ -518,22 +522,25 @@ impl RawCongee { key_tracker.push(k); - if key_tracker.len() == K_LEN { - let new_v = f(key_tracker.to_usize_key(), child_node.as_payload()); - if new_v == child_node.as_payload() { + if let Some(proof) = key_tracker.is_last_level::() { + let new_v = f( + key_tracker.to_usize_key(&proof), + child_node.as_payload(&proof), + ); + if new_v == child_node.as_payload(&proof) { // Don't acquire the lock if the value is not changed - return Ok(Some((key_tracker.to_usize_key(), new_v, new_v))); + return Ok(Some((key_tracker.to_usize_key(&proof), new_v, new_v))); } let mut write_n = node.upgrade().map_err(|(_n, v)| v)?; let old_v = write_n.as_mut().change(k, NodePtr::from_payload(new_v)); - debug_assert_eq!(old_v.as_payload(), child_node.as_payload()); + debug_assert_eq!(old_v.as_payload(&proof), child_node.as_payload(&proof)); return Ok(Some(( - key_tracker.to_usize_key(), - child_node.as_payload(), + key_tracker.to_usize_key(&proof), + child_node.as_payload(&proof), new_v, ))); } diff --git a/src/utils.rs b/src/utils.rs index 056db18..ca05f62 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,5 @@ use crate::base_node::{BaseNode, MAX_KEY_LEN}; -use crate::node_ptr::NodePtr; +use crate::node_ptr::{LastLevelProof, NodePtr}; use core::cell::Cell; use core::fmt; @@ -101,9 +101,16 @@ impl KeyTracker { v } + pub(crate) fn is_last_level(&self) -> Option { + if self.len == K_LEN { + Some(LastLevelProof {}) + } else { + None + } + } + #[inline] - pub(crate) fn to_usize_key(&self) -> usize { - assert!(self.len == 8); + pub(crate) fn to_usize_key(&self, _last_level_proof: &LastLevelProof) -> usize { let val = unsafe { *((&self.data) as *const [u8; 8] as *const usize) }; val.swap_bytes() }