Skip to content

Commit

Permalink
Fix the problem with deriving Debug with other attributes (#15)
Browse files Browse the repository at this point in the history
* derive: Remove Debug from derive list

... in case we are emitting custom Debug impl trait implementation.
Otherwise it would conflict with the default Debug implementation.

Signed-off-by: Jinank Jain <[email protected]>

* tests: Add new tests for derive Debug

There are two new things added:
1. OpenEnum which mimics the failing build scenario
2. Struct which embeds that enum inside it and also derive Debug on top
   of it.

By doing so we verify first of all the code compiles with multiple
derive attributes and secondly the embedded debug derives works as
expected.

Signed-off-by: Jinank Jain <[email protected]>

* Improve stdlib-derive detection logic

It will now check against a (module, Name) pair to determine how to
intercept PartialEq, Eq, and Debug.

Fixes #14.

---------

Signed-off-by: Jinank Jain <[email protected]>
Co-authored-by: Alyssa Haroldsen <[email protected]>
  • Loading branch information
jinankjain and kupiakos authored Nov 18, 2023
1 parent 9c1f1b3 commit c9a5730
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 24 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ libc = { version = "0.2", optional = true }

[dev-dependencies]
trybuild = "1"
zerocopy = { version = "0.7.11", features = ["derive"] }

[workspace]
members = ["derive"]
121 changes: 97 additions & 24 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ use quote::{format_ident, quote, ToTokens};
use repr::Repr;
use std::collections::HashSet;
use syn::{
parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Path,
Token, Visibility,
parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Visibility,
};

/// Sets the span for every token tree in the token stream
Expand Down Expand Up @@ -96,6 +95,37 @@ fn emit_debug_impl<'a>(
})
}

