Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add codegeneration for stack.FunctionID #1114

Merged
merged 18 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/check-codegen.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
- name: Build
run: |
go install ./internal/cmd/gtrace
go install ./internal/cmd/gstack
go install go.uber.org/mock/[email protected]

- name: Clean and re-generate *_gtrace.go files
Expand All @@ -40,5 +41,9 @@ jobs:
go generate ./trace
go generate ./...

- name: Re-generate stack.FunctionID calls
run: |
gstack .

- name: Check repository diff
run: bash ./.github/scripts/check-work-copy-equals-to-committed.sh "code-generation not equal with committed"
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
* Added query pool metrics
* Fixed logic of query session pool
* Changed initialization of internal driver clients to lazy
* Disabled the logic of background grpc-connection parking
* Disabled the logic of background grpc-connection parking
* Added internal gstack codegen for filling `stack.FunctionID` with value from call stack

## v3.58.2
* Added `trace.Query.OnSessionBegin` event
Expand Down
223 changes: 223 additions & 0 deletions internal/cmd/gstack/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
package main

import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/fs"
"os"
"path/filepath"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/cmd/gstack/utils"
)

func usage() {
fmt.Fprintf(os.Stderr, "usage: codegenerate [path]\n")
flag.PrintDefaults()
}

func getCallExpressionsFromExpr(expr ast.Expr) (listOfCalls []*ast.CallExpr) {
switch expr := expr.(type) {
case *ast.SelectorExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.IndexExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.StarExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.BinaryExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Y)...)
case *ast.CallExpr:
listOfCalls = append(listOfCalls, expr)
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Fun)...)
for _, arg := range expr.Args {
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(arg)...)
}
case *ast.CompositeLit:
for _, elt := range expr.Elts {
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(elt)...)
}
case *ast.UnaryExpr:
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.X)...)
case *ast.KeyValueExpr:
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Value)...)
case *ast.FuncLit:
listOfCalls = append(listOfCalls, getListOfCallExpressionsFromBlockStmt(expr.Body)...)
}

return listOfCalls
}

func getExprFromDeclStmt(statement *ast.DeclStmt) (listOfExpressions []ast.Expr) {
decl, ok := statement.Decl.(*ast.GenDecl)
if !ok {
return listOfExpressions
}
for _, spec := range decl.Specs {
if spec, ok := spec.(*ast.ValueSpec); ok {
for _, expr := range spec.Values {
listOfExpressions = append(listOfExpressions, expr)
}
}
}
return listOfExpressions
}

func getCallExpressionsFromStmt(statement ast.Stmt) (listOfCallExpressions []*ast.CallExpr) {
var body *ast.BlockStmt
var listOfExpressions []ast.Expr
switch statement.(type) {
case *ast.IfStmt:
body = statement.(*ast.IfStmt).Body
case *ast.SwitchStmt:
body = statement.(*ast.SwitchStmt).Body
case *ast.TypeSwitchStmt:
body = statement.(*ast.TypeSwitchStmt).Body
case *ast.SelectStmt:
body = statement.(*ast.SelectStmt).Body
case *ast.ForStmt:
body = statement.(*ast.ForStmt).Body
case *ast.RangeStmt:
body = statement.(*ast.RangeStmt).Body
case *ast.DeclStmt:
listOfExpressions = append(listOfExpressions, getExprFromDeclStmt(statement.(*ast.DeclStmt))...)
for _, expr := range listOfExpressions {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(expr)...)
}
case *ast.CommClause:
stmts := statement.(*ast.CommClause).Body
for _, stmt := range stmts {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(stmt)...)
}
case *ast.ExprStmt:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(statement.(*ast.ExprStmt).X)...)
}
if body != nil {
listOfCallExpressions = append(
listOfCallExpressions,
getListOfCallExpressionsFromBlockStmt(body)...,
)
}

return listOfCallExpressions
}

func getListOfCallExpressionsFromBlockStmt(block *ast.BlockStmt) (listOfCallExpressions []*ast.CallExpr) {
for _, statement := range block.List {
switch expr := statement.(type) {
case *ast.ExprStmt:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(expr.X)...)
case *ast.ReturnStmt:
for _, result := range expr.Results {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(result)...)
}
case *ast.AssignStmt:
for _, rh := range expr.Rhs {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(rh)...)
}
default:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(statement)...)
}
}

return listOfCallExpressions
}

func format(src []byte, path string, fset *token.FileSet, file ast.File) ([]byte, error) {
var listOfArgs []utils.FunctionIDArg
for _, f := range file.Decls {
var listOfCalls []*ast.CallExpr
fn, ok := f.(*ast.FuncDecl)
if !ok {
continue
}
listOfCalls = getListOfCallExpressionsFromBlockStmt(fn.Body)
for _, call := range listOfCalls {
if function, ok := call.Fun.(*ast.SelectorExpr); ok && function.Sel.Name == "FunctionID" {
pack, ok := function.X.(*ast.Ident)
if !ok {
continue
}
if pack.Name == "stack" && len(call.Args) == 1 {
listOfArgs = append(listOfArgs, utils.FunctionIDArg{
FuncDecl: fn,
ArgPos: call.Args[0].Pos(),
ArgEnd: call.Args[0].End(),
})
}
}
}
}
if len(listOfArgs) != 0 {
fixed, err := utils.FixSource(fset, path, src, listOfArgs)
if err != nil {
return nil, err
}

return fixed, nil
}

return src, nil
}

