diff --git a/python/protoc_gen_validate/validator.py b/python/protoc_gen_validate/validator.py index a6874ab72..0137331df 100644 --- a/python/protoc_gen_validate/validator.py +++ b/python/protoc_gen_validate/validator.py @@ -308,9 +308,6 @@ def const_template(option_value, name): {%- elif str(o.bool) and o.bool['const'] != "" -%} if {{ name }} != {{ o.bool['const'] }}: raise ValidationFailed(\"{{ name }} not equal to {{ o.bool['const'] }}\") - {%- elif str(o.enum) and o.enum['const'] -%} - if {{ name }} != {{ o.enum['const'] }}: - raise ValidationFailed(\"{{ name }} not equal to {{ o.enum['const'] }}\") {%- elif str(o.bytes) and o.bytes.HasField('const') -%} {% if sys.version_info[0] >= 3 %} if {{ name }} != {{ o.bytes['const'] }}: @@ -828,17 +825,52 @@ def enum_values(field): return [x.number for x in field.enum_type.values] +def enum_name(field, number): + for x in field.enum_type.values: + if x.number == number: + return x.name + return "" + + +def enum_names(field, numbers): + m = {x.number: x.name for x in field.enum_type.values} + return "[" + "".join([m[n] for n in numbers]) + "]" + + +def enum_const_template(value, name, field): + const_tmpl = """{%- if str(value) and value['const'] -%} + if {{ name }} != {{ value['const'] }}: + raise ValidationFailed(\"{{ name }} not equal to {{ enum_name(field, value['const']) }}\") + {%- endif -%} + """ + return Template(const_tmpl).render(value=value, name=name, field=field, enum_name=enum_name, str=str) + + +def enum_in_template(value, name, field): + in_tmpl = """ + {%- if value['in'] %} + if {{ name }} not in {{ value['in'] }}: + raise ValidationFailed(\"{{ name }} not in {{ enum_names(field, value['in']) }}\") + {%- endif -%} + {%- if value['not_in'] %} + if {{ name }} in {{ value['not_in'] }}: + raise ValidationFailed(\"{{ name }} in {{ enum_names(field, value['not_in']) }}\") + {%- endif -%} + """ + return Template(in_tmpl).render(value=value, name=name, field=field, enum_names=enum_names) + + def enum_template(option_value, name, field): enum_tmpl = """ - {{ const_template(option_value, name) -}} - {{ in_template(option_value.enum, name) -}} + {{ enum_const_template(option_value.enum, name, field) -}} + {{ enum_in_template(option_value.enum, name, field) -}} {% if option_value.enum['defined_only'] %} if {{ name }} not in {{ enum_values(field) }}: raise ValidationFailed(\"{{ name }} is not defined\") {% endif %} """ - return Template(enum_tmpl).render(option_value=option_value, name=name, const_template=const_template, - in_template=in_template, field=field, enum_values=enum_values) + return Template(enum_tmpl).render(option_value=option_value, name=name, enum_const_template=enum_const_template, + enum_in_template=enum_in_template, field=field, enum_values=enum_values) def any_template(option_value, name, repeated=False): diff --git a/templates/cc/const.go b/templates/cc/const.go index 2d6af9ed7..27c8e2839 100644 --- a/templates/cc/const.go +++ b/templates/cc/const.go @@ -1,9 +1,13 @@ package cc -const constTpl = `{{ $r := .Rules }} +const constTpl = `{{ $f := .Field }}{{ $r := .Rules }} {{ if $r.Const }} if ({{ accessor . }} != {{ lit $r.GetConst }}) { + {{- if isEnum $f }} + {{ err . "value must equal " (enumVal $f $r.GetConst) }} + {{- else }} {{ err . "value must equal " (lit $r.GetConst) }} + {{- end }} } {{ end }} ` diff --git a/templates/cc/in.go b/templates/cc/in.go index c9b8f7525..57cf21f75 100644 --- a/templates/cc/in.go +++ b/templates/cc/in.go @@ -3,11 +3,19 @@ package cc const inTpl = `{{ $f := .Field -}}{{ $r := .Rules -}} {{- if $r.In }} if ({{ lookup $f "InLookup" }}.find(static_cast({{ accessor . }})) == {{ lookup $f "InLookup" }}.end()) { + {{- if isEnum $f }} + {{ err . "value must be in list " (enumList $f $r.In) }} + {{- else }} {{ err . "value must be in list " $r.In }} + {{- end }} } {{- else if $r.NotIn }} if ({{ lookup $f "NotInLookup" }}.find(static_cast({{ accessor . }})) != {{ lookup $f "NotInLookup" }}.end()) { + {{- if isEnum $f }} + {{ err . "value must not be in list " (enumList $f $r.NotIn) }} + {{- else }} {{ err . "value must not be in list " $r.NotIn }} + {{- end }} } {{- end }} ` diff --git a/templates/goshared/const.go b/templates/goshared/const.go index 18bf64665..118172e4a 100644 --- a/templates/goshared/const.go +++ b/templates/goshared/const.go @@ -1,9 +1,13 @@ package goshared -const constTpl = `{{ $r := .Rules }} +const constTpl = `{{ $f := .Field }}{{ $r := .Rules }} {{ if $r.Const }} if {{ accessor . }} != {{ lit $r.GetConst }} { + {{- if isEnum $f }} + err := {{ err . "value must equal " (enumVal $f $r.GetConst) }} + {{- else }} err := {{ err . "value must equal " $r.GetConst }} + {{- end }} if !all { return err } errors = append(errors, err) } diff --git a/templates/goshared/in.go b/templates/goshared/in.go index b6fdfa4c9..c2fc2e3bb 100644 --- a/templates/goshared/in.go +++ b/templates/goshared/in.go @@ -3,13 +3,21 @@ package goshared const inTpl = `{{ $f := .Field }}{{ $r := .Rules }} {{ if $r.In }} if _, ok := {{ lookup $f "InLookup" }}[{{ accessor . }}]; !ok { + {{- if isEnum $f }} + err := {{ err . "value must be in list " (enumList $f $r.In) }} + {{- else }} err := {{ err . "value must be in list " $r.In }} + {{- end }} if !all { return err } errors = append(errors, err) } {{ else if $r.NotIn }} if _, ok := {{ lookup $f "NotInLookup" }}[{{ accessor . }}]; ok { + {{- if isEnum $f }} + err := {{ err . "value must not be in list " (enumList $f $r.NotIn) }} + {{- else }} err := {{ err . "value must not be in list " $r.NotIn }} + {{- end }} if !all { return err } errors = append(errors, err) } diff --git a/templates/shared/BUILD.bazel b/templates/shared/BUILD.bazel index 2137345d7..0f9f5f151 100644 --- a/templates/shared/BUILD.bazel +++ b/templates/shared/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "functions.go", "reflection.go", "well_known.go", + "enums.go" ], importpath = "github.com/envoyproxy/protoc-gen-validate/templates/shared", visibility = ["//visibility:public"], diff --git a/templates/shared/enums.go b/templates/shared/enums.go new file mode 100644 index 000000000..cdc7e64b7 --- /dev/null +++ b/templates/shared/enums.go @@ -0,0 +1,46 @@ +package shared + +import ( + "fmt" + "strings" + + pgs "github.com/lyft/protoc-gen-star" +) + +func isEnum(f pgs.Field) bool { + return f.Type().IsEnum() +} + +func enumNamesMap(values []pgs.EnumValue) (m map[int32]string) { + m = make(map[int32]string) + for _, v := range values { + if _, exists := m[v.Value()]; !exists { + m[v.Value()] = v.Name().String() + } + } + return m +} + +// enumList - if type is ENUM, enum values are returned +func enumList(f pgs.Field, list []int32) string { + stringList := make([]string, 0, len(list)) + if enum := f.Type().Enum(); enum != nil { + names := enumNamesMap(enum.Values()) + for _, n := range list { + stringList = append(stringList, names[n]) + } + } else { + for _, n := range list { + stringList = append(stringList, fmt.Sprint(n)) + } + } + return "[" + strings.Join(stringList, " ") + "]" +} + +// enumVal - if type is ENUM, enum value is returned +func enumVal(f pgs.Field, val int32) string { + if enum := f.Type().Enum(); enum != nil { + return enumNamesMap(enum.Values())[val] + } + return fmt.Sprint(val) +} diff --git a/templates/shared/functions.go b/templates/shared/functions.go index 2bf111a45..3e5c26e20 100644 --- a/templates/shared/functions.go +++ b/templates/shared/functions.go @@ -3,7 +3,7 @@ package shared import ( "text/template" - "github.com/lyft/protoc-gen-star" + pgs "github.com/lyft/protoc-gen-star" ) func RegisterFunctions(tpl *template.Template, params pgs.Parameters) { @@ -16,5 +16,8 @@ func RegisterFunctions(tpl *template.Template, params pgs.Parameters) { "has": Has, "needs": Needs, "fileneeds": FileNeeds, + "isEnum": isEnum, + "enumList": enumList, + "enumVal": enumVal, }) }