diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index efe182f..10e6c9b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: - name: Run linters uses: golangci/golangci-lint-action@v2 with: - version: v1.52 + version: v1.54 test: strategy: diff --git a/go.mod b/go.mod index 8573fcf..f618e5a 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/clear-street/reinforcer go 1.20 require ( - github.com/dave/jennifer v1.4.1 + github.com/dave/jennifer v1.7.0 github.com/mitchellh/go-homedir v1.1.0 github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.29.0 diff --git a/go.sum b/go.sum index d4ccc34..250ac87 100644 --- a/go.sum +++ b/go.sum @@ -57,8 +57,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/dave/jennifer v1.4.1 h1:XyqG6cn5RQsTj3qlWQTKlRGAyrTcsk1kUmWdZBzRjDw= -github.com/dave/jennifer v1.4.1/go.mod h1:7jEdnm+qBcxl8PC0zyp7vxcpSRnzXSt9r39tpTVGlwA= +github.com/dave/jennifer v1.7.0 h1:uRbSBH9UTS64yXbh4FrMHfgfY762RD+C7bUPKODpSJE= +github.com/dave/jennifer v1.7.0/go.mod h1:nXbxhEmQfOZhWml3D1cDK5M1FLnMSozpbFN/m3RmGZc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/generator/executor/executor.go b/internal/generator/executor/executor.go index a428ce6..6bfc081 100644 --- a/internal/generator/executor/executor.go +++ b/internal/generator/executor/executor.go @@ -111,7 +111,7 @@ func createFileConfigs(discoveredSet map[string]struct{}, match map[string]*load return nil, errors.Errorf("multiple types with same name discovered with name %s", typName) } discoveredSet[typName] = struct{}{} - cfg = append(cfg, generator.NewFileConfig(typName, typName, res.Methods)) + cfg = append(cfg, generator.NewFileConfig(typName, typName, res.TypeParams, res.TypeArgs, res.Methods)) } return cfg, nil } diff --git a/internal/generator/generator.go b/internal/generator/generator.go index 75291b6..df61a13 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -21,12 +21,16 @@ type FileConfig struct { srcTypeName string // outTypeName is the desired output type name outTypeName string + // typeParams is the list of generic type parameters + typeParams []jen.Code + // typeArgs is the list of generic type arguments + typeArgs []jen.Code // methods that should be in the generated type methods []*method.Method } // NewFileConfig creates a new instance of the FileConfig which holds code generation configuration -func NewFileConfig(srcTypeName, outTypeName string, methods []*method.Method) *FileConfig { +func NewFileConfig(srcTypeName, outTypeName string, typeParams []jen.Code, typeArgs []jen.Code, methods []*method.Method) *FileConfig { // cannot use cases.Title as it will lowercase MyService to Myservice if len(srcTypeName) > 0 { srcTypeName = strings.ToUpper(string(srcTypeName[0])) + srcTypeName[1:] @@ -37,6 +41,8 @@ func NewFileConfig(srcTypeName, outTypeName string, methods []*method.Method) *F return &FileConfig{ srcTypeName: srcTypeName, outTypeName: outTypeName, + typeParams: typeParams, + typeArgs: typeArgs, methods: methods, } } @@ -136,22 +142,22 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig for _, meth := range methods { declMethods = append(declMethods, jen.Id(meth.Name).Params(meth.ParametersNameAndType...).Params(meth.ReturnTypes...)) } - f.Add(jen.Type().Id(fileCfg.targetName()).Interface( + f.Add(jen.Type().Id(fileCfg.targetName()).Types(fileCfg.typeParams...).Interface( declMethods..., )) // Declare the proxy implementation - f.Add(jen.Type().Id(fileCfg.outTypeName).Struct( + f.Add(jen.Type().Id(fileCfg.outTypeName).Types(fileCfg.typeParams...).Struct( jen.Op("*").Id("base"), - jen.Id("delegate").Id(fileCfg.targetName()), + jen.Id("delegate").Id(fileCfg.targetName()).Types(fileCfg.typeArgs...), )) // Declare the ctor - f.Add(jen.Func().Id("New"+fileCfg.outTypeName).Params( - jen.Id("delegate").Id(fileCfg.targetName()), + f.Add(jen.Func().Id("New"+fileCfg.outTypeName).Types(fileCfg.typeParams...).Params( + jen.Id("delegate").Id(fileCfg.targetName()).Types(fileCfg.typeArgs...), jen.Id("runnerFactory").Id("runnerFactory"), jen.Id("options").Op("...").Id("Option"), - ).Op("*").Id(fileCfg.outTypeName).Block( + ).Op("*").Id(fileCfg.outTypeName).Types(fileCfg.typeArgs...).Block( // if delegate == nil jen.If(jen.Id("delegate").Op("==").Nil().Block( // panic("...") @@ -163,7 +169,7 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig jen.Panic(jen.Lit("provided nil runner factory")), )), // c:= &OutTypeName{...} - jen.Id("c").Op(":=").Add(jen.Op("&").Id(fileCfg.outTypeName).Values(jen.Dict{ + jen.Id("c").Op(":=").Add(jen.Op("&").Id(fileCfg.outTypeName).Types(fileCfg.typeArgs...).Values(jen.Dict{ // embed the base struct jen.Id("base"): jen.Op("&").Id("base").Values(jen.Dict{ jen.Id("errorPredicate"): jen.Id("RetryAllErrors"), @@ -181,7 +187,7 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig // Declare all of our proxy methods for _, mm := range methods { if mm.ReturnsError { - r := retryable.NewRetryable(mm, fileCfg.outTypeName, fileCfg.receiverName()) + r := retryable.NewRetryable(mm, fileCfg.outTypeName, fileCfg.typeArgs, fileCfg.receiverName()) s, err := r.Statement() if err != nil { return "", err @@ -190,9 +196,9 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig } else { var p statement if ignoreNoReturnMethods { - p = passthrough.NewPassThrough(mm, fileCfg.outTypeName, fileCfg.receiverName()) + p = passthrough.NewPassThrough(mm, fileCfg.outTypeName, fileCfg.typeArgs, fileCfg.receiverName()) } else { - p = noret.NewNoReturn(mm, fileCfg.outTypeName, fileCfg.receiverName()) + p = noret.NewNoReturn(mm, fileCfg.outTypeName, fileCfg.typeArgs, fileCfg.receiverName()) } s, err := p.Statement() if err != nil { diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index 61f2ad1..fb94b64 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -880,6 +880,128 @@ func (g *GeneratedService) SayHello(arg0 string) error { } return err } +`, + }, + }, + }, + }, + { + name: "Generic Type Parameters", + ignoreNoReturnMethods: true, + inputs: map[string]input{ + "users_service.go": { + interfaceName: "Service", + code: `package fake + +type Service[T any] interface { + SayHello(name T) error + DoNothing() +} +`, + }, + }, + outCode: &generator.Generated{ + Common: `// Code generated by reinforcer, DO NOT EDIT. + +package resilient + +import ( + "context" + goresilience "github.com/slok/goresilience" +) + +type base struct { + errorPredicate func(string, error) bool + runnerFactory runnerFactory +} +type runnerFactory interface { + GetRunner(name string) goresilience.Runner +} + +var RetryAllErrors = func(_ string, _ error) bool { + return true +} + +type Option func(*base) + +func WithRetryableErrorPredicate(fn func(string, error) bool) Option { + return func(o *base) { + o.errorPredicate = fn + } +} +func (b *base) run(ctx context.Context, name string, fn func(ctx context.Context) error) error { + return b.runnerFactory.GetRunner(name).Run(ctx, fn) +} +`, + Constants: `// Code generated by reinforcer, DO NOT EDIT. + +package resilient + +// GeneratedServiceMethods are the methods in GeneratedService +var GeneratedServiceMethods = struct { + DoNothing string + SayHello string +}{ + DoNothing: "DoNothing", + SayHello: "SayHello", +} +`, + Files: []*generator.GeneratedFile{ + { + TypeName: "GeneratedService", + Contents: `// Code generated by reinforcer, DO NOT EDIT. + +package resilient + +import "context" + +type targetService[T any] interface { + DoNothing() + SayHello(arg0 T) error +} +type GeneratedService[T any] struct { + *base + delegate targetService[T] +} + +func NewGeneratedService[T any](delegate targetService[T], runnerFactory runnerFactory, options ...Option) *GeneratedService[T] { + if delegate == nil { + panic("provided nil delegate") + } + if runnerFactory == nil { + panic("provided nil runner factory") + } + c := &GeneratedService[T]{ + base: &base{ + errorPredicate: RetryAllErrors, + runnerFactory: runnerFactory, + }, + delegate: delegate, + } + for _, o := range options { + o(c.base) + } + return c +} +func (g *GeneratedService[T]) DoNothing() { + g.delegate.DoNothing() +} +func (g *GeneratedService[T]) SayHello(arg0 T) error { + var nonRetryableErr error + err := g.run(context.Background(), GeneratedServiceMethods.SayHello, func(_ context.Context) error { + var err error + err = g.delegate.SayHello(arg0) + if g.errorPredicate(GeneratedServiceMethods.SayHello, err) { + return err + } + nonRetryableErr = err + return nil + }) + if nonRetryableErr != nil { + return nonRetryableErr + } + return err +} `, }, }, @@ -947,6 +1069,8 @@ func loadInterface(t *testing.T, filesCode map[string]input) []*generator.FileCo } loadedTypes = append(loadedTypes, generator.NewFileConfig(in.interfaceName, fmt.Sprintf("Generated%s", in.interfaceName), + svc.TypeParams, + svc.TypeArgs, svc.Methods, )) } diff --git a/internal/generator/method/method.go b/internal/generator/method/method.go index 4f5c86f..a5f65fc 100644 --- a/internal/generator/method/method.go +++ b/internal/generator/method/method.go @@ -6,17 +6,12 @@ import ( rtypes "github.com/clear-street/reinforcer/internal/types" "github.com/dave/jennifer/jen" - "github.com/pkg/errors" ) const ( ctxVarName = "ctx" ) -type named interface { - Name() string -} - // Method holds all of the data for code generation on a specific method signature type Method struct { Name string @@ -94,7 +89,7 @@ func ParseMethod(name string, signature *types.Signature) (*Method, error) { } else { paramName := fmt.Sprintf("arg%d", i) - paramType, err := toType(param.Type(), isVariadic && i == lastIndex) + paramType, err := rtypes.ToType(param.Type(), isVariadic && i == lastIndex) if err != nil { return nil, fmt.Errorf("failed to convert type=%v; error=%w", param.Type(), err) } @@ -104,7 +99,7 @@ func ParseMethod(name string, signature *types.Signature) (*Method, error) { } for i := 0; i < signature.Results().Len(); i++ { res := signature.Results().At(i) - resType, err := toType(res.Type(), false) + resType, err := rtypes.ToType(res.Type(), false) if err != nil { panic(err) } @@ -120,115 +115,3 @@ func ParseMethod(name string, signature *types.Signature) (*Method, error) { } return m, nil } - -// variadicToType generates the representation for a variadic type "...MyType" -func variadicToType(t types.Type) (jen.Code, error) { - sliceType, ok := t.(*types.Slice) - if !ok { - return nil, fmt.Errorf("expected type to be *types.Slice, got=%T", t) - } - sliceElemType, err := toType(sliceType.Elem(), false) - if err != nil { - return nil, fmt.Errorf("failed to convert slice's type; error=%w", err) - } - return jen.Op("...").Add(sliceElemType), nil -} - -// toType generates the representation for the given type -func toType(t types.Type, variadic bool) (jen.Code, error) { - if variadic { - return variadicToType(t) - } - - switch v := t.(type) { - case *types.Basic: - return jen.Id(v.Name()), nil - case *types.Chan: - rt, err := toType(v.Elem(), false) - if err != nil { - return nil, err - } - switch v.Dir() { - case types.SendRecv: - return jen.Chan().Add(rt), nil - case types.RecvOnly: - return jen.Op("<-").Chan().Add(rt), nil - default: - return jen.Chan().Op("<-").Add(rt), nil - } - case *types.Named: - typeName := v.Obj() - if _, ok := v.Underlying().(*types.Interface); ok { - if typeName.Pkg() != nil { - pkgPath := typeName.Pkg().Path() - return jen.Qual( - pkgPath, - typeName.Name(), - ), nil - } - return jen.Id(typeName.Name()), nil - } - pkgPath := typeName.Pkg().Path() - return jen.Qual( - pkgPath, - typeName.Name(), - ), nil - case *types.Pointer: - rt, err := toType(v.Elem(), false) - if err != nil { - return nil, err - } - return jen.Op("*").Add(rt), nil - case *types.Interface: - return jen.Id("interface{}"), nil - case *types.Slice: - elemType, err := toType(v.Elem(), false) - if err != nil { - return nil, err - } - return jen.Index().Add(elemType), nil - case named: - return jen.Id(v.Name()), nil - case *types.Map: - keyType, err := toType(v.Key(), false) - if err != nil { - return nil, err - } - elemType, err := toType(v.Elem(), false) - if err != nil { - return nil, err - } - return jen.Map(keyType).Add(elemType), nil - case *types.Signature: - fnVariadic := v.Variadic() - var paramTypes []jen.Code - lastIndex := v.Params().Len() - 1 - for p := 0; p < v.Params().Len(); p++ { - paramType := v.Params().At(p).Type() - tt, err := toType(paramType, lastIndex == p && fnVariadic) - if err != nil { - return nil, errors.Wrapf(err, "failed to convert type %v", paramType) - } - paramTypes = append(paramTypes, tt) - } - - var returnTypes []jen.Code - for r := 0; r < v.Results().Len(); r++ { - returnType := v.Results().At(r).Type() - tt, err := toType(returnType, false) - if err != nil { - return nil, errors.Wrapf(err, "failed to convert type %v", returnType) - } - returnTypes = append(returnTypes, tt) - } - if len(returnTypes) == 0 { - return jen.Func().Params(paramTypes...), nil - } - if len(returnTypes) > 1 { - return jen.Func().Params(paramTypes...).Parens(jen.List(returnTypes...)), nil - } - return jen.Func().Params(paramTypes...).Add(returnTypes[0]), nil - default: - return nil, fmt.Errorf("type not handled: %T", v) - } -} diff --git a/internal/generator/method/method_test.go b/internal/generator/method/method_test.go index ed06465..f1bce73 100644 --- a/internal/generator/method/method_test.go +++ b/internal/generator/method/method_test.go @@ -14,6 +14,12 @@ import ( func TestNewMethod(t *testing.T) { ctxVar := types.NewVar(token.NoPos, nil, "ctx", rtypes.ContextType) + typedType, _ := types.Instantiate( + types.NewContext(), + types.NewNamed(types.NewTypeName(token.NoPos, types.NewPackage("github.com/clear-street/fake", "fake"), "genericType", nil), types.NewStruct(nil, nil), nil), + []types.Type{types.Typ[types.String]}, + false, + ) zero := new(int) *zero = 0 @@ -174,8 +180,8 @@ func TestNewMethod(t *testing.T) { Name: "Fn", HasContext: false, ParameterNames: []string{"arg0"}, - ParametersNameAndType: []jen.Code{jen.Id("arg0").Add(jen.Id("interface{}"))}, - ReturnTypes: []jen.Code{jen.Id("interface{}")}, + ParametersNameAndType: []jen.Code{jen.Id("arg0").Add(jen.Id("any"))}, + ReturnTypes: []jen.Code{jen.Id("any")}, }, }, { @@ -191,7 +197,7 @@ func TestNewMethod(t *testing.T) { Name: "Fn", HasContext: false, ParameterNames: []string{"arg0"}, - ParametersNameAndType: []jen.Code{jen.Id("arg0").Add(jen.Map(jen.Id("string")).Add(jen.Id("interface{}")))}, + ParametersNameAndType: []jen.Code{jen.Id("arg0").Add(jen.Map(jen.Id("string")).Add(jen.Id("any")))}, ReturnTypes: []jen.Code{jen.Map(jen.Id("string")).Add(jen.Id("int"))}, }, }, @@ -222,6 +228,23 @@ func TestNewMethod(t *testing.T) { ReturnTypes: []jen.Code{jen.Map(jen.Id("string")).Add(jen.Id("int"))}, }, }, + { + name: "Fn(arg genericType[string])", + args: args{ + name: "Fn", + signature: types.NewSignatureType(nil, nil, nil, + types.NewTuple(types.NewVar(token.NoPos, nil, "arg", typedType)), + types.NewTuple(), + false), + }, + want: &method.Method{ + Name: "Fn", + HasContext: false, + ParameterNames: []string{"arg0"}, + ParametersNameAndType: []jen.Code{jen.Id("arg0").Add(jen.Qual("github.com/clear-street/fake", "genericType").Types(jen.String()))}, + ReturnTypes: []jen.Code{}, + }, + }, } for _, tt := range tests { diff --git a/internal/generator/noret/noret.go b/internal/generator/noret/noret.go index 2054d2e..cf4dd5e 100644 --- a/internal/generator/noret/noret.go +++ b/internal/generator/noret/noret.go @@ -7,17 +7,19 @@ import ( // NoReturn is a code generator that injects the middleware to delegates that don't return anything type NoReturn struct { - method *method.Method - structName string - receiverName string + method *method.Method + structName string + structTypeArgs []jen.Code + receiverName string } // NewNoReturn is a ctor for NoReturn -func NewNoReturn(method *method.Method, structName string, receiverName string) *NoReturn { +func NewNoReturn(method *method.Method, structName string, structTypeArgs []jen.Code, receiverName string) *NoReturn { return &NoReturn{ - method: method, - structName: structName, - receiverName: receiverName, + method: method, + structName: structName, + structTypeArgs: structTypeArgs, + receiverName: receiverName, } } @@ -35,7 +37,7 @@ func (p *NoReturn) Statement() (*jen.Statement, error) { jen.Return(jen.Nil()), ) - return jen.Func().Params(jen.Id(p.receiverName).Op("*").Id(p.structName)).Id(p.method.Name).Call(methodArgParams...).Block( + return jen.Func().Params(jen.Id(p.receiverName).Op("*").Id(p.structName).Types(p.structTypeArgs...)).Id(p.method.Name).Call(methodArgParams...).Block( jen.Id("err").Op(":=").Id(p.receiverName).Dot("run").Call(ctxParam, p.method.ConstantRef(p.structName), call), jen.If(jen.Id("err").Op("!=").Nil()).Block( jen.Panic(jen.Id("err")), diff --git a/internal/generator/noret/noret_test.go b/internal/generator/noret/noret_test.go index 9364676..dffac8a 100644 --- a/internal/generator/noret/noret_test.go +++ b/internal/generator/noret/noret_test.go @@ -9,6 +9,7 @@ import ( "github.com/clear-street/reinforcer/internal/generator/method" "github.com/clear-street/reinforcer/internal/generator/noret" rtypes "github.com/clear-street/reinforcer/internal/types" + "github.com/dave/jennifer/jen" "github.com/stretchr/testify/require" ) @@ -16,11 +17,12 @@ func TestNoReturn_Statement(t *testing.T) { ctxVar := types.NewVar(token.NoPos, nil, "ctx", rtypes.ContextType) tests := []struct { - name string - methodName string - signature *types.Signature - want string - wantErr bool + name string + methodName string + structTypeArgs []jen.Code + signature *types.Signature + want string + wantErr bool }{ { name: "MyFunction()", @@ -52,6 +54,22 @@ func TestNoReturn_Statement(t *testing.T) { if err != nil { panic(err) } +}`, + wantErr: false, + }, + { + name: "struct type args", + methodName: "MyFunction", + structTypeArgs: []jen.Code{jen.Id("T")}, + signature: types.NewSignatureType(nil, nil, nil, types.NewTuple(), types.NewTuple(), false), + want: `func (r *Resilient[T]) MyFunction() { + err := r.run(context.Background(), ResilientMethods.MyFunction, func(_ context.Context) error { + r.delegate.MyFunction() + return nil + }) + if err != nil { + panic(err) + } }`, wantErr: false, }, @@ -61,7 +79,7 @@ func TestNoReturn_Statement(t *testing.T) { t.Run(tt.name, func(t *testing.T) { m, err := method.ParseMethod(tt.methodName, tt.signature) require.NoError(t, err) - ret := noret.NewNoReturn(m, "Resilient", "r") + ret := noret.NewNoReturn(m, "Resilient", tt.structTypeArgs, "r") buf := &bytes.Buffer{} s, err := ret.Statement() if tt.wantErr { diff --git a/internal/generator/passthrough/passthrough.go b/internal/generator/passthrough/passthrough.go index dad2e7b..e688d84 100644 --- a/internal/generator/passthrough/passthrough.go +++ b/internal/generator/passthrough/passthrough.go @@ -7,17 +7,19 @@ import ( // PassThrough is a code generator that injects no middleware and acts a simple fall through call to the delegate type PassThrough struct { - method *method.Method - structName string - receiverName string + method *method.Method + structName string + structTypeArgs []jen.Code + receiverName string } // NewPassThrough is a ctor for PassThrough -func NewPassThrough(method *method.Method, structName string, receiverName string) *PassThrough { +func NewPassThrough(method *method.Method, structName string, structTypeArgs []jen.Code, receiverName string) *PassThrough { return &PassThrough{ - method: method, - structName: structName, - receiverName: receiverName, + method: method, + structName: structName, + structTypeArgs: structTypeArgs, + receiverName: receiverName, } } @@ -34,7 +36,7 @@ func (p *PassThrough) Statement() (*jen.Statement, error) { block = append(block, delegateCall) } - return jen.Func().Params(jen.Id(p.receiverName).Op("*").Id(p.structName)).Id(p.method.Name).Call(methodArgParams...).Block( + return jen.Func().Params(jen.Id(p.receiverName).Op("*").Id(p.structName).Types(p.structTypeArgs...)).Id(p.method.Name).Call(methodArgParams...).Block( block..., ), nil } diff --git a/internal/generator/passthrough/passthrough_test.go b/internal/generator/passthrough/passthrough_test.go index 6cbe5ba..3ddf433 100644 --- a/internal/generator/passthrough/passthrough_test.go +++ b/internal/generator/passthrough/passthrough_test.go @@ -9,6 +9,7 @@ import ( "github.com/clear-street/reinforcer/internal/generator/method" "github.com/clear-street/reinforcer/internal/generator/passthrough" rtypes "github.com/clear-street/reinforcer/internal/types" + "github.com/dave/jennifer/jen" "github.com/stretchr/testify/require" ) @@ -16,11 +17,12 @@ func TestPassThrough_Statement(t *testing.T) { ctxVar := types.NewVar(token.NoPos, nil, "ctx", rtypes.ContextType) tests := []struct { - name string - methodName string - signature *types.Signature - want string - wantErr bool + name string + methodName string + structTypeArgs []jen.Code + signature *types.Signature + want string + wantErr bool }{ { name: "Function arguments and returns", @@ -52,6 +54,16 @@ func TestPassThrough_Statement(t *testing.T) { signature: types.NewSignatureType(nil, nil, nil, types.NewTuple(), types.NewTuple(), false), want: `func (r *resilient) MyFunction() { r.delegate.MyFunction() +}`, + wantErr: false, + }, + { + name: "struct type args", + methodName: "MyFunction", + structTypeArgs: []jen.Code{jen.Id("T")}, + signature: types.NewSignatureType(nil, nil, nil, types.NewTuple(), types.NewTuple(), false), + want: `func (r *resilient[T]) MyFunction() { + r.delegate.MyFunction() }`, wantErr: false, }, @@ -61,7 +73,7 @@ func TestPassThrough_Statement(t *testing.T) { t.Run(tt.name, func(t *testing.T) { m, err := method.ParseMethod(tt.methodName, tt.signature) require.NoError(t, err) - ret := passthrough.NewPassThrough(m, "resilient", "r") + ret := passthrough.NewPassThrough(m, "resilient", tt.structTypeArgs, "r") buf := &bytes.Buffer{} s, err := ret.Statement() if tt.wantErr { diff --git a/internal/generator/retryable/retryable.go b/internal/generator/retryable/retryable.go index c8529cf..29e2323 100644 --- a/internal/generator/retryable/retryable.go +++ b/internal/generator/retryable/retryable.go @@ -14,21 +14,23 @@ const ( // Retryable is a code generator for a method that can be retried on error type Retryable struct { - method *method.Method - structName string - receiverName string + method *method.Method + structName string + structTypeArgs []jen.Code + receiverName string } // NewRetryable is a constructor for Retryable, the given method must be an error-returning method -func NewRetryable(method *method.Method, structName string, receiverName string) *Retryable { +func NewRetryable(method *method.Method, structName string, structTypeArgs []jen.Code, receiverName string) *Retryable { if !method.ReturnsError { panic("method does not return an error and is thus not retryable") } return &Retryable{ - method: method, - structName: structName, - receiverName: receiverName, + method: method, + structName: structName, + structTypeArgs: structTypeArgs, + receiverName: receiverName, } } @@ -38,7 +40,7 @@ func (r *Retryable) Statement() (*jen.Statement, error) { if err != nil { return nil, err } - return jen.Func().Params(jen.Id(r.receiverName).Op("*").Id(r.structName)).Id(r.method.Name).Call(r.method.ParametersNameAndType...).Params(r.method.ReturnTypes...).Block( + return jen.Func().Params(jen.Id(r.receiverName).Op("*").Id(r.structName).Types(r.structTypeArgs...)).Id(r.method.Name).Call(r.method.ParametersNameAndType...).Params(r.method.ReturnTypes...).Block( methodCallStatements..., ), nil } diff --git a/internal/generator/retryable/retryable_test.go b/internal/generator/retryable/retryable_test.go index 9ad304f..170b67b 100644 --- a/internal/generator/retryable/retryable_test.go +++ b/internal/generator/retryable/retryable_test.go @@ -9,6 +9,7 @@ import ( "github.com/clear-street/reinforcer/internal/generator/method" "github.com/clear-street/reinforcer/internal/generator/retryable" rtypes "github.com/clear-street/reinforcer/internal/types" + "github.com/dave/jennifer/jen" "github.com/stretchr/testify/require" ) @@ -17,11 +18,12 @@ func TestRetryable_Statement(t *testing.T) { ctxVar := types.NewVar(token.NoPos, nil, "ctx", rtypes.ContextType) tests := []struct { - name string - methodName string - signature *types.Signature - want string - wantErr bool + name string + methodName string + structTypeArgs []jen.Code + signature *types.Signature + want string + wantErr bool }{ { name: "Function returns error", @@ -91,6 +93,29 @@ func TestRetryable_Statement(t *testing.T) { return r0, nonRetryableErr } return r0, err +}`, + wantErr: false, + }, + { + name: "Function returns error", + methodName: "MyFunction", + structTypeArgs: []jen.Code{jen.Id("T")}, + signature: types.NewSignatureType(nil, nil, nil, types.NewTuple(), types.NewTuple(errVar), false), + want: `func (r *Resilient[T]) MyFunction() error { + var nonRetryableErr error + err := r.run(context.Background(), ResilientMethods.MyFunction, func(_ context.Context) error { + var err error + err = r.delegate.MyFunction() + if r.errorPredicate(ResilientMethods.MyFunction, err) { + return err + } + nonRetryableErr = err + return nil + }) + if nonRetryableErr != nil { + return nonRetryableErr + } + return err }`, wantErr: false, }, @@ -100,7 +125,7 @@ func TestRetryable_Statement(t *testing.T) { t.Run(tt.name, func(t *testing.T) { m, err := method.ParseMethod(tt.methodName, tt.signature) require.NoError(t, err) - ret := retryable.NewRetryable(m, "Resilient", "r") + ret := retryable.NewRetryable(m, "Resilient", tt.structTypeArgs, "r") buf := &bytes.Buffer{} s, err := ret.Statement() if tt.wantErr { @@ -121,7 +146,7 @@ func TestRetryable_Statement(t *testing.T) { require.Panics(t, func() { m, err := method.ParseMethod("Fn", types.NewSignatureType(nil, nil, nil, types.NewTuple(), types.NewTuple(), false)) require.NoError(t, err) - retryable.NewRetryable(m, "Resilient", "r") + retryable.NewRetryable(m, "Resilient", nil, "r") }) }) } diff --git a/internal/loader/loader.go b/internal/loader/loader.go index 43a6b06..6a4a591 100644 --- a/internal/loader/loader.go +++ b/internal/loader/loader.go @@ -10,6 +10,8 @@ import ( "unicode" "github.com/clear-street/reinforcer/internal/generator/method" + rtypes "github.com/clear-street/reinforcer/internal/types" + "github.com/dave/jennifer/jen" "github.com/rs/zerolog/log" "golang.org/x/tools/go/packages" ) @@ -53,8 +55,10 @@ func (l *LoadingError) Error() string { // Result holds the results of loading a particular type type Result struct { - Name string - Methods []*method.Method + Name string + TypeParams []jen.Code + TypeArgs []jen.Code + Methods []*method.Method } // Loader is a utility service for extracting type information from a go package @@ -173,7 +177,7 @@ func (l *Loader) loadExpr(path string, expr *regexp.Regexp, mode LoadMode) (*pac switch typ := obj.Type().Underlying().(type) { case *types.Interface: logger.Info().Msgf("Discovered interface type %s", typeFound) - result, err := loadFromInterface(typeFound, typ) + result, err := loadFromInterface(typeFound, typ, obj.Type()) if err != nil { return nil, nil, err } @@ -223,10 +227,21 @@ func (l *Loader) load(path string, mode LoadMode) ([]*packages.Package, error) { return pkgs, nil } -func loadFromInterface(name string, interfaceType *types.Interface) (*Result, error) { +func loadFromInterface(name string, interfaceType *types.Interface, objType types.Type) (*Result, error) { result := &Result{ Name: name, } + typeParams := objType.(*types.Named).TypeParams() + for p := 0; p < typeParams.Len(); p++ { + typeParam := typeParams.At(p) + typeParamName := typeParam.Obj().Name() + typ, err := rtypes.ToType(typeParam.Constraint(), false) + if err != nil { + return nil, fmt.Errorf("failed to convert type parameter %s; error=%w", typeParamName, err) + } + result.TypeParams = append(result.TypeParams, jen.Id(typeParamName).Add(typ)) + result.TypeArgs = append(result.TypeArgs, jen.Id(typeParamName)) + } for m := 0; m < interfaceType.NumMethods(); m++ { meth := interfaceType.Method(m) mm, err := method.ParseMethod(meth.Name(), meth.Type().(*types.Signature)) @@ -254,10 +269,19 @@ func loadFromStruct(f *ast.File, name string, info *types.Info) (*Result, error) for _, l := range fn.Recv.List { var ident *ast.Ident switch t := l.Type.(type) { - case *ast.StarExpr: - ident = t.X.(*ast.Ident) case *ast.Ident: ident = t + case *ast.StarExpr: + switch t := t.X.(type) { + case *ast.Ident: + ident = t + case *ast.IndexExpr: + // single type parameter + ident = t.X.(*ast.Ident) + case *ast.IndexListExpr: + // multiple type parameters + ident = t.X.(*ast.Ident) + } } if ident == nil || ident.Name != name { diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 06cb138..59a3aba 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -98,7 +98,7 @@ type Service interface { import "context" -type service struct { +type service struct { } func (s *service) GetUserID(ctx context.Context, userID string) (string, error) { @@ -120,6 +120,88 @@ func (s *service) GetUserID(ctx context.Context, userID string) (string, error) require.Equal(t, 1, len(svc.Methods)) require.Equal(t, "GetUserID", svc.Methods[0].Name) }) + + t.Run("Load struct with method with generic typed argument", func(t *testing.T) { + exported := packagestest.Export(t, packagestest.GOPATH, []packagestest.Module{{ + Name: "github.com/clear-street", + Files: map[string]interface{}{ + "fake/fake.go": `package fake + +type genericType[T any] struct { + value T +} + +type genericService struct{} + +func (g *genericService) DoTheThing(t genericType[string]) (string, error) { return t.value, nil } +`, + }}}) + defer exported.Cleanup() + + l := loader.NewLoader(func(cfg *packages.Config, patterns ...string) ([]*packages.Package, error) { + exported.Config.Mode = cfg.Mode + return packages.Load(exported.Config, patterns...) + }) + + svc, err := l.LoadOne("github.com/clear-street/fake", "genericService", loader.PackageLoadMode) + require.NoError(t, err) + require.NotNil(t, svc) + require.Equal(t, "genericService", svc.Name) + require.Equal(t, 1, len(svc.Methods)) + require.Equal(t, "DoTheThing", svc.Methods[0].Name) + }) + + t.Run("Load struct with generic type param", func(t *testing.T) { + exported := packagestest.Export(t, packagestest.GOPATH, []packagestest.Module{{ + Name: "github.com/clear-street", + Files: map[string]interface{}{ + "fake/fake.go": `package fake + +type genericService[T any] struct{} + +func (g *genericService[T]) DoTheThing() (string, error) { return "yep", nil } +`, + }}}) + defer exported.Cleanup() + + l := loader.NewLoader(func(cfg *packages.Config, patterns ...string) ([]*packages.Package, error) { + exported.Config.Mode = cfg.Mode + return packages.Load(exported.Config, patterns...) + }) + + svc, err := l.LoadOne("github.com/clear-street/fake", "genericService", loader.PackageLoadMode) + require.NoError(t, err) + require.NotNil(t, svc) + require.Equal(t, "genericService", svc.Name) + require.Equal(t, 1, len(svc.Methods)) + require.Equal(t, "DoTheThing", svc.Methods[0].Name) + }) + + t.Run("Load struct with generic type param list", func(t *testing.T) { + exported := packagestest.Export(t, packagestest.GOPATH, []packagestest.Module{{ + Name: "github.com/clear-street", + Files: map[string]interface{}{ + "fake/fake.go": `package fake + +type genericService[T any, U any] struct{} + +func (g *genericService[T, U]) DoTheThing() (string, error) { return "yep", nil } +`, + }}}) + defer exported.Cleanup() + + l := loader.NewLoader(func(cfg *packages.Config, patterns ...string) ([]*packages.Package, error) { + exported.Config.Mode = cfg.Mode + return packages.Load(exported.Config, patterns...) + }) + + svc, err := l.LoadOne("github.com/clear-street/fake", "genericService", loader.PackageLoadMode) + require.NoError(t, err) + require.NotNil(t, svc) + require.Equal(t, "genericService", svc.Name) + require.Equal(t, 1, len(svc.Methods)) + require.Equal(t, "DoTheThing", svc.Methods[0].Name) + }) } func TestLoadMatched(t *testing.T) { diff --git a/internal/types/types.go b/internal/types/types.go index f597016..b3dc057 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -1,11 +1,18 @@ package types import ( + "fmt" "go/types" + "github.com/dave/jennifer/jen" + "github.com/pkg/errors" "golang.org/x/tools/go/packages" ) +type named interface { + Name() string +} + // ErrType is the types.Type for the error interface var ErrType types.Type @@ -60,3 +67,126 @@ func IsContextType(t types.Type) bool { } return types.Implements(t, ContextType) } + +// variadicToType generates the representation for a variadic type "...MyType" +func variadicToType(t types.Type) (jen.Code, error) { + sliceType, ok := t.(*types.Slice) + if !ok { + return nil, fmt.Errorf("expected type to be *types.Slice, got=%T", t) + } + sliceElemType, err := ToType(sliceType.Elem(), false) + if err != nil { + return nil, fmt.Errorf("failed to convert slice's type; error=%w", err) + } + return jen.Op("...").Add(sliceElemType), nil +} + +// ToType generates the representation for the given type +func ToType(t types.Type, variadic bool) (jen.Code, error) { + if variadic { + return variadicToType(t) + } + + switch v := t.(type) { + case *types.Basic: + return jen.Id(v.Name()), nil + case *types.Chan: + rt, err := ToType(v.Elem(), false) + if err != nil { + return nil, err + } + switch v.Dir() { + case types.SendRecv: + return jen.Chan().Add(rt), nil + case types.RecvOnly: + return jen.Op("<-").Chan().Add(rt), nil + default: + return jen.Chan().Op("<-").Add(rt), nil + } + case *types.Named: + typeName := v.Obj() + if _, ok := v.Underlying().(*types.Interface); ok { + if typeName.Pkg() != nil { + pkgPath := typeName.Pkg().Path() + return jen.Qual( + pkgPath, + typeName.Name(), + ), nil + } + return jen.Id(typeName.Name()), nil + } + pkgPath := typeName.Pkg().Path() + var typeArgs []jen.Code + for p := 0; p < v.TypeArgs().Len(); p++ { + typeArg := v.TypeArgs().At(p) + tt, err := ToType(typeArg, false) + if err != nil { + return nil, errors.Wrapf(err, "failed to convert type %v", typeArg) + } + typeArgs = append(typeArgs, tt) + } + return jen.Qual( + pkgPath, + typeName.Name(), + ).Types(typeArgs...), nil + case *types.Pointer: + rt, err := ToType(v.Elem(), false) + if err != nil { + return nil, err + } + return jen.Op("*").Add(rt), nil + case *types.Interface: + return jen.Id("any"), nil + case *types.Slice: + elemType, err := ToType(v.Elem(), false) + if err != nil { + return nil, err + } + return jen.Index().Add(elemType), nil + case named: + return jen.Id(v.Name()), nil + case *types.Map: + keyType, err := ToType(v.Key(), false) + if err != nil { + return nil, err + } + elemType, err := ToType(v.Elem(), false) + if err != nil { + return nil, err + } + return jen.Map(keyType).Add(elemType), nil + case *types.Signature: + fnVariadic := v.Variadic() + var paramTypes []jen.Code + lastIndex := v.Params().Len() - 1 + for p := 0; p < v.Params().Len(); p++ { + paramType := v.Params().At(p).Type() + tt, err := ToType(paramType, lastIndex == p && fnVariadic) + if err != nil { + return nil, errors.Wrapf(err, "failed to convert type %v", paramType) + } + paramTypes = append(paramTypes, tt) + } + + var returnTypes []jen.Code + for r := 0; r < v.Results().Len(); r++ { + returnType := v.Results().At(r).Type() + tt, err := ToType(returnType, false) + if err != nil { + return nil, errors.Wrapf(err, "failed to convert type %v", returnType) + } + returnTypes = append(returnTypes, tt) + } + if len(returnTypes) == 0 { + return jen.Func().Params(paramTypes...), nil + } + if len(returnTypes) > 1 { + return jen.Func().Params(paramTypes...).Parens(jen.List(returnTypes...)), nil + } + return jen.Func().Params(paramTypes...).Add(returnTypes[0]), nil + case *types.TypeParam: + return jen.Id(v.Obj().Name()), nil + default: + return nil, fmt.Errorf("type not handled: %T", v) + } +}