From dbc333be76bafd2bad225c7d773e9dd15a6c10d5 Mon Sep 17 00:00:00 2001 From: Luca Steeb Date: Thu, 28 Dec 2023 17:12:49 +0700 Subject: [PATCH 1/3] fix(engine): adapt datasource handling (#1136) --- engine/lifecycle.go | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/engine/lifecycle.go b/engine/lifecycle.go index 9fdd4062..9081c1ff 100644 --- a/engine/lifecycle.go +++ b/engine/lifecycle.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "log" "os" "os/exec" "path" @@ -202,18 +201,7 @@ func (e *QueryEngine) GetEncodedDatasources() (string, error) { } for i := range datasources { - if env := datasources[i].URL.FromEnvVar; env != "" { - url := os.Getenv(env) - if url == "" { - log.Printf("WARNING: env var %s which was defined in the Prisma schema is not set", env) - continue - // return "", fmt.Errorf("env var %s which was defined in the Prisma schema is not set", env) - } - overrides = append(overrides, DatasourceOverride{ - Name: datasources[i].Name.String(), - URL: url, - }) - } else { + if val := datasources[i].URL.Value; val != "" { overrides = append(overrides, DatasourceOverride{ Name: datasources[i].Name.String(), URL: e.datasourceURL, From ed187c088abfe2acfcca74072027a068634f5de0 Mon Sep 17 00:00:00 2001 From: Luca Steeb Date: Thu, 28 Dec 2023 17:12:57 +0700 Subject: [PATCH 2/3] fix(generator): major casing refactor (#1137) --- engine/lifecycle.go | 2 + generator/templates/models.gotpl | 2 + generator/templates/query.gotpl | 2 +- generator/types/types.go | 12 +-- generator/types/types_test.go | 29 +++++++ helpers/gocase/gocase.go | 70 ++++++++++++++-- helpers/gocase/gocase_test.go | 107 +++++++++++++++++++----- helpers/strcase/camel.go | 4 +- helpers/strcase/camel_test.go | 2 +- test/features/composite/default_test.go | 7 +- test/features/composite/schema.prisma | 1 + test/features/enums/enums_test.go | 2 +- test/integration/main.go | 3 +- 13 files changed, 206 insertions(+), 37 deletions(-) create mode 100644 generator/types/types_test.go diff --git a/engine/lifecycle.go b/engine/lifecycle.go index 9081c1ff..3db922b1 100644 --- a/engine/lifecycle.go +++ b/engine/lifecycle.go @@ -218,6 +218,8 @@ func (e *QueryEngine) GetEncodedDatasources() (string, error) { return "", fmt.Errorf("marshal datasources: %w", err) } + log.Printf("overriding datasources raw: %s", raw) + datasourcesBase64 := base64.URLEncoding.EncodeToString(raw) return datasourcesBase64, nil diff --git a/generator/templates/models.gotpl b/generator/templates/models.gotpl index 573dec15..c4d9a611 100644 --- a/generator/templates/models.gotpl +++ b/generator/templates/models.gotpl @@ -12,6 +12,8 @@ {{ range $field := $model.Fields }} {{- if not $field.Kind.IsRelation -}} {{- if $field.IsRequired }} + // GoCase {{ $field.Name.GoCase }} + // GoLowerCase {{ $field.Name.GoLowerCase }} {{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ end }}{{ $field.Type.Value }} {{ $field.Name.Tag $field.IsRequired }} {{- else }} {{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ else }}*{{ end }}{{ $field.Type.Value }} {{ $field.Name.Tag $field.IsRequired }} diff --git a/generator/templates/query.gotpl b/generator/templates/query.gotpl index 66b16680..0ad1e5bb 100644 --- a/generator/templates/query.gotpl +++ b/generator/templates/query.gotpl @@ -54,7 +54,7 @@ {{/* composite keys for FindUnique */}} {{ range $unique := $model.CompoundKeys }} - func ({{ $nsQuery }}) {{ $unique.Name }}( + func ({{ $nsQuery }}) {{ $unique.Name.GoCase }}( {{- range $f := $unique.Fields }} _{{- $f.GoLowerCase }} {{ $model.Name.GoCase }}WithPrisma{{ $f.GoCase }}WhereParam, {{ end -}} diff --git a/generator/types/types.go b/generator/types/types.go index 285948c7..1daf1cb9 100644 --- a/generator/types/types.go +++ b/generator/types/types.go @@ -16,12 +16,12 @@ func (s String) String() string { // GoCase transforms strings into Go-style casing, meaning uppercase including Go casing edge cases. func (s String) GoCase() string { - return gocase.To(strcase.ToCamel(string(s))) + return gocase.ToUpper(string(s)) } // GoLowerCase transforms strings into Go-style lowercase casing. It is like GoCase but used for private fields. func (s String) GoLowerCase() string { - return gocase.To(strcase.ToLowerCamel(string(s))) + return gocase.ToLower(string(s)) } // CamelCase transforms strings into camelCase casing. It is often used for json mappings. @@ -39,7 +39,7 @@ func (s String) Tag(isRequired bool) string { // PrismaGoCase transforms `relevance` into `Relevance_` func (s String) PrismaGoCase() string { - return strcase.ToCamel(string(s)) + "_" + return strcase.ToUpperCamel(string(s)) + "_" } // PrismaInternalCase transforms `relevance` into `_relevance` @@ -75,17 +75,17 @@ func (t Type) Value() string { return v } - return gocase.To(strcase.ToCamel(str)) + return gocase.ToUpper(strcase.ToUpperCamel(str)) } // GoCase transforms strings into Go-style lowercase casing. It is like GoCase but used for private fields. func (t Type) GoCase() string { - return gocase.To(strcase.ToCamel(string(t))) + return gocase.ToUpper(string(t)) } // GoLowerCase transforms strings into Go-style lowercase casing. It is like GoCase but used for private fields. func (t Type) GoLowerCase() string { - return gocase.To(strcase.ToLowerCamel(string(t))) + return gocase.ToLower(string(t)) } // CamelCase transforms strings into camelCase casing. It is often used for json mappings. diff --git a/generator/types/types_test.go b/generator/types/types_test.go new file mode 100644 index 00000000..1bfb01ba --- /dev/null +++ b/generator/types/types_test.go @@ -0,0 +1,29 @@ +package types + +import ( + "fmt" + "testing" +) + +func TestString_GoCase(t *testing.T) { + tests := []struct { + have String + want string + }{{ + have: "", + want: "", + }, { + have: "anotherIDStuffSomethingID", + want: "AnotherIDStuffSomethingID", + }, { + have: "anotherIdStuffSomethingId", + want: "AnotherIDStuffSomethingID", + }} + for _, tt := range tests { + t.Run(fmt.Sprintf("%s -> %s", tt.have, tt.want), func(t *testing.T) { + if got := tt.have.GoCase(); got != tt.want { + t.Errorf("GoCase() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/helpers/gocase/gocase.go b/helpers/gocase/gocase.go index 91a8e4b7..d89547fe 100644 --- a/helpers/gocase/gocase.go +++ b/helpers/gocase/gocase.go @@ -34,15 +34,22 @@ import ( "fmt" "regexp" "strings" + + "github.com/steebchen/prisma-client-go/helpers/strcase" ) -// To returns a string converted to Go case. -func To(s string) string { - return defaultConverter.To(s) +// ToLower returns a string converted to Go upper case. +func ToLower(s string) string { + return defaultConverter.To(s, false) +} + +// ToUpper returns a string converted to Go lower case. +func ToUpper(s string) string { + return defaultConverter.To(s, true) } // To returns a string converted to Go case with converter. -func (c *Converter) To(s string) string { +func (c *Converter) To(s string, upper bool) string { for _, i := range c.initialisms { // not end re1 := regexp.MustCompile(fmt.Sprintf("%s([^a-z])", i.capUpper())) @@ -52,7 +59,60 @@ func (c *Converter) To(s string) string { re2 := regexp.MustCompile(fmt.Sprintf("%s$", i.capUpper())) s = re2.ReplaceAllString(s, i.allUpper()) } - return s + + if len(s) == 0 { + return s + } + + var isAllCap = true + for _, i := range s { + if i < 'A' || i > 'Z' { + isAllCap = false + } + } + + if strings.Contains(s, "_") || isAllCap { + if upper { + // for snake case + s = strcase.ToUpperCamel(s) + } else { + s = strcase.ToLowerCamel(s) + } + } else { + if upper { + s = strings.ToUpper(s[:1]) + s[1:] + } else { + // TODO!!! + s = strings.ToLower(s[:1]) + s[1:] + } + } + + // run again for new uppercase words + for _, i := range c.initialisms { + // not end + re1 := regexp.MustCompile(fmt.Sprintf("%s([^a-z])", i.capUpper())) + s = re1.ReplaceAllString(s, i.allUpper()+"$1") + + // end + re2 := regexp.MustCompile(fmt.Sprintf("%s$", i.capUpper())) + s = re2.ReplaceAllString(s, i.allUpper()) + } + + // fix casing after numbers + n := strings.Builder{} + n.Grow(len(s)) + prevIsNumber := false + for _, v := range []byte(s) { + vIsLow := v >= 'a' && v <= 'z' + if prevIsNumber && vIsLow { + v += 'A' + v -= 'a' + } + prevIsNumber = v >= '0' && v <= '9' + n.WriteByte(v) + } + + return n.String() } // Revert returns a string converted from Go case to normal case. diff --git a/helpers/gocase/gocase_test.go b/helpers/gocase/gocase_test.go index 80622371..e95c852b 100644 --- a/helpers/gocase/gocase_test.go +++ b/helpers/gocase/gocase_test.go @@ -1,6 +1,7 @@ package gocase_test import ( + "fmt" "strings" "testing" "unicode/utf8" @@ -8,34 +9,101 @@ import ( "github.com/steebchen/prisma-client-go/helpers/gocase" ) -func TestConverter_To(t *testing.T) { +func TestConverter_ToLower(t *testing.T) { t.Parallel() dc, _ := gocase.New() cc, _ := gocase.New(gocase.WithInitialisms("JSON", "CSV")) cases := []struct { - conv *gocase.Converter - s, want string + conv *gocase.Converter + have, want string }{ - {conv: dc, s: "", want: ""}, - {conv: dc, s: "jsonFile", want: "jsonFile"}, - {conv: dc, s: "IpAddress", want: "IPAddress"}, - {conv: dc, s: "defaultDnsServer", want: "defaultDNSServer"}, - {conv: dc, s: "somethingHttpApiId", want: "somethingHTTPAPIID"}, - {conv: dc, s: "somethingUuid", want: "somethingUUID"}, - {conv: dc, s: "somethingSip", want: "somethingSIP"}, - {conv: dc, s: "Urid", want: "Urid"}, - {conv: cc, s: "JsonFile", want: "JSONFile"}, - {conv: cc, s: "CsvFile", want: "CSVFile"}, - {conv: cc, s: "IpAddress", want: "IpAddress"}, + {conv: dc, have: "", want: ""}, + {conv: dc, have: "CONSTANT", want: "constant"}, + {conv: dc, have: "id", want: "id"}, + {conv: dc, have: "ID", want: "id"}, + {conv: dc, have: "jsonFile", want: "jsonFile"}, + // {conv: dc, have: "IpAddress", want: "ipAddress"}, + {conv: dc, have: "ip_address", want: "ipAddress"}, + {conv: dc, have: "defaultDnsServer", want: "defaultDNSServer"}, + {conv: dc, have: "somethingHttpApiId", want: "somethingHTTPAPIID"}, + {conv: dc, have: "somethingUuid", want: "somethingUUID"}, + {conv: dc, have: "somethingSip", want: "somethingSIP"}, + {conv: dc, have: "Urid", want: "urid"}, + {conv: dc, have: "stuffLast7D", want: "stuffLast7D"}, + {conv: dc, have: "stuffLast7d", want: "stuffLast7D"}, + {conv: dc, have: "StuffLast7d", want: "stuffLast7D"}, + {conv: dc, have: "StuffLast7dAnd", want: "stuffLast7DAnd"}, + {conv: dc, have: "StuffLast7DAnd", want: "stuffLast7DAnd"}, + {conv: dc, have: "anotherIDStuffSomethingID", want: "anotherIDStuffSomethingID"}, + {conv: dc, have: "anotherIdStuffSomethingId", want: "anotherIDStuffSomethingID"}, + {conv: dc, have: "anotherIdStuffSomethingId", want: "anotherIDStuffSomethingID"}, + {conv: dc, have: "another_id_stuff_something_id", want: "anotherIDStuffSomethingID"}, + + // {conv: cc, have: "JsonFile", want: "jsonFile"}, + // {conv: cc, have: "CsvFile", want: "csvFile"}, + {conv: cc, have: "IpAddress", want: "ipAddress"}, } for _, c := range cases { - r := c.conv.To(c.s) - if r != c.want { - t.Errorf("value doesn't match: %s (want %s)", r, c.want) - } + cc := c + t.Run(fmt.Sprintf("%s -> %s", cc.have, cc.want), func(t *testing.T) { + r := cc.conv.To(cc.have, false) + if r != cc.want { + t.Errorf("value doesn't match: have %s, is %s, want %s", cc.have, r, cc.want) + } + }) + } +} + +func TestConverter_ToUpper(t *testing.T) { + t.Parallel() + + dc, _ := gocase.New() + cc, _ := gocase.New(gocase.WithInitialisms("JSON", "CSV")) + + cases := []struct { + conv *gocase.Converter + have, want string + }{ + {conv: dc, have: "", want: ""}, + {conv: dc, have: "CONSTANT", want: "Constant"}, + {conv: dc, have: "id", want: "ID"}, + {conv: dc, have: "Id", want: "ID"}, + {conv: dc, have: "IdSomething", want: "IDSomething"}, + {conv: dc, have: "IDSomething", want: "IDSomething"}, + {conv: dc, have: "jsonFile", want: "JSONFile"}, + {conv: dc, have: "IpAddress", want: "IPAddress"}, + {conv: dc, have: "ip_address", want: "IPAddress"}, + {conv: dc, have: "defaultDnsServer", want: "DefaultDNSServer"}, + {conv: dc, have: "somethingHttpApiId", want: "SomethingHTTPAPIID"}, + {conv: dc, have: "somethingUuid", want: "SomethingUUID"}, + {conv: dc, have: "somethingSip", want: "SomethingSIP"}, + {conv: dc, have: "stuffLast7D", want: "StuffLast7D"}, + {conv: dc, have: "stuffLast7d", want: "StuffLast7D"}, + {conv: dc, have: "StuffLast7d", want: "StuffLast7D"}, + {conv: dc, have: "StuffLast7dAnd", want: "StuffLast7DAnd"}, + {conv: dc, have: "StuffLast7DAnd", want: "StuffLast7DAnd"}, + {conv: dc, have: "Urid", want: "Urid"}, + {conv: dc, have: "anotherIDStuffSomethingID", want: "AnotherIDStuffSomethingID"}, + {conv: dc, have: "anotherIdStuffSomethingId", want: "AnotherIDStuffSomethingID"}, + {conv: dc, have: "anotherIdStuffSomethingId", want: "AnotherIDStuffSomethingID"}, + {conv: dc, have: "another_id_stuff_something_id", want: "AnotherIDStuffSomethingID"}, + + {conv: cc, have: "JsonFile", want: "JSONFile"}, + {conv: cc, have: "CsvFile", want: "CSVFile"}, + {conv: cc, have: "IpAddress", want: "IpAddress"}, + } + + for _, c := range cases { + cc := c + t.Run(fmt.Sprintf("%s -> %s", cc.have, cc.want), func(t *testing.T) { + r := cc.conv.To(cc.have, true) + if r != cc.want { + t.Errorf("value doesn't match: have %s, is %s, want %s", cc.have, r, cc.want) + } + }) } } @@ -76,12 +144,13 @@ func TestConverter_Revert(t *testing.T) { // go test -fuzz=Fuzz // ``` func FuzzReverse(f *testing.F) { + f.Skip() testcases := []string{"jsonFile", "IpAddress", "defaultDnsServer"} for _, tc := range testcases { f.Add(tc) } f.Fuzz(func(t *testing.T, orig string) { - to := gocase.To(orig) + to := gocase.ToUpper(orig) rev := gocase.Revert(to) if !ignoreInput(orig) && orig != rev { t.Errorf("before: %q, after: %q", orig, rev) diff --git a/helpers/strcase/camel.go b/helpers/strcase/camel.go index 5224c912..be029b5d 100644 --- a/helpers/strcase/camel.go +++ b/helpers/strcase/camel.go @@ -74,8 +74,8 @@ func toCamelInitCase(s string, initCase bool) string { return n.String() } -// ToCamel converts a string to CamelCase -func ToCamel(s string) string { +// ToUpperCamel converts a string to CamelCase +func ToUpperCamel(s string) string { return toCamelInitCase(s, true) } diff --git a/helpers/strcase/camel_test.go b/helpers/strcase/camel_test.go index 7bb759ec..636588d1 100644 --- a/helpers/strcase/camel_test.go +++ b/helpers/strcase/camel_test.go @@ -46,7 +46,7 @@ func toCamel(tb testing.TB) { for _, i := range cases { in := i[0] out := i[1] - result := ToCamel(in) + result := ToUpperCamel(in) if result != out { tb.Errorf("%q (%q != %q)", in, result, out) } diff --git a/test/features/composite/default_test.go b/test/features/composite/default_test.go index acf4b2eb..b4dc5a70 100644 --- a/test/features/composite/default_test.go +++ b/test/features/composite/default_test.go @@ -13,12 +13,17 @@ type Func func(t *testing.T, client *PrismaClient, ctx cx) func TestComposite(t *testing.T) { // no-op compile time test - User.SomethingIDAnotherIDStuff( User.SomethingID.Equals(""), User.AnotherIDStuff.Equals(""), ) + // custom name test + User.AnotherIDStuffSomethingID( + User.AnotherIDStuff.Equals(""), + User.SomethingID.Equals(""), + ) + tests := []struct { name string before []string diff --git a/test/features/composite/schema.prisma b/test/features/composite/schema.prisma index 221cc543..08077873 100644 --- a/test/features/composite/schema.prisma +++ b/test/features/composite/schema.prisma @@ -21,4 +21,5 @@ model User { @@id([firstName, lastName]) @@unique([firstName, middleName, lastName]) @@unique([somethingId, anotherIdStuff]) + @@unique([anotherIdStuff, somethingId], name: "anotherIDStuffSomethingID") } diff --git a/test/features/enums/enums_test.go b/test/features/enums/enums_test.go index c1b2829a..fd433ea0 100644 --- a/test/features/enums/enums_test.go +++ b/test/features/enums/enums_test.go @@ -23,7 +23,7 @@ func TestEnums(t *testing.T) { _ = StuffLast30D _ = StuffSlack _ = StuffLast7DAnd - _ = StuffLast30Dand + _ = StuffLast30DAnd _ = StuffID tests := []struct { diff --git a/test/integration/main.go b/test/integration/main.go index fa71824d..d4463db3 100644 --- a/test/integration/main.go +++ b/test/integration/main.go @@ -1,9 +1,10 @@ +//go:generate go run github.com/steebchen/prisma-client-go generate --schema schemax.prisma + package main import ( "context" "fmt" - "integration/db" ) From d2b2373128a2723e92eafd07012d0ff237c8a4cf Mon Sep 17 00:00:00 2001 From: Luca Steeb Date: Thu, 28 Dec 2023 17:28:37 +0700 Subject: [PATCH 3/3] fix(engine): minor changes (#1138) --- engine/lifecycle.go | 2 -- generator/templates/models.gotpl | 2 -- helpers/gocase/gocase_test.go | 28 ++++++++++++++++++++++++---- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/engine/lifecycle.go b/engine/lifecycle.go index 3db922b1..9081c1ff 100644 --- a/engine/lifecycle.go +++ b/engine/lifecycle.go @@ -218,8 +218,6 @@ func (e *QueryEngine) GetEncodedDatasources() (string, error) { return "", fmt.Errorf("marshal datasources: %w", err) } - log.Printf("overriding datasources raw: %s", raw) - datasourcesBase64 := base64.URLEncoding.EncodeToString(raw) return datasourcesBase64, nil diff --git a/generator/templates/models.gotpl b/generator/templates/models.gotpl index c4d9a611..573dec15 100644 --- a/generator/templates/models.gotpl +++ b/generator/templates/models.gotpl @@ -12,8 +12,6 @@ {{ range $field := $model.Fields }} {{- if not $field.Kind.IsRelation -}} {{- if $field.IsRequired }} - // GoCase {{ $field.Name.GoCase }} - // GoLowerCase {{ $field.Name.GoLowerCase }} {{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ end }}{{ $field.Type.Value }} {{ $field.Name.Tag $field.IsRequired }} {{- else }} {{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ else }}*{{ end }}{{ $field.Type.Value }} {{ $field.Name.Tag $field.IsRequired }} diff --git a/helpers/gocase/gocase_test.go b/helpers/gocase/gocase_test.go index e95c852b..1e90e092 100644 --- a/helpers/gocase/gocase_test.go +++ b/helpers/gocase/gocase_test.go @@ -135,7 +135,7 @@ func TestConverter_Revert(t *testing.T) { } } -// FuzzReverse runs a Fuzzing test to check if the strings +// FuzzReverseUpper runs a Fuzzing test to check if the strings // before and after `To` and `Revert` match. // Note that there may be cases where the strings before and after // the `To` and `Revert` do not match for certain inputs. @@ -143,9 +143,8 @@ func TestConverter_Revert(t *testing.T) { // ```cmd // go test -fuzz=Fuzz // ``` -func FuzzReverse(f *testing.F) { - f.Skip() - testcases := []string{"jsonFile", "IpAddress", "defaultDnsServer"} +func FuzzReverseUpper(f *testing.F) { + testcases := []string{"JsonFile", "IpAddress", "DefaultDnsServer"} for _, tc := range testcases { f.Add(tc) } @@ -161,6 +160,27 @@ func FuzzReverse(f *testing.F) { }) } +// FuzzReverseLower runs a Fuzzing test to check if the strings +// before and after `To` and `Revert` match. +// Note that there may be cases where the strings before and after +// the `To` and `Revert` do not match for certain inputs. +func FuzzReverseLower(f *testing.F) { + testcases := []string{"jsonFile", "ipAddress", "defaultDnsServer"} + for _, tc := range testcases { + f.Add(tc) + } + f.Fuzz(func(t *testing.T, orig string) { + to := gocase.ToLower(orig) + rev := gocase.Revert(to) + if !ignoreInput(orig) && orig != rev { + t.Errorf("before: %q, after: %q", orig, rev) + } + if utf8.ValidString(orig) && !utf8.ValidString(rev) { + t.Errorf("To or Revert produced invalid UTF-8 string %q", rev) + } + }) +} + func ignoreInput(in string) bool { for _, s := range gocase.DefaultInitialisms {