diff --git a/Makefile b/Makefile index c89d44c4..746cee4f 100644 --- a/Makefile +++ b/Makefile @@ -75,7 +75,7 @@ ${TEST_BIN_DIR}/peerswap: # Test section. Has commads for local and ci testing. test: - PAYMENT_RETRY_TIME=5 go test -tags dev -tags fast_test -timeout=10m -v ./... + PAYMENT_RETRY_TIME=5 go test -tags dev -tags fast_test -race -timeout=10m -v ./... .PHONY: test test-integration: test-bins diff --git a/electrum/block_subscriber.go b/electrum/block_subscriber.go index b810dced..0079b91a 100644 --- a/electrum/block_subscriber.go +++ b/electrum/block_subscriber.go @@ -2,6 +2,7 @@ package electrum import ( "context" + "sync" "github.com/elementsproject/peerswap/log" ) @@ -24,6 +25,7 @@ type BlockHeaderSubscriber interface { type liquidBlockHeaderSubscriber struct { txObservers []TXObserver + mu sync.Mutex } func NewLiquidBlockHeaderSubscriber() *liquidBlockHeaderSubscriber { @@ -33,14 +35,20 @@ func NewLiquidBlockHeaderSubscriber() *liquidBlockHeaderSubscriber { var _ BlockHeaderSubscriber = (*liquidBlockHeaderSubscriber)(nil) func (h *liquidBlockHeaderSubscriber) Register(tx TXObserver) { + h.mu.Lock() + defer h.mu.Unlock() h.txObservers = append(h.txObservers, tx) } func (h *liquidBlockHeaderSubscriber) Deregister(o TXObserver) { + h.mu.Lock() + defer h.mu.Unlock() h.txObservers = remove(h.txObservers, o) } func (h *liquidBlockHeaderSubscriber) Update(ctx context.Context, blockHeight BlocKHeight) error { + h.mu.Lock() + defer h.mu.Unlock() for _, observer := range h.txObservers { callbacked, err := observer.Callback(ctx, blockHeight) if callbacked && err == nil { diff --git a/electrum/block_subscriber_test.go b/electrum/block_subscriber_test.go new file mode 100644 index 00000000..0fc05ec0 --- /dev/null +++ b/electrum/block_subscriber_test.go @@ -0,0 +1,52 @@ +package electrum + +import ( + "context" + + "github.com/elementsproject/peerswap/swap" +) + +type testtxo struct{} + +// GetSwapID() swap.SwapId +// // Callback calls the callback function if the condition is match. +// // Returns true if the callback function is called. +// Callback(context.Context, BlocKHeight) (bool, error) + +func (t *testtxo) GetSwapID() swap.SwapId { + return swap.SwapId{} +} + +func (t *testtxo) Callback(context.Context, BlocKHeight) (bool, error) { + return true, nil +} + +// func TestConcurrentUpdate(t *testing.T) { +// observers := make([]TXObserver, 10) +// for i := range observers { +// observers[i] = &testtxo{} +// } + +// h := &liquidBlockHeaderSubscriber{ +// txObservers: observers, +// } + +// var wg sync.WaitGroup +// for i := 0; i < 10; i++ { +// wg.Add(1) +// go func() { +// defer wg.Done() +// err := h.Update(context.Background(), BlocKHeight(0)) +// if err != nil { +// t.Errorf("Unexpected error: %v", err) +// } +// }() +// } +// wg.Wait() + +// // Add assertions here to check the state of h.txObservers +// // For example, you might check the length of h.txObservers +// if len(h.txObservers) != 0 { +// t.Errorf("Expected length %d, but got %d", 0, len(h.txObservers)) +// } +// } diff --git a/lwk/electrumtxwatcher.go b/lwk/electrumtxwatcher.go index 92462146..3ded06db 100644 --- a/lwk/electrumtxwatcher.go +++ b/lwk/electrumtxwatcher.go @@ -3,6 +3,7 @@ package lwk import ( "context" "fmt" + "sync" "time" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -28,6 +29,7 @@ type electrumTxWatcher struct { // The connection with the electrum client is // disconnected after a certain period of time. resubscribeTicker *time.Ticker + mu sync.Mutex } func NewElectrumTxWatcher(electrumClient electrum.RPC) (*electrumTxWatcher, error) { @@ -59,7 +61,9 @@ func (r *electrumTxWatcher) StartWatchingTxs() error { if r.blockHeight.Confirmed() && blockHeader.Height <= int32(r.blockHeight.Height()) { continue } + r.mu.Lock() r.blockHeight = electrum.BlocKHeight(blockHeader.Height) + r.mu.Unlock() log.Infof("New block received. block height:%d", r.blockHeight) err = r.subscriber.Update(ctx, r.blockHeight) if err != nil { @@ -89,9 +93,12 @@ func (r *electrumTxWatcher) waitForInitialBlockHeaderSubscription(ctx context.Co log.Infof("Initial block header subscription timeout.") return ctx.Err() default: + r.mu.Lock() if r.blockHeight.Confirmed() { + r.mu.Unlock() return nil } + r.mu.Unlock() } time.Sleep(heartbeatInterval) } @@ -119,9 +126,13 @@ func (r *electrumTxWatcher) AddWaitForConfirmationTx(swapIDStr, txIDStr string, } func (r *electrumTxWatcher) AddConfirmationCallback(f func(swapId string, txHex string, err error) error) { + r.mu.Lock() + defer r.mu.Unlock() r.confirmationCallback = f } func (r *electrumTxWatcher) AddCsvCallback(f func(swapId string) error) { + r.mu.Lock() + defer r.mu.Unlock() r.csvCallback = f } diff --git a/lwk/electrumtxwatcher_test.go b/lwk/electrumtxwatcher_test.go index a05a925f..1f4091bc 100644 --- a/lwk/electrumtxwatcher_test.go +++ b/lwk/electrumtxwatcher_test.go @@ -110,7 +110,6 @@ func TestElectrumTxWatcher_Callback(t *testing.T) { r.AddCsvCallback( func(swapId string) error { assert.Equal(t, wantSwapID, swapId) - assert.NoError(t, err) callbackChan <- swapId return nil },