From 726728cd2d848eea3f22339e0f815dfe25f47274 Mon Sep 17 00:00:00 2001 From: chankyin Date: Mon, 16 Dec 2024 18:37:20 +0800 Subject: [PATCH 1/2] feat(util/component): add Declared.SkipFutureChanges --- util/component/component.go | 119 ++++++++++++++++++++++++++---------- 1 file changed, 88 insertions(+), 31 deletions(-) diff --git a/util/component/component.go b/util/component/component.go index 231ac8a..c052d22 100644 --- a/util/component/component.go +++ b/util/component/component.go @@ -86,6 +86,7 @@ type Component interface { // Describes this component for dependency resolution. manifest() manifest + canBeMergedInto() bool // Updates the other component upon name duplication. mergeInto(other Component) []*depRequest @@ -356,21 +357,53 @@ type componentImpl[Args any, Options any, Deps any, State any] struct { // Merge arguments of this instantiation into the previous instantiation. // // Optionally requests extra dependencies. - onMerge func(*Args, *Deps, *DepRequests) + mergeIntoFn func(Component) []*depRequest + + // Records whether SkipFutureMerges() was called. + skipFutureMerges bool phase *atomic.Pointer[string] } -type setOnMerge[Args any, Deps any] interface { +type declaredComp[Args any, Deps any] interface { setOnMerge(onMerge func(*Args, *Deps, *DepRequests)) + setSkipFutureMerges() } -// -//nolint:unused // Implements unexported interface setOnMerge, false positive from unused lint +//nolint:unused // Implements unexported interface declaredComp, false positive from unused lint func (impl *componentImpl[Args, Options, Deps, State]) setOnMerge( onMerge func(*Args, *Deps, *DepRequests), ) { - impl.onMerge = onMerge + impl.mergeIntoFn = strictMergeIntoFn[Args, Options, Deps, State](impl.name, onMerge) +} + +//nolint:unused // Implements unexported interface declaredComp, false positive from unused lint +func (impl *componentImpl[Args, Options, Deps, State]) setSkipFutureMerges() { + impl.skipFutureMerges = true +} + +func strictMergeIntoFn[Args any, Options any, Deps any, State any]( + compName string, + onMerge func(*Args, *Deps, *DepRequests), +) func(Component) []*depRequest { + return func(other Component) []*depRequest { + if !other.canBeMergedInto() { + return []*depRequest{} + } + + if other, isValidType := other.(*componentImpl[Args, Options, Deps, State]); isValidType { + reqs := DepRequests{requests: nil} + onMerge(&other.Args, &other.Deps, &reqs) + return reqs.requests + } else { + panic(fmt.Sprintf( + "cannot merge %q [%v, %v, %v, %v] into incompatible Component type %T", + compName, + util.Type[Args](), util.Type[Options](), util.Type[Deps](), util.Type[State](), + other, + )) + } + } } const phaseStarted = "Started" @@ -388,7 +421,7 @@ func isPhaseReady(phase string) bool { // // Refer to package documentation for the description of the arguments. func Declare[Args any, Options any, Deps any, State any, Api any]( - name func(args Args) string, + nameFn func(args Args) string, newOptions func(args Args, fs *flag.FlagSet) Options, newDeps func(args Args, requests *DepRequests) Deps, init func(ctx context.Context, args Args, options Options, deps Deps) (*State, error), @@ -396,6 +429,7 @@ func Declare[Args any, Options any, Deps any, State any, Api any]( api func(d *Data[Args, Options, Deps, State]) Api, ) DeclaredCtor[Args, Deps, Api] { return func(args Args) Declared[Api] { + name := nameFn(args) impl := &componentImpl[Args, Options, Deps, State]{ Data: Data[Args, Options, Deps, State]{ Args: args, @@ -403,13 +437,13 @@ func Declare[Args any, Options any, Deps any, State any, Api any]( Deps: util.Zero[Deps](), State: nil, }, - name: name(args), - optionsFn: newOptions, - depsFn: newDeps, - init: init, - lifecycle: lifecycle, - onMerge: nil, - phase: nil, + name: name, + optionsFn: newOptions, + depsFn: newDeps, + init: init, + lifecycle: lifecycle, + mergeIntoFn: strictMergeIntoFn[Args, Options, Deps, State](name, func(*Args, *Deps, *DepRequests) {}), + phase: nil, } if start := impl.lifecycle.Start; start != nil { @@ -442,12 +476,26 @@ func Declare[Args any, Options any, Deps any, State any, Api any]( type DeclaredCtor[Args any, Deps any, Api any] func(Args) Declared[Api] +// Avoid merging with future dependency requests with the same name, silently dropping them instead. +// +// This is equivalent to [Declared.SkipFutureMerges], only differing by call site +// (before vs after passing args). +func (ctor DeclaredCtor[Args, Deps, Api]) SkipFutureMerges() DeclaredCtor[Args, Deps, Api] { + return func(args Args) Declared[Api] { + decl := ctor(args) + decl.SkipFutureMerges() + return decl + } +} + +// If another component with the same name already exists and was not created from `ApiOnly`, +// merge them together by calling `onMerge` on the states of the preceding instance. func (ctor DeclaredCtor[Args, Deps, Api]) WithMergeFn( onMerge func(*Args, *Deps, *DepRequests), ) DeclaredCtor[Args, Deps, Api] { return func(args Args) Declared[Api] { decl := ctor(args).(*declaredImpl[Args, Deps, Api]) - impl := decl.comp.(setOnMerge[Args, Deps]) + impl := decl.comp.(declaredComp[Args, Deps]) impl.setOnMerge(onMerge) return decl @@ -461,6 +509,17 @@ type Declared[Api any] interface { GetNew() (Component, func() Api) set(comp Component, typedApi func() Api) + + // Do not merge with future dependency requests with the same name, silently dropping them instead. + // Transitive dependencies from the skipped requests will not be processed. + // + // Custom main files may register components with `SkipFutureMerges` at the beginning + // to override the "default" implementation registered afterwards. + // The custom main file must be aware of the actual implementation getting skipped + // and ensure that the overriding implementation (the receiver of this method) replaces it fully. + // + // Always returns the receiver. + SkipFutureMerges() Declared[Api] } // A component declaration. @@ -483,6 +542,11 @@ func (decl *declaredImpl[Args, Deps, Api]) set(comp Component, typedApi func() A decl.api = typedApi } +func (decl *declaredImpl[Args, Deps, Api]) SkipFutureMerges() Declared[Api] { + decl.comp.(declaredComp[Args, Deps]).setSkipFutureMerges() + return decl +} + //nolint:unused // Used from asRawDep, false positive from unused lint type rawDep[Api any] struct { api *func() Api @@ -534,24 +598,12 @@ func (impl *componentImpl[Args, Options, Deps, State]) manifest() manifest { } } -func (impl *componentImpl[Args, Options, Deps, State]) mergeInto(other Component) []*depRequest { - switch other := other.(type) { - case *componentImpl[Args, Options, Deps, State]: - deps := DepRequests{requests: nil} - - if impl.onMerge != nil { - impl.onMerge(&other.Args, &other.Deps, &deps) - } - - return deps.requests - - case emptyComponent: - // do not merge into empty components since the implementation is exclusively determined by the test case - return []*depRequest{} +func (impl *componentImpl[Args, Options, Deps, State]) canBeMergedInto() bool { + return !impl.skipFutureMerges +} - default: - panic(fmt.Sprintf("cannot merge %q (%T) into incompatible Component type %T", impl.name, impl, other)) - } +func (impl *componentImpl[Args, Options, Deps, State]) mergeInto(other Component) []*depRequest { + return impl.mergeIntoFn(other) } func (impl *componentImpl[Args, Options, Deps, State]) AddFlags(fs *flag.FlagSet) { @@ -628,6 +680,11 @@ func (comp emptyComponent) manifest() manifest { } } +func (comp emptyComponent) canBeMergedInto() bool { + // do not merge into empty components since the implementation is exclusively determined by the test case + return false +} + func (comp emptyComponent) mergeInto(other Component) []*depRequest { panic(fmt.Sprintf("component %q is already registered as %T", comp.name, other)) } From 857451312675b9206377bb4f70dcb9ed10cf9bdd Mon Sep 17 00:00:00 2001 From: chankyin Date: Tue, 17 Dec 2024 18:01:53 +0800 Subject: [PATCH 2/2] test(util/component): add unit tests for component deduplication --- util/cmd/cmd.go | 6 +- util/cmd/cmd_test.go | 217 ++++++++++++++++++++++++++++++++++++ util/component/component.go | 83 ++++++++------ 3 files changed, 268 insertions(+), 38 deletions(-) create mode 100644 util/cmd/cmd_test.go diff --git a/util/cmd/cmd.go b/util/cmd/cmd.go index 5447690..a819e3a 100644 --- a/util/cmd/cmd.go +++ b/util/cmd/cmd.go @@ -120,12 +120,16 @@ func tryRun(requests []func(*component.DepRequests)) error { // and does not block until shutdown. // Returns as soon as startup completes to allow the caller to orchestrate integration tests. func MockStartup(ctx context.Context, requests []func(*component.DepRequests)) component.ApiMap { + return MockStartupWithCliArgs(ctx, requests, []string{}) +} + +func MockStartupWithCliArgs(ctx context.Context, requests []func(*component.DepRequests), cliArgs []string) component.ApiMap { components := component.ResolveList(requests) fs := new(pflag.FlagSet) setupFlags(components, fs) - if err := fs.Parse([]string{}); err != nil { + if err := fs.Parse(cliArgs); err != nil { panic(err) } diff --git a/util/cmd/cmd_test.go b/util/cmd/cmd_test.go new file mode 100644 index 0000000..16460e5 --- /dev/null +++ b/util/cmd/cmd_test.go @@ -0,0 +1,217 @@ +// Copyright 2024 The Podseidon Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd_test + +import ( + "context" + "flag" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/utils/ptr" + + "github.com/kubewharf/podseidon/util/cmd" + "github.com/kubewharf/podseidon/util/component" + "github.com/kubewharf/podseidon/util/util" +) + +var TestingT = component.Declare( + func(_ *testing.T) string { return "testing" }, + func(_ *testing.T, _ *flag.FlagSet) util.Empty { return util.Empty{} }, + func(_ *testing.T, _ *component.DepRequests) util.Empty { return util.Empty{} }, + func(_ context.Context, _ *testing.T, _ util.Empty, _ util.Empty) (*util.Empty, error) { + return &util.Empty{}, nil + }, + component.Lifecycle[*testing.T, util.Empty, util.Empty, util.Empty]{ + Start: nil, + Join: nil, + HealthChecks: nil, + }, + func(d *component.Data[*testing.T, util.Empty, util.Empty, util.Empty]) *testing.T { return d.Args }, +) + +var Counter = component.Declare[CounterArgs, CounterOptions, CounterDeps, CounterState, CounterApi]( + func(args CounterArgs) string { return args.Name }, + func(_ CounterArgs, fs *flag.FlagSet) CounterOptions { + return CounterOptions{ + Multiplier: fs.Int("multiplier", 10, ""), + ExpectMutableField: fs.String("expect", "", ""), + } + }, + func(_ CounterArgs, reqs *component.DepRequests) CounterDeps { + return CounterDeps{Testing: component.DepPtr(reqs, TestingT(nil))} + }, + func(_ context.Context, _ CounterArgs, _ CounterOptions, _ CounterDeps) (*CounterState, error) { + return &CounterState{mutableField: []string{}}, nil + }, + component.Lifecycle[CounterArgs, CounterOptions, CounterDeps, CounterState]{ + Start: func(_ context.Context, _ *CounterArgs, options *CounterOptions, deps *CounterDeps, state *CounterState) error { + assert.Equalf( + deps.Testing.Get(), + *options.ExpectMutableField, + strings.Join(state.mutableField, ","), + "start is called strictly after all init", + ) + return nil + }, + Join: nil, + HealthChecks: nil, + }, + func(d *component.Data[CounterArgs, CounterOptions, CounterDeps, CounterState]) CounterApi { + return CounterApi{multiplier: *d.Options.Multiplier, state: d.State} + }, +) + +type CounterArgs struct { + Name string +} + +type CounterOptions struct { + Multiplier *int + ExpectMutableField *string +} + +type CounterDeps struct { + Testing component.Dep[*testing.T] +} + +type CounterState struct { + mutableField []string +} + +type CounterApi struct { + multiplier int + state *CounterState +} + +func (api CounterApi) Update(value int) { + api.state.mutableField = append(api.state.mutableField, fmt.Sprint(api.multiplier*value)) +} + +// the pointer is intentionally added to trigger panic when it is unused despite expected to be used. +func Mutator(mutatorArgs MutatorArgs, counterDepName *string, mutatorDepNames *[]string) component.Declared[util.Empty] { + return component.Declare( + func(args MutatorArgs) string { return args.Name }, + func(_ MutatorArgs, fs *flag.FlagSet) MutatorOptions { + return MutatorOptions{Value: fs.Int("value", 0, "")} + }, + func(_ MutatorArgs, reqs *component.DepRequests) MutatorDeps { + counterDep := component.DepPtr(reqs, Counter(CounterArgs{Name: *counterDepName})) + + for _, mutatorDepName := range *mutatorDepNames { + component.DepPtr(reqs, Mutator(MutatorArgs{Name: mutatorDepName}, nil, nil)) + } + + return MutatorDeps{ + Counter: counterDep, + } + }, + func(_ context.Context, _ MutatorArgs, options MutatorOptions, deps MutatorDeps) (*util.Empty, error) { + deps.Counter.Get().Update(*options.Value) + return &util.Empty{}, nil + }, + component.Lifecycle[MutatorArgs, MutatorOptions, MutatorDeps, util.Empty]{ + Start: nil, + Join: nil, + HealthChecks: nil, + }, + func(*component.Data[MutatorArgs, MutatorOptions, MutatorDeps, util.Empty]) util.Empty { + return util.Empty{} + }, + )(mutatorArgs) +} + +type MutatorArgs struct { + Name string +} + +type MutatorOptions struct { + Value *int +} + +type MutatorDeps struct { + Counter component.Dep[CounterApi] +} + +func TestDeduplication(t *testing.T) { + t.Parallel() + + cmd.MockStartupWithCliArgs( + context.Background(), + []func(*component.DepRequests){ + component.RequireDep(TestingT(t).SkipFutureMerges()), + component.RequireDep(Mutator(MutatorArgs{ + Name: "mutator1", + }, ptr.To("counter1"), ptr.To[[]string](nil))), + component.RequireDep(Mutator(MutatorArgs{ + Name: "mutator2", + }, ptr.To("counter1"), ptr.To([]string{"mutator1"}))), + component.RequireDep(Mutator(MutatorArgs{ + Name: "mutator3", + }, ptr.To("counter2"), ptr.To([]string{"mutator1"}))), + }, + []string{ + "--counter1-multiplier=100", + "--counter1-expect=100,200", + "--counter2-expect=30", + "--mutator1-value=1", + "--mutator2-value=2", + "--mutator3-value=3", + }, + ) +} + +func Cyclic(name string, deps ...func() component.Declared[util.Empty]) component.Declared[util.Empty] { + return component.Declare( + func(_ util.Empty) string { return name }, + func(_ util.Empty, _ *flag.FlagSet) util.Empty { return util.Empty{} }, + func(_ util.Empty, reqs *component.DepRequests) util.Empty { + for _, dep := range deps { + component.DepPtr(reqs, dep()) + } + + return util.Empty{} + }, + func(_ context.Context, _ util.Empty, _ util.Empty, _ util.Empty) (*util.Empty, error) { + return &util.Empty{}, nil + }, + component.Lifecycle[util.Empty, util.Empty, util.Empty, util.Empty]{ + Start: nil, + Join: nil, + HealthChecks: nil, + }, + func(d *component.Data[util.Empty, util.Empty, util.Empty, util.Empty]) util.Empty { return d.Args }, + )(util.Empty{}) +} + +func Cyclic1() component.Declared[util.Empty] { return Cyclic("cyclic1", Cyclic2) } +func Cyclic2() component.Declared[util.Empty] { return Cyclic("cyclic2", Cyclic1) } + +func TestCyclicDependency(t *testing.T) { + t.Parallel() + + assert.PanicsWithValue(t, `cyclic dependency detected: "cyclic1"`, func() { + cmd.MockStartupWithCliArgs( + context.Background(), + []func(*component.DepRequests){ + component.RequireDep(Cyclic1()), + component.RequireDep(Cyclic2()), + }, + []string{}, + ) + }) +} diff --git a/util/component/component.go b/util/component/component.go index c052d22..c33bde1 100644 --- a/util/component/component.go +++ b/util/component/component.go @@ -83,8 +83,10 @@ var ErrRecursiveDependencies = errors.TagErrorf( // // This interface is only useful for lifecycle orchestration and should not be implemented by other packages. type Component interface { + getName() string + // Describes this component for dependency resolution. - manifest() manifest + depRequests() []*depRequest canBeMergedInto() bool // Updates the other component upon name duplication. @@ -106,11 +108,6 @@ type Component interface { RegisterHealthChecks(handler *healthz.Handler, onFail func(name string, err error)) } -type manifest struct { - Name string - Dependencies []*depRequest -} - // A registry of dependencies requested by components. type DepRequests struct { requests []*depRequest @@ -164,7 +161,7 @@ func DepPtr[Api any](requests *DepRequests, base Declared[Api]) Dep[Api] { if !ok { panic(fmt.Sprintf( "Components of types %T and %T declared the same name %q with incompatible APIs %T and %v", - comp, base, comp.manifest().Name, util.Type[Api]().Out(0), reflect.TypeOf(api).Out(0), + comp, base, comp.getName(), util.Type[Api]().Out(0), reflect.TypeOf(api).Out(0), )) } @@ -258,10 +255,14 @@ func resolveRequest( request *depRequest, ) (string, Component, any) { requestComp, requestApi := request.getNew() - manifest := requestComp.manifest() + name := requestComp.getName() // already exists, return previous value - if prev, hasPrev := componentMap[manifest.Name]; hasPrev { + if prev, hasPrev := componentMap[name]; hasPrev { + if prev == nil { + panic(fmt.Sprintf("cyclic dependency detected: %q", name)) + } + deps := requestComp.mergeInto(prev.comp) // resolve incremental dependencies @@ -271,25 +272,27 @@ func resolveRequest( prev.deps.Insert(depName) } - return manifest.Name, prev.comp, prev.apiGetter + return name, prev.comp, prev.apiGetter } + componentMap[name] = nil + requestDeps := sets.New[string]() // new component; resolve dependencies, init and return the instance we got - for _, dep := range manifest.Dependencies { + for _, dep := range requestComp.depRequests() { depName, depComp, depApi := resolveRequest(componentMap, dep) dep.set(depComp, depApi) requestDeps.Insert(depName) } - componentMap[manifest.Name] = &componentMapEntry{ + componentMap[name] = &componentMapEntry{ comp: requestComp, apiGetter: requestApi, deps: requestDeps, } - return manifest.Name, requestComp, requestApi + return name, requestComp, requestApi } // Accessor to interact with components by name. @@ -391,18 +394,20 @@ func strictMergeIntoFn[Args any, Options any, Deps any, State any]( return []*depRequest{} } - if other, isValidType := other.(*componentImpl[Args, Options, Deps, State]); isValidType { - reqs := DepRequests{requests: nil} - onMerge(&other.Args, &other.Deps, &reqs) - return reqs.requests - } else { + otherTyped, isValidType := other.(*componentImpl[Args, Options, Deps, State]) + if !isValidType { panic(fmt.Sprintf( "cannot merge %q [%v, %v, %v, %v] into incompatible Component type %T", compName, util.Type[Args](), util.Type[Options](), util.Type[Deps](), util.Type[State](), - other, + otherTyped, )) } + + reqs := DepRequests{requests: nil} + onMerge(&otherTyped.Args, &otherTyped.Deps, &reqs) + + return reqs.requests } } @@ -437,13 +442,14 @@ func Declare[Args any, Options any, Deps any, State any, Api any]( Deps: util.Zero[Deps](), State: nil, }, - name: name, - optionsFn: newOptions, - depsFn: newDeps, - init: init, - lifecycle: lifecycle, - mergeIntoFn: strictMergeIntoFn[Args, Options, Deps, State](name, func(*Args, *Deps, *DepRequests) {}), - phase: nil, + name: name, + optionsFn: newOptions, + depsFn: newDeps, + init: init, + lifecycle: lifecycle, + mergeIntoFn: strictMergeIntoFn[Args, Options, Deps, State](name, func(*Args, *Deps, *DepRequests) {}), + skipFutureMerges: false, + phase: nil, } if start := impl.lifecycle.Start; start != nil { @@ -484,6 +490,7 @@ func (ctor DeclaredCtor[Args, Deps, Api]) SkipFutureMerges() DeclaredCtor[Args, return func(args Args) Declared[Api] { decl := ctor(args) decl.SkipFutureMerges() + return decl } } @@ -588,14 +595,15 @@ type Lifecycle[Args any, Options any, Deps any, State any] struct { // and returns a non-nil error for unready status. type HealthChecks map[string]func() error -func (impl *componentImpl[Args, Options, Deps, State]) manifest() manifest { +func (impl *componentImpl[Args, Options, Deps, State]) getName() string { + return impl.name +} + +func (impl *componentImpl[Args, Options, Deps, State]) depRequests() []*depRequest { deps := DepRequests{requests: []*depRequest{}} impl.Deps = impl.depsFn(impl.Args, &deps) - return manifest{ - Name: impl.name, - Dependencies: deps.requests, - } + return deps.requests } func (impl *componentImpl[Args, Options, Deps, State]) canBeMergedInto() bool { @@ -673,14 +681,15 @@ type emptyComponent struct { name string } -func (comp emptyComponent) manifest() manifest { - return manifest{ - Name: comp.name, - Dependencies: []*depRequest{}, - } +func (comp emptyComponent) getName() string { + return comp.name +} + +func (emptyComponent) depRequests() []*depRequest { + return nil } -func (comp emptyComponent) canBeMergedInto() bool { +func (emptyComponent) canBeMergedInto() bool { // do not merge into empty components since the implementation is exclusively determined by the test case return false }