Skip to content

Commit

Permalink
Merge pull request #1817 from c9s/c9s/fix-connectivity
Browse files Browse the repository at this point in the history
FIX: refactor and redesign connectivity
  • Loading branch information
c9s authored Nov 14, 2024
2 parents 1eb0c1b + 1e57336 commit 8791da6
Show file tree
Hide file tree
Showing 7 changed files with 722 additions and 197 deletions.
23 changes: 13 additions & 10 deletions pkg/strategy/deposit2transfer/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ func (s *Strategy) checkDeposits(ctx context.Context) {
func (s *Strategy) addWatchingDeposit(deposit types.Deposit) {
s.watchingDeposits[deposit.TransactionID] = deposit

if lastTime, ok := s.lastAssetDepositTimes[deposit.Asset]; ok {
s.lastAssetDepositTimes[deposit.Asset] = later(deposit.Time.Time(), lastTime)
} else {
s.lastAssetDepositTimes[deposit.Asset] = deposit.Time.Time()
}

if s.SlackAlert != nil {
bbgo.PostLiveNote(&deposit,
livenote.Channel(s.SlackAlert.Channel),
Expand Down Expand Up @@ -339,7 +345,13 @@ func (s *Strategy) scanDepositHistory(ctx context.Context, asset string, duratio
logger.Infof("ignored expired succeedded deposit: %s %+v", deposit.TransactionID, deposit)
}
} else {
s.addWatchingDeposit(deposit)
// if the latest deposit time is not found, check if the deposit is older than 5 minutes
expiryTime := 5 * time.Minute
if deposit.Time.Before(time.Now().Add(-expiryTime)) {
logger.Infof("ignored expired (%s) succeedded deposit: %s %+v", expiryTime, deposit.TransactionID, deposit)
} else {
s.addWatchingDeposit(deposit)
}
}

case types.DepositCredited, types.DepositPending:
Expand All @@ -349,15 +361,6 @@ func (s *Strategy) scanDepositHistory(ctx context.Context, asset string, duratio
}
}

if len(deposits) > 0 {
lastDeposit := deposits[len(deposits)-1]
if lastTime, ok := s.lastAssetDepositTimes[asset]; ok {
s.lastAssetDepositTimes[asset] = later(lastDeposit.Time.Time(), lastTime)
} else {
s.lastAssetDepositTimes[asset] = lastDeposit.Time.Time()
}
}

var succeededDeposits []types.Deposit

// find and move out succeeded deposits
Expand Down
144 changes: 36 additions & 108 deletions pkg/types/connectivity.go
Original file line number Diff line number Diff line change
@@ -1,96 +1,9 @@
package types

import (
"context"
"sync"
"time"
)

type ConnectivityGroup struct {
connections []*Connectivity
mu sync.Mutex
}

func NewConnectivityGroup(cons ...*Connectivity) *ConnectivityGroup {
return &ConnectivityGroup{
connections: cons,
}
}

func (g *ConnectivityGroup) Add(con *Connectivity) {
g.mu.Lock()
defer g.mu.Unlock()

g.connections = append(g.connections, con)
}

func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool {
g.mu.Lock()
conns := g.connections
g.mu.Unlock()

for _, conn := range conns {
select {
case <-ctx.Done():
return false

case <-conn.connectedC:
continue

case <-conn.disconnectedC:
return true
}
}

return false
}

func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}, allTimeoutDuration time.Duration) {
g.mu.Lock()
conns := g.connections
g.mu.Unlock()

authedConns := make([]bool, len(conns))
allTimeout := time.After(allTimeoutDuration)
for {
for idx, con := range conns {
// if the connection is not authed, mark it as false
if !con.authed {
// authedConns[idx] = false
}

timeout := time.After(3 * time.Second)
select {
case <-ctx.Done():
return

case <-allTimeout:
return

case <-timeout:
continue

case <-con.AuthedC():
authedConns[idx] = true
}
}

if allTrue(authedConns) {
close(c)
return
}
}
}

// AllAuthedC returns a channel that will be closed when all connections are authenticated
// the returned channel will be closed when all connections are authenticated
// and the channel can only be used once (because we can't close a channel twice)
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} {
c := make(chan struct{})
go g.waitAllAuthed(ctx, c, timeout)
return c
}

