From 0e214c799c2e4f7be4fb1262b2279bca334cd743 Mon Sep 17 00:00:00 2001 From: Xiangpeng Hao Date: Mon, 18 Nov 2024 21:20:53 -0600 Subject: [PATCH] refactor locking --- src/base_node.rs | 30 +++++++++++++++++++----------- src/lock.rs | 43 +++++++++++++++++++------------------------ src/node_ptr.rs | 42 +++++++++++++++++++++++++++++++++++++++--- src/range_scan.rs | 8 ++++---- src/stats.rs | 4 ++-- src/tree.rs | 26 +++++++++++++++----------- src/utils.rs | 31 +------------------------------ 7 files changed, 99 insertions(+), 85 deletions(-) diff --git a/src/base_node.rs b/src/base_node.rs index 4f6bbcd..1519dca 100644 --- a/src/base_node.rs +++ b/src/base_node.rs @@ -8,13 +8,12 @@ use crossbeam_epoch::Guard; use crate::{ error::ArtError, - lock::{ConcreteReadGuard, ReadGuard}, + lock::{ReadGuard, TypedReadGuard}, node_16::{Node16, Node16Iter}, node_256::{Node256, Node256Iter}, node_4::{Node4, Node4Iter}, node_48::{Node48, Node48Iter}, - node_ptr::NodePtr, - utils::AllocatedNode, + node_ptr::{AllocatedNode, NodePtr}, Allocator, }; @@ -241,18 +240,27 @@ impl BaseNode { self.meta.node_type } - #[inline] - pub(crate) fn read_lock(&self) -> Result { - let version = self.type_version_lock_obsolete.load(Ordering::Acquire); - - // #[cfg(test)] - // crate::utils::fail_point(ArtError::Locked(version))?; + pub(crate) fn read_lock<'a>(node: *const BaseNode) -> Result, ArtError> { + let version = unsafe { &*node } + .type_version_lock_obsolete + .load(Ordering::Acquire); if Self::is_locked(version) || Self::is_obsolete(version) { return Err(ArtError::Locked); } - Ok(ReadGuard::new(version, self)) + Ok(ReadGuard::new(version, node)) + } + + pub(crate) fn read_lock_node_ptr<'a, const MAX_LEVEL: usize>( + node: NodePtr, + current_level: u32, + ) -> Result, ArtError> { + Self::read_lock(node.as_ptr_safe::(current_level as usize)) + } + + pub(crate) fn read_lock_typed<'a, T: Node>(node: *const T) -> Result, ArtError> { + Self::read_lock(node as *const BaseNode) } fn is_locked(version: usize) -> bool { @@ -276,7 +284,7 @@ impl BaseNode { } pub(crate) fn insert_grow( - n: ConcreteReadGuard, + n: TypedReadGuard, parent: (u8, Option), val: (u8, NodePtr), allocator: &A, diff --git a/src/lock.rs b/src/lock.rs index 84cc823..8a71a5a 100644 --- a/src/lock.rs +++ b/src/lock.rs @@ -1,28 +1,22 @@ -use std::{cell::UnsafeCell, sync::atomic::Ordering}; +use std::{marker::PhantomData, sync::atomic::Ordering}; use crate::{ base_node::{BaseNode, Node}, error::ArtError, }; -pub(crate) struct ConcreteReadGuard<'a, T: Node> { +pub(crate) struct TypedReadGuard<'a, T: Node> { version: usize, - node: &'a UnsafeCell, + node: *const T, + _pt_node: PhantomData<&'a T>, } -impl<'a, T: Node> ConcreteReadGuard<'a, T> { +impl<'a, T: Node> TypedReadGuard<'a, T> { pub(crate) fn as_ref(&self) -> &T { - unsafe { &*self.node.get() } + unsafe { &*self.node } } pub(crate) fn upgrade(self) -> Result, (Self, ArtError)> { - #[cfg(test)] - { - if crate::utils::fail_point(ArtError::VersionNotMatch).is_err() { - return Err((self, ArtError::VersionNotMatch)); - }; - } - let new_version = self.version + 0b10; match self .as_ref() @@ -35,7 +29,8 @@ impl<'a, T: Node> ConcreteReadGuard<'a, T> { Ordering::Relaxed, ) { Ok(_) => Ok(ConcreteWriteGuard { - node: unsafe { &mut *self.node.get() }, + // SAFETY: this is seems to be unsound, but we (1) acquired write lock, (2) has the right memory ordering. + node: unsafe { &mut *(self.node as *const T as *mut T) }, }), Err(_v) => Err((self, ArtError::VersionNotMatch)), } @@ -74,14 +69,16 @@ impl<'a, T: Node> Drop for ConcreteWriteGuard<'a, T> { pub(crate) struct ReadGuard<'a> { version: usize, - node: &'a UnsafeCell, + node: *const BaseNode, + _pt_node: PhantomData<&'a BaseNode>, } impl<'a> ReadGuard<'a> { - pub(crate) fn new(v: usize, node: &'a BaseNode) -> Self { + pub(crate) fn new(v: usize, node: *const BaseNode) -> Self { Self { version: v, - node: unsafe { &*(node as *const BaseNode as *const UnsafeCell) }, // todo: the caller should pass UnsafeCell instead + node, + _pt_node: PhantomData, } } @@ -91,9 +88,6 @@ impl<'a> ReadGuard<'a> { .type_version_lock_obsolete .load(Ordering::Acquire); - #[cfg(test)] - crate::utils::fail_point(ArtError::VersionNotMatch)?; - if v == self.version { Ok(v) } else { @@ -106,17 +100,18 @@ impl<'a> ReadGuard<'a> { } #[must_use] - pub(crate) fn into_concrete(self) -> ConcreteReadGuard<'a, T> { + pub(crate) fn into_concrete(self) -> TypedReadGuard<'a, T> { assert_eq!(self.as_ref().get_type(), T::get_type()); - ConcreteReadGuard { + TypedReadGuard { version: self.version, - node: unsafe { &*(self.node as *const UnsafeCell as *const UnsafeCell) }, + node: unsafe { &*(self.node as *const BaseNode as *const T) }, + _pt_node: PhantomData, } } pub(crate) fn as_ref(&self) -> &BaseNode { - unsafe { &*self.node.get() } + unsafe { &*self.node } } pub(crate) fn upgrade(self) -> Result, (Self, ArtError)> { @@ -138,7 +133,7 @@ impl<'a> ReadGuard<'a> { Ordering::Relaxed, ) { Ok(_) => Ok(WriteGuard { - node: unsafe { &mut *self.node.get() }, + node: unsafe { &mut *(self.node as *mut BaseNode) }, }), Err(_v) => Err((self, ArtError::VersionNotMatch)), } diff --git a/src/node_ptr.rs b/src/node_ptr.rs index aab5a7e..e25f13c 100644 --- a/src/node_ptr.rs +++ b/src/node_ptr.rs @@ -1,6 +1,6 @@ use std::ptr::NonNull; -use crate::base_node::BaseNode; +use crate::base_node::{BaseNode, Node}; #[derive(Clone, Copy)] pub(crate) union NodePtr { @@ -10,11 +10,11 @@ pub(crate) union NodePtr { impl NodePtr { #[inline] - pub(crate) fn from_node(ptr: *const BaseNode) -> Self { + pub(crate) fn from_node(ptr: &BaseNode) -> Self { Self { sub_node: ptr } } - pub(crate) fn from_node_new(ptr: NonNull) -> Self { + fn from_node_new(ptr: NonNull) -> Self { Self { sub_node: ptr.as_ptr(), } @@ -34,4 +34,40 @@ impl NodePtr { pub(crate) fn as_ptr(&self) -> *const BaseNode { unsafe { self.sub_node } } + + pub(crate) fn as_ptr_safe( + &self, + current_level: usize, + ) -> *const BaseNode { + debug_assert!(current_level < MAX_LEVEL); + unsafe { self.sub_node } + } +} + +pub(crate) struct AllocatedNode { + ptr: NonNull, +} + +impl AllocatedNode { + pub(crate) fn new(ptr: NonNull) -> Self { + Self { ptr } + } + + pub(crate) fn as_mut(&mut self) -> &mut N { + unsafe { self.ptr.as_mut() } + } + + pub(crate) fn into_note_ptr(self) -> NodePtr { + let ptr = self.ptr; + std::mem::forget(self); + unsafe { NodePtr::from_node_new(std::mem::transmute(ptr)) } + } +} + +impl Drop for AllocatedNode { + fn drop(&mut self) { + unsafe { + std::ptr::drop_in_place(self.ptr.as_mut()); + } + } } diff --git a/src/range_scan.rs b/src/range_scan.rs index e8de283..c2cbfbe 100644 --- a/src/range_scan.rs +++ b/src/range_scan.rs @@ -61,7 +61,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { let mut key_tracker = KeyTracker::default(); loop { - node = unsafe { &*next_node }.read_lock()?; + node = BaseNode::read_lock(next_node)?; let prefix_check_result = self.check_prefix_equals(node.as_ref(), &mut key_tracker); @@ -148,7 +148,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { ) -> Result<(), ArtError> { debug_assert!(key_tracker.len() != 8); - let node = unsafe { &*node.as_ptr() }.read_lock()?; + let node = BaseNode::read_lock_node_ptr::(node, key_tracker.len() as u32)?; let prefix_result = self.check_prefix_compare(node.as_ref(), self.end, 255, &mut key_tracker); let level = key_tracker.len(); @@ -198,7 +198,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { ) -> Result<(), ArtError> { debug_assert!(key_tracker.len() != 8); - let node = unsafe { &*node.as_ptr() }.read_lock()?; + let node = BaseNode::read_lock_node_ptr::(node, key_tracker.len() as u32)?; let prefix_result = self.check_prefix_compare(node.as_ref(), self.start, 0, &mut key_tracker); @@ -252,7 +252,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> { self.result_found += 1; }; } else { - let node = unsafe { &*node.as_ptr() }.read_lock()?; + let node = BaseNode::read_lock_node_ptr::(node, key_tracker.len() as u32)?; let mut key_tracker = key_tracker.clone(); let children = node.as_ref().get_children(0, 255); diff --git a/src/stats.rs b/src/stats.rs index 084b53a..d69cfe9 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -129,8 +129,8 @@ impl RawCongee { if key_level != (MAX_KEY_LEN - 1) { sub_nodes.push(( level + 1, - key_level + 1 + unsafe { &*n.as_ptr() }.prefix().len(), - n.as_ptr(), + key_level + 1 + unsafe { &*n.as_ptr_safe::(level) }.prefix().len(), + n.as_ptr_safe::(level), )); } } diff --git a/src/tree.rs b/src/tree.rs index 9030bf0..a2ca674 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -39,7 +39,10 @@ impl Drop for RawCongee { let children = unsafe { &*node }.get_children(0, 255); for (_k, n) in children { if level != (K_LEN - 1) { - sub_nodes.push((n.as_ptr(), unsafe { &*n.as_ptr() }.prefix().len())); + sub_nodes.push(( + n.as_ptr_safe::(level), + unsafe { &*n.as_ptr_safe::(level) }.prefix().len(), + )); } } unsafe { @@ -55,7 +58,7 @@ impl RawCongee { .expect("Can't allocate memory for root node!"); let root_ptr = root.into_note_ptr(); RawCongee { - root: root_ptr.as_ptr() as *const Node256, + root: root_ptr.as_ptr_safe::(0) as *const Node256, allocator, _pt_key: PhantomData, } @@ -68,7 +71,7 @@ impl RawCongee { 'outer: loop { let mut level = 0; - let mut node = if let Ok(v) = unsafe { &*self.root }.base().read_lock() { + let mut node = if let Ok(v) = BaseNode::read_lock_typed(self.root) { v } else { continue; @@ -94,7 +97,7 @@ impl RawCongee { level += 1; - node = if let Ok(n) = unsafe { &*child_node.as_ptr() }.read_lock() { + node = if let Ok(n) = BaseNode::read_lock_node_ptr::(child_node, level) { n } else { continue 'outer; @@ -123,7 +126,7 @@ impl RawCongee { loop { parent_key = node_key; - node = unsafe { &*next_node }.read_lock()?; + node = BaseNode::read_lock(next_node)?; let mut next_level = level; let res = self.check_prefix_not_match(node.as_ref(), k, &mut next_level); @@ -165,7 +168,8 @@ impl RawCongee { if level != (K_LEN - 1) as u32 { unsafe { BaseNode::drop_node( - new_leaf.as_ptr() as *mut BaseNode, + new_leaf.as_ptr_safe::(level as usize) + as *mut BaseNode, self.allocator.clone(), ); } @@ -196,7 +200,7 @@ impl RawCongee { let old = write_n.as_mut().change(node_key, NodePtr::from_tid(new)); return Ok(Some(old.as_tid())); } - next_node = next_node_tmp.as_ptr(); + next_node = next_node_tmp.as_ptr_safe::(level as usize); level += 1; } @@ -376,7 +380,7 @@ impl RawCongee { let mut parent: Option<(ReadGuard, u8)> = None; let mut node_key: u8; let mut level = 0; - let mut node = unsafe { &*self.root }.base().read_lock()?; + let mut node = BaseNode::read_lock_typed(self.root)?; loop { level = if let Some(v) = Self::check_prefix(node.as_ref(), k, level) { @@ -443,7 +447,7 @@ impl RawCongee { level += 1; parent = Some((node, node_key)); - node = unsafe { &*child_node.as_ptr() }.read_lock()?; + node = BaseNode::read_lock_node_ptr::(child_node, level)?; } } @@ -491,7 +495,7 @@ impl RawCongee { f: &mut impl FnMut(usize, usize) -> usize, _guard: &Guard, ) -> Result, ArtError> { - let mut node = unsafe { &*self.root }.base().read_lock()?; + let mut node = BaseNode::read_lock_typed(self.root)?; let mut key_tracker = crate::utils::KeyTracker::default(); @@ -530,7 +534,7 @@ impl RawCongee { ))); } - node = unsafe { &*child_node.as_ptr() }.read_lock()?; + node = BaseNode::read_lock_node_ptr::(child_node, key_tracker.len() as u32)?; } } } diff --git a/src/utils.rs b/src/utils.rs index 230dad1..dc43f75 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,8 +1,7 @@ -use crate::base_node::{Node, MAX_KEY_LEN}; +use crate::base_node::MAX_KEY_LEN; use crate::node_ptr::NodePtr; use core::cell::Cell; use core::fmt; -use std::ptr::NonNull; const SPIN_LIMIT: u32 = 6; const YIELD_LIMIT: u32 = 10; @@ -173,31 +172,3 @@ fn random(n: u32) -> u32 { }) .unwrap_or(0) } - -pub(crate) struct AllocatedNode { - ptr: NonNull, -} - -impl AllocatedNode { - pub(crate) fn new(ptr: NonNull) -> Self { - Self { ptr } - } - - pub(crate) fn as_mut(&mut self) -> &mut N { - unsafe { self.ptr.as_mut() } - } - - pub(crate) fn into_note_ptr(self) -> NodePtr { - let ptr = self.ptr; - std::mem::forget(self); - unsafe { NodePtr::from_node_new(std::mem::transmute(ptr)) } - } -} - -impl Drop for AllocatedNode { - fn drop(&mut self) { - unsafe { - std::ptr::drop_in_place(self.ptr.as_mut()); - } - } -}