Skip to content

Commit

Permalink
lwk: ensure thread safety
Browse files Browse the repository at this point in the history
Add race detector flag to go test command
for better concurrency checks.
Fixes have been implemented for the race conditions
detected in the above tests.
  • Loading branch information
YusukeShimizu committed May 30, 2024
1 parent 83603a2 commit 885181a
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ ${TEST_BIN_DIR}/peerswap:

# Test section. Has commads for local and ci testing.
test:
PAYMENT_RETRY_TIME=5 go test -tags dev -tags fast_test -timeout=10m -v ./...
PAYMENT_RETRY_TIME=5 go test -tags dev -tags fast_test -race -timeout=10m -v ./...
.PHONY: test

test-integration: test-bins
Expand Down
8 changes: 8 additions & 0 deletions electrum/block_subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package electrum

import (
"context"
"sync"

"github.com/elementsproject/peerswap/log"
)
Expand All @@ -24,6 +25,7 @@ type BlockHeaderSubscriber interface {

type liquidBlockHeaderSubscriber struct {
txObservers []TXObserver
mu sync.Mutex
}

func NewLiquidBlockHeaderSubscriber() *liquidBlockHeaderSubscriber {
Expand All @@ -33,14 +35,20 @@ func NewLiquidBlockHeaderSubscriber() *liquidBlockHeaderSubscriber {
var _ BlockHeaderSubscriber = (*liquidBlockHeaderSubscriber)(nil)

func (h *liquidBlockHeaderSubscriber) Register(tx TXObserver) {
h.mu.Lock()
defer h.mu.Unlock()
h.txObservers = append(h.txObservers, tx)
}

func (h *liquidBlockHeaderSubscriber) Deregister(o TXObserver) {
h.mu.Lock()
defer h.mu.Unlock()
h.txObservers = remove(h.txObservers, o)
}

func (h *liquidBlockHeaderSubscriber) Update(ctx context.Context, blockHeight BlocKHeight) error {
h.mu.Lock()
defer h.mu.Unlock()
for _, observer := range h.txObservers {
callbacked, err := observer.Callback(ctx, blockHeight)
if callbacked && err == nil {
Expand Down
52 changes: 52 additions & 0 deletions electrum/block_subscriber_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package electrum

import (
"context"

"github.com/elementsproject/peerswap/swap"
)

type testtxo struct{}

// GetSwapID() swap.SwapId
// // Callback calls the callback function if the condition is match.
// // Returns true if the callback function is called.
// Callback(context.Context, BlocKHeight) (bool, error)

func (t *testtxo) GetSwapID() swap.SwapId {
return swap.SwapId{}
}

func (t *testtxo) Callback(context.Context, BlocKHeight) (bool, error) {
return true, nil
}

// func TestConcurrentUpdate(t *testing.T) {
// observers := make([]TXObserver, 10)
// for i := range observers {
// observers[i] = &testtxo{}
// }

// h := &liquidBlockHeaderSubscriber{
// txObservers: observers,
// }

// var wg sync.WaitGroup
// for i := 0; i < 10; i++ {
// wg.Add(1)
// go func() {
// defer wg.Done()
// err := h.Update(context.Background(), BlocKHeight(0))
// if err != nil {
// t.Errorf("Unexpected error: %v", err)
// }
// }()
// }
// wg.Wait()

// // Add assertions here to check the state of h.txObservers
// // For example, you might check the length of h.txObservers
// if len(h.txObservers) != 0 {
// t.Errorf("Expected length %d, but got %d", 0, len(h.txObservers))
// }
// }
11 changes: 11 additions & 0 deletions lwk/electrumtxwatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package lwk
import (
"context"
"fmt"
"sync"
"time"

"github.com/btcsuite/btcd/chaincfg/chainhash"
Expand All @@ -28,6 +29,7 @@ type electrumTxWatcher struct {
// The connection with the electrum client is
// disconnected after a certain period of time.
resubscribeTicker *time.Ticker
mu sync.Mutex
}

func NewElectrumTxWatcher(electrumClient electrum.RPC) (*electrumTxWatcher, error) {
Expand Down Expand Up @@ -59,7 +61,9 @@ func (r *electrumTxWatcher) StartWatchingTxs() error {
if r.blockHeight.Confirmed() && blockHeader.Height <= int32(r.blockHeight.Height()) {
continue
}
r.mu.Lock()
r.blockHeight = electrum.BlocKHeight(blockHeader.Height)
r.mu.Unlock()
log.Infof("New block received. block height:%d", r.blockHeight)
err = r.subscriber.Update(ctx, r.blockHeight)
if err != nil {
Expand Down Expand Up @@ -89,9 +93,12 @@ func (r *electrumTxWatcher) waitForInitialBlockHeaderSubscription(ctx context.Co
log.Infof("Initial block header subscription timeout.")
return ctx.Err()
default:
r.mu.Lock()
if r.blockHeight.Confirmed() {
r.mu.Unlock()
return nil
}
r.mu.Unlock()
}
time.Sleep(heartbeatInterval)
}
Expand Down Expand Up @@ -119,9 +126,13 @@ func (r *electrumTxWatcher) AddWaitForConfirmationTx(swapIDStr, txIDStr string,
}

func (r *electrumTxWatcher) AddConfirmationCallback(f func(swapId string, txHex string, err error) error) {
r.mu.Lock()
defer r.mu.Unlock()
r.confirmationCallback = f
}
func (r *electrumTxWatcher) AddCsvCallback(f func(swapId string) error) {
r.mu.Lock()
defer r.mu.Unlock()
r.csvCallback = f
}

Expand Down
1 change: 0 additions & 1 deletion lwk/electrumtxwatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ func TestElectrumTxWatcher_Callback(t *testing.T) {
r.AddCsvCallback(
func(swapId string) error {
assert.Equal(t, wantSwapID, swapId)
assert.NoError(t, err)
callbackChan <- swapId
return nil
},
Expand Down

0 comments on commit 885181a

Please sign in to comment.