From 311a2599e10fb9f86c0182ae607faeb0988e6ffc Mon Sep 17 00:00:00 2001 From: Jack Wrenn Date: Thu, 17 Oct 2024 13:59:21 +0000 Subject: [PATCH] Implement `Ref::{try_as_ref,try_into_ref,try_into_mut}` Only `Ref::try_as_mut` remains missing, probably pending polonius landing in rustc. Partially fixes #1865 Supersedes #1184 --- src/error.rs | 7 ++ src/ref.rs | 311 ++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 314 insertions(+), 4 deletions(-) diff --git a/src/error.rs b/src/error.rs index e30b6895ce..a8c52aaf79 100644 --- a/src/error.rs +++ b/src/error.rs @@ -569,6 +569,13 @@ impl ValidityError { self.src } + pub(crate) fn with_src(self, new_src: NewSrc) -> ValidityError { + // INVARIANT: `with_src` doesn't change the type of `Dst`, so the + // invariant that `Dst`'s alignment requirement is greater than one is + // preserved. + ValidityError { src: new_src, dst: SendSyncPhantomData::default() } + } + /// Maps the source value associated with the conversion error. /// /// This can help mitigate [issues with `Send`, `Sync` and `'static` diff --git a/src/ref.rs b/src/ref.rs index 0f4ce00214..7a5f3c5d97 100644 --- a/src/ref.rs +++ b/src/ref.rs @@ -595,10 +595,85 @@ where } } +impl Ref +where + B: ByteSlice, + T: KnownLayout + ?Sized, +{ + /// Attempts to dereference this `Ref<_, T>` into a `&T` without copying. + /// + /// If the bytes of `self` are a valid instance of `T`, this method returns + /// a reference to those bytes interpreted as `T`. If those bytes are not a + /// valid instance of `T`, this returns `Err`. + /// + /// # Examples + /// + /// ``` + /// use zerocopy::Ref; + /// # use zerocopy_derive::*; + /// + /// // The only valid value of this type is the byte `0xC0` + /// #[derive(TryFromBytes, KnownLayout, Immutable)] + /// #[repr(u8)] + /// enum C0 { xC0 = 0xC0 } + /// + /// // The only valid value of this type is the bytes `0xC0C0`. + /// #[derive(TryFromBytes, KnownLayout, Immutable)] + /// #[repr(C, packed)] + /// struct C0C0(C0, C0); + /// + /// #[derive(TryFromBytes, KnownLayout, Immutable)] + /// #[repr(C)] + /// struct Packet { + /// magic_number: C0C0, + /// mug_size: u8, + /// temperature: u8, + /// marshmallows: [[u8; 2]], + /// } + /// + /// let bytes = &[0xC0, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5][..]; + /// + /// let r = Ref::<_, Packet>::new(bytes).unwrap(); + /// let packet = Ref::try_as_ref(&r).unwrap(); + /// + /// assert_eq!(packet.mug_size, 240); + /// assert_eq!(packet.temperature, 77); + /// assert_eq!(packet.marshmallows, [[0, 1], [2, 3], [4, 5]]); + /// ``` + #[must_use = "has no side effects"] + #[inline(always)] + pub fn try_as_ref(r: &Self) -> Result<&T, ValidityError<&Self, T>> + where + T: TryFromBytes + Immutable, + { + // Presumably unreachable, since we've guarded each constructor of `Ref`. + static_assert_dst_is_not_zst!(T); + + // SAFETY: We don't call any methods on `r` other than those provided by + // `ByteSlice`. + let b = unsafe { r.as_byte_slice() }; + + match Ptr::from_ref(b.deref()).try_cast_into_no_leftover::(None) { + Ok(candidate) => match candidate.try_into_valid() { + Ok(valid) => Ok(valid.as_ref()), + Err(e) => Err(e.map_src(|_| r)), + }, + Err(CastError::Validity(i)) => match i {}, + Err(CastError::Alignment(_) | CastError::Size(_)) => { + // SAFETY: By invariant on `Ref::0`, the referenced byte slice + // is aligned to `T`'s alignment and its size corresponds to a + // valid size for `T`. Since properties are checked upon + // constructing `Ref`, these failures are unreachable. + unsafe { core::hint::unreachable_unchecked() } + } + } + } +} + impl<'a, B, T> Ref where B: 'a + IntoByteSlice<'a>, - T: FromBytes + KnownLayout + Immutable + ?Sized, + T: TryFromBytes + KnownLayout + Immutable + ?Sized, { /// Converts this `Ref` into a reference. /// @@ -609,7 +684,10 @@ where /// there is no conflict with a method on the inner type. #[must_use = "has no side effects"] #[inline(always)] - pub fn into_ref(r: Self) -> &'a T { + pub fn into_ref(r: Self) -> &'a T + where + T: FromBytes, + { // Presumably unreachable, since we've guarded each constructor of `Ref`. static_assert_dst_is_not_zst!(T); @@ -627,12 +705,91 @@ where let ptr = ptr.bikeshed_recall_valid(); ptr.as_ref() } + + /// Attempts to convert this `Ref<_, T>` into a `&T` without copying. + /// + /// If the bytes of `self` are a valid instance of `T`, this method returns + /// a reference to those bytes interpreted as `T`. If those bytes are not a + /// valid instance of `T`, this returns `Err`. + /// + /// # Examples + /// + /// ``` + /// use zerocopy::Ref; + /// # use zerocopy_derive::*; + /// + /// // The only valid value of this type is the byte `0xC0` + /// #[derive(TryFromBytes, KnownLayout, Immutable)] + /// #[repr(u8)] + /// enum C0 { xC0 = 0xC0 } + /// + /// // The only valid value of this type is the bytes `0xC0C0`. + /// #[derive(TryFromBytes, KnownLayout, Immutable)] + /// #[repr(C, packed)] + /// struct C0C0(C0, C0); + /// + /// #[derive(TryFromBytes, KnownLayout, Immutable)] + /// #[repr(C)] + /// struct Packet { + /// magic_number: C0C0, + /// mug_size: u8, + /// temperature: u8, + /// marshmallows: [[u8; 2]], + /// } + /// + /// let bytes = &[0xC0, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5][..]; + /// + /// let r = Ref::<_, Packet>::new(bytes).unwrap(); + /// let packet = Ref::try_into_ref(r).unwrap(); + /// + /// assert_eq!(packet.mug_size, 240); + /// assert_eq!(packet.temperature, 77); + /// assert_eq!(packet.marshmallows, [[0, 1], [2, 3], [4, 5]]); + /// ``` + #[must_use = "has no side effects"] + #[inline(always)] + pub fn try_into_ref(r: Self) -> Result<&'a T, ValidityError> { + // Presumably unreachable, since we've guarded each constructor of `Ref`. + static_assert_dst_is_not_zst!(T); + + // SAFETY: We don't call any methods on `b` other than those provided by + // `ByteSlice`. + let bytes = unsafe { r.as_byte_slice() }; + + let bytes: &'_ [u8] = bytes.deref(); + + // Extend the lifetime of `bytes` to `'a`. This gives us a reference + // `bytes` with the same lifetime as if we had called + // `r.into_byte_slice()`, but without consuming `r`. This is valuable, + // since we will need to return `r` if validation fails. + // + // SAFETY: This is sound because `bytes` lives for `'a`. `Self` is + // `IntoByteSlice`, whose `.into_byte_slice()` method is guaranteed to + // produce a `&'a [u8]` with the same address and length as the slice + // obtained by `.deref()` (which is how `bytes` is obtained). + let bytes = unsafe { mem::transmute::<&[u8], &'a [u8]>(bytes) }; + + match Ptr::from_ref(bytes).try_cast_into_no_leftover::(None) { + Ok(candidate) => match candidate.try_into_valid() { + Ok(candidate) => Ok(candidate.as_ref()), + Err(e) => Err(e.with_src(r)), + }, + Err(CastError::Validity(i)) => match i {}, + Err(CastError::Alignment(_) | CastError::Size(_)) => { + // SAFETY: By invariant on `Ref::0`, the referenced byte slice + // is aligned to `T`'s alignment and its size corresponds to a + // valid size for `T`. Since properties are checked upon + // constructing `Ref`, these failures are unreachable. + unsafe { core::hint::unreachable_unchecked() } + } + } + } } impl<'a, B, T> Ref where B: 'a + IntoByteSliceMut<'a>, - T: FromBytes + IntoBytes + KnownLayout + ?Sized, + T: TryFromBytes + IntoBytes + KnownLayout + ?Sized, { /// Converts this `Ref` into a mutable reference. /// @@ -643,7 +800,10 @@ where /// there is no conflict with a method on the inner type. #[must_use = "has no side effects"] #[inline(always)] - pub fn into_mut(r: Self) -> &'a mut T { + pub fn into_mut(r: Self) -> &'a mut T + where + T: FromBytes, + { // Presumably unreachable, since we've guarded each constructor of `Ref`. static_assert_dst_is_not_zst!(T); @@ -661,6 +821,86 @@ where let ptr = ptr.bikeshed_recall_valid(); ptr.as_mut() } + + /// Attempts to convert this `Ref<_, T>` into a `&mut T` without copying. + /// + /// If the bytes of `self` are a valid instance of `T`, this method returns + /// a reference to those bytes interpreted as `T`. If those bytes are not a + /// valid instance of `T`, this returns `Err`. + /// + /// # Examples + /// + /// ``` + /// use zerocopy::Ref; + /// # use zerocopy_derive::*; + /// + /// // The only valid value of this type is the byte `0xC0` + /// #[derive(TryFromBytes, IntoBytes, KnownLayout, Immutable)] + /// #[repr(u8)] + /// enum C0 { xC0 = 0xC0 } + /// + /// // The only valid value of this type is the bytes `0xC0C0`. + /// #[derive(TryFromBytes, IntoBytes, KnownLayout, Immutable)] + /// #[repr(C, packed)] + /// struct C0C0(C0, C0); + /// + /// #[derive(TryFromBytes, IntoBytes, KnownLayout, Immutable)] + /// #[repr(C, packed)] + /// struct Packet { + /// magic_number: C0C0, + /// mug_size: u8, + /// temperature: u8, + /// marshmallows: [[u8; 2]], + /// } + /// + /// let bytes = &mut [0xC0, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5][..]; + /// + /// let r = Ref::<_, Packet>::new(bytes).unwrap(); + /// let packet = Ref::try_into_mut(r).unwrap(); + /// + /// assert_eq!(packet.mug_size, 240); + /// assert_eq!(packet.temperature, 77); + /// assert_eq!(packet.marshmallows, [[0, 1], [2, 3], [4, 5]]); + /// ``` + #[must_use = "has no side effects"] + #[inline(always)] + pub fn try_into_mut(mut r: Self) -> Result<&'a mut T, ValidityError> { + // Presumably unreachable, since we've guarded each constructor of `Ref`. + static_assert_dst_is_not_zst!(T); + + // SAFETY: We don't call any methods on `b` other than those provided by + // `ByteSliceMut`. + let bytes = unsafe { r.as_byte_slice_mut() }; + + let bytes: &'_ mut [u8] = bytes.deref_mut(); + + // Extend the lifetime of `bytes` to `'a`. This gives us a reference + // `bytes` with the same lifetime as if we had called + // `r.into_byte_slice_mut()`, but without consuming `r`. This is + // valuable, since we will need to return `r` if validation fails. + // + // SAFETY: This is sound because `bytes` lives for `'a`. `Self` is + // `IntoByteSliceMut`, whose `.into_byte_slice_mut()` method is + // guaranteed to produce a `&'a [u8]` with the same address and length + // as the slice obtained by `.deref()` (which is how `bytes` is + // obtained). + let bytes = unsafe { mem::transmute::<&mut [u8], &'a mut [u8]>(bytes) }; + + match Ptr::from_mut(bytes).try_cast_into_no_leftover::(None) { + Ok(candidate) => match candidate.try_into_valid() { + Ok(candidate) => Ok(candidate.as_mut()), + Err(e) => Err(e.with_src(r)), + }, + Err(CastError::Validity(i)) => match i {}, + Err(CastError::Alignment(_) | CastError::Size(_)) => { + // SAFETY: By invariant on `Ref::0`, the referenced byte slice + // is aligned to `T`'s alignment and its size corresponds to a + // valid size for `T`. Since properties are checked upon + // constructing `Ref`, these failures are unreachable. + unsafe { core::hint::unreachable_unchecked() } + } + } + } } impl Ref @@ -1109,6 +1349,33 @@ mod tests { assert!(Ref::<_, [AU64]>::from_suffix_with_elems(&buf.t[..], unreasonable_len).is_err()); } + #[test] + #[allow(unstable_name_collisions)] + #[allow(clippy::as_conversions)] + fn test_try_as_ref() { + #[allow(unused)] + use crate::util::AsAddress as _; + + // valid source + + let buf = Align::<[u8; 8], u64>::default(); + let buf_addr = (&buf.t as *const [u8; 8]).addr(); + + let r = Ref::<_, u64>::from_bytes(&buf.t[..]).unwrap(); + let rf = Ref::try_as_ref(&r).unwrap(); + assert_eq!(rf, &0u64); + assert_eq!((rf as *const u64).addr(), buf_addr); + + // invalid source + + let buf = Align::<[u8; 1], u64>::new([42]); + let buf_addr = (&buf.t as *const [u8; 1]).addr(); + + let r = Ref::<_, bool>::from_bytes(&buf.t[..]).unwrap(); + let re = Ref::try_as_ref(&r).unwrap_err(); + assert_eq!(Ref::bytes(re.into_src()).addr(), buf_addr); + } + #[test] #[allow(unstable_name_collisions)] #[allow(clippy::as_conversions)] @@ -1132,6 +1399,42 @@ mod tests { assert_eq!(buf.t, [0xFF; 8]); } + #[test] + #[allow(unstable_name_collisions)] + #[allow(clippy::as_conversions)] + fn test_try_into_ref_mut() { + #[allow(unused)] + use crate::util::AsAddress as _; + + // valid source + + let mut buf = Align::<[u8; 8], u64>::default(); + let buf_addr = (&buf.t as *const [u8; 8]).addr(); + + let r = Ref::<_, u64>::from_bytes(&buf.t[..]).unwrap(); + let rf = Ref::try_into_ref(r).unwrap(); + assert_eq!(rf, &0u64); + assert_eq!((rf as *const u64).addr(), buf_addr); + + let r = Ref::<_, u64>::from_bytes(&mut buf.t[..]).unwrap(); + let rf = Ref::try_into_mut(r).unwrap(); + assert_eq!(rf, &mut 0u64); + assert_eq!((rf as *mut u64).addr(), buf_addr); + + // invalid source + + let mut buf = Align::<[u8; 1], u64>::new([42]); + let buf_addr = (&buf.t as *const [u8; 1]).addr(); + + let r = Ref::<_, bool>::from_bytes(&buf.t[..]).unwrap(); + let re = Ref::try_into_ref(r).unwrap_err(); + assert_eq!(Ref::bytes(&re.into_src()).addr(), buf_addr); + + let r = Ref::<_, bool>::from_bytes(&mut buf.t[..]).unwrap(); + let re = Ref::try_into_mut(r).unwrap_err(); + assert_eq!(Ref::bytes(&re.into_src()).addr(), buf_addr); + } + #[test] fn test_display_debug() { let buf = Align::<[u8; 8], u64>::default();