diff --git a/internal/gen.go b/internal/gen.go index f88ae2c..f263baf 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -124,9 +124,40 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat enums, structs = filterUnusedStructs(enums, structs, queries) } + if err := validate(options, enums, structs, queries); err != nil { + return nil, err + } + return generate(req, options, enums, structs, queries) } +func validate(options *opts.Options, enums []Enum, structs []Struct, queries []Query) error { + enumNames := make(map[string]struct{}) + for _, enum := range enums { + enumNames[enum.Name] = struct{}{} + enumNames["Null"+enum.Name] = struct{}{} + } + structNames := make(map[string]struct{}) + for _, struckt := range structs { + if _, ok := enumNames[struckt.Name]; ok { + return fmt.Errorf("struct name conflicts with enum name: %s", struckt.Name) + } + structNames[struckt.Name] = struct{}{} + } + if !options.EmitExportedQueries { + return nil + } + for _, query := range queries { + if _, ok := enumNames[query.ConstantName]; ok { + return fmt.Errorf("query constant name conflicts with enum name: %s", query.ConstantName) + } + if _, ok := structNames[query.ConstantName]; ok { + return fmt.Errorf("query constant name conflicts with struct name: %s", query.ConstantName) + } + } + return nil +} + func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.GenerateResponse, error) { i := &importer{ Options: options, diff --git a/internal/result.go b/internal/result.go index d0ed307..714b1e1 100644 --- a/internal/result.go +++ b/internal/result.go @@ -312,16 +312,16 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] return qs, nil } +var cmdReturnsData = map[string]struct{}{ + metadata.CmdBatchMany: {}, + metadata.CmdBatchOne: {}, + metadata.CmdMany: {}, + metadata.CmdOne: {}, +} + func putOutColumns(query *plugin.Query) bool { - if len(query.Columns) > 0 { - return true - } - for _, allowed := range []string{metadata.CmdMany, metadata.CmdOne, metadata.CmdBatchMany} { - if query.Cmd == allowed { - return true - } - } - return false + _, found := cmdReturnsData[query.Cmd] + return found } // It's possible that this method will generate duplicate JSON tag values diff --git a/internal/result_test.go b/internal/result_test.go index be30e1e..fd8bd11 100644 --- a/internal/result_test.go +++ b/internal/result_test.go @@ -50,7 +50,7 @@ func TestPutOutColumns_ForZeroColumns(t *testing.T) { }, { cmd: metadata.CmdBatchOne, - want: false, + want: true, }, } for _, tc := range tests {