Skip to content

Commit

Permalink
derive: generalize require_self_sized (#883)
Browse files Browse the repository at this point in the history
Generalize `require_self_sized` to require other `Trait`s.
  • Loading branch information
jswrenn authored Feb 15, 2024
1 parent b0edd98 commit d39a5d4
Showing 1 changed file with 112 additions and 46 deletions.
158 changes: 112 additions & 46 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@
mod ext;
mod repr;

use quote::quote_spanned;

use {
proc_macro2::Span,
quote::quote,
syn::{
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit,
GenericParam, Ident, Lit, Type, WherePredicate,
parse_quote, parse_quote_spanned, Data, DataEnum, DataStruct, DataUnion, DeriveInput,
Error, Expr, ExprLit, GenericParam, Ident, Lit, Path, Type, WherePredicate,
},
};

Expand Down Expand Up @@ -88,10 +90,8 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre

let fields = ast.data.fields();

let (require_self_sized, extras) = if let (
Some(reprs),
Some((trailing_field, leading_fields)),
) = (is_repr_c_struct, fields.split_last())
let (self_bounds, extras) = if let (Some(reprs), Some((trailing_field, leading_fields))) =
(is_repr_c_struct, fields.split_last())
{
let (_name, trailing_field_ty) = trailing_field;
let leading_fields_tys = leading_fields.iter().map(|(_name, ty)| ty);
Expand Down Expand Up @@ -121,7 +121,7 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre
.unwrap_or(quote!(None));

(
false,
SelfBounds::None,
quote!(
// SAFETY: `LAYOUT` accurately describes the layout of `Self`.
// The layout of `Self` is reflected using a sequence of
Expand Down Expand Up @@ -172,7 +172,7 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre
// `Self` is sized, and as a result don't need to reason about the
// internals of the type.
(
true,
SelfBounds::SIZED,
quote!(
// SAFETY: `LAYOUT` is guaranteed to accurately describe the
// layout of `Self`, because that is the documented safety
Expand All @@ -196,8 +196,11 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre

match &ast.data {
Data::Struct(strct) => {
let require_trait_bound_on_field_types =
if require_self_sized { FieldBounds::None } else { FieldBounds::TRAILING_SELF };
let require_trait_bound_on_field_types = if self_bounds == SelfBounds::SIZED {
FieldBounds::None
} else {
FieldBounds::TRAILING_SELF
};

// A bound on the trailing field is required, since structs are
// unsized if their trailing field is unsized. Reflecting the layout
Expand All @@ -208,20 +211,36 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre
strct,
Trait::KnownLayout,
require_trait_bound_on_field_types,
require_self_sized,
self_bounds,
None,
Some(extras),
)
}
Data::Enum(enm) => {
// A bound on the trailing field is not required, since enums cannot
// currently be unsized.
impl_block(&ast, enm, Trait::KnownLayout, FieldBounds::None, true, None, Some(extras))
impl_block(
&ast,
enm,
Trait::KnownLayout,
FieldBounds::None,
SelfBounds::SIZED,
None,
Some(extras),
)
}
Data::Union(unn) => {
// A bound on the trailing field is not required, since unions
// cannot currently be unsized.
impl_block(&ast, unn, Trait::KnownLayout, FieldBounds::None, true, None, Some(extras))
impl_block(
&ast,
unn,
Trait::KnownLayout,
FieldBounds::None,
SelfBounds::SIZED,
None,
Some(extras),
)
}
}
.into()
Expand All @@ -231,15 +250,33 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre
pub fn derive_no_cell(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse_macro_input!(ts as DeriveInput);
match &ast.data {
Data::Struct(strct) => {
impl_block(&ast, strct, Trait::NoCell, FieldBounds::ALL_SELF, false, None, None)
}
Data::Enum(enm) => {
impl_block(&ast, enm, Trait::NoCell, FieldBounds::ALL_SELF, false, None, None)
}
Data::Union(unn) => {
impl_block(&ast, unn, Trait::NoCell, FieldBounds::ALL_SELF, false, None, None)
}
Data::Struct(strct) => impl_block(
&ast,
strct,
Trait::NoCell,
FieldBounds::ALL_SELF,
SelfBounds::None,
None,
None,
),
Data::Enum(enm) => impl_block(
&ast,
enm,
Trait::NoCell,
FieldBounds::ALL_SELF,
SelfBounds::None,
None,
None,
),
Data::Union(unn) => impl_block(
&ast,
unn,
Trait::NoCell,
FieldBounds::ALL_SELF,
SelfBounds::None,
None,
None,
),
}
.into()
}
Expand Down Expand Up @@ -337,7 +374,15 @@ fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_m
}
)
});
impl_block(ast, strct, Trait::TryFromBytes, FieldBounds::ALL_SELF, false, None, extras)
impl_block(
ast,
strct,
Trait::TryFromBytes,
FieldBounds::ALL_SELF,
SelfBounds::None,
None,
extras,
)
}

