diff --git a/README.md b/README.md index b0aa552..f5e5cf7 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,9 @@ go install github.com/Buzzvil/recovergoroutine recovergoroutine -recover="" ./... # -recover string -# Custom recover method name. Currently, it is difficult to determine -# if a CustomRecover function declared in another package is valid, -# so this option can be used to resolve it. +# Custom recovery method name. You can use this option +# when you want to call a method defined in a struct or +# use CustomRecover declared in an external package. ``` Check out the test cases for validation [examples](./test/src/faildata/failcode.go). diff --git a/recovergoroutine/recovergoroutine.go b/recovergoroutine/recovergoroutine.go index 9482b06..8f245f2 100644 --- a/recovergoroutine/recovergoroutine.go +++ b/recovergoroutine/recovergoroutine.go @@ -2,15 +2,12 @@ package recovergoroutine import ( "flag" - "fmt" "go/ast" - "go/parser" - "go/types" - "reflect" - "golang.org/x/tools/go/analysis" ) +type message string + var customRecover string func NewAnalyzer() *analysis.Analyzer { @@ -25,8 +22,7 @@ func NewAnalyzer() *analysis.Analyzer { &customRecover, "recover", "", - "It is difficult to determine if a CustomRecover function declared in another package is valid,"+ - " so this option can be used to resolve it.", + "You can use this option when you want to call a method defined in a struct or use CustomRecover declared in an external package.", ) return analyzer @@ -41,12 +37,7 @@ func run(pass *analysis.Pass) (interface{}, error) { return true } - ok, err := safeGoStmt(goStmt, pass) - if err != nil { - runErr = err - return false - } - + ok, msg := safeGoStmt(goStmt) if ok { return true } @@ -55,7 +46,7 @@ func run(pass *analysis.Pass) (interface{}, error) { Pos: goStmt.Pos(), End: 0, Category: "goroutine", - Message: "goroutine must have recover", + Message: string(msg), }) return false @@ -65,43 +56,28 @@ func run(pass *analysis.Pass) (interface{}, error) { return nil, runErr } -func safeGoStmt(goStmt *ast.GoStmt, pass *analysis.Pass) (bool, error) { +func safeGoStmt(goStmt *ast.GoStmt) (bool, message) { fn := goStmt.Call switch fun := fn.Fun.(type) { - case *ast.SelectorExpr: - return safeSelectorExpr(fun, pass, safeFunc) case *ast.FuncLit: - return safeFunc(fun, pass) - case *ast.Ident: - if fun.Obj == nil { - return false, nil - } - - funcDecl, ok := fun.Obj.Decl.(*ast.FuncDecl) - if !ok { - return false, nil + if !safeFunc(fun) { + return false, "goroutine must have recover" } - - return safeFunc(funcDecl, pass) + return true, "" } - return false, fmt.Errorf("unexpected goroutine function type: %v", reflect.TypeOf(fn.Fun).String()) + return false, "use function literals when using goroutines" } -func safeFunc(node ast.Node, pass *analysis.Pass) (bool, error) { +func safeFunc(node ast.Node) bool { result := false - var err error ast.Inspect(node, func(node ast.Node) bool { deferStmt, ok := node.(*ast.DeferStmt) if !ok { return true } - ok, err = hasRecover(deferStmt.Call, pass) - if err != nil { - return false - } - + ok = hasRecover(deferStmt.Call) if ok { result = true return false @@ -110,12 +86,11 @@ func safeFunc(node ast.Node, pass *analysis.Pass) (bool, error) { return !result }) - return result, err + return result } -func hasRecover(expr ast.Node, pass *analysis.Pass) (bool, error) { +func hasRecover(expr ast.Node) bool { var result bool - var err error ast.Inspect(expr, func(node ast.Node) bool { switch n := node.(type) { case *ast.CallExpr: @@ -128,13 +103,7 @@ func hasRecover(expr ast.Node, pass *analysis.Pass) (bool, error) { return true } - var ok bool - ok, err = safeSelectorExpr(n, pass, hasRecover) - if err != nil { - return false - } - - if ok || n.Sel.Name == customRecover { + if n.Sel.Name == customRecover { result = true return false } @@ -142,55 +111,7 @@ func hasRecover(expr ast.Node, pass *analysis.Pass) (bool, error) { return true }) - return result, err -} - -func safeSelectorExpr( - expr *ast.SelectorExpr, - pass *analysis.Pass, - methodChecker func(node ast.Node, pass *analysis.Pass) (bool, error), -) (bool, error) { - ident, ok := expr.X.(*ast.Ident) - if !ok { - return false, nil - } - - methodName := expr.Sel.Name - objType := pass.TypesInfo.ObjectOf(ident) - pointerType, ok := objType.Type().(*types.Pointer) - if !ok { - return false, nil - } - - named, ok := pointerType.Elem().(*types.Named) - if !ok { - return false, nil - } - - result := false - for i := 0; i < named.NumMethods(); i++ { - if named.Method(i).Name() != methodName { - continue - } - - fset := pass.Fset - position := fset.Position(named.Method(i).Pos()) - file, err := parser.ParseFile(fset, position.Filename, nil, 0) - if err != nil { - return false, fmt.Errorf("parse file: %w", err) - } - - for _, decl := range file.Decls { - if funcDecl, ok := decl.(*ast.FuncDecl); ok { - if funcDecl.Name.Name == methodName { - result, err = methodChecker(funcDecl, pass) - break - } - } - } - } - - return result, nil + return result } func isRecover(callExpr *ast.CallExpr) bool { @@ -199,7 +120,7 @@ func isRecover(callExpr *ast.CallExpr) bool { return false } - return ident.Name == "recover" + return ident.Name == "recover" || ident.Name == customRecover } func isCustomRecover(callExpr *ast.CallExpr) bool { diff --git a/test/src/custom/recover.go b/test/src/custom/recover.go new file mode 100644 index 0000000..8554486 --- /dev/null +++ b/test/src/custom/recover.go @@ -0,0 +1 @@ +package custom diff --git a/test/src/succdata/succcode.go b/test/src/succdata/succcode.go index 0d2d7e4..cdc4bbf 100644 --- a/test/src/succdata/succcode.go +++ b/test/src/succdata/succcode.go @@ -1,6 +1,10 @@ package succdata func whenASTFuncLit() { + go func() { + defer recover() + }() + go func() { defer func() { if r := recover(); r != nil { @@ -23,54 +27,4 @@ func whenASTFuncLit() { defer rec() }() - - go func() { - defer customRecover() - }() - -} - -func whenIdent() { - go runGoroutine() - go nestedFunc1() -} - -func whenCallMethod() { - foo := &Foo{} - go foo.run() - go func() { - defer foo.Recover() - }() -} - -func runGoroutine() { - defer func() { - recover() - }() -} - -func nestedFunc1() { - // must have recover in parent caller - nestedFunc2() - defer func() { - recover() - }() -} - -func nestedFunc2() {} - -func customRecover() { - recover() -} - -type Foo struct{} - -func (a *Foo) run() { - defer func() { - recover() - }() -} - -func (a *Foo) Recover() { - recover() }