From 0ae9f1a72e1037a8be4b5f53f9664c0a4f8a7bf8 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 21 Jan 2025 01:19:02 +0800 Subject: [PATCH] sweep: add `requestID` to `monitorRecord` This way we can greatly simplify the method signatures, also paving the upcoming changes where we wanna make it clear when updating the monitorRecord, we only touch a portion of it. --- sweep/fee_bumper.go | 106 ++++++++++++++++++++++----------------- sweep/fee_bumper_test.go | 45 +++++++++++------ 2 files changed, 88 insertions(+), 63 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 3325b0c74a..7db98f99e7 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -410,34 +410,35 @@ func (t *TxPublisher) Broadcast(req *BumpRequest) <-chan *BumpResult { lnutils.SpewLogClosure(req)) // Store the request. - requestID, record := t.storeInitialRecord(req) + record := t.storeInitialRecord(req) // Create a chan to send the result to the caller. subscriber := make(chan *BumpResult, 1) - t.subscriberChans.Store(requestID, subscriber) + t.subscriberChans.Store(record.requestID, subscriber) // Publish the tx immediately if specified. if req.Immediate { - t.handleInitialBroadcast(record, requestID) + t.handleInitialBroadcast(record) } return subscriber } // storeInitialRecord initializes a monitor record and saves it in the map. -func (t *TxPublisher) storeInitialRecord(req *BumpRequest) ( - uint64, *monitorRecord) { - +func (t *TxPublisher) storeInitialRecord(req *BumpRequest) *monitorRecord { // Increase the request counter. // // NOTE: this is the only place where we increase the counter. requestID := t.requestCounter.Add(1) // Register the record. - record := &monitorRecord{req: req} + record := &monitorRecord{ + requestID: requestID, + req: req, + } t.records.Store(requestID, record) - return requestID, record + return record } // storeRecord stores the given record in the records map. @@ -446,6 +447,7 @@ func (t *TxPublisher) storeRecord(requestID uint64, sweepCtx *sweepTxCtx, // Register the record. t.records.Store(requestID, &monitorRecord{ + requestID: requestID, tx: sweepCtx.tx, req: req, feeFunction: f, @@ -461,16 +463,25 @@ func (t *TxPublisher) Name() string { // initializeTx initializes a fee function and creates an RBF-compliant tx. If // succeeded, the initial tx is stored in the records map. -func (t *TxPublisher) initializeTx(requestID uint64, req *BumpRequest) error { +func (t *TxPublisher) initializeTx(r *monitorRecord) error { // Create a fee bumping algorithm to be used for future RBF. - feeAlgo, err := t.initializeFeeFunction(req) + feeAlgo, err := t.initializeFeeFunction(r.req) if err != nil { return fmt.Errorf("init fee function: %w", err) } + // Attach the newly created fee function. + // + // TODO(yy): current we'd initialize a monitorRecord before creating the + // fee function, while we could instead create the fee function first + // then save it to the record. To make this happen we need to change the + // conf target calculation below since we would be initializing the fee + // function one block before. + r.feeFunction = feeAlgo + // Create the initial tx to be broadcasted. This tx is guaranteed to // comply with the RBF restrictions. - err = t.createRBFCompliantTx(requestID, req, feeAlgo) + err = t.createRBFCompliantTx(r) if err != nil { return fmt.Errorf("create RBF-compliant tx: %w", err) } @@ -511,24 +522,24 @@ func (t *TxPublisher) initializeFeeFunction( // so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee // and redo the process until the tx is valid, or return an error when non-RBF // related errors occur or the budget has been used up. -func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest, - f FeeFunction) error { +func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error { + f := r.feeFunction for { // Create a new tx with the given fee rate and check its // mempool acceptance. - sweepCtx, err := t.createAndCheckTx(req, f) + sweepCtx, err := t.createAndCheckTx(r.req, f) switch { case err == nil: // The tx is valid, store it. - t.storeRecord(requestID, sweepCtx, req, f) + t.storeRecord(r.requestID, sweepCtx, r.req, f) log.Infof("Created initial sweep tx=%v for %v inputs: "+ "feerate=%v, fee=%v, inputs:\n%v", - sweepCtx.tx.TxHash(), len(req.Inputs), + sweepCtx.tx.TxHash(), len(r.req.Inputs), f.FeeRate(), sweepCtx.fee, - inputTypeSummary(req.Inputs)) + inputTypeSummary(r.req.Inputs)) return nil @@ -773,6 +784,9 @@ func (t *TxPublisher) handleResult(result *BumpResult) { // monitorRecord is used to keep track of the tx being monitored by the // publisher internally. type monitorRecord struct { + // requestID is the ID of the request that created this record. + requestID uint64 + // tx is the tx being monitored. tx *wire.MsgTx @@ -787,6 +801,10 @@ type monitorRecord struct { // outpointToTxIndex is a map of outpoint to tx index. outpointToTxIndex map[wire.OutPoint]int + + // spendNotifiers is a map of spend notifiers registered for all the + // inputs. + spendNotifiers map[wire.OutPoint]*chainntnfs.SpendEvent } // Start starts the publisher by subscribing to block epoch updates and kicking @@ -915,35 +933,35 @@ func (t *TxPublisher) processRecords() { t.records.ForEach(visitor) // Handle the initial broadcast. - for requestID, r := range initialRecords { - t.handleInitialBroadcast(r, requestID) + for _, r := range initialRecords { + t.handleInitialBroadcast(r) } // For records that are confirmed, we'll notify the caller about this // result. - for requestID, r := range confirmedRecords { + for _, r := range confirmedRecords { log.Debugf("Tx=%v is confirmed", r.tx.TxHash()) t.wg.Add(1) - go t.handleTxConfirmed(r, requestID) + go t.handleTxConfirmed(r) } // Get the current height to be used in the following goroutines. currentHeight := t.currentHeight.Load() // For records that are not confirmed, we perform a fee bump if needed. - for requestID, r := range feeBumpRecords { + for _, r := range feeBumpRecords { log.Debugf("Attempting to fee bump Tx=%v", r.tx.TxHash()) t.wg.Add(1) - go t.handleFeeBumpTx(requestID, r, currentHeight) + go t.handleFeeBumpTx(r, currentHeight) } // For records that are failed, we'll notify the caller about this // result. - for requestID, r := range failedRecords { + for _, r := range failedRecords { log.Debugf("Tx=%v has inputs been spent by a third party, "+ "failing it now", r.tx.TxHash()) t.wg.Add(1) - go t.handleThirdPartySpent(r, requestID) + go t.handleThirdPartySpent(r) } } @@ -951,7 +969,7 @@ func (t *TxPublisher) processRecords() { // notify the subscriber then remove the record from the maps . // // NOTE: Must be run as a goroutine to avoid blocking on sending the result. -func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) { +func (t *TxPublisher) handleTxConfirmed(r *monitorRecord) { defer t.wg.Done() // Create a result that will be sent to the resultChan which is @@ -959,7 +977,7 @@ func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) { result := &BumpResult{ Event: TxConfirmed, Tx: r.tx, - requestID: requestID, + requestID: r.requestID, Fee: r.fee, FeeRate: r.feeFunction.FeeRate(), } @@ -1017,10 +1035,8 @@ func (t *TxPublisher) handleInitialTxError(requestID uint64, err error) { // 1. init a fee function based on the given strategy. // 2. create an RBF-compliant tx and monitor it for confirmation. // 3. notify the initial broadcast result back to the caller. -func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord, - requestID uint64) { - - log.Debugf("Initial broadcast for requestID=%v", requestID) +func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) { + log.Debugf("Initial broadcast for requestID=%v", r.requestID) var ( result *BumpResult @@ -1031,18 +1047,18 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord, // RBF rules. // // Create the initial tx to be broadcasted. - err = t.initializeTx(requestID, r.req) + err = t.initializeTx(r) if err != nil { log.Errorf("Initial broadcast failed: %v", err) // We now handle the initialization error and exit. - t.handleInitialTxError(requestID, err) + t.handleInitialTxError(r.requestID, err) return } // Successfully created the first tx, now broadcast it. - result, err = t.broadcast(requestID) + result, err = t.broadcast(r.requestID) if err != nil { // The broadcast failed, which can only happen if the tx record // cannot be found or the aux sweeper returns an error. In @@ -1051,7 +1067,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord, result = &BumpResult{ Event: TxFailed, Err: err, - requestID: requestID, + requestID: r.requestID, } } @@ -1062,9 +1078,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord, // attempt to bump the fee of the tx. // // NOTE: Must be run as a goroutine to avoid blocking on sending the result. -func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord, - currentHeight int32) { - +func (t *TxPublisher) handleFeeBumpTx(r *monitorRecord, currentHeight int32) { defer t.wg.Done() oldTxid := r.tx.TxHash() @@ -1095,7 +1109,7 @@ func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord, // The fee function now has a new fee rate, we will use it to bump the // fee of the tx. - resultOpt := t.createAndPublishTx(requestID, r) + resultOpt := t.createAndPublishTx(r) // If there's a result, we will notify the caller about the result. resultOpt.WhenSome(func(result BumpResult) { @@ -1109,9 +1123,7 @@ func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord, // and send a TxFailed event to the subscriber. // // NOTE: Must be run as a goroutine to avoid blocking on sending the result. -func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord, - requestID uint64) { - +func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord) { defer t.wg.Done() // Create a result that will be sent to the resultChan which is @@ -1123,7 +1135,7 @@ func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord, result := &BumpResult{ Event: TxFailed, Tx: r.tx, - requestID: requestID, + requestID: r.requestID, Err: ErrThirdPartySpent, } @@ -1134,7 +1146,7 @@ func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord, // createAndPublishTx creates a new tx with a higher fee rate and publishes it // to the network. It will update the record with the new tx and fee rate if // successfully created, and return the result when published successfully. -func (t *TxPublisher) createAndPublishTx(requestID uint64, +func (t *TxPublisher) createAndPublishTx( r *monitorRecord) fn.Option[BumpResult] { // Fetch the old tx. @@ -1185,16 +1197,16 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64, Event: TxFailed, Tx: oldTx, Err: err, - requestID: requestID, + requestID: r.requestID, }) } // The tx has been created without any errors, we now register a new // record by overwriting the same requestID. - t.storeRecord(requestID, sweepCtx, r.req, r.feeFunction) + t.storeRecord(r.requestID, sweepCtx, r.req, r.feeFunction) // Attempt to broadcast this new tx. - result, err := t.broadcast(requestID) + result, err := t.broadcast(r.requestID) if err != nil { log.Infof("Failed to broadcast replacement tx %v: %v", sweepCtx.tx.TxHash(), err) diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 0ce83301a7..527e259538 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -664,11 +664,19 @@ func TestCreateRBFCompliantTx(t *testing.T) { tc := tc rid := requestCounter.Add(1) + + // Create a test record. + record := &monitorRecord{ + requestID: rid, + req: req, + feeFunction: m.feeFunc, + } + t.Run(tc.name, func(t *testing.T) { tc.setupMock() // Call the method under test. - err := tp.createRBFCompliantTx(rid, req, m.feeFunc) + err := tp.createRBFCompliantTx(record) // Check the result is as expected. require.ErrorIs(t, err, tc.expectedErr) @@ -1082,6 +1090,7 @@ func TestCreateAnPublishFail(t *testing.T) { // Overwrite the budget to make it smaller than the fee. req.Budget = 100 record := &monitorRecord{ + requestID: requestID, req: req, feeFunction: m.feeFunc, tx: &wire.MsgTx{}, @@ -1097,7 +1106,7 @@ func TestCreateAnPublishFail(t *testing.T) { mock.Anything).Return(script, nil) // Call the createAndPublish method. - resultOpt := tp.createAndPublishTx(requestID, record) + resultOpt := tp.createAndPublishTx(record) result := resultOpt.UnwrapOrFail(t) // We expect the result to be TxFailed and the error is set in the @@ -1116,7 +1125,7 @@ func TestCreateAnPublishFail(t *testing.T) { mock.Anything).Return(lnwallet.ErrMempoolFee).Once() // Call the createAndPublish method and expect a none option. - resultOpt = tp.createAndPublishTx(requestID, record) + resultOpt = tp.createAndPublishTx(record) require.True(t, resultOpt.IsNone()) // Mock the testmempoolaccept to return a fee related error that should @@ -1125,7 +1134,7 @@ func TestCreateAnPublishFail(t *testing.T) { mock.Anything).Return(chain.ErrInsufficientFee).Once() // Call the createAndPublish method and expect a none option. - resultOpt = tp.createAndPublishTx(requestID, record) + resultOpt = tp.createAndPublishTx(record) require.True(t, resultOpt.IsNone()) } @@ -1147,6 +1156,7 @@ func TestCreateAnPublishSuccess(t *testing.T) { // Create a testing monitor record. req := createTestBumpRequest() record := &monitorRecord{ + requestID: requestID, req: req, feeFunction: m.feeFunc, tx: &wire.MsgTx{}, @@ -1169,7 +1179,7 @@ func TestCreateAnPublishSuccess(t *testing.T) { mock.Anything, mock.Anything).Return(errDummy).Once() // Call the createAndPublish method and expect a failure result. - resultOpt := tp.createAndPublishTx(requestID, record) + resultOpt := tp.createAndPublishTx(record) result := resultOpt.UnwrapOrFail(t) // We expect the result to be TxFailed and the error is set. @@ -1190,7 +1200,7 @@ func TestCreateAnPublishSuccess(t *testing.T) { mock.Anything, mock.Anything).Return(nil).Once() // Call the createAndPublish method and expect a success result. - resultOpt = tp.createAndPublishTx(requestID, record) + resultOpt = tp.createAndPublishTx(record) result = resultOpt.UnwrapOrFail(t) require.True(t, resultOpt.IsSome()) @@ -1258,7 +1268,7 @@ func TestHandleTxConfirmed(t *testing.T) { tp.wg.Add(1) done := make(chan struct{}) go func() { - tp.handleTxConfirmed(record, requestID) + tp.handleTxConfirmed(record) close(done) }() @@ -1304,7 +1314,11 @@ func TestHandleFeeBumpTx(t *testing.T) { // Create a testing monitor record. req := createTestBumpRequest() + + // Create a testing record and put it in the map. + requestID := uint64(1) record := &monitorRecord{ + requestID: requestID, req: req, feeFunction: m.feeFunc, tx: tx, @@ -1317,10 +1331,7 @@ func TestHandleFeeBumpTx(t *testing.T) { utxoIndex := map[wire.OutPoint]int{ op: 0, } - - // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := uint64(1) // Create a sweepTxCtx. sweepCtx := &sweepTxCtx{ @@ -1345,7 +1356,7 @@ func TestHandleFeeBumpTx(t *testing.T) { // Call the method and expect no result received. tp.wg.Add(1) - go tp.handleFeeBumpTx(requestID, record, testHeight) + go tp.handleFeeBumpTx(record, testHeight) // Check there's no result sent back. select { @@ -1359,7 +1370,7 @@ func TestHandleFeeBumpTx(t *testing.T) { // Call the method and expect no result received. tp.wg.Add(1) - go tp.handleFeeBumpTx(requestID, record, testHeight) + go tp.handleFeeBumpTx(record, testHeight) // Check there's no result sent back. select { @@ -1391,7 +1402,7 @@ func TestHandleFeeBumpTx(t *testing.T) { // // NOTE: must be called in a goroutine in case it blocks. tp.wg.Add(1) - go tp.handleFeeBumpTx(requestID, record, testHeight) + go tp.handleFeeBumpTx(record, testHeight) select { case <-time.After(time.Second): @@ -1437,6 +1448,7 @@ func TestProcessRecords(t *testing.T) { // Create a monitor record that's confirmed. recordConfirmed := &monitorRecord{ + requestID: requestID1, req: req1, feeFunction: m.feeFunc, tx: tx1, @@ -1450,6 +1462,7 @@ func TestProcessRecords(t *testing.T) { // Create a monitor record that's not confirmed. We know it's not // confirmed because the num of confirms is zero. recordFeeBump := &monitorRecord{ + requestID: requestID2, req: req2, feeFunction: m.feeFunc, tx: tx2, @@ -1588,7 +1601,7 @@ func TestHandleInitialBroadcastSuccess(t *testing.T) { // Call the method under test. tp.wg.Add(1) - tp.handleInitialBroadcast(rec, rid) + tp.handleInitialBroadcast(rec) // Check the result is sent back. select { @@ -1659,7 +1672,7 @@ func TestHandleInitialBroadcastFail(t *testing.T) { // Call the method under test and expect an error returned. tp.wg.Add(1) - tp.handleInitialBroadcast(rec, rid) + tp.handleInitialBroadcast(rec) // Check the result is sent back. select { @@ -1692,7 +1705,7 @@ func TestHandleInitialBroadcastFail(t *testing.T) { // Call the method under test. tp.wg.Add(1) - tp.handleInitialBroadcast(rec, rid) + tp.handleInitialBroadcast(rec) // Check the result is sent back. select {