// A union is `TryFromBytes` if:
Expand Down Expand Up @@ -370,7 +415,7 @@ fn derive_try_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro
}
)
});
impl_block(ast, unn, Trait::TryFromBytes, FieldBounds::ALL_SELF, false, None, extras)
impl_block(ast, unn, Trait::TryFromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, extras)
}

const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
Expand Down Expand Up @@ -438,7 +483,7 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2:
})*
}
));
impl_block(ast, enm, Trait::TryFromBytes, FieldBounds::ALL_SELF, false, None, extras)
impl_block(ast, enm, Trait::TryFromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, extras)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -468,7 +513,7 @@ const ENUM_TRY_FROM_BYTES_CFG: Config<EnumRepr> = {
// - all fields are `FromZeros`

fn derive_from_zeros_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
impl_block(ast, strct, Trait::FromZeros, FieldBounds::ALL_SELF, false, None, None)
impl_block(ast, strct, Trait::FromZeros, FieldBounds::ALL_SELF, SelfBounds::None, None, None)
}

// An enum is `FromZeros` if:
Expand Down Expand Up @@ -506,21 +551,21 @@ fn derive_from_zeros_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok
.to_compile_error();
}

impl_block(ast, enm, Trait::FromZeros, FieldBounds::ALL_SELF, false, None, None)
impl_block(ast, enm, Trait::FromZeros, FieldBounds::ALL_SELF, SelfBounds::None, None, None)
}

// Like structs, unions are `FromZeros` if
// - all fields are `FromZeros`

fn derive_from_zeros_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
impl_block(ast, unn, Trait::FromZeros, FieldBounds::ALL_SELF, false, None, None)
impl_block(ast, unn, Trait::FromZeros, FieldBounds::ALL_SELF, SelfBounds::None, None, None)
}

// A struct is `FromBytes` if:
// - all fields are `FromBytes`

fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
impl_block(ast, strct, Trait::FromBytes, FieldBounds::ALL_SELF, false, None, None)
impl_block(ast, strct, Trait::FromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, None)
}

// An enum is `FromBytes` if:
Expand Down Expand Up @@ -563,7 +608,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok
.to_compile_error();
}

impl_block(ast, enm, Trait::FromBytes, FieldBounds::ALL_SELF, false, None, None)
impl_block(ast, enm, Trait::FromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, None)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -594,7 +639,7 @@ const ENUM_FROM_BYTES_CFG: Config<EnumRepr> = {
// - all fields are `FromBytes`

fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
impl_block(ast, unn, Trait::FromBytes, FieldBounds::ALL_SELF, false, None, None)
impl_block(ast, unn, Trait::FromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, None)
}

fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
Expand Down Expand Up @@ -644,7 +689,7 @@ fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2:
FieldBounds::ALL_SELF
};

impl_block(ast, strct, Trait::IntoBytes, field_bounds, false, padding_check, None)
impl_block(ast, strct, Trait::IntoBytes, field_bounds, SelfBounds::None, padding_check, None)
}

const STRUCT_UNION_AS_BYTES_CFG: Config<StructRepr> = Config {
Expand All @@ -667,7 +712,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Token
// We don't care what the repr is; we only care that it is one of the
// allowed ones.
try_or_print!(ENUM_FROM_ZEROS_AS_BYTES_CFG.validate_reprs(ast));
impl_block(ast, enm, Trait::IntoBytes, FieldBounds::None, false, None, None)
impl_block(ast, enm, Trait::IntoBytes, FieldBounds::None, SelfBounds::None, None, None)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -714,7 +759,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok
unn,
Trait::IntoBytes,
FieldBounds::ALL_SELF,
false,
SelfBounds::None,
Some(PaddingCheck::Union),
None,
)
Expand All @@ -734,7 +779,7 @@ fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2
FieldBounds::None
};

impl_block(ast, strct, Trait::Unaligned, field_bounds, false, None, None)
impl_block(ast, strct, Trait::Unaligned, field_bounds, SelfBounds::None, None, None)
}

