From d39a5d4ca20e26a06556c32d3a064bd9e3f58f3c Mon Sep 17 00:00:00 2001 From: Jack Wrenn Date: Thu, 15 Feb 2024 18:18:52 -0500 Subject: [PATCH] derive: generalize `require_self_sized` (#883) Generalize `require_self_sized` to require other `Trait`s. --- zerocopy-derive/src/lib.rs | 158 ++++++++++++++++++++++++++----------- 1 file changed, 112 insertions(+), 46 deletions(-) diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index d7b3a331f4..75799e8cb2 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -31,12 +31,14 @@ mod ext; mod repr; +use quote::quote_spanned; + use { proc_macro2::Span, quote::quote, syn::{ - parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit, - GenericParam, Ident, Lit, Type, WherePredicate, + parse_quote, parse_quote_spanned, Data, DataEnum, DataStruct, DataUnion, DeriveInput, + Error, Expr, ExprLit, GenericParam, Ident, Lit, Path, Type, WherePredicate, }, }; @@ -88,10 +90,8 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre let fields = ast.data.fields(); - let (require_self_sized, extras) = if let ( - Some(reprs), - Some((trailing_field, leading_fields)), - ) = (is_repr_c_struct, fields.split_last()) + let (self_bounds, extras) = if let (Some(reprs), Some((trailing_field, leading_fields))) = + (is_repr_c_struct, fields.split_last()) { let (_name, trailing_field_ty) = trailing_field; let leading_fields_tys = leading_fields.iter().map(|(_name, ty)| ty); @@ -121,7 +121,7 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre .unwrap_or(quote!(None)); ( - false, + SelfBounds::None, quote!( // SAFETY: `LAYOUT` accurately describes the layout of `Self`. // The layout of `Self` is reflected using a sequence of @@ -172,7 +172,7 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre // `Self` is sized, and as a result don't need to reason about the // internals of the type. ( - true, + SelfBounds::SIZED, quote!( // SAFETY: `LAYOUT` is guaranteed to accurately describe the // layout of `Self`, because that is the documented safety @@ -196,8 +196,11 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre match &ast.data { Data::Struct(strct) => { - let require_trait_bound_on_field_types = - if require_self_sized { FieldBounds::None } else { FieldBounds::TRAILING_SELF }; + let require_trait_bound_on_field_types = if self_bounds == SelfBounds::SIZED { + FieldBounds::None + } else { + FieldBounds::TRAILING_SELF + }; // A bound on the trailing field is required, since structs are // unsized if their trailing field is unsized. Reflecting the layout @@ -208,7 +211,7 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre strct, Trait::KnownLayout, require_trait_bound_on_field_types, - require_self_sized, + self_bounds, None, Some(extras), ) @@ -216,12 +219,28 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre Data::Enum(enm) => { // A bound on the trailing field is not required, since enums cannot // currently be unsized. - impl_block(&ast, enm, Trait::KnownLayout, FieldBounds::None, true, None, Some(extras)) + impl_block( + &ast, + enm, + Trait::KnownLayout, + FieldBounds::None, + SelfBounds::SIZED, + None, + Some(extras), + ) } Data::Union(unn) => { // A bound on the trailing field is not required, since unions // cannot currently be unsized. - impl_block(&ast, unn, Trait::KnownLayout, FieldBounds::None, true, None, Some(extras)) + impl_block( + &ast, + unn, + Trait::KnownLayout, + FieldBounds::None, + SelfBounds::SIZED, + None, + Some(extras), + ) } } .into() @@ -231,15 +250,33 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre pub fn derive_no_cell(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { let ast = syn::parse_macro_input!(ts as DeriveInput); match &ast.data { - Data::Struct(strct) => { - impl_block(&ast, strct, Trait::NoCell, FieldBounds::ALL_SELF, false, None, None) - } - Data::Enum(enm) => { - impl_block(&ast, enm, Trait::NoCell, FieldBounds::ALL_SELF, false, None, None) - } - Data::Union(unn) => { - impl_block(&ast, unn, Trait::NoCell, FieldBounds::ALL_SELF, false, None, None) - } + Data::Struct(strct) => impl_block( + &ast, + strct, + Trait::NoCell, + FieldBounds::ALL_SELF, + SelfBounds::None, + None, + None, + ), + Data::Enum(enm) => impl_block( + &ast, + enm, + Trait::NoCell, + FieldBounds::ALL_SELF, + SelfBounds::None, + None, + None, + ), + Data::Union(unn) => impl_block( + &ast, + unn, + Trait::NoCell, + FieldBounds::ALL_SELF, + SelfBounds::None, + None, + None, + ), } .into() } @@ -337,7 +374,15 @@ fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_m } ) }); - impl_block(ast, strct, Trait::TryFromBytes, FieldBounds::ALL_SELF, false, None, extras) + impl_block( + ast, + strct, + Trait::TryFromBytes, + FieldBounds::ALL_SELF, + SelfBounds::None, + None, + extras, + ) } // A union is `TryFromBytes` if: @@ -370,7 +415,7 @@ fn derive_try_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro } ) }); - impl_block(ast, unn, Trait::TryFromBytes, FieldBounds::ALL_SELF, false, None, extras) + impl_block(ast, unn, Trait::TryFromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, extras) } const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ @@ -438,7 +483,7 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2: })* } )); - impl_block(ast, enm, Trait::TryFromBytes, FieldBounds::ALL_SELF, false, None, extras) + impl_block(ast, enm, Trait::TryFromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, extras) } #[rustfmt::skip] @@ -468,7 +513,7 @@ const ENUM_TRY_FROM_BYTES_CFG: Config = { // - all fields are `FromZeros` fn derive_from_zeros_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, Trait::FromZeros, FieldBounds::ALL_SELF, false, None, None) + impl_block(ast, strct, Trait::FromZeros, FieldBounds::ALL_SELF, SelfBounds::None, None, None) } // An enum is `FromZeros` if: @@ -506,21 +551,21 @@ fn derive_from_zeros_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok .to_compile_error(); } - impl_block(ast, enm, Trait::FromZeros, FieldBounds::ALL_SELF, false, None, None) + impl_block(ast, enm, Trait::FromZeros, FieldBounds::ALL_SELF, SelfBounds::None, None, None) } // Like structs, unions are `FromZeros` if // - all fields are `FromZeros` fn derive_from_zeros_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, Trait::FromZeros, FieldBounds::ALL_SELF, false, None, None) + impl_block(ast, unn, Trait::FromZeros, FieldBounds::ALL_SELF, SelfBounds::None, None, None) } // A struct is `FromBytes` if: // - all fields are `FromBytes` fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, Trait::FromBytes, FieldBounds::ALL_SELF, false, None, None) + impl_block(ast, strct, Trait::FromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, None) } // An enum is `FromBytes` if: @@ -563,7 +608,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok .to_compile_error(); } - impl_block(ast, enm, Trait::FromBytes, FieldBounds::ALL_SELF, false, None, None) + impl_block(ast, enm, Trait::FromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, None) } #[rustfmt::skip] @@ -594,7 +639,7 @@ const ENUM_FROM_BYTES_CFG: Config = { // - all fields are `FromBytes` fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, Trait::FromBytes, FieldBounds::ALL_SELF, false, None, None) + impl_block(ast, unn, Trait::FromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, None) } fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { @@ -644,7 +689,7 @@ fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2: FieldBounds::ALL_SELF }; - impl_block(ast, strct, Trait::IntoBytes, field_bounds, false, padding_check, None) + impl_block(ast, strct, Trait::IntoBytes, field_bounds, SelfBounds::None, padding_check, None) } const STRUCT_UNION_AS_BYTES_CFG: Config = Config { @@ -667,7 +712,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Token // We don't care what the repr is; we only care that it is one of the // allowed ones. try_or_print!(ENUM_FROM_ZEROS_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, enm, Trait::IntoBytes, FieldBounds::None, false, None, None) + impl_block(ast, enm, Trait::IntoBytes, FieldBounds::None, SelfBounds::None, None, None) } #[rustfmt::skip] @@ -714,7 +759,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok unn, Trait::IntoBytes, FieldBounds::ALL_SELF, - false, + SelfBounds::None, Some(PaddingCheck::Union), None, ) @@ -734,7 +779,7 @@ fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2 FieldBounds::None }; - impl_block(ast, strct, Trait::Unaligned, field_bounds, false, None, None) + impl_block(ast, strct, Trait::Unaligned, field_bounds, SelfBounds::None, None, None) } const STRUCT_UNION_UNALIGNED_CFG: Config = Config { @@ -765,7 +810,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Toke // true for `require_trait_bound_on_field_types` doesn't really do anything. // But it's marginally more future-proof in case that restriction is lifted // in the future. - impl_block(ast, enm, Trait::Unaligned, FieldBounds::ALL_SELF, false, None, None) + impl_block(ast, enm, Trait::Unaligned, FieldBounds::ALL_SELF, SelfBounds::None, None, None) } #[rustfmt::skip] @@ -807,7 +852,7 @@ fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::To FieldBounds::None }; - impl_block(ast, unn, Trait::Unaligned, field_type_trait_bounds, false, None, None) + impl_block(ast, unn, Trait::Unaligned, field_type_trait_bounds, SelfBounds::None, None, None) } // This enum describes what kind of padding check needs to be generated for the @@ -841,11 +886,19 @@ enum Trait { FromBytes, IntoBytes, Unaligned, + Sized, } impl Trait { - fn ident(&self) -> Ident { - Ident::new(format!("{:?}", self).as_str(), Span::call_site()) + fn path(&self) -> Path { + let span = Span::call_site(); + let root = if *self == Self::Sized { + quote_spanned!(span=> ::zerocopy::macro_util::core_reexport::marker) + } else { + quote_spanned!(span=> ::zerocopy) + }; + let ident = Ident::new(&format!("{:?}", self), span); + parse_quote_spanned! {span=> #root::#ident} } } @@ -867,6 +920,16 @@ impl<'a> FieldBounds<'a> { const TRAILING_SELF: FieldBounds<'a> = FieldBounds::Trailing(&[TraitBound::Slf]); } +#[derive(Debug, Eq, PartialEq)] +enum SelfBounds<'a> { + None, + All(&'a [Trait]), +} + +impl<'a> SelfBounds<'a> { + const SIZED: Self = Self::All(&[Trait::Sized]); +} + /// Normalizes a slice of bounds by replacing [`TraitBound::Slf`] with `slf`. fn normalize_bounds(slf: Trait, bounds: &[TraitBound]) -> impl '_ + Iterator { bounds.iter().map(move |bound| match bound { @@ -880,7 +943,7 @@ fn impl_block( data: &D, trt: Trait, field_type_trait_bounds: FieldBounds, - require_self_sized: bool, + self_type_trait_bounds: SelfBounds, padding_check: Option, extras: Option, ) -> proc_macro2::TokenStream { @@ -943,12 +1006,12 @@ fn impl_block( // = note: required by `zerocopy::Unaligned` let type_ident = &input.ident; - let trait_ident = trt.ident(); + let trait_path = trt.path(); let fields = data.fields(); fn bound_tt(ty: &Type, traits: impl Iterator) -> WherePredicate { - let traits = traits.map(|t| t.ident()); - parse_quote!(#ty: #(::zerocopy::#traits)+*) + let traits = traits.map(|t| t.path()); + parse_quote!(#ty: #(#traits)+*) } let field_type_bounds: Vec<_> = match (field_type_trait_bounds, &fields[..]) { (FieldBounds::All(traits), _) => { @@ -972,7 +1035,10 @@ fn impl_block( ) }); - let self_sized_bound = if require_self_sized { Some(parse_quote!(Self: Sized)) } else { None }; + let self_bounds: Option = match self_type_trait_bounds { + SelfBounds::None => None, + SelfBounds::All(traits) => Some(bound_tt(&parse_quote!(Self), traits.iter().copied())), + }; let bounds = input .generics @@ -983,7 +1049,7 @@ fn impl_block( .flatten() .chain(field_type_bounds.iter()) .chain(padding_check_bound.iter()) - .chain(self_sized_bound.iter()); + .chain(self_bounds.iter()); // The parameters with trait bounds, but without type defaults. let params = input.generics.params.clone().into_iter().map(|mut param| { @@ -1015,7 +1081,7 @@ fn impl_block( // TODO(#553): Add a test that generates a warning when // `#[allow(deprecated)]` isn't present. #[allow(deprecated)] - unsafe impl < #(#params),* > ::zerocopy::#trait_ident for #type_ident < #(#param_idents),* > + unsafe impl < #(#params),* > #trait_path for #type_ident < #(#param_idents),* > where #(#bounds,)* {