diff --git a/src/set.rs b/src/set.rs index 0ade464dc..66c2d4833 100644 --- a/src/set.rs +++ b/src/set.rs @@ -7,7 +7,7 @@ use core::hash::{BuildHasher, Hash}; use core::iter::{Chain, FusedIterator}; use core::ops::{BitAnd, BitOr, BitXor, Sub}; -use super::map::{self, DefaultHashBuilder, HashMap, Keys}; +use super::map::{self, make_hash, DefaultHashBuilder, HashMap, Keys, RawEntryMut}; use crate::raw::{Allocator, Global, RawExtractIf}; // Future Optimization (FIXME!) @@ -955,6 +955,12 @@ where /// Inserts a value computed from `f` into the set if the given `value` is /// not present, then returns a reference to the value in the set. /// + /// # Panics + /// + /// Panics if the value from the function and the provided lookup value + /// are not equivalent or have different hashes. See [`Equivalent`] + /// and [`Hash`] for more information. + /// /// # Examples /// /// ``` @@ -969,6 +975,7 @@ where /// assert_eq!(value, pet); /// } /// assert_eq!(set.len(), 4); // a new "fish" was inserted + /// assert!(set.contains("fish")); /// ``` #[cfg_attr(feature = "inline-more", inline)] pub fn get_or_insert_with(&mut self, value: &Q, f: F) -> &T @@ -976,13 +983,29 @@ where Q: Hash + Equivalent, F: FnOnce(&Q) -> T, { + #[cold] + #[inline(never)] + fn assert_failed() { + panic!( + "the value from the function and the lookup value \ + must be equivalent and have the same hash" + ); + } + // Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with // `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`. - self.map - .raw_entry_mut() - .from_key(value) - .or_insert_with(|| (f(value), ())) - .0 + let hash = make_hash::(&self.map.hash_builder, value); + let raw_entry_builder = self.map.raw_entry_mut(); + match raw_entry_builder.from_key_hashed_nocheck(hash, value) { + RawEntryMut::Occupied(entry) => entry.into_key(), + RawEntryMut::Vacant(entry) => { + let insert_value = f(value); + if !value.equivalent(&insert_value) { + assert_failed(); + } + entry.insert_hashed_nocheck(hash, insert_value, ()).0 + } + } } /// Gets the given value's corresponding entry in the set for in-place manipulation. @@ -2492,7 +2515,7 @@ fn assert_covariance() { #[cfg(test)] mod test_set { use super::super::map::DefaultHashBuilder; - use super::HashSet; + use super::{make_hash, Equivalent, HashSet}; use std::vec::Vec; #[test] @@ -2958,4 +2981,57 @@ mod test_set { // (and without the `map`, it does not). let mut _set: HashSet<_> = (0..3).map(|_| ()).collect(); } + + #[test] + fn duplicate_insert() { + let mut set = HashSet::new(); + set.insert(1); + set.get_or_insert_with(&1, |_| 1); + set.get_or_insert_with(&1, |_| 1); + assert!([1].iter().eq(set.iter())); + } + + #[test] + #[should_panic] + fn some_invalid_equivalent() { + use core::hash::{Hash, Hasher}; + struct Invalid { + count: u32, + other: u32, + } + + struct InvalidRef { + count: u32, + other: u32, + } + + impl PartialEq for Invalid { + fn eq(&self, other: &Self) -> bool { + self.count == other.count && self.other == other.other + } + } + impl Eq for Invalid {} + + impl Equivalent for InvalidRef { + fn equivalent(&self, key: &Invalid) -> bool { + self.count == key.count && self.other == key.other + } + } + impl Hash for Invalid { + fn hash(&self, state: &mut H) { + self.count.hash(state); + } + } + impl Hash for InvalidRef { + fn hash(&self, state: &mut H) { + self.count.hash(state); + } + } + let mut set: HashSet = HashSet::new(); + let key = InvalidRef { count: 1, other: 1 }; + let value = Invalid { count: 1, other: 2 }; + if make_hash(set.hasher(), &key) == make_hash(set.hasher(), &value) { + set.get_or_insert_with(&key, |_| value); + } + } }