Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: unsound cache upcast #136

Merged
merged 1 commit into from
Oct 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 79 additions & 30 deletions common/src/dataloader/cache.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::collections::hash_map::RandomState;
use std::{
collections::hash_map::RandomState,
ops::{Deref, DerefMut},
};

use crate::dataloader::LoaderOutput;

Expand All @@ -16,62 +19,70 @@ pub trait Cache<L: Loader<S>, S = RandomState> {
}
}

/// # Safety
///
/// This trait is marked as unsafe because the implementor must ensure that the Cache is safe for concurrent access.
/// This will almost always be with some kind of interior mutability. Such as a `RwLock` or `Mutex`. Or if the cache performs no-ops on mutation.
/// Look at `SharedCache` for an example.
pub unsafe trait AutoImplCacheRef<L: Loader<S>, S = RandomState>:
AutoImplCacheMutRef<L, S>
{
#[repr(transparent)]
pub struct EmptyDerefMut<T>(T);

impl<T> Deref for EmptyDerefMut<T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<T> DerefMut for EmptyDerefMut<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

#[inline(always)]
#[allow(clippy::mut_from_ref)]
#[allow(invalid_reference_casting)]
fn upcast<T: AutoImplCacheRef<L, S>, L: Loader<S>, S>(cache: &T) -> impl Cache<L, S> {
// Safety:
// This is safe because the trait `AutoImplCacheRef` is marked as unsafe and therefore the implementor must ensure that the
// Cache is safe for concurrent access. This is used to implement cache for &T where T is a Cache.
// The issue is we need to upcast the reference to a mutable reference, even though the implementor only ever requires a reference.
// This is not safe unless T has some kind of interior mutability.
unsafe { &mut *(cache as *const T as *mut T) }
pub trait AutoImplCacheRef<L: Loader<S>, S = RandomState> {
Comment on lines -33 to -38
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was actually nothing wrong with the way we used this. The issue would have been

struct CImpl {
   cache: Mutex<i32>
}

impl Cache {
  fn clear(&mut self) {
      *self.cache.get_mut() = 0;
  }
   ...
}

If the caller implemented the unsafe trait incorrectly it would be UB; now you might say well that is the point of the unsafe part. and you would be right, but i might as well catch it before they fuck up.

type Cache: Cache<L, S>;
type Ref<'a>: Deref<Target = Self::Cache> + 'a
where
Self: 'a;
type MutRef<'a>: DerefMut<Target = Self::Cache> + 'a
where
Self: 'a;

fn as_ref(&self) -> Self::Ref<'_>;
fn as_mut(&self) -> Self::MutRef<'_>;
}

impl<L: Loader<S>, S, T: AutoImplCacheRef<L, S>> Cache<L, S> for &T {
#[inline(always)]
fn contains_key(&self, key: &L::Key) -> bool {
(**self).contains_key(key)
self.as_ref().contains_key(key)
}

#[inline(always)]
fn get(&self, key: &L::Key) -> Option<L::Value> {
(**self).get(key)
self.as_ref().get(key)
}

#[inline(always)]
fn insert(&mut self, key: &L::Key, value: &L::Value) {
upcast(*self).insert(key, value)
self.as_mut().insert(key, value)
}

#[inline(always)]
fn clear(&mut self) {
upcast(*self).clear()
self.as_mut().clear()
}

#[inline(always)]
fn len(&self) -> usize {
(**self).len()
self.as_ref().len()
}

#[inline(always)]
fn delete(&mut self, key: &L::Key) -> Option<L::Value> {
upcast(*self).delete(key)
self.as_mut().delete(key)
}

#[inline(always)]
fn is_empty(&self) -> bool {
(**self).is_empty()
self.as_ref().is_empty()
}
}

Expand Down Expand Up @@ -117,8 +128,21 @@ impl<L: Loader<S>, S, T: AutoImplCacheMutRef<L, S>> Cache<L, S> for &mut T {
#[derive(Default, Clone, Debug, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct NoCache;

/// Safety: The no cache is always for safe for concurrent access because it is a no-op cache and therefore has no internal state.
unsafe impl<L: Loader<S>, S> AutoImplCacheRef<L, S> for NoCache {}
impl<L: Loader<S>, S> AutoImplCacheRef<L, S> for NoCache {
type Cache = Self;
type Ref<'a> = EmptyDerefMut<Self>;
type MutRef<'a> = EmptyDerefMut<Self>;

#[inline(always)]
fn as_mut(&self) -> Self::MutRef<'_> {
EmptyDerefMut(*self)
}

#[inline(always)]
fn as_ref(&self) -> Self::Ref<'_> {
EmptyDerefMut(*self)
}
}

impl<L: Loader<S>, S> AutoImplCacheMutRef<L, S> for NoCache {}

Expand Down Expand Up @@ -234,8 +258,23 @@ impl<C: Default> Default for SharedCache<C> {
}

impl<C: Cache<L, S>, L: Loader<S>, S> AutoImplCacheMutRef<L, S> for SharedCache<C> {}
/// Safety: The shared cache is always for concurrent access because it contains a RwLock, which is a safe concurrent access primitive.
unsafe impl<'a, C: Cache<L, S> + 'a, L: Loader<S>, S> AutoImplCacheRef<L, S> for SharedCache<C> {}

impl<C: Cache<L, S> + 'static, L: Loader<S>, S> AutoImplCacheRef<L, S> for SharedCache<C> {
type Cache = C;

type Ref<'a> = std::sync::RwLockReadGuard<'a, C>;
type MutRef<'a> = std::sync::RwLockWriteGuard<'a, C>;

#[inline(always)]
fn as_ref(&self) -> Self::Ref<'_> {
self.0.read().unwrap()
}

#[inline(always)]
fn as_mut(&self) -> Self::MutRef<'_> {
self.0.write().unwrap()
}
}

impl<C: Cache<L, S>, L: Loader<S>, S> Cache<L, S> for SharedCache<C> {
#[inline(always)]
Expand Down Expand Up @@ -289,6 +328,16 @@ const _: () = {
}
}

const fn assert_size_of<T: Sized>() {
if std::mem::size_of::<T>() != std::mem::size_of::<EmptyDerefMut<T>>() {
panic!("T and EmptyDerefMut<T> have different sizes")
}
}

assert_size_of::<NoCache>();
assert_size_of::<HashMapCache<DummyLoader>>();
assert_size_of::<SharedCache<NoCache>>();

const fn assert_auto_impl_cache_ref<C: Cache<L, S>, L: Loader<S>, S>() {}

assert_auto_impl_cache_ref::<NoCache, DummyLoader, RandomState>();
Expand Down
Loading