diff --git a/README.md b/README.md index 494318a..33f3f89 100644 --- a/README.md +++ b/README.md @@ -401,6 +401,7 @@ error_set! { InvalidCredentials }; LoginError = { + #[display("Io Error: {}")] // equivalent to `#[display("Io Error: {}", source)]` IoError(std::io::Error), } || AuthError; } diff --git a/impl/Cargo.toml b/impl/Cargo.toml index ee4f669..52e01df 100644 --- a/impl/Cargo.toml +++ b/impl/Cargo.toml @@ -15,6 +15,7 @@ syn = { version = "2", default-features = false, features = ["parsing","derive", proc-macro2 = "1" quote = "1" indices = "0.3" +dyn-fmt = "0.4.3" [features] default = [] diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 79c8484..f2a5abc 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -166,11 +166,20 @@ fn impl_display(error_enum_node: &ErrorEnumGraphNode, token_stream: &mut TokenSt let name = &variant.name; if let Some(display) = &variant.display { let tokens = &display.tokens; - if is_str_literal(tokens.clone()) { - error_variant_tokens.append_all(quote::quote! { - #enum_name::#name(ref source) => #tokens, - }); + if let Some(string) = extract_string_from_str_literal(tokens.clone()) { + // e.g. `"{}"` + if is_format_str(&string) { + error_variant_tokens.append_all(quote::quote! { + #enum_name::#name(ref source) => &*format!(#tokens, source), + }); + } else { + // e.g. `"literal str"` + error_variant_tokens.append_all(quote::quote! { + #enum_name::#name(_) => #tokens, + }); + } } else { + // e.g. `"field: {}", source.field` error_variant_tokens.append_all(quote::quote! { #enum_name::#name(ref source) => &*format!(#tokens), }); @@ -383,6 +392,19 @@ fn is_str_literal(input: TokenStream) -> bool { false } +fn extract_string_from_str_literal(input: TokenStream) -> Option { + if let Ok(expr) = syn::parse2::(input) { + if let Lit::Str(lit) = expr { + return Some(lit.value()); + } + } + None +} + +fn is_format_str(str: &str) -> bool { + dyn_fmt::AsStrFormatExt::format(&str, ["t"]) != str +} + //************************************************************************// #[cfg(feature = "coerce_macro")] diff --git a/tests/mod.rs b/tests/mod.rs index 8c636ce..f7381d9 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -405,6 +405,14 @@ pub mod display_ref_error { #[display("Y io error: {}", source)] IoError(std::io::Error), }; + Y2 = { + #[display("Y2 io error type: {}", source.kind())] + IoError(std::io::Error), + }; + Z = { + #[display("Z io error: {}")] + IoError(std::io::Error), + }; } #[test] @@ -413,19 +421,35 @@ pub mod display_ref_error { std::io::ErrorKind::OutOfMemory, "oops out of memory 1", )); - assert_eq!(x.to_string(), "X io error".to_string()); let y = Y::IoError(std::io::Error::new( std::io::ErrorKind::OutOfMemory, "oops out of memory 2", )); - assert_eq!(y.to_string(), "Y io error: oops out of memory 2".to_string()); + + let y2 = Y2::IoError(std::io::Error::new( + std::io::ErrorKind::OutOfMemory, + "oops out of memory 3", + )); + assert_eq!(y2.to_string(), "Y2 io error type: out of memory".to_string()); + + let z = Z::IoError(std::io::Error::new( + std::io::ErrorKind::OutOfMemory, + "oops out of memory 4", + )); + assert_eq!(z.to_string(), "Z io error: oops out of memory 4".to_string()); + let y_to_x: X = y.into(); let x_to_y: Y = x.into(); assert_eq!(y_to_x.to_string(), "X io error".to_string()); assert_eq!(x_to_y.to_string(), "Y io error: oops out of memory 1".to_string()); + + let z_to_y2: Y2 = z.into(); + let y2_to_z: Z = y2.into(); + assert_eq!(z_to_y2.to_string(), "Y2 io error type: out of memory".to_string()); + assert_eq!(y2_to_z.to_string(), "Z io error: oops out of memory 3".to_string()); } }