diff --git a/patch/patcher.go b/patch/patcher.go index c3855a7..390202f 100644 --- a/patch/patcher.go +++ b/patch/patcher.go @@ -18,9 +18,13 @@ import ( "golang.org/x/tools/go/ast/astutil" "google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo" "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/pluginpb" + "github.com/golang/protobuf/proto" + "github.com/alta/protopatch/lint" + "github.com/alta/protopatch/patch/gopb" "github.com/alta/protopatch/patch/ident" ) @@ -33,56 +37,153 @@ import ( // - (go.enum).name overrides the name of an enum type. // - (go.value).name overrides the name of an enum value. type Patcher struct { - gen *protogen.Plugin - fset *token.FileSet - filesByName map[string]*ast.File - info *types.Info - packages []*Package - packagesByPath map[string]*Package - packagesByName map[string]*Package - renames map[protogen.GoIdent]string - typeRenames map[protogen.GoIdent]string - valueRenames map[protogen.GoIdent]string - fieldRenames map[protogen.GoIdent]string - methodRenames map[protogen.GoIdent]string - objectRenames map[types.Object]string - tags map[protogen.GoIdent]string - fieldTags map[types.Object]string - embeds map[protogen.GoIdent]string - fieldEmbeds map[types.Object]string - types map[protogen.GoIdent]string - fieldTypes map[types.Object]string + gen *protogen.Plugin + fset *token.FileSet + filesByName map[string]*ast.File + info *types.Info + packages []*Package + packagesByPath map[string]*Package + packagesByName map[string]*Package + renames map[protogen.GoIdent]string + typeRenames map[protogen.GoIdent]string + valueRenames map[protogen.GoIdent]string + fieldRenames map[protogen.GoIdent]string + methodRenames map[protogen.GoIdent]string + objectRenames map[types.Object]string + tags map[protogen.GoIdent]string + fieldTags map[types.Object]string + embeds map[protogen.GoIdent]string + fieldEmbeds map[types.Object]string + types map[protogen.GoIdent]string + fieldTypes map[types.Object]string + processedMessages map[protogen.GoIdent]bool } // NewPatcher returns an initialized Patcher for gen. func NewPatcher(gen *protogen.Plugin) (*Patcher, error) { p := &Patcher{ - gen: gen, - packagesByPath: make(map[string]*Package), - packagesByName: make(map[string]*Package), - renames: make(map[protogen.GoIdent]string), - typeRenames: make(map[protogen.GoIdent]string), - valueRenames: make(map[protogen.GoIdent]string), - fieldRenames: make(map[protogen.GoIdent]string), - methodRenames: make(map[protogen.GoIdent]string), - objectRenames: make(map[types.Object]string), - tags: make(map[protogen.GoIdent]string), - fieldTags: make(map[types.Object]string), - embeds: make(map[protogen.GoIdent]string), - fieldEmbeds: make(map[types.Object]string), - types: make(map[protogen.GoIdent]string), - fieldTypes: make(map[types.Object]string), + gen: gen, + packagesByPath: make(map[string]*Package), + packagesByName: make(map[string]*Package), + renames: make(map[protogen.GoIdent]string), + typeRenames: make(map[protogen.GoIdent]string), + valueRenames: make(map[protogen.GoIdent]string), + fieldRenames: make(map[protogen.GoIdent]string), + methodRenames: make(map[protogen.GoIdent]string), + objectRenames: make(map[types.Object]string), + tags: make(map[protogen.GoIdent]string), + fieldTags: make(map[types.Object]string), + embeds: make(map[protogen.GoIdent]string), + fieldEmbeds: make(map[types.Object]string), + types: make(map[protogen.GoIdent]string), + fieldTypes: make(map[types.Object]string), + processedMessages: make(map[protogen.GoIdent]bool), } return p, p.scan() } +func getExtensionDesc(pb proto.Message, extname string) (*proto.ExtensionDesc, error) { + desc := proto.RegisteredExtensions(pb) + for _, d := range desc { + if d.Name == extname { + return d, nil + } + } + return nil, fmt.Errorf("ExtensionDesc not found") +} + +func getExtension(pb proto.Message, extname string) (interface{}, error) { + d, err := getExtensionDesc(pb, extname) + if err != nil { + return nil, err + } + e, err := proto.GetExtension(pb, d) + if err != nil { + return nil, err + } + return e, err +} + func (p *Patcher) scan() error { for _, f := range p.gen.Files { p.scanFile(f) } + for _, f := range p.gen.Request.ProtoFile { + found := false + mident := protogen.GoIdent{GoName: "", GoImportPath: ""} + fident := protogen.GoIdent{GoName: "", GoImportPath: ""} + for _, genFile := range p.gen.Files { + if *f.Name == genFile.Desc.Path() { + found = true + mident = protogen.GoIdent{GoName: "", GoImportPath: genFile.GoImportPath} + fident = protogen.GoIdent{GoName: "", GoImportPath: genFile.GoImportPath} + break + } + } + if !found { + panic("Not found") + } + for _, m := range f.MessageType { + mident.GoName = *m.Name + if _, ok := p.processedMessages[mident]; ok { + continue + } + for _, msgfield := range m.Field { + fident.GoName = *msgfield.Name + p.scanProtoField(mident, fident, msgfield) + } + } + } + return nil } +func (p *Patcher) scanProtoField(mident protogen.GoIdent, fident protogen.GoIdent, f *descriptorpb.FieldDescriptorProto) { + //m := f.Parent + //o := f.Oneof + + if f.TypeName == nil { + return + } + fi, err := getExtension(f.GetOptions(), "go.field") + if err != nil { + return + } + opts := fi.(*gopb.Options) + + log.Printf("Parent Message %v (%v), opts %v", *f.Name, *f.TypeName, opts) + // Embed field ? + embed := false + newName := "" + if opts.GetEmbed() { + switch { + default: + embed = true + log.Printf("Embed Set %v ", *f.Name, *f.TypeName) + splitStrings := strings.Split((*f.TypeName)[1:], ".") + newName = splitStrings[len(splitStrings)-1] + } + } + if newName != "" { + if false { + p.RenameType(fident, p.nameFor(mident)+"_"+newName) // Oneof wrapper struct + p.RenameField(ident.WithChild(fident, fident.GoName), newName, false) // Oneof wrapper field (not embeddable) + } else { + p.RenameField(ident.WithChild(mident, fident.GoName), newName, embed) // Field + childID := ident.WithChild(mident, fident.GoName) + log.Printf("child %v parent %v", childID, mident.GoName) + } + p.RenameMethod(ident.WithChild(mident, "Get"+fident.GoName), "Get"+newName) // Getter + } else { + p.RenameField(ident.WithChild(mident, fident.GoName), newName, embed) // Field + } + + tags := opts.GetTags() + if tags != "" { + p.Tag(ident.WithChild(mident, fident.GoName), tags) // Field tags + } +} + func (p *Patcher) scanFile(f *protogen.File) { log.Printf("\nScan proto:\t%s", f.Desc.Path()) @@ -190,6 +291,7 @@ func (p *Patcher) scanMessage(m *protogen.Message, parent *protogen.Message) { opts := messageOptions(m) lints := fileLintOptions(m.Desc) + p.processedMessages[m.GoIdent] = true // Rename message? newName := opts.GetName() if newName == "" && parent != nil && p.isRenamed(parent.GoIdent) {