Skip to content

Commit

Permalink
Merge pull request #1114 from anatoly32322/codegen_function_id
Browse files Browse the repository at this point in the history
Add codegeneration for stack.FunctionID
  • Loading branch information
asmyasnikov authored Mar 24, 2024
2 parents 0e4cf0a + ffd3ee0 commit 8e6ec46
Show file tree
Hide file tree
Showing 35 changed files with 521 additions and 103 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/check-codegen.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
- name: Build
run: |
go install ./internal/cmd/gtrace
go install ./internal/cmd/gstack
go install go.uber.org/mock/[email protected]
- name: Clean and re-generate *_gtrace.go files
Expand All @@ -40,5 +41,9 @@ jobs:
go generate ./trace
go generate ./...
- name: Re-generate stack.FunctionID calls
run: |
gstack .
- name: Check repository diff
run: bash ./.github/scripts/check-work-copy-equals-to-committed.sh "code-generation not equal with committed"
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Added internal `gstack` codegen tool for filling `stack.FunctionID` with value from call stack

## v3.59.1
* Fixed updating last usage timestamp for smart parking of the conns

Expand Down
7 changes: 4 additions & 3 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ func (d *Driver) trace() *trace.Driver {
//
//nolint:nonamedreturns
func (d *Driver) Close(ctx context.Context) (finalErr error) {
onDone := trace.DriverOnClose(d.trace(), &ctx, stack.FunctionID(""))
onDone := trace.DriverOnClose(d.trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/ydb.(*Driver).Close"))
defer func() {
onDone(finalErr)
}()
Expand Down Expand Up @@ -248,7 +249,7 @@ func Open(ctx context.Context, dsn string, opts ...Option) (_ *Driver, err error

onDone := trace.DriverOnInit(
d.trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/ydb.Open"),
d.config.Endpoint(), d.config.Database(), d.config.Secure(),
)
defer func() {
Expand Down Expand Up @@ -284,7 +285,7 @@ func New(ctx context.Context, opts ...Option) (_ *Driver, err error) {

onDone := trace.DriverOnInit(
d.trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/ydb.New"),
d.config.Endpoint(), d.config.Database(), d.config.Secure(),
)
defer func() {
Expand Down
12 changes: 7 additions & 5 deletions internal/balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
address = "ydb:///" + b.driverConfig.Endpoint()
onDone = trace.DriverOnBalancerClusterDiscoveryAttempt(
b.driverConfig.Trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID(
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).clusterDiscoveryAttempt"),
address,
)
endpoints []endpoint.Endpoint
Expand Down Expand Up @@ -173,7 +174,8 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
var (
onDone = trace.DriverOnBalancerUpdate(
b.driverConfig.Trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID(
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
b.config.DetectLocalDC,
)
previousConns []conn.Conn
Expand Down Expand Up @@ -211,7 +213,7 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
func (b *Balancer) Close(ctx context.Context) (err error) {
onDone := trace.DriverOnBalancerClose(
b.driverConfig.Trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).Close"),
)
defer func() {
onDone(err)
Expand All @@ -237,7 +239,7 @@ func New(
var (
onDone = trace.DriverOnBalancerInit(
driverConfig.Trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.New"),
driverConfig.Balancer().String(),
)
discoveryConfig = discoveryConfig.New(append(opts,
Expand Down Expand Up @@ -371,7 +373,7 @@ func (b *Balancer) connections() *connectionsState {
func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
onDone := trace.DriverOnBalancerChooseEndpoint(
b.driverConfig.Trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn"),
)
defer func() {
if err == nil {
Expand Down
227 changes: 227 additions & 0 deletions internal/cmd/gstack/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
package main

import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/fs"
"os"
"path/filepath"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/cmd/gstack/utils"
)

func usage() {
fmt.Fprintf(os.Stderr, "usage: gstack [path]\n")
flag.PrintDefaults()
}

func getCallExpressionsFromExpr(expr ast.Expr) (listOfCalls []*ast.CallExpr) {
switch expr := expr.(type) {
case *ast.SelectorExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.IndexExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.StarExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.BinaryExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Y)...)
case *ast.CallExpr:
listOfCalls = append(listOfCalls, expr)
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Fun)...)
for _, arg := range expr.Args {
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(arg)...)
}
case *ast.CompositeLit:
for _, elt := range expr.Elts {
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(elt)...)
}
case *ast.UnaryExpr:
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.X)...)
case *ast.KeyValueExpr:
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Value)...)
case *ast.FuncLit:
listOfCalls = append(listOfCalls, getListOfCallExpressionsFromBlockStmt(expr.Body)...)
}

return listOfCalls
}

func getExprFromDeclStmt(statement *ast.DeclStmt) (listOfExpressions []ast.Expr) {
decl, ok := statement.Decl.(*ast.GenDecl)
if !ok {
return listOfExpressions
}
for _, spec := range decl.Specs {
if spec, ok := spec.(*ast.ValueSpec); ok {
listOfExpressions = append(listOfExpressions, spec.Values...)
}
}

return listOfExpressions
}

func getCallExpressionsFromStmt(statement ast.Stmt) (listOfCallExpressions []*ast.CallExpr) {
var body *ast.BlockStmt
var listOfExpressions []ast.Expr
switch stmt := statement.(type) {
case *ast.IfStmt:
body = stmt.Body
case *ast.SwitchStmt:
body = stmt.Body
case *ast.TypeSwitchStmt:
body = stmt.Body
case *ast.SelectStmt:
body = stmt.Body
case *ast.ForStmt:
body = stmt.Body
case *ast.RangeStmt:
body = stmt.Body
case *ast.DeclStmt:
listOfExpressions = append(listOfExpressions, getExprFromDeclStmt(stmt)...)
for _, expr := range listOfExpressions {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(expr)...)
}
case *ast.CommClause:
stmts := stmt.Body
for _, stmt := range stmts {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(stmt)...)
}
case *ast.ExprStmt:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(stmt.X)...)
}
if body != nil {
listOfCallExpressions = append(
listOfCallExpressions,
getListOfCallExpressionsFromBlockStmt(body)...,
)
}

return listOfCallExpressions
}

func getListOfCallExpressionsFromBlockStmt(block *ast.BlockStmt) (listOfCallExpressions []*ast.CallExpr) {
for _, statement := range block.List {
switch expr := statement.(type) {
case *ast.ExprStmt:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(expr.X)...)
case *ast.ReturnStmt:
for _, result := range expr.Results {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(result)...)
}
case *ast.AssignStmt:
for _, rh := range expr.Rhs {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(rh)...)
}
default:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(statement)...)
}
}

