diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 60352a62144..a9da019518c 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "time" "github.com/btcsuite/btcd/btcec/v2" @@ -54,6 +55,14 @@ type paymentLifecycle struct { // returned from the htlcswitch. switchResults lnutils.SyncMap[*channeldb.HTLCAttempt, *htlcswitch.PaymentResult] + + // activeCollectors tracks the number of active result collectors. + // So that we can resolve all acitve collectors and correctly fail + // a payment so that the payment only exits when all result collectors + // received their results. + // + // NOTE: To be used atomically. + activeCollectors int32 } // newPaymentLifecycle initiates a new payment lifecycle and returns it. @@ -398,6 +407,26 @@ func (p *paymentLifecycle) requestRoute( // stop signals any active shard goroutine to exit. func (p *paymentLifecycle) stop() { + // Process any remaining results that might have come in while we were + // shutting down. + for atomic.LoadInt32(&p.activeCollectors) > 0 { + select { + case <-p.resultCollected: + // Process any lingering results before exiting the + // payment lifecycle. + err := p.processSwitchResults() + if err != nil { + log.Errorf("Error processing final results "+ + "for payment %v during shutdown: %v", + p.identifier, err) + } + + case <-p.router.quit: + log.Infof("ChanRouter shutting down while collecting "+ + "lingering result for payment=%v", + p.identifier) + } + } close(p.quit) } @@ -422,7 +451,14 @@ func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { log.Debugf("Collecting result for attempt %v in payment %v", attempt.AttemptID, p.identifier) + // Increment the active collectors counter + atomic.AddInt32(&p.activeCollectors, 1) + go func() { + // Make sure we decrease the counter if this result collector + // exits. + defer atomic.AddInt32(&p.activeCollectors, -1) + result, err := p.collectResult(attempt) if err != nil { log.Errorf("Error collecting result for attempt %v in "+ @@ -443,11 +479,10 @@ func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { // Send the signal or quit. case p.resultCollected <- struct{}{}: - case <-p.quit: - log.Debugf("Lifecycle exiting while collecting "+ - "result for payment %v", p.identifier) - case <-p.router.quit: + log.Debugf("ChanRouter shutting down while collecting "+ + "result for payment=%v attemptID=%v", + p.identifier, attempt.AttemptID) } }() }