From f4ac8c500799000f89275ad549e70872ef07ac16 Mon Sep 17 00:00:00 2001 From: Troy Benson Date: Mon, 16 Oct 2023 18:38:41 +0000 Subject: [PATCH] fix: unsound cache upcast Removes the unsafe code around cache upcasting by introducing traited functions `as_ref` and `as_mut` on the AutoImplCacheRef trait. Also no longer requires the `unsafe` attribute. --- common/src/dataloader/cache.rs | 108 ++++++++++++++++++++++++--------- 1 file changed, 79 insertions(+), 29 deletions(-) diff --git a/common/src/dataloader/cache.rs b/common/src/dataloader/cache.rs index e1c447a94..f466d036f 100644 --- a/common/src/dataloader/cache.rs +++ b/common/src/dataloader/cache.rs @@ -1,4 +1,7 @@ -use std::collections::hash_map::RandomState; +use std::{ + collections::hash_map::RandomState, + ops::{Deref, DerefMut}, +}; use crate::dataloader::LoaderOutput; @@ -16,62 +19,70 @@ pub trait Cache, S = RandomState> { } } -/// # Safety -/// -/// This trait is marked as unsafe because the implementor must ensure that the Cache is safe for concurrent access. -/// This will almost always be with some kind of interior mutability. Such as a `RwLock` or `Mutex`. Or if the cache performs no-ops on mutation. -/// Look at `SharedCache` for an example. -pub unsafe trait AutoImplCacheRef, S = RandomState>: - AutoImplCacheMutRef -{ +#[repr(transparent)] +pub struct EmptyDerefMut(T); + +impl Deref for EmptyDerefMut { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for EmptyDerefMut { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } -#[inline(always)] -#[allow(clippy::mut_from_ref)] -#[allow(invalid_reference_casting)] -fn upcast, L: Loader, S>(cache: &T) -> impl Cache { - // Safety: - // This is safe because the trait `AutoImplCacheRef` is marked as unsafe and therefore the implementor must ensure that the - // Cache is safe for concurrent access. This is used to implement cache for &T where T is a Cache. - // The issue is we need to upcast the reference to a mutable reference, even though the implementor only ever requires a reference. - // This is not safe unless T has some kind of interior mutability. - unsafe { &mut *(cache as *const T as *mut T) } +pub trait AutoImplCacheRef, S = RandomState> { + type Cache: Cache; + type Ref<'a>: Deref + 'a + where + Self: 'a; + type MutRef<'a>: DerefMut + 'a + where + Self: 'a; + + fn as_ref(&self) -> Self::Ref<'_>; + fn as_mut(&self) -> Self::MutRef<'_>; } impl, S, T: AutoImplCacheRef> Cache for &T { #[inline(always)] fn contains_key(&self, key: &L::Key) -> bool { - (**self).contains_key(key) + self.as_ref().contains_key(key) } #[inline(always)] fn get(&self, key: &L::Key) -> Option { - (**self).get(key) + self.as_ref().get(key) } #[inline(always)] fn insert(&mut self, key: &L::Key, value: &L::Value) { - upcast(*self).insert(key, value) + self.as_mut().insert(key, value) } #[inline(always)] fn clear(&mut self) { - upcast(*self).clear() + self.as_mut().clear() } #[inline(always)] fn len(&self) -> usize { - (**self).len() + self.as_ref().len() } #[inline(always)] fn delete(&mut self, key: &L::Key) -> Option { - upcast(*self).delete(key) + self.as_mut().delete(key) } #[inline(always)] fn is_empty(&self) -> bool { - (**self).is_empty() + self.as_ref().is_empty() } } @@ -118,7 +129,21 @@ impl, S, T: AutoImplCacheMutRef> Cache for &mut T { pub struct NoCache; /// Safety: The no cache is always for safe for concurrent access because it is a no-op cache and therefore has no internal state. -unsafe impl, S> AutoImplCacheRef for NoCache {} +impl, S> AutoImplCacheRef for NoCache { + type Cache = Self; + type Ref<'a> = EmptyDerefMut; + type MutRef<'a> = EmptyDerefMut; + + #[inline(always)] + fn as_mut(&self) -> Self::MutRef<'_> { + EmptyDerefMut(*self) + } + + #[inline(always)] + fn as_ref(&self) -> Self::Ref<'_> { + EmptyDerefMut(*self) + } +} impl, S> AutoImplCacheMutRef for NoCache {} @@ -234,8 +259,23 @@ impl Default for SharedCache { } impl, L: Loader, S> AutoImplCacheMutRef for SharedCache {} -/// Safety: The shared cache is always for concurrent access because it contains a RwLock, which is a safe concurrent access primitive. -unsafe impl<'a, C: Cache + 'a, L: Loader, S> AutoImplCacheRef for SharedCache {} + +impl + 'static, L: Loader, S> AutoImplCacheRef for SharedCache { + type Cache = C; + + type Ref<'a> = std::sync::RwLockReadGuard<'a, C>; + type MutRef<'a> = std::sync::RwLockWriteGuard<'a, C>; + + #[inline(always)] + fn as_ref(&self) -> Self::Ref<'_> { + self.0.read().unwrap() + } + + #[inline(always)] + fn as_mut(&self) -> Self::MutRef<'_> { + self.0.write().unwrap() + } +} impl, L: Loader, S> Cache for SharedCache { #[inline(always)] @@ -289,6 +329,16 @@ const _: () = { } } + const fn assert_size_of() { + if std::mem::size_of::() != std::mem::size_of::>() { + panic!("T and EmptyDerefMut have different sizes") + } + } + + assert_size_of::(); + assert_size_of::>(); + assert_size_of::>(); + const fn assert_auto_impl_cache_ref, L: Loader, S>() {} assert_auto_impl_cache_ref::();