Skip to content

Commit

Permalink
sweep: add requestID to monitorRecord
Browse files Browse the repository at this point in the history
This way we can greatly simplify the method signatures, also paving the
upcoming changes where we wanna make it clear when updating the
monitorRecord, we only touch a portion of it.
  • Loading branch information
yyforyongyu committed Jan 27, 2025
1 parent 0ac5bfa commit 0ae9f1a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 63 deletions.
106 changes: 59 additions & 47 deletions sweep/fee_bumper.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,34 +410,35 @@ func (t *TxPublisher) Broadcast(req *BumpRequest) <-chan *BumpResult {
lnutils.SpewLogClosure(req))

// Store the request.
requestID, record := t.storeInitialRecord(req)
record := t.storeInitialRecord(req)

// Create a chan to send the result to the caller.
subscriber := make(chan *BumpResult, 1)
t.subscriberChans.Store(requestID, subscriber)
t.subscriberChans.Store(record.requestID, subscriber)

// Publish the tx immediately if specified.
if req.Immediate {
t.handleInitialBroadcast(record, requestID)
t.handleInitialBroadcast(record)
}

return subscriber
}

// storeInitialRecord initializes a monitor record and saves it in the map.
func (t *TxPublisher) storeInitialRecord(req *BumpRequest) (
uint64, *monitorRecord) {

func (t *TxPublisher) storeInitialRecord(req *BumpRequest) *monitorRecord {
// Increase the request counter.
//
// NOTE: this is the only place where we increase the counter.
requestID := t.requestCounter.Add(1)

// Register the record.
record := &monitorRecord{req: req}
record := &monitorRecord{
requestID: requestID,
req: req,
}
t.records.Store(requestID, record)

return requestID, record
return record
}

// storeRecord stores the given record in the records map.
Expand All @@ -446,6 +447,7 @@ func (t *TxPublisher) storeRecord(requestID uint64, sweepCtx *sweepTxCtx,

// Register the record.
t.records.Store(requestID, &monitorRecord{
requestID: requestID,
tx: sweepCtx.tx,
req: req,
feeFunction: f,
Expand All @@ -461,16 +463,25 @@ func (t *TxPublisher) Name() string {

// initializeTx initializes a fee function and creates an RBF-compliant tx. If
// succeeded, the initial tx is stored in the records map.
func (t *TxPublisher) initializeTx(requestID uint64, req *BumpRequest) error {
func (t *TxPublisher) initializeTx(r *monitorRecord) error {
// Create a fee bumping algorithm to be used for future RBF.
feeAlgo, err := t.initializeFeeFunction(req)
feeAlgo, err := t.initializeFeeFunction(r.req)
if err != nil {
return fmt.Errorf("init fee function: %w", err)
}

// Attach the newly created fee function.
//
// TODO(yy): current we'd initialize a monitorRecord before creating the
// fee function, while we could instead create the fee function first
// then save it to the record. To make this happen we need to change the
// conf target calculation below since we would be initializing the fee
// function one block before.
r.feeFunction = feeAlgo

// Create the initial tx to be broadcasted. This tx is guaranteed to
// comply with the RBF restrictions.
err = t.createRBFCompliantTx(requestID, req, feeAlgo)
err = t.createRBFCompliantTx(r)
if err != nil {
return fmt.Errorf("create RBF-compliant tx: %w", err)
}
Expand Down Expand Up @@ -511,24 +522,24 @@ func (t *TxPublisher) initializeFeeFunction(
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
// and redo the process until the tx is valid, or return an error when non-RBF
// related errors occur or the budget has been used up.
func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest,
f FeeFunction) error {
func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
f := r.feeFunction

for {
// Create a new tx with the given fee rate and check its
// mempool acceptance.
sweepCtx, err := t.createAndCheckTx(req, f)
sweepCtx, err := t.createAndCheckTx(r.req, f)

switch {
case err == nil:
// The tx is valid, store it.
t.storeRecord(requestID, sweepCtx, req, f)
t.storeRecord(r.requestID, sweepCtx, r.req, f)

log.Infof("Created initial sweep tx=%v for %v inputs: "+
"feerate=%v, fee=%v, inputs:\n%v",
sweepCtx.tx.TxHash(), len(req.Inputs),
sweepCtx.tx.TxHash(), len(r.req.Inputs),
f.FeeRate(), sweepCtx.fee,
inputTypeSummary(req.Inputs))
inputTypeSummary(r.req.Inputs))

return nil

Expand Down Expand Up @@ -773,6 +784,9 @@ func (t *TxPublisher) handleResult(result *BumpResult) {
// monitorRecord is used to keep track of the tx being monitored by the
// publisher internally.
type monitorRecord struct {
// requestID is the ID of the request that created this record.
requestID uint64

// tx is the tx being monitored.
tx *wire.MsgTx

Expand All @@ -787,6 +801,10 @@ type monitorRecord struct {

// outpointToTxIndex is a map of outpoint to tx index.
outpointToTxIndex map[wire.OutPoint]int

// spendNotifiers is a map of spend notifiers registered for all the
// inputs.
spendNotifiers map[wire.OutPoint]*chainntnfs.SpendEvent
}

// Start starts the publisher by subscribing to block epoch updates and kicking
Expand Down Expand Up @@ -915,51 +933,51 @@ func (t *TxPublisher) processRecords() {
t.records.ForEach(visitor)

// Handle the initial broadcast.
for requestID, r := range initialRecords {
t.handleInitialBroadcast(r, requestID)
for _, r := range initialRecords {
t.handleInitialBroadcast(r)
}

// For records that are confirmed, we'll notify the caller about this
// result.
for requestID, r := range confirmedRecords {
for _, r := range confirmedRecords {
log.Debugf("Tx=%v is confirmed", r.tx.TxHash())
t.wg.Add(1)
go t.handleTxConfirmed(r, requestID)
go t.handleTxConfirmed(r)
}

// Get the current height to be used in the following goroutines.
currentHeight := t.currentHeight.Load()

// For records that are not confirmed, we perform a fee bump if needed.
for requestID, r := range feeBumpRecords {
for _, r := range feeBumpRecords {
log.Debugf("Attempting to fee bump Tx=%v", r.tx.TxHash())
t.wg.Add(1)
go t.handleFeeBumpTx(requestID, r, currentHeight)
go t.handleFeeBumpTx(r, currentHeight)
}

// For records that are failed, we'll notify the caller about this
// result.
for requestID, r := range failedRecords {
for _, r := range failedRecords {
log.Debugf("Tx=%v has inputs been spent by a third party, "+
"failing it now", r.tx.TxHash())
t.wg.Add(1)
go t.handleThirdPartySpent(r, requestID)
go t.handleThirdPartySpent(r)
}
}

// handleTxConfirmed is called when a monitored tx is confirmed. It will
// notify the subscriber then remove the record from the maps .
//
// NOTE: Must be run as a goroutine to avoid blocking on sending the result.
func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) {
func (t *TxPublisher) handleTxConfirmed(r *monitorRecord) {
defer t.wg.Done()

// Create a result that will be sent to the resultChan which is
// listened by the caller.
result := &BumpResult{
Event: TxConfirmed,
Tx: r.tx,
requestID: requestID,
requestID: r.requestID,
Fee: r.fee,
FeeRate: r.feeFunction.FeeRate(),
}
Expand Down Expand Up @@ -1017,10 +1035,8 @@ func (t *TxPublisher) handleInitialTxError(requestID uint64, err error) {
// 1. init a fee function based on the given strategy.
// 2. create an RBF-compliant tx and monitor it for confirmation.
// 3. notify the initial broadcast result back to the caller.
func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
requestID uint64) {

log.Debugf("Initial broadcast for requestID=%v", requestID)
func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
log.Debugf("Initial broadcast for requestID=%v", r.requestID)

var (
result *BumpResult
Expand All @@ -1031,18 +1047,18 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
// RBF rules.
//
// Create the initial tx to be broadcasted.
err = t.initializeTx(requestID, r.req)
err = t.initializeTx(r)
if err != nil {
log.Errorf("Initial broadcast failed: %v", err)

// We now handle the initialization error and exit.
t.handleInitialTxError(requestID, err)
t.handleInitialTxError(r.requestID, err)

return
}

// Successfully created the first tx, now broadcast it.
result, err = t.broadcast(requestID)
result, err = t.broadcast(r.requestID)
if err != nil {
// The broadcast failed, which can only happen if the tx record
// cannot be found or the aux sweeper returns an error. In
Expand All @@ -1051,7 +1067,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
result = &BumpResult{
Event: TxFailed,
Err: err,
requestID: requestID,
requestID: r.requestID,
}
}

Expand All @@ -1062,9 +1078,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
// attempt to bump the fee of the tx.
//
// NOTE: Must be run as a goroutine to avoid blocking on sending the result.
func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord,
currentHeight int32) {

func (t *TxPublisher) handleFeeBumpTx(r *monitorRecord, currentHeight int32) {
defer t.wg.Done()

oldTxid := r.tx.TxHash()
Expand Down Expand Up @@ -1095,7 +1109,7 @@ func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord,

// The fee function now has a new fee rate, we will use it to bump the
// fee of the tx.
resultOpt := t.createAndPublishTx(requestID, r)
resultOpt := t.createAndPublishTx(r)

// If there's a result, we will notify the caller about the result.
resultOpt.WhenSome(func(result BumpResult) {
Expand All @@ -1109,9 +1123,7 @@ func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord,
// and send a TxFailed event to the subscriber.
//
// NOTE: Must be run as a goroutine to avoid blocking on sending the result.
func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord,
requestID uint64) {

func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord) {
defer t.wg.Done()

// Create a result that will be sent to the resultChan which is
Expand All @@ -1123,7 +1135,7 @@ func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord,
result := &BumpResult{
Event: TxFailed,
Tx: r.tx,
requestID: requestID,
requestID: r.requestID,
Err: ErrThirdPartySpent,
}

Expand All @@ -1134,7 +1146,7 @@ func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord,
// createAndPublishTx creates a new tx with a higher fee rate and publishes it
// to the network. It will update the record with the new tx and fee rate if
// successfully created, and return the result when published successfully.
func (t *TxPublisher) createAndPublishTx(requestID uint64,
func (t *TxPublisher) createAndPublishTx(
r *monitorRecord) fn.Option[BumpResult] {

// Fetch the old tx.
Expand Down Expand Up @@ -1185,16 +1197,16 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64,
Event: TxFailed,
Tx: oldTx,
Err: err,
requestID: requestID,
requestID: r.requestID,
})
}

// The tx has been created without any errors, we now register a new
// record by overwriting the same requestID.
t.storeRecord(requestID, sweepCtx, r.req, r.feeFunction)
t.storeRecord(r.requestID, sweepCtx, r.req, r.feeFunction)

// Attempt to broadcast this new tx.
result, err := t.broadcast(requestID)
result, err := t.broadcast(r.requestID)
if err != nil {
log.Infof("Failed to broadcast replacement tx %v: %v",
sweepCtx.tx.TxHash(), err)
Expand Down
Loading

0 comments on commit 0ae9f1a

Please sign in to comment.