Skip to content

Commit

Permalink
Add pattern, support reference constraints on primitives, and add num…
Browse files Browse the repository at this point in the history
…ber/integer constraints (#264)

* add pattern

* support primitive types having restraints

* Numeric types

* Fix a bug where primitve string types can't pattern match

* Fix lint

* Fix test case output

* Min int size

* remote refs

* fix more lint

* Use math.Round instead of directly casting first

* Try with floats instead...

* PR feedback, fix typos, and reversed old for exclusive min/max
  • Loading branch information
nolag authored Sep 16, 2024
1 parent e14741f commit bbba7f7
Show file tree
Hide file tree
Showing 54 changed files with 3,071 additions and 153 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,16 @@ only specific validations remain to be fully implemented.
* [x] `type` (single)
* [x] `type` (multiple; **note**: partial support, limited validation)
* [ ] `const`
* [ ] Numeric validation (§6.2)
* [ ] `multipleOf`
* [ ] `maximum`
* [ ] `exclusiveMaximum`
* [ ] `minimum`
* [ ] `exclusiveMinimum`
* [ ] String validation (§6.3)
* [X] Numeric validation (§6.2)
* [X] `multipleOf`
* [X] `maximum`
* [X] `exclusiveMaximum`
* [X] `minimum`
* [X] `exclusiveMinimum`
* [X] String validation (§6.3)
* [X] `maxLength`
* [X] `minLength`
* [ ] `pattern`
* [X] `pattern`
* [ ] Array validation (§6.4)
* [X] `items`
* [x] `maxItems`
Expand Down
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ go 1.22.0
toolchain go1.22.7

require (
github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b // indirect
github.com/fatih/color v1.17.0 // indirect
github.com/goccy/go-yaml v1.12.0
github.com/mitchellh/go-wordwrap v1.0.1
github.com/pkg/errors v0.9.1
github.com/sanity-io/litter v1.5.5
github.com/spf13/cobra v1.8.1
github.com/stretchr/testify v0.0.0-20161117074351-18a02ba4a312
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0
)

require (
github.com/fatih/color v1.17.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
Expand Down
7 changes: 7 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ var (
yamlExtensions []string
tags []string
structNameFromTitle bool
minSizedInts bool

errFlagFormat = errors.New("flag must be in the format URI=PACKAGE")

Expand Down Expand Up @@ -75,6 +76,7 @@ var (
StructNameFromTitle: structNameFromTitle,
Tags: tags,
OnlyModels: onlyModels,
MinSizedInts: minSizedInts,
}
for _, id := range allKeys(schemaPackageMap, schemaOutputMap, schemaRootTypeMap) {
mapping := generator.SchemaMapping{SchemaID: id}
Expand Down Expand Up @@ -166,6 +168,11 @@ also look for foo.json if --resolve-extension json is provided.`)
"Use the schema title as the generated struct name")
rootCmd.PersistentFlags().StringSliceVar(&tags, "tags", []string{"json", "yaml", "mapstructure"},
`Specify which struct tags to generate. Defaults are json, yaml, mapstructure`)
rootCmd.PersistentFlags().BoolVar(
&minSizedInts,
"min-sized-ints",
false,
"Uses sized int and uint values based on the min and max values for the field")

abortWithErr(rootCmd.Execute())
}
Expand Down
116 changes: 115 additions & 1 deletion pkg/codegen/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package codegen
import (
"errors"
"fmt"
"math"

"github.com/atombender/go-jsonschema/pkg/mathutils"
"github.com/atombender/go-jsonschema/pkg/schemas"
)

Expand Down Expand Up @@ -33,7 +35,16 @@ func isPointerType(t Type) bool {
}
}

func PrimitiveTypeFromJSONSchemaType(jsType, format string, pointer bool) (Type, error) {
func PrimitiveTypeFromJSONSchemaType(
jsType,
format string,
pointer,
minIntSize bool,
minimum **float64,
maximum **float64,
exclusiveMinimum **any,
exclusiveMaximum **any,
) (Type, error) {
var t Type

switch jsType {
Expand Down Expand Up @@ -119,6 +130,22 @@ func PrimitiveTypeFromJSONSchemaType(jsType, format string, pointer bool) (Type,

case schemas.TypeNameInteger:
t := PrimitiveType{"int"}

if minIntSize {
newType, removeMin, removeMax := getMinIntType(*minimum, *maximum, *exclusiveMinimum, *exclusiveMaximum)
t.Type = newType

if removeMin {
*minimum = nil
*exclusiveMaximum = nil
}

if removeMax {
*maximum = nil
*exclusiveMinimum = nil
}
}

if pointer {
return WrapTypeInPointer(t), nil
}
Expand All @@ -142,3 +169,90 @@ func PrimitiveTypeFromJSONSchemaType(jsType, format string, pointer bool) (Type,

return nil, fmt.Errorf("%w %q", errUnknownJSONSchemaType, jsType)
}

// getMinIntType returns the smallest integer type that can represent the bounds, and if the bounds can be removed.
func getMinIntType(
minimum, maximum *float64, exclusiveMinimum, exclusiveMaximum *any,
) (string, bool, bool) {
nMin, nMax, nExclusiveMin, nExclusiveMax := mathutils.NormalizeBounds(
minimum, maximum, exclusiveMinimum, exclusiveMaximum,
)

if nExclusiveMin && nMin != nil {
*nMin += 1.0
}

if nExclusiveMax && nMax != nil {
*nMax -= 1.0
}

if nMin != nil && *nMin >= 0 {
return adjustForUnsignedBounds(nMin, nMax)
}

return adjustForSignedBounds(nMin, nMax)
}

const i64 = "int64"

func adjustForSignedBounds(nMin, nMax *float64) (string, bool, bool) {
var minRounded, maxRounded float64

if nMin != nil {
minRounded = math.Round(*nMin)
}

if nMax != nil {
maxRounded = math.Round(*nMax)
}

switch {
case nMin == nil && nMax == nil:
return i64, false, false

case nMin == nil:
return i64, false, maxRounded == float64(math.MaxInt64)

case nMax == nil:
return i64, minRounded == float64(math.MinInt64), false

case minRounded < float64(math.MinInt32) || maxRounded > float64(math.MaxInt32):
return i64, minRounded == float64(math.MinInt64), maxRounded == float64(math.MaxInt64)

case minRounded < float64(math.MinInt16) || maxRounded > float64(math.MaxInt16):
return "int32", minRounded == float64(math.MinInt32), maxRounded == float64(math.MaxInt32)

case minRounded < float64(math.MinInt8) || maxRounded > float64(math.MaxInt8):
return "int16", minRounded == float64(math.MinInt16), maxRounded == float64(math.MaxInt16)

default:
return "int8", minRounded == float64(math.MinInt8), maxRounded == float64(math.MaxInt8)
}
}

func adjustForUnsignedBounds(nMin, nMax *float64) (string, bool, bool) {
removeMin := nMin != nil && *nMin == 0.0

var maxRounded float64

if nMax != nil {
maxRounded = math.Round(*nMax)
}

switch {
case nMax == nil:
return "uint64", removeMin, false

case maxRounded > float64(math.MaxUint32):
return "uint64", removeMin, maxRounded == float64(math.MaxUint64)

case maxRounded > float64(math.MaxUint16):
return "uint32", removeMin, maxRounded == float64(math.MaxUint32)

case maxRounded > float64(math.MaxUint8):
return "uint16", removeMin, maxRounded == float64(math.MaxUint16)

default:
return "uint8", removeMin, maxRounded == float64(math.MaxUint8)
}
}
4 changes: 4 additions & 0 deletions pkg/generator/config.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package generator

import "github.com/atombender/go-jsonschema/pkg/schemas"

type Config struct {
SchemaMappings []SchemaMapping
ExtraImports bool
Expand All @@ -12,6 +14,8 @@ type Config struct {
Warner func(string)
Tags []string
OnlyModels bool
MinSizedInts bool
Loader schemas.Loader
}

type SchemaMapping struct {
Expand Down
39 changes: 18 additions & 21 deletions pkg/generator/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,19 @@ var (
errEnumNonPrimitiveVal = errors.New("enum has non-primitive value")
errMapURIToPackageName = errors.New("unable to map schema URI to Go package name")
errExpectedNamedType = errors.New("expected named type")
errUnsupportedRefFormat = errors.New("unsupported $ref format")
errConflictSameFile = errors.New("conflict: same file")
errDefinitionDoesNotExistInSchema = errors.New("definition does not exist in schema")
errCannotGenerateReferencedType = errors.New("cannot generate referenced type")
)

type Generator struct {
caser *text.Caser
config Config
inScope map[qualifiedDefinition]struct{}
outputs map[string]*output
schemaCacheByFileName map[string]*schemas.Schema
warner func(string)
formatters []formatter
fileLoader schemas.Loader
caser *text.Caser
config Config
inScope map[qualifiedDefinition]struct{}
outputs map[string]*output
warner func(string)
formatters []formatter
loader schemas.Loader
}

type qualifiedDefinition struct {
Expand All @@ -56,19 +54,18 @@ func New(config Config) (*Generator, error) {
}

generator := &Generator{
caser: text.NewCaser(config.Capitalizations, config.ResolveExtensions),
config: config,
inScope: map[qualifiedDefinition]struct{}{},
outputs: map[string]*output{},
schemaCacheByFileName: map[string]*schemas.Schema{},
warner: config.Warner,
formatters: formatters,
caser: text.NewCaser(config.Capitalizations, config.ResolveExtensions),
config: config,
inScope: map[qualifiedDefinition]struct{}{},
outputs: map[string]*output{},
warner: config.Warner,
formatters: formatters,
loader: config.Loader,
}

generator.fileLoader = schemas.NewCachedLoader(
schemas.NewFileLoader(config.ResolveExtensions, config.YAMLExtensions),
generator.schemaCacheByFileName,
)
if config.Loader == nil {
generator.loader = schemas.NewDefaultCacheLoader(config.ResolveExtensions, config.YAMLExtensions)
}

return generator, nil
}
Expand Down Expand Up @@ -125,7 +122,7 @@ func (g *Generator) DoFile(fileName string) error {
return fmt.Errorf("error parsing from standard input: %w", err)
}
} else {
schema, err = g.fileLoader.Load(fileName, "")
schema, err = g.loader.Load(fileName, "")
if err != nil {
return fmt.Errorf("error parsing from file %s: %w", fileName, err)
}
Expand Down
37 changes: 26 additions & 11 deletions pkg/generator/json_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,44 @@ const (
type jsonFormatter struct{}

func (jf *jsonFormatter) generate(declType codegen.TypeDecl, validators []validator) func(*codegen.Emitter) {
var beforeValidators []validator

var afterValidators []validator

forceBefore := false

for _, v := range validators {
desc := v.desc()
if desc.beforeJSONUnmarshal {
beforeValidators = append(beforeValidators, v)
} else {
afterValidators = append(afterValidators, v)
forceBefore = forceBefore || desc.requiresRawAfter
}
}

return func(out *codegen.Emitter) {
out.Commentf("Unmarshal%s implements %s.Unmarshaler.", strings.ToUpper(formatJSON), formatJSON)
out.Printlnf("func (j *%s) Unmarshal%s(b []byte) error {", declType.Name, strings.ToUpper(formatJSON))
out.Indent(1)
out.Printlnf("var %s map[string]interface{}", varNameRawMap)
out.Printlnf("if err := %s.Unmarshal(b, &%s); err != nil { return err }",
formatJSON, varNameRawMap)

for _, v := range validators {
if v.desc().beforeJSONUnmarshal {
v.generate(out)
}
if forceBefore || len(beforeValidators) != 0 {
out.Printlnf("var %s map[string]interface{}", varNameRawMap)
out.Printlnf("if err := %s.Unmarshal(b, &%s); err != nil { return err }",
formatJSON, varNameRawMap)
}

for _, v := range beforeValidators {
v.generate(out)
}

out.Printlnf("type Plain %s", declType.Name)
out.Printlnf("var %s Plain", varNamePlainStruct)
out.Printlnf("if err := %s.Unmarshal(b, &%s); err != nil { return err }",
formatJSON, varNamePlainStruct)

for _, v := range validators {
if !v.desc().beforeJSONUnmarshal {
v.generate(out)
}
for _, v := range afterValidators {
v.generate(out)
}

out.Printlnf("*j = %s(%s)", declType.Name, varNamePlainStruct)
Expand Down
Loading

0 comments on commit bbba7f7

Please sign in to comment.