Skip to content

Commit

Permalink
refactor locking
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Nov 19, 2024
1 parent 6cf0e02 commit 0e214c7
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 85 deletions.
30 changes: 19 additions & 11 deletions src/base_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ use crossbeam_epoch::Guard;

use crate::{
error::ArtError,
lock::{ConcreteReadGuard, ReadGuard},
lock::{ReadGuard, TypedReadGuard},
node_16::{Node16, Node16Iter},
node_256::{Node256, Node256Iter},
node_4::{Node4, Node4Iter},
node_48::{Node48, Node48Iter},
node_ptr::NodePtr,
utils::AllocatedNode,
node_ptr::{AllocatedNode, NodePtr},
Allocator,
};

Expand Down Expand Up @@ -241,18 +240,27 @@ impl BaseNode {
self.meta.node_type
}

#[inline]
pub(crate) fn read_lock(&self) -> Result<ReadGuard, ArtError> {
let version = self.type_version_lock_obsolete.load(Ordering::Acquire);

// #[cfg(test)]
// crate::utils::fail_point(ArtError::Locked(version))?;
pub(crate) fn read_lock<'a>(node: *const BaseNode) -> Result<ReadGuard<'a>, ArtError> {
let version = unsafe { &*node }
.type_version_lock_obsolete
.load(Ordering::Acquire);

if Self::is_locked(version) || Self::is_obsolete(version) {
return Err(ArtError::Locked);
}

Ok(ReadGuard::new(version, self))
Ok(ReadGuard::new(version, node))
}

pub(crate) fn read_lock_node_ptr<'a, const MAX_LEVEL: usize>(
node: NodePtr,
current_level: u32,
) -> Result<ReadGuard<'a>, ArtError> {
Self::read_lock(node.as_ptr_safe::<MAX_LEVEL>(current_level as usize))
}

pub(crate) fn read_lock_typed<'a, T: Node>(node: *const T) -> Result<ReadGuard<'a>, ArtError> {
Self::read_lock(node as *const BaseNode)
}

fn is_locked(version: usize) -> bool {
Expand All @@ -276,7 +284,7 @@ impl BaseNode {
}

pub(crate) fn insert_grow<CurT: Node, BiggerT: Node, A: Allocator + Send + Clone + 'static>(
n: ConcreteReadGuard<CurT>,
n: TypedReadGuard<CurT>,
parent: (u8, Option<ReadGuard>),
val: (u8, NodePtr),
allocator: &A,
Expand Down
43 changes: 19 additions & 24 deletions src/lock.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
use std::{cell::UnsafeCell, sync::atomic::Ordering};
use std::{marker::PhantomData, sync::atomic::Ordering};

use crate::{
base_node::{BaseNode, Node},
error::ArtError,
};

pub(crate) struct ConcreteReadGuard<'a, T: Node> {
pub(crate) struct TypedReadGuard<'a, T: Node> {
version: usize,
node: &'a UnsafeCell<T>,
node: *const T,
_pt_node: PhantomData<&'a T>,
}

