From 50da3fef1145636167391abd2b2c3ebc99f43185 Mon Sep 17 00:00:00 2001 From: mcmah309 Date: Wed, 4 Dec 2024 23:31:01 +0000 Subject: [PATCH] feat: Allow overriding other generic names --- impl/src/ast.rs | 53 ++++++-- impl/src/expand.rs | 54 ++++---- impl/src/resolve.rs | 119 ++++++++++++++---- tests/mod.rs | 28 ++++- ...ics.rs => generic_specification_needed.rs} | 0 .../generic_specification_needed.stderr | 13 ++ .../multiple_different_generics.stderr | 13 -- 7 files changed, 204 insertions(+), 76 deletions(-) rename tests/trybuild/{multiple_different_generics.rs => generic_specification_needed.rs} (100%) create mode 100644 tests/trybuild/generic_specification_needed.stderr delete mode 100644 tests/trybuild/multiple_different_generics.stderr diff --git a/impl/src/ast.rs b/impl/src/ast.rs index ae8458f..83f8a35 100644 --- a/impl/src/ast.rs +++ b/impl/src/ast.rs @@ -4,7 +4,7 @@ use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, spanned::Spanned, - token, Attribute, Generics, Ident, Result, + token, Attribute, Ident, Result, TypeParam, }; const DISPLAY_ATTRIBUTE_NAME: &str = "display"; @@ -35,7 +35,7 @@ impl Parse for AstErrorSet { pub(crate) struct AstErrorDeclaration { pub(crate) attributes: Vec, pub(crate) error_name: Ident, - pub(crate) generics: Option, + pub(crate) generics: Vec, pub(crate) parts: Vec, } @@ -56,11 +56,7 @@ impl Parse for AstErrorDeclaration { "Expected `=` or generic `<..>` to be next next.", )); } - let generics = if input.peek(syn::Token![<]) { - Some(input.parse::()?) - } else { - None - }; + let generics = generics(&input)?; let last_position_save = input.fork(); if !input.peek(syn::Token![=]) { return Err(syn::Error::new( @@ -102,8 +98,6 @@ impl Parse for AstErrorDeclaration { } } -pub(crate) type RefError = Ident; - #[derive(Clone)] pub(crate) enum AstInlineOrRefError { Inline(AstInlineError), @@ -152,6 +146,23 @@ impl Parse for AstInlineError { } } +#[derive(Clone)] +pub(crate) struct RefError { + pub(crate) name: Ident, + pub(crate) generic_refs: Vec, +} + +impl Parse for RefError { + fn parse(input: ParseStream) -> Result { + let name = input.parse::()?; + let generics = generics(&input)?; + Ok(RefError { + name, + generic_refs: generics, + }) + } +} + //************************************************************************// /// A variant for an error @@ -233,6 +244,30 @@ impl Parse for AstErrorVariant { //************************************************************************// +fn generics(input: &ParseStream) -> Result> { + if input.peek(syn::Token![<]) { + input.parse::()?; + let mut generics = Vec::new(); + loop { + let next = input.parse::(); + match next { + Ok(next) => generics.push(next), + Err(_) => {} + } + let punc = input.parse::(); + if punc.is_err() { + break; + } + } + input.parse::]>()?; + Ok(generics) + } else { + Ok(Vec::new()) + } +} + +//************************************************************************// + /// The format string to use for display #[derive(Clone)] pub(crate) struct DisplayAttribute { diff --git a/impl/src/expand.rs b/impl/src/expand.rs index b247d84..c7a697e 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -4,8 +4,8 @@ use std::usize; use proc_macro2::TokenStream; -use quote::TokenStreamExt; -use syn::{Attribute, Generics, Ident, ImplGenerics, Lit, TypeGenerics, WhereClause}; +use quote::{quote, TokenStreamExt}; +use syn::{Attribute, Ident, Lit, TypeParam}; use crate::ast::{AstInlineErrorVariantField, DisplayAttribute}; @@ -132,11 +132,11 @@ fn add_enum(error_enum_node: &ErrorEnumGraphNode, token_stream: &mut TokenStream } } let attributes = &error_enum.attributes; - let generic_params = &error_enum.generics; + let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics); token_stream.append_all(quote::quote! { #(#attributes)* #[derive(Debug)] - pub enum #enum_name #generic_params { + pub enum #enum_name #impl_generics { #error_variant_tokens } }); @@ -178,10 +178,10 @@ fn impl_error(error_enum_node: &ErrorEnumGraphNode, token_stream: &mut TokenStre } }); } - let (impl_generics, ty_generics, where_clause) = generic_tokens(&error_enum.generics); + let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics); token_stream.append_all(quote::quote! { #[allow(unused_qualifications)] - impl #impl_generics core::error::Error for #enum_name #ty_generics #where_clause { + impl #impl_generics core::error::Error for #enum_name #ty_generics { #error_inner } }); @@ -272,9 +272,9 @@ fn impl_display(error_enum_node: &ErrorEnumGraphNode, token_stream: &mut TokenSt } } } - let (impl_generics, ty_generics, where_clause) = generic_tokens(&error_enum.generics); + let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics); token_stream.append_all(quote::quote! { - impl #impl_generics core::fmt::Display for #enum_name #ty_generics #where_clause { + impl #impl_generics core::fmt::Display for #enum_name #ty_generics { #[inline] fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match *self { @@ -389,10 +389,10 @@ fn impl_froms( error_branch_tokens.append_all(arm); } } - let (impl_generics, ty_generics, where_clause) = generic_tokens(&error_enum.generics); - // Dev Note: We can safely apply ty_generics to the `from` since only one level of generic inheritence is allowed. + let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics); + // todo it is not always correct to apply `ty_generics` to `from_error_enum_name`. token_stream.append_all(quote::quote! { - impl #impl_generics From<#from_error_enum_name #ty_generics> for #error_enum_name #ty_generics #where_clause { + impl #impl_generics From<#from_error_enum_name #ty_generics> for #error_enum_name #ty_generics { fn from(error: #from_error_enum_name #ty_generics) -> Self { match error { #error_branch_tokens @@ -410,10 +410,10 @@ fn impl_froms( continue; } if is_source_tuple_type(error_variant) { - let (impl_generics, ty_generics, where_clause) = generic_tokens(&error_enum.generics); + let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics); let variant_name = &error_variant.name(); token_stream.append_all(quote::quote! { - impl #impl_generics From<#source_type> for #error_enum_name #ty_generics #where_clause { + impl #impl_generics From<#source_type> for #error_enum_name #ty_generics { fn from(error: #source_type) -> Self { #error_enum_name::#variant_name(error) } @@ -421,10 +421,10 @@ fn impl_froms( }); source_errors_froms_already_implemented.push(source_type); } else if is_source_only_struct_type(error_variant) { - let (impl_generics, ty_generics, where_clause) = generic_tokens(&error_enum.generics); + let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics); let variant_name = &error_variant.name(); token_stream.append_all(quote::quote! { - impl #impl_generics From<#source_type> for #error_enum_name #ty_generics #where_clause { + impl #impl_generics From<#source_type> for #error_enum_name #ty_generics { fn from(error: #source_type) -> Self { #error_enum_name::#variant_name { source: error } } @@ -731,7 +731,7 @@ impl ErrorEnumGraphNode { pub(crate) struct ErrorEnum { pub(crate) attributes: Vec, pub(crate) error_name: Ident, - pub(crate) generics: Option, + pub(crate) generics: Vec, pub(crate) error_variants: Vec, } @@ -802,15 +802,16 @@ fn is_opaque(input: TokenStream) -> bool { //************************************************************************// -fn generic_tokens(generics: &Option) -> (Option>, Option>, Option<&WhereClause>) { - let (impl_generics, ty_generics, where_clause) = if let Some(generics) = generics { - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - (Some(impl_generics), Some(ty_generics), where_clause) +fn generic_tokens(generics: &Vec) -> (Option, Option) { + if generics.is_empty() { + return (None, None); } - else { - (None,None,None) - }; - (impl_generics, ty_generics, where_clause) + let impl_clause = quote! {<#(#generics),*>}; + + let names = generics.iter().map(|e| &e.ident); + let ty_clause = quote! {<#(#names),*>}; + + (Some(impl_clause), Some(ty_clause)) } //************************************************************************// @@ -821,7 +822,10 @@ pub(crate) fn is_source_tuple_type(error_variant: &ErrorVariant) -> bool { pub(crate) fn is_source_only_struct_type(error_variant: &ErrorVariant) -> bool { return error_variant.source_type().is_some() - && error_variant.fields().as_ref().is_some_and(|e| e.is_empty()); + && error_variant + .fields() + .as_ref() + .is_some_and(|e| e.is_empty()); } pub(crate) fn is_source_struct_type(error_variant: &ErrorVariant) -> bool { diff --git a/impl/src/resolve.rs b/impl/src/resolve.rs index 33186fd..777d57e 100644 --- a/impl/src/resolve.rs +++ b/impl/src/resolve.rs @@ -1,7 +1,11 @@ -use crate::ast::{AstErrorDeclaration, AstErrorSet, AstErrorVariant, RefError}; +use std::collections::HashMap; + +use crate::ast::{ + AstErrorDeclaration, AstErrorSet, AstErrorVariant, AstInlineErrorVariantField, RefError, +}; use crate::expand::{ErrorEnum, ErrorVariant, Named, SourceStruct, SourceTuple, Struct}; -use syn::{Attribute, Generics, Ident}; +use syn::{Attribute, Ident, TypeParam}; /// Constructs [ErrorEnum]s from the ast, resolving any references to other sets. The returned result is /// all error sets with the full expansion. @@ -81,12 +85,12 @@ fn resolve_builders_helper<'a>( for ref_part in ref_parts_to_resolve { let ref_error_enum_index = error_enum_builders .iter() - .position(|e| e.error_name == ref_part); + .position(|e| e.error_name == ref_part.name); let ref_error_enum_index = match ref_error_enum_index { Some(e) => e, None => { return Err(syn::parse::Error::new_spanned( - &ref_part, + &ref_part.name, "Not a declared error set.", )); } @@ -101,28 +105,76 @@ fn resolve_builders_helper<'a>( } let (this_error_enum_builder, ref_error_enum_builder) = indices::indices!(&mut *error_enum_builders, index, ref_error_enum_index); - match ( - &this_error_enum_builder.generics, - &ref_error_enum_builder.generics, - ) { - (Some(this_generics), Some(that_generics)) => { - if this_generics != that_generics { - // Dev Note: Merging generics may cause collisions in a combined definitions, e.g. `T: Debug` and `T`. - // or unintended sparsity, e.g. `T` and `G` when one would rather just have `T`. - return Err(syn::parse::Error::new_spanned( - &ref_part, - "Aggregating multiple different generic errors is not supported. \ - Instead redefine the error set with the desired generics and fields.", - )); - } + // Let the ref declaration override the original generic declaration name to avoid collisions - `.. || X ..` + if ref_part.generic_refs.len() != ref_error_enum_builder.generics.len() { + Err(syn::parse::Error::new_spanned( + &ref_part.name, + format!("A reference to {} was declared with {} generic param(s), but the original defintion takes {}.", ref_part.name, ref_part.generic_refs.len(), ref_error_enum_builder.generics.len()), + ))?; + } + let mut error_variants = Vec::new(); + let error_variants = if ref_part.generic_refs.is_empty() { + &ref_error_enum_builder.error_variants + } else { + fn ident_to_type(ident: Ident) -> syn::Type { + let segment = syn::PathSegment { + ident, + arguments: syn::PathArguments::None, // No generic arguments + }; + let path = syn::Path { + leading_colon: None, + segments: { + let mut punctuated = syn::punctuated::Punctuated::new(); + punctuated.push(segment); + punctuated + }, + }; + let type_path = syn::TypePath { qself: None, path }; + syn::Type::Path(type_path) } - (None, None) => {} - (None, Some(generics)) => { - this_error_enum_builder.generics = Some(generics.clone()); + // rename the generics inside the variants to the new declared name - to avoid collisions. + let mut rename = HashMap::::new(); + for (ref_part_generic, ref_error_enum_generic) in ref_part + .generic_refs + .iter() + .zip(ref_error_enum_builder.generics.iter()) + { + rename.insert( + ident_to_type(ref_error_enum_generic.ident.clone()), + ident_to_type(ref_part_generic.clone()), + ); } - (Some(_), None) => {} + // let error_variants = Vec::new(); + for error_variant in ref_error_enum_builder.error_variants.iter() { + let newfields = if let Some(fields) = &error_variant.fields { + let mut newfields = Vec::new(); + for field in fields.iter() { + if rename.contains_key(&field.r#type) { + let new_type = rename.get(&field.r#type).unwrap().clone(); + newfields.push(AstInlineErrorVariantField { + name: field.name.clone(), + r#type: new_type.clone(), + }); + } else { + newfields.push(field.clone()); + } + } + Some(newfields) + } else { + None + }; + error_variants.push(AstErrorVariant { + attributes: error_variant.attributes.clone(), + display: error_variant.display.clone(), + name: error_variant.name.clone(), + fields: newfields, + source_type: error_variant.source_type.clone(), + backtrace_type: error_variant.backtrace_type.clone(), + }); + } + &error_variants }; - for variant in ref_error_enum_builder.error_variants.iter() { + for variant in error_variants { let this_error_variants = &mut this_error_enum_builder.error_variants; let is_variant_already_in_enum = this_error_variants .iter() @@ -144,17 +196,34 @@ pub(crate) fn does_occupy_the_same_space(this: &AstErrorVariant, other: &AstErro return this.name == other.name; } +// fn merge_generics(this: &mut Generics, other: &Generics) { +// let other_params = other.params.iter().collect::>(); +// for other_param in other_params { +// if !this.params.iter().any(|param| param == other_param) { +// this.params.push(other_param.clone()); +// } +// } +// let other_where = other.where_clause.as_ref(); +// if let Some(other_where) = other_where { +// if let Some(this_where) = &mut this.where_clause { +// this_where.predicates.extend(other_where.predicates.clone()); +// } else { +// this.where_clause = Some(other_where.clone()); +// } +// } +// } + struct ErrorEnumBuilder { pub attributes: Vec, pub error_name: Ident, - pub generics: Option, + pub generics: Vec, pub error_variants: Vec, /// Once this is empty, all [ref_parts] have been resolved and [error_variants] is complete. pub ref_parts_to_resolve: Vec, } impl ErrorEnumBuilder { - fn new(error_name: Ident, attributes: Vec, generics: Option) -> Self { + fn new(error_name: Ident, attributes: Vec, generics: Vec) -> Self { Self { attributes, error_name, diff --git a/tests/mod.rs b/tests/mod.rs index 8ce9304..a58ec3d 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -681,9 +681,21 @@ pub mod generics { }, InvalidCredentials }; - LoginError = { + LoginError = { IoError(std::io::Error), - } || AuthError1 || AuthError2; + } || AuthError1 || AuthError2; + + X = { + A { + a: G + } + }; + Y = { + B { + b: H + } + }; + Z = X || Y; } #[test] @@ -715,6 +727,14 @@ pub mod generics { matches!(auth_error, AuthError1::InvalidCredentials); let auth_error: AuthError2 = auth_error.into(); matches!(auth_error, AuthError2::InvalidCredentials); + + let _x: X = X::A { a: 1 }; + //let z: Z = x.into(); + //matches!(z, Z::A { a: _ }); + + let _y: Y = Y::B { b: 1 }; + //let z: Z = y.into(); + //matches!(z, Z::B { b: _ }); } } @@ -740,9 +760,9 @@ pub mod should_not_compile_tests { } #[test] - fn multiple_different_generics() { + fn generic_specification_needed() { let t = trybuild::TestCases::new(); - t.compile_fail("tests/trybuild/multiple_different_generics.rs"); + t.compile_fail("tests/trybuild/generic_specification_needed.rs"); } #[test] diff --git a/tests/trybuild/multiple_different_generics.rs b/tests/trybuild/generic_specification_needed.rs similarity index 100% rename from tests/trybuild/multiple_different_generics.rs rename to tests/trybuild/generic_specification_needed.rs diff --git a/tests/trybuild/generic_specification_needed.stderr b/tests/trybuild/generic_specification_needed.stderr new file mode 100644 index 0000000..fce3d2f --- /dev/null +++ b/tests/trybuild/generic_specification_needed.stderr @@ -0,0 +1,13 @@ +error: A reference to AuthError1 was declared with 0 generic param(s), but the original defintion takes 1. + --> tests/trybuild/generic_specification_needed.rs:32:10 + | +32 | } || AuthError1 || AuthError2 || AuthError3; + | ^^^^^^^^^^ + +warning: unused import: `std::fmt::Debug` + --> tests/trybuild/generic_specification_needed.rs:1:5 + | +1 | use std::fmt::Debug; + | ^^^^^^^^^^^^^^^ + | + = note: `#[warn(unused_imports)]` on by default diff --git a/tests/trybuild/multiple_different_generics.stderr b/tests/trybuild/multiple_different_generics.stderr deleted file mode 100644 index 2487389..0000000 --- a/tests/trybuild/multiple_different_generics.stderr +++ /dev/null @@ -1,13 +0,0 @@ -error: Aggregating multiple different generic errors is not supported. Instead redefine the error set with the desired generics and fields. - --> tests/trybuild/multiple_different_generics.rs:32:38 - | -32 | } || AuthError1 || AuthError2 || AuthError3; - | ^^^^^^^^^^ - -warning: unused import: `std::fmt::Debug` - --> tests/trybuild/multiple_different_generics.rs:1:5 - | -1 | use std::fmt::Debug; - | ^^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default