Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Filter interface #185

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions any.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type anyPB struct {
NotInValue protoreflect.Value
}

func (a anyPB) Evaluate(val protoreflect.Value, failFast bool) error {
func (a anyPB) Evaluate(_ protoreflect.Message, val protoreflect.Value, cfg *validationConfig) error {
typeURL := val.Message().Get(a.TypeURLDescriptor).String()

err := &ValidationError{}
Expand All @@ -76,7 +76,7 @@ func (a anyPB) Evaluate(val protoreflect.Value, failFast bool) error {
RuleValue: a.InValue,
RuleDescriptor: anyInRuleDescriptor,
})
if failFast {
if cfg.failFast {
return err
}
}
Expand Down
1 change: 1 addition & 0 deletions buf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ lint:
ignore_only:
PROTOVALIDATE:
- proto/tests/example/v1/validations.proto
- proto/tests/example/v1/filter.proto
breaking:
use:
- FILE
6 changes: 3 additions & 3 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (bldr *builder) processOneofConstraints(
Descriptor: oneofDesc,
Required: oneofConstraints.GetRequired(),
}
msgEval.Append(oneofEval)
msgEval.AppendNested(oneofEval)
}
}

Expand All @@ -203,7 +203,7 @@ func (bldr *builder) processFields(
msgEval.Err = err
return
}
msgEval.Append(fldEval)
msgEval.AppendNested(fldEval)
}
}

