Skip to content

Commit

Permalink
sweep: refactor storeRecord to updateRecord
Browse files Browse the repository at this point in the history
To make it clear we are only updating fields, which will be handy for
the following commit where we start tracking for spending notifications.
  • Loading branch information
yyforyongyu committed Jan 27, 2025
1 parent 0ae9f1a commit ed4c7f7
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 55 deletions.
66 changes: 29 additions & 37 deletions sweep/fee_bumper.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,19 +441,19 @@ func (t *TxPublisher) storeInitialRecord(req *BumpRequest) *monitorRecord {
return record
}

// storeRecord stores the given record in the records map.
func (t *TxPublisher) storeRecord(requestID uint64, sweepCtx *sweepTxCtx,
req *BumpRequest, f FeeFunction) {
// updateRecord updates the given record's tx and fee, and saves it in the
// records map.
func (t *TxPublisher) updateRecord(r *monitorRecord,
sweepCtx *sweepTxCtx) *monitorRecord {

r.tx = sweepCtx.tx
r.fee = sweepCtx.fee
r.outpointToTxIndex = sweepCtx.outpointToTxIndex

// Register the record.
t.records.Store(requestID, &monitorRecord{
requestID: requestID,
tx: sweepCtx.tx,
req: req,
feeFunction: f,
fee: sweepCtx.fee,
outpointToTxIndex: sweepCtx.outpointToTxIndex,
})
t.records.Store(r.requestID, r)

return r
}

// NOTE: part of the `chainio.Consumer` interface.
Expand All @@ -463,11 +463,11 @@ func (t *TxPublisher) Name() string {

// initializeTx initializes a fee function and creates an RBF-compliant tx. If
// succeeded, the initial tx is stored in the records map.
func (t *TxPublisher) initializeTx(r *monitorRecord) error {
func (t *TxPublisher) initializeTx(r *monitorRecord) (*monitorRecord, error) {
// Create a fee bumping algorithm to be used for future RBF.
feeAlgo, err := t.initializeFeeFunction(r.req)
if err != nil {
return fmt.Errorf("init fee function: %w", err)
return nil, fmt.Errorf("init fee function: %w", err)
}

// Attach the newly created fee function.
Expand All @@ -481,12 +481,12 @@ func (t *TxPublisher) initializeTx(r *monitorRecord) error {

// Create the initial tx to be broadcasted. This tx is guaranteed to
// comply with the RBF restrictions.
err = t.createRBFCompliantTx(r)
record, err := t.createRBFCompliantTx(r)
if err != nil {
return fmt.Errorf("create RBF-compliant tx: %w", err)
return nil, fmt.Errorf("create RBF-compliant tx: %w", err)
}

return nil
return record, nil
}

// initializeFeeFunction initializes a fee function to be used for this request
Expand Down Expand Up @@ -522,7 +522,9 @@ func (t *TxPublisher) initializeFeeFunction(
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
// and redo the process until the tx is valid, or return an error when non-RBF
// related errors occur or the budget has been used up.
func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
func (t *TxPublisher) createRBFCompliantTx(
r *monitorRecord) (*monitorRecord, error) {

f := r.feeFunction

for {
Expand All @@ -533,15 +535,15 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
switch {
case err == nil:
// The tx is valid, store it.
t.storeRecord(r.requestID, sweepCtx, r.req, f)
record := t.updateRecord(r, sweepCtx)

log.Infof("Created initial sweep tx=%v for %v inputs: "+
"feerate=%v, fee=%v, inputs:\n%v",
sweepCtx.tx.TxHash(), len(r.req.Inputs),
f.FeeRate(), sweepCtx.fee,
inputTypeSummary(r.req.Inputs))

return nil
return record, nil

// If the error indicates the fees paid is not enough, we will
// ask the fee function to increase the fee rate and retry.
Expand Down Expand Up @@ -572,7 +574,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
// cluster these inputs differetly.
increased, err = f.Increment()
if err != nil {
return err
return nil, err
}
}

Expand All @@ -582,7 +584,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
// mempool acceptance.
default:
log.Debugf("Failed to create RBF-compliant tx: %v", err)
return err
return nil, err
}
}
}
Expand Down Expand Up @@ -645,13 +647,7 @@ func (t *TxPublisher) createAndCheckTx(req *BumpRequest,
// the event channel to the record. Any broadcast-related errors will not be
// returned here, instead, they will be put inside the `BumpResult` and
// returned to the caller.
func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) {
// Get the record being monitored.
record, ok := t.records.Load(requestID)
if !ok {
return nil, fmt.Errorf("tx record %v not found", requestID)
}

func (t *TxPublisher) broadcast(record *monitorRecord) (*BumpResult, error) {
txid := record.tx.TxHash()

tx := record.tx
Expand Down Expand Up @@ -698,7 +694,7 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) {
Fee: record.fee,
FeeRate: record.feeFunction.FeeRate(),
Err: err,
requestID: requestID,
requestID: record.requestID,
}

return result, nil
Expand Down Expand Up @@ -801,10 +797,6 @@ type monitorRecord struct {

// outpointToTxIndex is a map of outpoint to tx index.
outpointToTxIndex map[wire.OutPoint]int

// spendNotifiers is a map of spend notifiers registered for all the
// inputs.
spendNotifiers map[wire.OutPoint]*chainntnfs.SpendEvent
}

// Start starts the publisher by subscribing to block epoch updates and kicking
Expand Down Expand Up @@ -1047,7 +1039,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
// RBF rules.
//
// Create the initial tx to be broadcasted.
err = t.initializeTx(r)
record, err := t.initializeTx(r)
if err != nil {
log.Errorf("Initial broadcast failed: %v", err)

Expand All @@ -1058,7 +1050,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
}

// Successfully created the first tx, now broadcast it.
result, err = t.broadcast(r.requestID)
result, err = t.broadcast(record)
if err != nil {
// The broadcast failed, which can only happen if the tx record
// cannot be found or the aux sweeper returns an error. In
Expand Down Expand Up @@ -1203,10 +1195,10 @@ func (t *TxPublisher) createAndPublishTx(

// The tx has been created without any errors, we now register a new
// record by overwriting the same requestID.
t.storeRecord(r.requestID, sweepCtx, r.req, r.feeFunction)
record := t.updateRecord(r, sweepCtx)

// Attempt to broadcast this new tx.
result, err := t.broadcast(r.requestID)
result, err := t.broadcast(record)
if err != nil {
log.Infof("Failed to broadcast replacement tx %v: %v",
sweepCtx.tx.TxHash(), err)
Expand Down
89 changes: 71 additions & 18 deletions sweep/fee_bumper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,9 @@ func TestInitializeFeeFunction(t *testing.T) {
require.Equal(t, feerate, f.FeeRate())
}

// TestStoreRecord correctly increases the request counter and saves the
// TestUpdateRecord correctly updates the fields fee and tx, and saves the
// record.
func TestStoreRecord(t *testing.T) {
func TestUpdateRecord(t *testing.T) {
t.Parallel()

// Create a test input.
Expand Down Expand Up @@ -358,8 +358,15 @@ func TestStoreRecord(t *testing.T) {
outpointToTxIndex: utxoIndex,
}

// Create a test record.
record := &monitorRecord{
requestID: initialCounter,
req: req,
feeFunction: feeFunc,
}

// Call the method under test.
tp.storeRecord(initialCounter, sweepCtx, req, feeFunc)
tp.updateRecord(record, sweepCtx)

// Read the saved record and compare.
record, ok := tp.records.Load(initialCounter)
Expand Down Expand Up @@ -676,10 +683,19 @@ func TestCreateRBFCompliantTx(t *testing.T) {
tc.setupMock()

// Call the method under test.
err := tp.createRBFCompliantTx(record)
rec, err := tp.createRBFCompliantTx(record)

// Check the result is as expected.
require.ErrorIs(t, err, tc.expectedErr)

if tc.expectedErr != nil {
return
}

// Assert the returned record has the following fields
// populated.
require.NotEmpty(t, rec.tx)
require.NotEmpty(t, rec.fee)
})
}
}
Expand Down Expand Up @@ -721,13 +737,13 @@ func TestTxPublisherBroadcast(t *testing.T) {
outpointToTxIndex: utxoIndex,
}

tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)

// Quickly check when the requestID cannot be found, an error is
// returned.
result, err := tp.broadcast(uint64(1000))
require.Error(t, err)
require.Nil(t, result)
// Create a test record.
record := &monitorRecord{
requestID: requestID,
req: req,
feeFunction: m.feeFunc,
}
rec := tp.updateRecord(record, sweepCtx)

testCases := []struct {
name string
Expand Down Expand Up @@ -782,7 +798,7 @@ func TestTxPublisherBroadcast(t *testing.T) {
tc.setupMock()

// Call the method under test.
result, err := tp.broadcast(requestID)
result, err := tp.broadcast(rec)

// Check the result is as expected.
require.ErrorIs(t, err, tc.expectedErr)
Expand Down Expand Up @@ -838,7 +854,15 @@ func TestRemoveResult(t *testing.T) {
name: "remove on TxConfirmed",
setupRecord: func() uint64 {
rid := requestCounter.Add(1)
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)

// Create a test record.
record := &monitorRecord{
requestID: rid,
req: req,
feeFunction: m.feeFunc,
}

tp.updateRecord(record, sweepCtx)
tp.subscriberChans.Store(rid, nil)

return rid
Expand All @@ -854,7 +878,15 @@ func TestRemoveResult(t *testing.T) {
name: "remove on TxFailed",
setupRecord: func() uint64 {
rid := requestCounter.Add(1)
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)

// Create a test record.
record := &monitorRecord{
requestID: rid,
req: req,
feeFunction: m.feeFunc,
}

tp.updateRecord(record, sweepCtx)
tp.subscriberChans.Store(rid, nil)

return rid
Expand All @@ -871,7 +903,15 @@ func TestRemoveResult(t *testing.T) {
name: "noop when tx is not confirmed or failed",
setupRecord: func() uint64 {
rid := requestCounter.Add(1)
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)

// Create a test record.
record := &monitorRecord{
requestID: rid,
req: req,
feeFunction: m.feeFunc,
}

tp.updateRecord(record, sweepCtx)
tp.subscriberChans.Store(rid, nil)

return rid
Expand Down Expand Up @@ -937,8 +977,14 @@ func TestNotifyResult(t *testing.T) {
fee: fee,
outpointToTxIndex: utxoIndex,
}
// Create a test record.
record := &monitorRecord{
requestID: requestID,
req: req,
feeFunction: m.feeFunc,
}

tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
tp.updateRecord(record, sweepCtx)

// Create a subscription to the event.
subscriber := make(chan *BumpResult, 1)
Expand Down Expand Up @@ -1250,7 +1296,14 @@ func TestHandleTxConfirmed(t *testing.T) {
outpointToTxIndex: utxoIndex,
}

tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
// Create a test record.
record := &monitorRecord{
requestID: requestID,
req: req,
feeFunction: m.feeFunc,
}

tp.updateRecord(record, sweepCtx)
record, ok := tp.records.Load(requestID)
require.True(t, ok)

Expand Down Expand Up @@ -1340,7 +1393,7 @@ func TestHandleFeeBumpTx(t *testing.T) {
outpointToTxIndex: utxoIndex,
}

tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
tp.updateRecord(record, sweepCtx)

// Create a subscription to the event.
subscriber := make(chan *BumpResult, 1)
Expand Down

0 comments on commit ed4c7f7

Please sign in to comment.