diff --git a/Makefile b/Makefile index c89d44c4..746cee4f 100644 --- a/Makefile +++ b/Makefile @@ -75,7 +75,7 @@ ${TEST_BIN_DIR}/peerswap: # Test section. Has commads for local and ci testing. test: - PAYMENT_RETRY_TIME=5 go test -tags dev -tags fast_test -timeout=10m -v ./... + PAYMENT_RETRY_TIME=5 go test -tags dev -tags fast_test -race -timeout=10m -v ./... .PHONY: test test-integration: test-bins diff --git a/electrum/block_subscriber.go b/electrum/block_subscriber.go index b810dced..483d180e 100644 --- a/electrum/block_subscriber.go +++ b/electrum/block_subscriber.go @@ -2,6 +2,7 @@ package electrum import ( "context" + "sync" "github.com/elementsproject/peerswap/log" ) @@ -24,6 +25,7 @@ type BlockHeaderSubscriber interface { type liquidBlockHeaderSubscriber struct { txObservers []TXObserver + mu sync.Mutex } func NewLiquidBlockHeaderSubscriber() *liquidBlockHeaderSubscriber { @@ -33,14 +35,24 @@ func NewLiquidBlockHeaderSubscriber() *liquidBlockHeaderSubscriber { var _ BlockHeaderSubscriber = (*liquidBlockHeaderSubscriber)(nil) func (h *liquidBlockHeaderSubscriber) Register(tx TXObserver) { + h.mu.Lock() + defer h.mu.Unlock() h.txObservers = append(h.txObservers, tx) } func (h *liquidBlockHeaderSubscriber) Deregister(o TXObserver) { - h.txObservers = remove(h.txObservers, o) + newObservers := make([]TXObserver, 0, len(h.txObservers)) + for _, observer := range h.txObservers { + if observer.GetSwapID() != o.GetSwapID() { + newObservers = append(newObservers, observer) + } + } + h.txObservers = newObservers } func (h *liquidBlockHeaderSubscriber) Update(ctx context.Context, blockHeight BlocKHeight) error { + h.mu.Lock() + defer h.mu.Unlock() for _, observer := range h.txObservers { callbacked, err := observer.Callback(ctx, blockHeight) if callbacked && err == nil { @@ -53,13 +65,3 @@ func (h *liquidBlockHeaderSubscriber) Update(ctx context.Context, blockHeight Bl } return nil } - -func remove(observerList []TXObserver, observerToRemove TXObserver) []TXObserver { - newObservers := make([]TXObserver, len(observerList)-1) - for _, observer := range observerList { - if observer.GetSwapID() != observerToRemove.GetSwapID() { - newObservers = append(newObservers, observer) - } - } - return newObservers -} diff --git a/electrum/block_subscriber_test.go b/electrum/block_subscriber_test.go new file mode 100644 index 00000000..a58d3132 --- /dev/null +++ b/electrum/block_subscriber_test.go @@ -0,0 +1,53 @@ +package electrum + +import ( + "context" + "sync" + "testing" + + "github.com/elementsproject/peerswap/swap" +) + +type testtxo struct { + swapid swap.SwapId +} + +func (t *testtxo) GetSwapID() swap.SwapId { + return t.swapid +} + +func (t *testtxo) Callback(context.Context, BlocKHeight) (bool, error) { + return true, nil +} + +func TestConcurrentUpdate(t *testing.T) { + t.Parallel() + observers := make([]TXObserver, 10) + for i := range observers { + observers[i] = &testtxo{ + swapid: *swap.NewSwapId(), + } + } + + h := &liquidBlockHeaderSubscriber{ + txObservers: observers, + } + + var wg sync.WaitGroup + const concurrency = 100 + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := h.Update(context.Background(), BlocKHeight(0)) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + }() + } + wg.Wait() + + if len(h.txObservers) != 0 { + t.Errorf("Expected length %d, but got %d", 0, len(h.txObservers)) + } +} diff --git a/lwk/electrumtxwatcher.go b/lwk/electrumtxwatcher.go index 92462146..3ded06db 100644 --- a/lwk/electrumtxwatcher.go +++ b/lwk/electrumtxwatcher.go @@ -3,6 +3,7 @@ package lwk import ( "context" "fmt" + "sync" "time" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -28,6 +29,7 @@ type electrumTxWatcher struct { // The connection with the electrum client is // disconnected after a certain period of time. resubscribeTicker *time.Ticker + mu sync.Mutex } func NewElectrumTxWatcher(electrumClient electrum.RPC) (*electrumTxWatcher, error) { @@ -59,7 +61,9 @@ func (r *electrumTxWatcher) StartWatchingTxs() error { if r.blockHeight.Confirmed() && blockHeader.Height <= int32(r.blockHeight.Height()) { continue } + r.mu.Lock() r.blockHeight = electrum.BlocKHeight(blockHeader.Height) + r.mu.Unlock() log.Infof("New block received. block height:%d", r.blockHeight) err = r.subscriber.Update(ctx, r.blockHeight) if err != nil { @@ -89,9 +93,12 @@ func (r *electrumTxWatcher) waitForInitialBlockHeaderSubscription(ctx context.Co log.Infof("Initial block header subscription timeout.") return ctx.Err() default: + r.mu.Lock() if r.blockHeight.Confirmed() { + r.mu.Unlock() return nil } + r.mu.Unlock() } time.Sleep(heartbeatInterval) } @@ -119,9 +126,13 @@ func (r *electrumTxWatcher) AddWaitForConfirmationTx(swapIDStr, txIDStr string, } func (r *electrumTxWatcher) AddConfirmationCallback(f func(swapId string, txHex string, err error) error) { + r.mu.Lock() + defer r.mu.Unlock() r.confirmationCallback = f } func (r *electrumTxWatcher) AddCsvCallback(f func(swapId string) error) { + r.mu.Lock() + defer r.mu.Unlock() r.csvCallback = f } diff --git a/lwk/electrumtxwatcher_test.go b/lwk/electrumtxwatcher_test.go index a05a925f..1f4091bc 100644 --- a/lwk/electrumtxwatcher_test.go +++ b/lwk/electrumtxwatcher_test.go @@ -110,7 +110,6 @@ func TestElectrumTxWatcher_Callback(t *testing.T) { r.AddCsvCallback( func(swapId string) error { assert.Equal(t, wantSwapID, swapId) - assert.NoError(t, err) callbackChan <- swapId return nil },