Skip to content

Commit

Permalink
skiplist: Add compare_insert (#976)
Browse files Browse the repository at this point in the history
Signed-off-by: Tianion <[email protected]>
  • Loading branch information
Tianion authored Aug 26, 2023
1 parent 518c0b9 commit 0720460
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 7 deletions.
33 changes: 26 additions & 7 deletions crossbeam-skiplist/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ where

/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist.
pub fn get_or_insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> {
self.insert_internal(key, || value, false, guard)
self.insert_internal(key, || value, |_| false, guard)
}

/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist,
Expand All @@ -486,7 +486,7 @@ where
where
F: FnOnce() -> V,
{
self.insert_internal(key, value, false, guard)
self.insert_internal(key, value, |_| false, guard)
}

/// Returns an iterator over all entries in the skip list.
Expand Down Expand Up @@ -846,15 +846,16 @@ where
/// Inserts an entry with the specified `key` and `value`.
///
/// If `replace` is `true`, then any existing entry with this key will first be removed.
fn insert_internal<F>(
fn insert_internal<F, CompareF>(
&self,
key: K,
value: F,
replace: bool,
replace: CompareF,
guard: &Guard,
) -> RefEntry<'_, K, V>
where
F: FnOnce() -> V,
CompareF: Fn(&V) -> bool,
{
self.check_guard(guard);

Expand All @@ -874,7 +875,7 @@ where
Some(r) => r,
None => break,
};

let replace = replace(&r.value);
if replace {
// If a node with the key was found and we should replace it, mark its tower
// and then repeat the search.
Expand All @@ -896,7 +897,6 @@ where

// create value before creating node, so extra allocation doesn't happen if value() function panics
let value = value();

// Create a new node.
let height = self.random_height();
let (node, n) = {
Expand Down Expand Up @@ -945,6 +945,7 @@ where
}

if let Some(r) = search.found {
let replace = replace(&r.value);
if replace {
// If a node with the key was found and we should replace it, mark its
// tower and then repeat the search.
Expand Down Expand Up @@ -1082,7 +1083,25 @@ where
/// If there is an existing entry with this key, it will be removed before inserting the new
/// one.
pub fn insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> {
self.insert_internal(key, || value, true, guard)
self.insert_internal(key, || value, |_| true, guard)
}

/// Inserts a `key`-`value` pair into the skip list and returns the new entry.
///
/// If there is an existing entry with this key and compare(entry.value) returns true,
/// it will be removed before inserting the new one.
/// The closure will not be called if the key is not present.
pub fn compare_insert<F>(
&self,
key: K,
value: V,
compare_fn: F,
guard: &Guard,
) -> RefEntry<'_, K, V>
where
F: Fn(&V) -> bool,
{
self.insert_internal(key, || value, compare_fn, guard)
}

/// Removes an entry with the specified `key` from the map and returns it.
Expand Down
30 changes: 30 additions & 0 deletions crossbeam-skiplist/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,36 @@ where
Entry::new(self.inner.insert(key, value, guard))
}

/// Inserts a `key`-`value` pair into the skip list and returns the new entry.
///
/// If there is an existing entry with this key and compare(entry.value) returns true,
/// it will be removed before inserting the new one.
/// The closure will not be called if the key is not present.
///
/// This function returns an [`Entry`] which
/// can be used to access the inserted key's associated value.
///
/// # Example
/// ```
/// use crossbeam_skiplist::SkipMap;
///
/// let map = SkipMap::new();
/// map.insert("key", 1);
/// map.compare_insert("key", 0, |x| x < &0);
/// assert_eq!(*map.get("key").unwrap().value(), 1);
/// map.compare_insert("key", 2, |x| x < &2);
/// assert_eq!(*map.get("key").unwrap().value(), 2);
/// map.compare_insert("absent_key", 0, |_| false);
/// assert_eq!(*map.get("absent_key").unwrap().value(), 0);
/// ```
pub fn compare_insert<F>(&self, key: K, value: V, compare_fn: F) -> Entry<'_, K, V>
where
F: Fn(&V) -> bool,
{
let guard = &epoch::pin();
Entry::new(self.inner.compare_insert(key, value, compare_fn, guard))
}

/// Removes an entry with the specified `key` from the map and returns it.
///
/// The value will not actually be dropped until all references to it have gone
Expand Down
61 changes: 61 additions & 0 deletions crossbeam-skiplist/tests/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,48 @@ fn insert() {
}
}

#[test]
fn compare_and_insert() {
let insert = [0, 4, 2, 12, 8, 7, 11, 5];
let not_present = [1, 3, 6, 9, 10];
let s = SkipMap::new();

for &x in &insert {
s.insert(x, x * 10);
}

for &x in &insert {
let value = x * 5;
let old_value = *s.get(&x).unwrap().value();
s.compare_insert(x, value, |x| x < &value);
assert_eq!(*s.get(&x).unwrap().value(), old_value);
}

for &x in &insert {
let value = x * 15;
s.compare_insert(x, value, |x| x < &value);
assert_eq!(*s.get(&x).unwrap().value(), value);
}

for &x in &not_present {
assert!(s.get(&x).is_none());
}
}

#[test]
fn compare_insert_with_absent_key() {
let insert = [0, 4, 2, 12, 8, 7, 11, 5];
let s = SkipMap::new();

// The closure will not be called if the key is not present,
// so the key-value will be inserted into the map
for &x in &insert {
let value = x * 15;
s.compare_insert(x, value, |_| false);
assert_eq!(*s.get(&x).unwrap().value(), value);
}
}

#[test]
fn remove() {
let insert = [0, 4, 2, 12, 8, 7, 11, 5];
Expand Down Expand Up @@ -103,6 +145,25 @@ fn concurrent_insert() {
}
}

#[test]
fn concurrent_compare_and_insert() {
let set: SkipMap<i32, i32> = SkipMap::new();
let set = std::sync::Arc::new(set);
let len = 100;
let mut handlers = Vec::with_capacity(len as usize);
for i in 0..len {
let set = set.clone();
let handler = std::thread::spawn(move || {
set.compare_insert(1, i, |j| j < &i);
});
handlers.push(handler);
}
for handler in handlers {
handler.join().unwrap();
}
assert_eq!(*set.get(&1).unwrap().value(), len - 1);
}

// https://github.com/crossbeam-rs/crossbeam/issues/672
#[test]
fn concurrent_remove() {
Expand Down

0 comments on commit 0720460

Please sign in to comment.