Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add codegeneration for stack.FunctionID #1114

Merged
merged 18 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/check-codegen.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
- name: Build
run: |
go install ./internal/cmd/gtrace
go install ./internal/cmd/gstack
go install go.uber.org/mock/[email protected]

- name: Clean and re-generate *_gtrace.go files
Expand All @@ -40,5 +41,9 @@ jobs:
go generate ./trace
go generate ./...

- name: Re-generate stack.FunctionID calls
run: |
gstack .

- name: Check repository diff
run: bash ./.github/scripts/check-work-copy-equals-to-committed.sh "code-generation not equal with committed"
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Added internal `gstack` codegen tool for filling `stack.FunctionID` with value from call stack

## v3.59.1
* Fixed updating last usage timestamp for smart parking of the conns

Expand Down
7 changes: 4 additions & 3 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ func (d *Driver) trace() *trace.Driver {
//
//nolint:nonamedreturns
func (d *Driver) Close(ctx context.Context) (finalErr error) {
onDone := trace.DriverOnClose(d.trace(), &ctx, stack.FunctionID(""))
onDone := trace.DriverOnClose(d.trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/ydb.(*Driver).Close"))
defer func() {
onDone(finalErr)
}()
Expand Down Expand Up @@ -248,7 +249,7 @@ func Open(ctx context.Context, dsn string, opts ...Option) (_ *Driver, err error

onDone := trace.DriverOnInit(
d.trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/ydb.Open"),
d.config.Endpoint(), d.config.Database(), d.config.Secure(),
)
defer func() {
Expand Down Expand Up @@ -284,7 +285,7 @@ func New(ctx context.Context, opts ...Option) (_ *Driver, err error) {

onDone := trace.DriverOnInit(
d.trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/ydb.New"),
d.config.Endpoint(), d.config.Database(), d.config.Secure(),
)
defer func() {
Expand Down
12 changes: 7 additions & 5 deletions internal/balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
address = "ydb:///" + b.driverConfig.Endpoint()
onDone = trace.DriverOnBalancerClusterDiscoveryAttempt(
b.driverConfig.Trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID(
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).clusterDiscoveryAttempt"),
address,
)
endpoints []endpoint.Endpoint
Expand Down Expand Up @@ -173,7 +174,8 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
var (
onDone = trace.DriverOnBalancerUpdate(
b.driverConfig.Trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID(
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
b.config.DetectLocalDC,
)
previousConns []conn.Conn
Expand Down Expand Up @@ -211,7 +213,7 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
func (b *Balancer) Close(ctx context.Context) (err error) {
onDone := trace.DriverOnBalancerClose(
b.driverConfig.Trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).Close"),
)
defer func() {
onDone(err)
Expand All @@ -237,7 +239,7 @@ func New(
var (
onDone = trace.DriverOnBalancerInit(
driverConfig.Trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.New"),
driverConfig.Balancer().String(),
)
discoveryConfig = discoveryConfig.New(append(opts,
Expand Down Expand Up @@ -371,7 +373,7 @@ func (b *Balancer) connections() *connectionsState {
func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
onDone := trace.DriverOnBalancerChooseEndpoint(
b.driverConfig.Trace(), &ctx,
stack.FunctionID(""),
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn"),
)
defer func() {
if err == nil {
Expand Down
227 changes: 227 additions & 0 deletions internal/cmd/gstack/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
package main

import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/fs"
"os"
"path/filepath"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/cmd/gstack/utils"
)

func usage() {
fmt.Fprintf(os.Stderr, "usage: gstack [path]\n")
flag.PrintDefaults()
}

func getCallExpressionsFromExpr(expr ast.Expr) (listOfCalls []*ast.CallExpr) {
switch expr := expr.(type) {
case *ast.SelectorExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.IndexExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.StarExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.BinaryExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Y)...)
case *ast.CallExpr:
listOfCalls = append(listOfCalls, expr)
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Fun)...)
for _, arg := range expr.Args {
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(arg)...)
}
case *ast.CompositeLit:
for _, elt := range expr.Elts {
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(elt)...)
}
case *ast.UnaryExpr:
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.X)...)
case *ast.KeyValueExpr:
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Value)...)
case *ast.FuncLit:
listOfCalls = append(listOfCalls, getListOfCallExpressionsFromBlockStmt(expr.Body)...)
}

return listOfCalls
}

func getExprFromDeclStmt(statement *ast.DeclStmt) (listOfExpressions []ast.Expr) {
decl, ok := statement.Decl.(*ast.GenDecl)
if !ok {
return listOfExpressions
}
for _, spec := range decl.Specs {
if spec, ok := spec.(*ast.ValueSpec); ok {
listOfExpressions = append(listOfExpressions, spec.Values...)
}
}

return listOfExpressions
}

func getCallExpressionsFromStmt(statement ast.Stmt) (listOfCallExpressions []*ast.CallExpr) {
var body *ast.BlockStmt
var listOfExpressions []ast.Expr
switch stmt := statement.(type) {
case *ast.IfStmt:
body = stmt.Body
case *ast.SwitchStmt:
body = stmt.Body
case *ast.TypeSwitchStmt:
body = stmt.Body
case *ast.SelectStmt:
body = stmt.Body
case *ast.ForStmt:
body = stmt.Body
case *ast.RangeStmt:
body = stmt.Body
case *ast.DeclStmt:
listOfExpressions = append(listOfExpressions, getExprFromDeclStmt(stmt)...)
for _, expr := range listOfExpressions {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(expr)...)
}
case *ast.CommClause:
stmts := stmt.Body
for _, stmt := range stmts {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(stmt)...)
}
case *ast.ExprStmt:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(stmt.X)...)
}
if body != nil {
listOfCallExpressions = append(
listOfCallExpressions,
getListOfCallExpressionsFromBlockStmt(body)...,
)
}

return listOfCallExpressions
}

func getListOfCallExpressionsFromBlockStmt(block *ast.BlockStmt) (listOfCallExpressions []*ast.CallExpr) {
for _, statement := range block.List {
switch expr := statement.(type) {
case *ast.ExprStmt:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(expr.X)...)
case *ast.ReturnStmt:
for _, result := range expr.Results {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(result)...)
}
case *ast.AssignStmt:
for _, rh := range expr.Rhs {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(rh)...)
}
default:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(statement)...)
}
}

