diff --git a/Cargo.toml b/Cargo.toml index e3cc96f..e3e085b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ libc = { version = "0.2", optional = true } [dev-dependencies] trybuild = "1" +zerocopy = { version = "0.7.11", features = ["derive"] } [workspace] members = ["derive"] diff --git a/derive/src/lib.rs b/derive/src/lib.rs index c499911..a7bf107 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -26,8 +26,7 @@ use quote::{format_ident, quote, ToTokens}; use repr::Repr; use std::collections::HashSet; use syn::{ - parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Path, - Token, Visibility, + parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Visibility, }; /// Sets the span for every token tree in the token stream @@ -96,6 +95,37 @@ fn emit_debug_impl<'a>( }) } +fn path_matches_prelude_derive( + got_path: &syn::Path, + expected_path_after_std: &[&'static str], +) -> bool { + let &[a, b] = expected_path_after_std else { + unimplemented!("checking against stdlib paths with != 2 parts"); + }; + let segments: Vec<&syn::PathSegment> = got_path.segments.iter().collect(); + if segments + .iter() + .any(|segment| !matches!(segment.arguments, syn::PathArguments::None)) + { + return false; + } + match &segments[..] { + // `core::fmt::Debug` or `some_crate::module::Name` + &[maybe_core_or_std, maybe_a, maybe_b] => { + (maybe_core_or_std.ident == "core" || maybe_core_or_std.ident == "std") + && maybe_a.ident == a + && maybe_b.ident == b + } + // `fmt::Debug` or `module::Name` + &[maybe_a, maybe_b] => { + maybe_a.ident == a && maybe_b.ident == b && got_path.leading_colon.is_none() + } + // `Debug` or `Name`` + &[maybe_b] => maybe_b.ident == b && got_path.leading_colon.is_none(), + _ => false, + } +} + fn open_enum_impl( enum_: ItemEnum, Config { @@ -134,31 +164,39 @@ fn open_enum_impl( let mut explicit_repr: Option = None; // To make `match` seamless, derive(PartialEq, Eq) if they aren't already. - let mut our_derives = HashSet::new(); - our_derives.insert("PartialEq"); - our_derives.insert("Eq"); + let mut extra_derives = vec![quote!(::core::cmp::PartialEq), quote!(::core::cmp::Eq)]; + let mut make_custom_debug_impl = false; for attr in &enum_.attrs { let mut include_in_struct = true; // Turns out `is_ident` does a `to_string` every time match attr.path.to_token_stream().to_string().as_str() { "derive" => { - let derives = - attr.parse_args_with(Punctuated::::parse_terminated)?; - for derive in derives { - if derive.is_ident("PartialEq") { - our_derives.remove("PartialEq"); - } else if derive.is_ident("Eq") { - our_derives.remove("Eq"); - } + if let Ok(derive_paths) = + attr.parse_args_with(Punctuated::::parse_terminated) + { + for derive in &derive_paths { + // These derives are treated specially + const PARTIAL_EQ_PATH: &[&str] = &["cmp", "PartialEq"]; + const EQ_PATH: &[&str] = &["cmp", "Eq"]; + const DEBUG_PATH: &[&str] = &["fmt", "Debug"]; - // If we allow aliasing, then don't bother with a custom - // debug impl. There's no way to tell which alias we should - // print. - if derive.is_ident("Debug") && !allow_alias { - make_custom_debug_impl = true; - include_in_struct = false; + if path_matches_prelude_derive(derive, PARTIAL_EQ_PATH) + || path_matches_prelude_derive(derive, EQ_PATH) + { + // This derive is always included, exclude it. + continue; + } + if path_matches_prelude_derive(derive, DEBUG_PATH) { + if !allow_alias { + make_custom_debug_impl = true; + // Don't include this derive since we're generating a special one. + continue; + } + } + extra_derives.push(derive.to_token_stream()); } + include_in_struct = false; } } // Copy linting attribute to the impl. @@ -196,11 +234,8 @@ fn open_enum_impl( } }; - if !our_derives.is_empty() { - let our_derives = our_derives - .into_iter() - .map(|d| Ident::new(d, Span::call_site())); - struct_attrs.push(quote!(#[derive(#(#our_derives),*)])); + if !extra_derives.is_empty() { + struct_attrs.push(quote!(#[derive(#(#extra_derives),*)])); } let alias_check = if allow_alias { @@ -259,3 +294,41 @@ pub fn open_enum( .unwrap_or_else(Error::into_compile_error) .into() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_path_matches_stdlib_derive() { + const DEBUG_PATH: &[&str] = &["fmt", "Debug"]; + + for success_case in [ + "::core::fmt::Debug", + "::std::fmt::Debug", + "core::fmt::Debug", + "std::fmt::Debug", + "fmt::Debug", + "Debug", + ] { + assert!( + path_matches_prelude_derive(&syn::parse_str(success_case).unwrap(), DEBUG_PATH), + "{success_case}" + ); + } + + for fail_case in [ + "::fmt::Debug", + "::Debug", + "zerocopy::AsBytes", + "::zerocopy::AsBytes", + "PartialEq", + "core::cmp::Eq", + ] { + assert!( + !path_matches_prelude_derive(&syn::parse_str(fail_case).unwrap(), DEBUG_PATH), + "{fail_case}" + ); + } + } +} diff --git a/tests/derives.rs b/tests/derives.rs new file mode 100644 index 0000000..2cf80b4 --- /dev/null +++ b/tests/derives.rs @@ -0,0 +1,89 @@ +// Copyright © 2023 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate open_enum; +use open_enum::*; + +#[open_enum] +#[derive( + Debug, Clone, Copy, PartialEq, Eq, zerocopy::AsBytes, zerocopy::FromBytes, zerocopy::FromZeroes, +)] +#[repr(u32)] +pub enum Color { + Red = 1, + Blue = 2, +} + +#[open_enum] +#[derive( + core::fmt::Debug, + std::clone::Clone, + ::core::marker::Copy, + std::cmp::PartialEq, + ::core::cmp::Eq, + zerocopy::AsBytes, + ::zerocopy::FromBytes, + zerocopy::FromZeroes, +)] +#[repr(u32)] +pub enum ColorWithNonPreludeDerives { + Red = 1, + Blue = 2, +} + +// Ensure that `Color` actually implements the `derive`d traits. +#[derive( + Debug, Copy, Clone, PartialEq, Eq, zerocopy::AsBytes, zerocopy::FromBytes, zerocopy::FromZeroes, +)] +#[repr(C)] +pub struct EmbedColor { + pub color: Color, +} + +#[derive( + Debug, Copy, Clone, PartialEq, Eq, zerocopy::AsBytes, zerocopy::FromBytes, zerocopy::FromZeroes, +)] +#[repr(C)] +pub struct EmbedColorWithNonPreludeDerives { + pub color: ColorWithNonPreludeDerives, +} + +#[test] +fn embedded_enum_struct_partialeq() { + assert_eq!( + EmbedColor { color: Color::Red }, + EmbedColor { color: Color::Red } + ); + assert_ne!( + EmbedColor { color: Color::Red }, + EmbedColor { color: Color::Blue } + ); +} + +#[test] +fn embedded_enum_struct_debug() { + let debug_str = format!("{:?}", EmbedColor { color: Color::Red }); + assert!(debug_str.contains("Red"), "{debug_str}"); +} + +#[test] +fn extended_embedded_enum_struct_debug() { + let debug_str = format!( + "{:?}", + EmbedColorWithNonPreludeDerives { + color: ColorWithNonPreludeDerives::Red + } + ); + assert!(debug_str.contains("Red"), "{debug_str}"); +}