Skip to content

Commit

Permalink
Add support to read gopb tags from proto request
Browse files Browse the repository at this point in the history
  • Loading branch information
sudhiaithal committed Sep 9, 2024
1 parent d71e9a8 commit 4b09e73
Showing 1 changed file with 136 additions and 34 deletions.
170 changes: 136 additions & 34 deletions patch/patcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ import (
"golang.org/x/tools/go/ast/astutil"
"google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/pluginpb"

"github.com/golang/protobuf/proto"

"github.com/alta/protopatch/lint"
"github.com/alta/protopatch/patch/gopb"
"github.com/alta/protopatch/patch/ident"
)

Expand All @@ -33,56 +37,153 @@ import (
// - (go.enum).name overrides the name of an enum type.
// - (go.value).name overrides the name of an enum value.
type Patcher struct {
gen *protogen.Plugin
fset *token.FileSet
filesByName map[string]*ast.File
info *types.Info
packages []*Package
packagesByPath map[string]*Package
packagesByName map[string]*Package
renames map[protogen.GoIdent]string
typeRenames map[protogen.GoIdent]string
valueRenames map[protogen.GoIdent]string
fieldRenames map[protogen.GoIdent]string
methodRenames map[protogen.GoIdent]string
objectRenames map[types.Object]string
tags map[protogen.GoIdent]string
fieldTags map[types.Object]string
embeds map[protogen.GoIdent]string
fieldEmbeds map[types.Object]string
types map[protogen.GoIdent]string
fieldTypes map[types.Object]string
gen *protogen.Plugin
fset *token.FileSet
filesByName map[string]*ast.File
info *types.Info
packages []*Package
packagesByPath map[string]*Package
packagesByName map[string]*Package
renames map[protogen.GoIdent]string
typeRenames map[protogen.GoIdent]string
valueRenames map[protogen.GoIdent]string
fieldRenames map[protogen.GoIdent]string
methodRenames map[protogen.GoIdent]string
objectRenames map[types.Object]string
tags map[protogen.GoIdent]string
fieldTags map[types.Object]string
embeds map[protogen.GoIdent]string
fieldEmbeds map[types.Object]string
types map[protogen.GoIdent]string
fieldTypes map[types.Object]string
processedMessages map[protogen.GoIdent]bool
}

// NewPatcher returns an initialized Patcher for gen.
func NewPatcher(gen *protogen.Plugin) (*Patcher, error) {
p := &Patcher{
gen: gen,
packagesByPath: make(map[string]*Package),
packagesByName: make(map[string]*Package),
renames: make(map[protogen.GoIdent]string),
typeRenames: make(map[protogen.GoIdent]string),
valueRenames: make(map[protogen.GoIdent]string),
fieldRenames: make(map[protogen.GoIdent]string),
methodRenames: make(map[protogen.GoIdent]string),
objectRenames: make(map[types.Object]string),
tags: make(map[protogen.GoIdent]string),
fieldTags: make(map[types.Object]string),
embeds: make(map[protogen.GoIdent]string),
fieldEmbeds: make(map[types.Object]string),
types: make(map[protogen.GoIdent]string),
fieldTypes: make(map[types.Object]string),
gen: gen,
packagesByPath: make(map[string]*Package),
packagesByName: make(map[string]*Package),
renames: make(map[protogen.GoIdent]string),
typeRenames: make(map[protogen.GoIdent]string),
valueRenames: make(map[protogen.GoIdent]string),
fieldRenames: make(map[protogen.GoIdent]string),
methodRenames: make(map[protogen.GoIdent]string),
objectRenames: make(map[types.Object]string),
tags: make(map[protogen.GoIdent]string),
fieldTags: make(map[types.Object]string),
embeds: make(map[protogen.GoIdent]string),
fieldEmbeds: make(map[types.Object]string),
types: make(map[protogen.GoIdent]string),
fieldTypes: make(map[types.Object]string),
processedMessages: make(map[protogen.GoIdent]bool),
}
return p, p.scan()
}

func getExtensionDesc(pb proto.Message, extname string) (*proto.ExtensionDesc, error) {
desc := proto.RegisteredExtensions(pb)
for _, d := range desc {
if d.Name == extname {
return d, nil
}
}
return nil, fmt.Errorf("ExtensionDesc not found")
}

func getExtension(pb proto.Message, extname string) (interface{}, error) {
d, err := getExtensionDesc(pb, extname)
if err != nil {
return nil, err
}
e, err := proto.GetExtension(pb, d)
if err != nil {
return nil, err
}
return e, err
}

func (p *Patcher) scan() error {
for _, f := range p.gen.Files {
p.scanFile(f)
}
for _, f := range p.gen.Request.ProtoFile {
found := false
mident := protogen.GoIdent{GoName: "", GoImportPath: ""}
fident := protogen.GoIdent{GoName: "", GoImportPath: ""}
for _, genFile := range p.gen.Files {
if *f.Name == genFile.Desc.Path() {
found = true
mident = protogen.GoIdent{GoName: "", GoImportPath: genFile.GoImportPath}
fident = protogen.GoIdent{GoName: "", GoImportPath: genFile.GoImportPath}
break
}
}
if !found {
panic("Not found")
}
for _, m := range f.MessageType {
mident.GoName = *m.Name
if _, ok := p.processedMessages[mident]; ok {
continue
}
for _, msgfield := range m.Field {
fident.GoName = *msgfield.Name
p.scanProtoField(mident, fident, msgfield)
}
}
}

return nil
}

func (p *Patcher) scanProtoField(mident protogen.GoIdent, fident protogen.GoIdent, f *descriptorpb.FieldDescriptorProto) {
//m := f.Parent
//o := f.Oneof

if f.TypeName == nil {
return
}
fi, err := getExtension(f.GetOptions(), "go.field")
if err != nil {
return
}
opts := fi.(*gopb.Options)

log.Printf("Parent Message %v (%v), opts %v", *f.Name, *f.TypeName, opts)
// Embed field ?
embed := false
newName := ""
if opts.GetEmbed() {
switch {
default:
embed = true
log.Printf("Embed Set %v ", *f.Name, *f.TypeName)
splitStrings := strings.Split((*f.TypeName)[1:], ".")
newName = splitStrings[len(splitStrings)-1]
}
}
if newName != "" {
if false {
p.RenameType(fident, p.nameFor(mident)+"_"+newName) // Oneof wrapper struct
p.RenameField(ident.WithChild(fident, fident.GoName), newName, false) // Oneof wrapper field (not embeddable)
} else {
p.RenameField(ident.WithChild(mident, fident.GoName), newName, embed) // Field
childID := ident.WithChild(mident, fident.GoName)
log.Printf("child %v parent %v", childID, mident.GoName)
}
p.RenameMethod(ident.WithChild(mident, "Get"+fident.GoName), "Get"+newName) // Getter
} else {
p.RenameField(ident.WithChild(mident, fident.GoName), newName, embed) // Field
}

tags := opts.GetTags()
if tags != "" {
p.Tag(ident.WithChild(mident, fident.GoName), tags) // Field tags
}
}

func (p *Patcher) scanFile(f *protogen.File) {
log.Printf("\nScan proto:\t%s", f.Desc.Path())

Expand Down Expand Up @@ -190,6 +291,7 @@ func (p *Patcher) scanMessage(m *protogen.Message, parent *protogen.Message) {
opts := messageOptions(m)
lints := fileLintOptions(m.Desc)

p.processedMessages[m.GoIdent] = true
// Rename message?
newName := opts.GetName()
if newName == "" && parent != nil && p.isRenamed(parent.GoIdent) {
Expand Down

0 comments on commit 4b09e73

Please sign in to comment.