From ed4c7f772bd84c3a58cc35c3fb929bdf27643c9a Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 21 Jan 2025 01:50:58 +0800 Subject: [PATCH] sweep: refactor `storeRecord` to `updateRecord` To make it clear we are only updating fields, which will be handy for the following commit where we start tracking for spending notifications. --- sweep/fee_bumper.go | 66 +++++++++++++---------------- sweep/fee_bumper_test.go | 89 ++++++++++++++++++++++++++++++++-------- 2 files changed, 100 insertions(+), 55 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 7db98f99e7..620ad7554b 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -441,19 +441,19 @@ func (t *TxPublisher) storeInitialRecord(req *BumpRequest) *monitorRecord { return record } -// storeRecord stores the given record in the records map. -func (t *TxPublisher) storeRecord(requestID uint64, sweepCtx *sweepTxCtx, - req *BumpRequest, f FeeFunction) { +// updateRecord updates the given record's tx and fee, and saves it in the +// records map. +func (t *TxPublisher) updateRecord(r *monitorRecord, + sweepCtx *sweepTxCtx) *monitorRecord { + + r.tx = sweepCtx.tx + r.fee = sweepCtx.fee + r.outpointToTxIndex = sweepCtx.outpointToTxIndex // Register the record. - t.records.Store(requestID, &monitorRecord{ - requestID: requestID, - tx: sweepCtx.tx, - req: req, - feeFunction: f, - fee: sweepCtx.fee, - outpointToTxIndex: sweepCtx.outpointToTxIndex, - }) + t.records.Store(r.requestID, r) + + return r } // NOTE: part of the `chainio.Consumer` interface. @@ -463,11 +463,11 @@ func (t *TxPublisher) Name() string { // initializeTx initializes a fee function and creates an RBF-compliant tx. If // succeeded, the initial tx is stored in the records map. -func (t *TxPublisher) initializeTx(r *monitorRecord) error { +func (t *TxPublisher) initializeTx(r *monitorRecord) (*monitorRecord, error) { // Create a fee bumping algorithm to be used for future RBF. feeAlgo, err := t.initializeFeeFunction(r.req) if err != nil { - return fmt.Errorf("init fee function: %w", err) + return nil, fmt.Errorf("init fee function: %w", err) } // Attach the newly created fee function. @@ -481,12 +481,12 @@ func (t *TxPublisher) initializeTx(r *monitorRecord) error { // Create the initial tx to be broadcasted. This tx is guaranteed to // comply with the RBF restrictions. - err = t.createRBFCompliantTx(r) + record, err := t.createRBFCompliantTx(r) if err != nil { - return fmt.Errorf("create RBF-compliant tx: %w", err) + return nil, fmt.Errorf("create RBF-compliant tx: %w", err) } - return nil + return record, nil } // initializeFeeFunction initializes a fee function to be used for this request @@ -522,7 +522,9 @@ func (t *TxPublisher) initializeFeeFunction( // so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee // and redo the process until the tx is valid, or return an error when non-RBF // related errors occur or the budget has been used up. -func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error { +func (t *TxPublisher) createRBFCompliantTx( + r *monitorRecord) (*monitorRecord, error) { + f := r.feeFunction for { @@ -533,7 +535,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error { switch { case err == nil: // The tx is valid, store it. - t.storeRecord(r.requestID, sweepCtx, r.req, f) + record := t.updateRecord(r, sweepCtx) log.Infof("Created initial sweep tx=%v for %v inputs: "+ "feerate=%v, fee=%v, inputs:\n%v", @@ -541,7 +543,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error { f.FeeRate(), sweepCtx.fee, inputTypeSummary(r.req.Inputs)) - return nil + return record, nil // If the error indicates the fees paid is not enough, we will // ask the fee function to increase the fee rate and retry. @@ -572,7 +574,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error { // cluster these inputs differetly. increased, err = f.Increment() if err != nil { - return err + return nil, err } } @@ -582,7 +584,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error { // mempool acceptance. default: log.Debugf("Failed to create RBF-compliant tx: %v", err) - return err + return nil, err } } } @@ -645,13 +647,7 @@ func (t *TxPublisher) createAndCheckTx(req *BumpRequest, // the event channel to the record. Any broadcast-related errors will not be // returned here, instead, they will be put inside the `BumpResult` and // returned to the caller. -func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) { - // Get the record being monitored. - record, ok := t.records.Load(requestID) - if !ok { - return nil, fmt.Errorf("tx record %v not found", requestID) - } - +func (t *TxPublisher) broadcast(record *monitorRecord) (*BumpResult, error) { txid := record.tx.TxHash() tx := record.tx @@ -698,7 +694,7 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) { Fee: record.fee, FeeRate: record.feeFunction.FeeRate(), Err: err, - requestID: requestID, + requestID: record.requestID, } return result, nil @@ -801,10 +797,6 @@ type monitorRecord struct { // outpointToTxIndex is a map of outpoint to tx index. outpointToTxIndex map[wire.OutPoint]int - - // spendNotifiers is a map of spend notifiers registered for all the - // inputs. - spendNotifiers map[wire.OutPoint]*chainntnfs.SpendEvent } // Start starts the publisher by subscribing to block epoch updates and kicking @@ -1047,7 +1039,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) { // RBF rules. // // Create the initial tx to be broadcasted. - err = t.initializeTx(r) + record, err := t.initializeTx(r) if err != nil { log.Errorf("Initial broadcast failed: %v", err) @@ -1058,7 +1050,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) { } // Successfully created the first tx, now broadcast it. - result, err = t.broadcast(r.requestID) + result, err = t.broadcast(record) if err != nil { // The broadcast failed, which can only happen if the tx record // cannot be found or the aux sweeper returns an error. In @@ -1203,10 +1195,10 @@ func (t *TxPublisher) createAndPublishTx( // The tx has been created without any errors, we now register a new // record by overwriting the same requestID. - t.storeRecord(r.requestID, sweepCtx, r.req, r.feeFunction) + record := t.updateRecord(r, sweepCtx) // Attempt to broadcast this new tx. - result, err := t.broadcast(r.requestID) + result, err := t.broadcast(record) if err != nil { log.Infof("Failed to broadcast replacement tx %v: %v", sweepCtx.tx.TxHash(), err) diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 527e259538..0531dec8d9 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -313,9 +313,9 @@ func TestInitializeFeeFunction(t *testing.T) { require.Equal(t, feerate, f.FeeRate()) } -// TestStoreRecord correctly increases the request counter and saves the +// TestUpdateRecord correctly updates the fields fee and tx, and saves the // record. -func TestStoreRecord(t *testing.T) { +func TestUpdateRecord(t *testing.T) { t.Parallel() // Create a test input. @@ -358,8 +358,15 @@ func TestStoreRecord(t *testing.T) { outpointToTxIndex: utxoIndex, } + // Create a test record. + record := &monitorRecord{ + requestID: initialCounter, + req: req, + feeFunction: feeFunc, + } + // Call the method under test. - tp.storeRecord(initialCounter, sweepCtx, req, feeFunc) + tp.updateRecord(record, sweepCtx) // Read the saved record and compare. record, ok := tp.records.Load(initialCounter) @@ -676,10 +683,19 @@ func TestCreateRBFCompliantTx(t *testing.T) { tc.setupMock() // Call the method under test. - err := tp.createRBFCompliantTx(record) + rec, err := tp.createRBFCompliantTx(record) // Check the result is as expected. require.ErrorIs(t, err, tc.expectedErr) + + if tc.expectedErr != nil { + return + } + + // Assert the returned record has the following fields + // populated. + require.NotEmpty(t, rec.tx) + require.NotEmpty(t, rec.fee) }) } } @@ -721,13 +737,13 @@ func TestTxPublisherBroadcast(t *testing.T) { outpointToTxIndex: utxoIndex, } - tp.storeRecord(requestID, sweepCtx, req, m.feeFunc) - - // Quickly check when the requestID cannot be found, an error is - // returned. - result, err := tp.broadcast(uint64(1000)) - require.Error(t, err) - require.Nil(t, result) + // Create a test record. + record := &monitorRecord{ + requestID: requestID, + req: req, + feeFunction: m.feeFunc, + } + rec := tp.updateRecord(record, sweepCtx) testCases := []struct { name string @@ -782,7 +798,7 @@ func TestTxPublisherBroadcast(t *testing.T) { tc.setupMock() // Call the method under test. - result, err := tp.broadcast(requestID) + result, err := tp.broadcast(rec) // Check the result is as expected. require.ErrorIs(t, err, tc.expectedErr) @@ -838,7 +854,15 @@ func TestRemoveResult(t *testing.T) { name: "remove on TxConfirmed", setupRecord: func() uint64 { rid := requestCounter.Add(1) - tp.storeRecord(rid, sweepCtx, req, m.feeFunc) + + // Create a test record. + record := &monitorRecord{ + requestID: rid, + req: req, + feeFunction: m.feeFunc, + } + + tp.updateRecord(record, sweepCtx) tp.subscriberChans.Store(rid, nil) return rid @@ -854,7 +878,15 @@ func TestRemoveResult(t *testing.T) { name: "remove on TxFailed", setupRecord: func() uint64 { rid := requestCounter.Add(1) - tp.storeRecord(rid, sweepCtx, req, m.feeFunc) + + // Create a test record. + record := &monitorRecord{ + requestID: rid, + req: req, + feeFunction: m.feeFunc, + } + + tp.updateRecord(record, sweepCtx) tp.subscriberChans.Store(rid, nil) return rid @@ -871,7 +903,15 @@ func TestRemoveResult(t *testing.T) { name: "noop when tx is not confirmed or failed", setupRecord: func() uint64 { rid := requestCounter.Add(1) - tp.storeRecord(rid, sweepCtx, req, m.feeFunc) + + // Create a test record. + record := &monitorRecord{ + requestID: rid, + req: req, + feeFunction: m.feeFunc, + } + + tp.updateRecord(record, sweepCtx) tp.subscriberChans.Store(rid, nil) return rid @@ -937,8 +977,14 @@ func TestNotifyResult(t *testing.T) { fee: fee, outpointToTxIndex: utxoIndex, } + // Create a test record. + record := &monitorRecord{ + requestID: requestID, + req: req, + feeFunction: m.feeFunc, + } - tp.storeRecord(requestID, sweepCtx, req, m.feeFunc) + tp.updateRecord(record, sweepCtx) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1) @@ -1250,7 +1296,14 @@ func TestHandleTxConfirmed(t *testing.T) { outpointToTxIndex: utxoIndex, } - tp.storeRecord(requestID, sweepCtx, req, m.feeFunc) + // Create a test record. + record := &monitorRecord{ + requestID: requestID, + req: req, + feeFunction: m.feeFunc, + } + + tp.updateRecord(record, sweepCtx) record, ok := tp.records.Load(requestID) require.True(t, ok) @@ -1340,7 +1393,7 @@ func TestHandleFeeBumpTx(t *testing.T) { outpointToTxIndex: utxoIndex, } - tp.storeRecord(requestID, sweepCtx, req, m.feeFunc) + tp.updateRecord(record, sweepCtx) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1)