Skip to content

Commit

Permalink
feat: Allow overriding other generic names
Browse files Browse the repository at this point in the history
  • Loading branch information
mcmah309 committed Dec 4, 2024
1 parent bd93839 commit 50da3fe
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 76 deletions.
53 changes: 44 additions & 9 deletions impl/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
spanned::Spanned,
token, Attribute, Generics, Ident, Result,
token, Attribute, Ident, Result, TypeParam,
};

const DISPLAY_ATTRIBUTE_NAME: &str = "display";
Expand Down Expand Up @@ -35,7 +35,7 @@ impl Parse for AstErrorSet {
pub(crate) struct AstErrorDeclaration {
pub(crate) attributes: Vec<Attribute>,
pub(crate) error_name: Ident,
pub(crate) generics: Option<Generics>,
pub(crate) generics: Vec<TypeParam>,
pub(crate) parts: Vec<AstInlineOrRefError>,
}

Expand All @@ -56,11 +56,7 @@ impl Parse for AstErrorDeclaration {
"Expected `=` or generic `<..>` to be next next.",
));
}
let generics = if input.peek(syn::Token![<]) {
Some(input.parse::<Generics>()?)
} else {
None
};
let generics = generics(&input)?;
let last_position_save = input.fork();
if !input.peek(syn::Token![=]) {
return Err(syn::Error::new(
Expand Down Expand Up @@ -102,8 +98,6 @@ impl Parse for AstErrorDeclaration {
}
}

pub(crate) type RefError = Ident;

#[derive(Clone)]
pub(crate) enum AstInlineOrRefError {
Inline(AstInlineError),
Expand Down Expand Up @@ -152,6 +146,23 @@ impl Parse for AstInlineError {
}
}

#[derive(Clone)]
pub(crate) struct RefError {
pub(crate) name: Ident,
pub(crate) generic_refs: Vec<Ident>,
}

impl Parse for RefError {
fn parse(input: ParseStream) -> Result<Self> {
let name = input.parse::<Ident>()?;
let generics = generics(&input)?;
Ok(RefError {
name,
generic_refs: generics,
})
}
}

//************************************************************************//

/// A variant for an error
Expand Down Expand Up @@ -233,6 +244,30 @@ impl Parse for AstErrorVariant {

//************************************************************************//

fn generics<T: Parse>(input: &ParseStream) -> Result<Vec<T>> {
if input.peek(syn::Token![<]) {
input.parse::<syn::Token![<]>()?;
let mut generics = Vec::new();
loop {
let next = input.parse::<T>();
match next {
Ok(next) => generics.push(next),
Err(_) => {}
}
let punc = input.parse::<syn::Token![,]>();
if punc.is_err() {
break;
}
}
input.parse::<syn::Token![>]>()?;
Ok(generics)
} else {
Ok(Vec::new())
}
}

//************************************************************************//

/// The format string to use for display
#[derive(Clone)]
pub(crate) struct DisplayAttribute {
Expand Down
54 changes: 29 additions & 25 deletions impl/src/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
use std::usize;

use proc_macro2::TokenStream;
use quote::TokenStreamExt;
use syn::{Attribute, Generics, Ident, ImplGenerics, Lit, TypeGenerics, WhereClause};
use quote::{quote, TokenStreamExt};
use syn::{Attribute, Ident, Lit, TypeParam};

use crate::ast::{AstInlineErrorVariantField, DisplayAttribute};

Expand Down Expand Up @@ -132,11 +132,11 @@ fn add_enum(error_enum_node: &ErrorEnumGraphNode, token_stream: &mut TokenStream
}
}
let attributes = &error_enum.attributes;
let generic_params = &error_enum.generics;
let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics);
token_stream.append_all(quote::quote! {
#(#attributes)*
#[derive(Debug)]
pub enum #enum_name #generic_params {
pub enum #enum_name #impl_generics {
#error_variant_tokens
}
});
Expand Down Expand Up @@ -178,10 +178,10 @@ fn impl_error(error_enum_node: &ErrorEnumGraphNode, token_stream: &mut TokenStre
}
});
}
let (impl_generics, ty_generics, where_clause) = generic_tokens(&error_enum.generics);
let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics);
token_stream.append_all(quote::quote! {
#[allow(unused_qualifications)]
impl #impl_generics core::error::Error for #enum_name #ty_generics #where_clause {
impl #impl_generics core::error::Error for #enum_name #ty_generics {
#error_inner
}
});
Expand Down Expand Up @@ -272,9 +272,9 @@ fn impl_display(error_enum_node: &ErrorEnumGraphNode, token_stream: &mut TokenSt
}
}
}
let (impl_generics, ty_generics, where_clause) = generic_tokens(&error_enum.generics);
let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics);
token_stream.append_all(quote::quote! {
impl #impl_generics core::fmt::Display for #enum_name #ty_generics #where_clause {
impl #impl_generics core::fmt::Display for #enum_name #ty_generics {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
match *self {
Expand Down Expand Up @@ -389,10 +389,10 @@ fn impl_froms(
error_branch_tokens.append_all(arm);
}
}
let (impl_generics, ty_generics, where_clause) = generic_tokens(&error_enum.generics);
// Dev Note: We can safely apply ty_generics to the `from` since only one level of generic inheritence is allowed.
let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics);
// todo it is not always correct to apply `ty_generics` to `from_error_enum_name`.
token_stream.append_all(quote::quote! {
impl #impl_generics From<#from_error_enum_name #ty_generics> for #error_enum_name #ty_generics #where_clause {
impl #impl_generics From<#from_error_enum_name #ty_generics> for #error_enum_name #ty_generics {
fn from(error: #from_error_enum_name #ty_generics) -> Self {
match error {
#error_branch_tokens
Expand All @@ -410,21 +410,21 @@ fn impl_froms(
continue;
}
if is_source_tuple_type(error_variant) {
let (impl_generics, ty_generics, where_clause) = generic_tokens(&error_enum.generics);
let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics);
let variant_name = &error_variant.name();
token_stream.append_all(quote::quote! {
impl #impl_generics From<#source_type> for #error_enum_name #ty_generics #where_clause {
impl #impl_generics From<#source_type> for #error_enum_name #ty_generics {
fn from(error: #source_type) -> Self {
#error_enum_name::#variant_name(error)
}
}
});
source_errors_froms_already_implemented.push(source_type);
} else if is_source_only_struct_type(error_variant) {
let (impl_generics, ty_generics, where_clause) = generic_tokens(&error_enum.generics);
let (impl_generics, ty_generics) = generic_tokens(&error_enum.generics);
let variant_name = &error_variant.name();
token_stream.append_all(quote::quote! {
impl #impl_generics From<#source_type> for #error_enum_name #ty_generics #where_clause {
impl #impl_generics From<#source_type> for #error_enum_name #ty_generics {
fn from(error: #source_type) -> Self {
#error_enum_name::#variant_name { source: error }
}
Expand Down Expand Up @@ -731,7 +731,7 @@ impl ErrorEnumGraphNode {
pub(crate) struct ErrorEnum {
pub(crate) attributes: Vec<Attribute>,
pub(crate) error_name: Ident,
pub(crate) generics: Option<Generics>,
pub(crate) generics: Vec<TypeParam>,
pub(crate) error_variants: Vec<ErrorVariant>,
}

Expand Down Expand Up @@ -802,15 +802,16 @@ fn is_opaque(input: TokenStream) -> bool {

//************************************************************************//

fn generic_tokens(generics: &Option<Generics>) -> (Option<ImplGenerics<'_>>, Option<TypeGenerics<'_>>, Option<&WhereClause>) {
let (impl_generics, ty_generics, where_clause) = if let Some(generics) = generics {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
(Some(impl_generics), Some(ty_generics), where_clause)
fn generic_tokens(generics: &Vec<TypeParam>) -> (Option<TokenStream>, Option<TokenStream>) {
if generics.is_empty() {
return (None, None);
}
else {
(None,None,None)
};
(impl_generics, ty_generics, where_clause)
let impl_clause = quote! {<#(#generics),*>};

let names = generics.iter().map(|e| &e.ident);
let ty_clause = quote! {<#(#names),*>};

(Some(impl_clause), Some(ty_clause))
}

//************************************************************************//
Expand All @@ -821,7 +822,10 @@ pub(crate) fn is_source_tuple_type(error_variant: &ErrorVariant) -> bool {

pub(crate) fn is_source_only_struct_type(error_variant: &ErrorVariant) -> bool {
return error_variant.source_type().is_some()
&& error_variant.fields().as_ref().is_some_and(|e| e.is_empty());
&& error_variant
.fields()
.as_ref()
.is_some_and(|e| e.is_empty());
}

pub(crate) fn is_source_struct_type(error_variant: &ErrorVariant) -> bool {
Expand Down
119 changes: 94 additions & 25 deletions impl/src/resolve.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use crate::ast::{AstErrorDeclaration, AstErrorSet, AstErrorVariant, RefError};
use std::collections::HashMap;

use crate::ast::{
AstErrorDeclaration, AstErrorSet, AstErrorVariant, AstInlineErrorVariantField, RefError,
};
use crate::expand::{ErrorEnum, ErrorVariant, Named, SourceStruct, SourceTuple, Struct};

use syn::{Attribute, Generics, Ident};
use syn::{Attribute, Ident, TypeParam};

/// Constructs [ErrorEnum]s from the ast, resolving any references to other sets. The returned result is
/// all error sets with the full expansion.
Expand Down Expand Up @@ -81,12 +85,12 @@ fn resolve_builders_helper<'a>(
for ref_part in ref_parts_to_resolve {
let ref_error_enum_index = error_enum_builders
.iter()
.position(|e| e.error_name == ref_part);
.position(|e| e.error_name == ref_part.name);
let ref_error_enum_index = match ref_error_enum_index {
Some(e) => e,
None => {
return Err(syn::parse::Error::new_spanned(
&ref_part,
&ref_part.name,
"Not a declared error set.",
));
}
Expand All @@ -101,28 +105,76 @@ fn resolve_builders_helper<'a>(
}
let (this_error_enum_builder, ref_error_enum_builder) =
indices::indices!(&mut *error_enum_builders, index, ref_error_enum_index);
match (
&this_error_enum_builder.generics,
&ref_error_enum_builder.generics,
) {
(Some(this_generics), Some(that_generics)) => {
if this_generics != that_generics {
// Dev Note: Merging generics may cause collisions in a combined definitions, e.g. `T: Debug` and `T`.
// or unintended sparsity, e.g. `T` and `G` when one would rather just have `T`.
return Err(syn::parse::Error::new_spanned(
&ref_part,
"Aggregating multiple different generic errors is not supported. \
Instead redefine the error set with the desired generics and fields.",
));
}
// Let the ref declaration override the original generic declaration name to avoid collisions - `.. || X<T> ..`
if ref_part.generic_refs.len() != ref_error_enum_builder.generics.len() {
Err(syn::parse::Error::new_spanned(
&ref_part.name,
format!("A reference to {} was declared with {} generic param(s), but the original defintion takes {}.", ref_part.name, ref_part.generic_refs.len(), ref_error_enum_builder.generics.len()),
))?;
}
let mut error_variants = Vec::new();
let error_variants = if ref_part.generic_refs.is_empty() {
&ref_error_enum_builder.error_variants
} else {
fn ident_to_type(ident: Ident) -> syn::Type {
let segment = syn::PathSegment {
ident,
arguments: syn::PathArguments::None, // No generic arguments
};
let path = syn::Path {
leading_colon: None,
segments: {
let mut punctuated = syn::punctuated::Punctuated::new();
punctuated.push(segment);
punctuated
},
};
let type_path = syn::TypePath { qself: None, path };
syn::Type::Path(type_path)
}
(None, None) => {}
(None, Some(generics)) => {
this_error_enum_builder.generics = Some(generics.clone());
// rename the generics inside the variants to the new declared name - to avoid collisions.
let mut rename = HashMap::<syn::Type, syn::Type>::new();
for (ref_part_generic, ref_error_enum_generic) in ref_part
.generic_refs
.iter()
.zip(ref_error_enum_builder.generics.iter())
{
rename.insert(
ident_to_type(ref_error_enum_generic.ident.clone()),
ident_to_type(ref_part_generic.clone()),
);
}
(Some(_), None) => {}
// let error_variants = Vec::new();
for error_variant in ref_error_enum_builder.error_variants.iter() {
let newfields = if let Some(fields) = &error_variant.fields {
let mut newfields = Vec::new();
for field in fields.iter() {
if rename.contains_key(&field.r#type) {
let new_type = rename.get(&field.r#type).unwrap().clone();
newfields.push(AstInlineErrorVariantField {
name: field.name.clone(),
r#type: new_type.clone(),
});
} else {
newfields.push(field.clone());
}
}
Some(newfields)
} else {
None
};
error_variants.push(AstErrorVariant {
attributes: error_variant.attributes.clone(),
display: error_variant.display.clone(),
name: error_variant.name.clone(),
fields: newfields,
source_type: error_variant.source_type.clone(),
backtrace_type: error_variant.backtrace_type.clone(),
});
}
&error_variants
};
for variant in ref_error_enum_builder.error_variants.iter() {
for variant in error_variants {
let this_error_variants = &mut this_error_enum_builder.error_variants;
let is_variant_already_in_enum = this_error_variants
.iter()
Expand All @@ -144,17 +196,34 @@ pub(crate) fn does_occupy_the_same_space(this: &AstErrorVariant, other: &AstErro
return this.name == other.name;
}

// fn merge_generics(this: &mut Generics, other: &Generics) {
// let other_params = other.params.iter().collect::<Vec<_>>();
// for other_param in other_params {
// if !this.params.iter().any(|param| param == other_param) {
// this.params.push(other_param.clone());
// }
// }
// let other_where = other.where_clause.as_ref();
// if let Some(other_where) = other_where {
// if let Some(this_where) = &mut this.where_clause {
// this_where.predicates.extend(other_where.predicates.clone());
// } else {
// this.where_clause = Some(other_where.clone());
// }
// }
// }

struct ErrorEnumBuilder {
pub attributes: Vec<Attribute>,
pub error_name: Ident,
pub generics: Option<Generics>,
pub generics: Vec<TypeParam>,
pub error_variants: Vec<AstErrorVariant>,
/// Once this is empty, all [ref_parts] have been resolved and [error_variants] is complete.
pub ref_parts_to_resolve: Vec<RefError>,
}

impl ErrorEnumBuilder {
fn new(error_name: Ident, attributes: Vec<Attribute>, generics: Option<Generics>) -> Self {
fn new(error_name: Ident, attributes: Vec<Attribute>, generics: Vec<TypeParam>) -> Self {
Self {
attributes,
error_name,
Expand Down
Loading

0 comments on commit 50da3fe

Please sign in to comment.