Skip to content

Commit

Permalink
htlcswitch: add test for deferred processing remote adds when quiescent
Browse files Browse the repository at this point in the history
  • Loading branch information
ProofOfKeags committed Jul 26, 2024
1 parent 30c3a5c commit fa88ebb
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 2 deletions.
100 changes: 100 additions & 0 deletions htlcswitch/link_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7489,3 +7489,103 @@ func TestLinkFlushHooksCalled(t *testing.T) {
ctx.receiveRevAndAckAliceToBob()
assertHookCalled(true)
}

// TestLinkQuiescenceExitHopProcessingDeferred ensures that we do not send back
// htlc resolution messages in the case where the link is quiescent AND we are
// the exit hop. This is needed because we handle exit hop processing in the
// link instead of the switch and we process htlc resolutions when we receive
// a RevokeAndAck. Because of this we need to ensure that we hold off on
// processing the remote adds when we are quiescent. Later, when the channel
// update traffic is allowed to resume, we will need to verify that the actions
// we didn't run during the initial RevokeAndAck are run.
func TestLinkQuiescenceExitHopProcessingDeferred(t *testing.T) {
// Initialize two channel state machines for testing.
alice, bob, err := createTwoClusterChannels(
t, btcutil.SatoshiPerBitcoin, btcutil.SatoshiPerBitcoin,
)
require.NoError(t, err)

// Build a single edge network to test channel quiescence.
network := newTwoHopNetwork(
t, alice.channel, bob.channel, testStartingHeight,
)
aliceLink := network.aliceChannelLink
bobLink := network.bobChannelLink

// Generate an invoice for Bob so that Alice can pay him.
htlcID := uint64(0)
htlc, invoice := generateHtlcAndInvoice(t, htlcID)
err = network.bobServer.registry.AddInvoice(
nil, *invoice, htlc.PaymentHash,
)
require.NoError(t, err)

// Establish a payment circuit for Alice
circuit := &PaymentCircuit{
Incoming: CircuitKey{
HtlcID: htlcID,
},
PaymentHash: htlc.PaymentHash,
}
circuitMap := network.aliceServer.htlcSwitch.circuits
_, err = circuitMap.CommitCircuits(circuit)
require.NoError(t, err)

// Add a switch packet to Alice's switch so that she can initialize the
// payment attempt.
err = aliceLink.handleSwitchPacket(&htlcPacket{
incomingHTLCID: htlcID,
htlc: htlc,
circuit: circuit,
})
require.NoError(t, err)

// give alice enough time to fire the update_add
// TODO(proofofkeags): make this not depend on a flakey sleep.
<-time.After(time.Millisecond)

// bob initiates stfu which he can do immediately since he doesn't have
// local updates
<-bobLink.InitStfu()

// wait for other possible messages to play out
<-time.After(1 * time.Second)

ensureNoUpdateAfterStfu := func(t *testing.T, trace []lnwire.Message) {
stfuReceived := false
for _, msg := range trace {
if msg.MsgType() == lnwire.MsgStfu {
stfuReceived = true
continue
}

if stfuReceived {
switch msg.MsgType() {
case lnwire.MsgUpdateAddHTLC:
fallthrough
case lnwire.MsgUpdateFulfillHTLC:
fallthrough
case lnwire.MsgUpdateFailHTLC:
fallthrough
case lnwire.MsgUpdateFailMalformedHTLC:
fallthrough
case lnwire.MsgUpdateFee:
t.Fatalf("channel update "+
"after stfu: %v",
msg.MsgType())
default:
}
}
}
}

network.aliceServer.protocolTraceMtx.Lock()
ensureNoUpdateAfterStfu(t, network.aliceServer.protocolTrace)
network.aliceServer.protocolTraceMtx.Unlock()

network.bobServer.protocolTraceMtx.Lock()
ensureNoUpdateAfterStfu(t, network.bobServer.protocolTrace)
network.bobServer.protocolTraceMtx.Unlock()

// TODO(proofofkeags): make sure these actions are run on resume.
}
12 changes: 10 additions & 2 deletions htlcswitch/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ type mockServer struct {

t testing.TB

name string
messages chan lnwire.Message
name string
messages chan lnwire.Message
protocolTraceMtx sync.Mutex
protocolTrace []lnwire.Message

id [33]byte
htlcSwitch *Switch
Expand Down Expand Up @@ -288,6 +290,10 @@ func (s *mockServer) Start() error {
for {
select {
case msg := <-s.messages:
s.protocolTraceMtx.Lock()
s.protocolTrace = append(s.protocolTrace, msg)
s.protocolTraceMtx.Unlock()

var shouldSkip bool

for _, interceptor := range s.interceptorFuncs {
Expand Down Expand Up @@ -626,6 +632,8 @@ func (s *mockServer) readHandler(message lnwire.Message) error {
targetChan = msg.ChanID
case *lnwire.UpdateFee:
targetChan = msg.ChanID
case *lnwire.Stfu:
targetChan = msg.ChanID
default:
return fmt.Errorf("unknown message type: %T", msg)
}
Expand Down

0 comments on commit fa88ebb

Please sign in to comment.