const STRUCT_UNION_UNALIGNED_CFG: Config<StructRepr> = Config {
Expand Down Expand Up @@ -765,7 +810,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Toke
// true for `require_trait_bound_on_field_types` doesn't really do anything.
// But it's marginally more future-proof in case that restriction is lifted
// in the future.
impl_block(ast, enm, Trait::Unaligned, FieldBounds::ALL_SELF, false, None, None)
impl_block(ast, enm, Trait::Unaligned, FieldBounds::ALL_SELF, SelfBounds::None, None, None)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -807,7 +852,7 @@ fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::To
FieldBounds::None
};

impl_block(ast, unn, Trait::Unaligned, field_type_trait_bounds, false, None, None)
impl_block(ast, unn, Trait::Unaligned, field_type_trait_bounds, SelfBounds::None, None, None)
}

// This enum describes what kind of padding check needs to be generated for the
Expand Down Expand Up @@ -841,11 +886,19 @@ enum Trait {
FromBytes,
IntoBytes,
Unaligned,
Sized,
}

impl Trait {
fn ident(&self) -> Ident {
Ident::new(format!("{:?}", self).as_str(), Span::call_site())
fn path(&self) -> Path {
let span = Span::call_site();
let root = if *self == Self::Sized {
quote_spanned!(span=> ::zerocopy::macro_util::core_reexport::marker)
} else {
quote_spanned!(span=> ::zerocopy)
};
let ident = Ident::new(&format!("{:?}", self), span);
parse_quote_spanned! {span=> #root::#ident}
}
}

Expand All @@ -867,6 +920,16 @@ impl<'a> FieldBounds<'a> {
const TRAILING_SELF: FieldBounds<'a> = FieldBounds::Trailing(&[TraitBound::Slf]);
}

#[derive(Debug, Eq, PartialEq)]
enum SelfBounds<'a> {
None,
All(&'a [Trait]),
}

impl<'a> SelfBounds<'a> {
const SIZED: Self = Self::All(&[Trait::Sized]);
}

/// Normalizes a slice of bounds by replacing [`TraitBound::Slf`] with `slf`.
fn normalize_bounds(slf: Trait, bounds: &[TraitBound]) -> impl '_ + Iterator<Item = Trait> {
bounds.iter().map(move |bound| match bound {
Expand All @@ -880,7 +943,7 @@ fn impl_block<D: DataExt>(
data: &D,
trt: Trait,
field_type_trait_bounds: FieldBounds,
require_self_sized: bool,
self_type_trait_bounds: SelfBounds,
padding_check: Option<PaddingCheck>,
extras: Option<proc_macro2::TokenStream>,
) -> proc_macro2::TokenStream {
Expand Down Expand Up @@ -943,12 +1006,12 @@ fn impl_block<D: DataExt>(
// = note: required by `zerocopy::Unaligned`

let type_ident = &input.ident;
let trait_ident = trt.ident();
let trait_path = trt.path();
let fields = data.fields();

fn bound_tt(ty: &Type, traits: impl Iterator<Item = Trait>) -> WherePredicate {
let traits = traits.map(|t| t.ident());
parse_quote!(#ty: #(::zerocopy::#traits)+*)
let traits = traits.map(|t| t.path());
parse_quote!(#ty: #(#traits)+*)
}
let field_type_bounds: Vec<_> = match (field_type_trait_bounds, &fields[..]) {
(FieldBounds::All(traits), _) => {
Expand All @@ -972,7 +1035,10 @@ fn impl_block<D: DataExt>(
)
});

let self_sized_bound = if require_self_sized { Some(parse_quote!(Self: Sized)) } else { None };
let self_bounds: Option<WherePredicate> = match self_type_trait_bounds {
SelfBounds::None => None,
SelfBounds::All(traits) => Some(bound_tt(&parse_quote!(Self), traits.iter().copied())),
};

let bounds = input
.generics
Expand All @@ -983,7 +1049,7 @@ fn impl_block<D: DataExt>(
.flatten()
.chain(field_type_bounds.iter())
.chain(padding_check_bound.iter())
.chain(self_sized_bound.iter());
.chain(self_bounds.iter());

// The parameters with trait bounds, but without type defaults.
let params = input.generics.params.clone().into_iter().map(|mut param| {
Expand Down Expand Up @@ -1015,7 +1081,7 @@ fn impl_block<D: DataExt>(
// TODO(#553): Add a test that generates a warning when
// `#[allow(deprecated)]` isn't present.
#[allow(deprecated)]
unsafe impl < #(#params),* > ::zerocopy::#trait_ident for #type_ident < #(#param_idents),* >
unsafe impl < #(#params),* > #trait_path for #type_ident < #(#param_idents),* >
where
#(#bounds,)*
{
Expand Down

0 comments on commit d39a5d4

Please sign in to comment.