-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.go
199 lines (178 loc) · 5.12 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
package main
import (
"fmt"
"io/ioutil"
"path/filepath"
"sort"
"strings"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/pointer"
"golang.org/x/tools/go/ssa/ssautil"
)
func main() {
res, err := getProjectUsedCall("/Users/mac/Downloads/goc-master")
// res, err := getProjectUsedCall("/Users/mac/Desktop/mixStew")
fmt.Printf("res is: %+v err:%+v\n", res, err)
}
// getProjectUsedCall 获取项目使用中的调用方法
func getProjectUsedCall(projectPath string) ([]string, error) {
projectModule, err := parseProjectModule(projectPath)
if err != nil {
return nil, errors.Wrap(err, "parseProjectModule fail")
}
log.Debugf("projectModule: %+v", projectModule)
callMap, err := parseProjectCallMap(projectPath)
if err != nil {
return nil, errors.Wrap(err, "parseProjectCallMap fail")
}
log.Debugf("callMap: %+v", callMap)
srcCall := fmt.Sprintf("%v.main", projectModule)
isDeleteEdgeFunc := func(caller, callee string) bool {
// 非本项目调用
if !strings.Contains(caller, projectModule) || !strings.Contains(callee, projectModule) {
return true
}
// 非初始化调用
if isInitCall(caller) || isInitCall(callee) {
return true
}
// 非自我调用
if caller == callee {
return true
}
return false
}
// 过滤不需要的边
for caller, callees := range callMap {
for callee := range callees {
if isDeleteEdgeFunc(caller, callee) {
delete(callees, callee)
}
}
if len(callees) == 0 {
delete(callMap, caller)
}
}
// 广度搜索图
for {
srcCallees := callMap[srcCall]
srcSize := len(srcCallees)
for srcCallee := range srcCallees {
for nextCallee := range callMap[srcCallee] {
callMap[srcCall][nextCallee] = true
}
}
if srcSize == len(callMap[srcCall]) {
break
}
}
// 调用源涉及到的所有方法
var callees []string
for c := range callMap[srcCall] {
callees = append(callees, c)
}
sort.Strings(callees)
return callees, nil
}
// parseProjectCallMap 解析项目调用图
func parseProjectCallMap(projectPath string) (map[string]map[string]bool, error) {
projectModule, err := parseProjectModule(projectPath)
if err != nil {
return nil, errors.Wrap(err, "parseProjectModule fail")
}
log.Debugf("projectModule: %+v", projectModule)
result, err := analyzeProject(projectPath)
if err != nil {
return nil, errors.Wrap(err, "analyzeProject fail")
}
log.Debugf("analyzeProject: %+v", result)
// 遍历调用链路
var callMap = make(map[string]map[string]bool)
visitFunc := func(edge *callgraph.Edge) error {
if edge == nil {
return nil
}
// 解析调用者和被调用者
caller, callee, err := parseCallEdge(edge)
if err != nil {
return errors.Wrap(err, "parseCallEdge fail")
}
// 记录调用关系
if callMap[caller] == nil {
callMap[caller] = make(map[string]bool)
}
callMap[caller][callee] = true
return nil
}
err = callgraph.GraphVisitEdges(result.CallGraph, visitFunc)
if err != nil {
return nil, errors.Wrap(err, "GraphVisitEdges fail")
}
return callMap, nil
}
func parseProjectModule(projectPath string) (string, error) {
modFilename := filepath.Join(projectPath, "go.mod")
content, err := ioutil.ReadFile(modFilename)
if err != nil {
return "", errors.Wrap(err, "ioutil.ReadFile fail")
}
lines := strings.Split(string(content), "\n")
module := strings.TrimPrefix(lines[0], "module ")
module = strings.TrimSpace(module)
return module, nil
}
func analyzeProject(projectPath string) (*pointer.Result, error) {
// 生成Go Packages
pkgs, err := packages.Load(&packages.Config{
Mode: packages.LoadAllSyntax,
Dir: projectPath,
})
if err != nil {
return nil, errors.Wrap(err, "packages.Load fail")
}
log.Debugf("pkgs: %+v", pkgs)
// 生成ssa 构建编译
prog, ssaPkgs := ssautil.AllPackages(pkgs, 0)
prog.Build()
log.Debugf("ssaPkgs: %+v", ssaPkgs)
// 使用pointer生成调用链路
return pointer.Analyze(&pointer.Config{
Mains: ssaPkgs,
BuildCallGraph: true,
})
}
func parseCallEdge(edge *callgraph.Edge) (string, string, error) {
const callArrow = "-->"
edgeStr := fmt.Sprintf("%+v", edge)
strArray := strings.Split(edgeStr, callArrow)
if len(strArray) != 2 {
return "", "", fmt.Errorf("invalid format: %v", edgeStr)
}
callerNodeStr, calleeNodeStr := strArray[0], strArray[1]
caller, callee := getCallRoute(callerNodeStr), getCallRoute(calleeNodeStr)
return caller, callee, nil
}
func getCallRoute(nodeStr string) string {
nodeStr = strings.TrimSpace(nodeStr)
if strings.Contains(nodeStr, ":") {
nodeStr = nodeStr[strings.Index(nodeStr, ":")+1:]
}
nodeStr = strings.ReplaceAll(nodeStr, "*", "")
nodeStr = strings.ReplaceAll(nodeStr, "(", "")
nodeStr = strings.ReplaceAll(nodeStr, ")", "")
nodeStr = strings.ReplaceAll(nodeStr, "<", "")
nodeStr = strings.ReplaceAll(nodeStr, ">", "")
if strings.Contains(nodeStr, "$") {
nodeStr = nodeStr[:strings.Index(nodeStr, "$")]
}
if strings.Contains(nodeStr, "#") {
nodeStr = nodeStr[:strings.Index(nodeStr, "#")]
}
return strings.TrimSpace(nodeStr)
}
func isInitCall(call string) bool {
return strings.HasSuffix(call, ".init")
}