fn path_matches_prelude_derive(
got_path: &syn::Path,
expected_path_after_std: &[&'static str],
) -> bool {
let &[a, b] = expected_path_after_std else {
unimplemented!("checking against stdlib paths with != 2 parts");
};
let segments: Vec<&syn::PathSegment> = got_path.segments.iter().collect();
if segments
.iter()
.any(|segment| !matches!(segment.arguments, syn::PathArguments::None))
{
return false;
}
match &segments[..] {
// `core::fmt::Debug` or `some_crate::module::Name`
&[maybe_core_or_std, maybe_a, maybe_b] => {
(maybe_core_or_std.ident == "core" || maybe_core_or_std.ident == "std")
&& maybe_a.ident == a
&& maybe_b.ident == b
}
// `fmt::Debug` or `module::Name`
&[maybe_a, maybe_b] => {
maybe_a.ident == a && maybe_b.ident == b && got_path.leading_colon.is_none()
}
// `Debug` or `Name``
&[maybe_b] => maybe_b.ident == b && got_path.leading_colon.is_none(),
_ => false,
}
}

fn open_enum_impl(
enum_: ItemEnum,
Config {
Expand Down Expand Up @@ -134,31 +164,39 @@ fn open_enum_impl(
let mut explicit_repr: Option<Repr> = None;

// To make `match` seamless, derive(PartialEq, Eq) if they aren't already.
let mut our_derives = HashSet::new();
our_derives.insert("PartialEq");
our_derives.insert("Eq");
let mut extra_derives = vec![quote!(::core::cmp::PartialEq), quote!(::core::cmp::Eq)];

let mut make_custom_debug_impl = false;
for attr in &enum_.attrs {
let mut include_in_struct = true;
// Turns out `is_ident` does a `to_string` every time
match attr.path.to_token_stream().to_string().as_str() {
"derive" => {
let derives =
attr.parse_args_with(Punctuated::<Path, Token![,]>::parse_terminated)?;
for derive in derives {
if derive.is_ident("PartialEq") {
our_derives.remove("PartialEq");
} else if derive.is_ident("Eq") {
our_derives.remove("Eq");
}
if let Ok(derive_paths) =
attr.parse_args_with(Punctuated::<syn::Path, syn::Token![,]>::parse_terminated)
{
for derive in &derive_paths {
// These derives are treated specially
const PARTIAL_EQ_PATH: &[&str] = &["cmp", "PartialEq"];
const EQ_PATH: &[&str] = &["cmp", "Eq"];
const DEBUG_PATH: &[&str] = &["fmt", "Debug"];

// If we allow aliasing, then don't bother with a custom
// debug impl. There's no way to tell which alias we should
// print.
if derive.is_ident("Debug") && !allow_alias {
make_custom_debug_impl = true;
include_in_struct = false;
if path_matches_prelude_derive(derive, PARTIAL_EQ_PATH)
|| path_matches_prelude_derive(derive, EQ_PATH)
{
// This derive is always included, exclude it.
continue;
}
if path_matches_prelude_derive(derive, DEBUG_PATH) {
if !allow_alias {
make_custom_debug_impl = true;
// Don't include this derive since we're generating a special one.
continue;
}
}
extra_derives.push(derive.to_token_stream());
}
include_in_struct = false;
}
}
// Copy linting attribute to the impl.
Expand Down Expand Up @@ -196,11 +234,8 @@ fn open_enum_impl(
}
};

if !our_derives.is_empty() {
let our_derives = our_derives
.into_iter()
.map(|d| Ident::new(d, Span::call_site()));
struct_attrs.push(quote!(#[derive(#(#our_derives),*)]));
if !extra_derives.is_empty() {
struct_attrs.push(quote!(#[derive(#(#extra_derives),*)]));
}

let alias_check = if allow_alias {
Expand Down Expand Up @@ -259,3 +294,41 @@ pub fn open_enum(
.unwrap_or_else(Error::into_compile_error)
.into()
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_path_matches_stdlib_derive() {
const DEBUG_PATH: &[&str] = &["fmt", "Debug"];

for success_case in [
"::core::fmt::Debug",
"::std::fmt::Debug",
"core::fmt::Debug",
"std::fmt::Debug",
"fmt::Debug",
"Debug",
] {
assert!(
path_matches_prelude_derive(&syn::parse_str(success_case).unwrap(), DEBUG_PATH),
"{success_case}"
);
}

for fail_case in [
"::fmt::Debug",
"::Debug",
"zerocopy::AsBytes",
"::zerocopy::AsBytes",
"PartialEq",
"core::cmp::Eq",
] {
assert!(
!path_matches_prelude_derive(&syn::parse_str(fail_case).unwrap(), DEBUG_PATH),
"{fail_case}"
);
}
}
}
89 changes: 89 additions & 0 deletions tests/derives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright © 2023 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

extern crate open_enum;
use open_enum::*;

#[open_enum]
#[derive(
Debug, Clone, Copy, PartialEq, Eq, zerocopy::AsBytes, zerocopy::FromBytes, zerocopy::FromZeroes,
)]
#[repr(u32)]
pub enum Color {
Red = 1,
Blue = 2,
}

#[open_enum]
#[derive(
core::fmt::Debug,
std::clone::Clone,
::core::marker::Copy,
std::cmp::PartialEq,
::core::cmp::Eq,
zerocopy::AsBytes,
::zerocopy::FromBytes,
zerocopy::FromZeroes,
)]
#[repr(u32)]
pub enum ColorWithNonPreludeDerives {
Red = 1,
Blue = 2,
}

// Ensure that `Color` actually implements the `derive`d traits.
#[derive(
Debug, Copy, Clone, PartialEq, Eq, zerocopy::AsBytes, zerocopy::FromBytes, zerocopy::FromZeroes,
)]
#[repr(C)]
pub struct EmbedColor {
pub color: Color,
}

#[derive(
Debug, Copy, Clone, PartialEq, Eq, zerocopy::AsBytes, zerocopy::FromBytes, zerocopy::FromZeroes,
)]
#[repr(C)]
pub struct EmbedColorWithNonPreludeDerives {
pub color: ColorWithNonPreludeDerives,
}

#[test]
fn embedded_enum_struct_partialeq() {
assert_eq!(
EmbedColor { color: Color::Red },
EmbedColor { color: Color::Red }
);
assert_ne!(
EmbedColor { color: Color::Red },
EmbedColor { color: Color::Blue }
);
}

#[test]
fn embedded_enum_struct_debug() {
let debug_str = format!("{:?}", EmbedColor { color: Color::Red });
assert!(debug_str.contains("Red"), "{debug_str}");
}

#[test]
fn extended_embedded_enum_struct_debug() {
let debug_str = format!(
"{:?}",
EmbedColorWithNonPreludeDerives {
color: ColorWithNonPreludeDerives::Red
}
);
assert!(debug_str.contains("Red"), "{debug_str}");
}

0 comments on commit c9a5730

Please sign in to comment.