return listOfCallExpressions
}

func format(src []byte, path string, fset *token.FileSet, file *ast.File) ([]byte, error) {
var listOfArgs []utils.FunctionIDArg
for _, f := range file.Decls {
var listOfCalls []*ast.CallExpr
fn, ok := f.(*ast.FuncDecl)
if !ok {
continue
}
listOfCalls = getListOfCallExpressionsFromBlockStmt(fn.Body)
for _, call := range listOfCalls {
if function, ok := call.Fun.(*ast.SelectorExpr); ok && function.Sel.Name == "FunctionID" {
pack, ok := function.X.(*ast.Ident)
if !ok {
continue
}
if pack.Name == "stack" && len(call.Args) == 1 {
listOfArgs = append(listOfArgs, utils.FunctionIDArg{
FuncDecl: fn,
ArgPos: call.Args[0].Pos(),
ArgEnd: call.Args[0].End(),
})
}
}
}
}
if len(listOfArgs) != 0 {
fixed, err := utils.FixSource(fset, path, src, listOfArgs)
if err != nil {
return nil, err
}

return fixed, nil
}

return src, nil
}

func processFile(src []byte, path string, fset *token.FileSet, file *ast.File, info os.FileInfo) error {
formatted, err := format(src, path, fset, file)
if err != nil {
return err
}
if !bytes.Equal(src, formatted) {
err = utils.WriteFile(path, formatted, info.Mode().Perm())
if err != nil {
return err
}
}

return nil
}

func main() {
flag.Usage = usage
flag.Parse()
args := flag.Args()

if len(args) != 1 {
flag.Usage()

return
}
_, err := os.Stat(args[0])
if err != nil {
panic(err)
}

fileSystem := os.DirFS(args[0])

err = fs.WalkDir(fileSystem, ".", func(path string, d fs.DirEntry, err error) error {
fset := token.NewFileSet()
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if filepath.Ext(path) == ".go" {
info, err := os.Stat(path)
if err != nil {
return err
}
src, err := utils.ReadFile(path, info)
if err != nil {
return err
}
file, err := parser.ParseFile(fset, path, nil, 0)
if err != nil {
return err
}

return processFile(src, path, fset, file, info)
}

return nil
})
if err != nil {
panic(err)
}
}
Loading
Loading