return listOfCallExpressions
}

func format(src []byte, path string, fset *token.FileSet, file *ast.File) ([]byte, error) {
var listOfArgs []utils.FunctionIDArg
for _, f := range file.Decls {
var listOfCalls []*ast.CallExpr
fn, ok := f.(*ast.FuncDecl)
if !ok {
continue
}
listOfCalls = getListOfCallExpressionsFromBlockStmt(fn.Body)
for _, call := range listOfCalls {
if function, ok := call.Fun.(*ast.SelectorExpr); ok && function.Sel.Name == "FunctionID" {
pack, ok := function.X.(*ast.Ident)
if !ok {
continue
}
if pack.Name == "stack" && len(call.Args) == 1 {
listOfArgs = append(listOfArgs, utils.FunctionIDArg{
FuncDecl: fn,
ArgPos: call.Args[0].Pos(),
ArgEnd: call.Args[0].End(),
})
}
}
}
}
if len(listOfArgs) != 0 {
fixed, err := utils.FixSource(fset, path, src, listOfArgs)
if err != nil {
return nil, err
}

return fixed, nil
}

return src, nil
}

func processFile(src []byte, path string, fset *token.FileSet, file *ast.File, info os.FileInfo) error {
formatted, err := format(src, path, fset, file)
if err != nil {
return err
}
if !bytes.Equal(src, formatted) {
err = utils.WriteFile(path, formatted, info.Mode().Perm())
if err != nil {
return err
}
}

return nil
}

func main() {
flag.Usage = usage
flag.Parse()
args := flag.Args()

if len(args) != 1 {
flag.Usage()

return
}
_, err := os.Stat(args[0])
if err != nil {
panic(err)
}

fileSystem := os.DirFS(args[0])

err = fs.WalkDir(fileSystem, ".", func(path string, d fs.DirEntry, err error) error {
fset := token.NewFileSet()
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if filepath.Ext(path) == ".go" {
info, err := os.Stat(path)
if err != nil {
return err
}
src, err := utils.ReadFile(path, info)
if err != nil {
return err
}
file, err := parser.ParseFile(fset, path, nil, 0)
if err != nil {
return err
}

return processFile(src, path, fset, file, info)
}

return nil
})
if err != nil {
panic(err)
}
}
Loading

0 comments on commit 8e6ec46

Please sign in to comment.