impl<'a, T: Node> ConcreteReadGuard<'a, T> {
impl<'a, T: Node> TypedReadGuard<'a, T> {
pub(crate) fn as_ref(&self) -> &T {
unsafe { &*self.node.get() }
unsafe { &*self.node }
}

pub(crate) fn upgrade(self) -> Result<ConcreteWriteGuard<'a, T>, (Self, ArtError)> {
#[cfg(test)]
{
if crate::utils::fail_point(ArtError::VersionNotMatch).is_err() {
return Err((self, ArtError::VersionNotMatch));
};
}

let new_version = self.version + 0b10;
match self
.as_ref()
Expand All @@ -35,7 +29,8 @@ impl<'a, T: Node> ConcreteReadGuard<'a, T> {
Ordering::Relaxed,
) {
Ok(_) => Ok(ConcreteWriteGuard {
node: unsafe { &mut *self.node.get() },
// SAFETY: this is seems to be unsound, but we (1) acquired write lock, (2) has the right memory ordering.
node: unsafe { &mut *(self.node as *const T as *mut T) },
}),
Err(_v) => Err((self, ArtError::VersionNotMatch)),
}
Expand Down Expand Up @@ -74,14 +69,16 @@ impl<'a, T: Node> Drop for ConcreteWriteGuard<'a, T> {

pub(crate) struct ReadGuard<'a> {
version: usize,
node: &'a UnsafeCell<BaseNode>,
node: *const BaseNode,
_pt_node: PhantomData<&'a BaseNode>,
}

impl<'a> ReadGuard<'a> {
pub(crate) fn new(v: usize, node: &'a BaseNode) -> Self {
pub(crate) fn new(v: usize, node: *const BaseNode) -> Self {
Self {
version: v,
node: unsafe { &*(node as *const BaseNode as *const UnsafeCell<BaseNode>) }, // todo: the caller should pass UnsafeCell<BaseNode> instead
node,
_pt_node: PhantomData,
}
}

Expand All @@ -91,9 +88,6 @@ impl<'a> ReadGuard<'a> {
.type_version_lock_obsolete
.load(Ordering::Acquire);

#[cfg(test)]
crate::utils::fail_point(ArtError::VersionNotMatch)?;

if v == self.version {
Ok(v)
} else {
Expand All @@ -106,17 +100,18 @@ impl<'a> ReadGuard<'a> {
}

#[must_use]
pub(crate) fn into_concrete<T: Node>(self) -> ConcreteReadGuard<'a, T> {
pub(crate) fn into_concrete<T: Node>(self) -> TypedReadGuard<'a, T> {
assert_eq!(self.as_ref().get_type(), T::get_type());

ConcreteReadGuard {
TypedReadGuard {
version: self.version,
node: unsafe { &*(self.node as *const UnsafeCell<BaseNode> as *const UnsafeCell<T>) },
node: unsafe { &*(self.node as *const BaseNode as *const T) },
_pt_node: PhantomData,
}
}

pub(crate) fn as_ref(&self) -> &BaseNode {
unsafe { &*self.node.get() }
unsafe { &*self.node }
}

pub(crate) fn upgrade(self) -> Result<WriteGuard<'a>, (Self, ArtError)> {
Expand All @@ -138,7 +133,7 @@ impl<'a> ReadGuard<'a> {
Ordering::Relaxed,
) {
Ok(_) => Ok(WriteGuard {
node: unsafe { &mut *self.node.get() },
node: unsafe { &mut *(self.node as *mut BaseNode) },
}),
Err(_v) => Err((self, ArtError::VersionNotMatch)),
}
Expand Down
42 changes: 39 additions & 3 deletions src/node_ptr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::ptr::NonNull;

use crate::base_node::BaseNode;
use crate::base_node::{BaseNode, Node};

#[derive(Clone, Copy)]
pub(crate) union NodePtr {
Expand All @@ -10,11 +10,11 @@ pub(crate) union NodePtr {

impl NodePtr {
#[inline]
pub(crate) fn from_node(ptr: *const BaseNode) -> Self {
pub(crate) fn from_node(ptr: &BaseNode) -> Self {
Self { sub_node: ptr }
}

pub(crate) fn from_node_new(ptr: NonNull<BaseNode>) -> Self {
fn from_node_new(ptr: NonNull<BaseNode>) -> Self {
Self {
sub_node: ptr.as_ptr(),
}
Expand All @@ -34,4 +34,40 @@ impl NodePtr {
pub(crate) fn as_ptr(&self) -> *const BaseNode {
unsafe { self.sub_node }
}

pub(crate) fn as_ptr_safe<const MAX_LEVEL: usize>(
&self,
current_level: usize,
) -> *const BaseNode {
debug_assert!(current_level < MAX_LEVEL);
unsafe { self.sub_node }
}
}

pub(crate) struct AllocatedNode<N: Node> {
ptr: NonNull<N>,
}

impl<N: Node> AllocatedNode<N> {
pub(crate) fn new(ptr: NonNull<N>) -> Self {
Self { ptr }
}

pub(crate) fn as_mut(&mut self) -> &mut N {
unsafe { self.ptr.as_mut() }
}

pub(crate) fn into_note_ptr(self) -> NodePtr {
let ptr = self.ptr;
std::mem::forget(self);
unsafe { NodePtr::from_node_new(std::mem::transmute(ptr)) }
}
}

impl<N: Node> Drop for AllocatedNode<N> {
fn drop(&mut self) {
unsafe {
std::ptr::drop_in_place(self.ptr.as_mut());
}
}
}
8 changes: 4 additions & 4 deletions src/range_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
let mut key_tracker = KeyTracker::default();

loop {
node = unsafe { &*next_node }.read_lock()?;
node = BaseNode::read_lock(next_node)?;

let prefix_check_result = self.check_prefix_equals(node.as_ref(), &mut key_tracker);

Expand Down Expand Up @@ -148,7 +148,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
) -> Result<(), ArtError> {
debug_assert!(key_tracker.len() != 8);

let node = unsafe { &*node.as_ptr() }.read_lock()?;
let node = BaseNode::read_lock_node_ptr::<K_LEN>(node, key_tracker.len() as u32)?;
let prefix_result =
self.check_prefix_compare(node.as_ref(), self.end, 255, &mut key_tracker);
let level = key_tracker.len();
Expand Down Expand Up @@ -198,7 +198,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
) -> Result<(), ArtError> {
debug_assert!(key_tracker.len() != 8);

let node = unsafe { &*node.as_ptr() }.read_lock()?;
let node = BaseNode::read_lock_node_ptr::<K_LEN>(node, key_tracker.len() as u32)?;
let prefix_result =
self.check_prefix_compare(node.as_ref(), self.start, 0, &mut key_tracker);

Expand Down Expand Up @@ -252,7 +252,7 @@ impl<'a, const K_LEN: usize> RangeScan<'a, K_LEN> {
self.result_found += 1;
};
} else {
let node = unsafe { &*node.as_ptr() }.read_lock()?;
let node = BaseNode::read_lock_node_ptr::<K_LEN>(node, key_tracker.len() as u32)?;
let mut key_tracker = key_tracker.clone();

let children = node.as_ref().get_children(0, 255);
Expand Down
4 changes: 2 additions & 2 deletions src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ impl<const K_LEN: usize, A: Allocator + Clone> RawCongee<K_LEN, A> {
if key_level != (MAX_KEY_LEN - 1) {
sub_nodes.push((
level + 1,
key_level + 1 + unsafe { &*n.as_ptr() }.prefix().len(),
n.as_ptr(),
key_level + 1 + unsafe { &*n.as_ptr_safe::<K_LEN>(level) }.prefix().len(),
n.as_ptr_safe::<K_LEN>(level),
));
}
}
Expand Down
26 changes: 15 additions & 11 deletions src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ impl<const K_LEN: usize, A: Allocator + Clone> Drop for RawCongee<K_LEN, A> {
let children = unsafe { &*node }.get_children(0, 255);
for (_k, n) in children {
if level != (K_LEN - 1) {
sub_nodes.push((n.as_ptr(), unsafe { &*n.as_ptr() }.prefix().len()));
sub_nodes.push((
n.as_ptr_safe::<K_LEN>(level),
unsafe { &*n.as_ptr_safe::<K_LEN>(level) }.prefix().len(),
));
}
}
unsafe {
Expand All @@ -55,7 +58,7 @@ impl<const K_LEN: usize, A: Allocator + Clone> RawCongee<K_LEN, A> {
.expect("Can't allocate memory for root node!");
let root_ptr = root.into_note_ptr();
RawCongee {
root: root_ptr.as_ptr() as *const Node256,
root: root_ptr.as_ptr_safe::<K_LEN>(0) as *const Node256,
allocator,
_pt_key: PhantomData,
}
Expand All @@ -68,7 +71,7 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
'outer: loop {
let mut level = 0;

let mut node = if let Ok(v) = unsafe { &*self.root }.base().read_lock() {
let mut node = if let Ok(v) = BaseNode::read_lock_typed(self.root) {
v
} else {
continue;
Expand All @@ -94,7 +97,7 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {

level += 1;

node = if let Ok(n) = unsafe { &*child_node.as_ptr() }.read_lock() {
node = if let Ok(n) = BaseNode::read_lock_node_ptr::<K_LEN>(child_node, level) {
n
} else {
continue 'outer;
Expand Down Expand Up @@ -123,7 +126,7 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {

loop {
parent_key = node_key;
node = unsafe { &*next_node }.read_lock()?;
node = BaseNode::read_lock(next_node)?;

let mut next_level = level;
let res = self.check_prefix_not_match(node.as_ref(), k, &mut next_level);
Expand Down Expand Up @@ -165,7 +168,8 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
if level != (K_LEN - 1) as u32 {
unsafe {
BaseNode::drop_node(
new_leaf.as_ptr() as *mut BaseNode,
new_leaf.as_ptr_safe::<K_LEN>(level as usize)
as *mut BaseNode,
self.allocator.clone(),
);
}
Expand Down Expand Up @@ -196,7 +200,7 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
let old = write_n.as_mut().change(node_key, NodePtr::from_tid(new));
return Ok(Some(old.as_tid()));
}
next_node = next_node_tmp.as_ptr();
next_node = next_node_tmp.as_ptr_safe::<K_LEN>(level as usize);
level += 1;
}

Expand Down Expand Up @@ -376,7 +380,7 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
let mut parent: Option<(ReadGuard, u8)> = None;
let mut node_key: u8;
let mut level = 0;
let mut node = unsafe { &*self.root }.base().read_lock()?;
let mut node = BaseNode::read_lock_typed(self.root)?;

loop {
level = if let Some(v) = Self::check_prefix(node.as_ref(), k, level) {
Expand Down Expand Up @@ -443,7 +447,7 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {

level += 1;
parent = Some((node, node_key));
node = unsafe { &*child_node.as_ptr() }.read_lock()?;
node = BaseNode::read_lock_node_ptr::<K_LEN>(child_node, level)?;
}
}

Expand Down Expand Up @@ -491,7 +495,7 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
f: &mut impl FnMut(usize, usize) -> usize,
_guard: &Guard,
) -> Result<Option<(usize, usize, usize)>, ArtError> {
let mut node = unsafe { &*self.root }.base().read_lock()?;
let mut node = BaseNode::read_lock_typed(self.root)?;

let mut key_tracker = crate::utils::KeyTracker::default();

Expand Down Expand Up @@ -530,7 +534,7 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
)));
}

node = unsafe { &*child_node.as_ptr() }.read_lock()?;
node = BaseNode::read_lock_node_ptr::<K_LEN>(child_node, key_tracker.len() as u32)?;
}
}
}
Loading

0 comments on commit 0e214c7

Please sign in to comment.