diff --git a/zerocopy-derive/Cargo.toml b/zerocopy-derive/Cargo.toml index 3d19d6ce16..4eae69f393 100644 --- a/zerocopy-derive/Cargo.toml +++ b/zerocopy-derive/Cargo.toml @@ -20,7 +20,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.1" quote = "1.0.10" -syn = { version = "2", features = ["visit"] } +syn = "2.0.31" [dev-dependencies] rustversion = "1.0" diff --git a/zerocopy-derive/src/ext.rs b/zerocopy-derive/src/ext.rs index 45b592ee69..ff8a3d6596 100644 --- a/zerocopy-derive/src/ext.rs +++ b/zerocopy-derive/src/ext.rs @@ -2,34 +2,39 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use syn::{Data, DataEnum, DataStruct, DataUnion, Field, Fields, Type}; +use syn::{Data, DataEnum, DataStruct, DataUnion, Type}; pub trait DataExt { - fn nested_types(&self) -> Vec<&Type>; + /// Extract the types of all fields. For enums, extract the types of fields + /// from each variant. + fn field_types(&self) -> Vec<&Type>; } impl DataExt for Data { - fn nested_types(&self) -> Vec<&Type> { + fn field_types(&self) -> Vec<&Type> { match self { - Data::Struct(strc) => strc.nested_types(), - Data::Enum(enm) => enm.nested_types(), - Data::Union(un) => un.nested_types(), + Data::Struct(strc) => strc.field_types(), + Data::Enum(enm) => enm.field_types(), + Data::Union(un) => un.field_types(), } } } impl DataExt for DataStruct { - fn nested_types(&self) -> Vec<&Type> { - fields_to_types(&self.fields) + fn field_types(&self) -> Vec<&Type> { + self.fields.iter().map(|f| &f.ty).collect() } } impl DataExt for DataEnum { - fn nested_types(&self) -> Vec<&Type> { - self.variants.iter().map(|var| fields_to_types(&var.fields)).fold(Vec::new(), |mut a, b| { - a.extend(b); - a - }) + fn field_types(&self) -> Vec<&Type> { + self.variants.iter().flat_map(|var| &var.fields).map(|f| &f.ty).collect() + } +} + +impl DataExt for DataUnion { + fn field_types(&self) -> Vec<&Type> { + self.fields.named.iter().map(|f| &f.ty).collect() } } @@ -39,24 +44,6 @@ pub trait EnumExt { impl EnumExt for DataEnum { fn is_c_like(&self) -> bool { - self.nested_types().is_empty() + self.field_types().is_empty() } } - -impl DataExt for DataUnion { - fn nested_types(&self) -> Vec<&Type> { - field_iter_to_types(&self.fields.named) - } -} - -fn fields_to_types(fields: &Fields) -> Vec<&Type> { - match fields { - Fields::Named(named) => field_iter_to_types(&named.named), - Fields::Unnamed(unnamed) => field_iter_to_types(&unnamed.unnamed), - Fields::Unit => Vec::new(), - } -} - -fn field_iter_to_types<'a, I: IntoIterator>(fields: I) -> Vec<&'a Type> { - fields.into_iter().map(|f| &f.ty).collect() -} diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index 8793665cb1..6b9e3a40f2 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -30,10 +30,9 @@ mod repr; use { proc_macro2::Span, quote::quote, - syn::visit::{self, Visit}, syn::{ - parse_quote, punctuated::Punctuated, token::Comma, Data, DataEnum, DataStruct, DataUnion, - DeriveInput, Error, Expr, ExprLit, GenericParam, Ident, Lifetime, Lit, Type, TypePath, + parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit, + GenericParam, Ident, Lit, }, }; @@ -122,7 +121,7 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ // - all fields are `FromZeroes` fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromZeroes", true, PaddingCheck::None) + impl_block(ast, strct, "FromZeroes", true, None) } // An enum is `FromZeroes` if: @@ -156,21 +155,21 @@ fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::To .to_compile_error(); } - impl_block(ast, enm, "FromZeroes", true, PaddingCheck::None) + impl_block(ast, enm, "FromZeroes", true, None) } // Like structs, unions are `FromZeroes` if // - all fields are `FromZeroes` fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromZeroes", true, PaddingCheck::None) + impl_block(ast, unn, "FromZeroes", true, None) } // A struct is `FromBytes` if: // - all fields are `FromBytes` fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromBytes", true, PaddingCheck::None) + impl_block(ast, strct, "FromBytes", true, None) } // An enum is `FromBytes` if: @@ -213,7 +212,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok .to_compile_error(); } - impl_block(ast, enm, "FromBytes", true, PaddingCheck::None) + impl_block(ast, enm, "FromBytes", true, None) } #[rustfmt::skip] @@ -244,7 +243,7 @@ const ENUM_FROM_BYTES_CFG: Config = { // - all fields are `FromBytes` fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromBytes", true, PaddingCheck::None) + impl_block(ast, unn, "FromBytes", true, None) } // A struct is `AsBytes` if: @@ -277,8 +276,7 @@ fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2: // - repr(packed): Any inter-field padding bytes are removed, meaning that // any padding bytes would need to come from the fields, all of which // we require to be `AsBytes` (meaning they don't have any padding). - let padding_check = - if is_transparent || is_packed { PaddingCheck::None } else { PaddingCheck::Struct }; + let padding_check = if is_transparent || is_packed { None } else { Some(PaddingCheck::Struct) }; impl_block(ast, strct, "AsBytes", true, padding_check) } @@ -302,7 +300,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Token // We don't care what the repr is; we only care that it is one of the // allowed ones. let _: Vec = try_or_print!(ENUM_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, enm, "AsBytes", false, PaddingCheck::None) + impl_block(ast, enm, "AsBytes", false, None) } #[rustfmt::skip] @@ -344,7 +342,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, unn, "AsBytes", true, PaddingCheck::Union) + impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union)) } // A struct is `Unaligned` if: @@ -357,7 +355,7 @@ fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2 let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, strct, "Unaligned", require_trait_bound, PaddingCheck::None) + impl_block(ast, strct, "Unaligned", require_trait_bound, None) } const STRUCT_UNION_UNALIGNED_CFG: Config = Config { @@ -388,7 +386,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Toke // for `require_trait_bounds` doesn't really do anything. But it's // marginally more future-proof in case that restriction is lifted in the // future. - impl_block(ast, enm, "Unaligned", true, PaddingCheck::None) + impl_block(ast, enm, "Unaligned", true, None) } #[rustfmt::skip] @@ -426,26 +424,37 @@ fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::To let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, unn, "Unaligned", require_trait_bound, PaddingCheck::None) + impl_block(ast, unn, "Unaligned", require_trait_bound, None) } // This enum describes what kind of padding check needs to be generated for the // associated impl. enum PaddingCheck { - // No additional padding check is required. - None, // Check that the sum of the fields' sizes exactly equals the struct's size. Struct, // Check that the size of each field exactly equals the union's size. Union, } +impl PaddingCheck { + /// Returns the ident of the macro to call in order to validate that a type + /// passes the padding check encoded by `PaddingCheck`. + fn validator_macro_ident(&self) -> Ident { + let s = match self { + PaddingCheck::Struct => "struct_has_padding", + PaddingCheck::Union => "union_has_padding", + }; + + Ident::new(s, Span::call_site()) + } +} + fn impl_block( input: &DeriveInput, data: &D, trait_name: &str, require_trait_bound: bool, - padding_check: PaddingCheck, + padding_check: Option, ) -> proc_macro2::TokenStream { // In this documentation, we will refer to this hypothetical struct: // @@ -461,22 +470,10 @@ fn impl_block( // c: I::Item, // } // - // First, we extract the field types, which in this case are `u8`, `T`, and - // `I::Item`. We use the names of the type parameters to split the field - // types into two sets - a set of types which are based on the type - // parameters, and a set of types which are not. First, we re-use the - // existing parameters and where clauses, generating an `impl` block like: - // - // impl FromBytes for Foo - // where - // T: Copy, - // I: Clone, - // I::Item: Clone, - // { - // } - // - // Then, we use the list of types which are based on the type parameters to - // generate new entries in the `where` clause: + // We extract the field types, which in this case are `u8`, `T`, and + // `I::Item`. We re-use the existing parameters and where clauses. If + // `require_trait_bound == true` (as it is for `FromBytes), we add where + // bounds for each field's type: // // impl FromBytes for Foo // where @@ -488,18 +485,6 @@ fn impl_block( // { // } // - // Finally, we use a different technique to generate the bounds for the - // types which are not based on type parameters: - // - // - // fn only_derive_is_allowed_to_implement_this_trait() where Self: Sized { - // struct ImplementsFromBytes(PhantomData); - // let _: ImplementsFromBytes; - // } - // - // It would be easier to put all types in the where clause, but that won't - // work until the trivial_bounds feature is stabilized (#48214). - // // NOTE: It is standard practice to only emit bounds for the type parameters // themselves, not for field types based on those parameters (e.g., `T` vs // `T::Foo`). For a discussion of why this is standard practice, see @@ -521,7 +506,6 @@ fn impl_block( // b: PhantomData<&'b u8>, // } // - // // error[E0283]: type annotations required: cannot resolve `core::marker::PhantomData<&'a u8>: zerocopy::Unaligned` // --> src/main.rs:6:10 // | @@ -530,67 +514,37 @@ fn impl_block( // | // = note: required by `zerocopy::Unaligned` - // A visitor which is used to walk a field's type and determine whether any - // of its definition is based on the type or lifetime parameters on a type. - struct FromTypeParamVisit<'a, 'b>(&'a Punctuated, &'b mut bool); - - impl<'a, 'b> Visit<'a> for FromTypeParamVisit<'a, 'b> { - fn visit_lifetime(&mut self, i: &'a Lifetime) { - visit::visit_lifetime(self, i); - if self.0.iter().any(|param| { - if let GenericParam::Lifetime(param) = param { - param.lifetime.ident == i.ident - } else { - false - } - }) { - *self.1 = true; - } - } - - fn visit_type_path(&mut self, i: &'a TypePath) { - visit::visit_type_path(self, i); - if self.0.iter().any(|param| { - if let GenericParam::Type(param) = param { - i.path.segments.first().unwrap().ident == param.ident - } else { - false - } - }) { - *self.1 = true; - } - } - } - - // Whether this type is based on one of the type parameters. E.g., given the - // type parameters ``, `T`, `T::Foo`, and `(T::Foo, String)` are all - // based on the type parameters, while `String` and `(String, Box<()>)` are - // not. - let is_from_type_param = |ty: &Type| { - let mut ret = false; - FromTypeParamVisit(&input.generics.params, &mut ret).visit_type(ty); - ret - }; - + let type_ident = &input.ident; let trait_ident = Ident::new(trait_name, Span::call_site()); + let field_types = data.field_types(); + + let field_type_bounds = require_trait_bound + .then(|| field_types.iter().map(|ty| parse_quote!(#ty: zerocopy::#trait_ident))) + .into_iter() + .flatten() + .collect::>(); + + // Don't bother emitting a padding check if there are no fields. + #[allow(unstable_name_collisions)] // See `BoolExt` below + let padding_check_bound = padding_check.and_then(|check| (!field_types.is_empty()).then_some(check)).map(|check| { + let fields = field_types.iter(); + let validator_macro = check.validator_macro_ident(); + parse_quote!( + zerocopy::derive_util::HasPadding<#type_ident, {zerocopy::#validator_macro!(#type_ident, #(#fields),*)}>: + zerocopy::derive_util::ShouldBe + ) + }); - let field_types = data.nested_types(); - let type_param_field_types = field_types.iter().filter(|ty| is_from_type_param(ty)); - let non_type_param_field_types = field_types.iter().filter(|ty| !is_from_type_param(ty)); - - // Add a new set of where clause predicates of the form `T: Trait` for each - // of the types of the struct's fields (but only the ones whose types are - // based on one of the type parameters). - let mut generics = input.generics.clone(); - let where_clause = generics.make_where_clause(); - if require_trait_bound { - for ty in type_param_field_types { - let bound = parse_quote!(#ty: zerocopy::#trait_ident); - where_clause.predicates.push(bound); - } - } + let bounds = input + .generics + .where_clause + .as_ref() + .map(|where_clause| where_clause.predicates.iter()) + .into_iter() + .flatten() + .chain(field_type_bounds.iter()) + .chain(padding_check_bound.iter()); - let type_ident = &input.ident; // The parameters with trait bounds, but without type defaults. let params = input.generics.params.clone().into_iter().map(|mut param| { match &mut param { @@ -610,70 +564,13 @@ fn impl_block( GenericParam::Const(cnst) => quote!(#cnst), }); - if require_trait_bound { - for ty in non_type_param_field_types { - where_clause.predicates.push(parse_quote!(#ty: zerocopy::#trait_ident)); - } - } - - match (field_types.is_empty(), padding_check) { - (true, _) | (false, PaddingCheck::None) => (), - (false, PaddingCheck::Struct) => { - let fields = field_types.iter(); - // `parse_quote!` doesn't parse macro invocations in const generics - // properly without enabling syn's `full` feature, so the type has - // to be manually constructed as `syn::Type::Verbatim`. - // - // This where clause is equivalent to adding: - // ``` - // HasPadding: ShouldBe - // ``` - // with fully-qualified paths. - where_clause.predicates.push(syn::WherePredicate::Type(syn::PredicateType { - lifetimes: None, - bounded_ty: syn::Type::Verbatim(quote!(zerocopy::derive_util::HasPadding<#type_ident, {zerocopy::struct_has_padding!(#type_ident, #(#fields),*)}>)), - colon_token: syn::Token![:](Span::mixed_site()), - bounds: parse_quote!(zerocopy::derive_util::ShouldBe), - })); - } - (false, PaddingCheck::Union) => { - let fields = field_types.iter(); - // `parse_quote!` doesn't parse macro invocations in const generics - // properly without enabling syn's `full` feature, so the type has - // to be manually constructed as `syn::Type::Verbatim`. - // - // This where clause is equivalent to adding: - // ``` - // HasPadding: ShouldBe - // ``` - // with fully-qualified paths. - where_clause.predicates.push(syn::WherePredicate::Type(syn::PredicateType { - lifetimes: None, - bounded_ty: syn::Type::Verbatim(quote!(zerocopy::derive_util::HasPadding<#type_ident, {zerocopy::union_has_padding!(#type_ident, #(#fields),*)}>)), - colon_token: syn::Token![:](Span::mixed_site()), - bounds: parse_quote!(zerocopy::derive_util::ShouldBe), - })); - } - } - - // We use a constant to force the compiler to emit an error when a concrete - // type does not satisfy the where clauses on its impl. - let use_concrete = if input.generics.params.is_empty() { - Some(quote! { - const _: () = { - fn must_implement_trait() {} - let _ = must_implement_trait::<#type_ident>; - }; - }) - } else { - None - }; - quote! { - unsafe impl < #(#params),* > zerocopy::#trait_ident for #type_ident < #(#param_idents),* > #where_clause { + unsafe impl < #(#params),* > zerocopy::#trait_ident for #type_ident < #(#param_idents),* > + where + #(#bounds,)* + { fn only_derive_is_allowed_to_implement_this_trait() {} } - #use_concrete } } @@ -681,6 +578,23 @@ fn print_all_errors(errors: Vec) -> proc_macro2::TokenStream { errors.iter().map(Error::to_compile_error).collect() } +// A polyfill for `Option::then_some`, which was added after our MSRV. +// +// TODO(#67): Remove this once our MSRV is >= 1.62. +trait BoolExt { + fn then_some(self, t: T) -> Option; +} + +impl BoolExt for bool { + fn then_some(self, t: T) -> Option { + if self { + Some(t) + } else { + None + } + } +} + #[cfg(test)] mod tests { use super::*;