From 59e75db2bfbd8b09c7eae02398bdc4f142aa3af7 Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Wed, 17 Apr 2024 13:52:15 -0400 Subject: [PATCH] Add editions helper functions for resolving features to protoutil (#283) These helpers, in particular `protoutil.ResolveFeature` and `protoutil.ResolveCustomFeature`, will be used from updated checks for `buf breaking`, which will allow the tool to understand the features-related semantics of the schema. This way, it can correctly report issues with incompatible changes to features in editions source files. And it can also allow changing a file's syntax (like migrating from proto2 or proto3 to editions) as long as there are no actual semantic changes to the schema. --- internal/editions/editions.go | 117 ++++++++- protoutil/editions.go | 140 +++++++++++ protoutil/editions_test.go | 451 ++++++++++++++++++++++++++++++++++ protoutil/protos.go | 5 +- 4 files changed, 707 insertions(+), 6 deletions(-) create mode 100644 protoutil/editions.go create mode 100644 protoutil/editions_test.go diff --git a/internal/editions/editions.go b/internal/editions/editions.go index 2bb7d432..53f361cd 100644 --- a/internal/editions/editions.go +++ b/internal/editions/editions.go @@ -23,9 +23,11 @@ import ( "sync" "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" ) var ( @@ -77,7 +79,7 @@ var _ HasFeatures = (*descriptorpb.MethodOptions)(nil) // override. If there is no overridden value, it returns a zero value. func ResolveFeature( element protoreflect.Descriptor, - field protoreflect.FieldDescriptor, + fields ...protoreflect.FieldDescriptor, ) (protoreflect.Value, error) { for { var features *descriptorpb.FeatureSet @@ -86,9 +88,25 @@ func ResolveFeature( features = withFeatures.GetFeatures() } - msgRef := features.ProtoReflect() - if msgRef.Has(field) { - return msgRef.Get(field), nil + msgRef, err := adaptFeatureSet(features, fields[0]) + if err != nil { + return protoreflect.Value{}, err + } + // Navigate the fields to find the value + var val protoreflect.Value + for i, field := range fields { + if i > 0 { + msgRef = val.Message() + } + if !msgRef.Has(field) { + val = protoreflect.Value{} + break + } + val = msgRef.Get(field) + } + if val.IsValid() { + // All fields were set! + return val, nil } parent := element.Parent() @@ -230,3 +248,94 @@ func GetFeatureDefault(edition descriptorpb.Edition, container protoreflect.Mess } return empty.Get(feature), nil } + +func adaptFeatureSet(msg *descriptorpb.FeatureSet, field protoreflect.FieldDescriptor) (protoreflect.Message, error) { + msgRef := msg.ProtoReflect() + if field.IsExtension() { + // Extensions can always be used directly with the feature set, even if + // field.ContainingMessage() != FeatureSetDescriptor. + if msgRef.Has(field) || len(msgRef.GetUnknown()) == 0 { + return msgRef, nil + } + // The field is not present, but the message has unrecognized values. So + // let's try to parse the unrecognized bytes, just in case they contain + // this extension. + temp := &descriptorpb.FeatureSet{} + unmarshaler := prototext.UnmarshalOptions{ + AllowPartial: true, + Resolver: resolverForExtension{field}, + } + if err := unmarshaler.Unmarshal(msgRef.GetUnknown(), temp); err != nil { + return nil, fmt.Errorf("failed to parse unrecognized fields of FeatureSet: %w", err) + } + return temp.ProtoReflect(), nil + } + + if field.ContainingMessage() == FeatureSetDescriptor { + // Known field, not dynamically generated. Can directly use with the feature set. + return msgRef, nil + } + + // If we get here, we have a dynamic field descriptor. We want to copy its + // value into a dynamic message, which requires marshalling/unmarshalling. + msgField := FeatureSetDescriptor.Fields().ByNumber(field.Number()) + // We only need to copy over the unrecognized bytes (if any) + // and the same field (if present). + data := msgRef.GetUnknown() + if msgField != nil && msgRef.Has(msgField) { + subset := &descriptorpb.FeatureSet{} + subset.ProtoReflect().Set(msgField, msgRef.Get(msgField)) + fieldBytes, err := proto.MarshalOptions{AllowPartial: true}.Marshal(subset) + if err != nil { + return nil, fmt.Errorf("failed to marshal FeatureSet field %s to bytes: %w", field.Name(), err) + } + data = append(data, fieldBytes...) + } + if len(data) == 0 { + // No relevant data to copy over, so we can just return + // a zero value message + return dynamicpb.NewMessageType(field.ContainingMessage()).Zero(), nil + } + + other := dynamicpb.NewMessage(field.ContainingMessage()) + // We don't need to use a resolver for this step because we know that + // field is not an extension. And features are not allowed to themselves + // have extensions. + if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(data, other); err != nil { + return nil, fmt.Errorf("failed to marshal FeatureSet field %s to bytes: %w", field.Name(), err) + } + return other, nil +} + +type resolverForExtension struct { + ext protoreflect.ExtensionDescriptor +} + +func (r resolverForExtension) FindMessageByName(_ protoreflect.FullName) (protoreflect.MessageType, error) { + return nil, protoregistry.NotFound +} + +func (r resolverForExtension) FindMessageByURL(_ string) (protoreflect.MessageType, error) { + return nil, protoregistry.NotFound +} + +func (r resolverForExtension) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { + if field == r.ext.FullName() { + return asExtensionType(r.ext), nil + } + return nil, protoregistry.NotFound +} + +func (r resolverForExtension) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { + if message == r.ext.ContainingMessage().FullName() && field == r.ext.Number() { + return asExtensionType(r.ext), nil + } + return nil, protoregistry.NotFound +} + +func asExtensionType(ext protoreflect.ExtensionDescriptor) protoreflect.ExtensionType { + if xtd, ok := ext.(protoreflect.ExtensionTypeDescriptor); ok { + return xtd.Type() + } + return dynamicpb.NewExtensionType(ext) +} diff --git a/protoutil/editions.go b/protoutil/editions.go new file mode 100644 index 00000000..fb21dff6 --- /dev/null +++ b/protoutil/editions.go @@ -0,0 +1,140 @@ +// Copyright 2020-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package protoutil + +import ( + "fmt" + + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" + + "github.com/bufbuild/protocompile/internal/editions" +) + +// GetFeatureDefault gets the default value for the given feature and the given +// edition. The given feature must represent a field of the google.protobuf.FeatureSet +// message and must not be an extension. +// +// If the given field is from a dynamically built descriptor (i.e. it's containing +// message descriptor is different from the linked-in descriptor for +// [*descriptorpb.FeatureSet]), the returned value may be a dynamic value. In such +// cases, the value may not be directly usable using [protoreflect.Message.Set] with +// an instance of [*descriptorpb.FeatureSet] and must instead be used with a +// [*dynamicpb.Message]. +// +// To get the default value of a custom feature, use [GetCustomFeatureDefault] +// instead. +func GetFeatureDefault(edition descriptorpb.Edition, feature protoreflect.FieldDescriptor) (protoreflect.Value, error) { + if feature.ContainingMessage().FullName() != editions.FeatureSetDescriptor.FullName() { + return protoreflect.Value{}, fmt.Errorf("feature %s is a field of %s but should be a field of %s", + feature.Name(), feature.ContainingMessage().FullName(), editions.FeatureSetDescriptor.FullName()) + } + var msgType protoreflect.MessageType + if feature.ContainingMessage() == editions.FeatureSetDescriptor { + msgType = editions.FeatureSetType + } else { + msgType = dynamicpb.NewMessageType(feature.ContainingMessage()) + } + return editions.GetFeatureDefault(edition, msgType, feature) +} + +// GetCustomFeatureDefault gets the default value for the given custom feature +// and given edition. A custom feature is a field whose containing message is the +// type of an extension field of google.protobuf.FeatureSet. The given extension +// describes that extension field and message type. The given feature must be a +// field of that extension's message type. +func GetCustomFeatureDefault(edition descriptorpb.Edition, extension protoreflect.ExtensionType, feature protoreflect.FieldDescriptor) (protoreflect.Value, error) { + extDesc := extension.TypeDescriptor() + if extDesc.ContainingMessage().FullName() != editions.FeatureSetDescriptor.FullName() { + return protoreflect.Value{}, fmt.Errorf("extension %s does not extend %s", extDesc.FullName(), editions.FeatureSetDescriptor.FullName()) + } + if extDesc.Message() == nil { + return protoreflect.Value{}, fmt.Errorf("extensions of %s should be messages; %s is instead %s", + editions.FeatureSetDescriptor.FullName(), extDesc.FullName(), extDesc.Kind().String()) + } + if feature.IsExtension() { + return protoreflect.Value{}, fmt.Errorf("feature %s is an extension, but feature extension %s may not itself have extensions", + feature.FullName(), extDesc.FullName()) + } + if feature.ContainingMessage().FullName() != extDesc.Message().FullName() { + return protoreflect.Value{}, fmt.Errorf("feature %s is a field of %s but should be a field of %s", + feature.Name(), feature.ContainingMessage().FullName(), extDesc.Message().FullName()) + } + if feature.ContainingMessage() != extDesc.Message() { + return protoreflect.Value{}, fmt.Errorf("feature %s has a different message descriptor from the given extension type for %s", + feature.Name(), extDesc.Message().FullName()) + } + return editions.GetFeatureDefault(edition, extension.Zero().Message().Type(), feature) +} + +// ResolveFeature resolves a feature for the given descriptor. +// +// If the given element is in a proto2 or proto3 syntax file, this skips +// resolution and just returns the relevant default (since such files are not +// allowed to override features). If neither the given element nor any of its +// ancestors override the given feature, the relevant default is returned. +// +// This has the same caveat as GetFeatureDefault if the given feature is from a +// dynamically built descriptor. +func ResolveFeature(element protoreflect.Descriptor, feature protoreflect.FieldDescriptor) (protoreflect.Value, error) { + edition := editions.GetEdition(element) + defaultVal, err := GetFeatureDefault(edition, feature) + if err != nil { + return protoreflect.Value{}, err + } + return resolveFeature(edition, defaultVal, element, feature) +} + +// ResolveCustomFeature resolves a custom feature for the given extension and +// field descriptor. +// +// The given extension must be an extension of google.protobuf.FeatureSet that +// represents a non-repeated message value. The given feature is a field in +// that extension's message type. +// +// If the given element is in a proto2 or proto3 syntax file, this skips +// resolution and just returns the relevant default (since such files are not +// allowed to override features). If neither the given element nor any of its +// ancestors override the given feature, the relevant default is returned. +func ResolveCustomFeature(element protoreflect.Descriptor, extension protoreflect.ExtensionType, feature protoreflect.FieldDescriptor) (protoreflect.Value, error) { + edition := editions.GetEdition(element) + defaultVal, err := GetCustomFeatureDefault(edition, extension, feature) + if err != nil { + return protoreflect.Value{}, err + } + return resolveFeature(edition, defaultVal, element, extension.TypeDescriptor(), feature) +} + +func resolveFeature( + edition descriptorpb.Edition, + defaultVal protoreflect.Value, + element protoreflect.Descriptor, + fields ...protoreflect.FieldDescriptor, +) (protoreflect.Value, error) { + if edition == descriptorpb.Edition_EDITION_PROTO2 || edition == descriptorpb.Edition_EDITION_PROTO3 { + // these syntax levels can't specify features, so we can short-circuit the search + // through the descriptor hierarchy for feature overrides + return defaultVal, nil + } + val, err := editions.ResolveFeature(element, fields...) + if err != nil { + return protoreflect.Value{}, err + } + if val.IsValid() { + return val, nil + } + return defaultVal, nil +} diff --git a/protoutil/editions_test.go b/protoutil/editions_test.go new file mode 100644 index 00000000..00061f24 --- /dev/null +++ b/protoutil/editions_test.go @@ -0,0 +1,451 @@ +// Copyright 2020-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package protoutil_test + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" + + "github.com/bufbuild/protocompile" + "github.com/bufbuild/protocompile/internal/editions" + "github.com/bufbuild/protocompile/linker" + "github.com/bufbuild/protocompile/protoutil" +) + +func TestMain(m *testing.M) { + // Enable just for tests. + editions.AllowEditions = true + status := m.Run() + os.Exit(status) +} + +func TestResolveFeature(t *testing.T) { + t.Parallel() + testResolveFeature(t) +} + +func TestResolveFeature_Dynamic(t *testing.T) { + t.Parallel() + descriptorProto := protodesc.ToFileDescriptorProto( + (*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile(), + ) + // Provide our own version of descriptor.proto, so the FeatureSet + // descriptor will be dynamically built. + testResolveFeature(t, descriptorProto) + + // Also test with an extra field (not recognized by descriptorpb). + t.Run("editions-new-field", func(t *testing.T) { + t.Parallel() + var found bool + descriptorProto := proto.Clone(descriptorProto).(*descriptorpb.FileDescriptorProto) //nolint:errcheck + for _, msg := range descriptorProto.MessageType { + if msg.GetName() == "FeatureSet" { + msg.Field = append(msg.Field, &descriptorpb.FieldDescriptorProto{ + Name: proto.String("fubar"), + Number: proto.Int32(8888), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_ENUM.Enum(), + TypeName: proto.String(".google.protobuf.FeatureSet.Fubar"), + JsonName: proto.String("fubar"), + Options: &descriptorpb.FieldOptions{ + Targets: []descriptorpb.FieldOptions_OptionTargetType{ + descriptorpb.FieldOptions_TARGET_TYPE_FILE, + descriptorpb.FieldOptions_TARGET_TYPE_MESSAGE, + descriptorpb.FieldOptions_TARGET_TYPE_SERVICE, + }, + EditionDefaults: []*descriptorpb.FieldOptions_EditionDefault{ + { + Edition: descriptorpb.Edition_EDITION_PROTO2.Enum(), + Value: proto.String("FOO"), + }, + { + Edition: descriptorpb.Edition_EDITION_2023.Enum(), + Value: proto.String("BAR"), + }, + }, + }, + }) + msg.EnumType = append(msg.EnumType, &descriptorpb.EnumDescriptorProto{ + Name: proto.String("Fubar"), + Value: []*descriptorpb.EnumValueDescriptorProto{ + { + Name: proto.String("FUBAR_UNKNOWN"), + Number: proto.Int32(0), + }, + { + Name: proto.String("FOO"), + Number: proto.Int32(1), + }, + { + Name: proto.String("BAR"), + Number: proto.Int32(2), + }, + { + Name: proto.String("BAZ"), + Number: proto.Int32(3), + }, + }, + }) + found = true + break + } + } + require.True(t, found) + + sourceResolver := &protocompile.SourceResolver{ + Accessor: protocompile.SourceAccessorFromMap(map[string]string{ + "test.proto": ` + edition = "2023"; + message Foo { + option features.fubar = FOO; + } + message Bar { + // default feature value, which is Bar + } + service Baz { + option features.fubar = BAZ; + rpc Do(Foo) returns (Bar); + }`, + }), + } + file, featureSetDescriptor := compileFile(t, "test.proto", sourceResolver, descriptorProto) + + feature := featureSetDescriptor.Fields().ByName("fubar") + val, err := protoutil.ResolveFeature(file, feature) + require.NoError(t, err) + // Value is the default for edition 2023 + require.Equal(t, protoreflect.EnumNumber(2), val.Enum()) + + elem := file.FindDescriptorByName("Foo") + require.NotNil(t, elem) + val, err = protoutil.ResolveFeature(elem, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(1), val.Enum()) + + elem = file.FindDescriptorByName("Bar") + require.NotNil(t, elem) + val, err = protoutil.ResolveFeature(elem, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(2), val.Enum()) + + elem = file.FindDescriptorByName("Baz") + require.NotNil(t, elem) + val, err = protoutil.ResolveFeature(elem, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(3), val.Enum()) + }) +} + +func testResolveFeature(t *testing.T, deps ...*descriptorpb.FileDescriptorProto) { + t.Run("proto2", func(t *testing.T) { + t.Parallel() + file, featureSetDescriptor := compileFile(t, "desc_test1.proto", nil, deps...) + + feature := featureSetDescriptor.Fields().ByName("json_format") + val, err := protoutil.ResolveFeature(file, feature) + require.NoError(t, err) + // Value is the default for proto2 + require.Equal(t, descriptorpb.FeatureSet_LEGACY_BEST_EFFORT.Number(), val.Enum()) + + // Same value for a field therein + field := file.FindDescriptorByName("testprotos.AnotherTestMessage.RockNRoll.beatles") + require.NotNil(t, field) + val, err = protoutil.ResolveFeature(field, feature) + require.NoError(t, err) + require.Equal(t, descriptorpb.FeatureSet_LEGACY_BEST_EFFORT.Number(), val.Enum()) + }) + + t.Run("proto3", func(t *testing.T) { + t.Parallel() + file, featureSetDescriptor := compileFile(t, "desc_test_proto3.proto", nil, deps...) + + feature := featureSetDescriptor.Fields().ByName("utf8_validation") + val, err := protoutil.ResolveFeature(file, feature) + require.NoError(t, err) + // Value is the default for proto3 + require.Equal(t, descriptorpb.FeatureSet_VERIFY.Number(), val.Enum()) + + // Same value for a field therein + field := file.FindDescriptorByName("testprotos.TestRequest.FlagsEntry.value") + require.NotNil(t, field) + val, err = protoutil.ResolveFeature(field, feature) + require.NoError(t, err) + require.Equal(t, descriptorpb.FeatureSet_VERIFY.Number(), val.Enum()) + }) + + t.Run("editions-defaults", func(t *testing.T) { + t.Parallel() + file, featureSetDescriptor := compileFile(t, "editions/all_default_features.proto", nil, deps...) + + feature := featureSetDescriptor.Fields().ByName("repeated_field_encoding") + val, err := protoutil.ResolveFeature(file, feature) + require.NoError(t, err) + // Value is the default for editions + require.Equal(t, descriptorpb.FeatureSet_PACKED.Number(), val.Enum()) + + // Same value for a field therein + field := file.FindDescriptorByName("foo.bar.Foo.Bar.abc") + require.NotNil(t, field) + val, err = protoutil.ResolveFeature(field, feature) + require.NoError(t, err) + require.Equal(t, descriptorpb.FeatureSet_PACKED.Number(), val.Enum()) + }) + + t.Run("editions-overrides", func(t *testing.T) { + t.Parallel() + file, featureSetDescriptor := compileFile(t, "editions/features_with_overrides.proto", nil, deps...) + + feature := featureSetDescriptor.Fields().ByName("field_presence") + val, err := protoutil.ResolveFeature(file, feature) + require.NoError(t, err) + // Value is from explicit file-wide default + require.Equal(t, descriptorpb.FeatureSet_IMPLICIT.Number(), val.Enum()) + + // Overridden value for a field therein + field := file.FindDescriptorByName("foo.bar.baz.Bar.left") + require.NotNil(t, field) + val, err = protoutil.ResolveFeature(field, feature) + require.NoError(t, err) + require.Equal(t, descriptorpb.FeatureSet_EXPLICIT.Number(), val.Enum()) + + // Let's check another feature + feature = featureSetDescriptor.Fields().ByName("utf8_validation") + val, err = protoutil.ResolveFeature(file, feature) + require.NoError(t, err) + // Value is the default for editions + require.Equal(t, descriptorpb.FeatureSet_VERIFY.Number(), val.Enum()) + + field = file.FindDescriptorByName("foo.bar.baz.Foo.JklEntry.key") + require.NotNil(t, field) + val, err = protoutil.ResolveFeature(field, feature) + require.NoError(t, err) + require.Equal(t, descriptorpb.FeatureSet_NONE.Number(), val.Enum()) + }) +} + +func TestResolveCustomFeature(t *testing.T) { + t.Parallel() + descriptorProto := protodesc.ToFileDescriptorProto( + (*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile(), + ) + optionsSource := ` + edition = "2023"; + package test; + import "google/protobuf/descriptor.proto"; + extend google.protobuf.FeatureSet { + CustomFeatures custom = 9996; + } + message CustomFeatures { + bool encabulate = 1 [ + targets=TARGET_TYPE_FILE, + targets=TARGET_TYPE_FIELD, + edition_defaults ={ + edition: EDITION_PROTO2 + value: "true" + }, + edition_defaults = { + edition: EDITION_2023 + value: "false" + } + ]; + Frob nitz = 2 [ + targets=TARGET_TYPE_FILE, + targets=TARGET_TYPE_MESSAGE, + edition_defaults = { + edition: EDITION_PROTO2 + value: "POWER_CYCLE" + }, + edition_defaults = { + edition: EDITION_PROTO3 + value: "RTFM" + }, + edition_defaults = { + edition: EDITION_2023 + value: "ID_10_T" + } + ]; + } + enum Frob { + FROB_UNKNOWN = 0; + POWER_CYCLE = 1; + RTFM = 2; + ID_10_T = 3; + } + ` + + // We can do proto2 and proto3 in the same way since they + // can't override feature values. + testCases := []struct { + syntax string + expectedEncabulate bool + expectedNitz int32 + }{ + { + syntax: "proto2", + expectedEncabulate: true, + expectedNitz: 1, + }, + { + syntax: "proto3", + expectedEncabulate: true, + expectedNitz: 2, + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.syntax, func(t *testing.T) { + t.Parallel() + sourceResolver := &protocompile.SourceResolver{ + Accessor: protocompile.SourceAccessorFromMap(map[string]string{ + "options.proto": optionsSource, + "test.proto": ` + syntax = "` + testCase.syntax + `"; + import "options.proto"; + message Foo { + }`, + }), + } + file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto) + optionsFile := file.FindImportByPath("options.proto") + extType := dynamicpb.NewExtensionType(optionsFile.FindDescriptorByName("test.custom").(protoreflect.ExtensionDescriptor)) + feature := optionsFile.FindDescriptorByName("test.CustomFeatures.encabulate").(protoreflect.FieldDescriptor) //nolint:errcheck + + val, err := protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.expectedEncabulate, val.Bool()) + + // Same value for an element therein + elem := file.FindDescriptorByName("Foo") + require.NotNil(t, elem) + val, err = protoutil.ResolveCustomFeature(elem, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.expectedEncabulate, val.Bool()) + + // Check the other feature field, too + feature = optionsFile.FindDescriptorByName("test.CustomFeatures.nitz").(protoreflect.FieldDescriptor) //nolint:errcheck + val, err = protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(testCase.expectedNitz), val.Enum()) + + val, err = protoutil.ResolveCustomFeature(elem, extType, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(testCase.expectedNitz), val.Enum()) + }) + } + + t.Run("editions", func(t *testing.T) { + t.Parallel() + sourceResolver := &protocompile.SourceResolver{ + Accessor: protocompile.SourceAccessorFromMap(map[string]string{ + "options.proto": optionsSource, + "test.proto": ` + edition = "2023"; + import "options.proto"; + message Foo { + } + message Bar { + option features.(test.custom).nitz = RTFM; + string name = 1 [ + features.(test.custom).encabulate = true + ]; + bytes extra = 2; + }`, + }), + } + file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto) + optionsFile := file.FindImportByPath("options.proto") + extType := dynamicpb.NewExtensionType(optionsFile.FindDescriptorByName("test.custom").(protoreflect.ExtensionDescriptor)) + feature := optionsFile.FindDescriptorByName("test.CustomFeatures.encabulate").(protoreflect.FieldDescriptor) //nolint:errcheck + + val, err := protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + // Default for edition + require.False(t, val.Bool()) + + // Override + field := file.FindDescriptorByName("Bar.name") + require.NotNil(t, field) + val, err = protoutil.ResolveCustomFeature(field, extType, feature) + require.NoError(t, err) + require.True(t, val.Bool()) + + // Check the other feature field, too + feature = optionsFile.FindDescriptorByName("test.CustomFeatures.nitz").(protoreflect.FieldDescriptor) //nolint:errcheck + val, err = protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(3), val.Enum()) + + val, err = protoutil.ResolveCustomFeature(field, extType, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(2), val.Enum()) + }) +} + +func compileFile(t *testing.T, filename string, sources *protocompile.SourceResolver, deps ...*descriptorpb.FileDescriptorProto) (result linker.File, featureSet protoreflect.MessageDescriptor) { + t.Helper() + if sources == nil { + sources = &protocompile.SourceResolver{ + ImportPaths: []string{"../internal/testdata"}, + } + } + resolver := protocompile.Resolver(sources) + if len(deps) > 0 { + resolver = addDepsToResolver(resolver, deps...) + } + compiler := &protocompile.Compiler{Resolver: resolver} + names := make([]string, len(deps)+1) + names[0] = filename + for i := range deps { + names[i+1] = deps[i].GetName() + } + files, err := compiler.Compile(context.Background(), names...) + require.NoError(t, err) + + // See if compile included version of google.protobuf.FeatureSet + var featureSetDescriptor protoreflect.MessageDescriptor + desc, err := files.AsResolver().FindDescriptorByName(editions.FeatureSetDescriptor.FullName()) + if err != nil { + featureSetDescriptor = editions.FeatureSetDescriptor + } else { + featureSetDescriptor = desc.(protoreflect.MessageDescriptor) //nolint:errcheck + } + + return files[0], featureSetDescriptor +} + +func addDepsToResolver(resolver protocompile.Resolver, deps ...*descriptorpb.FileDescriptorProto) protocompile.Resolver { + if len(deps) == 0 { + return resolver + } + depsByPath := make(map[string]*descriptorpb.FileDescriptorProto, len(deps)) + for _, dep := range deps { + depsByPath[dep.GetName()] = dep + } + return protocompile.ResolverFunc(func(path string) (protocompile.SearchResult, error) { + file := depsByPath[path] + if file != nil { + return protocompile.SearchResult{Proto: file}, nil + } + return resolver.FindFileByPath(path) + }) +} diff --git a/protoutil/protos.go b/protoutil/protos.go index 4f0f3629..9c559993 100644 --- a/protoutil/protos.go +++ b/protoutil/protos.go @@ -14,11 +14,12 @@ // Package protoutil contains useful functions for interacting with descriptors. // For now these include only functions for efficiently converting descriptors -// produced by the compiler to descriptor protos. +// produced by the compiler to descriptor protos and functions for resolving +// "features" (a core concept of Protobuf Editions). // // Despite the fact that descriptor protos are mutable, calling code should NOT // mutate any of the protos returned from this package. For efficiency, some -// protos returned from this package may be part of internal state of a compiler +// values returned from this package may reference internal state of a compiler // result, and mutating the proto could corrupt or invalidate parts of that // result. package protoutil