From c9a57300adbf69975ffd60f157db7f15aae342f8 Mon Sep 17 00:00:00 2001 From: Jinank Jain Date: Sat, 18 Nov 2023 05:36:47 +0530 Subject: [PATCH] Fix the problem with deriving Debug with other attributes (#15) * derive: Remove Debug from derive list ... in case we are emitting custom Debug impl trait implementation. Otherwise it would conflict with the default Debug implementation. Signed-off-by: Jinank Jain * tests: Add new tests for derive Debug There are two new things added: 1. OpenEnum which mimics the failing build scenario 2. Struct which embeds that enum inside it and also derive Debug on top of it. By doing so we verify first of all the code compiles with multiple derive attributes and secondly the embedded debug derives works as expected. Signed-off-by: Jinank Jain * Improve stdlib-derive detection logic It will now check against a (module, Name) pair to determine how to intercept PartialEq, Eq, and Debug. Fixes #14. --------- Signed-off-by: Jinank Jain Co-authored-by: Alyssa Haroldsen --- Cargo.toml | 1 + derive/src/lib.rs | 121 +++++++++++++++++++++++++++++++++++++--------- tests/derives.rs | 89 ++++++++++++++++++++++++++++++++++ 3 files changed, 187 insertions(+), 24 deletions(-) create mode 100644 tests/derives.rs diff --git a/Cargo.toml b/Cargo.toml index e3cc96f..e3e085b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ libc = { version = "0.2", optional = true } [dev-dependencies] trybuild = "1" +zerocopy = { version = "0.7.11", features = ["derive"] } [workspace] members = ["derive"] diff --git a/derive/src/lib.rs b/derive/src/lib.rs index c499911..a7bf107 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -26,8 +26,7 @@ use quote::{format_ident, quote, ToTokens}; use repr::Repr; use std::collections::HashSet; use syn::{ - parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Path, - Token, Visibility, + parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Visibility, }; /// Sets the span for every token tree in the token stream @@ -96,6 +95,37 @@ fn emit_debug_impl<'a>( }) } +fn path_matches_prelude_derive( + got_path: &syn::Path, + expected_path_after_std: &[&'static str], +) -> bool { + let &[a, b] = expected_path_after_std else { + unimplemented!("checking against stdlib paths with != 2 parts"); + }; + let segments: Vec<&syn::PathSegment> = got_path.segments.iter().collect(); + if segments + .iter() + .any(|segment| !matches!(segment.arguments, syn::PathArguments::None)) + { + return false; + } + match &segments[..] { + // `core::fmt::Debug` or `some_crate::module::Name` + &[maybe_core_or_std, maybe_a, maybe_b] => { + (maybe_core_or_std.ident == "core" || maybe_core_or_std.ident == "std") + && maybe_a.ident == a + && maybe_b.ident == b + } + // `fmt::Debug` or `module::Name` + &[maybe_a, maybe_b] => { + maybe_a.ident == a && maybe_b.ident == b && got_path.leading_colon.is_none() + } + // `Debug` or `Name`` + &[maybe_b] => maybe_b.ident == b && got_path.leading_colon.is_none(), + _ => false, + } +} + fn open_enum_impl( enum_: ItemEnum, Config { @@ -134,31 +164,39 @@ fn open_enum_impl( let mut explicit_repr: Option = None; // To make `match` seamless, derive(PartialEq, Eq) if they aren't already. - let mut our_derives = HashSet::new(); - our_derives.insert("PartialEq"); - our_derives.insert("Eq"); + let mut extra_derives = vec![quote!(::core::cmp::PartialEq), quote!(::core::cmp::Eq)]; + let mut make_custom_debug_impl = false; for attr in &enum_.attrs { let mut include_in_struct = true; // Turns out `is_ident` does a `to_string` every time match attr.path.to_token_stream().to_string().as_str() { "derive" => { - let derives = - attr.parse_args_with(Punctuated::::parse_terminated)?; - for derive in derives { - if derive.is_ident("PartialEq") { - our_derives.remove("PartialEq"); - } else if derive.is_ident("Eq") { - our_derives.remove("Eq"); - } + if let Ok(derive_paths) = + attr.parse_args_with(Punctuated::::parse_terminated) + { + for derive in &derive_paths { + // These derives are treated specially + const PARTIAL_EQ_PATH: &[&str] = &["cmp", "PartialEq"]; + const EQ_PATH: &[&str] = &["cmp", "Eq"]; + const DEBUG_PATH: &[&str] = &["fmt", "Debug"]; - // If we allow aliasing, then don't bother with a custom - // debug impl. There's no way to tell which alias we should - // print. - if derive.is_ident("Debug") && !allow_alias { - make_custom_debug_impl = true; - include_in_struct = false; + if path_matches_prelude_derive(derive, PARTIAL_EQ_PATH) + || path_matches_prelude_derive(derive, EQ_PATH) + { + // This derive is always included, exclude it. + continue; + } + if path_matches_prelude_derive(derive, DEBUG_PATH) { + if !allow_alias { + make_custom_debug_impl = true; + // Don't include this derive since we're generating a special one. + continue; + } + } + extra_derives.push(derive.to_token_stream()); } + include_in_struct = false; } } // Copy linting attribute to the impl. @@ -196,11 +234,8 @@ fn open_enum_impl( } }; - if !our_derives.is_empty() { - let our_derives = our_derives - .into_iter() - .map(|d| Ident::new(d, Span::call_site())); - struct_attrs.push(quote!(#[derive(#(#our_derives),*)])); + if !extra_derives.is_empty() { + struct_attrs.push(quote!(#[derive(#(#extra_derives),*)])); } let alias_check = if allow_alias { @@ -259,3 +294,41 @@ pub fn open_enum( .unwrap_or_else(Error::into_compile_error) .into() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_path_matches_stdlib_derive() { + const DEBUG_PATH: &[&str] = &["fmt", "Debug"]; + + for success_case in [ + "::core::fmt::Debug", + "::std::fmt::Debug", + "core::fmt::Debug", + "std::fmt::Debug", + "fmt::Debug", + "Debug", + ] { + assert!( + path_matches_prelude_derive(&syn::parse_str(success_case).unwrap(), DEBUG_PATH), + "{success_case}" + ); + } + + for fail_case in [ + "::fmt::Debug", + "::Debug", + "zerocopy::AsBytes", + "::zerocopy::AsBytes", + "PartialEq", + "core::cmp::Eq", + ] { + assert!( + !path_matches_prelude_derive(&syn::parse_str(fail_case).unwrap(), DEBUG_PATH), + "{fail_case}" + ); + } + } +} diff --git a/tests/derives.rs b/tests/derives.rs new file mode 100644 index 0000000..2cf80b4 --- /dev/null +++ b/tests/derives.rs @@ -0,0 +1,89 @@ +// Copyright © 2023 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate open_enum; +use open_enum::*; + +#[open_enum] +#[derive( + Debug, Clone, Copy, PartialEq, Eq, zerocopy::AsBytes, zerocopy::FromBytes, zerocopy::FromZeroes, +)] +#[repr(u32)] +pub enum Color { + Red = 1, + Blue = 2, +} + +#[open_enum] +#[derive( + core::fmt::Debug, + std::clone::Clone, + ::core::marker::Copy, + std::cmp::PartialEq, + ::core::cmp::Eq, + zerocopy::AsBytes, + ::zerocopy::FromBytes, + zerocopy::FromZeroes, +)] +#[repr(u32)] +pub enum ColorWithNonPreludeDerives { + Red = 1, + Blue = 2, +} + +// Ensure that `Color` actually implements the `derive`d traits. +#[derive( + Debug, Copy, Clone, PartialEq, Eq, zerocopy::AsBytes, zerocopy::FromBytes, zerocopy::FromZeroes, +)] +#[repr(C)] +pub struct EmbedColor { + pub color: Color, +} + +#[derive( + Debug, Copy, Clone, PartialEq, Eq, zerocopy::AsBytes, zerocopy::FromBytes, zerocopy::FromZeroes, +)] +#[repr(C)] +pub struct EmbedColorWithNonPreludeDerives { + pub color: ColorWithNonPreludeDerives, +} + +#[test] +fn embedded_enum_struct_partialeq() { + assert_eq!( + EmbedColor { color: Color::Red }, + EmbedColor { color: Color::Red } + ); + assert_ne!( + EmbedColor { color: Color::Red }, + EmbedColor { color: Color::Blue } + ); +} + +#[test] +fn embedded_enum_struct_debug() { + let debug_str = format!("{:?}", EmbedColor { color: Color::Red }); + assert!(debug_str.contains("Red"), "{debug_str}"); +} + +#[test] +fn extended_embedded_enum_struct_debug() { + let debug_str = format!( + "{:?}", + EmbedColorWithNonPreludeDerives { + color: ColorWithNonPreludeDerives::Red + } + ); + assert!(debug_str.contains("Red"), "{debug_str}"); +}