Skip to content

Commit

Permalink
Add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
hovsep committed Oct 23, 2024
1 parent f4ddc2f commit 85880f1
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 9 deletions.
1 change: 1 addition & 0 deletions component/activation_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const (
func NewActivationResult(componentName string) *ActivationResult {
return &ActivationResult{
componentName: componentName,
Chainable: common.NewChainable(),
}
}

Expand Down
2 changes: 1 addition & 1 deletion component/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func NewCollection() *Collection {
// ByName returns a component by its name
func (c *Collection) ByName(name string) *Component {
if c.HasChainError() {
return nil
return New("").WithChainError(c.ChainError())
}

component, ok := c.components[name]
Expand Down
20 changes: 19 additions & 1 deletion component/component.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,12 @@ func (c *Component) OutputByName(name string) *port.Port {

// InputByName is shortcut method
func (c *Component) InputByName(name string) *port.Port {
if c.HasChainError() {
return port.New("").WithChainError(c.ChainError())
}
inputPort := c.Inputs().ByName(name)
if inputPort.HasChainError() {
c.SetChainError(inputPort.ChainError())
return nil
}
return inputPort
}
Expand All @@ -159,6 +161,22 @@ func (c *Component) hasActivationFunction() bool {
// @TODO: hide this method from user
// @TODO: can we remove named return ?
func (c *Component) MaybeActivate() (activationResult *ActivationResult) {
//Bubble up chain errors from ports
for _, p := range c.Inputs().PortsOrNil() {
if p.HasChainError() {
c.Inputs().SetChainError(p.ChainError())
c.SetChainError(c.Inputs().ChainError())
break
}
}
for _, p := range c.Outputs().PortsOrNil() {
if p.HasChainError() {
c.Outputs().SetChainError(p.ChainError())
c.SetChainError(c.Outputs().ChainError())
break
}
}

if c.HasChainError() {
activationResult = NewActivationResult(c.Name()).WithChainError(c.ChainError())
return
Expand Down
22 changes: 21 additions & 1 deletion fmesh.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func New(name string) *FMesh {
// Components getter
func (fm *FMesh) Components() *component.Collection {
if fm.HasChainError() {
return nil
return component.NewCollection().WithChainError(fm.ChainError())
}
return fm.components
}
Expand Down Expand Up @@ -106,6 +106,9 @@ func (fm *FMesh) runCycle() (*cycle.Cycle, error) {
}

for _, c := range components {
if c.HasChainError() {
fm.SetChainError(c.ChainError())
}
wg.Add(1)

go func(component *component.Component, cycle *cycle.Cycle) {
Expand Down Expand Up @@ -135,6 +138,10 @@ func (fm *FMesh) drainComponents(cycle *cycle.Cycle) error {
for _, c := range components {
activationResult := cycle.ActivationResults().ByComponentName(c.Name())

if activationResult.HasChainError() {
return activationResult.ChainError()
}

if !activationResult.Activated() {
// Component did not activate, so it did not create new output signals, hence nothing to drain
continue
Expand Down Expand Up @@ -173,13 +180,26 @@ func (fm *FMesh) Run() (cycle.Collection, error) {
cycleNumber := 0
for {
cycleResult, err := fm.runCycle()

//Bubble up chain errors from activation results
for _, ar := range cycleResult.ActivationResults() {
if ar.HasChainError() {
fm.SetChainError(ar.ChainError())
break
}
}

if err != nil {
return nil, err
}
cycleResult.WithNumber(cycleNumber)
allCycles = allCycles.With(cycleResult)

mustStop, err := fm.mustStop(cycleResult)
if err != nil {
return nil, err
}

if mustStop {
return allCycles, err
}
Expand Down
154 changes: 154 additions & 0 deletions integration_tests/error_handling/chainable_api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package error_handling

import (
"errors"
"github.com/hovsep/fmesh"
"github.com/hovsep/fmesh/component"
"github.com/hovsep/fmesh/port"
"github.com/hovsep/fmesh/signal"
"github.com/stretchr/testify/assert"
"testing"
)

func Test_Signal(t *testing.T) {
tests := []struct {
name string
test func(t *testing.T)
}{
{
name: "no errors",
test: func(t *testing.T) {
sig := signal.New(123)
_, err := sig.Payload()
assert.False(t, sig.HasChainError())
assert.NoError(t, err)

_ = sig.PayloadOrDefault(555)
assert.False(t, sig.HasChainError())

_ = sig.PayloadOrNil()
assert.False(t, sig.HasChainError())
},
},
{
name: "error propagated from group to signal",
test: func(t *testing.T) {
emptyGroup := signal.NewGroup()

sig := emptyGroup.First()
assert.True(t, sig.HasChainError())
assert.Error(t, sig.ChainError())

_, err := sig.Payload()
assert.Error(t, err)
assert.EqualError(t, err, signal.ErrNoSignalsInGroup.Error())
},
},
}
for _, tt := range tests {
t.Run(tt.name, tt.test)
}
}

func Test_FMesh(t *testing.T) {
tests := []struct {
name string
test func(t *testing.T)
}{
{
name: "no errors",
test: func(t *testing.T) {
fm := fmesh.New("test").WithComponents(
component.New("c1").WithInputs("num1", "num2").
WithOutputs("sum").WithActivationFunc(func(inputs *port.Collection, outputs *port.Collection) error {
num1 := inputs.ByName("num1").FirstSignalPayloadOrDefault(0).(int)
num2 := inputs.ByName("num2").FirstSignalPayloadOrDefault(0).(int)
outputs.ByName("sum").PutSignals(signal.New(num1 + num2))
return nil
}),
)

fm.Components().ByName("c1").InputByName("num1").PutSignals(signal.New(10))
fm.Components().ByName("c1").InputByName("num2").PutSignals(signal.New(5))

_, err := fm.Run()
assert.False(t, fm.HasChainError())
assert.NoError(t, err)
},
},
{
name: "error propagated from component",
test: func(t *testing.T) {
fm := fmesh.New("test").WithComponents(
component.New("c1").
WithInputs("num1", "num2").
WithOutputs("sum").
WithActivationFunc(func(inputs *port.Collection, outputs *port.Collection) error {
num1 := inputs.ByName("num1").FirstSignalPayloadOrDefault(0).(int)
num2 := inputs.ByName("num2").FirstSignalPayloadOrDefault(0).(int)
outputs.ByName("sum").PutSignals(signal.New(num1 + num2))
return nil
}).
WithChainError(errors.New("some error in component")),
)

fm.Components().ByName("c1").InputByName("num1").PutSignals(signal.New(10))
fm.Components().ByName("c1").InputByName("num2").PutSignals(signal.New(5))

_, err := fm.Run()
assert.True(t, fm.HasChainError())
assert.Error(t, err)
assert.EqualError(t, err, "some error in component")
},
},
{
name: "error propagated from port",
test: func(t *testing.T) {
fm := fmesh.New("test").WithComponents(
component.New("c1").
WithInputs("num1", "num2").
WithOutputs("sum").
WithActivationFunc(func(inputs *port.Collection, outputs *port.Collection) error {
num1 := inputs.ByName("num1").FirstSignalPayloadOrDefault(0).(int)
num2 := inputs.ByName("num2").FirstSignalPayloadOrDefault(0).(int)
outputs.ByName("sum").PutSignals(signal.New(num1 + num2))
return nil
}),
)

fm.Components().ByName("c1").InputByName("num777").PutSignals(signal.New(10))
fm.Components().ByName("c1").InputByName("num2").PutSignals(signal.New(5))

_, err := fm.Run()
assert.True(t, fm.HasChainError())
assert.Error(t, err)
assert.EqualError(t, err, "port not found")
},
},
{
name: "error propagated from signal",
test: func(t *testing.T) {
fm := fmesh.New("test").WithComponents(
component.New("c1").WithInputs("num1", "num2").
WithOutputs("sum").WithActivationFunc(func(inputs *port.Collection, outputs *port.Collection) error {
num1 := inputs.ByName("num1").FirstSignalPayloadOrDefault(0).(int)
num2 := inputs.ByName("num2").FirstSignalPayloadOrDefault(0).(int)
outputs.ByName("sum").PutSignals(signal.New(num1 + num2))
return nil
}),
)

fm.Components().ByName("c1").InputByName("num1").PutSignals(signal.New(10).WithChainError(errors.New("some error in input signal")))
fm.Components().ByName("c1").InputByName("num2").PutSignals(signal.New(5))

_, err := fm.Run()
assert.True(t, fm.HasChainError())
assert.Error(t, err)
assert.EqualError(t, err, "some error in input signal")
},
},
}
for _, tt := range tests {
t.Run(tt.name, tt.test)
}
}
4 changes: 1 addition & 3 deletions port/collection.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package port

import (
"errors"
"github.com/hovsep/fmesh/common"
"github.com/hovsep/fmesh/signal"
)
Expand Down Expand Up @@ -31,8 +30,7 @@ func (collection *Collection) ByName(name string) *Port {
}
port, ok := collection.ports[name]
if !ok {
collection.SetChainError(errors.New("port not found"))
return nil
return New("").WithChainError(ErrPortNotFoundInCollection)
}
return port
}
Expand Down
7 changes: 7 additions & 0 deletions port/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package port

import "errors"

var (
ErrPortNotFoundInCollection = errors.New("port not found")
)
3 changes: 3 additions & 0 deletions port/port.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ func (p *Port) Pipes() *Group {

// setSignals sets buffer field
func (p *Port) setSignals(signals *signal.Group) {
if signals.HasChainError() {
p.SetChainError(signals.ChainError())
}
p.buffer = signals
}

Expand Down
8 changes: 8 additions & 0 deletions signal/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package signal

import "errors"

var (
ErrNoSignalsInGroup = errors.New("group has no signals")
ErrInvalidSignal = errors.New("signal is invalid")
)
5 changes: 2 additions & 3 deletions signal/group.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package signal

import (
"errors"
"github.com/hovsep/fmesh/common"
)

Expand Down Expand Up @@ -33,7 +32,7 @@ func (g *Group) First() *Signal {
}

if len(g.signals) == 0 {
return New(nil).WithChainError(errors.New("group has no signals"))
return New(nil).WithChainError(ErrNoSignalsInGroup)
}

return g.signals[0]
Expand Down Expand Up @@ -76,7 +75,7 @@ func (g *Group) With(signals ...*Signal) *Group {
copy(newSignals, g.signals)
for i, sig := range signals {
if sig == nil {
return g.WithChainError(errors.New("signal is nil"))
return g.WithChainError(ErrInvalidSignal)
}

if sig.HasChainError() {
Expand Down

0 comments on commit 85880f1

Please sign in to comment.