func allTrue(bools []bool) bool {
for _, b := range bools {
if !b {
Expand All @@ -101,6 +14,7 @@ func allTrue(bools []bool) bool {
return true
}

//go:generate callbackgen -type Connectivity
type Connectivity struct {
authed bool
authedC chan struct{}
Expand All @@ -109,7 +23,12 @@ type Connectivity struct {
connectedC chan struct{}
disconnectedC chan struct{}

mu sync.Mutex
connectCallbacks []func()
disconnectCallbacks []func()
authCallbacks []func()

stream Stream
mu sync.Mutex
}

func NewConnectivity() *Connectivity {
Expand Down Expand Up @@ -141,31 +60,39 @@ func (c *Connectivity) IsAuthed() (authed bool) {
return authed
}

func (c *Connectivity) handleConnect() {
func (c *Connectivity) setConnect() {
c.mu.Lock()
defer c.mu.Unlock()

c.connected = true
close(c.connectedC)
c.disconnectedC = make(chan struct{})
if !c.connected {
c.connected = true
close(c.connectedC)
c.disconnectedC = make(chan struct{})
}
c.mu.Unlock()
c.EmitConnect()
}

func (c *Connectivity) handleDisconnect() {
func (c *Connectivity) setDisconnect() {
c.mu.Lock()
defer c.mu.Unlock()

c.connected = false
c.authedC = make(chan struct{})
c.connectedC = make(chan struct{})
close(c.disconnectedC)
if c.connected {
c.connected = false
c.authed = false
c.authedC = make(chan struct{})
c.connectedC = make(chan struct{})
close(c.disconnectedC)
}
c.mu.Unlock()
c.EmitDisconnect()
}

func (c *Connectivity) handleAuth() {
func (c *Connectivity) setAuthed() {
c.mu.Lock()
defer c.mu.Unlock()
if !c.authed {
c.authed = true
close(c.authedC)
}
c.mu.Unlock()

c.authed = true
close(c.authedC)
c.EmitAuth()
}

func (c *Connectivity) AuthedC() chan struct{} {
Expand All @@ -187,7 +114,8 @@ func (c *Connectivity) DisconnectedC() chan struct{} {
}

func (c *Connectivity) Bind(stream Stream) {
stream.OnConnect(c.handleConnect)
stream.OnDisconnect(c.handleDisconnect)
stream.OnAuth(c.handleAuth)
stream.OnConnect(c.setConnect)
stream.OnDisconnect(c.setDisconnect)
stream.OnAuth(c.setAuthed)
c.stream = stream
}
35 changes: 35 additions & 0 deletions pkg/types/connectivity_callbacks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

92 changes: 13 additions & 79 deletions pkg/types/connectivity_test.go
Original file line number Diff line number Diff line change
@@ -1,105 +1,39 @@
package types

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestConnectivity(t *testing.T) {
t.Run("general", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.handleConnect()
conn1.handleAuth()
conn1.handleDisconnect()
conn1.setConnect()
conn1.setAuthed()
conn1.setDisconnect()
})

t.Run("reconnect", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.handleConnect()
conn1.handleAuth()
conn1.handleDisconnect()
conn1.setConnect()
conn1.setAuthed()
conn1.setDisconnect()

conn1.handleConnect()
conn1.handleAuth()
conn1.handleDisconnect()
conn1.setConnect()
conn1.setAuthed()
conn1.setDisconnect()
})

t.Run("no-auth reconnect", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.handleConnect()
conn1.handleDisconnect()
conn1.setConnect()
conn1.setDisconnect()

conn1.handleConnect()
conn1.handleDisconnect()
conn1.setConnect()
conn1.setDisconnect()
})
}

func TestConnectivityGroupAuthC(t *testing.T) {
timeout := 100 * time.Millisecond
delay := timeout * 2

ctx := context.Background()
conn1 := NewConnectivity()
conn2 := NewConnectivity()
group := NewConnectivityGroup(conn1, conn2)
allAuthedC := group.AllAuthedC(ctx, time.Second)

time.Sleep(delay)
conn1.handleConnect()
assert.True(t, waitSigChan(conn1.ConnectedC(), timeout))
conn1.handleAuth()
assert.True(t, waitSigChan(conn1.AuthedC(), timeout))

time.Sleep(delay)
conn2.handleConnect()
assert.True(t, waitSigChan(conn2.ConnectedC(), timeout))

conn2.handleAuth()
assert.True(t, waitSigChan(conn2.AuthedC(), timeout))

assert.True(t, waitSigChan(allAuthedC, timeout))
}

func TestConnectivityGroupReconnect(t *testing.T) {
timeout := 100 * time.Millisecond
delay := timeout * 2

ctx := context.Background()
conn1 := NewConnectivity()
conn2 := NewConnectivity()
group := NewConnectivityGroup(conn1, conn2)

time.Sleep(delay)
conn1.handleConnect()
conn1.handleAuth()
conn1authC := conn1.authedC

time.Sleep(delay)
conn2.handleConnect()
conn2.handleAuth()

assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated")

assert.False(t, group.AnyDisconnected(ctx))

// this should re-allocate authedC
conn1.handleDisconnect()
assert.NotEqual(t, conn1authC, conn1.authedC)

assert.True(t, group.AnyDisconnected(ctx))

assert.False(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "one connection should be un-authed")

time.Sleep(delay)

conn1.handleConnect()
conn1.handleAuth()
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated, again")
}

func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool {
select {
case <-time.After(timeoutDuration):
Expand Down
Loading

0 comments on commit 8791da6

Please sign in to comment.