Skip to content

Commit

Permalink
fix: unsound cache upcast
Browse files Browse the repository at this point in the history
Removes the unsafe code around cache upcasting by introducing traited
functions `as_ref` and `as_mut` on the AutoImplCacheRef trait. Also no
longer requires the `unsafe` attribute.
  • Loading branch information
TroyKomodo committed Oct 16, 2023
1 parent d653e4a commit 2f1afdb
Showing 1 changed file with 75 additions and 29 deletions.
104 changes: 75 additions & 29 deletions common/src/dataloader/cache.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::collections::hash_map::RandomState;
use std::{
collections::hash_map::RandomState,
ops::{Deref, DerefMut},
};

use crate::dataloader::LoaderOutput;

Expand All @@ -16,62 +19,70 @@ pub trait Cache<L: Loader<S>, S = RandomState> {
}
}

/// # Safety
///
/// This trait is marked as unsafe because the implementor must ensure that the Cache is safe for concurrent access.
/// This will almost always be with some kind of interior mutability. Such as a `RwLock` or `Mutex`. Or if the cache performs no-ops on mutation.
/// Look at `SharedCache` for an example.
pub unsafe trait AutoImplCacheRef<L: Loader<S>, S = RandomState>:
AutoImplCacheMutRef<L, S>
{
#[repr(transparent)]
pub struct EmptyDerefMut<T>(T);

impl<T> Deref for EmptyDerefMut<T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<T> DerefMut for EmptyDerefMut<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

#[inline(always)]
#[allow(clippy::mut_from_ref)]
#[allow(invalid_reference_casting)]
fn upcast<T: AutoImplCacheRef<L, S>, L: Loader<S>, S>(cache: &T) -> impl Cache<L, S> {
// Safety:
// This is safe because the trait `AutoImplCacheRef` is marked as unsafe and therefore the implementor must ensure that the
// Cache is safe for concurrent access. This is used to implement cache for &T where T is a Cache.
// The issue is we need to upcast the reference to a mutable reference, even though the implementor only ever requires a reference.
// This is not safe unless T has some kind of interior mutability.
unsafe { &mut *(cache as *const T as *mut T) }
pub trait AutoImplCacheRef<L: Loader<S>, S = RandomState> {
type Cache: Cache<L, S>;
type Ref<'a>: Deref<Target = Self::Cache> + 'a
where
Self: 'a;
type MutRef<'a>: DerefMut<Target = Self::Cache> + 'a
where
Self: 'a;

fn as_ref(&self) -> Self::Ref<'_>;
fn as_mut(&self) -> Self::MutRef<'_>;
}

impl<L: Loader<S>, S, T: AutoImplCacheRef<L, S>> Cache<L, S> for &T {
#[inline(always)]
fn contains_key(&self, key: &L::Key) -> bool {
(**self).contains_key(key)
self.as_ref().contains_key(key)
}

#[inline(always)]
fn get(&self, key: &L::Key) -> Option<L::Value> {
(**self).get(key)
self.as_ref().get(key)
}

#[inline(always)]
fn insert(&mut self, key: &L::Key, value: &L::Value) {
upcast(*self).insert(key, value)
self.as_mut().insert(key, value)
}

#[inline(always)]
fn clear(&mut self) {
upcast(*self).clear()
self.as_mut().clear()
}

#[inline(always)]
fn len(&self) -> usize {
(**self).len()
self.as_ref().len()
}

#[inline(always)]
fn delete(&mut self, key: &L::Key) -> Option<L::Value> {
upcast(*self).delete(key)
self.as_mut().delete(key)
}

#[inline(always)]
fn is_empty(&self) -> bool {
(**self).is_empty()
self.as_ref().is_empty()
}
}

Expand Down Expand Up @@ -118,7 +129,19 @@ impl<L: Loader<S>, S, T: AutoImplCacheMutRef<L, S>> Cache<L, S> for &mut T {
pub struct NoCache;

/// Safety: The no cache is always for safe for concurrent access because it is a no-op cache and therefore has no internal state.
unsafe impl<L: Loader<S>, S> AutoImplCacheRef<L, S> for NoCache {}
impl<L: Loader<S>, S> AutoImplCacheRef<L, S> for NoCache {
type Cache = Self;
type Ref<'a> = EmptyDerefMut<Self>;
type MutRef<'a> = EmptyDerefMut<Self>;

fn as_mut(&self) -> Self::MutRef<'_> {
EmptyDerefMut(*self)
}

fn as_ref(&self) -> Self::Ref<'_> {
EmptyDerefMut(*self)
}
}

impl<L: Loader<S>, S> AutoImplCacheMutRef<L, S> for NoCache {}

Expand Down Expand Up @@ -234,8 +257,21 @@ impl<C: Default> Default for SharedCache<C> {
}

impl<C: Cache<L, S>, L: Loader<S>, S> AutoImplCacheMutRef<L, S> for SharedCache<C> {}
/// Safety: The shared cache is always for concurrent access because it contains a RwLock, which is a safe concurrent access primitive.
unsafe impl<'a, C: Cache<L, S> + 'a, L: Loader<S>, S> AutoImplCacheRef<L, S> for SharedCache<C> {}

impl<C: Cache<L, S> + 'static, L: Loader<S>, S> AutoImplCacheRef<L, S> for SharedCache<C> {
type Cache = C;

type Ref<'a> = std::sync::RwLockReadGuard<'a, C>;
type MutRef<'a> = std::sync::RwLockWriteGuard<'a, C>;

fn as_ref(&self) -> Self::Ref<'_> {
self.0.read().unwrap()
}

fn as_mut(&self) -> Self::MutRef<'_> {
self.0.write().unwrap()
}
}

impl<C: Cache<L, S>, L: Loader<S>, S> Cache<L, S> for SharedCache<C> {
#[inline(always)]
Expand Down Expand Up @@ -289,6 +325,16 @@ const _: () = {
}
}

const fn assert_size_of<T: Sized>() {
if std::mem::size_of::<T>() != std::mem::size_of::<EmptyDerefMut<T>>() {
panic!("T and EmptyDerefMut<T> have different sizes")
}
}

assert_size_of::<NoCache>();
assert_size_of::<HashMapCache<DummyLoader>>();
assert_size_of::<SharedCache<NoCache>>();

const fn assert_auto_impl_cache_ref<C: Cache<L, S>, L: Loader<S>, S>() {}

assert_auto_impl_cache_ref::<NoCache, DummyLoader, RandomState>();
Expand Down

0 comments on commit 2f1afdb

Please sign in to comment.