diff --git a/itest/send_test.go b/itest/send_test.go index fe5d14012..f9c29205e 100644 --- a/itest/send_test.go +++ b/itest/send_test.go @@ -5,8 +5,6 @@ import ( "context" "encoding/hex" "fmt" - "io" - "strings" "sync" "testing" "time" @@ -41,10 +39,7 @@ func testBasicSendUnidirectional(t *harnessTest) { ) // Subscribe to receive assent send events from primary tapd node. - eventNtfns, err := t.tapd.SubscribeSendAssetEventNtfns( - ctxb, &tapdevrpc.SubscribeSendAssetEventNtfnsRequest{}, - ) - require.NoError(t.t, err) + events := SubscribeSendEvents(t.t, t.tapd) // Test to ensure that we execute the transaction broadcast state. // This test is executed in a goroutine to ensure that we can receive @@ -64,10 +59,9 @@ func testBasicSendUnidirectional(t *harnessTest) { } timeout := 2 * defaultProofTransferReceiverAckTimeout - ctx, cancel := context.WithTimeout(ctxb, timeout) - defer cancel() + assertAssetSendNtfsEvent( - t, ctx, eventNtfns, targetEventSelector, numSends, + t, events, timeout, targetEventSelector, numSends, ) }() @@ -122,7 +116,7 @@ func testBasicSendUnidirectional(t *harnessTest) { } // Close event stream. - err = eventNtfns.CloseSend() + err = events.CloseSend() require.NoError(t.t, err) wg.Wait() @@ -146,10 +140,7 @@ func testRestartReceiverCheckBalance(t *harnessTest) { ) // Subscribe to receive assent send events from primary tapd node. - eventNtfns, err := t.tapd.SubscribeSendAssetEventNtfns( - ctxb, &tapdevrpc.SubscribeSendAssetEventNtfnsRequest{}, - ) - require.NoError(t.t, err) + events := SubscribeSendEvents(t.t, t.tapd) // Test to ensure that we execute the transaction broadcast state. // This test is executed in a goroutine to ensure that we can receive @@ -169,10 +160,9 @@ func testRestartReceiverCheckBalance(t *harnessTest) { } timeout := 2 * defaultProofTransferReceiverAckTimeout - ctx, cancel := context.WithTimeout(ctxb, timeout) - defer cancel() + assertAssetSendNtfsEvent( - t, ctx, eventNtfns, targetEventSelector, 1, + t, events, timeout, targetEventSelector, 1, ) }() @@ -236,7 +226,7 @@ func testRestartReceiverCheckBalance(t *harnessTest) { AssertNonInteractiveRecvComplete(t.t, recvTapd, 1) // Close event stream. - err = eventNtfns.CloseSend() + err = events.CloseSend() require.NoError(t.t, err) wg.Wait() @@ -587,10 +577,7 @@ func testReattemptFailedSendHashmailCourier(t *harnessTest) { ) // Subscribe to receive asset send events from primary tapd node. - eventNtfns, err := sendTapd.SubscribeSendAssetEventNtfns( - ctxb, &tapdevrpc.SubscribeSendAssetEventNtfnsRequest{}, - ) - require.NoError(t.t, err) + events := SubscribeSendEvents(t.t, sendTapd) // Test to ensure that we receive the expected number of backoff wait // event notifications. @@ -622,11 +609,9 @@ func testReattemptFailedSendHashmailCourier(t *harnessTest) { defaultProofTransferReceiverAckTimeout // Add overhead buffer to context timeout. timeout += 5 * time.Second - ctx, cancel := context.WithTimeout(ctxb, timeout) - defer cancel() assertAssetSendNtfsEvent( - t, ctx, eventNtfns, targetEventSelector, + t, events, timeout, targetEventSelector, expectedEventCount, ) }() @@ -687,10 +672,7 @@ func testReattemptFailedSendUniCourier(t *harnessTest) { ) // Subscribe to receive asset send events from the sending tapd node. - eventNtfns, err := sendTapd.SubscribeSendAssetEventNtfns( - ctxb, &tapdevrpc.SubscribeSendAssetEventNtfnsRequest{}, - ) - require.NoError(t.t, err) + events := SubscribeSendEvents(t.t, sendTapd) // Test to ensure that we receive the expected number of backoff wait // event notifications. @@ -722,11 +704,9 @@ func testReattemptFailedSendUniCourier(t *harnessTest) { defaultProofTransferReceiverAckTimeout // Add overhead buffer to context timeout. timeout += 5 * time.Second - ctx, cancel := context.WithTimeout(ctxb, timeout) - defer cancel() assertAssetSendNtfsEvent( - t, ctx, eventNtfns, targetEventSelector, + t, events, timeout, targetEventSelector, expectedEventCount, ) }() @@ -846,10 +826,7 @@ func testReattemptFailedReceiveUniCourier(t *harnessTest) { // Subscribe to receive asset receive events from receiving tapd node. // We'll use these events to ensure that the receiver node is making // multiple attempts to retrieve the asset proof. - eventNtfns, err := receiveTapd.SubscribeReceiveAssetEventNtfns( - ctxb, &tapdevrpc.SubscribeReceiveAssetEventNtfnsRequest{}, - ) - require.NoError(t.t, err) + events := SubscribeReceiveEvents(t.t, receiveTapd) // Test to ensure that we receive the minimum expected number of backoff // wait event notifications. @@ -883,13 +860,11 @@ func testReattemptFailedReceiveUniCourier(t *harnessTest) { defaultProofTransferReceiverAckTimeout // Add overhead buffer to context timeout. timeout += 5 * time.Second - ctx, cancel := context.WithTimeout(ctxb, timeout) - defer cancel() // Assert that the receiver tapd node has accomplished our minimum // expected number of backoff procedure receive attempts. assertAssetRecvNtfsEvent( - t, ctx, eventNtfns, targetEventSelector, expectedEventCount, + t, timeout, events, targetEventSelector, expectedEventCount, ) t.Logf("Finished waiting for the receiving tapd node to complete " + @@ -911,7 +886,7 @@ func testReattemptFailedReceiveUniCourier(t *harnessTest) { // transfer and publishes an asset recv complete event. t.Logf("Check for asset recv complete event from receiver tapd node") assertAssetRecvCompleteEvent( - t, ctxb, 5*time.Second, recvAddr.Encoded, eventNtfns, + t, 5*time.Second, recvAddr.Encoded, events, ) } @@ -947,10 +922,7 @@ func testOfflineReceiverEventuallyReceives(t *harnessTest) { recvTapd := t.tapd // Subscribe to receive asset send events from primary tapd node. - eventNtfns, err := sendTapd.SubscribeSendAssetEventNtfns( - ctxb, &tapdevrpc.SubscribeSendAssetEventNtfnsRequest{}, - ) - require.NoError(t.t, err) + events := SubscribeSendEvents(t.t, sendTapd) // Test to ensure that we receive the expected number of backoff wait // event notifications. @@ -979,11 +951,9 @@ func testOfflineReceiverEventuallyReceives(t *harnessTest) { // Events must be received before a timeout. timeout := 5 * time.Second - ctx, cancel := context.WithTimeout(ctxb, timeout) - defer cancel() assertAssetSendNtfsEvent( - t, ctx, eventNtfns, targetEventSelector, + t, events, timeout, targetEventSelector, expectedEventCount, ) }() @@ -1034,36 +1004,48 @@ func testOfflineReceiverEventuallyReceives(t *harnessTest) { // assertAssetSendNtfsEvent asserts that the given asset send event notification // was received. This function will block until the event is received or the // event stream is closed. -func assertAssetSendNtfsEvent(t *harnessTest, ctx context.Context, - eventNtfns tapdevrpc.TapDev_SubscribeSendAssetEventNtfnsClient, +func assertAssetSendNtfsEvent(t *harnessTest, + stream *eventSubscription[*tapdevrpc.SendAssetEvent], + timeout time.Duration, targetEventSelector func(*tapdevrpc.SendAssetEvent) bool, expectedCount int) { + success := make(chan struct{}) + timeoutChan := time.After(timeout) + + // To make sure we don't forever hang on receiving on the stream, we'll + // cancel it after the timeout. + go func() { + select { + case <-timeoutChan: + stream.cancel() + + case <-success: + } + }() + countFound := 0 for { // Ensure that the context has not been cancelled. - require.NoError(t.t, ctx.Err()) + select { + case <-stream.Context().Done(): + require.NoError(t.t, stream.Context().Err()) - if countFound == expectedCount { break + default: } - event, err := eventNtfns.Recv() + if countFound == expectedCount { + break + } - // Break if we get an EOF, which means the stream was - // closed. - // - // Use string comparison here because the RPC protocol - // does not transport wrapped error structures. - if err != nil && - strings.Contains(err.Error(), io.EOF.Error()) { + event, err := stream.Recv() + if err != nil { + require.NoError(t.t, err) break } - // If err is not EOF, then we expect it to be nil. - require.NoError(t.t, err) - // Check for target state. if targetEventSelector(event) { countFound++ @@ -1076,35 +1058,48 @@ func assertAssetSendNtfsEvent(t *harnessTest, ctx context.Context, // assertAssetRecvNtfsEvent asserts that the given asset receive event // notification was received. This function will block until the event is // received or the event stream is closed. -func assertAssetRecvNtfsEvent(t *harnessTest, ctx context.Context, - eventNtfns tapdevrpc.TapDev_SubscribeReceiveAssetEventNtfnsClient, +func assertAssetRecvNtfsEvent(t *harnessTest, timeout time.Duration, + stream *eventSubscription[*tapdevrpc.ReceiveAssetEvent], targetEventSelector func(event *tapdevrpc.ReceiveAssetEvent) bool, expectedCount int) { + success := make(chan struct{}) + timeoutChan := time.After(timeout) + + // To make sure we don't forever hang on receiving on the stream, we'll + // cancel it after the timeout. + go func() { + select { + case <-timeoutChan: + stream.cancel() + + case <-success: + } + }() + countFound := 0 for { // Ensure that the context has not been cancelled. - require.NoError(t.t, ctx.Err()) + select { + case <-stream.Context().Done(): + require.NoError(t.t, stream.Context().Err()) - if countFound == expectedCount { break + default: } - event, err := eventNtfns.Recv() - - // Break if we get an EOF, which means the stream was - // closed. - // - // Use string comparison here because the RPC protocol - // does not transport wrapped error structures. - if err != nil && - strings.Contains(err.Error(), io.EOF.Error()) { + if countFound == expectedCount { + close(success) break } - // If err is not EOF, then we expect it to be nil. - require.NoError(t.t, err) + event, err := stream.Recv() + if err != nil { + require.NoError(t.t, err) + + break + } // Check for target state. if targetEventSelector(event) { @@ -1120,12 +1115,9 @@ func assertAssetRecvNtfsEvent(t *harnessTest, ctx context.Context, // assertAssetRecvNtfsEvent asserts that the given asset receive complete event // notification was received. This function will block until the event is // received or the event stream is closed. -func assertAssetRecvCompleteEvent(t *harnessTest, ctxb context.Context, +func assertAssetRecvCompleteEvent(t *harnessTest, timeout time.Duration, encodedAddr string, - eventNtfns tapdevrpc.TapDev_SubscribeReceiveAssetEventNtfnsClient) { - - ctx, cancel := context.WithTimeout(ctxb, timeout) - defer cancel() + stream *eventSubscription[*tapdevrpc.ReceiveAssetEvent]) { eventSelector := func(event *tapdevrpc.ReceiveAssetEvent) bool { switch eventTyped := event.Event.(type) { @@ -1137,7 +1129,7 @@ func assertAssetRecvCompleteEvent(t *harnessTest, ctxb context.Context, } } - assertAssetRecvNtfsEvent(t, ctx, eventNtfns, eventSelector, 1) + assertAssetRecvNtfsEvent(t, timeout, stream, eventSelector, 1) } // testMultiInputSendNonInteractiveSingleID tests that we can properly diff --git a/itest/utils.go b/itest/utils.go index 2b7fea626..e5d7377e5 100644 --- a/itest/utils.go +++ b/itest/utils.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntest/node" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "google.golang.org/protobuf/proto" ) @@ -28,6 +29,20 @@ var ( regtestParams = &chaincfg.RegressionNetParams ) +// clientEventStream is a generic interface for a client stream that allows us +// to receive events from a server. +type clientEventStream[T any] interface { + Recv() (T, error) + grpc.ClientStream +} + +// eventSubscription holds a generic client stream and its context cancel +// function. +type eventSubscription[T any] struct { + clientEventStream[T] + cancel context.CancelFunc +} + // CopyRequest is a helper function to copy a request so that we can modify it. func CopyRequest(req *mintrpc.MintAssetRequest) *mintrpc.MintAssetRequest { return proto.Clone(req).(*mintrpc.MintAssetRequest) @@ -396,3 +411,68 @@ func MintAssetsConfirmBatch(t *testing.T, minerClient *rpcclient.Client, return AssertAssetsMinted(t, tapClient, assetRequests, mintTXID, blockHash) } + +// SubscribeSendEvents subscribes to send events and returns the event stream. +func SubscribeSendEvents(t *testing.T, + tapd TapdClient) *eventSubscription[*tapdevrpc.SendAssetEvent] { + + ctxb := context.Background() + ctxt, cancel := context.WithCancel(ctxb) + + stream, err := tapd.SubscribeSendAssetEventNtfns( + ctxt, &tapdevrpc.SubscribeSendAssetEventNtfnsRequest{}, + ) + require.NoError(t, err) + + return &eventSubscription[*tapdevrpc.SendAssetEvent]{ + clientEventStream: stream, + cancel: cancel, + } +} + +// SubscribeReceiveEvents subscribes to receive events and returns the event +// stream. +func SubscribeReceiveEvents(t *testing.T, + tapd TapdClient) *eventSubscription[*tapdevrpc.ReceiveAssetEvent] { + + ctxb := context.Background() + ctxt, cancel := context.WithCancel(ctxb) + + stream, err := tapd.SubscribeReceiveAssetEventNtfns( + ctxt, &tapdevrpc.SubscribeReceiveAssetEventNtfnsRequest{}, + ) + require.NoError(t, err) + + return &eventSubscription[*tapdevrpc.ReceiveAssetEvent]{ + clientEventStream: stream, + cancel: cancel, + } +} + +// NewAddrWithEventStream creates a new TAP address and also registers a new +// event stream for receive events for the address. +func NewAddrWithEventStream(t *testing.T, tapd TapdClient, + req *taprpc.NewAddrRequest) (*taprpc.Addr, + *eventSubscription[*taprpc.ReceiveEvent]) { + + ctxb := context.Background() + ctxt, cancel := context.WithTimeout(ctxb, defaultWaitTimeout) + defer cancel() + + addr, err := tapd.NewAddr(ctxt, req) + require.NoError(t, err) + + ctxc, cancel := context.WithCancel(ctxb) + + stream, err := tapd.SubscribeReceiveEvents( + ctxc, &taprpc.SubscribeReceiveEventsRequest{ + FilterAddr: addr.Encoded, + }, + ) + require.NoError(t, err) + + return addr, &eventSubscription[*taprpc.ReceiveEvent]{ + clientEventStream: stream, + cancel: cancel, + } +}