Skip to content

Commit

Permalink
fix a old bug
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Nov 21, 2024
1 parent f095e5d commit e2c8bdf
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 52 deletions.
2 changes: 2 additions & 0 deletions src/base_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,14 @@ impl Iterator for NodeIter<'_> {
}

#[repr(C)]
#[derive(Debug)]
pub(crate) struct BaseNode {
// 2b type | 60b version | 1b lock | 1b obsolete
pub(crate) type_version_lock_obsolete: AtomicUsize,
pub(crate) meta: NodeMeta,
}

#[derive(Debug)]
pub(crate) struct NodeMeta {
prefix_cnt: u32,
pub(crate) count: u16,
Expand Down
22 changes: 5 additions & 17 deletions src/node_16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,10 @@ const _: () = assert!(std::mem::size_of::<Node16>() == 168);
const _: () = assert!(std::mem::align_of::<Node16>() == 8);

impl Node16 {
fn flip_sign(val: u8) -> u8 {
val ^ 128
}

fn get_insert_pos(&self, key: u8) -> usize {
let flipped = Self::flip_sign(key);

let mut pos = 0;
while pos < self.base.meta.count {
if self.keys[pos as usize] >= flipped {
if self.keys[pos as usize] >= key {
return pos as usize;
}
pos += 1;
Expand All @@ -37,12 +31,11 @@ impl Node16 {

fn get_child_pos(&self, key: u8) -> Option<usize> {
// TODO: xiangpeng check this code is being auto-vectorized
let target = Self::flip_sign(key);

self.keys
.iter()
.take(self.base.meta.count as usize)
.position(|k| *k == target)
.position(|k| *k == key)
}
}

Expand All @@ -59,7 +52,7 @@ impl Iterator for Node16Iter<'_> {
if self.start_pos > self.end_pos {
return None;
}
let key = Node16::flip_sign(self.node.keys[self.start_pos]);
let key = self.node.keys[self.start_pos];
let child = self.node.children[self.start_pos];
self.start_pos += 1;
Some((key, child))
Expand Down Expand Up @@ -110,10 +103,7 @@ impl Node for Node16 {

fn copy_to<N: Node>(&self, dst: &mut N) {
for i in 0..self.base.meta.count {
dst.insert(
Self::flip_sign(self.keys[i as usize]),
self.children[i as usize],
);
dst.insert(self.keys[i as usize], self.children[i as usize]);
}
}

Expand All @@ -127,8 +117,6 @@ impl Node for Node16 {

// Insert must keep keys sorted, is this necessary?
fn insert(&mut self, key: u8, node: NodePtr) {
let key_flipped = Self::flip_sign(key);

let pos = self.get_insert_pos(key);

if pos < self.base.meta.count as usize {
Expand All @@ -138,7 +126,7 @@ impl Node for Node16 {
.copy_within(pos..self.base.meta.count as usize, pos + 1);
}

self.keys[pos] = key_flipped;
self.keys[pos] = key;
self.children[pos] = node;
self.base.meta.count += 1;

Expand Down
14 changes: 13 additions & 1 deletion src/node_ptr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ptr::NonNull;
use std::{fmt::Debug, ptr::NonNull};

use crate::{
base_node::{BaseNode, Node},
Expand Down Expand Up @@ -40,6 +40,18 @@ pub(crate) union NodePtr {
sub_node: NonNull<BaseNode>,
}

impl Debug for NodePtr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
unsafe {
write!(
f,
"payload: {:?} or sub_node: {:?}",
self.payload, self.sub_node
)
}
}
}

impl NodePtr {
#[inline]
pub(crate) fn from_node(ptr: &BaseNode) -> Self {
Expand Down
29 changes: 12 additions & 17 deletions src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
#[derive(Default, Debug, Clone)]
pub struct NodeStats {
levels: HashMap<usize, LevelStats>,
kv_pairs: usize,
}

impl Display for NodeStats {
Expand Down Expand Up @@ -56,18 +57,10 @@ impl Display for NodeStats {

writeln!(
f,
"Overall node count: {}, value count: {}",
"Overall node count: {}, entry count: {}",
node_count, value_count
)?;

let last_level = levels.last().unwrap();
writeln!(
f,
"Last level node: {}, value count: {}",
last_level.node_count(),
last_level.value_count(),
)?;

let load_factor = total_f / (node_count as f64);
if load_factor < 0.5 {
writeln!(f, "Load factor: {load_factor:.2} (too low)")?;
Expand All @@ -77,6 +70,8 @@ impl Display for NodeStats {

writeln!(f, "Active memory usage: {} Mb", memory_size / 1024 / 1024)?;

writeln!(f, "KV count: {}", self.kv_pairs)?;

Ok(())
}
}
Expand Down Expand Up @@ -130,15 +125,13 @@ struct StatsVisitor {
}

impl<const K_LEN: usize> CongeeVisitor<K_LEN> for StatsVisitor {
fn pre_visit_sub_node(&mut self, node: NonNull<BaseNode>) {
fn pre_visit_sub_node(&mut self, node: NonNull<BaseNode>, tree_level: usize) {
let node = BaseNode::read_lock(node).unwrap();
let tree_level = node.as_ref().prefix().len();

if !self.node_stats.levels.contains_key(&tree_level) {
self.node_stats
.levels
.insert(tree_level, LevelStats::new_level(tree_level));
}
self.node_stats
.levels
.entry(tree_level)
.or_insert_with(|| LevelStats::new_level(tree_level));

match node.as_ref().get_type() {
crate::base_node::NodeType::N4 => {
Expand Down Expand Up @@ -209,7 +202,9 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
};

self.dfs_visitor_slow(&mut visitor).unwrap();
let pin = crossbeam_epoch::pin();
visitor.node_stats.kv_pairs = self.value_count(&pin);

return visitor.node_stats;
visitor.node_stats
}
}
5 changes: 3 additions & 2 deletions src/tests/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ fn small_insert() {

#[test]
fn test_sparse_keys() {
let key_cnt = 30_000;
let tree = RawCongee::default();
use crate::utils::leak_check::LeakCheckAllocator;
let key_cnt = 100_000;
let tree = RawCongee::new(LeakCheckAllocator::new());
let mut keys = Vec::<usize>::with_capacity(key_cnt);

let guard = crossbeam_epoch::pin();
Expand Down
24 changes: 9 additions & 15 deletions src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ impl<const K_LEN: usize> Default for RawCongee<K_LEN> {

pub(crate) trait CongeeVisitor<const K_LEN: usize> {
fn visit_payload(&mut self, _payload: usize) {}
fn pre_visit_sub_node(&mut self, _node: NonNull<BaseNode>) {}
fn post_visit_sub_node(&mut self, _node: NonNull<BaseNode>) {}
fn pre_visit_sub_node(&mut self, _node: NonNull<BaseNode>, _tree_level: usize) {}
fn post_visit_sub_node(&mut self, _node: NonNull<BaseNode>, _tree_level: usize) {}
}

struct DropVisitor<const K_LEN: usize, A: Allocator + Clone + Send> {
Expand All @@ -47,21 +47,17 @@ struct DropVisitor<const K_LEN: usize, A: Allocator + Clone + Send> {
impl<const K_LEN: usize, A: Allocator + Clone + Send> CongeeVisitor<K_LEN>
for DropVisitor<K_LEN, A>
{
fn visit_payload(&mut self, _payload: usize) {}
fn pre_visit_sub_node(&mut self, _node: NonNull<BaseNode>) {}
fn post_visit_sub_node(&mut self, node: NonNull<BaseNode>) {
fn post_visit_sub_node(&mut self, node: NonNull<BaseNode>, _tree_level: usize) {
unsafe {
BaseNode::drop_node(node, self.allocator.clone());
}
}
}

#[cfg(test)]
struct ValueCountVisitor<const K_LEN: usize> {
value_count: usize,
}

#[cfg(test)]
impl<const K_LEN: usize> CongeeVisitor<K_LEN> for ValueCountVisitor<K_LEN> {
fn visit_payload(&mut self, _payload: usize) {
self.value_count += 1;
Expand Down Expand Up @@ -164,40 +160,39 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
let first = PtrType::SubNode(unsafe {
std::mem::transmute::<NonNull<Node256>, NonNull<BaseNode>>(self.root)
});
self.recursive_dfs(first, visitor)?;
Self::recursive_dfs(first, 0, visitor)?;
Ok(())
}

fn recursive_dfs<V: CongeeVisitor<K_LEN>>(
&self,
node: PtrType,
tree_level: usize,
visitor: &mut V,
) -> Result<(), ArtError> {
match node {
PtrType::Payload(v) => {
visitor.visit_payload(v);
}
PtrType::SubNode(node_ptr) => {
visitor.pre_visit_sub_node(node_ptr);
visitor.pre_visit_sub_node(node_ptr, tree_level);
let node_lock = BaseNode::read_lock(node_ptr)?;
let children = node_lock.as_ref().get_children(0, 255);
for (_k, child_ptr) in children {
let next = child_ptr.downcast::<K_LEN>(node_lock.as_ref().prefix().len());
self.recursive_dfs(next, visitor)?;
Self::recursive_dfs(next, tree_level + 1, visitor)?;
}
visitor.post_visit_sub_node(node_ptr);
node_lock.check_version()?;
visitor.post_visit_sub_node(node_ptr, tree_level);
}
}
Ok(())
}

/// Returns the number of values in the tree.
#[cfg(test)]
pub(crate) fn value_count(&self, _guard: &Guard) -> usize {
loop {
let mut visitor = ValueCountVisitor::<K_LEN> { value_count: 0 };
if let Ok(_) = self.dfs_visitor_slow(&mut visitor) {
if self.dfs_visitor_slow(&mut visitor).is_ok() {
return visitor.value_count;
}
}
Expand Down Expand Up @@ -307,7 +302,6 @@ impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
let mut write_n = node.upgrade().map_err(|(_n, v)| v)?;

// 1) Create new node which will be parent of node, Set common prefix, level to this node
// let prefix_len = write_n.as_ref().prefix().len();
let mut new_middle_node = BaseNode::make_node::<Node4>(
write_n.as_ref().prefix()[0..next_level].as_ref(),
&self.allocator,
Expand Down
82 changes: 82 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,85 @@ impl<const K_LEN: usize> KeyTracker<K_LEN> {
self.len
}
}

#[cfg(test)]
pub(crate) mod leak_check {
use super::*;

use crate::error::OOMError;
use crate::{Allocator, DefaultAllocator};
use std::collections::HashSet;
use std::ptr::NonNull;
use std::sync::{Arc, Mutex};

struct LeakCheckAllocatorInner {
allocated: Mutex<HashSet<NonNull<BaseNode>>>,
inner: DefaultAllocator,
}

unsafe impl Send for LeakCheckAllocatorInner {}
unsafe impl Sync for LeakCheckAllocatorInner {}

impl LeakCheckAllocatorInner {
pub fn new() -> Self {
Self {
allocated: Mutex::new(HashSet::new()),
inner: DefaultAllocator {},
}
}
}

impl Drop for LeakCheckAllocatorInner {
fn drop(&mut self) {
let allocated = self.allocated.lock().unwrap();

println!("Memory leak detected, leaked: {:?}", allocated.len());
for ptr in allocated.iter() {
let node = BaseNode::read_lock(ptr.clone()).unwrap();
println!("Ptr address: {:?}", ptr);
println!("{:?}", node.as_ref());
for (k, v) in node.as_ref().get_children(0, 255) {
println!("{:?} {:?}", k, v);
}
}
panic!("Memory leak detected, see above for details!");
}
}

#[derive(Clone)]
pub(crate) struct LeakCheckAllocator {
inner: Arc<LeakCheckAllocatorInner>,
}

impl LeakCheckAllocator {
pub fn new() -> Self {
Self {
inner: Arc::new(LeakCheckAllocatorInner::new()),
}
}
}

impl Allocator for LeakCheckAllocator {
fn allocate(
&self,
layout: std::alloc::Layout,
) -> Result<std::ptr::NonNull<[u8]>, OOMError> {
let ptr = self.inner.inner.allocate(layout)?;
self.inner
.allocated
.lock()
.unwrap()
.insert(NonNull::new(ptr.as_ptr() as *mut BaseNode).unwrap());
Ok(ptr)
}

unsafe fn deallocate(&self, ptr: std::ptr::NonNull<u8>, layout: std::alloc::Layout) {
self.inner
.allocated
.lock()
.unwrap()
.remove(&NonNull::new(ptr.as_ptr() as *mut BaseNode).unwrap());
self.inner.inner.deallocate(ptr, layout);
}
}
}

0 comments on commit e2c8bdf

Please sign in to comment.