Expand Down Expand Up @@ -337,7 +337,7 @@ func (bldr *builder) processEmbeddedMessage(
"failed to compile embedded type %s for %s: %w",
fdesc.Message().FullName(), fdesc.FullName(), err)}
}
valEval.Append(&embeddedMessage{
valEval.AppendNested(&embeddedMessage{
base: newBase(valEval),
message: embedEval,
})
Expand Down
8 changes: 4 additions & 4 deletions cel.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ type celPrograms struct {
programSet
}

func (c celPrograms) Evaluate(val protoreflect.Value, failFast bool) error {
err := c.programSet.Eval(val, failFast)
func (c celPrograms) Evaluate(_ protoreflect.Message, val protoreflect.Value, cfg *validationConfig) error {
err := c.programSet.Eval(val, cfg)
if err != nil {
var valErr *ValidationError
if errors.As(err, &valErr) {
Expand All @@ -42,8 +42,8 @@ func (c celPrograms) Evaluate(val protoreflect.Value, failFast bool) error {
return err
}

func (c celPrograms) EvaluateMessage(msg protoreflect.Message, failFast bool) error {
return c.programSet.Eval(protoreflect.ValueOfMessage(msg), failFast)
func (c celPrograms) EvaluateMessage(msg protoreflect.Message, cfg *validationConfig) error {
return c.programSet.Eval(protoreflect.ValueOfMessage(msg), cfg)
}

func (c celPrograms) Tautology() bool {
Expand Down
2 changes: 1 addition & 1 deletion enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type definedEnum struct {
ValueDescriptors protoreflect.EnumValueDescriptors
}

func (d definedEnum) Evaluate(val protoreflect.Value, _ bool) error {
func (d definedEnum) Evaluate(_ protoreflect.Message, val protoreflect.Value, _ *validationConfig) error {
if d.ValueDescriptors.ByNumber(val.Enum()) == nil {
return &ValidationError{Violations: []*Violation{{
Proto: &validate.Violation{
Expand Down
6 changes: 3 additions & 3 deletions error_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
// to failFast or the result is not a ValidationError).
//
//nolint:errorlint
func mergeViolations(dst, src error, failFast bool) (ok bool, err error) {
func mergeViolations(dst, src error, cfg *validationConfig) (ok bool, err error) {
if src == nil {
return true, dst
}
Expand All @@ -42,7 +42,7 @@ func mergeViolations(dst, src error, failFast bool) (ok bool, err error) {
}

if dst == nil {
return !(failFast && len(srcValErrs.Violations) > 0), src
return !(cfg.failFast && len(srcValErrs.Violations) > 0), src
}

dstValErrs, ok := dst.(*ValidationError)
Expand All @@ -52,7 +52,7 @@ func mergeViolations(dst, src error, failFast bool) (ok bool, err error) {
}

dstValErrs.Violations = append(dstValErrs.Violations, srcValErrs.Violations...)
return !(failFast && len(dstValErrs.Violations) > 0), dst
return !(cfg.failFast && len(dstValErrs.Violations) > 0), dst
}

// fieldPathElement returns a buf.validate.fieldPathElement that corresponds to
Expand Down
24 changes: 12 additions & 12 deletions error_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ func TestMerge(t *testing.T) {

t.Run("no errors", func(t *testing.T) {
t.Parallel()
ok, err := mergeViolations(nil, nil, true)
ok, err := mergeViolations(nil, nil, &validationConfig{failFast: true})
require.NoError(t, err)
assert.True(t, ok)
ok, err = mergeViolations(nil, nil, false)
ok, err = mergeViolations(nil, nil, &validationConfig{failFast: false})
require.NoError(t, err)
assert.True(t, ok)
})
Expand All @@ -43,23 +43,23 @@ func TestMerge(t *testing.T) {
t.Run("non-validation", func(t *testing.T) {
t.Parallel()
someErr := errors.New("some error")
ok, err := mergeViolations(nil, someErr, true)
ok, err := mergeViolations(nil, someErr, &validationConfig{failFast: true})
assert.Equal(t, someErr, err)
assert.False(t, ok)
ok, err = mergeViolations(nil, someErr, false)
ok, err = mergeViolations(nil, someErr, &validationConfig{failFast: false})
assert.Equal(t, someErr, err)
assert.False(t, ok)
})

t.Run("validation", func(t *testing.T) {
t.Parallel()
exErr := &ValidationError{Violations: []*Violation{{Proto: &validate.Violation{ConstraintId: proto.String("foo")}}}}
ok, err := mergeViolations(nil, exErr, true)
ok, err := mergeViolations(nil, exErr, &validationConfig{failFast: true})
var valErr *ValidationError
require.ErrorAs(t, err, &valErr)
assert.True(t, proto.Equal(exErr.ToProto(), valErr.ToProto()))
assert.False(t, ok)
ok, err = mergeViolations(nil, exErr, false)
ok, err = mergeViolations(nil, exErr, &validationConfig{failFast: false})
require.ErrorAs(t, err, &valErr)
assert.True(t, proto.Equal(exErr.ToProto(), valErr.ToProto()))
assert.True(t, ok)
Expand All @@ -73,10 +73,10 @@ func TestMerge(t *testing.T) {
t.Parallel()
dstErr := errors.New("some error")
srcErr := &ValidationError{Violations: []*Violation{{Proto: &validate.Violation{ConstraintId: proto.String("foo")}}}}
ok, err := mergeViolations(dstErr, srcErr, true)
ok, err := mergeViolations(dstErr, srcErr, &validationConfig{failFast: true})
assert.Equal(t, dstErr, err)
assert.False(t, ok)
ok, err = mergeViolations(dstErr, srcErr, false)
ok, err = mergeViolations(dstErr, srcErr, &validationConfig{failFast: false})
assert.Equal(t, dstErr, err)
assert.False(t, ok)
})
Expand All @@ -85,10 +85,10 @@ func TestMerge(t *testing.T) {
t.Parallel()
dstErr := &ValidationError{Violations: []*Violation{{Proto: &validate.Violation{ConstraintId: proto.String("foo")}}}}
srcErr := errors.New("some error")
ok, err := mergeViolations(dstErr, srcErr, true)
ok, err := mergeViolations(dstErr, srcErr, &validationConfig{failFast: true})
assert.Equal(t, srcErr, err)
assert.False(t, ok)
ok, err = mergeViolations(dstErr, srcErr, false)
ok, err = mergeViolations(dstErr, srcErr, &validationConfig{failFast: false})
assert.Equal(t, srcErr, err)
assert.False(t, ok)
})
Expand All @@ -102,13 +102,13 @@ func TestMerge(t *testing.T) {
{Proto: &validate.Violation{ConstraintId: proto.String("foo")}},
{Proto: &validate.Violation{ConstraintId: proto.String("bar")}},
}}
ok, err := mergeViolations(dstErr, srcErr, true)
ok, err := mergeViolations(dstErr, srcErr, &validationConfig{failFast: true})
var valErr *ValidationError
require.ErrorAs(t, err, &valErr)
assert.True(t, proto.Equal(exErr.ToProto(), valErr.ToProto()))
assert.False(t, ok)
dstErr = &ValidationError{Violations: []*Violation{{Proto: &validate.Violation{ConstraintId: proto.String("foo")}}}}
ok, err = mergeViolations(dstErr, srcErr, false)
ok, err = mergeViolations(dstErr, srcErr, &validationConfig{failFast: false})
require.ErrorAs(t, err, &valErr)
assert.True(t, proto.Equal(exErr.ToProto(), valErr.ToProto()))
assert.True(t, ok)
Expand Down
20 changes: 10 additions & 10 deletions evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type evaluator interface {
// - errors.CompilationError: this evaluator (or child evaluator) failed to
// build. This error is not recoverable.
//
Evaluate(val protoreflect.Value, failFast bool) error
Evaluate(msg protoreflect.Message, val protoreflect.Value, cfg *validationConfig) error
}

// messageEvaluator is essentially the same as evaluator, but specialized for
Expand All @@ -45,19 +45,19 @@ type messageEvaluator interface {

// EvaluateMessage checks that the provided msg is valid. See
// evaluator.Evaluate for behavior
EvaluateMessage(msg protoreflect.Message, failFast bool) error
EvaluateMessage(msg protoreflect.Message, cfg *validationConfig) error
}

// evaluators are a set of evaluator applied together to a value. Evaluation
// merges all errors.ValidationError violations or short-circuits if failFast is
// true or a different error is returned.
type evaluators []evaluator

func (e evaluators) Evaluate(val protoreflect.Value, failFast bool) (err error) {
func (e evaluators) Evaluate(msg protoreflect.Message, val protoreflect.Value, cfg *validationConfig) (err error) {
var ok bool
for _, eval := range e {
evalErr := eval.Evaluate(val, failFast)
if ok, err = mergeViolations(err, evalErr, failFast); !ok {
evalErr := eval.Evaluate(msg, val, cfg)
if ok, err = mergeViolations(err, evalErr, cfg); !ok {
return err
}
}
Expand All @@ -77,15 +77,15 @@ func (e evaluators) Tautology() bool {
// behavior details.
type messageEvaluators []messageEvaluator

func (m messageEvaluators) Evaluate(val protoreflect.Value, failFast bool) error {
return m.EvaluateMessage(val.Message(), failFast)
func (m messageEvaluators) Evaluate(val protoreflect.Value, cfg *validationConfig) error {
return m.EvaluateMessage(val.Message(), cfg)
}

func (m messageEvaluators) EvaluateMessage(msg protoreflect.Message, failFast bool) (err error) {
func (m messageEvaluators) EvaluateMessage(msg protoreflect.Message, cfg *validationConfig) (err error) {
var ok bool
for _, eval := range m {
evalErr := eval.EvaluateMessage(msg, failFast)
if ok, err = mergeViolations(err, evalErr, failFast); !ok {
evalErr := eval.EvaluateMessage(msg, cfg)
if ok, err = mergeViolations(err, evalErr, cfg); !ok {
return err
}
}
Expand Down
12 changes: 8 additions & 4 deletions field.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ type field struct {
Zero protoreflect.Value
}

func (f field) Evaluate(val protoreflect.Value, failFast bool) error {
return f.EvaluateMessage(val.Message(), failFast)
func (f field) Evaluate(_ protoreflect.Message, val protoreflect.Value, cfg *validationConfig) error {
return f.EvaluateMessage(val.Message(), cfg)
}

func (f field) EvaluateMessage(msg protoreflect.Message, failFast bool) (err error) {
func (f field) EvaluateMessage(msg protoreflect.Message, cfg *validationConfig) (err error) {
if !cfg.filter.ShouldValidate(msg, f.Value.Descriptor) {
return nil
}

if f.Required && !msg.Has(f.Value.Descriptor) {
return &ValidationError{Violations: []*Violation{{
Proto: &validate.Violation{
Expand All @@ -76,7 +80,7 @@ func (f field) EvaluateMessage(msg protoreflect.Message, failFast bool) (err err
if f.IgnoreDefault && val.Equal(f.Zero) {
return nil
}
return f.Value.Evaluate(val, failFast)
return f.Value.EvaluateField(msg, val, cfg, true)
}

func (f field) Tautology() bool {
Expand Down
50 changes: 50 additions & 0 deletions filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2023-2024 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package protovalidate

import "google.golang.org/protobuf/reflect/protoreflect"

// The Filter interface determines which constraints should be validated.
type Filter interface {
// ShouldValidate returns whether constraints for a given message, field, or
// oneof should be evaluated. For a message or oneof, this only determines
// whether message-level or oneof-level constraints should be evaluated, and
// ShouldValidate will still be called for each field in the message. If
// ShouldValidate returns false for a specific field, all constraints nested
// in submessages of that field will be skipped as well.
// For a message, the message argument provides the message itself. For a
// field or oneof, the message argument provides the containing message.
ShouldValidate(message protoreflect.Message, descriptor protoreflect.Descriptor) bool
}

// FilterFunc is a function type that implements the Filter interface, as a
// convenience for simple filters. A FilterFunc should follow the same semantics
// as the ShouldValidate method of Filter.
type FilterFunc func(protoreflect.Message, protoreflect.Descriptor) bool

func (f FilterFunc) ShouldValidate(
message protoreflect.Message,
descriptor protoreflect.Descriptor,
) bool {
return f(message, descriptor)
}

type nopFilter struct{}

func (nopFilter) ShouldValidate(_ protoreflect.Message, _ protoreflect.Descriptor) bool {
return true
}

var _ Filter = nopFilter{}
Loading