func main() {
flag.Usage = usage
flag.Parse()
args := flag.Args()

if len(args) != 1 {
flag.Usage()

return
}
_, err := os.Stat(args[0])
if err != nil {
panic(err)
}

fileSystem := os.DirFS(args[0])

err = fs.WalkDir(fileSystem, ".", func(path string, d fs.DirEntry, err error) error {
fset := token.NewFileSet()
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if filepath.Ext(path) == ".go" {
info, err := os.Stat(path)
if err != nil {
return err
}
src, err := utils.ReadFile(path, info)
if err != nil {
return err
}
file, err := parser.ParseFile(fset, path, nil, 0)
if err != nil {
return err
}
formatted, err := format(src, path, fset, *file)
if err != nil {
return err
}
if !bytes.Equal(src, formatted) {
err = utils.WriteFile(path, formatted, info.Mode().Perm())
if err != nil {
return err
}
}

return nil
}

return nil
})
if err != nil {
panic(err)
}
}
134 changes: 134 additions & 0 deletions internal/cmd/gstack/utils/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package utils

import (
"fmt"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/version"
asmyasnikov marked this conversation as resolved.
Show resolved Hide resolved
"go/ast"
"go/parser"
"go/token"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
)

type FunctionIDArg struct {
FuncDecl *ast.FuncDecl
ArgPos token.Pos
ArgEnd token.Pos
}

func ReadFile(filename string, info fs.FileInfo) ([]byte, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer func(f *os.File) {
err := f.Close()
if err != nil {
}
}(f)
size := int(info.Size())
src := make([]byte, size)
n, err := io.ReadFull(f, src)
if err != nil {
return nil, err
}
if n < size {
return nil, fmt.Errorf("error: size of %s changed during reading (from %d to %d bytes)", filename, size, n)
} else if n > size {
return nil, fmt.Errorf("error: size of %s changed during reading (from %d to >=%d bytes)", filename, size, len(src))
}

return src, nil
}

func FixSource(fset *token.FileSet, path string, src []byte, listOfArgs []FunctionIDArg) ([]byte, error) {
var fixed []byte
var previousArgEnd int
for _, arg := range listOfArgs {
argPosOffset := fset.Position(arg.ArgPos).Offset
argEndOffset := fset.Position(arg.ArgEnd).Offset
argument, err := makeCall(fset, path, arg)
if err != nil {
return nil, err
}
fixed = append(fixed, src[previousArgEnd:argPosOffset]...)
fixed = append(fixed, fmt.Sprintf("\"%s\"", argument)...)
previousArgEnd = argEndOffset
}
fixed = append(fixed, src[previousArgEnd:]...)

return fixed, nil
}

func WriteFile(filename string, formatted []byte, perm fs.FileMode) error {
fout, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC, perm)
if err != nil {
return err
}

defer fout.Close()

_, err = fout.Write(formatted)
if err != nil {
return err
}

return nil
}

func makeCall(fset *token.FileSet, path string, arg FunctionIDArg) (string, error) {
basePath := filepath.Join("github.com/ydb-platform/", version.Prefix, version.Major, "")
packageName, err := getPackageName(fset, arg)
if err != nil {
return "", err
}
filePath := filepath.Dir(filepath.Dir(path))
funcName, err := getFuncName(arg.FuncDecl)
if err != nil {
return "", err
}
return strings.Join([]string{filepath.Join(basePath, filePath, packageName), funcName}, "."), nil
}

func getFuncName(funcDecl *ast.FuncDecl) (string, error) {
if funcDecl.Recv != nil {
recvType := funcDecl.Recv.List[0].Type
prefix, err := getIdentNameFromExpr(recvType)
if err != nil {
return "", err
}
return strings.Join([]string{prefix, funcDecl.Name.Name}, "."), nil
}
return funcDecl.Name.Name, nil
}

func getIdentNameFromExpr(expr ast.Expr) (string, error) {
switch expr := expr.(type) {
case *ast.Ident:
return expr.Name, nil
case *ast.StarExpr:
prefix, err := getIdentNameFromExpr(expr.X)
if err != nil {
return "", err
}
return "(*" + prefix + ")", nil
case *ast.IndexExpr:
return getIdentNameFromExpr(expr.X)
case *ast.IndexListExpr:
return getIdentNameFromExpr(expr.X)
default:
return "", fmt.Errorf("error during getting ident from expr")
}
}

func getPackageName(fset *token.FileSet, arg FunctionIDArg) (string, error) {
file := fset.File(arg.ArgPos)
parsedFile, err := parser.ParseFile(fset, file.Name(), nil, parser.PackageClauseOnly)
if err != nil {
return "", fmt.Errorf("error during get package name function")
}
return parsedFile.Name.Name, nil
}
Loading
Loading