From b4164b2fbb7ffae8f115e6f4ad6086c2edae0e5d Mon Sep 17 00:00:00 2001 From: hovsep Date: Wed, 30 Oct 2024 01:10:08 +0200 Subject: [PATCH 1/2] Increase coverage --- component/component.go | 55 ++++++++------ component/component_test.go | 146 +++++++++++++++++++++++++++++------- port/collection.go | 18 ++--- port/collection_test.go | 46 ++++++++---- port/port.go | 18 ++--- port/port_test.go | 142 +++++++++++++++++++++++++++++------ signal/group_test.go | 3 +- 7 files changed, 320 insertions(+), 108 deletions(-) diff --git a/component/component.go b/component/component.go index 5fffd7f..5833b54 100644 --- a/component/component.go +++ b/component/component.go @@ -42,6 +42,25 @@ func (c *Component) WithDescription(description string) *Component { return c } +// withInputPorts sets input ports collection +func (c *Component) withInputPorts(collection *port.Collection) *Component { + if collection.HasChainError() { + return c.WithChainError(collection.ChainError()) + } + c.inputs = collection + return c +} + +// withOutputPorts sets input ports collection +func (c *Component) withOutputPorts(collection *port.Collection) *Component { + if collection.HasChainError() { + return c.WithChainError(collection.ChainError()) + } + + c.outputs = collection + return c +} + // WithInputs ads input ports func (c *Component) WithInputs(portNames ...string) *Component { if c.HasChainError() { @@ -52,8 +71,7 @@ func (c *Component) WithInputs(portNames ...string) *Component { if err != nil { return c.WithChainError(err) } - c.inputs = c.Inputs().With(ports...) - return c + return c.withInputPorts(c.Inputs().With(ports...)) } // WithOutputs adds output ports @@ -61,12 +79,12 @@ func (c *Component) WithOutputs(portNames ...string) *Component { if c.HasChainError() { return c } + ports, err := port.NewGroup(portNames...).Ports() if err != nil { return c.WithChainError(err) } - c.outputs = c.Outputs().With(ports...) - return c + return c.withOutputPorts(c.Outputs().With(ports...)) } // WithInputsIndexed creates multiple prefixed ports @@ -75,8 +93,7 @@ func (c *Component) WithInputsIndexed(prefix string, startIndex int, endIndex in return c } - c.inputs = c.Inputs().WithIndexed(prefix, startIndex, endIndex) - return c + return c.withInputPorts(c.Inputs().WithIndexed(prefix, startIndex, endIndex)) } // WithOutputsIndexed creates multiple prefixed ports @@ -85,8 +102,7 @@ func (c *Component) WithOutputsIndexed(prefix string, startIndex int, endIndex i return c } - c.outputs = c.Outputs().WithIndexed(prefix, startIndex, endIndex) - return c + return c.withOutputPorts(c.Outputs().WithIndexed(prefix, startIndex, endIndex)) } // WithActivationFunc sets activation function @@ -130,8 +146,7 @@ func (c *Component) Outputs() *port.Collection { func (c *Component) OutputByName(name string) *port.Port { outputPort := c.Outputs().ByName(name) if outputPort.HasChainError() { - c.SetChainError(outputPort.ChainError()) - return nil + return port.New("").WithChainError(outputPort.ChainError()) } return outputPort } @@ -143,7 +158,7 @@ func (c *Component) InputByName(name string) *port.Port { } inputPort := c.Inputs().ByName(name) if inputPort.HasChainError() { - c.SetChainError(inputPort.ChainError()) + return port.New("").WithChainError(inputPort.ChainError()) } return inputPort } @@ -161,20 +176,12 @@ func (c *Component) hasActivationFunction() bool { // @TODO: hide this method from user // @TODO: can we remove named return ? func (c *Component) MaybeActivate() (activationResult *ActivationResult) { - //Bubble up chain errors from ports - for _, p := range c.Inputs().PortsOrNil() { - if p.HasChainError() { - c.Inputs().SetChainError(p.ChainError()) - c.SetChainError(c.Inputs().ChainError()) - break - } + if c.Inputs().HasChainError() { + c.SetChainError(c.Inputs().ChainError()) } - for _, p := range c.Outputs().PortsOrNil() { - if p.HasChainError() { - c.Outputs().SetChainError(p.ChainError()) - c.SetChainError(c.Outputs().ChainError()) - break - } + + if c.Outputs().HasChainError() { + c.SetChainError(c.Outputs().ChainError()) } if c.HasChainError() { diff --git a/component/component_test.go b/component/component_test.go index bdc4130..ae0168a 100644 --- a/component/component_test.go +++ b/component/component_test.go @@ -41,44 +41,42 @@ func TestNewComponent(t *testing.T) { } func TestComponent_FlushOutputs(t *testing.T) { - sink := port.New("sink") - - componentWithNoOutputs := New("c1") - componentWithCleanOutputs := New("c1").WithOutputs("o1", "o2") - - componentWithAllOutputsSet := New("c1").WithOutputs("o1", "o2") - componentWithAllOutputsSet.Outputs().ByNames("o1").PutSignals(signal.New(777)) - componentWithAllOutputsSet.Outputs().ByNames("o2").PutSignals(signal.New(888)) - componentWithAllOutputsSet.Outputs().ByNames("o1", "o2").PipeTo(sink) - tests := []struct { - name string - component *Component - destPort *port.Port //Where the component flushes ALL it's inputs - assertions func(t *testing.T, componentAfterFlush *Component, destPort *port.Port) + name string + getComponent func() *Component + assertions func(t *testing.T, componentAfterFlush *Component) }{ { - name: "no outputs", - component: componentWithNoOutputs, - destPort: nil, - assertions: func(t *testing.T, componentAfterFlush *Component, destPort *port.Port) { + name: "no outputs", + getComponent: func() *Component { + return New("c1") + }, + assertions: func(t *testing.T, componentAfterFlush *Component) { assert.NotNil(t, componentAfterFlush.Outputs()) assert.Zero(t, componentAfterFlush.Outputs().Len()) }, }, { - name: "output has no signal set", - component: componentWithCleanOutputs, - destPort: nil, - assertions: func(t *testing.T, componentAfterFlush *Component, destPort *port.Port) { + name: "output has no signal set", + getComponent: func() *Component { + return New("c1").WithOutputs("o1", "o2") + }, + assertions: func(t *testing.T, componentAfterFlush *Component) { assert.False(t, componentAfterFlush.Outputs().AnyHasSignals()) }, }, { - name: "happy path", - component: componentWithAllOutputsSet, - destPort: sink, - assertions: func(t *testing.T, componentAfterFlush *Component, destPort *port.Port) { + name: "happy path", + getComponent: func() *Component { + sink := port.New("sink") + c := New("c1").WithOutputs("o1", "o2") + c.Outputs().ByNames("o1").PutSignals(signal.New(777)) + c.Outputs().ByNames("o2").PutSignals(signal.New(888)) + c.Outputs().ByNames("o1", "o2").PipeTo(sink) + return c + }, + assertions: func(t *testing.T, componentAfterFlush *Component) { + destPort := componentAfterFlush.OutputByName("o1").Pipes().PortsOrNil()[0] allPayloads, err := destPort.AllSignalsPayloads() assert.NoError(t, err) assert.Contains(t, allPayloads, 777) @@ -88,11 +86,25 @@ func TestComponent_FlushOutputs(t *testing.T) { assert.False(t, componentAfterFlush.Outputs().AnyHasSignals()) }, }, + { + name: "with chain error", + getComponent: func() *Component { + sink := port.New("sink") + c := New("c").WithOutputs("o1").WithChainError(errors.New("some error")) + //Lines below are ignored as error immediately propagates up to component level + c.Outputs().ByName("o1").PipeTo(sink) + c.Outputs().ByName("o1").PutSignals(signal.New("signal from component with chain error")) + return c + }, + assertions: func(t *testing.T, componentAfterFlush *Component) { + assert.False(t, componentAfterFlush.OutputByName("o1").HasPipes()) + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tt.component.FlushOutputs() - tt.assertions(t, tt.component, tt.destPort) + componentAfter := tt.getComponent().FlushOutputs() + tt.assertions(t, componentAfter) }) } } @@ -481,6 +493,28 @@ func TestComponent_MaybeActivate(t *testing.T) { err: NewErrWaitForInputs(true), }, }, + { + name: "with chain error from input port", + getComponent: func() *Component { + c := New("c").WithInputs("i1").WithOutputs("o1") + c.Inputs().With(port.New("p").WithChainError(errors.New("some error"))) + return c + }, + wantActivationResult: NewActivationResult("c"). + WithActivationCode(ActivationCodeUndefined). + WithChainError(errors.New("some error")), + }, + { + name: "with chain error from output port", + getComponent: func() *Component { + c := New("c").WithInputs("i1").WithOutputs("o1") + c.Outputs().With(port.New("p").WithChainError(errors.New("some error"))) + return c + }, + wantActivationResult: NewActivationResult("c"). + WithActivationCode(ActivationCodeUndefined). + WithChainError(errors.New("some error")), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -630,3 +664,59 @@ func TestComponent_WithLabels(t *testing.T) { }) } } + +func TestComponent_ShortcutMethods(t *testing.T) { + t.Run("InputByName", func(t *testing.T) { + c := New("c").WithInputs("a", "b", "c") + assert.Equal(t, port.New("b"), c.InputByName("b")) + }) + + t.Run("OutputByName", func(t *testing.T) { + c := New("c").WithOutputs("a", "b", "c") + assert.Equal(t, port.New("b"), c.OutputByName("b")) + }) +} + +func TestComponent_ClearInputs(t *testing.T) { + tests := []struct { + name string + getComponent func() *Component + assertions func(t *testing.T, componentAfter *Component) + }{ + { + name: "no side effects", + getComponent: func() *Component { + return New("c").WithInputs("i1").WithOutputs("o1") + }, + assertions: func(t *testing.T, componentAfter *Component) { + assert.Equal(t, 1, componentAfter.Inputs().Len()) + assert.Equal(t, 1, componentAfter.Outputs().Len()) + assert.False(t, componentAfter.Inputs().AnyHasSignals()) + assert.False(t, componentAfter.Outputs().AnyHasSignals()) + }, + }, + { + name: "only inputs are cleared", + getComponent: func() *Component { + c := New("c").WithInputs("i1").WithOutputs("o1") + c.Inputs().ByName("i1").PutSignals(signal.New(10)) + c.Outputs().ByName("o1").PutSignals(signal.New(20)) + return c + }, + assertions: func(t *testing.T, componentAfter *Component) { + assert.Equal(t, 1, componentAfter.Inputs().Len()) + assert.Equal(t, 1, componentAfter.Outputs().Len()) + assert.False(t, componentAfter.Inputs().AnyHasSignals()) + assert.True(t, componentAfter.Outputs().ByName("o1").HasSignals()) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + componentAfter := tt.getComponent().ClearInputs() + if tt.assertions != nil { + tt.assertions(t, componentAfter) + } + }) + } +} diff --git a/port/collection.go b/port/collection.go index 09ab582..6dde735 100644 --- a/port/collection.go +++ b/port/collection.go @@ -27,10 +27,11 @@ func NewCollection() *Collection { // ByName returns a port by its name func (collection *Collection) ByName(name string) *Port { if collection.HasChainError() { - return nil + return New("").WithChainError(collection.ChainError()) } port, ok := collection.ports[name] if !ok { + collection.SetChainError(ErrPortNotFoundInCollection) return New("").WithChainError(ErrPortNotFoundInCollection) } return port @@ -39,7 +40,7 @@ func (collection *Collection) ByName(name string) *Port { // ByNames returns multiple ports by their names func (collection *Collection) ByNames(names ...string) *Collection { if collection.HasChainError() { - return collection + return NewCollection().WithChainError(collection.ChainError()) } selectedPorts := NewCollection() @@ -84,10 +85,9 @@ func (collection *Collection) AllHaveSignals() bool { } // PutSignals adds buffer to every port in collection -// @TODO: return collection func (collection *Collection) PutSignals(signals ...*signal.Signal) *Collection { if collection.HasChainError() { - return collection + return NewCollection().WithChainError(collection.ChainError()) } for _, p := range collection.ports { @@ -115,7 +115,7 @@ func (collection *Collection) Clear() *Collection { // Flush flushes all ports in collection func (collection *Collection) Flush() *Collection { if collection.HasChainError() { - return collection + return NewCollection().WithChainError(collection.ChainError()) } for _, p := range collection.ports { @@ -144,15 +144,15 @@ func (collection *Collection) PipeTo(destPorts ...*Port) *Collection { // With adds ports to collection and returns it func (collection *Collection) With(ports ...*Port) *Collection { if collection.HasChainError() { - return collection + return NewCollection().WithChainError(collection.ChainError()) } for _, port := range ports { - collection.ports[port.Name()] = port - if port.HasChainError() { return collection.WithChainError(port.ChainError()) } + + collection.ports[port.Name()] = port } return collection @@ -161,7 +161,7 @@ func (collection *Collection) With(ports ...*Port) *Collection { // WithIndexed creates ports with names like "o1","o2","o3" and so on func (collection *Collection) WithIndexed(prefix string, startIndex int, endIndex int) *Collection { if collection.HasChainError() { - return collection + return NewCollection().WithChainError(collection.ChainError()) } indexedPorts, err := NewIndexedGroup(prefix, startIndex, endIndex).Ports() diff --git a/port/collection_test.go b/port/collection_test.go index f7518a0..99db3cb 100644 --- a/port/collection_test.go +++ b/port/collection_test.go @@ -98,7 +98,15 @@ func TestCollection_ByName(t *testing.T) { args: args{ name: "p3", }, - want: New("").WithChainError(errors.New("port not found")), + want: New("").WithChainError(ErrPortNotFoundInCollection), + }, + { + name: "with chain error", + collection: NewCollection().With(NewGroup("p1", "p2").PortsOrNil()...).WithChainError(errors.New("some error")), + args: args{ + name: "p1", + }, + want: New("").WithChainError(errors.New("some error")), }, } for _, tt := range tests { @@ -114,47 +122,55 @@ func TestCollection_ByNames(t *testing.T) { names []string } tests := []struct { - name string - ports *Collection - args args - want *Collection + name string + collection *Collection + args args + want *Collection }{ { - name: "single port found", - ports: NewCollection().With(NewGroup("p1", "p2").PortsOrNil()...), + name: "single port found", + collection: NewCollection().With(NewGroup("p1", "p2").PortsOrNil()...), args: args{ names: []string{"p1"}, }, want: NewCollection().With(New("p1")), }, { - name: "multiple ports found", - ports: NewCollection().With(NewGroup("p1", "p2", "p3", "p4").PortsOrNil()...), + name: "multiple ports found", + collection: NewCollection().With(NewGroup("p1", "p2", "p3", "p4").PortsOrNil()...), args: args{ names: []string{"p1", "p2"}, }, want: NewCollection().With(NewGroup("p1", "p2").PortsOrNil()...), }, { - name: "single port not found", - ports: NewCollection().With(NewGroup("p1", "p2").PortsOrNil()...), + name: "single port not found", + collection: NewCollection().With(NewGroup("p1", "p2").PortsOrNil()...), args: args{ names: []string{"p7"}, }, want: NewCollection(), }, { - name: "some ports not found", - ports: NewCollection().With(NewGroup("p1", "p2").PortsOrNil()...), + name: "some ports not found", + collection: NewCollection().With(NewGroup("p1", "p2").PortsOrNil()...), args: args{ names: []string{"p1", "p2", "p3"}, }, want: NewCollection().With(NewGroup("p1", "p2").PortsOrNil()...), }, + { + name: "with chain error", + collection: NewCollection().With(NewGroup("p1", "p2").PortsOrNil()...).WithChainError(errors.New("some error")), + args: args{ + names: []string{"p1", "p2"}, + }, + want: NewCollection().WithChainError(errors.New("some error")), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, tt.ports.ByNames(tt.args.names...)) + assert.Equal(t, tt.want, tt.collection.ByNames(tt.args.names...)) }) } } @@ -246,7 +262,7 @@ func TestCollection_Flush(t *testing.T) { assert.False(t, collection.ByName("src").HasSignals()) for _, destPort := range collection.ByName("src").Pipes().PortsOrNil() { assert.Equal(t, destPort.Buffer().Len(), 3) - allPayloads, err := destPort.Buffer().AllPayloads() + allPayloads, err := destPort.AllSignalsPayloads() assert.NoError(t, err) assert.Contains(t, allPayloads, 1) assert.Contains(t, allPayloads, 2) diff --git a/port/port.go b/port/port.go index 2d2223b..5e6652f 100644 --- a/port/port.go +++ b/port/port.go @@ -27,6 +27,7 @@ func New(name string) *Port { } // Buffer getter +// @TODO: maybe we can hide this and return signals to user code func (p *Port) Buffer() *signal.Group { if p.HasChainError() { return p.buffer.WithChainError(p.ChainError()) @@ -43,12 +44,13 @@ func (p *Port) Pipes() *Group { return p.pipes } -// setSignals sets buffer field -func (p *Port) setSignals(signals *signal.Group) { - if signals.HasChainError() { - p.SetChainError(signals.ChainError()) +// withBuffer sets buffer field +func (p *Port) withBuffer(buffer *signal.Group) *Port { + if buffer.HasChainError() { + return p.WithChainError(buffer.ChainError()) } - p.buffer = signals + p.buffer = buffer + return p } // PutSignals adds signals to buffer @@ -57,8 +59,7 @@ func (p *Port) PutSignals(signals ...*signal.Signal) *Port { if p.HasChainError() { return p } - p.setSignals(p.Buffer().With(signals...)) - return p + return p.withBuffer(p.Buffer().With(signals...)) } // WithSignals puts buffer and returns the port @@ -94,8 +95,7 @@ func (p *Port) Clear() *Port { if p.HasChainError() { return p } - p.setSignals(signal.NewGroup()) - return p + return p.withBuffer(signal.NewGroup()) } // Flush pushes buffer to pipes and clears the port diff --git a/port/port_test.go b/port/port_test.go index eec8170..c428784 100644 --- a/port/port_test.go +++ b/port/port_test.go @@ -1,6 +1,7 @@ package port import ( + "errors" "github.com/hovsep/fmesh/common" "github.com/hovsep/fmesh/signal" "github.com/stretchr/testify/assert" @@ -31,14 +32,14 @@ func TestPort_HasSignals(t *testing.T) { } } -func TestPort_Signals(t *testing.T) { +func TestPort_Buffer(t *testing.T) { tests := []struct { name string port *Port want *signal.Group }{ { - name: "no buffer", + name: "empty buffer", port: New("noSignal"), want: signal.NewGroup(), }, @@ -47,6 +48,11 @@ func TestPort_Signals(t *testing.T) { port: New("p").WithSignals(signal.New(123)), want: signal.NewGroup(123), }, + { + name: "with chain error", + port: New("p").WithChainError(errors.New("some error")), + want: signal.NewGroup().WithChainError(errors.New("some error")), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -122,56 +128,90 @@ func TestPort_PutSignals(t *testing.T) { signals signal.Signals } tests := []struct { - name string - port *Port - signalsAfter signal.Signals - args args + name string + port *Port + args args + assertions func(t *testing.T, portAfter *Port) }{ { - name: "single signal to empty port", - port: New("emptyPort"), - signalsAfter: signal.NewGroup(11).SignalsOrNil(), + name: "single signal to empty port", + port: New("emptyPort"), + assertions: func(t *testing.T, portAfter *Port) { + assert.Equal(t, signal.NewGroup(11), portAfter.Buffer()) + }, args: args{ signals: signal.NewGroup(11).SignalsOrNil(), }, }, { - name: "multiple buffer to empty port", - port: New("p"), - signalsAfter: signal.NewGroup(11, 12).SignalsOrNil(), + name: "multiple buffer to empty port", + port: New("p"), + assertions: func(t *testing.T, portAfter *Port) { + assert.Equal(t, signal.NewGroup(11, 12), portAfter.Buffer()) + }, args: args{ signals: signal.NewGroup(11, 12).SignalsOrNil(), }, }, { - name: "single signal to port with single signal", - port: New("p").WithSignals(signal.New(11)), - signalsAfter: signal.NewGroup(11, 12).SignalsOrNil(), + name: "single signal to port with single signal", + port: New("p").WithSignals(signal.New(11)), + assertions: func(t *testing.T, portAfter *Port) { + assert.Equal(t, signal.NewGroup(11, 12), portAfter.Buffer()) + }, args: args{ signals: signal.NewGroup(12).SignalsOrNil(), }, }, { - name: "single buffer to port with multiple buffer", - port: New("p").WithSignalGroups(signal.NewGroup(11, 12)), - signalsAfter: signal.NewGroup(11, 12, 13).SignalsOrNil(), + name: "single buffer to port with multiple buffer", + port: New("p").WithSignalGroups(signal.NewGroup(11, 12)), + assertions: func(t *testing.T, portAfter *Port) { + assert.Equal(t, signal.NewGroup(11, 12, 13), portAfter.Buffer()) + }, args: args{ signals: signal.NewGroup(13).SignalsOrNil(), }, }, { - name: "multiple buffer to port with multiple buffer", - port: New("p").WithSignalGroups(signal.NewGroup(55, 66)), - signalsAfter: signal.NewGroup(55, 66, 13, 14).SignalsOrNil(), + name: "multiple buffer to port with multiple buffer", + port: New("p").WithSignalGroups(signal.NewGroup(55, 66)), + assertions: func(t *testing.T, portAfter *Port) { + assert.Equal(t, signal.NewGroup(55, 66, 13, 14), portAfter.Buffer()) + }, args: args{ signals: signal.NewGroup(13, 14).SignalsOrNil(), }, }, + { + name: "chain error propagated from buffer", + port: New("p"), + assertions: func(t *testing.T, portAfter *Port) { + assert.Zero(t, portAfter.Buffer().Len()) + assert.True(t, portAfter.Buffer().HasChainError()) + }, + args: args{ + signals: signal.Signals{signal.New(111).WithChainError(errors.New("some error in signal"))}, + }, + }, + { + name: "with chain error", + port: New("p").WithChainError(errors.New("some error in port")), + args: args{ + signals: signal.Signals{signal.New(123)}, + }, + assertions: func(t *testing.T, portAfter *Port) { + assert.True(t, portAfter.HasChainError()) + assert.Zero(t, portAfter.Buffer().Len()) + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { portAfter := tt.port.PutSignals(tt.args.signals...) - assert.ElementsMatch(t, tt.signalsAfter, portAfter.AllSignalsOrNil()) + if tt.assertions != nil { + tt.assertions(t, portAfter) + } }) } } @@ -339,3 +379,61 @@ func TestPort_WithLabels(t *testing.T) { }) } } + +func TestPort_Pipes(t *testing.T) { + tests := []struct { + name string + port *Port + want *Group + }{ + { + name: "no pipes", + port: New("p"), + want: NewGroup(), + }, + { + name: "with pipes", + port: New("p1").PipeTo(New("p2"), New("p3")), + want: NewGroup("p2", "p3"), + }, + { + name: "with chain error", + port: New("p").WithChainError(errors.New("some error")), + want: NewGroup().WithChainError(errors.New("some error")), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.port.Pipes()) + }) + } +} + +func TestPort_ShortcutGetters(t *testing.T) { + t.Run("FirstSignalPayload", func(t *testing.T) { + port := New("p").WithSignalGroups(signal.NewGroup(4, 7, 6, 5)) + payload, err := port.FirstSignalPayload() + assert.NoError(t, err) + assert.Equal(t, 4, payload) + }) + + t.Run("FirstSignalPayloadOrNil", func(t *testing.T) { + port := New("p").WithSignals(signal.New(123).WithChainError(errors.New("some error"))) + assert.Nil(t, port.FirstSignalPayloadOrNil()) + }) + + t.Run("FirstSignalPayloadOrDefault", func(t *testing.T) { + port := New("p").WithSignals(signal.New(123).WithChainError(errors.New("some error"))) + assert.Equal(t, 888, port.FirstSignalPayloadOrDefault(888)) + }) + + t.Run("AllSignalsOrNil", func(t *testing.T) { + port := New("p").WithSignals(signal.New(123).WithChainError(errors.New("some error"))) + assert.Nil(t, port.AllSignalsOrNil()) + }) + + t.Run("AllSignalsOrDefault", func(t *testing.T) { + port := New("p").WithSignals(signal.New(123).WithChainError(errors.New("some error"))) + assert.Equal(t, signal.NewGroup(999).SignalsOrNil(), port.AllSignalsOrDefault(signal.NewGroup(999).SignalsOrNil())) + }) +} diff --git a/signal/group_test.go b/signal/group_test.go index f1f0a36..bd24e12 100644 --- a/signal/group_test.go +++ b/signal/group_test.go @@ -24,6 +24,7 @@ func TestNewGroup(t *testing.T) { signals, err := group.Signals() assert.NoError(t, err) assert.Len(t, signals, 0) + assert.Zero(t, group.Len()) }, }, { @@ -34,7 +35,7 @@ func TestNewGroup(t *testing.T) { assertions: func(t *testing.T, group *Group) { signals, err := group.Signals() assert.NoError(t, err) - assert.Len(t, signals, 3) + assert.Equal(t, group.Len(), 3) assert.Contains(t, signals, New(1)) assert.Contains(t, signals, New(nil)) assert.Contains(t, signals, New(3)) From bf55284f8c36e2499f82e4f37d3605baf0e9f029 Mon Sep 17 00:00:00 2001 From: hovsep Date: Wed, 30 Oct 2024 01:38:27 +0200 Subject: [PATCH 2/2] Return chainable with error instead of separate error when possible --- component/component.go | 29 +++++++++++++++++++++++++---- fmesh.go | 24 +++++++++--------------- fmesh_test.go | 8 +++++--- port/errors.go | 4 +++- port/group.go | 2 +- port/group_test.go | 2 +- port/port.go | 2 +- 7 files changed, 45 insertions(+), 26 deletions(-) diff --git a/component/component.go b/component/component.go index 5833b54..97bbc62 100644 --- a/component/component.go +++ b/component/component.go @@ -172,18 +172,39 @@ func (c *Component) hasActivationFunction() bool { return c.f != nil } -// MaybeActivate tries to run the activation function if all required conditions are met -// @TODO: hide this method from user -// @TODO: can we remove named return ? -func (c *Component) MaybeActivate() (activationResult *ActivationResult) { +// propagateChainErrors propagates up all chain errors that might have not been propagated yet +func (c *Component) propagateChainErrors() { if c.Inputs().HasChainError() { c.SetChainError(c.Inputs().ChainError()) + return } if c.Outputs().HasChainError() { c.SetChainError(c.Outputs().ChainError()) + return + } + + for _, p := range c.Inputs().PortsOrNil() { + if p.HasChainError() { + c.SetChainError(p.ChainError()) + return + } } + for _, p := range c.Outputs().PortsOrNil() { + if p.HasChainError() { + c.SetChainError(p.ChainError()) + return + } + } +} + +// MaybeActivate tries to run the activation function if all required conditions are met +// @TODO: hide this method from user +// @TODO: can we remove named return ? +func (c *Component) MaybeActivate() (activationResult *ActivationResult) { + c.propagateChainErrors() + if c.HasChainError() { activationResult = NewActivationResult(c.Name()).WithChainError(c.ChainError()) return diff --git a/fmesh.go b/fmesh.go index 8ff0889..741ec3b 100644 --- a/fmesh.go +++ b/fmesh.go @@ -87,22 +87,22 @@ func (fm *FMesh) WithConfig(config Config) *FMesh { } // runCycle runs one activation cycle (tries to activate ready components) -func (fm *FMesh) runCycle() (*cycle.Cycle, error) { +func (fm *FMesh) runCycle() *cycle.Cycle { + newCycle := cycle.New() + if fm.HasChainError() { - return nil, fm.ChainError() + return newCycle.WithChainError(fm.ChainError()) } if fm.Components().Len() == 0 { - return nil, errors.New("failed to run cycle: no components found") + return newCycle.WithChainError(errors.New("failed to run cycle: no components found")) } - newCycle := cycle.New() - var wg sync.WaitGroup components, err := fm.Components().Components() if err != nil { - return nil, fmt.Errorf("failed to run cycle: %w", err) + return newCycle.WithChainError(fmt.Errorf("failed to run cycle: %w", err)) } for _, c := range components { @@ -130,7 +130,7 @@ func (fm *FMesh) runCycle() (*cycle.Cycle, error) { } } - return newCycle, nil + return newCycle } // DrainComponents drains the data from activated components @@ -188,13 +188,7 @@ func (fm *FMesh) Run() (cycle.Cycles, error) { allCycles := cycle.NewGroup() cycleNumber := 0 for { - cycleResult, err := fm.runCycle() - - if err != nil { - return nil, err - } - - cycleResult.WithNumber(cycleNumber) + cycleResult := fm.runCycle().WithNumber(cycleNumber) if cycleResult.HasChainError() { fm.SetChainError(cycleResult.ChainError()) @@ -216,7 +210,7 @@ func (fm *FMesh) Run() (cycle.Cycles, error) { return cycles, stopError } - err = fm.drainComponents(cycleResult) + err := fm.drainComponents(cycleResult) if err != nil { return nil, err } diff --git a/fmesh_test.go b/fmesh_test.go index 1b4aac4..4457c10 100644 --- a/fmesh_test.go +++ b/fmesh_test.go @@ -649,11 +649,13 @@ func TestFMesh_runCycle(t *testing.T) { if tt.initFM != nil { tt.initFM(tt.fm) } - cycleResult, err := tt.fm.runCycle() + cycleResult := tt.fm.runCycle() if tt.wantError { - assert.Error(t, err) + assert.True(t, cycleResult.HasChainError()) + assert.Error(t, cycleResult.ChainError()) } else { - assert.NoError(t, err) + assert.False(t, cycleResult.HasChainError()) + assert.NoError(t, cycleResult.ChainError()) assert.Equal(t, tt.want, cycleResult) } }) diff --git a/port/errors.go b/port/errors.go index f69c5d5..d444ac8 100644 --- a/port/errors.go +++ b/port/errors.go @@ -3,5 +3,7 @@ package port import "errors" var ( - ErrPortNotFoundInCollection = errors.New("port not found") + ErrPortNotFoundInCollection = errors.New("port not found") + ErrPortNotReadyForFlush = errors.New("port is not ready for flush") + ErrInvalidRangeForIndexedGroup = errors.New("start index can not be greater than end index") ) diff --git a/port/group.go b/port/group.go index 2a878e8..ca8cb71 100644 --- a/port/group.go +++ b/port/group.go @@ -29,7 +29,7 @@ func NewGroup(names ...string) *Group { // NOTE: endIndex is inclusive, e.g. NewIndexedGroup("p", 0, 0) will create one port with name "p0" func NewIndexedGroup(prefix string, startIndex int, endIndex int) *Group { if startIndex > endIndex { - return nil + return NewGroup().WithChainError(ErrInvalidRangeForIndexedGroup) } ports := make([]*Port, endIndex-startIndex+1) diff --git a/port/group_test.go b/port/group_test.go index a29db44..e987a7a 100644 --- a/port/group_test.go +++ b/port/group_test.go @@ -80,7 +80,7 @@ func TestNewIndexedGroup(t *testing.T) { startIndex: 999, endIndex: 5, }, - want: nil, + want: NewGroup().WithChainError(ErrInvalidRangeForIndexedGroup), }, } for _, tt := range tests { diff --git a/port/port.go b/port/port.go index 5e6652f..a7c5c9f 100644 --- a/port/port.go +++ b/port/port.go @@ -107,7 +107,7 @@ func (p *Port) Flush() *Port { if !p.HasSignals() || !p.HasPipes() { //@TODO maybe better to return explicit errors - return nil + return New("").WithChainError(ErrPortNotReadyForFlush) } pipes, err := p.pipes.Ports()