From 5c18b5a042ef823ab279885e2fbaceaab7885d5e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 14 Jun 2024 18:34:09 -0400 Subject: [PATCH 01/20] routing: remove un-used method from routingGraph interface We really want to narrow down the interface we provide the router, so let's start here. --- routing/graph.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/routing/graph.go b/routing/graph.go index 3e466a3df6..dafadb8923 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -23,11 +23,6 @@ type routingGraph interface { // fetchNodeFeatures returns the features of the given node. fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) - - // FetchAmountPairCapacity determines the maximal capacity between two - // pairs of nodes. - FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, - amount lnwire.MilliSatoshi) (btcutil.Amount, error) } // CachedGraph is a routingGraph implementation that retrieves from the @@ -97,8 +92,6 @@ func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) ( // FetchAmountPairCapacity determines the maximal public capacity between two // nodes depending on the amount we try to send. -// -// NOTE: Part of the routingGraph interface. func (g *CachedGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { From 5a903c270f3c1e5e4a049a0e407c744ea796dd39 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 14 Jun 2024 18:47:15 -0400 Subject: [PATCH 02/20] routing: remove sourceNode from routingGraph interface In this commit, we further reduce the routingGraph interface and this time we make it more node-agnostic so that it can be backed by any graph and not one with a concept of "sourceNode". --- routing/graph.go | 12 +----------- routing/integrated_routing_context_test.go | 2 +- routing/pathfind.go | 9 ++++----- routing/pathfind_test.go | 3 ++- routing/payment_session.go | 11 ++++++----- routing/payment_session_source.go | 4 ++-- routing/payment_session_test.go | 10 +++++----- routing/router.go | 5 +++-- 8 files changed, 24 insertions(+), 32 deletions(-) diff --git a/routing/graph.go b/routing/graph.go index dafadb8923..1f0abf9c06 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -18,9 +18,6 @@ type routingGraph interface { forEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error - // sourceNode returns the source node of the graph. - sourceNode() route.Vertex - // fetchNodeFeatures returns the features of the given node. fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) } @@ -73,13 +70,6 @@ func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex, return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) } -// sourceNode returns the source node of the graph. -// -// NOTE: Part of the routingGraph interface. -func (g *CachedGraph) sourceNode() route.Vertex { - return g.source -} - // fetchNodeFeatures returns the features of the given node. If the node is // unknown, assume no additional features are supported. // @@ -99,7 +89,7 @@ func (g *CachedGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, // // Note: Inbound fees are not used here because this method is only used // by a deprecated router rpc. - u := newNodeEdgeUnifier(g.sourceNode(), nodeTo, false, nil) + u := newNodeEdgeUnifier(g.source, nodeTo, false, nil) err := u.addGraphPolicies(g) if err != nil { diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 4215d3b254..95a5eaf65f 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -200,7 +200,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, } session, err := newPaymentSession( - &payment, getBandwidthHints, + &payment, c.graph.source.pubkey, getBandwidthHints, func() (routingGraph, func(), error) { return c.graph, func() {}, nil }, diff --git a/routing/pathfind.go b/routing/pathfind.go index 208a550858..d7d2893b0b 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -48,7 +48,7 @@ const ( // pathFinder defines the interface of a path finding algorithm. type pathFinder = func(g *graphParams, r *RestrictParams, - cfg *PathFindingConfig, source, target route.Vertex, + cfg *PathFindingConfig, self, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( []*unifiedEdge, float64, error) @@ -521,8 +521,9 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // path and accurately check the amount to forward at every node against the // available bandwidth. func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, - source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, - finalHtlcExpiry int32) ([]*unifiedEdge, float64, error) { + self, source, target route.Vertex, amt lnwire.MilliSatoshi, + timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64, + error) { // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to @@ -583,8 +584,6 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // If we are routing from ourselves, check that we have enough local // balance available. - self := g.graph.sourceNode() - if source == self { max, total, err := getOutgoingBalance( self, outgoingChanMap, g.bandwidthHints, g.graph, diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 0f2a2659b1..a35c9e2f7b 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3218,7 +3218,8 @@ func dbFindPath(graph *channeldb.ChannelGraph, bandwidthHints: bandwidthHints, graph: routingGraph, }, - r, cfg, source, target, amt, timePref, finalHtlcExpiry, + r, cfg, sourceNode.PubKeyBytes, source, target, amt, timePref, + finalHtlcExpiry, ) return route, err diff --git a/routing/payment_session.go b/routing/payment_session.go index 2d174244c8..bdd1948128 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -163,6 +163,8 @@ type PaymentSession interface { // loop if payment attempts take long enough. An additional set of edges can // also be provided to assist in reaching the payment's destination. type paymentSession struct { + selfNode route.Vertex + additionalEdges map[route.Vertex][]AdditionalEdge getBandwidthHints func(routingGraph) (bandwidthHints, error) @@ -192,7 +194,7 @@ type paymentSession struct { } // newPaymentSession instantiates a new payment session. -func newPaymentSession(p *LightningPayment, +func newPaymentSession(p *LightningPayment, selfNode route.Vertex, getBandwidthHints func(routingGraph) (bandwidthHints, error), getRoutingGraph func() (routingGraph, func(), error), missionControl MissionController, pathFindingConfig PathFindingConfig) ( @@ -206,6 +208,7 @@ func newPaymentSession(p *LightningPayment, logPrefix := fmt.Sprintf("PaymentSession(%x):", p.Identifier()) return &paymentSession{ + selfNode: selfNode, additionalEdges: edges, getBandwidthHints: getBandwidthHints, payment: p, @@ -296,8 +299,6 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, p.log.Debugf("pathfinding for amt=%v", maxAmt) - sourceVertex := routingGraph.sourceNode() - // Find a route for the current amount. path, _, err := p.pathFinder( &graphParams{ @@ -306,7 +307,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, graph: routingGraph, }, restrictions, &p.pathFindingConfig, - sourceVertex, p.payment.Target, + p.selfNode, p.selfNode, p.payment.Target, maxAmt, p.payment.TimePref, finalHtlcExpiry, ) @@ -384,7 +385,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // this into a route by applying the time-lock and fee // requirements. route, err := newRoute( - sourceVertex, path, height, + p.selfNode, path, height, finalHopParams{ amt: maxAmt, totalAmt: p.payment.Amount, diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index b96a2294ba..ba010391bd 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -73,8 +73,8 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( } session, err := newPaymentSession( - p, getBandwidthHints, m.getRoutingGraph, - m.MissionControl, m.PathFindingConfig, + p, m.SourceNode.PubKeyBytes, getBandwidthHints, + m.getRoutingGraph, m.MissionControl, m.PathFindingConfig, ) if err != nil { return nil, err diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 75b84a51a3..b7efed5b7c 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -115,7 +115,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { // Create the paymentsession. session, err := newPaymentSession( - payment, + payment, route.Vertex{}, func(routingGraph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, @@ -195,7 +195,7 @@ func TestRequestRoute(t *testing.T) { } session, err := newPaymentSession( - payment, + payment, route.Vertex{}, func(routingGraph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, @@ -211,9 +211,9 @@ func TestRequestRoute(t *testing.T) { // Override pathfinder with a mock. session.pathFinder = func(_ *graphParams, r *RestrictParams, - _ *PathFindingConfig, _, _ route.Vertex, _ lnwire.MilliSatoshi, - _ float64, _ int32) ([]*unifiedEdge, float64, - error) { + _ *PathFindingConfig, _, _, _ route.Vertex, + _ lnwire.MilliSatoshi, _ float64, _ int32) ([]*unifiedEdge, + float64, error) { // We expect find path to receive a cltv limit excluding the // final cltv delta (including the block padding). diff --git a/routing/router.go b/routing/router.go index 149cd34156..597705754b 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2148,8 +2148,9 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, bandwidthHints: bandwidthHints, graph: r.cachedGraph, }, - req.Restrictions, &r.cfg.PathFindingConfig, req.Source, - req.Target, req.Amount, req.TimePreference, finalHtlcExpiry, + req.Restrictions, &r.cfg.PathFindingConfig, + r.selfNode.PubKeyBytes, req.Source, req.Target, req.Amount, + req.TimePreference, finalHtlcExpiry, ) if err != nil { return nil, 0, err From 3f121cbe81282c308ebdb8f86155e919de702644 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 25 Jun 2024 19:22:00 -0700 Subject: [PATCH 03/20] routing: rename and export routingGraph In preparation for structs outside of the `routing` package implementing this interface, export `routingGraph` and rename it to `Graph` so as to avoid stuttering. --- routing/bandwidth.go | 4 +-- routing/graph.go | 32 +++++++++++----------- routing/integrated_routing_context_test.go | 4 +-- routing/mock_graph_test.go | 18 ++++++------ routing/pathfind.go | 10 +++---- routing/payment_session.go | 8 +++--- routing/payment_session_source.go | 4 +-- routing/payment_session_test.go | 10 +++---- routing/router.go | 6 ++-- routing/unified_edges.go | 4 +-- 10 files changed, 50 insertions(+), 50 deletions(-) diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 19c6087018..0868255685 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -39,7 +39,7 @@ type bandwidthManager struct { // hints for the edges we directly have open ourselves. Obtaining these hints // allows us to reduce the number of extraneous attempts as we can skip channels // that are inactive, or just don't have enough bandwidth to carry the payment. -func newBandwidthManager(graph routingGraph, sourceNode route.Vertex, +func newBandwidthManager(graph Graph, sourceNode route.Vertex, linkQuery getLinkQuery) (*bandwidthManager, error) { manager := &bandwidthManager{ @@ -49,7 +49,7 @@ func newBandwidthManager(graph routingGraph, sourceNode route.Vertex, // First, we'll collect the set of outbound edges from the target // source node and add them to our bandwidth manager's map of channels. - err := graph.forEachNodeChannel(sourceNode, + err := graph.ForEachNodeChannel(sourceNode, func(channel *channeldb.DirectedChannel) error { shortID := lnwire.NewShortChanIDFromInt( channel.ChannelID, diff --git a/routing/graph.go b/routing/graph.go index 1f0abf9c06..1f4b24bb51 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -10,19 +10,19 @@ import ( "github.com/lightningnetwork/lnd/routing/route" ) -// routingGraph is an abstract interface that provides information about nodes -// and edges to pathfinding. -type routingGraph interface { - // forEachNodeChannel calls the callback for every channel of the given +// Graph is an abstract interface that provides information about nodes and +// edges to pathfinding. +type Graph interface { + // ForEachNodeChannel calls the callback for every channel of the given // node. - forEachNodeChannel(nodePub route.Vertex, + ForEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error - // fetchNodeFeatures returns the features of the given node. - fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) + // FetchNodeFeatures returns the features of the given node. + FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) } -// CachedGraph is a routingGraph implementation that retrieves from the +// CachedGraph is a Graph implementation that retrieves from the // database. type CachedGraph struct { graph *channeldb.ChannelGraph @@ -30,9 +30,9 @@ type CachedGraph struct { source route.Vertex } -// A compile time assertion to make sure CachedGraph implements the routingGraph +// A compile time assertion to make sure CachedGraph implements the Graph // interface. -var _ routingGraph = (*CachedGraph)(nil) +var _ Graph = (*CachedGraph)(nil) // NewCachedGraph instantiates a new db-connected routing graph. It implicitly // instantiates a new read transaction. @@ -61,20 +61,20 @@ func (g *CachedGraph) Close() error { return g.tx.Rollback() } -// forEachNodeChannel calls the callback for every channel of the given node. +// ForEachNodeChannel calls the callback for every channel of the given node. // -// NOTE: Part of the routingGraph interface. -func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex, +// NOTE: Part of the Graph interface. +func (g *CachedGraph) ForEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error { return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) } -// fetchNodeFeatures returns the features of the given node. If the node is +// FetchNodeFeatures returns the features of the given node. If the node is // unknown, assume no additional features are supported. // -// NOTE: Part of the routingGraph interface. -func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) ( +// NOTE: Part of the Graph interface. +func (g *CachedGraph) FetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { return g.graph.FetchNodeFeatures(nodePub) diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 95a5eaf65f..02cd6f0477 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -163,7 +163,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, c.t.Fatal(err) } - getBandwidthHints := func(_ routingGraph) (bandwidthHints, error) { + getBandwidthHints := func(_ Graph) (bandwidthHints, error) { // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { @@ -201,7 +201,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, session, err := newPaymentSession( &payment, c.graph.source.pubkey, getBandwidthHints, - func() (routingGraph, func(), error) { + func() (Graph, func(), error) { return c.graph, func() {}, nil }, mc, c.pathFindingCfg, diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index 2ec9a0f989..348eb3746f 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -164,8 +164,8 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte, // forEachNodeChannel calls the callback for every channel of the given node. // -// NOTE: Part of the routingGraph interface. -func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, +// NOTE: Part of the Graph interface. +func (m *mockGraph) ForEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error { // Look up the mock node. @@ -213,15 +213,15 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, // sourceNode returns the source node of the graph. // -// NOTE: Part of the routingGraph interface. +// NOTE: Part of the Graph interface. func (m *mockGraph) sourceNode() route.Vertex { return m.source.pubkey } // fetchNodeFeatures returns the features of the given node. // -// NOTE: Part of the routingGraph interface. -func (m *mockGraph) fetchNodeFeatures(nodePub route.Vertex) ( +// NOTE: Part of the Graph interface. +func (m *mockGraph) FetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { return lnwire.EmptyFeatureVector(), nil @@ -230,7 +230,7 @@ func (m *mockGraph) fetchNodeFeatures(nodePub route.Vertex) ( // FetchAmountPairCapacity returns the maximal capacity between nodes in the // graph. // -// NOTE: Part of the routingGraph interface. +// NOTE: Part of the Graph interface. func (m *mockGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { @@ -244,7 +244,7 @@ func (m *mockGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, return nil } - err := m.forEachNodeChannel(nodeFrom, cb) + err := m.ForEachNodeChannel(nodeFrom, cb) if err != nil { return 0, err } @@ -295,5 +295,5 @@ func (m *mockGraph) sendHtlc(route *route.Route) (htlcResult, error) { return source.fwd(nil, next) } -// Compile-time check for the routingGraph interface. -var _ routingGraph = &mockGraph{} +// Compile-time check for the Graph interface. +var _ Graph = &mockGraph{} diff --git a/routing/pathfind.go b/routing/pathfind.go index d7d2893b0b..083af04dbd 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -369,7 +369,7 @@ func edgeWeight(lockedAmt lnwire.MilliSatoshi, fee lnwire.MilliSatoshi, // graphParams wraps the set of graph parameters passed to findPath. type graphParams struct { // graph is the ChannelGraph to be used during path finding. - graph routingGraph + graph Graph // additionalEdges is an optional set of edges that should be // considered during path finding, that is not already found in the @@ -464,7 +464,7 @@ type PathFindingConfig struct { // available balance. func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, bandwidthHints bandwidthHints, - g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { + g Graph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi cb := func(channel *channeldb.DirectedChannel) error { @@ -502,7 +502,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, } // Iterate over all channels of the to node. - err := g.forEachNodeChannel(node, cb) + err := g.ForEachNodeChannel(node, cb) if err != nil { return 0, 0, err } @@ -542,7 +542,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, features := r.DestFeatures if features == nil { var err error - features, err = g.graph.fetchNodeFeatures(target) + features, err = g.graph.FetchNodeFeatures(target) if err != nil { return nil, 0, err } @@ -920,7 +920,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, } // Fetch node features fresh from the graph. - fromFeatures, err := g.graph.fetchNodeFeatures(node) + fromFeatures, err := g.graph.FetchNodeFeatures(node) if err != nil { return nil, err } diff --git a/routing/payment_session.go b/routing/payment_session.go index bdd1948128..6cfbeddf46 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -167,7 +167,7 @@ type paymentSession struct { additionalEdges map[route.Vertex][]AdditionalEdge - getBandwidthHints func(routingGraph) (bandwidthHints, error) + getBandwidthHints func(Graph) (bandwidthHints, error) payment *LightningPayment @@ -175,7 +175,7 @@ type paymentSession struct { pathFinder pathFinder - getRoutingGraph func() (routingGraph, func(), error) + getRoutingGraph func() (Graph, func(), error) // pathFindingConfig defines global parameters that control the // trade-off in path finding between fees and probability. @@ -195,8 +195,8 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, selfNode route.Vertex, - getBandwidthHints func(routingGraph) (bandwidthHints, error), - getRoutingGraph func() (routingGraph, func(), error), + getBandwidthHints func(Graph) (bandwidthHints, error), + getRoutingGraph func() (Graph, func(), error), missionControl MissionController, pathFindingConfig PathFindingConfig) ( *paymentSession, error) { diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index ba010391bd..51bfc97811 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -46,7 +46,7 @@ type SessionSource struct { // getRoutingGraph returns a routing graph and a clean-up function for // pathfinding. -func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { +func (m *SessionSource) getRoutingGraph() (Graph, func(), error) { routingTx, err := NewCachedGraph(m.SourceNode, m.Graph) if err != nil { return nil, nil, err @@ -66,7 +66,7 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( PaymentSession, error) { - getBandwidthHints := func(graph routingGraph) (bandwidthHints, error) { + getBandwidthHints := func(graph Graph) (bandwidthHints, error) { return newBandwidthManager( graph, m.SourceNode.PubKeyBytes, m.GetLink, ) diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index b7efed5b7c..9356a2be01 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -116,10 +116,10 @@ func TestUpdateAdditionalEdge(t *testing.T) { // Create the paymentsession. session, err := newPaymentSession( payment, route.Vertex{}, - func(routingGraph) (bandwidthHints, error) { + func(Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, - func() (routingGraph, func(), error) { + func() (Graph, func(), error) { return &sessionGraph{}, func() {}, nil }, &MissionControl{}, @@ -196,10 +196,10 @@ func TestRequestRoute(t *testing.T) { session, err := newPaymentSession( payment, route.Vertex{}, - func(routingGraph) (bandwidthHints, error) { + func(Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, - func() (routingGraph, func(), error) { + func() (Graph, func(), error) { return &sessionGraph{}, func() {}, nil }, &MissionControl{}, @@ -253,7 +253,7 @@ func TestRequestRoute(t *testing.T) { } type sessionGraph struct { - routingGraph + Graph } func (g *sessionGraph) sourceNode() route.Vertex { diff --git a/routing/router.go b/routing/router.go index 597705754b..9af047f4ec 100644 --- a/routing/router.go +++ b/routing/router.go @@ -453,9 +453,9 @@ type ChannelRouter struct { // when doing any path finding. selfNode *channeldb.LightningNode - // cachedGraph is an instance of routingGraph that caches the source + // cachedGraph is an instance of Graph that caches the source // node as well as the channel graph itself in memory. - cachedGraph routingGraph + cachedGraph Graph // newBlocks is a channel in which new blocks connected to the end of // the main chain are sent over, and blocks updated after a call to @@ -3177,7 +3177,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // getRouteUnifiers returns a list of edge unifiers for the given route. func getRouteUnifiers(source route.Vertex, hops []route.Vertex, useMinAmt bool, runningAmt lnwire.MilliSatoshi, - outgoingChans map[uint64]struct{}, graph routingGraph, + outgoingChans map[uint64]struct{}, graph Graph, bandwidthHints *bandwidthManager) ([]*edgeUnifier, lnwire.MilliSatoshi, error) { diff --git a/routing/unified_edges.go b/routing/unified_edges.go index d39eda1efd..a0300eea4b 100644 --- a/routing/unified_edges.go +++ b/routing/unified_edges.go @@ -94,7 +94,7 @@ func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex, // addGraphPolicies adds all policies that are known for the toNode in the // graph. -func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error { +func (u *nodeEdgeUnifier) addGraphPolicies(g Graph) error { cb := func(channel *channeldb.DirectedChannel) error { // If there is no edge policy for this candidate node, skip. // Note that we are searching backwards so this node would have @@ -120,7 +120,7 @@ func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error { } // Iterate over all channels of the to node. - return g.forEachNodeChannel(u.toNode, cb) + return g.ForEachNodeChannel(u.toNode, cb) } // unifiedEdge is the individual channel data that is kept inside an edgeUnifier From 90d6b863a8007153f3a296fb0b9f5ff4771a04d8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 25 Jun 2024 19:27:13 -0700 Subject: [PATCH 04/20] routing+refactor: remove the need to give CachedGraph source node access In preparation for the next commit. --- routing/graph.go | 20 ++++++++------------ routing/mock_graph_test.go | 25 ------------------------- routing/pathfind_test.go | 2 +- routing/payment_session_source.go | 2 +- routing/router.go | 3 +-- rpcserver.go | 9 ++++----- 6 files changed, 15 insertions(+), 46 deletions(-) diff --git a/routing/graph.go b/routing/graph.go index 1f4b24bb51..0c4d2e1d41 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -25,9 +25,8 @@ type Graph interface { // CachedGraph is a Graph implementation that retrieves from the // database. type CachedGraph struct { - graph *channeldb.ChannelGraph - tx kvdb.RTx - source route.Vertex + graph *channeldb.ChannelGraph + tx kvdb.RTx } // A compile time assertion to make sure CachedGraph implements the Graph @@ -36,18 +35,15 @@ var _ Graph = (*CachedGraph)(nil) // NewCachedGraph instantiates a new db-connected routing graph. It implicitly // instantiates a new read transaction. -func NewCachedGraph(sourceNode *channeldb.LightningNode, - graph *channeldb.ChannelGraph) (*CachedGraph, error) { - +func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) { tx, err := graph.NewPathFindTx() if err != nil { return nil, err } return &CachedGraph{ - graph: graph, - tx: tx, - source: sourceNode.PubKeyBytes, + graph: graph, + tx: tx, }, nil } @@ -82,16 +78,16 @@ func (g *CachedGraph) FetchNodeFeatures(nodePub route.Vertex) ( // FetchAmountPairCapacity determines the maximal public capacity between two // nodes depending on the amount we try to send. -func (g *CachedGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, +func FetchAmountPairCapacity(graph Graph, source, nodeFrom, nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { // Create unified edges for all incoming connections. // // Note: Inbound fees are not used here because this method is only used // by a deprecated router rpc. - u := newNodeEdgeUnifier(g.source, nodeTo, false, nil) + u := newNodeEdgeUnifier(source, nodeTo, false, nil) - err := u.addGraphPolicies(g) + err := u.addGraphPolicies(graph) if err != nil { return 0, err } diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index 348eb3746f..de03412343 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -227,31 +227,6 @@ func (m *mockGraph) FetchNodeFeatures(nodePub route.Vertex) ( return lnwire.EmptyFeatureVector(), nil } -// FetchAmountPairCapacity returns the maximal capacity between nodes in the -// graph. -// -// NOTE: Part of the Graph interface. -func (m *mockGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, - amount lnwire.MilliSatoshi) (btcutil.Amount, error) { - - var capacity btcutil.Amount - - cb := func(channel *channeldb.DirectedChannel) error { - if channel.OtherNode == nodeTo { - capacity = channel.Capacity - } - - return nil - } - - err := m.ForEachNodeChannel(nodeFrom, cb) - if err != nil { - return 0, err - } - - return capacity, nil -} - // htlcResult describes the resolution of an htlc. If failure is nil, the htlc // was settled. type htlcResult struct { diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index a35c9e2f7b..0eea8edc03 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3201,7 +3201,7 @@ func dbFindPath(graph *channeldb.ChannelGraph, return nil, err } - routingGraph, err := NewCachedGraph(sourceNode, graph) + routingGraph, err := NewCachedGraph(graph) if err != nil { return nil, err } diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 51bfc97811..cc90c465bf 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -47,7 +47,7 @@ type SessionSource struct { // getRoutingGraph returns a routing graph and a clean-up function for // pathfinding. func (m *SessionSource) getRoutingGraph() (Graph, func(), error) { - routingTx, err := NewCachedGraph(m.SourceNode, m.Graph) + routingTx, err := NewCachedGraph(m.Graph) if err != nil { return nil, nil, err } diff --git a/routing/router.go b/routing/router.go index 9af047f4ec..9af37a5434 100644 --- a/routing/router.go +++ b/routing/router.go @@ -517,8 +517,7 @@ func New(cfg Config) (*ChannelRouter, error) { r := &ChannelRouter{ cfg: &cfg, cachedGraph: &CachedGraph{ - graph: cfg.Graph, - source: selfNode.PubKeyBytes, + graph: cfg.Graph, }, networkUpdates: make(chan *routingMsg), topologyClients: &lnutils.SyncMap[uint64, *topologyClient]{}, diff --git a/rpcserver.go b/rpcserver.go index a306d4fd29..59e3196fe4 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -691,9 +691,7 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, FetchAmountPairCapacity: func(nodeFrom, nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { - routingGraph, err := routing.NewCachedGraph( - selfNode, graph, - ) + routingGraph, err := routing.NewCachedGraph(graph) if err != nil { return 0, err } @@ -706,8 +704,9 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, } }() - return routingGraph.FetchAmountPairCapacity( - nodeFrom, nodeTo, amount, + return routing.FetchAmountPairCapacity( + routingGraph, selfNode.PubKeyBytes, nodeFrom, + nodeTo, amount, ) }, FetchChannelEndpoints: func(chanID uint64) (route.Vertex, From 8c0df98439c560af689c1dded9570b6c27ccf00c Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 25 Jun 2024 19:58:57 -0700 Subject: [PATCH 05/20] multi: add abstraction for Router and SessionSource graph access In this commit, we completely remove the Router's dependence on a Graph source that requires a `kvdb.RTx`. In so doing, we are more prepared for a future where the Graph source is backed by different DB structure such as pure SQL. The two areas affected here are: the ChannelRouter's graph access that it uses for pathfinding. And the SessionSource's graph access that it uses for payments. The ChannelRouter gets given a Graph and the SessionSource is given a GraphSessionFactory which it can use to create a new session. Behind the scenes, this will acquire a kvdb.RTx that will be used for calls to the Graph's `ForEachNodeChannel` method. --- channeldb/graphsession/graph_session.go | 141 +++++++++++++++++++++ routing/graph.go | 63 ++------- routing/integrated_routing_context_test.go | 91 ++++++++++++- routing/pathfind_test.go | 10 +- routing/payment_session.go | 23 ++-- routing/payment_session_source.go | 24 +--- routing/payment_session_test.go | 8 +- routing/router.go | 20 ++- routing/router_test.go | 7 +- rpcserver.go | 18 +-- server.go | 11 +- 11 files changed, 288 insertions(+), 128 deletions(-) create mode 100644 channeldb/graphsession/graph_session.go diff --git a/channeldb/graphsession/graph_session.go b/channeldb/graphsession/graph_session.go new file mode 100644 index 0000000000..30f1903287 --- /dev/null +++ b/channeldb/graphsession/graph_session.go @@ -0,0 +1,141 @@ +package graphsession + +import ( + "fmt" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing" + "github.com/lightningnetwork/lnd/routing/route" +) + +// Factory implements the routing.GraphSessionFactory and can be used to start +// a session with a ReadOnlyGraph. +type Factory struct { + graph ReadOnlyGraph +} + +// NewGraphSessionFactory constructs a new Factory which can then be used to +// start a new session. +func NewGraphSessionFactory(graph ReadOnlyGraph) routing.GraphSessionFactory { + return &Factory{ + graph: graph, + } +} + +// NewGraphSession will produce a new Graph to use for a path-finding session. +// It returns the Graph along with a call-back that must be called once Graph +// access is complete. This call-back will close any read-only transaction that +// was created at Graph construction time. +// +// NOTE: This is part of the routing.GraphSessionFactory interface. +func (g *Factory) NewGraphSession() (routing.Graph, func() error, error) { + tx, err := g.graph.NewPathFindTx() + if err != nil { + return nil, nil, err + } + + session := &session{ + graph: g.graph, + tx: tx, + } + + return session, session.close, nil +} + +// A compile-time check to ensure that Factory implements the +// routing.GraphSessionFactory interface. +var _ routing.GraphSessionFactory = (*Factory)(nil) + +// session is an implementation of the routing.Graph interface where the same +// read-only transaction is held across calls to the graph and can be used to +// access the backing channel graph. +type session struct { + graph graph + tx kvdb.RTx +} + +// NewRoutingGraph constructs a session that which does not first start a +// read-only transaction and so each call on the routing.Graph will create a +// new transaction. +func NewRoutingGraph(graph ReadOnlyGraph) routing.Graph { + return &session{ + graph: graph, + } +} + +// close closes the read-only transaction being used to access the backing +// graph. If no transaction was started then this is a no-op. +func (g *session) close() error { + if g.tx == nil { + return nil + } + + err := g.tx.Rollback() + if err != nil { + return fmt.Errorf("error closing db tx: %w", err) + } + + return nil +} + +// ForEachNodeChannel calls the callback for every channel of the given node. +// +// NOTE: Part of the routing.Graph interface. +func (g *session) ForEachNodeChannel(nodePub route.Vertex, + cb func(channel *channeldb.DirectedChannel) error) error { + + return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) +} + +// FetchNodeFeatures returns the features of the given node. If the node is +// unknown, assume no additional features are supported. +// +// NOTE: Part of the routing.Graph interface. +func (g *session) FetchNodeFeatures(nodePub route.Vertex) ( + *lnwire.FeatureVector, error) { + + return g.graph.FetchNodeFeatures(nodePub) +} + +// A compile-time check to ensure that *session implements the +// routing.Graph interface. +var _ routing.Graph = (*session)(nil) + +// ReadOnlyGraph is a graph extended with a call to create a new read-only +// transaction that can then be used to make further queries to the graph. +type ReadOnlyGraph interface { + // NewPathFindTx returns a new read transaction that can be used for a + // single path finding session. Will return nil if the graph cache is + // enabled. + NewPathFindTx() (kvdb.RTx, error) + + graph +} + +// graph describes the API necessary for a graph source to have access to on a +// database implementation, like channeldb.ChannelGraph, in order to be used by +// the Router for pathfinding. +type graph interface { + // ForEachNodeDirectedChannel iterates through all channels of a given + // node, executing the passed callback on the directed edge representing + // the channel and its incoming policy. If the callback returns an + // error, then the iteration is halted with the error propagated back + // up to the caller. + // + // Unknown policies are passed into the callback as nil values. + // + // NOTE: if a nil tx is provided, then it is expected that the + // implementation create a read only tx. + ForEachNodeDirectedChannel(tx kvdb.RTx, node route.Vertex, + cb func(channel *channeldb.DirectedChannel) error) error + + // FetchNodeFeatures returns the features of a given node. If no + // features are known for the node, an empty feature vector is returned. + FetchNodeFeatures(node route.Vertex) (*lnwire.FeatureVector, error) +} + +// A compile-time check to ensure that *channeldb.ChannelGraph implements the +// graph interface. +var _ graph = (*channeldb.ChannelGraph)(nil) diff --git a/routing/graph.go b/routing/graph.go index 0c4d2e1d41..2b1c85bc8f 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -5,7 +5,6 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -22,58 +21,16 @@ type Graph interface { FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) } -// CachedGraph is a Graph implementation that retrieves from the -// database. -type CachedGraph struct { - graph *channeldb.ChannelGraph - tx kvdb.RTx -} - -// A compile time assertion to make sure CachedGraph implements the Graph -// interface. -var _ Graph = (*CachedGraph)(nil) - -// NewCachedGraph instantiates a new db-connected routing graph. It implicitly -// instantiates a new read transaction. -func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) { - tx, err := graph.NewPathFindTx() - if err != nil { - return nil, err - } - - return &CachedGraph{ - graph: graph, - tx: tx, - }, nil -} - -// Close attempts to close the underlying db transaction. This is a no-op in -// case the underlying graph uses an in-memory cache. -func (g *CachedGraph) Close() error { - if g.tx == nil { - return nil - } - - return g.tx.Rollback() -} - -// ForEachNodeChannel calls the callback for every channel of the given node. -// -// NOTE: Part of the Graph interface. -func (g *CachedGraph) ForEachNodeChannel(nodePub route.Vertex, - cb func(channel *channeldb.DirectedChannel) error) error { - - return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) -} - -// FetchNodeFeatures returns the features of the given node. If the node is -// unknown, assume no additional features are supported. -// -// NOTE: Part of the Graph interface. -func (g *CachedGraph) FetchNodeFeatures(nodePub route.Vertex) ( - *lnwire.FeatureVector, error) { - - return g.graph.FetchNodeFeatures(nodePub) +// GraphSessionFactory can be used to produce a new Graph instance which can +// then be used for a path-finding session. Depending on the implementation, +// the Graph session will represent a DB connection where a read-lock is being +// held across calls to the backing Graph. +type GraphSessionFactory interface { + // NewGraphSession will produce a new Graph to use for a path-finding + // session. It returns the Graph along with a call-back that must be + // called once Graph access is complete. This call-back will close any + // read-only transaction that was created at Graph construction time. + NewGraphSession() (Graph, func() error, error) } // FetchAmountPairCapacity determines the maximal public capacity between two diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 02cd6f0477..ee6fed295a 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -201,10 +202,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, session, err := newPaymentSession( &payment, c.graph.source.pubkey, getBandwidthHints, - func() (Graph, func(), error) { - return c.graph, func() {}, nil - }, - mc, c.pathFindingCfg, + newMockGraphSessionFactory(c.graph), mc, c.pathFindingCfg, ) if err != nil { c.t.Fatal(err) @@ -307,3 +305,88 @@ func getNodeIndex(route *route.Route, failureSource route.Vertex) *int { } return nil } + +type mockGraphSessionFactory struct { + Graph +} + +func newMockGraphSessionFactory(graph Graph) GraphSessionFactory { + return &mockGraphSessionFactory{Graph: graph} +} + +func (m *mockGraphSessionFactory) NewGraphSession() (Graph, func() error, + error) { + + return m, func() error { + return nil + }, nil +} + +var _ GraphSessionFactory = (*mockGraphSessionFactory)(nil) +var _ Graph = (*mockGraphSessionFactory)(nil) + +type mockGraphSessionFactoryChanDB struct { + graph *channeldb.ChannelGraph +} + +func newMockGraphSessionFactoryFromChanDB( + graph *channeldb.ChannelGraph) *mockGraphSessionFactoryChanDB { + + return &mockGraphSessionFactoryChanDB{ + graph: graph, + } +} + +func (g *mockGraphSessionFactoryChanDB) NewGraphSession() (Graph, func() error, + error) { + + tx, err := g.graph.NewPathFindTx() + if err != nil { + return nil, nil, err + } + + session := &mockGraphSessionChanDB{ + graph: g.graph, + tx: tx, + } + + return session, session.close, nil +} + +var _ GraphSessionFactory = (*mockGraphSessionFactoryChanDB)(nil) + +type mockGraphSessionChanDB struct { + graph *channeldb.ChannelGraph + tx kvdb.RTx +} + +func newMockGraphSessionChanDB(graph *channeldb.ChannelGraph) Graph { + return &mockGraphSessionChanDB{ + graph: graph, + } +} + +func (g *mockGraphSessionChanDB) close() error { + if g.tx == nil { + return nil + } + + err := g.tx.Rollback() + if err != nil { + return fmt.Errorf("error closing db tx: %w", err) + } + + return nil +} + +func (g *mockGraphSessionChanDB) ForEachNodeChannel(nodePub route.Vertex, + cb func(channel *channeldb.DirectedChannel) error) error { + + return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) +} + +func (g *mockGraphSessionChanDB) FetchNodeFeatures(nodePub route.Vertex) ( + *lnwire.FeatureVector, error) { + + return g.graph.FetchNodeFeatures(nodePub) +} diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 0eea8edc03..d430f30373 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3201,14 +3201,16 @@ func dbFindPath(graph *channeldb.ChannelGraph, return nil, err } - routingGraph, err := NewCachedGraph(graph) + graphSessFactory := newMockGraphSessionFactoryFromChanDB(graph) + + graphSess, closeGraphSess, err := graphSessFactory.NewGraphSession() if err != nil { return nil, err } defer func() { - if err := routingGraph.Close(); err != nil { - log.Errorf("Error closing db tx: %v", err) + if err := closeGraphSess(); err != nil { + log.Errorf("Error closing graph session: %v", err) } }() @@ -3216,7 +3218,7 @@ func dbFindPath(graph *channeldb.ChannelGraph, &graphParams{ additionalEdges: additionalEdges, bandwidthHints: bandwidthHints, - graph: routingGraph, + graph: graphSess, }, r, cfg, sourceNode.PubKeyBytes, source, target, amt, timePref, finalHtlcExpiry, diff --git a/routing/payment_session.go b/routing/payment_session.go index 6cfbeddf46..0d46f71199 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -175,7 +175,7 @@ type paymentSession struct { pathFinder pathFinder - getRoutingGraph func() (Graph, func(), error) + graphSessFactory GraphSessionFactory // pathFindingConfig defines global parameters that control the // trade-off in path finding between fees and probability. @@ -196,9 +196,8 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, selfNode route.Vertex, getBandwidthHints func(Graph) (bandwidthHints, error), - getRoutingGraph func() (Graph, func(), error), - missionControl MissionController, pathFindingConfig PathFindingConfig) ( - *paymentSession, error) { + graphSessFactory GraphSessionFactory, missionControl MissionController, + pathFindingConfig PathFindingConfig) (*paymentSession, error) { edges, err := RouteHintsToEdges(p.RouteHints, p.Target) if err != nil { @@ -213,7 +212,7 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex, getBandwidthHints: getBandwidthHints, payment: p, pathFinder: findPath, - getRoutingGraph: getRoutingGraph, + graphSessFactory: graphSessFactory, pathFindingConfig: pathFindingConfig, missionControl: missionControl, minShardAmt: DefaultShardMinAmt, @@ -280,8 +279,8 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, } for { - // Get a routing graph. - routingGraph, cleanup, err := p.getRoutingGraph() + // Get a routing graph session. + graph, closeGraph, err := p.graphSessFactory.NewGraphSession() if err != nil { return nil, err } @@ -292,7 +291,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // don't have enough bandwidth to carry the payment. New // bandwidth hints are queried for every new path finding // attempt, because concurrent payments may change balances. - bandwidthHints, err := p.getBandwidthHints(routingGraph) + bandwidthHints, err := p.getBandwidthHints(graph) if err != nil { return nil, err } @@ -304,15 +303,17 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, &graphParams{ additionalEdges: p.additionalEdges, bandwidthHints: bandwidthHints, - graph: routingGraph, + graph: graph, }, restrictions, &p.pathFindingConfig, p.selfNode, p.selfNode, p.payment.Target, maxAmt, p.payment.TimePref, finalHtlcExpiry, ) - // Close routing graph. - cleanup() + // Close routing graph session. + if err := closeGraph(); err != nil { + log.Errorf("could not close graph session: %v", err) + } switch { case err == errNoPathFound: diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index cc90c465bf..46e7a42aa1 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -16,9 +16,10 @@ var _ PaymentSessionSource = (*SessionSource)(nil) // SessionSource defines a source for the router to retrieve new payment // sessions. type SessionSource struct { - // Graph is the channel graph that will be used to gather metrics from - // and also to carry out path finding queries. - Graph *channeldb.ChannelGraph + // GraphSessionFactory can be used to gain access to a Graph session. + // If the backing DB allows it, this will mean that a read transaction + // is being held during the use of the session. + GraphSessionFactory GraphSessionFactory // SourceNode is the graph's source node. SourceNode *channeldb.LightningNode @@ -44,21 +45,6 @@ type SessionSource struct { PathFindingConfig PathFindingConfig } -// getRoutingGraph returns a routing graph and a clean-up function for -// pathfinding. -func (m *SessionSource) getRoutingGraph() (Graph, func(), error) { - routingTx, err := NewCachedGraph(m.Graph) - if err != nil { - return nil, nil, err - } - return routingTx, func() { - err := routingTx.Close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }, nil -} - // NewPaymentSession creates a new payment session backed by the latest prune // view from Mission Control. An optional set of routing hints can be provided // in order to populate additional edges to explore when finding a path to the @@ -74,7 +60,7 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( session, err := newPaymentSession( p, m.SourceNode.PubKeyBytes, getBandwidthHints, - m.getRoutingGraph, m.MissionControl, m.PathFindingConfig, + m.GraphSessionFactory, m.MissionControl, m.PathFindingConfig, ) if err != nil { return nil, err diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 9356a2be01..f6873aa752 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -119,9 +119,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { func(Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, - func() (Graph, func(), error) { - return &sessionGraph{}, func() {}, nil - }, + newMockGraphSessionFactory(&sessionGraph{}), &MissionControl{}, PathFindingConfig{}, ) @@ -199,9 +197,7 @@ func TestRequestRoute(t *testing.T) { func(Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, - func() (Graph, func(), error) { - return &sessionGraph{}, func() {}, nil - }, + newMockGraphSessionFactory(&sessionGraph{}), &MissionControl{}, PathFindingConfig{}, ) diff --git a/routing/router.go b/routing/router.go index 9af37a5434..c52adccc3e 100644 --- a/routing/router.go +++ b/routing/router.go @@ -319,6 +319,9 @@ type ChannelPolicy struct { // the configuration MUST be non-nil for the ChannelRouter to carry out its // duties. type Config struct { + // RoutingGraph is a graph source that will be used for pathfinding. + RoutingGraph Graph + // Graph is the channel graph that the ChannelRouter will use to gather // metrics from and also to carry out path finding queries. // TODO(roasbeef): make into an interface @@ -453,10 +456,6 @@ type ChannelRouter struct { // when doing any path finding. selfNode *channeldb.LightningNode - // cachedGraph is an instance of Graph that caches the source - // node as well as the channel graph itself in memory. - cachedGraph Graph - // newBlocks is a channel in which new blocks connected to the end of // the main chain are sent over, and blocks updated after a call to // UpdateFilter. @@ -515,10 +514,7 @@ func New(cfg Config) (*ChannelRouter, error) { } r := &ChannelRouter{ - cfg: &cfg, - cachedGraph: &CachedGraph{ - graph: cfg.Graph, - }, + cfg: &cfg, networkUpdates: make(chan *routingMsg), topologyClients: &lnutils.SyncMap[uint64, *topologyClient]{}, ntfnClientUpdates: make(chan *topologyClientUpdate), @@ -2118,7 +2114,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. bandwidthHints, err := newBandwidthManager( - r.cachedGraph, r.selfNode.PubKeyBytes, r.cfg.GetLink, + r.cfg.RoutingGraph, r.selfNode.PubKeyBytes, r.cfg.GetLink, ) if err != nil { return nil, 0, err @@ -2145,7 +2141,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, &graphParams{ additionalEdges: req.RouteHints, bandwidthHints: bandwidthHints, - graph: r.cachedGraph, + graph: r.cfg.RoutingGraph, }, req.Restrictions, &r.cfg.PathFindingConfig, r.selfNode.PubKeyBytes, req.Source, req.Target, req.Amount, @@ -3131,7 +3127,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := newBandwidthManager( - r.cachedGraph, r.selfNode.PubKeyBytes, r.cfg.GetLink, + r.cfg.RoutingGraph, r.selfNode.PubKeyBytes, r.cfg.GetLink, ) if err != nil { return nil, err @@ -3147,7 +3143,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, sourceNode := r.selfNode.PubKeyBytes unifiers, senderAmt, err := getRouteUnifiers( sourceNode, hops, useMinAmt, runningAmt, outgoingChans, - r.cachedGraph, bandwidthHints, + r.cfg.RoutingGraph, bandwidthHints, ) if err != nil { return nil, err diff --git a/routing/router_test.go b/routing/router_test.go index 47bdf7a17a..d6c5e6172d 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -77,6 +77,7 @@ func (c *testCtx) RestartRouter(t *testing.T) { // With the chainView reset, we'll now re-create the router itself, and // start it. router, err := New(Config{ + RoutingGraph: newMockGraphSessionChanDB(c.graph), Graph: c.graph, Chain: c.chain, ChainView: c.chainView, @@ -140,7 +141,9 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, sourceNode, err := graphInstance.graph.SourceNode() require.NoError(t, err) sessionSource := &SessionSource{ - Graph: graphInstance.graph, + GraphSessionFactory: newMockGraphSessionFactoryFromChanDB( + graphInstance.graph, + ), SourceNode: sourceNode, GetLink: graphInstance.getLink, PathFindingConfig: pathFindingConfig, @@ -154,6 +157,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, } router, err := New(Config{ + RoutingGraph: newMockGraphSessionChanDB(graphInstance.graph), Graph: graphInstance.graph, Chain: chain, ChainView: chainView, @@ -1763,6 +1767,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { // Create new router with same graph database. router, err := New(Config{ + RoutingGraph: newMockGraphSessionChanDB(ctx.graph), Graph: ctx.graph, Chain: ctx.chain, ChainView: ctx.chainView, diff --git a/rpcserver.go b/rpcserver.go index 59e3196fe4..c2ab4eb543 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -41,6 +41,7 @@ import ( "github.com/lightningnetwork/lnd/chanbackup" "github.com/lightningnetwork/lnd/chanfitness" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/graphsession" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/contractcourt" @@ -691,22 +692,9 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, FetchAmountPairCapacity: func(nodeFrom, nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { - routingGraph, err := routing.NewCachedGraph(graph) - if err != nil { - return 0, err - } - defer func() { - closeErr := routingGraph.Close() - if closeErr != nil { - rpcsLog.Errorf("not able to close "+ - "routing graph tx: %v", - closeErr) - } - }() - return routing.FetchAmountPairCapacity( - routingGraph, selfNode.PubKeyBytes, nodeFrom, - nodeTo, amount, + graphsession.NewRoutingGraph(graph), + selfNode.PubKeyBytes, nodeFrom, nodeTo, amount, ) }, FetchChannelEndpoints: func(chanID uint64) (route.Vertex, diff --git a/server.go b/server.go index 6ab197f86a..5b96ec7854 100644 --- a/server.go +++ b/server.go @@ -32,6 +32,7 @@ import ( "github.com/lightningnetwork/lnd/chanbackup" "github.com/lightningnetwork/lnd/chanfitness" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/graphsession" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/clock" @@ -956,7 +957,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, fmt.Errorf("error getting source node: %w", err) } paymentSessionSource := &routing.SessionSource{ - Graph: chanGraph, + GraphSessionFactory: graphsession.NewGraphSessionFactory( + chanGraph, + ), SourceNode: sourceNode, MissionControl: s.missionControl, GetLink: s.htlcSwitch.GetLinkByShortID, @@ -967,9 +970,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.controlTower = routing.NewControlTower(paymentControl) - strictPruning := (cfg.Bitcoin.Node == "neutrino" || - cfg.Routing.StrictZombiePruning) + strictPruning := cfg.Bitcoin.Node == "neutrino" || + cfg.Routing.StrictZombiePruning + s.chanRouter, err = routing.New(routing.Config{ + RoutingGraph: graphsession.NewRoutingGraph(chanGraph), Graph: chanGraph, Chain: cc.ChainIO, ChainView: cc.ChainView, From cf3de72503b737cc8be74164a420c29c5c86e13e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 25 Jun 2024 20:08:57 -0700 Subject: [PATCH 06/20] routing: let SelfNode be passed via config Instead of querying it from the graph since this will be removed in a future commit. --- routing/pathfind_test.go | 24 ++++-------------------- routing/router.go | 29 +++++++++++------------------ routing/router_test.go | 15 ++++++++++++--- server.go | 1 + 4 files changed, 28 insertions(+), 41 deletions(-) diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index d430f30373..e48f3c7fbd 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -2227,18 +2227,13 @@ func TestPathFindSpecExample(t *testing.T) { // Carol, so we set "B" as the source node so path finding starts from // Bob. bob := ctx.aliases["B"] - bobNode, err := ctx.graph.FetchLightningNode(nil, bob) - require.NoError(t, err, "unable to find bob") - if err := ctx.graph.SetSourceNode(bobNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } // Query for a route of 4,999,999 mSAT to carol. carol := ctx.aliases["C"] const amt lnwire.MilliSatoshi = 4999999 req, err := NewRouteRequest( - bobNode.PubKeyBytes, &carol, amt, 0, noRestrictions, nil, nil, - nil, MinCLTVDelta, + bob, &carol, amt, 0, noRestrictions, nil, nil, nil, + MinCLTVDelta, ) require.NoError(t, err, "invalid route request") @@ -2276,22 +2271,11 @@ func TestPathFindSpecExample(t *testing.T) { // Next, we'll set A as the source node so we can assert that we create // the proper route for any queries starting with Alice. alice := ctx.aliases["A"] - aliceNode, err := ctx.graph.FetchLightningNode(nil, alice) - require.NoError(t, err, "unable to find alice") - if err := ctx.graph.SetSourceNode(aliceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - ctx.router.selfNode = aliceNode - source, err := ctx.graph.SourceNode() - require.NoError(t, err, "unable to retrieve source node") - if source.PubKeyBytes != alice { - t.Fatalf("source node not set") - } // We'll now request a route from A -> B -> C. req, err = NewRouteRequest( - source.PubKeyBytes, &carol, amt, 0, noRestrictions, nil, nil, - nil, MinCLTVDelta, + alice, &carol, amt, 0, noRestrictions, nil, nil, nil, + MinCLTVDelta, ) require.NoError(t, err, "invalid route request") diff --git a/routing/router.go b/routing/router.go index c52adccc3e..04b6cad457 100644 --- a/routing/router.go +++ b/routing/router.go @@ -319,6 +319,10 @@ type ChannelPolicy struct { // the configuration MUST be non-nil for the ChannelRouter to carry out its // duties. type Config struct { + // SelfNode is the public key of the node that this channel router + // belongs to. + SelfNode route.Vertex + // RoutingGraph is a graph source that will be used for pathfinding. RoutingGraph Graph @@ -451,11 +455,6 @@ type ChannelRouter struct { // initialized with. cfg *Config - // selfNode is the center of the star-graph centered around the - // ChannelRouter. The ChannelRouter uses this node as a starting point - // when doing any path finding. - selfNode *channeldb.LightningNode - // newBlocks is a channel in which new blocks connected to the end of // the main chain are sent over, and blocks updated after a call to // UpdateFilter. @@ -508,18 +507,12 @@ var _ ChannelGraphSource = (*ChannelRouter)(nil) // channel graph is a subset of the UTXO set) set, then the router will proceed // to fully sync to the latest state of the UTXO set. func New(cfg Config) (*ChannelRouter, error) { - selfNode, err := cfg.Graph.SourceNode() - if err != nil { - return nil, err - } - r := &ChannelRouter{ cfg: &cfg, networkUpdates: make(chan *routingMsg), topologyClients: &lnutils.SyncMap[uint64, *topologyClient]{}, ntfnClientUpdates: make(chan *topologyClientUpdate), channelEdgeMtx: multimutex.NewMutex[uint64](), - selfNode: selfNode, statTicker: ticker.New(defaultStatInterval), stats: new(routerStats), quit: make(chan struct{}), @@ -968,8 +961,8 @@ func (r *ChannelRouter) pruneZombieChans() error { // A helper method to detect if the channel belongs to this node isSelfChannelEdge := func(info *models.ChannelEdgeInfo) bool { - return info.NodeKey1Bytes == r.selfNode.PubKeyBytes || - info.NodeKey2Bytes == r.selfNode.PubKeyBytes + return info.NodeKey1Bytes == r.cfg.SelfNode || + info.NodeKey2Bytes == r.cfg.SelfNode } // First, we'll collect all the channels which are eligible for garbage @@ -2114,7 +2107,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. bandwidthHints, err := newBandwidthManager( - r.cfg.RoutingGraph, r.selfNode.PubKeyBytes, r.cfg.GetLink, + r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, ) if err != nil { return nil, 0, err @@ -2144,7 +2137,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, graph: r.cfg.RoutingGraph, }, req.Restrictions, &r.cfg.PathFindingConfig, - r.selfNode.PubKeyBytes, req.Source, req.Target, req.Amount, + r.cfg.SelfNode, req.Source, req.Target, req.Amount, req.TimePreference, finalHtlcExpiry, ) if err != nil { @@ -2944,7 +2937,7 @@ func (r *ChannelRouter) ForEachNode( func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { - return r.cfg.Graph.ForEachNodeChannel(nil, r.selfNode.PubKeyBytes, + return r.cfg.Graph.ForEachNodeChannel(nil, r.cfg.SelfNode, func(tx kvdb.RTx, c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -3127,7 +3120,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := newBandwidthManager( - r.cfg.RoutingGraph, r.selfNode.PubKeyBytes, r.cfg.GetLink, + r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, ) if err != nil { return nil, err @@ -3140,7 +3133,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, return nil, err } - sourceNode := r.selfNode.PubKeyBytes + sourceNode := r.cfg.SelfNode unifiers, senderAmt, err := getRouteUnifiers( sourceNode, hops, useMinAmt, runningAmt, outgoingChans, r.cfg.RoutingGraph, bandwidthHints, diff --git a/routing/router_test.go b/routing/router_test.go index d6c5e6172d..6f11894dc7 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -74,9 +74,13 @@ func (c *testCtx) RestartRouter(t *testing.T) { // filter between restarts. c.chainView.Reset() + source, err := c.graph.SourceNode() + require.NoError(t, err) + // With the chainView reset, we'll now re-create the router itself, and // start it. router, err := New(Config{ + SelfNode: source.PubKeyBytes, RoutingGraph: newMockGraphSessionChanDB(c.graph), Graph: c.graph, Chain: c.chain, @@ -157,6 +161,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, } router, err := New(Config{ + SelfNode: sourceNode.PubKeyBytes, RoutingGraph: newMockGraphSessionChanDB(graphInstance.graph), Graph: graphInstance.graph, Chain: chain, @@ -278,7 +283,7 @@ func TestFindRoutesWithFeeLimit(t *testing.T) { } req, err := NewRouteRequest( - ctx.router.selfNode.PubKeyBytes, &target, paymentAmt, 0, + ctx.router.cfg.SelfNode, &target, paymentAmt, 0, restrictions, nil, nil, nil, MinCLTVDelta, ) require.NoError(t, err, "invalid route request") @@ -1541,7 +1546,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { copy(targetPubKeyBytes[:], targetNode.SerializeCompressed()) req, err := NewRouteRequest( - ctx.router.selfNode.PubKeyBytes, &targetPubKeyBytes, + ctx.router.cfg.SelfNode, &targetPubKeyBytes, paymentAmt, 0, noRestrictions, nil, nil, nil, MinCLTVDelta, ) require.NoError(t, err, "invalid route request") @@ -1583,7 +1588,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { // Should still be able to find the route, and the info should be // updated. req, err = NewRouteRequest( - ctx.router.selfNode.PubKeyBytes, &targetPubKeyBytes, + ctx.router.cfg.SelfNode, &targetPubKeyBytes, paymentAmt, 0, noRestrictions, nil, nil, nil, MinCLTVDelta, ) require.NoError(t, err, "invalid route request") @@ -1765,8 +1770,12 @@ func TestWakeUpOnStaleBranch(t *testing.T) { // Give time to process new blocks. time.Sleep(time.Millisecond * 500) + source, err := ctx.graph.SourceNode() + require.NoError(t, err) + // Create new router with same graph database. router, err := New(Config{ + SelfNode: source.PubKeyBytes, RoutingGraph: newMockGraphSessionChanDB(ctx.graph), Graph: ctx.graph, Chain: ctx.chain, diff --git a/server.go b/server.go index 5b96ec7854..b3852dbcc4 100644 --- a/server.go +++ b/server.go @@ -974,6 +974,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, cfg.Routing.StrictZombiePruning s.chanRouter, err = routing.New(routing.Config{ + SelfNode: selfNode.PubKeyBytes, RoutingGraph: graphsession.NewRoutingGraph(chanGraph), Graph: chanGraph, Chain: cc.ChainIO, From 71e93526d6bd07f3cf5f5e7df48eb0dcf676a67e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 14 Jun 2024 19:51:11 -0400 Subject: [PATCH 07/20] multi+refactor: let FetchChanInfos not take tx In preparation for having a clean graph DB interface, refactor FetchChanInfos so that no transaction can be provided. --- channeldb/graph.go | 17 +++++++++++++---- channeldb/graph_test.go | 2 +- discovery/chan_series.go | 2 +- routing/router.go | 2 +- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index 227d5fadd1..0ac0ec2f72 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2374,10 +2374,19 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, // skipped and the result will contain only those edges that exist at the time // of the query. This can be used to respond to peer queries that are seeking to // fill in gaps in their view of the channel graph. +func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { + return c.fetchChanInfos(nil, chanIDs) +} + +// fetchChanInfos returns the set of channel edges that correspond to the passed +// channel ID's. If an edge is the query is unknown to the database, it will +// skipped and the result will contain only those edges that exist at the time +// of the query. This can be used to respond to peer queries that are seeking to +// fill in gaps in their view of the channel graph. // // NOTE: An optional transaction may be provided. If none is provided, then a // new one will be created. -func (c *ChannelGraph) FetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( +func (c *ChannelGraph) fetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( []ChannelEdge, error) { // TODO(roasbeef): sort cids? @@ -2958,8 +2967,8 @@ func (c *ChannelGraph) isPublic(tx kvdb.RTx, nodePub route.Vertex, // key. If the node isn't found in the database, then ErrGraphNodeNotFound is // returned. An optional transaction may be provided. If none is provided, then // a new one will be created. -func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) ( - *LightningNode, error) { +func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, + nodePub route.Vertex) (*LightningNode, error) { var node *LightningNode fetch := func(tx kvdb.RTx) error { @@ -3705,7 +3714,7 @@ func (c *ChannelGraph) markEdgeLiveUnsafe(tx kvdb.RwTx, chanID uint64) error { // We need to add the channel back into our graph cache, otherwise we // won't use it for path finding. if c.graphCache != nil { - edgeInfos, err := c.FetchChanInfos(tx, []uint64{chanID}) + edgeInfos, err := c.fetchChanInfos(tx, []uint64{chanID}) if err != nil { return err } diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 717b99fd1b..9e430296bd 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -2685,7 +2685,7 @@ func TestFetchChanInfos(t *testing.T) { // We'll now attempt to query for the range of channel ID's we just // inserted into the database. We should get the exact same set of // edges back. - resp, err := graph.FetchChanInfos(nil, edgeQuery) + resp, err := graph.FetchChanInfos(edgeQuery) require.NoError(t, err, "unable to fetch chan edges") if len(resp) != len(edges) { t.Fatalf("expected %v edges, instead got %v", len(edges), diff --git a/discovery/chan_series.go b/discovery/chan_series.go index bd6571b87d..34e6d4a9db 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -249,7 +249,7 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash, chanIDs = append(chanIDs, chanID.ToUint64()) } - channels, err := c.graph.FetchChanInfos(nil, chanIDs) + channels, err := c.graph.FetchChanInfos(chanIDs) if err != nil { return nil, err } diff --git a/routing/router.go b/routing/router.go index 04b6cad457..c968f8d321 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1023,7 +1023,7 @@ func (r *ChannelRouter) pruneZombieChans() error { } disabledEdges, err := r.cfg.Graph.FetchChanInfos( - nil, disabledChanIDs, + disabledChanIDs, ) if err != nil { return fmt.Errorf("unable to fetch disabled channels "+ From c20d759d4198ca8ad89c85d9771edea6066c2ba8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 14 Jun 2024 20:29:26 -0400 Subject: [PATCH 08/20] refactor: create FetchLightningNode with no tx param In preparation for adding a clean Graph DB interface, we create a version of FetchLightningNode that doesnt allow a caller to provide in a transaction. --- autopilot/graph.go | 6 ++++-- channeldb/db.go | 2 +- channeldb/graph.go | 23 +++++++++++++++++++++-- channeldb/graph_test.go | 14 +++++++------- routing/router.go | 2 +- routing/router_test.go | 4 ++-- rpcserver.go | 4 ++-- server.go | 2 +- 8 files changed, 39 insertions(+), 18 deletions(-) diff --git a/autopilot/graph.go b/autopilot/graph.go index 74b04c5034..83447af9b6 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -105,7 +105,9 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { return nil } - node, err := d.db.FetchLightningNode(tx, ep.ToNode) + node, err := d.db.FetchLightningNodeTx( + tx, ep.ToNode, + ) if err != nil { return err } @@ -164,7 +166,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, return nil, err } - dbNode, err := d.db.FetchLightningNode(nil, vertex) + dbNode, err := d.db.FetchLightningNode(vertex) switch { case err == channeldb.ErrGraphNodeNotFound: fallthrough diff --git a/channeldb/db.go b/channeldb/db.go index 93bb239bb6..1e210d032a 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -1351,7 +1351,7 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, if err != nil { return nil, err } - graphNode, err := d.graph.FetchLightningNode(nil, pubKey) + graphNode, err := d.graph.FetchLightningNode(pubKey) if err != nil && err != ErrGraphNodeNotFound { return nil, err } else if err == ErrGraphNodeNotFound { diff --git a/channeldb/graph.go b/channeldb/graph.go index 0ac0ec2f72..a37ea9682c 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -529,7 +529,7 @@ func (c *ChannelGraph) FetchNodeFeatures( } // Fallback that uses the database. - targetNode, err := c.FetchLightningNode(nil, node) + targetNode, err := c.FetchLightningNode(node) switch err { // If the node exists and has features, return them directly. case nil: @@ -2963,11 +2963,30 @@ func (c *ChannelGraph) isPublic(tx kvdb.RTx, nodePub route.Vertex, return nodeIsPublic, nil } +// FetchLightningNodeTx attempts to look up a target node by its identity +// public key. If the node isn't found in the database, then +// ErrGraphNodeNotFound is returned. An optional transaction may be provided. +// If none is provided, then a new one will be created. +func (c *ChannelGraph) FetchLightningNodeTx(tx kvdb.RTx, nodePub route.Vertex) ( + *LightningNode, error) { + + return c.fetchLightningNode(tx, nodePub) +} + // FetchLightningNode attempts to look up a target node by its identity public // key. If the node isn't found in the database, then ErrGraphNodeNotFound is +// returned. +func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (*LightningNode, + error) { + + return c.fetchLightningNode(nil, nodePub) +} + +// fetchLightningNode attempts to look up a target node by its identity public +// key. If the node isn't found in the database, then ErrGraphNodeNotFound is // returned. An optional transaction may be provided. If none is provided, then // a new one will be created. -func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, +func (c *ChannelGraph) fetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (*LightningNode, error) { var node *LightningNode diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 9e430296bd..46bc0d3fda 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -141,7 +141,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { // Next, fetch the node from the database to ensure everything was // serialized properly. - dbNode, err := graph.FetchLightningNode(nil, testPub) + dbNode, err := graph.FetchLightningNode(testPub) require.NoError(t, err, "unable to locate node") if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { @@ -164,7 +164,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. - _, err = graph.FetchLightningNode(nil, testPub) + _, err = graph.FetchLightningNode(testPub) if err != ErrGraphNodeNotFound { t.Fatalf("fetch after delete should fail!") } @@ -192,7 +192,7 @@ func TestPartialNode(t *testing.T) { // Next, fetch the node from the database to ensure everything was // serialized properly. - dbNode, err := graph.FetchLightningNode(nil, testPub) + dbNode, err := graph.FetchLightningNode(testPub) require.NoError(t, err, "unable to locate node") if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { @@ -222,7 +222,7 @@ func TestPartialNode(t *testing.T) { // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. - _, err = graph.FetchLightningNode(nil, testPub) + _, err = graph.FetchLightningNode(testPub) if err != ErrGraphNodeNotFound { t.Fatalf("fetch after delete should fail!") } @@ -3014,7 +3014,7 @@ func TestPruneGraphNodes(t *testing.T) { // Finally, we'll ensure that node3, the only fully unconnected node as // properly deleted from the graph and not another node in its place. - _, err = graph.FetchLightningNode(nil, node3.PubKeyBytes) + _, err = graph.FetchLightningNode(node3.PubKeyBytes) if err == nil { t.Fatalf("node 3 should have been deleted!") } @@ -3048,13 +3048,13 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { // Ensure that node1 was inserted as a full node, while node2 only has // a shell node present. - node1, err = graph.FetchLightningNode(nil, node1.PubKeyBytes) + node1, err = graph.FetchLightningNode(node1.PubKeyBytes) require.NoError(t, err, "unable to fetch node1") if !node1.HaveNodeAnnouncement { t.Fatalf("have shell announcement for node1, shouldn't") } - node2, err = graph.FetchLightningNode(nil, node2.PubKeyBytes) + node2, err = graph.FetchLightningNode(node2.PubKeyBytes) require.NoError(t, err, "unable to fetch node2") if node2.HaveNodeAnnouncement { t.Fatalf("should have shell announcement for node2, but is full") diff --git a/routing/router.go b/routing/router.go index c968f8d321..c91e0b1879 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2915,7 +2915,7 @@ func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) ( func (r *ChannelRouter) FetchLightningNode( node route.Vertex) (*channeldb.LightningNode, error) { - return r.cfg.Graph.FetchLightningNode(nil, node) + return r.cfg.Graph.FetchLightningNode(node) } // ForEachNode is used to iterate over every node in router topology. diff --git a/routing/router_test.go b/routing/router_test.go index 6f11894dc7..49ca6a2665 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1596,14 +1596,14 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { _, _, err = ctx.router.FindRoute(req) require.NoError(t, err, "unable to find any routes") - copy1, err := ctx.graph.FetchLightningNode(nil, pub1) + copy1, err := ctx.graph.FetchLightningNode(pub1) require.NoError(t, err, "unable to fetch node") if copy1.Alias != n1.Alias { t.Fatalf("fetched node not equal to original") } - copy2, err := ctx.graph.FetchLightningNode(nil, pub2) + copy2, err := ctx.graph.FetchLightningNode(pub2) require.NoError(t, err, "unable to fetch node") if copy2.Alias != n2.Alias { diff --git a/rpcserver.go b/rpcserver.go index c2ab4eb543..2012e75d9f 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6345,7 +6345,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, // With the public key decoded, attempt to fetch the node corresponding // to this public key. If the node cannot be found, then an error will // be returned. - node, err := graph.FetchLightningNode(nil, pubKey) + node, err := graph.FetchLightningNode(pubKey) switch { case err == channeldb.ErrGraphNodeNotFound: return nil, status.Error(codes.NotFound, err.Error()) @@ -7393,7 +7393,7 @@ func (r *rpcServer) ForwardingHistory(ctx context.Context, return "", err } - peer, err := r.server.graphDB.FetchLightningNode(nil, vertex) + peer, err := r.server.graphDB.FetchLightningNode(vertex) if err != nil { return "", err } diff --git a/server.go b/server.go index b3852dbcc4..555513406a 100644 --- a/server.go +++ b/server.go @@ -4634,7 +4634,7 @@ func (s *server) fetchNodeAdvertisedAddrs(pub *btcec.PublicKey) ([]net.Addr, err return nil, err } - node, err := s.graphDB.FetchLightningNode(nil, vertex) + node, err := s.graphDB.FetchLightningNode(vertex) if err != nil { return nil, err } From e9c89ae0ec3c846b1fa84a4d879cdf200804a71f Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 14 Jun 2024 20:34:53 -0400 Subject: [PATCH 09/20] multi+refactor: create ForEachNodeChannel with no tx param In prep for a clean Graph DB interface, we add a version of ForEachNodeChannel that does not take in an existing db transaction. --- autopilot/graph.go | 2 +- channeldb/graph.go | 26 +++++++++++++++++++++----- channeldb/graph_test.go | 4 ++-- routing/router.go | 2 +- rpcserver.go | 4 ++-- server.go | 2 +- 6 files changed, 28 insertions(+), 12 deletions(-) diff --git a/autopilot/graph.go b/autopilot/graph.go index 83447af9b6..2ce49c1272 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -89,7 +89,7 @@ func (d *dbNode) Addrs() []net.Addr { // // NOTE: Part of the autopilot.Node interface. func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { - return d.db.ForEachNodeChannel(d.tx, d.node.PubKeyBytes, + return d.db.ForEachNodeChannelTx(d.tx, d.node.PubKeyBytes, func(tx kvdb.RTx, ei *models.ChannelEdgeInfo, ep, _ *models.ChannelEdgePolicy) error { diff --git a/channeldb/graph.go b/channeldb/graph.go index a37ea9682c..4146721660 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -565,7 +565,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, return c.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error { channels := make(map[uint64]*DirectedChannel) - err := c.ForEachNodeChannel(tx, node.PubKeyBytes, + err := c.ForEachNodeChannelTx(tx, node.PubKeyBytes, func(tx kvdb.RTx, e *models.ChannelEdgeInfo, p1 *models.ChannelEdgePolicy, p2 *models.ChannelEdgePolicy) error { @@ -2931,7 +2931,7 @@ func (c *ChannelGraph) isPublic(tx kvdb.RTx, nodePub route.Vertex, // used to terminate the check early. nodeIsPublic := false errDone := errors.New("done") - err := c.ForEachNodeChannel(tx, nodePub, func(tx kvdb.RTx, + err := c.ForEachNodeChannelTx(tx, nodePub, func(tx kvdb.RTx, info *models.ChannelEdgeInfo, _ *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -3224,13 +3224,29 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. +func (c *ChannelGraph) ForEachNodeChannel(nodePub route.Vertex, + cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error { + + return nodeTraversal(nil, nodePub[:], c.db, cb) +} + +// ForEachNodeChannelTx iterates through all channels of the given node, +// executing the passed callback with an edge info structure and the policies +// of each end of the channel. The first edge policy is the outgoing edge *to* +// the connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. // // If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should +// should be passed as the first argument. Otherwise, the first argument should // be nil and a fresh transaction will be created to execute the graph // traversal. -func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub route.Vertex, - cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, +func (c *ChannelGraph) ForEachNodeChannelTx(tx kvdb.RTx, + nodePub route.Vertex, cb func(kvdb.RTx, *models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { return nodeTraversal(tx, nodePub[:], c.db, cb) diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 46bc0d3fda..2568512385 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -1055,7 +1055,7 @@ func TestGraphTraversal(t *testing.T) { // outgoing channels for a particular node. numNodeChans := 0 firstNode, secondNode := nodeList[0], nodeList[1] - err = graph.ForEachNodeChannel(nil, firstNode.PubKeyBytes, + err = graph.ForEachNodeChannel(firstNode.PubKeyBytes, func(_ kvdb.RTx, _ *models.ChannelEdgeInfo, outEdge, inEdge *models.ChannelEdgePolicy) error { @@ -2737,7 +2737,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { // Ensure that channel is reported with unknown policies. checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { calls := 0 - err := graph.ForEachNodeChannel(nil, node.PubKeyBytes, + err := graph.ForEachNodeChannel(node.PubKeyBytes, func(_ kvdb.RTx, _ *models.ChannelEdgeInfo, outEdge, inEdge *models.ChannelEdgePolicy) error { diff --git a/routing/router.go b/routing/router.go index c91e0b1879..33f5a7814a 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2937,7 +2937,7 @@ func (r *ChannelRouter) ForEachNode( func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { - return r.cfg.Graph.ForEachNodeChannel(nil, r.cfg.SelfNode, + return r.cfg.Graph.ForEachNodeChannel(r.cfg.SelfNode, func(tx kvdb.RTx, c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { diff --git a/rpcserver.go b/rpcserver.go index 2012e75d9f..011674e6c4 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6361,7 +6361,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, channels []*lnrpc.ChannelEdge ) - err = graph.ForEachNodeChannel(nil, node.PubKeyBytes, + err = graph.ForEachNodeChannel(node.PubKeyBytes, func(_ kvdb.RTx, edge *models.ChannelEdgeInfo, c1, c2 *models.ChannelEdgePolicy) error { @@ -7014,7 +7014,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, } var feeReports []*lnrpc.ChannelFeeReport - err = channelGraph.ForEachNodeChannel(nil, selfNode.PubKeyBytes, + err = channelGraph.ForEachNodeChannel(selfNode.PubKeyBytes, func(_ kvdb.RTx, chanInfo *models.ChannelEdgeInfo, edgePolicy, _ *models.ChannelEdgePolicy) error { diff --git a/server.go b/server.go index 555513406a..f32fb611e9 100644 --- a/server.go +++ b/server.go @@ -3119,7 +3119,7 @@ func (s *server) establishPersistentConnections() error { // TODO(roasbeef): instead iterate over link nodes and query graph for // each of the nodes. selfPub := s.identityECDH.PubKey().SerializeCompressed() - err = s.graphDB.ForEachNodeChannel(nil, sourceNode.PubKeyBytes, func( + err = s.graphDB.ForEachNodeChannel(sourceNode.PubKeyBytes, func( tx kvdb.RTx, chanInfo *models.ChannelEdgeInfo, policy, _ *models.ChannelEdgePolicy) error { From c1d7a9d2e75b9f0f30fade57ffb05660bb844bbe Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 14 Jun 2024 20:05:20 -0400 Subject: [PATCH 10/20] multi: move ChannelGraphSource interface ... to the new `graph` package in preparation for the implementation of the interface being moved to this new package. --- discovery/gossiper.go | 3 +- discovery/gossiper_test.go | 3 +- graph/interfaces.go | 90 ++++++++++++++++++++++++++++++++++++++ routing/router.go | 79 +-------------------------------- 4 files changed, 96 insertions(+), 79 deletions(-) create mode 100644 graph/interfaces.go diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 8a1c30136f..6786b9db0a 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -20,6 +20,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnpeer" @@ -169,7 +170,7 @@ type Config struct { // topology of lightning network. After incoming channel, node, channel // updates announcements are validated they are sent to the router in // order to be included in the LN graph. - Router routing.ChannelGraphSource + Router graph.ChannelGraphSource // ChanSeries is an interfaces that provides access to a time series // view of the current known channel graph. Each GossipSyncer enabled diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index a7b8505298..4971e980aa 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -25,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnpeer" @@ -108,7 +109,7 @@ func newMockRouter(height uint32) *mockGraphSource { } } -var _ routing.ChannelGraphSource = (*mockGraphSource)(nil) +var _ graph.ChannelGraphSource = (*mockGraphSource)(nil) func (r *mockGraphSource) AddNode(node *channeldb.LightningNode, _ ...batch.SchedulerOption) error { diff --git a/graph/interfaces.go b/graph/interfaces.go new file mode 100644 index 0000000000..b49a07af96 --- /dev/null +++ b/graph/interfaces.go @@ -0,0 +1,90 @@ +package graph + +import ( + "time" + + "github.com/lightningnetwork/lnd/batch" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// ChannelGraphSource represents the source of information about the topology +// of the lightning network. It's responsible for the addition of nodes, edges, +// applying edge updates, and returning the current block height with which the +// topology is synchronized. +// +//nolint:interfacebloat +type ChannelGraphSource interface { + // AddNode is used to add information about a node to the router + // database. If the node with this pubkey is not present in an existing + // channel, it will be ignored. + AddNode(node *channeldb.LightningNode, + op ...batch.SchedulerOption) error + + // AddEdge is used to add edge/channel to the topology of the router, + // after all information about channel will be gathered this + // edge/channel might be used in construction of payment path. + AddEdge(edge *models.ChannelEdgeInfo, + op ...batch.SchedulerOption) error + + // AddProof updates the channel edge info with proof which is needed to + // properly announce the edge to the rest of the network. + AddProof(chanID lnwire.ShortChannelID, + proof *models.ChannelAuthProof) error + + // UpdateEdge is used to update edge information, without this message + // edge considered as not fully constructed. + UpdateEdge(policy *models.ChannelEdgePolicy, + op ...batch.SchedulerOption) error + + // IsStaleNode returns true if the graph source has a node announcement + // for the target node with a more recent timestamp. This method will + // also return true if we don't have an active channel announcement for + // the target node. + IsStaleNode(node route.Vertex, timestamp time.Time) bool + + // IsPublicNode determines whether the given vertex is seen as a public + // node in the graph from the graph's source node's point of view. + IsPublicNode(node route.Vertex) (bool, error) + + // IsKnownEdge returns true if the graph source already knows of the + // passed channel ID either as a live or zombie edge. + IsKnownEdge(chanID lnwire.ShortChannelID) bool + + // IsStaleEdgePolicy returns true if the graph source has a channel + // edge for the passed channel ID (and flags) that have a more recent + // timestamp. + IsStaleEdgePolicy(chanID lnwire.ShortChannelID, timestamp time.Time, + flags lnwire.ChanUpdateChanFlags) bool + + // MarkEdgeLive clears an edge from our zombie index, deeming it as + // live. + MarkEdgeLive(chanID lnwire.ShortChannelID) error + + // ForAllOutgoingChannels is used to iterate over all channels + // emanating from the "source" node which is the center of the + // star-graph. + ForAllOutgoingChannels(cb func(tx kvdb.RTx, + c *models.ChannelEdgeInfo, + e *models.ChannelEdgePolicy) error) error + + // CurrentBlockHeight returns the block height from POV of the router + // subsystem. + CurrentBlockHeight() (uint32, error) + + // GetChannelByID return the channel by the channel id. + GetChannelByID(chanID lnwire.ShortChannelID) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) + + // FetchLightningNode attempts to look up a target node by its identity + // public key. channeldb.ErrGraphNodeNotFound is returned if the node + // doesn't exist within the graph. + FetchLightningNode(route.Vertex) (*channeldb.LightningNode, error) + + // ForEachNode is used to iterate over every node in the known graph. + ForEachNode(func(node *channeldb.LightningNode) error) error +} diff --git a/routing/router.go b/routing/router.go index 33f5a7814a..f722c26645 100644 --- a/routing/router.go +++ b/routing/router.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" @@ -127,82 +128,6 @@ var ( ErrSkipTempErr = errors.New("cannot skip temp error for non-MPP") ) -// ChannelGraphSource represents the source of information about the topology -// of the lightning network. It's responsible for the addition of nodes, edges, -// applying edge updates, and returning the current block height with which the -// topology is synchronized. -type ChannelGraphSource interface { - // AddNode is used to add information about a node to the router - // database. If the node with this pubkey is not present in an existing - // channel, it will be ignored. - AddNode(node *channeldb.LightningNode, - op ...batch.SchedulerOption) error - - // AddEdge is used to add edge/channel to the topology of the router, - // after all information about channel will be gathered this - // edge/channel might be used in construction of payment path. - AddEdge(edge *models.ChannelEdgeInfo, - op ...batch.SchedulerOption) error - - // AddProof updates the channel edge info with proof which is needed to - // properly announce the edge to the rest of the network. - AddProof(chanID lnwire.ShortChannelID, - proof *models.ChannelAuthProof) error - - // UpdateEdge is used to update edge information, without this message - // edge considered as not fully constructed. - UpdateEdge(policy *models.ChannelEdgePolicy, - op ...batch.SchedulerOption) error - - // IsStaleNode returns true if the graph source has a node announcement - // for the target node with a more recent timestamp. This method will - // also return true if we don't have an active channel announcement for - // the target node. - IsStaleNode(node route.Vertex, timestamp time.Time) bool - - // IsPublicNode determines whether the given vertex is seen as a public - // node in the graph from the graph's source node's point of view. - IsPublicNode(node route.Vertex) (bool, error) - - // IsKnownEdge returns true if the graph source already knows of the - // passed channel ID either as a live or zombie edge. - IsKnownEdge(chanID lnwire.ShortChannelID) bool - - // IsStaleEdgePolicy returns true if the graph source has a channel - // edge for the passed channel ID (and flags) that have a more recent - // timestamp. - IsStaleEdgePolicy(chanID lnwire.ShortChannelID, timestamp time.Time, - flags lnwire.ChanUpdateChanFlags) bool - - // MarkEdgeLive clears an edge from our zombie index, deeming it as - // live. - MarkEdgeLive(chanID lnwire.ShortChannelID) error - - // ForAllOutgoingChannels is used to iterate over all channels - // emanating from the "source" node which is the center of the - // star-graph. - ForAllOutgoingChannels(cb func(tx kvdb.RTx, - c *models.ChannelEdgeInfo, - e *models.ChannelEdgePolicy) error) error - - // CurrentBlockHeight returns the block height from POV of the router - // subsystem. - CurrentBlockHeight() (uint32, error) - - // GetChannelByID return the channel by the channel id. - GetChannelByID(chanID lnwire.ShortChannelID) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) - - // FetchLightningNode attempts to look up a target node by its identity - // public key. channeldb.ErrGraphNodeNotFound is returned if the node - // doesn't exist within the graph. - FetchLightningNode(route.Vertex) (*channeldb.LightningNode, error) - - // ForEachNode is used to iterate over every node in the known graph. - ForEachNode(func(node *channeldb.LightningNode) error) error -} - // PaymentAttemptDispatcher is used by the router to send payment attempts onto // the network, and receive their results. type PaymentAttemptDispatcher interface { @@ -499,7 +424,7 @@ type ChannelRouter struct { // A compile time check to ensure ChannelRouter implements the // ChannelGraphSource interface. -var _ ChannelGraphSource = (*ChannelRouter)(nil) +var _ graph.ChannelGraphSource = (*ChannelRouter)(nil) // New creates a new instance of the ChannelRouter with the specified // configuration parameters. As part of initialization, if the router detects From be84d6974ee7763af2582c45d97c96f76e1ea668 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 16 Jun 2024 19:00:45 -0400 Subject: [PATCH 11/20] channeldb: add a graph.DB interface ..which describes the database methods that are required for graph maintaining and building. --- graph/interfaces.go | 190 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) diff --git a/graph/interfaces.go b/graph/interfaces.go index b49a07af96..7ae79f9a9f 100644 --- a/graph/interfaces.go +++ b/graph/interfaces.go @@ -3,6 +3,8 @@ package graph import ( "time" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" @@ -88,3 +90,191 @@ type ChannelGraphSource interface { // ForEachNode is used to iterate over every node in the known graph. ForEachNode(func(node *channeldb.LightningNode) error) error } + +// DB is an interface describing a persisted Lightning Network graph. +// +//nolint:interfacebloat +type DB interface { + // PruneTip returns the block height and hash of the latest block that + // has been used to prune channels in the graph. Knowing the "prune tip" + // allows callers to tell if the graph is currently in sync with the + // current best known UTXO state. + PruneTip() (*chainhash.Hash, uint32, error) + + // PruneGraph prunes newly closed channels from the channel graph in + // response to a new block being solved on the network. Any transactions + // which spend the funding output of any known channels within the graph + // will be deleted. Additionally, the "prune tip", or the last block + // which has been used to prune the graph is stored so callers can + // ensure the graph is fully in sync with the current UTXO state. A + // slice of channels that have been closed by the target block are + // returned if the function succeeds without error. + PruneGraph(spentOutputs []*wire.OutPoint, blockHash *chainhash.Hash, + blockHeight uint32) ([]*models.ChannelEdgeInfo, error) + + // ChannelView returns the verifiable edge information for each active + // channel within the known channel graph. The set of UTXO's (along with + // their scripts) returned are the ones that need to be watched on + // chain to detect channel closes on the resident blockchain. + ChannelView() ([]channeldb.EdgePoint, error) + + // PruneGraphNodes is a garbage collection method which attempts to + // prune out any nodes from the channel graph that are currently + // unconnected. This ensure that we only maintain a graph of reachable + // nodes. In the event that a pruned node gains more channels, it will + // be re-added back to the graph. + PruneGraphNodes() error + + // SourceNode returns the source node of the graph. The source node is + // treated as the center node within a star-graph. This method may be + // used to kick off a path finding algorithm in order to explore the + // reachability of another node based off the source node. + SourceNode() (*channeldb.LightningNode, error) + + // DisabledChannelIDs returns the channel ids of disabled channels. + // A channel is disabled when two of the associated ChanelEdgePolicies + // have their disabled bit on. + DisabledChannelIDs() ([]uint64, error) + + // FetchChanInfos returns the set of channel edges that correspond to + // the passed channel ID's. If an edge is the query is unknown to the + // database, it will skipped and the result will contain only those + // edges that exist at the time of the query. This can be used to + // respond to peer queries that are seeking to fill in gaps in their + // view of the channel graph. + FetchChanInfos(chanIDs []uint64) ([]channeldb.ChannelEdge, error) + + // ChanUpdatesInHorizon returns all the known channel edges which have + // at least one edge that has an update timestamp within the specified + // horizon. + ChanUpdatesInHorizon(startTime, endTime time.Time) ( + []channeldb.ChannelEdge, error) + + // DeleteChannelEdges removes edges with the given channel IDs from the + // database and marks them as zombies. This ensures that we're unable to + // re-add it to our database once again. If an edge does not exist + // within the database, then ErrEdgeNotFound will be returned. If + // strictZombiePruning is true, then when we mark these edges as + // zombies, we'll set up the keys such that we require the node that + // failed to send the fresh update to be the one that resurrects the + // channel from its zombie state. The markZombie bool denotes whether + // to mark the channel as a zombie. + DeleteChannelEdges(strictZombiePruning, markZombie bool, + chanIDs ...uint64) error + + // DisconnectBlockAtHeight is used to indicate that the block specified + // by the passed height has been disconnected from the main chain. This + // will "rewind" the graph back to the height below, deleting channels + // that are no longer confirmed from the graph. The prune log will be + // set to the last prune height valid for the remaining chain. + // Channels that were removed from the graph resulting from the + // disconnected block are returned. + DisconnectBlockAtHeight(height uint32) ([]*models.ChannelEdgeInfo, + error) + + // HasChannelEdge returns true if the database knows of a channel edge + // with the passed channel ID, and false otherwise. If an edge with that + // ID is found within the graph, then two time stamps representing the + // last time the edge was updated for both directed edges are returned + // along with the boolean. If it is not found, then the zombie index is + // checked and its result is returned as the second boolean. + HasChannelEdge(chanID uint64) (time.Time, time.Time, bool, bool, error) + + // FetchChannelEdgesByID attempts to lookup the two directed edges for + // the channel identified by the channel ID. If the channel can't be + // found, then ErrEdgeNotFound is returned. A struct which houses the + // general information for the channel itself is returned as well as + // two structs that contain the routing policies for the channel in + // either direction. + // + // ErrZombieEdge an be returned if the edge is currently marked as a + // zombie within the database. In this case, the ChannelEdgePolicy's + // will be nil, and the ChannelEdgeInfo will only include the public + // keys of each node. + FetchChannelEdgesByID(chanID uint64) (*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) + + // AddLightningNode adds a vertex/node to the graph database. If the + // node is not in the database from before, this will add a new, + // unconnected one to the graph. If it is present from before, this will + // update that node's information. Note that this method is expected to + // only be called to update an already present node from a node + // announcement, or to insert a node found in a channel update. + AddLightningNode(node *channeldb.LightningNode, + op ...batch.SchedulerOption) error + + // AddChannelEdge adds a new (undirected, blank) edge to the graph + // database. An undirected edge from the two target nodes are created. + // The information stored denotes the static attributes of the channel, + // such as the channelID, the keys involved in creation of the channel, + // and the set of features that the channel supports. The chanPoint and + // chanID are used to uniquely identify the edge globally within the + // database. + AddChannelEdge(edge *models.ChannelEdgeInfo, + op ...batch.SchedulerOption) error + + // MarkEdgeZombie attempts to mark a channel identified by its channel + // ID as a zombie. This method is used on an ad-hoc basis, when channels + // need to be marked as zombies outside the normal pruning cycle. + MarkEdgeZombie(chanID uint64, pubKey1, pubKey2 [33]byte) error + + // UpdateEdgePolicy updates the edge routing policy for a single + // directed edge within the database for the referenced channel. The + // `flags` attribute within the ChannelEdgePolicy determines which of + // the directed edges are being updated. If the flag is 1, then the + // first node's information is being updated, otherwise it's the second + // node's information. The node ordering is determined by the + // lexicographical ordering of the identity public keys of the nodes on + // either side of the channel. + UpdateEdgePolicy(edge *models.ChannelEdgePolicy, + op ...batch.SchedulerOption) error + + // HasLightningNode determines if the graph has a vertex identified by + // the target node identity public key. If the node exists in the + // database, a timestamp of when the data for the node was lasted + // updated is returned along with a true boolean. Otherwise, an empty + // time.Time is returned with a false boolean. + HasLightningNode(nodePub [33]byte) (time.Time, bool, error) + + // FetchLightningNode attempts to look up a target node by its identity + // public key. If the node isn't found in the database, then + // ErrGraphNodeNotFound is returned. + FetchLightningNode(nodePub route.Vertex) (*channeldb.LightningNode, + error) + + // ForEachNode iterates through all the stored vertices/nodes in the + // graph, executing the passed callback with each node encountered. If + // the callback returns an error, then the transaction is aborted and + // the iteration stops early. + ForEachNode(cb func(kvdb.RTx, *channeldb.LightningNode) error) error + + // ForEachNodeChannel iterates through all channels of the given node, + // executing the passed callback with an edge info structure and the + // policies of each end of the channel. The first edge policy is the + // outgoing edge *to* the connecting node, while the second is the + // incoming edge *from* the connecting node. If the callback returns an + // error, then the iteration is halted with the error propagated back up + // to the caller. + // + // Unknown policies are passed into the callback as nil values. + ForEachNodeChannel(nodePub route.Vertex, cb func(kvdb.RTx, + *models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error + + // UpdateChannelEdge retrieves and update edge of the graph database. + // Method only reserved for updating an edge info after its already been + // created. In order to maintain this constraints, we return an error in + // the scenario that an edge info hasn't yet been created yet, but + // someone attempts to update it. + UpdateChannelEdge(edge *models.ChannelEdgeInfo) error + + // IsPublicNode is a helper method that determines whether the node with + // the given public key is seen as a public node in the graph from the + // graph's source node's point of view. + IsPublicNode(pubKey [33]byte) (bool, error) + + // MarkEdgeLive clears an edge from our zombie index, deeming it as + // live. + MarkEdgeLive(chanID uint64) error +} From 30e6671a130ea072b89e529f9592e693be4eddff Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 16 Jun 2024 19:02:55 -0400 Subject: [PATCH 12/20] routing: use new graph.DB interface in ChannelRouter --- routing/notifications.go | 3 ++- routing/router.go | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/routing/notifications.go b/routing/notifications.go index 3afbb15330..7263b9a47c 100644 --- a/routing/notifications.go +++ b/routing/notifications.go @@ -13,6 +13,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/lnwire" ) @@ -313,7 +314,7 @@ type ChannelEdgeUpdate struct { // constitutes. This function will also fetch any required auxiliary // information required to create the topology change update from the graph // database. -func addToTopologyChange(graph *channeldb.ChannelGraph, update *TopologyChange, +func addToTopologyChange(graph graph.DB, update *TopologyChange, msg interface{}) error { switch m := msg.(type) { diff --git a/routing/router.go b/routing/router.go index f722c26645..276744d1b3 100644 --- a/routing/router.go +++ b/routing/router.go @@ -253,8 +253,7 @@ type Config struct { // Graph is the channel graph that the ChannelRouter will use to gather // metrics from and also to carry out path finding queries. - // TODO(roasbeef): make into an interface - Graph *channeldb.ChannelGraph + Graph graph.DB // Chain is the router's source to the most up-to-date blockchain data. // All incoming advertised channels will be checked against the chain From 0b7364f54bd204fc9a007df90dce62a7b85176e8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 16 Jun 2024 19:15:05 -0400 Subject: [PATCH 13/20] graph+server: add template for new graph Builder sub-system This is preparation for an upcoming commit that will move over various responsibilities from the ChannelRouter to the graph Builder. So that that commit can be a pure code-move commit, the template for the new sub-system is added up front here. --- graph/builder.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++ graph/log.go | 47 +++++++++++++++++++++++++++++++++++++++ log.go | 2 ++ server.go | 14 ++++++++++++ 4 files changed, 120 insertions(+) create mode 100644 graph/builder.go create mode 100644 graph/log.go diff --git a/graph/builder.go b/graph/builder.go new file mode 100644 index 0000000000..633e33bd18 --- /dev/null +++ b/graph/builder.go @@ -0,0 +1,57 @@ +package graph + +import ( + "sync" + "sync/atomic" +) + +// Config holds the configuration required by the Builder. +type Config struct{} + +// Builder builds and maintains a view of the Lightning Network graph. +type Builder struct { + started atomic.Bool + stopped atomic.Bool + + cfg *Config + + quit chan struct{} + wg sync.WaitGroup +} + +// NewBuilder constructs a new Builder. +func NewBuilder(cfg *Config) (*Builder, error) { + return &Builder{ + cfg: cfg, + quit: make(chan struct{}), + }, nil +} + +// Start launches all the goroutines the Builder requires to carry out its +// duties. If the builder has already been started, then this method is a noop. +func (b *Builder) Start() error { + if !b.started.CompareAndSwap(false, true) { + return nil + } + + log.Info("Builder starting") + + return nil +} + +// Stop signals to the Builder that it should halt all routines. This method +// will *block* until all goroutines have excited. If the builder has already +// stopped then this method will return immediately. +func (b *Builder) Stop() error { + if !b.stopped.CompareAndSwap(false, true) { + return nil + } + + log.Info("Builder shutting down...") + defer log.Debug("Builder shutdown complete") + + close(b.quit) + b.wg.Wait() + + return nil +} diff --git a/graph/log.go b/graph/log.go new file mode 100644 index 0000000000..2bd55297a0 --- /dev/null +++ b/graph/log.go @@ -0,0 +1,47 @@ +package graph + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// log is a logger that is initialized with no output filters. This means the +// package will not perform any logging by default until the caller requests +// it. +var log btclog.Logger + +const Subsystem = "GRPH" + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// DisableLog disables all library log output. Logging output is disabled by +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. This +// should be used in preference to SetLogWriter if the caller is also using +// btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} + +// logClosure is used to provide a closure over expensive logging operations so +// don't have to be performed when the logging level doesn't warrant it. +type logClosure func() string + +// String invokes the underlying function and returns the result. +func (c logClosure) String() string { + return c() +} + +// newLogClosure returns a new closure over a function that returns a string +// which itself provides a Stringer interface so that it can be used with the +// logging system. +func newLogClosure(c func() string) logClosure { + return logClosure(c) +} diff --git a/log.go b/log.go index f6da0235a9..1b170f5ea8 100644 --- a/log.go +++ b/log.go @@ -18,6 +18,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/funding" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/healthcheck" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" @@ -179,6 +180,7 @@ func SetupLoggers(root *build.RotatingLogWriter, interceptor signal.Interceptor) AddSubLogger(root, btcwallet.Subsystem, interceptor, btcwallet.UseLogger) AddSubLogger(root, rpcwallet.Subsystem, interceptor, rpcwallet.UseLogger) AddSubLogger(root, peersrpc.Subsystem, interceptor, peersrpc.UseLogger) + AddSubLogger(root, graph.Subsystem, interceptor, graph.UseLogger) } // AddSubLogger is a helper method to conveniently create and register the diff --git a/server.go b/server.go index f32fb611e9..17108ee469 100644 --- a/server.go +++ b/server.go @@ -41,6 +41,7 @@ import ( "github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/healthcheck" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -271,6 +272,8 @@ type server struct { missionControl *routing.MissionControl + graphBuilder *graph.Builder + chanRouter *routing.ChannelRouter controlTower routing.ControlTower @@ -973,6 +976,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, strictPruning := cfg.Bitcoin.Node == "neutrino" || cfg.Routing.StrictZombiePruning + s.graphBuilder, err = graph.NewBuilder(&graph.Config{}) + if err != nil { + return nil, fmt.Errorf("can't create graph builder: %w", err) + } + s.chanRouter, err = routing.New(routing.Config{ SelfNode: selfNode.PubKeyBytes, RoutingGraph: graphsession.NewRoutingGraph(chanGraph), @@ -2019,6 +2027,12 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.authGossiper.Stop) + if err := s.graphBuilder.Start(); err != nil { + startErr = err + return + } + cleanup = cleanup.add(s.graphBuilder.Stop) + if err := s.chanRouter.Start(); err != nil { startErr = err return From 7f1be39d45574962a6a7a4379eca4e318913d5fc Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 16 Jun 2024 19:30:01 -0400 Subject: [PATCH 14/20] refactor: move various duties from ChannelRouter to graph.Builder This commit is a large refactor that moves over various responsibilities from the ChannelRouter to the graph.Builder. These include all graph related tasks such as: - graph pruning - validation of new network updates & persisting new updates - notifying topology update clients of any changes. This is a large commit but: - many of the files are purely moved from `routing` to `graph` - the business logic put in the graph Builder is copied exactly as is from the ChannelRouter with one exception: - The ChannelRouter just needs to be able to call the Builder's `ApplyChannelUpdate` method. So this is now exported and provided to the ChannelRouter as a config option. - The trickiest part was just moving over the test code since quite a bit had to be duplicated. --- autopilot/manager.go | 4 +- discovery/chan_series.go | 6 +- discovery/gossiper.go | 43 +- discovery/gossiper_test.go | 5 +- funding/manager.go | 26 +- {routing => graph}/ann_validation.go | 2 +- graph/builder.go | 1757 ++++++++- graph/builder_test.go | 2051 ++++++++++ {routing => graph}/errors.go | 20 +- {routing => graph}/notifications.go | 25 +- {routing => graph}/notifications_test.go | 303 +- graph/setup_test.go | 11 + {routing => graph}/stats.go | 2 +- graph/testdata/basic_graph.json | 298 ++ graph/testdata/spec_example.json | 147 + {routing => graph}/validation_barrier.go | 2 +- {routing => graph}/validation_barrier_test.go | 18 +- netann/channel_update_test.go | 4 +- pilot.go | 2 +- routing/pathfind_test.go | 125 +- routing/payment_lifecycle.go | 2 +- routing/payment_session.go | 3 +- routing/router.go | 2004 +--------- routing/router_test.go | 3304 +++++------------ rpcserver.go | 11 +- server.go | 48 +- 26 files changed, 5741 insertions(+), 4482 deletions(-) rename {routing => graph}/ann_validation.go (99%) create mode 100644 graph/builder_test.go rename {routing => graph}/errors.go (80%) rename {routing => graph}/notifications.go (95%) rename {routing => graph}/notifications_test.go (78%) create mode 100644 graph/setup_test.go rename {routing => graph}/stats.go (98%) create mode 100644 graph/testdata/basic_graph.json create mode 100644 graph/testdata/spec_example.json rename {routing => graph}/validation_barrier.go (99%) rename {routing => graph}/validation_barrier_test.go (91%) diff --git a/autopilot/manager.go b/autopilot/manager.go index e5999f5182..dba4cc6cc5 100644 --- a/autopilot/manager.go +++ b/autopilot/manager.go @@ -6,9 +6,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/routing" ) // ManagerCfg houses a set of values and methods that is passed to the Manager @@ -36,7 +36,7 @@ type ManagerCfg struct { // SubscribeTopology is used to get a subscription for topology changes // on the network. - SubscribeTopology func() (*routing.TopologyClient, error) + SubscribeTopology func() (*graph.TopologyClient, error) } // Manager is struct that manages an autopilot agent, making it possible to diff --git a/discovery/chan_series.go b/discovery/chan_series.go index 34e6d4a9db..8cbca1277d 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -5,9 +5,9 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" - "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" ) @@ -136,7 +136,7 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, if edge1 != nil { // We don't want to send channel updates that don't // conform to the spec (anymore). - err := routing.ValidateChannelUpdateFields(0, edge1) + err := graph.ValidateChannelUpdateFields(0, edge1) if err != nil { log.Errorf("not sending invalid channel "+ "update %v: %v", edge1, err) @@ -145,7 +145,7 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, } } if edge2 != nil { - err := routing.ValidateChannelUpdateFields(0, edge2) + err := graph.ValidateChannelUpdateFields(0, edge2) if err != nil { log.Errorf("not sending invalid channel "+ "update %v: %v", edge2, err) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 6786b9db0a..752ee7446c 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -29,7 +29,6 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/multimutex" "github.com/lightningnetwork/lnd/netann" - "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/ticker" "golang.org/x/time/rate" @@ -1361,7 +1360,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // We'll use this validation to ensure that we process jobs in their // dependency order during parallel validation. - validationBarrier := routing.NewValidationBarrier(1000, d.quit) + validationBarrier := graph.NewValidationBarrier(1000, d.quit) for { select { @@ -1486,7 +1485,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // // NOTE: must be run as a goroutine. func (d *AuthenticatedGossiper) handleNetworkMessages(nMsg *networkMsg, - deDuped *deDupedAnnouncements, vb *routing.ValidationBarrier) { + deDuped *deDupedAnnouncements, vb *graph.ValidationBarrier) { defer d.wg.Done() defer vb.CompleteJob() @@ -1502,10 +1501,10 @@ func (d *AuthenticatedGossiper) handleNetworkMessages(nMsg *networkMsg, log.Debugf("Validating network message %s got err: %v", nMsg.msg.MsgType(), err) - if !routing.IsError( + if !graph.IsError( err, - routing.ErrVBarrierShuttingDown, - routing.ErrParentValidationFailed, + graph.ErrVBarrierShuttingDown, + graph.ErrParentValidationFailed, ) { log.Warnf("unexpected error during validation "+ @@ -1861,7 +1860,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge( if err != nil { return nil, err } - err = routing.ValidateChannelAnn(chanAnn) + err = graph.ValidateChannelAnn(chanAnn) if err != nil { err := fmt.Errorf("assembled channel announcement proof "+ "for shortChanID=%v isn't valid: %v", @@ -1910,7 +1909,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge( func (d *AuthenticatedGossiper) addNode(msg *lnwire.NodeAnnouncement, op ...batch.SchedulerOption) error { - if err := routing.ValidateNodeAnn(msg); err != nil { + if err := graph.ValidateNodeAnn(msg); err != nil { return fmt.Errorf("unable to validate node announcement: %w", err) } @@ -2064,7 +2063,7 @@ func (d *AuthenticatedGossiper) processZombieUpdate( "with chan_id=%v", msg.ShortChannelID) } - err := routing.VerifyChannelUpdateSignature(msg, pubKey) + err := graph.VerifyChannelUpdateSignature(msg, pubKey) if err != nil { return fmt.Errorf("unable to verify channel "+ "update signature: %v", err) @@ -2201,7 +2200,7 @@ func (d *AuthenticatedGossiper) updateChannel(info *models.ChannelEdgeInfo, // To ensure that our signature is valid, we'll verify it ourself // before committing it to the slice returned. - err = routing.ValidateChannelUpdateAnn(d.selfKey, info.Capacity, chanUpdate) + err = graph.ValidateChannelUpdateAnn(d.selfKey, info.Capacity, chanUpdate) if err != nil { return nil, nil, fmt.Errorf("generated invalid channel "+ "update sig: %v", err) @@ -2338,11 +2337,11 @@ func (d *AuthenticatedGossiper) handleNodeAnnouncement(nMsg *networkMsg, log.Debugf("Adding node: %x got error: %v", nodeAnn.NodeID, err) - if !routing.IsError( + if !graph.IsError( err, - routing.ErrOutdated, - routing.ErrIgnored, - routing.ErrVBarrierShuttingDown, + graph.ErrOutdated, + graph.ErrIgnored, + graph.ErrVBarrierShuttingDown, ) { log.Error(err) @@ -2457,7 +2456,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // the signatures within the proof as it should be well formed. var proof *models.ChannelAuthProof if nMsg.isRemote { - if err := routing.ValidateChannelAnn(ann); err != nil { + if err := graph.ValidateChannelAnn(ann); err != nil { err := fmt.Errorf("unable to validate announcement: "+ "%v", err) @@ -2538,7 +2537,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If the edge was rejected due to already being known, then it // may be the case that this new message has a fresh channel // proof, so we'll check. - if routing.IsError(err, routing.ErrIgnored) { + if graph.IsError(err, graph.ErrIgnored) { // Attempt to process the rejected message to see if we // get any new announcements. anns, rErr := d.processRejectedEdge(ann, proof) @@ -2862,7 +2861,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // Validate the channel announcement with the expected public key and // channel capacity. In the case of an invalid channel update, we'll // return an error to the caller and exit early. - err = routing.ValidateChannelUpdateAnn(pubKey, chanInfo.Capacity, upd) + err = graph.ValidateChannelUpdateAnn(pubKey, chanInfo.Capacity, upd) if err != nil { rErr := fmt.Errorf("unable to validate channel update "+ "announcement for short_chan_id=%v: %v", @@ -2947,10 +2946,10 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, } if err := d.cfg.Router.UpdateEdge(update, ops...); err != nil { - if routing.IsError( - err, routing.ErrOutdated, - routing.ErrIgnored, - routing.ErrVBarrierShuttingDown, + if graph.IsError( + err, graph.ErrOutdated, + graph.ErrIgnored, + graph.ErrVBarrierShuttingDown, ) { log.Debugf("Update edge for short_chan_id(%v) got: %v", @@ -3268,7 +3267,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, // With all the necessary components assembled validate the full // channel announcement proof. - if err := routing.ValidateChannelAnn(chanAnn); err != nil { + if err := graph.ValidateChannelAnn(chanAnn); err != nil { err := fmt.Errorf("channel announcement proof for "+ "short_chan_id=%v isn't valid: %v", shortChanID, err) diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 4971e980aa..33d87416ac 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -33,7 +33,6 @@ import ( "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" - "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/ticker" "github.com/stretchr/testify/require" @@ -351,7 +350,7 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, // Since it exists within our zombie index, we'll check that it // respects the router's live edge horizon to determine whether // it is stale or not. - return time.Since(timestamp) > routing.DefaultChannelPruneExpiry + return time.Since(timestamp) > graph.DefaultChannelPruneExpiry } switch { @@ -2258,7 +2257,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { // We'll generate a channel update with a timestamp far enough in the // past to consider it a zombie. - zombieTimestamp := time.Now().Add(-routing.DefaultChannelPruneExpiry) + zombieTimestamp := time.Now().Add(-graph.DefaultChannelPruneExpiry) batch.chanUpdAnn2.Timestamp = uint32(zombieTimestamp.Unix()) if err := signUpdate(remoteKeyPriv2, batch.chanUpdAnn2); err != nil { t.Fatalf("unable to sign update with new timestamp: %v", err) diff --git a/funding/manager.go b/funding/manager.go index 67e4f33c5c..70d5cd9c43 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -23,6 +23,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/discovery" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/labels" @@ -33,7 +34,6 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chanfunding" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/routing" "golang.org/x/crypto/salsa20" ) @@ -3415,10 +3415,10 @@ func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, select { case err := <-errChan: if err != nil { - if routing.IsError(err, routing.ErrOutdated, - routing.ErrIgnored) { + if graph.IsError(err, graph.ErrOutdated, + graph.ErrIgnored) { - log.Debugf("Router rejected "+ + log.Debugf("Graph rejected "+ "ChannelAnnouncement: %v", err) } else { return fmt.Errorf("error sending channel "+ @@ -3435,10 +3435,10 @@ func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, select { case err := <-errChan: if err != nil { - if routing.IsError(err, routing.ErrOutdated, - routing.ErrIgnored) { + if graph.IsError(err, graph.ErrOutdated, + graph.ErrIgnored) { - log.Debugf("Router rejected "+ + log.Debugf("Graph rejected "+ "ChannelUpdate: %v", err) } else { return fmt.Errorf("error sending channel "+ @@ -4354,10 +4354,10 @@ func (f *Manager) announceChannel(localIDKey, remoteIDKey *btcec.PublicKey, select { case err := <-errChan: if err != nil { - if routing.IsError(err, routing.ErrOutdated, - routing.ErrIgnored) { + if graph.IsError(err, graph.ErrOutdated, + graph.ErrIgnored) { - log.Debugf("Router rejected "+ + log.Debugf("Graph rejected "+ "AnnounceSignatures: %v", err) } else { log.Errorf("Unable to send channel "+ @@ -4384,10 +4384,10 @@ func (f *Manager) announceChannel(localIDKey, remoteIDKey *btcec.PublicKey, select { case err := <-errChan: if err != nil { - if routing.IsError(err, routing.ErrOutdated, - routing.ErrIgnored) { + if graph.IsError(err, graph.ErrOutdated, + graph.ErrIgnored) { - log.Debugf("Router rejected "+ + log.Debugf("Graph rejected "+ "NodeAnnouncement: %v", err) } else { log.Errorf("Unable to send node "+ diff --git a/routing/ann_validation.go b/graph/ann_validation.go similarity index 99% rename from routing/ann_validation.go rename to graph/ann_validation.go index aca071e17e..3936b4652f 100644 --- a/routing/ann_validation.go +++ b/graph/ann_validation.go @@ -1,4 +1,4 @@ -package routing +package graph import ( "bytes" diff --git a/graph/builder.go b/graph/builder.go index 633e33bd18..4a3445cfc4 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -1,29 +1,183 @@ package graph import ( + "bytes" + "fmt" + "runtime" + "strings" "sync" "sync/atomic" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/batch" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnutils" + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/lnwallet/btcwallet" + "github.com/lightningnetwork/lnd/lnwallet/chanvalidate" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/multimutex" + "github.com/lightningnetwork/lnd/routing/chainview" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/ticker" +) + +const ( + // DefaultChannelPruneExpiry is the default duration used to determine + // if a channel should be pruned or not. + DefaultChannelPruneExpiry = time.Hour * 24 * 14 + + // DefaultFirstTimePruneDelay is the time we'll wait after startup + // before attempting to prune the graph for zombie channels. We don't + // do it immediately after startup to allow lnd to start up without + // getting blocked by this job. + DefaultFirstTimePruneDelay = 30 * time.Second + + // defaultStatInterval governs how often the router will log non-empty + // stats related to processing new channels, updates, or node + // announcements. + defaultStatInterval = time.Minute +) + +var ( + // ErrGraphBuilderShuttingDown is returned if the graph builder is in + // the process of shutting down. + ErrGraphBuilderShuttingDown = fmt.Errorf("graph builder shutting down") ) // Config holds the configuration required by the Builder. -type Config struct{} +type Config struct { + // SelfNode is the public key of the node that this channel router + // belongs to. + SelfNode route.Vertex + + // Graph is the channel graph that the ChannelRouter will use to gather + // metrics from and also to carry out path finding queries. + Graph DB + + // Chain is the router's source to the most up-to-date blockchain data. + // All incoming advertised channels will be checked against the chain + // to ensure that the channels advertised are still open. + Chain lnwallet.BlockChainIO + + // ChainView is an instance of a FilteredChainView which is used to + // watch the sub-set of the UTXO set (the set of active channels) that + // we need in order to properly maintain the channel graph. + ChainView chainview.FilteredChainView + + // Notifier is a reference to the ChainNotifier, used to grab + // the latest blocks if the router is missing any. + Notifier chainntnfs.ChainNotifier + + // ChannelPruneExpiry is the duration used to determine if a channel + // should be pruned or not. If the delta between now and when the + // channel was last updated is greater than ChannelPruneExpiry, then + // the channel is marked as a zombie channel eligible for pruning. + ChannelPruneExpiry time.Duration + + // GraphPruneInterval is used as an interval to determine how often we + // should examine the channel graph to garbage collect zombie channels. + GraphPruneInterval time.Duration + + // FirstTimePruneDelay is the time we'll wait after startup before + // attempting to prune the graph for zombie channels. We don't do it + // immediately after startup to allow lnd to start up without getting + // blocked by this job. + FirstTimePruneDelay time.Duration + + // AssumeChannelValid toggles whether the router will check for + // spentness of channel outpoints. For neutrino, this saves long rescans + // from blocking initial usage of the daemon. + AssumeChannelValid bool + + // StrictZombiePruning determines if we attempt to prune zombie + // channels according to a stricter criteria. If true, then we'll prune + // a channel if only *one* of the edges is considered a zombie. + // Otherwise, we'll only prune the channel when both edges have a very + // dated last update. + StrictZombiePruning bool + + // IsAlias returns whether a passed ShortChannelID is an alias. This is + // only used for our local channels. + IsAlias func(scid lnwire.ShortChannelID) bool +} // Builder builds and maintains a view of the Lightning Network graph. type Builder struct { started atomic.Bool stopped atomic.Bool + ntfnClientCounter uint64 // To be used atomically. + bestHeight uint32 // To be used atomically. + cfg *Config + // newBlocks is a channel in which new blocks connected to the end of + // the main chain are sent over, and blocks updated after a call to + // UpdateFilter. + newBlocks <-chan *chainview.FilteredBlock + + // staleBlocks is a channel in which blocks disconnected from the end + // of our currently known best chain are sent over. + staleBlocks <-chan *chainview.FilteredBlock + + // networkUpdates is a channel that carries new topology updates + // messages from outside the Builder to be processed by the + // networkHandler. + networkUpdates chan *routingMsg + + // topologyClients maps a client's unique notification ID to a + // topologyClient client that contains its notification dispatch + // channel. + topologyClients *lnutils.SyncMap[uint64, *topologyClient] + + // ntfnClientUpdates is a channel that's used to send new updates to + // topology notification clients to the Builder. Updates either + // add a new notification client, or cancel notifications for an + // existing client. + ntfnClientUpdates chan *topologyClientUpdate + + // channelEdgeMtx is a mutex we use to make sure we process only one + // ChannelEdgePolicy at a time for a given channelID, to ensure + // consistency between the various database accesses. + channelEdgeMtx *multimutex.Mutex[uint64] + + // statTicker is a resumable ticker that logs the router's progress as + // it discovers channels or receives updates. + statTicker ticker.Ticker + + // stats tracks newly processed channels, updates, and node + // announcements over a window of defaultStatInterval. + stats *routerStats + quit chan struct{} wg sync.WaitGroup } +// A compile time check to ensure Builder implements the +// ChannelGraphSource interface. +var _ ChannelGraphSource = (*Builder)(nil) + // NewBuilder constructs a new Builder. func NewBuilder(cfg *Config) (*Builder, error) { return &Builder{ - cfg: cfg, - quit: make(chan struct{}), + cfg: cfg, + networkUpdates: make(chan *routingMsg), + topologyClients: &lnutils.SyncMap[uint64, *topologyClient]{}, + ntfnClientUpdates: make(chan *topologyClientUpdate), + channelEdgeMtx: multimutex.NewMutex[uint64](), + statTicker: ticker.New(defaultStatInterval), + stats: new(routerStats), + quit: make(chan struct{}), }, nil } @@ -36,6 +190,114 @@ func (b *Builder) Start() error { log.Info("Builder starting") + bestHash, bestHeight, err := b.cfg.Chain.GetBestBlock() + if err != nil { + return err + } + + // If the graph has never been pruned, or hasn't fully been created yet, + // then we don't treat this as an explicit error. + if _, _, err := b.cfg.Graph.PruneTip(); err != nil { + switch { + case errors.Is(err, channeldb.ErrGraphNeverPruned): + fallthrough + + case errors.Is(err, channeldb.ErrGraphNotFound): + // If the graph has never been pruned, then we'll set + // the prune height to the current best height of the + // chain backend. + _, err = b.cfg.Graph.PruneGraph( + nil, bestHash, uint32(bestHeight), + ) + if err != nil { + return err + } + + default: + return err + } + } + + // If AssumeChannelValid is present, then we won't rely on pruning + // channels from the graph based on their spentness, but whether they + // are considered zombies or not. We will start zombie pruning after a + // small delay, to avoid slowing down startup of lnd. + if b.cfg.AssumeChannelValid { + time.AfterFunc(b.cfg.FirstTimePruneDelay, func() { + select { + case <-b.quit: + return + default: + } + + log.Info("Initial zombie prune starting") + if err := b.pruneZombieChans(); err != nil { + log.Errorf("Unable to prune zombies: %v", err) + } + }) + } else { + // Otherwise, we'll use our filtered chain view to prune + // channels as soon as they are detected as spent on-chain. + if err := b.cfg.ChainView.Start(); err != nil { + return err + } + + // Once the instance is active, we'll fetch the channel we'll + // receive notifications over. + b.newBlocks = b.cfg.ChainView.FilteredBlocks() + b.staleBlocks = b.cfg.ChainView.DisconnectedBlocks() + + // Before we perform our manual block pruning, we'll construct + // and apply a fresh chain filter to the active + // FilteredChainView instance. We do this before, as otherwise + // we may miss on-chain events as the filter hasn't properly + // been applied. + channelView, err := b.cfg.Graph.ChannelView() + if err != nil && !errors.Is( + err, channeldb.ErrGraphNoEdgesFound, + ) { + return err + } + + log.Infof("Filtering chain using %v channels active", + len(channelView)) + + if len(channelView) != 0 { + err = b.cfg.ChainView.UpdateFilter( + channelView, uint32(bestHeight), + ) + if err != nil { + return err + } + } + + // The graph pruning might have taken a while and there could be + // new blocks available. + _, bestHeight, err = b.cfg.Chain.GetBestBlock() + if err != nil { + return err + } + b.bestHeight = uint32(bestHeight) + + // Before we begin normal operation of the router, we first need + // to synchronize the channel graph to the latest state of the + // UTXO set. + if err := b.syncGraphWithChain(); err != nil { + return err + } + + // Finally, before we proceed, we'll prune any unconnected nodes + // from the graph in order to ensure we maintain a tight graph + // of "useful" nodes. + err = b.cfg.Graph.PruneGraphNodes() + if err != nil && err != channeldb.ErrGraphNodesNotFound { + return err + } + } + + b.wg.Add(1) + go b.networkHandler() + return nil } @@ -50,8 +312,1497 @@ func (b *Builder) Stop() error { log.Info("Builder shutting down...") defer log.Debug("Builder shutdown complete") + // Our filtered chain view could've only been started if + // AssumeChannelValid isn't present. + if !b.cfg.AssumeChannelValid { + if err := b.cfg.ChainView.Stop(); err != nil { + return err + } + } + close(b.quit) b.wg.Wait() return nil } + +// syncGraphWithChain attempts to synchronize the current channel graph with +// the latest UTXO set state. This process involves pruning from the channel +// graph any channels which have been closed by spending their funding output +// since we've been down. +func (b *Builder) syncGraphWithChain() error { + // First, we'll need to check to see if we're already in sync with the + // latest state of the UTXO set. + bestHash, bestHeight, err := b.cfg.Chain.GetBestBlock() + if err != nil { + return err + } + b.bestHeight = uint32(bestHeight) + + pruneHash, pruneHeight, err := b.cfg.Graph.PruneTip() + if err != nil { + switch { + // If the graph has never been pruned, or hasn't fully been + // created yet, then we don't treat this as an explicit error. + case err == channeldb.ErrGraphNeverPruned: + case err == channeldb.ErrGraphNotFound: + default: + return err + } + } + + log.Infof("Prune tip for Channel Graph: height=%v, hash=%v", + pruneHeight, pruneHash) + + switch { + + // If the graph has never been pruned, then we can exit early as this + // entails it's being created for the first time and hasn't seen any + // block or created channels. + case pruneHeight == 0 || pruneHash == nil: + return nil + + // If the block hashes and heights match exactly, then we don't need to + // prune the channel graph as we're already fully in sync. + case bestHash.IsEqual(pruneHash) && uint32(bestHeight) == pruneHeight: + return nil + } + + // If the main chain blockhash at prune height is different from the + // prune hash, this might indicate the database is on a stale branch. + mainBlockHash, err := b.cfg.Chain.GetBlockHash(int64(pruneHeight)) + if err != nil { + return err + } + + // While we are on a stale branch of the chain, walk backwards to find + // first common block. + for !pruneHash.IsEqual(mainBlockHash) { + log.Infof("channel graph is stale. Disconnecting block %v "+ + "(hash=%v)", pruneHeight, pruneHash) + // Prune the graph for every channel that was opened at height + // >= pruneHeight. + _, err := b.cfg.Graph.DisconnectBlockAtHeight(pruneHeight) + if err != nil { + return err + } + + pruneHash, pruneHeight, err = b.cfg.Graph.PruneTip() + if err != nil { + switch { + // If at this point the graph has never been pruned, we + // can exit as this entails we are back to the point + // where it hasn't seen any block or created channels, + // alas there's nothing left to prune. + case err == channeldb.ErrGraphNeverPruned: + return nil + case err == channeldb.ErrGraphNotFound: + return nil + default: + return err + } + } + mainBlockHash, err = b.cfg.Chain.GetBlockHash(int64(pruneHeight)) + if err != nil { + return err + } + } + + log.Infof("Syncing channel graph from height=%v (hash=%v) to height=%v "+ + "(hash=%v)", pruneHeight, pruneHash, bestHeight, bestHash) + + // If we're not yet caught up, then we'll walk forward in the chain + // pruning the channel graph with each new block that hasn't yet been + // consumed by the channel graph. + var spentOutputs []*wire.OutPoint + for nextHeight := pruneHeight + 1; nextHeight <= uint32(bestHeight); nextHeight++ { + // Break out of the rescan early if a shutdown has been + // requested, otherwise long rescans will block the daemon from + // shutting down promptly. + select { + case <-b.quit: + return ErrGraphBuilderShuttingDown + default: + } + + // Using the next height, request a manual block pruning from + // the chainview for the particular block hash. + log.Infof("Filtering block for closed channels, at height: %v", + int64(nextHeight)) + nextHash, err := b.cfg.Chain.GetBlockHash(int64(nextHeight)) + if err != nil { + return err + } + log.Tracef("Running block filter on block with hash: %v", + nextHash) + filterBlock, err := b.cfg.ChainView.FilterBlock(nextHash) + if err != nil { + return err + } + + // We're only interested in all prior outputs that have been + // spent in the block, so collate all the referenced previous + // outpoints within each tx and input. + for _, tx := range filterBlock.Transactions { + for _, txIn := range tx.TxIn { + spentOutputs = append(spentOutputs, + &txIn.PreviousOutPoint) + } + } + } + + // With the spent outputs gathered, attempt to prune the channel graph, + // also passing in the best hash+height so the prune tip can be updated. + closedChans, err := b.cfg.Graph.PruneGraph( + spentOutputs, bestHash, uint32(bestHeight), + ) + if err != nil { + return err + } + + log.Infof("Graph pruning complete: %v channels were closed since "+ + "height %v", len(closedChans), pruneHeight) + return nil +} + +// isZombieChannel takes two edge policy updates and determines if the +// corresponding channel should be considered a zombie. The first boolean is +// true if the policy update from node 1 is considered a zombie, the second +// boolean is that of node 2, and the final boolean is true if the channel +// is considered a zombie. +func (b *Builder) isZombieChannel(e1, + e2 *models.ChannelEdgePolicy) (bool, bool, bool) { + + chanExpiry := b.cfg.ChannelPruneExpiry + + e1Zombie := e1 == nil || time.Since(e1.LastUpdate) >= chanExpiry + e2Zombie := e2 == nil || time.Since(e2.LastUpdate) >= chanExpiry + + var e1Time, e2Time time.Time + if e1 != nil { + e1Time = e1.LastUpdate + } + if e2 != nil { + e2Time = e2.LastUpdate + } + + return e1Zombie, e2Zombie, b.IsZombieChannel(e1Time, e2Time) +} + +// IsZombieChannel takes the timestamps of the latest channel updates for a +// channel and returns true if the channel should be considered a zombie based +// on these timestamps. +func (b *Builder) IsZombieChannel(updateTime1, + updateTime2 time.Time) bool { + + chanExpiry := b.cfg.ChannelPruneExpiry + + e1Zombie := updateTime1.IsZero() || + time.Since(updateTime1) >= chanExpiry + + e2Zombie := updateTime2.IsZero() || + time.Since(updateTime2) >= chanExpiry + + // If we're using strict zombie pruning, then a channel is only + // considered live if both edges have a recent update we know of. + if b.cfg.StrictZombiePruning { + return e1Zombie || e2Zombie + } + + // Otherwise, if we're using the less strict variant, then a channel is + // considered live if either of the edges have a recent update. + return e1Zombie && e2Zombie +} + +// pruneZombieChans is a method that will be called periodically to prune out +// any "zombie" channels. We consider channels zombies if *both* edges haven't +// been updated since our zombie horizon. If AssumeChannelValid is present, +// we'll also consider channels zombies if *both* edges are disabled. This +// usually signals that a channel has been closed on-chain. We do this +// periodically to keep a healthy, lively routing table. +func (b *Builder) pruneZombieChans() error { + chansToPrune := make(map[uint64]struct{}) + chanExpiry := b.cfg.ChannelPruneExpiry + + log.Infof("Examining channel graph for zombie channels") + + // A helper method to detect if the channel belongs to this node + isSelfChannelEdge := func(info *models.ChannelEdgeInfo) bool { + return info.NodeKey1Bytes == b.cfg.SelfNode || + info.NodeKey2Bytes == b.cfg.SelfNode + } + + // First, we'll collect all the channels which are eligible for garbage + // collection due to being zombies. + filterPruneChans := func(info *models.ChannelEdgeInfo, + e1, e2 *models.ChannelEdgePolicy) error { + + // Exit early in case this channel is already marked to be + // pruned + _, markedToPrune := chansToPrune[info.ChannelID] + if markedToPrune { + return nil + } + + // We'll ensure that we don't attempt to prune our *own* + // channels from the graph, as in any case this should be + // re-advertised by the sub-system above us. + if isSelfChannelEdge(info) { + return nil + } + + e1Zombie, e2Zombie, isZombieChan := b.isZombieChannel(e1, e2) + + if e1Zombie { + log.Tracef("Node1 pubkey=%x of chan_id=%v is zombie", + info.NodeKey1Bytes, info.ChannelID) + } + + if e2Zombie { + log.Tracef("Node2 pubkey=%x of chan_id=%v is zombie", + info.NodeKey2Bytes, info.ChannelID) + } + + // If either edge hasn't been updated for a period of + // chanExpiry, then we'll mark the channel itself as eligible + // for graph pruning. + if !isZombieChan { + return nil + } + + log.Debugf("ChannelID(%v) is a zombie, collecting to prune", + info.ChannelID) + + // TODO(roasbeef): add ability to delete single directional edge + chansToPrune[info.ChannelID] = struct{}{} + + return nil + } + + // If AssumeChannelValid is present we'll look at the disabled bit for + // both edges. If they're both disabled, then we can interpret this as + // the channel being closed and can prune it from our graph. + if b.cfg.AssumeChannelValid { + disabledChanIDs, err := b.cfg.Graph.DisabledChannelIDs() + if err != nil { + return fmt.Errorf("unable to get disabled channels "+ + "ids chans: %v", err) + } + + disabledEdges, err := b.cfg.Graph.FetchChanInfos( + disabledChanIDs, + ) + if err != nil { + return fmt.Errorf("unable to fetch disabled channels "+ + "edges chans: %v", err) + } + + // Ensuring we won't prune our own channel from the graph. + for _, disabledEdge := range disabledEdges { + if !isSelfChannelEdge(disabledEdge.Info) { + chansToPrune[disabledEdge.Info.ChannelID] = + struct{}{} + } + } + } + + startTime := time.Unix(0, 0) + endTime := time.Now().Add(-1 * chanExpiry) + oldEdges, err := b.cfg.Graph.ChanUpdatesInHorizon(startTime, endTime) + if err != nil { + return fmt.Errorf("unable to fetch expired channel updates "+ + "chans: %v", err) + } + + for _, u := range oldEdges { + filterPruneChans(u.Info, u.Policy1, u.Policy2) + } + + log.Infof("Pruning %v zombie channels", len(chansToPrune)) + if len(chansToPrune) == 0 { + return nil + } + + // With the set of zombie-like channels obtained, we'll do another pass + // to delete them from the channel graph. + toPrune := make([]uint64, 0, len(chansToPrune)) + for chanID := range chansToPrune { + toPrune = append(toPrune, chanID) + log.Tracef("Pruning zombie channel with ChannelID(%v)", chanID) + } + err = b.cfg.Graph.DeleteChannelEdges( + b.cfg.StrictZombiePruning, true, toPrune..., + ) + if err != nil { + return fmt.Errorf("unable to delete zombie channels: %w", err) + } + + // With the channels pruned, we'll also attempt to prune any nodes that + // were a part of them. + err = b.cfg.Graph.PruneGraphNodes() + if err != nil && err != channeldb.ErrGraphNodesNotFound { + return fmt.Errorf("unable to prune graph nodes: %w", err) + } + + return nil +} + +// handleNetworkUpdate is responsible for processing the update message and +// notifies topology changes, if any. +// +// NOTE: must be run inside goroutine. +func (b *Builder) handleNetworkUpdate(vb *ValidationBarrier, + update *routingMsg) { + + defer b.wg.Done() + defer vb.CompleteJob() + + // If this message has an existing dependency, then we'll wait until + // that has been fully validated before we proceed. + err := vb.WaitForDependants(update.msg) + if err != nil { + switch { + case IsError(err, ErrVBarrierShuttingDown): + update.err <- err + + case IsError(err, ErrParentValidationFailed): + update.err <- newErrf(ErrIgnored, err.Error()) + + default: + log.Warnf("unexpected error during validation "+ + "barrier shutdown: %v", err) + update.err <- err + } + + return + } + + // Process the routing update to determine if this is either a new + // update from our PoV or an update to a prior vertex/edge we + // previously accepted. + err = b.processUpdate(update.msg, update.op...) + update.err <- err + + // If this message had any dependencies, then we can now signal them to + // continue. + allowDependents := err == nil || IsError(err, ErrIgnored, ErrOutdated) + vb.SignalDependants(update.msg, allowDependents) + + // If the error is not nil here, there's no need to send topology + // change. + if err != nil { + // We now decide to log an error or not. If allowDependents is + // false, it means there is an error and the error is neither + // ErrIgnored or ErrOutdated. In this case, we'll log an error. + // Otherwise, we'll add debug log only. + if allowDependents { + log.Debugf("process network updates got: %v", err) + } else { + log.Errorf("process network updates got: %v", err) + } + + return + } + + // Otherwise, we'll send off a new notification for the newly accepted + // update, if any. + topChange := &TopologyChange{} + err = addToTopologyChange(b.cfg.Graph, topChange, update.msg) + if err != nil { + log.Errorf("unable to update topology change notification: %v", + err) + return + } + + if !topChange.isEmpty() { + b.notifyTopologyChange(topChange) + } +} + +// networkHandler is the primary goroutine for the Builder. The roles of +// this goroutine include answering queries related to the state of the +// network, pruning the graph on new block notification, applying network +// updates, and registering new topology clients. +// +// NOTE: This MUST be run as a goroutine. +func (b *Builder) networkHandler() { + defer b.wg.Done() + + graphPruneTicker := time.NewTicker(b.cfg.GraphPruneInterval) + defer graphPruneTicker.Stop() + + defer b.statTicker.Stop() + + b.stats.Reset() + + // We'll use this validation barrier to ensure that we process all jobs + // in the proper order during parallel validation. + // + // NOTE: For AssumeChannelValid, we bump up the maximum number of + // concurrent validation requests since there are no blocks being + // fetched. This significantly increases the performance of IGD for + // neutrino nodes. + // + // However, we dial back to use multiple of the number of cores when + // fully validating, to avoid fetching up to 1000 blocks from the + // backend. On bitcoind, this will empirically cause massive latency + // spikes when executing this many concurrent RPC calls. Critical + // subsystems or basic rpc calls that rely on calls such as GetBestBlock + // will hang due to excessive load. + // + // See https://github.com/lightningnetwork/lnd/issues/4892. + var validationBarrier *ValidationBarrier + if b.cfg.AssumeChannelValid { + validationBarrier = NewValidationBarrier(1000, b.quit) + } else { + validationBarrier = NewValidationBarrier( + 4*runtime.NumCPU(), b.quit, + ) + } + + for { + + // If there are stats, resume the statTicker. + if !b.stats.Empty() { + b.statTicker.Resume() + } + + select { + // A new fully validated network update has just arrived. As a + // result we'll modify the channel graph accordingly depending + // on the exact type of the message. + case update := <-b.networkUpdates: + // We'll set up any dependants, and wait until a free + // slot for this job opens up, this allows us to not + // have thousands of goroutines active. + validationBarrier.InitJobDependencies(update.msg) + + b.wg.Add(1) + go b.handleNetworkUpdate(validationBarrier, update) + + // TODO(roasbeef): remove all unconnected vertexes + // after N blocks pass with no corresponding + // announcements. + + case chainUpdate, ok := <-b.staleBlocks: + // If the channel has been closed, then this indicates + // the daemon is shutting down, so we exit ourselves. + if !ok { + return + } + + // Since this block is stale, we update our best height + // to the previous block. + blockHeight := uint32(chainUpdate.Height) + atomic.StoreUint32(&b.bestHeight, blockHeight-1) + + // Update the channel graph to reflect that this block + // was disconnected. + _, err := b.cfg.Graph.DisconnectBlockAtHeight(blockHeight) + if err != nil { + log.Errorf("unable to prune graph with stale "+ + "block: %v", err) + continue + } + + // TODO(halseth): notify client about the reorg? + + // A new block has arrived, so we can prune the channel graph + // of any channels which were closed in the block. + case chainUpdate, ok := <-b.newBlocks: + // If the channel has been closed, then this indicates + // the daemon is shutting down, so we exit ourselves. + if !ok { + return + } + + // We'll ensure that any new blocks received attach + // directly to the end of our main chain. If not, then + // we've somehow missed some blocks. Here we'll catch + // up the chain with the latest blocks. + currentHeight := atomic.LoadUint32(&b.bestHeight) + switch { + case chainUpdate.Height == currentHeight+1: + err := b.updateGraphWithClosedChannels( + chainUpdate, + ) + if err != nil { + log.Errorf("unable to prune graph "+ + "with closed channels: %v", err) + } + + case chainUpdate.Height > currentHeight+1: + log.Errorf("out of order block: expecting "+ + "height=%v, got height=%v", + currentHeight+1, chainUpdate.Height) + + err := b.getMissingBlocks(currentHeight, chainUpdate) + if err != nil { + log.Errorf("unable to retrieve missing"+ + "blocks: %v", err) + } + + case chainUpdate.Height < currentHeight+1: + log.Errorf("out of order block: expecting "+ + "height=%v, got height=%v", + currentHeight+1, chainUpdate.Height) + + log.Infof("Skipping channel pruning since "+ + "received block height %v was already"+ + " processed.", chainUpdate.Height) + } + + // A new notification client update has arrived. We're either + // gaining a new client, or cancelling notifications for an + // existing client. + case ntfnUpdate := <-b.ntfnClientUpdates: + clientID := ntfnUpdate.clientID + + if ntfnUpdate.cancel { + client, ok := b.topologyClients.LoadAndDelete( + clientID, + ) + if ok { + close(client.exit) + client.wg.Wait() + + close(client.ntfnChan) + } + + continue + } + + b.topologyClients.Store(clientID, &topologyClient{ + ntfnChan: ntfnUpdate.ntfnChan, + exit: make(chan struct{}), + }) + + // The graph prune ticker has ticked, so we'll examine the + // state of the known graph to filter out any zombie channels + // for pruning. + case <-graphPruneTicker.C: + if err := b.pruneZombieChans(); err != nil { + log.Errorf("Unable to prune zombies: %v", err) + } + + // Log any stats if we've processed a non-empty number of + // channels, updates, or nodes. We'll only pause the ticker if + // the last window contained no updates to avoid resuming and + // pausing while consecutive windows contain new info. + case <-b.statTicker.Ticks(): + if !b.stats.Empty() { + log.Infof(b.stats.String()) + } else { + b.statTicker.Pause() + } + b.stats.Reset() + + // The router has been signalled to exit, to we exit our main + // loop so the wait group can be decremented. + case <-b.quit: + return + } + } +} + +// getMissingBlocks walks through all missing blocks and updates the graph +// closed channels accordingly. +func (b *Builder) getMissingBlocks(currentHeight uint32, + chainUpdate *chainview.FilteredBlock) error { + + outdatedHash, err := b.cfg.Chain.GetBlockHash(int64(currentHeight)) + if err != nil { + return err + } + + outdatedBlock := &chainntnfs.BlockEpoch{ + Height: int32(currentHeight), + Hash: outdatedHash, + } + + epochClient, err := b.cfg.Notifier.RegisterBlockEpochNtfn( + outdatedBlock, + ) + if err != nil { + return err + } + defer epochClient.Cancel() + + blockDifference := int(chainUpdate.Height - currentHeight) + + // We'll walk through all the outdated blocks and make sure we're able + // to update the graph with any closed channels from them. + for i := 0; i < blockDifference; i++ { + var ( + missingBlock *chainntnfs.BlockEpoch + ok bool + ) + + select { + case missingBlock, ok = <-epochClient.Epochs: + if !ok { + return nil + } + + case <-b.quit: + return nil + } + + filteredBlock, err := b.cfg.ChainView.FilterBlock( + missingBlock.Hash, + ) + if err != nil { + return err + } + + err = b.updateGraphWithClosedChannels( + filteredBlock, + ) + if err != nil { + return err + } + } + + return nil +} + +// updateGraphWithClosedChannels prunes the channel graph of closed channels +// that are no longer needed. +func (b *Builder) updateGraphWithClosedChannels( + chainUpdate *chainview.FilteredBlock) error { + + // Once a new block arrives, we update our running track of the height + // of the chain tip. + blockHeight := chainUpdate.Height + + atomic.StoreUint32(&b.bestHeight, blockHeight) + log.Infof("Pruning channel graph using block %v (height=%v)", + chainUpdate.Hash, blockHeight) + + // We're only interested in all prior outputs that have been spent in + // the block, so collate all the referenced previous outpoints within + // each tx and input. + var spentOutputs []*wire.OutPoint + for _, tx := range chainUpdate.Transactions { + for _, txIn := range tx.TxIn { + spentOutputs = append(spentOutputs, + &txIn.PreviousOutPoint) + } + } + + // With the spent outputs gathered, attempt to prune the channel graph, + // also passing in the hash+height of the block being pruned so the + // prune tip can be updated. + chansClosed, err := b.cfg.Graph.PruneGraph(spentOutputs, + &chainUpdate.Hash, chainUpdate.Height) + if err != nil { + log.Errorf("unable to prune routing table: %v", err) + return err + } + + log.Infof("Block %v (height=%v) closed %v channels", chainUpdate.Hash, + blockHeight, len(chansClosed)) + + if len(chansClosed) == 0 { + return err + } + + // Notify all currently registered clients of the newly closed channels. + closeSummaries := createCloseSummaries(blockHeight, chansClosed...) + b.notifyTopologyChange(&TopologyChange{ + ClosedChannels: closeSummaries, + }) + + return nil +} + +// assertNodeAnnFreshness returns a non-nil error if we have an announcement in +// the database for the passed node with a timestamp newer than the passed +// timestamp. ErrIgnored will be returned if we already have the node, and +// ErrOutdated will be returned if we have a timestamp that's after the new +// timestamp. +func (b *Builder) assertNodeAnnFreshness(node route.Vertex, + msgTimestamp time.Time) error { + + // If we are not already aware of this node, it means that we don't + // know about any channel using this node. To avoid a DoS attack by + // node announcements, we will ignore such nodes. If we do know about + // this node, check that this update brings info newer than what we + // already have. + lastUpdate, exists, err := b.cfg.Graph.HasLightningNode(node) + if err != nil { + return errors.Errorf("unable to query for the "+ + "existence of node: %v", err) + } + if !exists { + return newErrf(ErrIgnored, "Ignoring node announcement"+ + " for node not found in channel graph (%x)", + node[:]) + } + + // If we've reached this point then we're aware of the vertex being + // advertised. So we now check if the new message has a new time stamp, + // if not then we won't accept the new data as it would override newer + // data. + if !lastUpdate.Before(msgTimestamp) { + return newErrf(ErrOutdated, "Ignoring outdated "+ + "announcement for %x", node[:]) + } + + return nil +} + +// addZombieEdge adds a channel that failed complete validation into the zombie +// index so we can avoid having to re-validate it in the future. +func (b *Builder) addZombieEdge(chanID uint64) error { + // If the edge fails validation we'll mark the edge itself as a zombie + // so we don't continue to request it. We use the "zero key" for both + // node pubkeys so this edge can't be resurrected. + var zeroKey [33]byte + err := b.cfg.Graph.MarkEdgeZombie(chanID, zeroKey, zeroKey) + if err != nil { + return fmt.Errorf("unable to mark spent chan(id=%v) as a "+ + "zombie: %w", chanID, err) + } + + return nil +} + +// makeFundingScript is used to make the funding script for both segwit v0 and +// segwit v1 (taproot) channels. +// +// TODO(roasbeef: export and use elsewhere? +func makeFundingScript(bitcoinKey1, bitcoinKey2 []byte, + chanFeatures []byte) ([]byte, error) { + + legacyFundingScript := func() ([]byte, error) { + witnessScript, err := input.GenMultiSigScript( + bitcoinKey1, bitcoinKey2, + ) + if err != nil { + return nil, err + } + pkScript, err := input.WitnessScriptHash(witnessScript) + if err != nil { + return nil, err + } + + return pkScript, nil + } + + if len(chanFeatures) == 0 { + return legacyFundingScript() + } + + // In order to make the correct funding script, we'll need to parse the + // chanFeatures bytes into a feature vector we can interact with. + rawFeatures := lnwire.NewRawFeatureVector() + err := rawFeatures.Decode(bytes.NewReader(chanFeatures)) + if err != nil { + return nil, fmt.Errorf("unable to parse chan feature "+ + "bits: %w", err) + } + + chanFeatureBits := lnwire.NewFeatureVector( + rawFeatures, lnwire.Features, + ) + if chanFeatureBits.HasFeature( + lnwire.SimpleTaprootChannelsOptionalStaging, + ) { + + pubKey1, err := btcec.ParsePubKey(bitcoinKey1) + if err != nil { + return nil, err + } + pubKey2, err := btcec.ParsePubKey(bitcoinKey2) + if err != nil { + return nil, err + } + + fundingScript, _, err := input.GenTaprootFundingScript( + pubKey1, pubKey2, 0, + ) + if err != nil { + return nil, err + } + + return fundingScript, nil + } + + return legacyFundingScript() +} + +// processUpdate processes a new relate authenticated channel/edge, node or +// channel/edge update network update. If the update didn't affect the internal +// state of the draft due to either being out of date, invalid, or redundant, +// then error is returned. +func (b *Builder) processUpdate(msg interface{}, + op ...batch.SchedulerOption) error { + + switch msg := msg.(type) { + case *channeldb.LightningNode: + // Before we add the node to the database, we'll check to see + // if the announcement is "fresh" or not. If it isn't, then + // we'll return an error. + err := b.assertNodeAnnFreshness(msg.PubKeyBytes, msg.LastUpdate) + if err != nil { + return err + } + + if err := b.cfg.Graph.AddLightningNode(msg, op...); err != nil { + return errors.Errorf("unable to add node %x to the "+ + "graph: %v", msg.PubKeyBytes, err) + } + + log.Tracef("Updated vertex data for node=%x", msg.PubKeyBytes) + b.stats.incNumNodeUpdates() + + case *models.ChannelEdgeInfo: + log.Debugf("Received ChannelEdgeInfo for channel %v", + msg.ChannelID) + + // Prior to processing the announcement we first check if we + // already know of this channel, if so, then we can exit early. + _, _, exists, isZombie, err := b.cfg.Graph.HasChannelEdge( + msg.ChannelID, + ) + if err != nil && err != channeldb.ErrGraphNoEdgesFound { + return errors.Errorf("unable to check for edge "+ + "existence: %v", err) + } + if isZombie { + return newErrf(ErrIgnored, "ignoring msg for zombie "+ + "chan_id=%v", msg.ChannelID) + } + if exists { + return newErrf(ErrIgnored, "ignoring msg for known "+ + "chan_id=%v", msg.ChannelID) + } + + // If AssumeChannelValid is present, then we are unable to + // perform any of the expensive checks below, so we'll + // short-circuit our path straight to adding the edge to our + // graph. If the passed ShortChannelID is an alias, then we'll + // skip validation as it will not map to a legitimate tx. This + // is not a DoS vector as only we can add an alias + // ChannelAnnouncement from the gossiper. + scid := lnwire.NewShortChanIDFromInt(msg.ChannelID) + if b.cfg.AssumeChannelValid || b.cfg.IsAlias(scid) { + if err := b.cfg.Graph.AddChannelEdge(msg, op...); err != nil { + return fmt.Errorf("unable to add edge: %w", err) + } + log.Tracef("New channel discovered! Link "+ + "connects %x and %x with ChannelID(%v)", + msg.NodeKey1Bytes, msg.NodeKey2Bytes, + msg.ChannelID) + b.stats.incNumEdgesDiscovered() + + break + } + + // Before we can add the channel to the channel graph, we need + // to obtain the full funding outpoint that's encoded within + // the channel ID. + channelID := lnwire.NewShortChanIDFromInt(msg.ChannelID) + fundingTx, err := b.fetchFundingTxWrapper(&channelID) + if err != nil { + // In order to ensure we don't erroneously mark a + // channel as a zombie due to an RPC failure, we'll + // attempt to string match for the relevant errors. + // + // * btcd: + // * https://github.com/btcsuite/btcd/blob/master/rpcserver.go#L1316 + // * https://github.com/btcsuite/btcd/blob/master/rpcserver.go#L1086 + // * bitcoind: + // * https://github.com/bitcoin/bitcoin/blob/7fcf53f7b4524572d1d0c9a5fdc388e87eb02416/src/rpc/blockchain.cpp#L770 + // * https://github.com/bitcoin/bitcoin/blob/7fcf53f7b4524572d1d0c9a5fdc388e87eb02416/src/rpc/blockchain.cpp#L954 + switch { + case strings.Contains(err.Error(), "not found"): + fallthrough + + case strings.Contains(err.Error(), "out of range"): + // If the funding transaction isn't found at + // all, then we'll mark the edge itself as a + // zombie so we don't continue to request it. + // We use the "zero key" for both node pubkeys + // so this edge can't be resurrected. + zErr := b.addZombieEdge(msg.ChannelID) + if zErr != nil { + return zErr + } + + default: + } + + return newErrf(ErrNoFundingTransaction, "unable to "+ + "locate funding tx: %v", err) + } + + // Recreate witness output to be sure that declared in channel + // edge bitcoin keys and channel value corresponds to the + // reality. + fundingPkScript, err := makeFundingScript( + msg.BitcoinKey1Bytes[:], msg.BitcoinKey2Bytes[:], + msg.Features, + ) + if err != nil { + return err + } + + // Next we'll validate that this channel is actually well + // formed. If this check fails, then this channel either + // doesn't exist, or isn't the one that was meant to be created + // according to the passed channel proofs. + fundingPoint, err := chanvalidate.Validate(&chanvalidate.Context{ + Locator: &chanvalidate.ShortChanIDChanLocator{ + ID: channelID, + }, + MultiSigPkScript: fundingPkScript, + FundingTx: fundingTx, + }) + if err != nil { + // Mark the edge as a zombie so we won't try to + // re-validate it on start up. + if err := b.addZombieEdge(msg.ChannelID); err != nil { + return err + } + + return newErrf(ErrInvalidFundingOutput, "output "+ + "failed validation: %w", err) + } + + // Now that we have the funding outpoint of the channel, ensure + // that it hasn't yet been spent. If so, then this channel has + // been closed so we'll ignore it. + chanUtxo, err := b.cfg.Chain.GetUtxo( + fundingPoint, fundingPkScript, channelID.BlockHeight, + b.quit, + ) + if err != nil { + if errors.Is(err, btcwallet.ErrOutputSpent) { + zErr := b.addZombieEdge(msg.ChannelID) + if zErr != nil { + return zErr + } + } + + return newErrf(ErrChannelSpent, "unable to fetch utxo "+ + "for chan_id=%v, chan_point=%v: %v", + msg.ChannelID, fundingPoint, err) + } + + // TODO(roasbeef): this is a hack, needs to be removed + // after commitment fees are dynamic. + msg.Capacity = btcutil.Amount(chanUtxo.Value) + msg.ChannelPoint = *fundingPoint + if err := b.cfg.Graph.AddChannelEdge(msg, op...); err != nil { + return errors.Errorf("unable to add edge: %v", err) + } + + log.Debugf("New channel discovered! Link "+ + "connects %x and %x with ChannelPoint(%v): "+ + "chan_id=%v, capacity=%v", + msg.NodeKey1Bytes, msg.NodeKey2Bytes, + fundingPoint, msg.ChannelID, msg.Capacity) + b.stats.incNumEdgesDiscovered() + + // As a new edge has been added to the channel graph, we'll + // update the current UTXO filter within our active + // FilteredChainView so we are notified if/when this channel is + // closed. + filterUpdate := []channeldb.EdgePoint{ + { + FundingPkScript: fundingPkScript, + OutPoint: *fundingPoint, + }, + } + err = b.cfg.ChainView.UpdateFilter( + filterUpdate, atomic.LoadUint32(&b.bestHeight), + ) + if err != nil { + return errors.Errorf("unable to update chain "+ + "view: %v", err) + } + + case *models.ChannelEdgePolicy: + log.Debugf("Received ChannelEdgePolicy for channel %v", + msg.ChannelID) + + // We make sure to hold the mutex for this channel ID, + // such that no other goroutine is concurrently doing + // database accesses for the same channel ID. + b.channelEdgeMtx.Lock(msg.ChannelID) + defer b.channelEdgeMtx.Unlock(msg.ChannelID) + + edge1Timestamp, edge2Timestamp, exists, isZombie, err := + b.cfg.Graph.HasChannelEdge(msg.ChannelID) + if err != nil && err != channeldb.ErrGraphNoEdgesFound { + return errors.Errorf("unable to check for edge "+ + "existence: %v", err) + + } + + // If the channel is marked as a zombie in our database, and + // we consider this a stale update, then we should not apply the + // policy. + isStaleUpdate := time.Since(msg.LastUpdate) > b.cfg.ChannelPruneExpiry + if isZombie && isStaleUpdate { + return newErrf(ErrIgnored, "ignoring stale update "+ + "(flags=%v|%v) for zombie chan_id=%v", + msg.MessageFlags, msg.ChannelFlags, + msg.ChannelID) + } + + // If the channel doesn't exist in our database, we cannot + // apply the updated policy. + if !exists { + return newErrf(ErrIgnored, "ignoring update "+ + "(flags=%v|%v) for unknown chan_id=%v", + msg.MessageFlags, msg.ChannelFlags, + msg.ChannelID) + } + + // As edges are directional edge node has a unique policy for + // the direction of the edge they control. Therefore, we first + // check if we already have the most up-to-date information for + // that edge. If this message has a timestamp not strictly + // newer than what we already know of we can exit early. + switch { + + // A flag set of 0 indicates this is an announcement for the + // "first" node in the channel. + case msg.ChannelFlags&lnwire.ChanUpdateDirection == 0: + + // Ignore outdated message. + if !edge1Timestamp.Before(msg.LastUpdate) { + return newErrf(ErrOutdated, "Ignoring "+ + "outdated update (flags=%v|%v) for "+ + "known chan_id=%v", msg.MessageFlags, + msg.ChannelFlags, msg.ChannelID) + } + + // Similarly, a flag set of 1 indicates this is an announcement + // for the "second" node in the channel. + case msg.ChannelFlags&lnwire.ChanUpdateDirection == 1: + + // Ignore outdated message. + if !edge2Timestamp.Before(msg.LastUpdate) { + return newErrf(ErrOutdated, "Ignoring "+ + "outdated update (flags=%v|%v) for "+ + "known chan_id=%v", msg.MessageFlags, + msg.ChannelFlags, msg.ChannelID) + } + } + + // Now that we know this isn't a stale update, we'll apply the + // new edge policy to the proper directional edge within the + // channel graph. + if err = b.cfg.Graph.UpdateEdgePolicy(msg, op...); err != nil { + err := errors.Errorf("unable to add channel: %v", err) + log.Error(err) + return err + } + + log.Tracef("New channel update applied: %v", + newLogClosure(func() string { return spew.Sdump(msg) })) + b.stats.incNumChannelUpdates() + + default: + return errors.Errorf("wrong routing update message type") + } + + return nil +} + +// fetchFundingTxWrapper is a wrapper around fetchFundingTx, except that it +// will exit if the router has stopped. +func (b *Builder) fetchFundingTxWrapper(chanID *lnwire.ShortChannelID) ( + *wire.MsgTx, error) { + + txChan := make(chan *wire.MsgTx, 1) + errChan := make(chan error, 1) + + go func() { + tx, err := b.fetchFundingTx(chanID) + if err != nil { + errChan <- err + return + } + + txChan <- tx + }() + + select { + case tx := <-txChan: + return tx, nil + + case err := <-errChan: + return nil, err + + case <-b.quit: + return nil, ErrGraphBuilderShuttingDown + } +} + +// fetchFundingTx returns the funding transaction identified by the passed +// short channel ID. +// +// TODO(roasbeef): replace with call to GetBlockTransaction? (would allow to +// later use getblocktxn) +func (b *Builder) fetchFundingTx( + chanID *lnwire.ShortChannelID) (*wire.MsgTx, error) { + + // First fetch the block hash by the block number encoded, then use + // that hash to fetch the block itself. + blockNum := int64(chanID.BlockHeight) + blockHash, err := b.cfg.Chain.GetBlockHash(blockNum) + if err != nil { + return nil, err + } + fundingBlock, err := b.cfg.Chain.GetBlock(blockHash) + if err != nil { + return nil, err + } + + // As a sanity check, ensure that the advertised transaction index is + // within the bounds of the total number of transactions within a + // block. + numTxns := uint32(len(fundingBlock.Transactions)) + if chanID.TxIndex > numTxns-1 { + return nil, fmt.Errorf("tx_index=#%v "+ + "is out of range (max_index=%v), network_chan_id=%v", + chanID.TxIndex, numTxns-1, chanID) + } + + return fundingBlock.Transactions[chanID.TxIndex].Copy(), nil +} + +// routingMsg couples a routing related routing topology update to the +// error channel. +type routingMsg struct { + msg interface{} + op []batch.SchedulerOption + err chan error +} + +// ApplyChannelUpdate validates a channel update and if valid, applies it to the +// database. It returns a bool indicating whether the updates were successful. +func (b *Builder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate) bool { + ch, _, _, err := b.GetChannelByID(msg.ShortChannelID) + if err != nil { + log.Errorf("Unable to retrieve channel by id: %v", err) + return false + } + + var pubKey *btcec.PublicKey + + switch msg.ChannelFlags & lnwire.ChanUpdateDirection { + case 0: + pubKey, _ = ch.NodeKey1() + + case 1: + pubKey, _ = ch.NodeKey2() + } + + // Exit early if the pubkey cannot be decided. + if pubKey == nil { + log.Errorf("Unable to decide pubkey with ChannelFlags=%v", + msg.ChannelFlags) + return false + } + + err = ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg) + if err != nil { + log.Errorf("Unable to validate channel update: %v", err) + return false + } + + err = b.UpdateEdge(&models.ChannelEdgePolicy{ + SigBytes: msg.Signature.ToSignatureBytes(), + ChannelID: msg.ShortChannelID.ToUint64(), + LastUpdate: time.Unix(int64(msg.Timestamp), 0), + MessageFlags: msg.MessageFlags, + ChannelFlags: msg.ChannelFlags, + TimeLockDelta: msg.TimeLockDelta, + MinHTLC: msg.HtlcMinimumMsat, + MaxHTLC: msg.HtlcMaximumMsat, + FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate), + ExtraOpaqueData: msg.ExtraOpaqueData, + }) + if err != nil && !IsError(err, ErrIgnored, ErrOutdated) { + log.Errorf("Unable to apply channel update: %v", err) + return false + } + + return true +} + +// AddNode is used to add information about a node to the router database. If +// the node with this pubkey is not present in an existing channel, it will +// be ignored. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) AddNode(node *channeldb.LightningNode, + op ...batch.SchedulerOption) error { + + rMsg := &routingMsg{ + msg: node, + op: op, + err: make(chan error, 1), + } + + select { + case b.networkUpdates <- rMsg: + select { + case err := <-rMsg.err: + return err + case <-b.quit: + return ErrGraphBuilderShuttingDown + } + case <-b.quit: + return ErrGraphBuilderShuttingDown + } +} + +// AddEdge is used to add edge/channel to the topology of the router, after all +// information about channel will be gathered this edge/channel might be used +// in construction of payment path. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) AddEdge(edge *models.ChannelEdgeInfo, + op ...batch.SchedulerOption) error { + + rMsg := &routingMsg{ + msg: edge, + op: op, + err: make(chan error, 1), + } + + select { + case b.networkUpdates <- rMsg: + select { + case err := <-rMsg.err: + return err + case <-b.quit: + return ErrGraphBuilderShuttingDown + } + case <-b.quit: + return ErrGraphBuilderShuttingDown + } +} + +// UpdateEdge is used to update edge information, without this message edge +// considered as not fully constructed. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) UpdateEdge(update *models.ChannelEdgePolicy, + op ...batch.SchedulerOption) error { + + rMsg := &routingMsg{ + msg: update, + op: op, + err: make(chan error, 1), + } + + select { + case b.networkUpdates <- rMsg: + select { + case err := <-rMsg.err: + return err + case <-b.quit: + return ErrGraphBuilderShuttingDown + } + case <-b.quit: + return ErrGraphBuilderShuttingDown + } +} + +// CurrentBlockHeight returns the block height from POV of the router subsystem. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) CurrentBlockHeight() (uint32, error) { + _, height, err := b.cfg.Chain.GetBestBlock() + return uint32(height), err +} + +// SyncedHeight returns the block height to which the router subsystem currently +// is synced to. This can differ from the above chain height if the goroutine +// responsible for processing the blocks isn't yet up to speed. +func (b *Builder) SyncedHeight() uint32 { + return atomic.LoadUint32(&b.bestHeight) +} + +// GetChannelByID return the channel by the channel id. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) GetChannelByID(chanID lnwire.ShortChannelID) ( + *models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + return b.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) +} + +// FetchLightningNode attempts to look up a target node by its identity public +// key. channeldb.ErrGraphNodeNotFound is returned if the node doesn't exist +// within the graph. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) FetchLightningNode( + node route.Vertex) (*channeldb.LightningNode, error) { + + return b.cfg.Graph.FetchLightningNode(node) +} + +// ForEachNode is used to iterate over every node in router topology. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) ForEachNode( + cb func(*channeldb.LightningNode) error) error { + + return b.cfg.Graph.ForEachNode( + func(_ kvdb.RTx, n *channeldb.LightningNode) error { + return cb(n) + }) +} + +// ForAllOutgoingChannels is used to iterate over all outgoing channels owned by +// the router. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) ForAllOutgoingChannels(cb func(kvdb.RTx, + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { + + return b.cfg.Graph.ForEachNodeChannel(b.cfg.SelfNode, + func(tx kvdb.RTx, c *models.ChannelEdgeInfo, + e *models.ChannelEdgePolicy, + _ *models.ChannelEdgePolicy) error { + + if e == nil { + return fmt.Errorf("channel from self node " + + "has no policy") + } + + return cb(tx, c, e) + }, + ) +} + +// AddProof updates the channel edge info with proof which is needed to +// properly announce the edge to the rest of the network. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) AddProof(chanID lnwire.ShortChannelID, + proof *models.ChannelAuthProof) error { + + info, _, _, err := b.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) + if err != nil { + return err + } + + info.AuthProof = proof + return b.cfg.Graph.UpdateChannelEdge(info) +} + +// IsStaleNode returns true if the graph source has a node announcement for the +// target node with a more recent timestamp. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) IsStaleNode(node route.Vertex, + timestamp time.Time) bool { + + // If our attempt to assert that the node announcement is fresh fails, + // then we know that this is actually a stale announcement. + err := b.assertNodeAnnFreshness(node, timestamp) + if err != nil { + log.Debugf("Checking stale node %x got %v", node, err) + return true + } + + return false +} + +// IsPublicNode determines whether the given vertex is seen as a public node in +// the graph from the graph's source node's point of view. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) IsPublicNode(node route.Vertex) (bool, error) { + return b.cfg.Graph.IsPublicNode(node) +} + +// IsKnownEdge returns true if the graph source already knows of the passed +// channel ID either as a live or zombie edge. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) IsKnownEdge(chanID lnwire.ShortChannelID) bool { + _, _, exists, isZombie, _ := b.cfg.Graph.HasChannelEdge( + chanID.ToUint64(), + ) + return exists || isZombie +} + +// IsStaleEdgePolicy returns true if the graph source has a channel edge for +// the passed channel ID (and flags) that have a more recent timestamp. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, + timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool { + + edge1Timestamp, edge2Timestamp, exists, isZombie, err := + b.cfg.Graph.HasChannelEdge(chanID.ToUint64()) + if err != nil { + log.Debugf("Check stale edge policy got error: %v", err) + return false + + } + + // If we know of the edge as a zombie, then we'll make some additional + // checks to determine if the new policy is fresh. + if isZombie { + // When running with AssumeChannelValid, we also prune channels + // if both of their edges are disabled. We'll mark the new + // policy as stale if it remains disabled. + if b.cfg.AssumeChannelValid { + isDisabled := flags&lnwire.ChanUpdateDisabled == + lnwire.ChanUpdateDisabled + if isDisabled { + return true + } + } + + // Otherwise, we'll fall back to our usual ChannelPruneExpiry. + return time.Since(timestamp) > b.cfg.ChannelPruneExpiry + } + + // If we don't know of the edge, then it means it's fresh (thus not + // stale). + if !exists { + return false + } + + // As edges are directional edge node has a unique policy for the + // direction of the edge they control. Therefore, we first check if we + // already have the most up-to-date information for that edge. If so, + // then we can exit early. + switch { + // A flag set of 0 indicates this is an announcement for the "first" + // node in the channel. + case flags&lnwire.ChanUpdateDirection == 0: + return !edge1Timestamp.Before(timestamp) + + // Similarly, a flag set of 1 indicates this is an announcement for the + // "second" node in the channel. + case flags&lnwire.ChanUpdateDirection == 1: + return !edge2Timestamp.Before(timestamp) + } + + return false +} + +// MarkEdgeLive clears an edge from our zombie index, deeming it as live. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (b *Builder) MarkEdgeLive(chanID lnwire.ShortChannelID) error { + return b.cfg.Graph.MarkEdgeLive(chanID.ToUint64()) +} diff --git a/graph/builder_test.go b/graph/builder_test.go new file mode 100644 index 0000000000..d3e25d2aab --- /dev/null +++ b/graph/builder_test.go @@ -0,0 +1,2051 @@ +package graph + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "image/color" + "math/rand" + "net" + "os" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" +) + +const ( + // basicGraphFilePath is the file path for a basic graph used within + // the tests. The basic graph consists of 5 nodes with 5 channels + // connecting them. + basicGraphFilePath = "testdata/basic_graph.json" + + testTimeout = 5 * time.Second +) + +// TestAddProof checks that we can update the channel proof after channel +// info was added to the database. +func TestAddProof(t *testing.T) { + t.Parallel() + + ctx := createTestCtxSingleNode(t, 0) + + // Before creating out edge, we'll create two new nodes within the + // network that the channel will connect. + node1 := createTestNode(t) + node2 := createTestNode(t) + + // In order to be able to add the edge we should have a valid funding + // UTXO within the blockchain. + fundingTx, _, chanID, err := createChannelEdge( + ctx, bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), 100, 0, + ) + require.NoError(t, err, "unable create channel edge") + fundingBlock := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{fundingTx}, + } + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) + + // After utxo was recreated adding the edge without the proof. + edge := &models.ChannelEdgeInfo{ + ChannelID: chanID.ToUint64(), + NodeKey1Bytes: node1.PubKeyBytes, + NodeKey2Bytes: node2.PubKeyBytes, + AuthProof: nil, + } + copy(edge.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) + copy(edge.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) + + require.NoError(t, ctx.builder.AddEdge(edge)) + + // Now we'll attempt to update the proof and check that it has been + // properly updated. + require.NoError(t, ctx.builder.AddProof(*chanID, &testAuthProof)) + + info, _, _, err := ctx.builder.GetChannelByID(*chanID) + require.NoError(t, err, "unable to get channel") + require.NotNil(t, info.AuthProof) +} + +// TestIgnoreNodeAnnouncement tests that adding a node to the router that is +// not known from any channel announcement, leads to the announcement being +// ignored. +func TestIgnoreNodeAnnouncement(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + ctx := createTestCtxFromFile(t, startingBlockHeight, basicGraphFilePath) + + pub := priv1.PubKey() + node := &channeldb.LightningNode{ + HaveNodeAnnouncement: true, + LastUpdate: time.Unix(123, 0), + Addresses: testAddrs, + Color: color.RGBA{1, 2, 3, 0}, + Alias: "node11", + AuthSigBytes: testSig.Serialize(), + Features: testFeatures, + } + copy(node.PubKeyBytes[:], pub.SerializeCompressed()) + + err := ctx.builder.AddNode(node) + if !IsError(err, ErrIgnored) { + t.Fatalf("expected to get ErrIgnore, instead got: %v", err) + } +} + +// TestIgnoreChannelEdgePolicyForUnknownChannel checks that a router will +// ignore a channel policy for a channel not in the graph. +func TestIgnoreChannelEdgePolicyForUnknownChannel(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + + // Setup an initially empty network. + var testChannels []*testChannel + testGraph, err := createTestGraphFromChannels( + t, true, testChannels, "roasbeef", + ) + require.NoError(t, err, "unable to create graph") + + ctx := createTestCtxFromGraphInstance( + t, startingBlockHeight, testGraph, false, + ) + + var pub1 [33]byte + copy(pub1[:], priv1.PubKey().SerializeCompressed()) + + var pub2 [33]byte + copy(pub2[:], priv2.PubKey().SerializeCompressed()) + + // Add the edge between the two unknown nodes to the graph, and check + // that the nodes are found after the fact. + fundingTx, _, chanID, err := createChannelEdge( + ctx, bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), 10000, 500, + ) + require.NoError(t, err, "unable to create channel edge") + fundingBlock := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{fundingTx}, + } + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) + + edge := &models.ChannelEdgeInfo{ + ChannelID: chanID.ToUint64(), + NodeKey1Bytes: pub1, + NodeKey2Bytes: pub2, + BitcoinKey1Bytes: pub1, + BitcoinKey2Bytes: pub2, + AuthProof: nil, + } + edgePolicy := &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + } + + // Attempt to update the edge. This should be ignored, since the edge + // is not yet added to the router. + err = ctx.builder.UpdateEdge(edgePolicy) + if !IsError(err, ErrIgnored) { + t.Fatalf("expected to get ErrIgnore, instead got: %v", err) + } + + // Add the edge. + require.NoErrorf(t, ctx.builder.AddEdge(edge), "expected to be able "+ + "to add edge to the channel graph, even though the vertexes "+ + "were unknown: %v.", err) + + // Now updating the edge policy should succeed. + require.NoError(t, ctx.builder.UpdateEdge(edgePolicy)) +} + +// TestWakeUpOnStaleBranch tests that upon startup of the ChannelRouter, if the +// the chain previously reflected in the channel graph is stale (overtaken by a +// longer chain), the channel router will prune the graph for any channels +// confirmed on the stale chain, and resync to the main chain. +func TestWakeUpOnStaleBranch(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + ctx := createTestCtxSingleNode(t, startingBlockHeight) + + const chanValue = 10000 + + // chanID1 will not be reorged out. + var chanID1 uint64 + + // chanID2 will be reorged out. + var chanID2 uint64 + + // Create 10 common blocks, confirming chanID1. + for i := uint32(1); i <= 10; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := startingBlockHeight + i + if i == 5 { + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + chanValue, height) + if err != nil { + t.Fatalf("unable create channel edge: %v", err) + } + block.Transactions = append(block.Transactions, + fundingTx) + chanID1 = chanID.ToUint64() + + } + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + ctx.chainView.notifyBlock(block.BlockHash(), height, + []*wire.MsgTx{}, t) + } + + // Give time to process new blocks + time.Sleep(time.Millisecond * 500) + + _, forkHeight, err := ctx.chain.GetBestBlock() + require.NoError(t, err, "unable to ge best block") + + // Create 10 blocks on the minority chain, confirming chanID2. + for i := uint32(1); i <= 10; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := uint32(forkHeight) + i + if i == 5 { + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + chanValue, height) + if err != nil { + t.Fatalf("unable create channel edge: %v", err) + } + block.Transactions = append(block.Transactions, + fundingTx) + chanID2 = chanID.ToUint64() + } + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + ctx.chainView.notifyBlock(block.BlockHash(), height, + []*wire.MsgTx{}, t) + } + // Give time to process new blocks + time.Sleep(time.Millisecond * 500) + + // Now add the two edges to the channel graph, and check that they + // correctly show up in the database. + node1 := createTestNode(t) + node2 := createTestNode(t) + + edge1 := &models.ChannelEdgeInfo{ + ChannelID: chanID1, + NodeKey1Bytes: node1.PubKeyBytes, + NodeKey2Bytes: node2.PubKeyBytes, + AuthProof: &models.ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + } + copy(edge1.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) + copy(edge1.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) + + if err := ctx.builder.AddEdge(edge1); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + edge2 := &models.ChannelEdgeInfo{ + ChannelID: chanID2, + NodeKey1Bytes: node1.PubKeyBytes, + NodeKey2Bytes: node2.PubKeyBytes, + AuthProof: &models.ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + } + copy(edge2.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) + copy(edge2.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) + + if err := ctx.builder.AddEdge(edge2); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + // Check that the fundingTxs are in the graph db. + _, _, has, isZombie, err := ctx.graph.HasChannelEdge(chanID1) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID1) + } + if !has { + t.Fatalf("could not find edge in graph") + } + if isZombie { + t.Fatal("edge was marked as zombie") + } + + _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID2) + } + if !has { + t.Fatalf("could not find edge in graph") + } + if isZombie { + t.Fatal("edge was marked as zombie") + } + + // Stop the router, so we can reorg the chain while its offline. + if err := ctx.builder.Stop(); err != nil { + t.Fatalf("unable to stop router: %v", err) + } + + // Create a 15 block fork. + for i := uint32(1); i <= 15; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := uint32(forkHeight) + i + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + } + + // Give time to process new blocks. + time.Sleep(time.Millisecond * 500) + + selfNode, err := ctx.graph.SourceNode() + require.NoError(t, err) + + // Create new router with same graph database. + router, err := NewBuilder(&Config{ + SelfNode: selfNode.PubKeyBytes, + Graph: ctx.graph, + Chain: ctx.chain, + ChainView: ctx.chainView, + ChannelPruneExpiry: time.Hour * 24, + GraphPruneInterval: time.Hour * 2, + + // We'll set the delay to zero to prune immediately. + FirstTimePruneDelay: 0, + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, + }) + require.NoError(t, err) + + // It should resync to the longer chain on startup. + if err := router.Start(); err != nil { + t.Fatalf("unable to start router: %v", err) + } + + // The channel with chanID2 should not be in the database anymore, + // since it is not confirmed on the longest chain. chanID1 should + // still be. + _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID1) + require.NoError(t, err) + + if !has { + t.Fatalf("did not find edge in graph") + } + if isZombie { + t.Fatal("edge was marked as zombie") + } + + _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID2) + } + if has { + t.Fatalf("found edge in graph") + } + if isZombie { + t.Fatal("reorged edge should not be marked as zombie") + } +} + +// TestDisconnectedBlocks checks that the router handles a reorg happening when +// it is active. +func TestDisconnectedBlocks(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + ctx := createTestCtxSingleNode(t, startingBlockHeight) + + const chanValue = 10000 + + // chanID1 will not be reorged out, while chanID2 will be reorged out. + var chanID1, chanID2 uint64 + + // Create 10 common blocks, confirming chanID1. + for i := uint32(1); i <= 10; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := startingBlockHeight + i + if i == 5 { + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + chanValue, height) + if err != nil { + t.Fatalf("unable create channel edge: %v", err) + } + block.Transactions = append(block.Transactions, + fundingTx) + chanID1 = chanID.ToUint64() + + } + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + ctx.chainView.notifyBlock(block.BlockHash(), height, + []*wire.MsgTx{}, t) + } + + // Give time to process new blocks + time.Sleep(time.Millisecond * 500) + + _, forkHeight, err := ctx.chain.GetBestBlock() + require.NoError(t, err, "unable to get best block") + + // Create 10 blocks on the minority chain, confirming chanID2. + var minorityChain []*wire.MsgBlock + for i := uint32(1); i <= 10; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := uint32(forkHeight) + i + if i == 5 { + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + chanValue, height) + if err != nil { + t.Fatalf("unable create channel edge: %v", err) + } + block.Transactions = append(block.Transactions, + fundingTx) + chanID2 = chanID.ToUint64() + } + minorityChain = append(minorityChain, block) + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + ctx.chainView.notifyBlock(block.BlockHash(), height, + []*wire.MsgTx{}, t) + } + // Give time to process new blocks + time.Sleep(time.Millisecond * 500) + + // Now add the two edges to the channel graph, and check that they + // correctly show up in the database. + node1 := createTestNode(t) + node2 := createTestNode(t) + + edge1 := &models.ChannelEdgeInfo{ + ChannelID: chanID1, + NodeKey1Bytes: node1.PubKeyBytes, + NodeKey2Bytes: node2.PubKeyBytes, + BitcoinKey1Bytes: node1.PubKeyBytes, + BitcoinKey2Bytes: node2.PubKeyBytes, + AuthProof: &models.ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + } + copy(edge1.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) + copy(edge1.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) + + if err := ctx.builder.AddEdge(edge1); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + edge2 := &models.ChannelEdgeInfo{ + ChannelID: chanID2, + NodeKey1Bytes: node1.PubKeyBytes, + NodeKey2Bytes: node2.PubKeyBytes, + BitcoinKey1Bytes: node1.PubKeyBytes, + BitcoinKey2Bytes: node2.PubKeyBytes, + AuthProof: &models.ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + } + copy(edge2.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) + copy(edge2.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) + + if err := ctx.builder.AddEdge(edge2); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + // Check that the fundingTxs are in the graph db. + _, _, has, isZombie, err := ctx.graph.HasChannelEdge(chanID1) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID1) + } + if !has { + t.Fatalf("could not find edge in graph") + } + if isZombie { + t.Fatal("edge was marked as zombie") + } + + _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID2) + } + if !has { + t.Fatalf("could not find edge in graph") + } + if isZombie { + t.Fatal("edge was marked as zombie") + } + + // Create a 15 block fork. We first let the chainView notify the router + // about stale blocks, before sending the now connected blocks. We do + // this because we expect this order from the chainview. + ctx.chainView.notifyStaleBlockAck = make(chan struct{}, 1) + for i := len(minorityChain) - 1; i >= 0; i-- { + block := minorityChain[i] + height := uint32(forkHeight) + uint32(i) + 1 + ctx.chainView.notifyStaleBlock(block.BlockHash(), height, + block.Transactions, t) + <-ctx.chainView.notifyStaleBlockAck + } + + time.Sleep(time.Second * 2) + + ctx.chainView.notifyBlockAck = make(chan struct{}, 1) + for i := uint32(1); i <= 15; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := uint32(forkHeight) + i + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + ctx.chainView.notifyBlock(block.BlockHash(), height, + block.Transactions, t) + <-ctx.chainView.notifyBlockAck + } + + time.Sleep(time.Millisecond * 500) + + // chanID2 should not be in the database anymore, since it is not + // confirmed on the longest chain. chanID1 should still be. + _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID1) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID1) + } + if !has { + t.Fatalf("did not find edge in graph") + } + if isZombie { + t.Fatal("edge was marked as zombie") + } + + _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID2) + } + if has { + t.Fatalf("found edge in graph") + } + if isZombie { + t.Fatal("reorged edge should not be marked as zombie") + } +} + +// TestChansClosedOfflinePruneGraph tests that if channels we know of are +// closed while we're offline, then once we resume operation of the +// ChannelRouter, then the channels are properly pruned. +func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + ctx := createTestCtxSingleNode(t, startingBlockHeight) + + const chanValue = 10000 + + // First, we'll create a channel, to be mined shortly at height 102. + block102 := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + nextHeight := startingBlockHeight + 1 + fundingTx1, chanUTXO, chanID1, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + chanValue, uint32(nextHeight)) + require.NoError(t, err, "unable create channel edge") + block102.Transactions = append(block102.Transactions, fundingTx1) + ctx.chain.addBlock(block102, uint32(nextHeight), rand.Uint32()) + ctx.chain.setBestBlock(int32(nextHeight)) + ctx.chainView.notifyBlock(block102.BlockHash(), uint32(nextHeight), + []*wire.MsgTx{}, t) + + // We'll now create the edges and nodes within the database required + // for the ChannelRouter to properly recognize the channel we added + // above. + node1 := createTestNode(t) + node2 := createTestNode(t) + + edge1 := &models.ChannelEdgeInfo{ + ChannelID: chanID1.ToUint64(), + NodeKey1Bytes: node1.PubKeyBytes, + NodeKey2Bytes: node2.PubKeyBytes, + AuthProof: &models.ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + } + copy(edge1.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) + copy(edge1.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) + if err := ctx.builder.AddEdge(edge1); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + // The router should now be aware of the channel we created above. + _, _, hasChan, isZombie, err := ctx.graph.HasChannelEdge(chanID1.ToUint64()) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID1) + } + if !hasChan { + t.Fatalf("could not find edge in graph") + } + if isZombie { + t.Fatal("edge was marked as zombie") + } + + // With the transaction included, and the router's database state + // updated, we'll now mine 5 additional blocks on top of it. + for i := 0; i < 5; i++ { + nextHeight++ + + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + ctx.chain.addBlock(block, uint32(nextHeight), rand.Uint32()) + ctx.chain.setBestBlock(int32(nextHeight)) + ctx.chainView.notifyBlock(block.BlockHash(), uint32(nextHeight), + []*wire.MsgTx{}, t) + } + + // At this point, our starting height should be 107. + _, chainHeight, err := ctx.chain.GetBestBlock() + require.NoError(t, err, "unable to get best block") + if chainHeight != 107 { + t.Fatalf("incorrect chain height: expected %v, got %v", + 107, chainHeight) + } + + // Next, we'll "shut down" the router in order to simulate downtime. + if err := ctx.builder.Stop(); err != nil { + t.Fatalf("unable to shutdown router: %v", err) + } + + // While the router is "offline" we'll mine 5 additional blocks, with + // the second block closing the channel we created above. + for i := 0; i < 5; i++ { + nextHeight++ + + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + + if i == 2 { + // For the second block, we'll add a transaction that + // closes the channel we created above by spending the + // output. + closingTx := wire.NewMsgTx(2) + closingTx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: *chanUTXO, + }) + block.Transactions = append(block.Transactions, + closingTx) + } + + ctx.chain.addBlock(block, uint32(nextHeight), rand.Uint32()) + ctx.chain.setBestBlock(int32(nextHeight)) + ctx.chainView.notifyBlock(block.BlockHash(), uint32(nextHeight), + []*wire.MsgTx{}, t) + } + + // At this point, our starting height should be 112. + _, chainHeight, err = ctx.chain.GetBestBlock() + require.NoError(t, err, "unable to get best block") + if chainHeight != 112 { + t.Fatalf("incorrect chain height: expected %v, got %v", + 112, chainHeight) + } + + // Now we'll re-start the ChannelRouter. It should recognize that it's + // behind the main chain and prune all the blocks that it missed while + // it was down. + ctx.RestartBuilder(t) + + // At this point, the channel that was pruned should no longer be known + // by the router. + _, _, hasChan, isZombie, err = ctx.graph.HasChannelEdge(chanID1.ToUint64()) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID1) + } + if hasChan { + t.Fatalf("channel was found in graph but shouldn't have been") + } + if isZombie { + t.Fatal("closed channel should not be marked as zombie") + } +} + +// TestPruneChannelGraphStaleEdges ensures that we properly prune stale edges +// from the channel graph. +func TestPruneChannelGraphStaleEdges(t *testing.T) { + t.Parallel() + + freshTimestamp := time.Now() + staleTimestamp := time.Unix(0, 0) + + // We'll create the following test graph so that two of the channels + // are pruned. + testChannels := []*testChannel{ + // No edges. + { + Node1: &testChannelEnd{Alias: "a"}, + Node2: &testChannelEnd{Alias: "b"}, + Capacity: 100000, + ChannelID: 1, + }, + + // Only one edge with a stale timestamp. + { + Node1: &testChannelEnd{ + Alias: "d", + testChannelPolicy: &testChannelPolicy{ + LastUpdate: staleTimestamp, + }, + }, + Node2: &testChannelEnd{Alias: "b"}, + Capacity: 100000, + ChannelID: 2, + }, + + // Only one edge with a stale timestamp, but it's the source + // node so it won't get pruned. + { + Node1: &testChannelEnd{ + Alias: "a", + testChannelPolicy: &testChannelPolicy{ + LastUpdate: staleTimestamp, + }, + }, + Node2: &testChannelEnd{Alias: "b"}, + Capacity: 100000, + ChannelID: 3, + }, + + // Only one edge with a fresh timestamp. + { + Node1: &testChannelEnd{ + Alias: "a", + testChannelPolicy: &testChannelPolicy{ + LastUpdate: freshTimestamp, + }, + }, + Node2: &testChannelEnd{Alias: "b"}, + Capacity: 100000, + ChannelID: 4, + }, + + // One edge fresh, one edge stale. This will be pruned with + // strict pruning activated. + { + Node1: &testChannelEnd{ + Alias: "c", + testChannelPolicy: &testChannelPolicy{ + LastUpdate: freshTimestamp, + }, + }, + Node2: &testChannelEnd{ + Alias: "d", + testChannelPolicy: &testChannelPolicy{ + LastUpdate: staleTimestamp, + }, + }, + Capacity: 100000, + ChannelID: 5, + }, + + // Both edges fresh. + symmetricTestChannel("g", "h", 100000, &testChannelPolicy{ + LastUpdate: freshTimestamp, + }, 6), + + // Both edges stale, only one pruned. This should be pruned for + // both normal and strict pruning. + symmetricTestChannel("e", "f", 100000, &testChannelPolicy{ + LastUpdate: staleTimestamp, + }, 7), + } + + for _, strictPruning := range []bool{true, false} { + // We'll create our test graph and router backed with these test + // channels we've created. + testGraph, err := createTestGraphFromChannels( + t, true, testChannels, "a", + ) + if err != nil { + t.Fatalf("unable to create test graph: %v", err) + } + + const startingHeight = 100 + ctx := createTestCtxFromGraphInstance( + t, startingHeight, testGraph, strictPruning, + ) + + // All of the channels should exist before pruning them. + assertChannelsPruned(t, ctx.graph, testChannels) + + // Proceed to prune the channels - only the last one should be pruned. + if err := ctx.builder.pruneZombieChans(); err != nil { + t.Fatalf("unable to prune zombie channels: %v", err) + } + + // We expect channels that have either both edges stale, or one edge + // stale with both known. + var prunedChannels []uint64 + if strictPruning { + prunedChannels = []uint64{2, 5, 7} + } else { + prunedChannels = []uint64{2, 7} + } + assertChannelsPruned(t, ctx.graph, testChannels, prunedChannels...) + } +} + +// TestPruneChannelGraphDoubleDisabled test that we can properly prune channels +// with both edges disabled from our channel graph. +func TestPruneChannelGraphDoubleDisabled(t *testing.T) { + t.Parallel() + + t.Run("no_assumechannelvalid", func(t *testing.T) { + testPruneChannelGraphDoubleDisabled(t, false) + }) + t.Run("assumechannelvalid", func(t *testing.T) { + testPruneChannelGraphDoubleDisabled(t, true) + }) +} + +func testPruneChannelGraphDoubleDisabled(t *testing.T, assumeValid bool) { + // We'll create the following test graph so that only the last channel + // is pruned. We'll use a fresh timestamp to ensure they're not pruned + // according to that heuristic. + timestamp := time.Now() + testChannels := []*testChannel{ + // Channel from self shouldn't be pruned. + symmetricTestChannel( + "self", "a", 100000, &testChannelPolicy{ + LastUpdate: timestamp, + Disabled: true, + }, 99, + ), + + // No edges. + { + Node1: &testChannelEnd{Alias: "a"}, + Node2: &testChannelEnd{Alias: "b"}, + Capacity: 100000, + ChannelID: 1, + }, + + // Only one edge disabled. + { + Node1: &testChannelEnd{ + Alias: "a", + testChannelPolicy: &testChannelPolicy{ + LastUpdate: timestamp, + Disabled: true, + }, + }, + Node2: &testChannelEnd{Alias: "b"}, + Capacity: 100000, + ChannelID: 2, + }, + + // Only one edge enabled. + { + Node1: &testChannelEnd{ + Alias: "a", + testChannelPolicy: &testChannelPolicy{ + LastUpdate: timestamp, + Disabled: false, + }, + }, + Node2: &testChannelEnd{Alias: "b"}, + Capacity: 100000, + ChannelID: 3, + }, + + // One edge disabled, one edge enabled. + { + Node1: &testChannelEnd{ + Alias: "a", + testChannelPolicy: &testChannelPolicy{ + LastUpdate: timestamp, + Disabled: true, + }, + }, + Node2: &testChannelEnd{ + Alias: "b", + testChannelPolicy: &testChannelPolicy{ + LastUpdate: timestamp, + Disabled: false, + }, + }, + Capacity: 100000, + ChannelID: 1, + }, + + // Both edges enabled. + symmetricTestChannel("c", "d", 100000, &testChannelPolicy{ + LastUpdate: timestamp, + Disabled: false, + }, 2), + + // Both edges disabled, only one pruned. + symmetricTestChannel("e", "f", 100000, &testChannelPolicy{ + LastUpdate: timestamp, + Disabled: true, + }, 3), + } + + // We'll create our test graph and router backed with these test + // channels we've created. + testGraph, err := createTestGraphFromChannels( + t, true, testChannels, "self", + ) + require.NoError(t, err, "unable to create test graph") + + const startingHeight = 100 + ctx := createTestCtxFromGraphInstanceAssumeValid( + t, startingHeight, testGraph, assumeValid, false, + ) + + // All the channels should exist within the graph before pruning them + // when not using AssumeChannelValid, otherwise we should have pruned + // the last channel on startup. + if !assumeValid { + assertChannelsPruned(t, ctx.graph, testChannels) + } else { + // Sleep to allow the pruning to finish. + time.Sleep(200 * time.Millisecond) + + prunedChannel := testChannels[len(testChannels)-1].ChannelID + assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel) + } + + if err := ctx.builder.pruneZombieChans(); err != nil { + t.Fatalf("unable to prune zombie channels: %v", err) + } + + // If we attempted to prune them without AssumeChannelValid being set, + // none should be pruned. Otherwise the last channel should still be + // pruned. + if !assumeValid { + assertChannelsPruned(t, ctx.graph, testChannels) + } else { + prunedChannel := testChannels[len(testChannels)-1].ChannelID + assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel) + } +} + +// TestIsStaleNode tests that the IsStaleNode method properly detects stale +// node announcements. +func TestIsStaleNode(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + ctx := createTestCtxSingleNode(t, startingBlockHeight) + + // Before we can insert a node in to the database, we need to create a + // channel that it's linked to. + var ( + pub1 [33]byte + pub2 [33]byte + ) + copy(pub1[:], priv1.PubKey().SerializeCompressed()) + copy(pub2[:], priv2.PubKey().SerializeCompressed()) + + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + 10000, 500) + require.NoError(t, err, "unable to create channel edge") + fundingBlock := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{fundingTx}, + } + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) + + edge := &models.ChannelEdgeInfo{ + ChannelID: chanID.ToUint64(), + NodeKey1Bytes: pub1, + NodeKey2Bytes: pub2, + BitcoinKey1Bytes: pub1, + BitcoinKey2Bytes: pub2, + AuthProof: nil, + } + if err := ctx.builder.AddEdge(edge); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + // Before we add the node, if we query for staleness, we should get + // false, as we haven't added the full node. + updateTimeStamp := time.Unix(123, 0) + if ctx.builder.IsStaleNode(pub1, updateTimeStamp) { + t.Fatalf("incorrectly detected node as stale") + } + + // With the node stub in the database, we'll add the fully node + // announcement to the database. + n1 := &channeldb.LightningNode{ + HaveNodeAnnouncement: true, + LastUpdate: updateTimeStamp, + Addresses: testAddrs, + Color: color.RGBA{1, 2, 3, 0}, + Alias: "node11", + AuthSigBytes: testSig.Serialize(), + Features: testFeatures, + } + copy(n1.PubKeyBytes[:], priv1.PubKey().SerializeCompressed()) + if err := ctx.builder.AddNode(n1); err != nil { + t.Fatalf("could not add node: %v", err) + } + + // If we use the same timestamp and query for staleness, we should get + // true. + if !ctx.builder.IsStaleNode(pub1, updateTimeStamp) { + t.Fatalf("failure to detect stale node update") + } + + // If we update the timestamp and once again query for staleness, it + // should report false. + newTimeStamp := time.Unix(1234, 0) + if ctx.builder.IsStaleNode(pub1, newTimeStamp) { + t.Fatalf("incorrectly detected node as stale") + } +} + +// TestIsKnownEdge tests that the IsKnownEdge method properly detects stale +// channel announcements. +func TestIsKnownEdge(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + ctx := createTestCtxSingleNode(t, startingBlockHeight) + + // First, we'll create a new channel edge (just the info) and insert it + // into the database. + var ( + pub1 [33]byte + pub2 [33]byte + ) + copy(pub1[:], priv1.PubKey().SerializeCompressed()) + copy(pub2[:], priv2.PubKey().SerializeCompressed()) + + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + 10000, 500) + require.NoError(t, err, "unable to create channel edge") + fundingBlock := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{fundingTx}, + } + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) + + edge := &models.ChannelEdgeInfo{ + ChannelID: chanID.ToUint64(), + NodeKey1Bytes: pub1, + NodeKey2Bytes: pub2, + BitcoinKey1Bytes: pub1, + BitcoinKey2Bytes: pub2, + AuthProof: nil, + } + if err := ctx.builder.AddEdge(edge); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + // Now that the edge has been inserted, query is the router already + // knows of the edge should return true. + if !ctx.builder.IsKnownEdge(*chanID) { + t.Fatalf("router should detect edge as known") + } +} + +// TestIsStaleEdgePolicy tests that the IsStaleEdgePolicy properly detects +// stale channel edge update announcements. +func TestIsStaleEdgePolicy(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + ctx := createTestCtxFromFile(t, startingBlockHeight, basicGraphFilePath) + + // First, we'll create a new channel edge (just the info) and insert it + // into the database. + var ( + pub1 [33]byte + pub2 [33]byte + ) + copy(pub1[:], priv1.PubKey().SerializeCompressed()) + copy(pub2[:], priv2.PubKey().SerializeCompressed()) + + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + 10000, 500) + require.NoError(t, err, "unable to create channel edge") + fundingBlock := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{fundingTx}, + } + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) + + // If we query for staleness before adding the edge, we should get + // false. + updateTimeStamp := time.Unix(123, 0) + if ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 0) { + t.Fatalf("router failed to detect fresh edge policy") + } + if ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 1) { + t.Fatalf("router failed to detect fresh edge policy") + } + + edge := &models.ChannelEdgeInfo{ + ChannelID: chanID.ToUint64(), + NodeKey1Bytes: pub1, + NodeKey2Bytes: pub2, + BitcoinKey1Bytes: pub1, + BitcoinKey2Bytes: pub2, + AuthProof: nil, + } + if err := ctx.builder.AddEdge(edge); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + // We'll also add two edge policies, one for each direction. + edgePolicy := &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: updateTimeStamp, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + } + edgePolicy.ChannelFlags = 0 + if err := ctx.builder.UpdateEdge(edgePolicy); err != nil { + t.Fatalf("unable to update edge policy: %v", err) + } + + edgePolicy = &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: updateTimeStamp, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + } + edgePolicy.ChannelFlags = 1 + if err := ctx.builder.UpdateEdge(edgePolicy); err != nil { + t.Fatalf("unable to update edge policy: %v", err) + } + + // Now that the edges have been added, an identical (chanID, flag, + // timestamp) tuple for each edge should be detected as a stale edge. + if !ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 0) { + t.Fatalf("router failed to detect stale edge policy") + } + if !ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 1) { + t.Fatalf("router failed to detect stale edge policy") + } + + // If we now update the timestamp for both edges, the router should + // detect that this tuple represents a fresh edge. + updateTimeStamp = time.Unix(9999, 0) + if ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 0) { + t.Fatalf("router failed to detect fresh edge policy") + } + if ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 1) { + t.Fatalf("router failed to detect fresh edge policy") + } +} + +// edgeCreationModifier is an enum-like type used to modify steps that are +// skipped when creating a channel in the test context. +type edgeCreationModifier uint8 + +const ( + // edgeCreationNoFundingTx is used to skip adding the funding + // transaction of an edge to the chain. + edgeCreationNoFundingTx edgeCreationModifier = iota + + // edgeCreationNoUTXO is used to skip adding the UTXO of a channel to + // the UTXO set. + edgeCreationNoUTXO + + // edgeCreationBadScript is used to create the edge, but use the wrong + // scrip which should cause it to fail output validation. + edgeCreationBadScript +) + +// newChannelEdgeInfo is a helper function used to create a new channel edge, +// possibly skipping adding it to parts of the chain/state as well. +func newChannelEdgeInfo(t *testing.T, ctx *testCtx, fundingHeight uint32, + ecm edgeCreationModifier) (*models.ChannelEdgeInfo, error) { + + node1 := createTestNode(t) + node2 := createTestNode(t) + + fundingTx, _, chanID, err := createChannelEdge( + ctx, bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), 100, fundingHeight, + ) + if err != nil { + return nil, fmt.Errorf("unable to create edge: %w", err) + } + + edge := &models.ChannelEdgeInfo{ + ChannelID: chanID.ToUint64(), + NodeKey1Bytes: node1.PubKeyBytes, + NodeKey2Bytes: node2.PubKeyBytes, + } + copy(edge.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) + copy(edge.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) + + if ecm == edgeCreationNoFundingTx { + return edge, nil + } + + fundingBlock := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{fundingTx}, + } + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) + + if ecm == edgeCreationNoUTXO { + ctx.chain.delUtxo(wire.OutPoint{ + Hash: fundingTx.TxHash(), + }) + } + + if ecm == edgeCreationBadScript { + fundingTx.TxOut[0].PkScript[0] ^= 1 + } + + return edge, nil +} + +func assertChanChainRejection(t *testing.T, ctx *testCtx, + edge *models.ChannelEdgeInfo, failCode errorCode) { + + t.Helper() + + err := ctx.builder.AddEdge(edge) + if !IsError(err, failCode) { + t.Fatalf("validation should have failed: %v", err) + } + + // This channel should now be present in the zombie channel index. + _, _, _, isZombie, err := ctx.graph.HasChannelEdge( + edge.ChannelID, + ) + require.Nil(t, err) + require.True(t, isZombie, "edge should be marked as zombie") +} + +// TestChannelOnChainRejectionZombie tests that if we fail validating a channel +// due to some sort of on-chain rejection (no funding transaction, or invalid +// UTXO), then we'll mark the channel as a zombie. +func TestChannelOnChainRejectionZombie(t *testing.T) { + t.Parallel() + + ctx := createTestCtxSingleNode(t, 0) + + // To start, we'll make an edge for the channel, but we won't add the + // funding transaction to the mock blockchain, which should cause the + // validation to fail below. + edge, err := newChannelEdgeInfo(t, ctx, 1, edgeCreationNoFundingTx) + require.Nil(t, err) + + // We expect this to fail as the transaction isn't present in the + // chain (nor the block). + assertChanChainRejection(t, ctx, edge, ErrNoFundingTransaction) + + // Next, we'll make another channel edge, but actually add it to the + // graph this time. + edge, err = newChannelEdgeInfo(t, ctx, 2, edgeCreationNoUTXO) + require.Nil(t, err) + + // Instead now, we'll remove it from the set of UTXOs which should + // cause the spentness validation to fail. + assertChanChainRejection(t, ctx, edge, ErrChannelSpent) + + // If we cause the funding transaction the chain to fail validation, we + // should see similar behavior. + edge, err = newChannelEdgeInfo(t, ctx, 3, edgeCreationBadScript) + require.Nil(t, err) + assertChanChainRejection(t, ctx, edge, ErrInvalidFundingOutput) +} + +// TestBlockDifferenceFix tests if when the router is behind on blocks, the +// router catches up to the best block head. +func TestBlockDifferenceFix(t *testing.T) { + t.Parallel() + + initialBlockHeight := uint32(0) + + // Starting height here is set to 0, which is behind where we want to + // be. + ctx := createTestCtxSingleNode(t, initialBlockHeight) + + // Add initial block to our mini blockchain. + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + ctx.chain.addBlock(block, initialBlockHeight, rand.Uint32()) + + // Let's generate a new block of height 5, 5 above where our node is at. + newBlock := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + newBlockHeight := uint32(5) + + blockDifference := newBlockHeight - initialBlockHeight + + ctx.chainView.notifyBlockAck = make(chan struct{}, 1) + + ctx.chain.addBlock(newBlock, newBlockHeight, rand.Uint32()) + ctx.chain.setBestBlock(int32(newBlockHeight)) + ctx.chainView.notifyBlock(block.BlockHash(), newBlockHeight, + []*wire.MsgTx{}, t) + + <-ctx.chainView.notifyBlockAck + + // At this point, the chain notifier should have noticed that we're + // behind on blocks, and will send the n missing blocks that we + // need to the client's epochs channel. Let's replicate this + // functionality. + for i := 0; i < int(blockDifference); i++ { + currBlockHeight := int32(i + 1) + + nonce := rand.Uint32() + + newBlock := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + Header: wire.BlockHeader{Nonce: nonce}, + } + ctx.chain.addBlock(newBlock, uint32(currBlockHeight), nonce) + currHash := newBlock.Header.BlockHash() + + newEpoch := &chainntnfs.BlockEpoch{ + Height: currBlockHeight, + Hash: &currHash, + } + + ctx.notifier.EpochChan <- newEpoch + + ctx.chainView.notifyBlock(currHash, + uint32(currBlockHeight), block.Transactions, t) + + <-ctx.chainView.notifyBlockAck + } + + err := wait.NoError(func() error { + // Then router height should be updated to the latest block. + if atomic.LoadUint32(&ctx.builder.bestHeight) != newBlockHeight { + return fmt.Errorf("height should have been updated "+ + "to %v, instead got %v", newBlockHeight, + ctx.builder.bestHeight) + } + + return nil + }, testTimeout) + require.NoError(t, err, "block height wasn't updated") +} + +func createTestCtxFromFile(t *testing.T, + startingHeight uint32, testGraph string) *testCtx { + + // We'll attempt to locate and parse out the file + // that encodes the graph that our tests should be run against. + graphInstance, err := parseTestGraph(t, true, testGraph) + require.NoError(t, err, "unable to create test graph") + + return createTestCtxFromGraphInstance( + t, startingHeight, graphInstance, false, + ) +} + +// parseTestGraph returns a fully populated ChannelGraph given a path to a JSON +// file which encodes a test graph. +func parseTestGraph(t *testing.T, useCache bool, path string) ( + *testGraphInstance, error) { + + graphJSON, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + // First unmarshal the JSON graph into an instance of the testGraph + // struct. Using the struct tags created above in the struct, the JSON + // will be properly parsed into the struct above. + var g testGraph + if err := json.Unmarshal(graphJSON, &g); err != nil { + return nil, err + } + + // We'll use this fake address for the IP address of all the nodes in + // our tests. This value isn't needed for path finding so it doesn't + // need to be unique. + var testAddrs []net.Addr + testAddr, err := net.ResolveTCPAddr("tcp", "192.0.0.1:8888") + if err != nil { + return nil, err + } + testAddrs = append(testAddrs, testAddr) + + // Next, create a temporary graph database for usage within the test. + graph, graphBackend, err := makeTestGraph(t, useCache) + if err != nil { + return nil, err + } + + aliasMap := make(map[string]route.Vertex) + privKeyMap := make(map[string]*btcec.PrivateKey) + channelIDs := make(map[route.Vertex]map[route.Vertex]uint64) + links := make(map[lnwire.ShortChannelID]htlcswitch.ChannelLink) + var source *channeldb.LightningNode + + // First we insert all the nodes within the graph as vertexes. + for _, node := range g.Nodes { + pubBytes, err := hex.DecodeString(node.PubKey) + if err != nil { + return nil, err + } + + dbNode := &channeldb.LightningNode{ + HaveNodeAnnouncement: true, + AuthSigBytes: testSig.Serialize(), + LastUpdate: testTime, + Addresses: testAddrs, + Alias: node.Alias, + Features: testFeatures, + } + copy(dbNode.PubKeyBytes[:], pubBytes) + + // We require all aliases within the graph to be unique for our + // tests. + if _, ok := aliasMap[node.Alias]; ok { + return nil, errors.New("aliases for nodes " + + "must be unique!") + } + + // If the alias is unique, then add the node to the + // alias map for easy lookup. + aliasMap[node.Alias] = dbNode.PubKeyBytes + + // private keys are needed for signing error messages. If set + // check the consistency with the public key. + privBytes, err := hex.DecodeString(node.PrivKey) + if err != nil { + return nil, err + } + if len(privBytes) > 0 { + key, derivedPub := btcec.PrivKeyFromBytes( + privBytes, + ) + + if !bytes.Equal( + pubBytes, derivedPub.SerializeCompressed(), + ) { + + return nil, fmt.Errorf("%s public key and "+ + "private key are inconsistent\n"+ + "got %x\nwant %x\n", + node.Alias, + derivedPub.SerializeCompressed(), + pubBytes, + ) + } + + privKeyMap[node.Alias] = key + } + + // If the node is tagged as the source, then we create a + // pointer to is so we can mark the source in the graph + // properly. + if node.Source { + // If we come across a node that's marked as the + // source, and we've already set the source in a prior + // iteration, then the JSON has an error as only ONE + // node can be the source in the graph. + if source != nil { + return nil, errors.New("JSON is invalid " + + "multiple nodes are tagged as the " + + "source") + } + + source = dbNode + } + + // With the node fully parsed, add it as a vertex within the + // graph. + if err := graph.AddLightningNode(dbNode); err != nil { + return nil, err + } + } + + if source != nil { + // Set the selected source node + if err := graph.SetSourceNode(source); err != nil { + return nil, err + } + } + + // With all the vertexes inserted, we can now insert the edges into the + // test graph. + for _, edge := range g.Edges { + node1Bytes, err := hex.DecodeString(edge.Node1) + if err != nil { + return nil, err + } + + node2Bytes, err := hex.DecodeString(edge.Node2) + if err != nil { + return nil, err + } + + if bytes.Compare(node1Bytes, node2Bytes) == 1 { + return nil, fmt.Errorf( + "channel %v node order incorrect", + edge.ChannelID, + ) + } + + fundingTXID := strings.Split(edge.ChannelPoint, ":")[0] + txidBytes, err := chainhash.NewHashFromStr(fundingTXID) + if err != nil { + return nil, err + } + fundingPoint := wire.OutPoint{ + Hash: *txidBytes, + Index: 0, + } + + // We first insert the existence of the edge between the two + // nodes. + edgeInfo := models.ChannelEdgeInfo{ + ChannelID: edge.ChannelID, + AuthProof: &testAuthProof, + ChannelPoint: fundingPoint, + Capacity: btcutil.Amount(edge.Capacity), + } + + copy(edgeInfo.NodeKey1Bytes[:], node1Bytes) + copy(edgeInfo.NodeKey2Bytes[:], node2Bytes) + copy(edgeInfo.BitcoinKey1Bytes[:], node1Bytes) + copy(edgeInfo.BitcoinKey2Bytes[:], node2Bytes) + + shortID := lnwire.NewShortChanIDFromInt(edge.ChannelID) + links[shortID] = &mockLink{ + bandwidth: lnwire.MilliSatoshi( + edgeInfo.Capacity * 1000, + ), + } + + err = graph.AddChannelEdge(&edgeInfo) + if err != nil && err != channeldb.ErrEdgeAlreadyExist { + return nil, err + } + + channelFlags := lnwire.ChanUpdateChanFlags(edge.ChannelFlags) + isUpdate1 := channelFlags&lnwire.ChanUpdateDirection == 0 + targetNode := edgeInfo.NodeKey1Bytes + if isUpdate1 { + targetNode = edgeInfo.NodeKey2Bytes + } + + edgePolicy := &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), + ChannelFlags: channelFlags, + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: edge.Expiry, + MinHTLC: lnwire.MilliSatoshi(edge.MinHTLC), + MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC), + FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat), + FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate), + ToNode: targetNode, + } + if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { + return nil, err + } + + // We also store the channel IDs info for each of the node. + node1Vertex, err := route.NewVertexFromBytes(node1Bytes) + if err != nil { + return nil, err + } + + node2Vertex, err := route.NewVertexFromBytes(node2Bytes) + if err != nil { + return nil, err + } + + if _, ok := channelIDs[node1Vertex]; !ok { + channelIDs[node1Vertex] = map[route.Vertex]uint64{} + } + channelIDs[node1Vertex][node2Vertex] = edge.ChannelID + + if _, ok := channelIDs[node2Vertex]; !ok { + channelIDs[node2Vertex] = map[route.Vertex]uint64{} + } + channelIDs[node2Vertex][node1Vertex] = edge.ChannelID + } + + return &testGraphInstance{ + graph: graph, + graphBackend: graphBackend, + aliasMap: aliasMap, + privKeyMap: privKeyMap, + channelIDs: channelIDs, + links: links, + }, nil +} + +// testGraph is the struct which corresponds to the JSON format used to encode +// graphs within the files in the testdata directory. +// +// TODO(roasbeef): add test graph auto-generator +type testGraph struct { + Info []string `json:"info"` + Nodes []testNode `json:"nodes"` + Edges []testChan `json:"edges"` +} + +// testNode represents a node within the test graph above. We skip certain +// information such as the node's IP address as that information isn't needed +// for our tests. Private keys are optional. If set, they should be consistent +// with the public key. The private key is used to sign error messages +// sent from the node. +type testNode struct { + Source bool `json:"source"` + PubKey string `json:"pubkey"` + PrivKey string `json:"privkey"` + Alias string `json:"alias"` +} + +// testChan represents the JSON version of a payment channel. This struct +// matches the Json that's encoded under the "edges" key within the test graph. +type testChan struct { + Node1 string `json:"node_1"` + Node2 string `json:"node_2"` + ChannelID uint64 `json:"channel_id"` + ChannelPoint string `json:"channel_point"` + ChannelFlags uint8 `json:"channel_flags"` + MessageFlags uint8 `json:"message_flags"` + Expiry uint16 `json:"expiry"` + MinHTLC int64 `json:"min_htlc"` + MaxHTLC int64 `json:"max_htlc"` + FeeBaseMsat int64 `json:"fee_base_msat"` + FeeRate int64 `json:"fee_rate"` + Capacity int64 `json:"capacity"` +} + +type testChannel struct { + Node1 *testChannelEnd + Node2 *testChannelEnd + Capacity btcutil.Amount + ChannelID uint64 +} + +type testChannelEnd struct { + Alias string + *testChannelPolicy +} + +func symmetricTestChannel(alias1, alias2 string, capacity btcutil.Amount, + policy *testChannelPolicy, chanID ...uint64) *testChannel { + + // Leaving id zero will result in auto-generation of a channel id during + // graph construction. + var id uint64 + if len(chanID) > 0 { + id = chanID[0] + } + + policy2 := *policy + + return asymmetricTestChannel( + alias1, alias2, capacity, policy, &policy2, id, + ) +} + +func asymmetricTestChannel(alias1, alias2 string, capacity btcutil.Amount, + policy1, policy2 *testChannelPolicy, id uint64) *testChannel { + + return &testChannel{ + Capacity: capacity, + Node1: &testChannelEnd{ + Alias: alias1, + testChannelPolicy: policy1, + }, + Node2: &testChannelEnd{ + Alias: alias2, + testChannelPolicy: policy2, + }, + ChannelID: id, + } +} + +// assertChannelsPruned ensures that only the given channels are pruned from the +// graph out of the set of all channels. +func assertChannelsPruned(t *testing.T, graph *channeldb.ChannelGraph, + channels []*testChannel, prunedChanIDs ...uint64) { + + t.Helper() + + pruned := make(map[uint64]struct{}, len(channels)) + for _, chanID := range prunedChanIDs { + pruned[chanID] = struct{}{} + } + + for _, channel := range channels { + _, shouldPrune := pruned[channel.ChannelID] + _, _, exists, isZombie, err := graph.HasChannelEdge( + channel.ChannelID, + ) + if err != nil { + t.Fatalf("unable to determine existence of "+ + "channel=%v in the graph: %v", + channel.ChannelID, err) + } + if !shouldPrune && !exists { + t.Fatalf("expected channel=%v to exist within "+ + "the graph", channel.ChannelID) + } + if shouldPrune && exists { + t.Fatalf("expected channel=%v to not exist "+ + "within the graph", channel.ChannelID) + } + if !shouldPrune && isZombie { + t.Fatalf("expected channel=%v to not be marked "+ + "as zombie", channel.ChannelID) + } + if shouldPrune && !isZombie { + t.Fatalf("expected channel=%v to be marked as "+ + "zombie", channel.ChannelID) + } + } +} + +type testChannelPolicy struct { + Expiry uint16 + MinHTLC lnwire.MilliSatoshi + MaxHTLC lnwire.MilliSatoshi + FeeBaseMsat lnwire.MilliSatoshi + FeeRate lnwire.MilliSatoshi + InboundFeeBaseMsat int64 + InboundFeeRate int64 + LastUpdate time.Time + Disabled bool + Features *lnwire.FeatureVector +} + +// createTestGraphFromChannels returns a fully populated ChannelGraph based on a set of +// test channels. Additional required information like keys are derived in +// a deterministic way and added to the channel graph. A list of nodes is +// not required and derived from the channel data. The goal is to keep +// instantiating a test channel graph as light weight as possible. +func createTestGraphFromChannels(t *testing.T, useCache bool, + testChannels []*testChannel, source string) (*testGraphInstance, error) { + + // We'll use this fake address for the IP address of all the nodes in + // our tests. This value isn't needed for path finding so it doesn't + // need to be unique. + var testAddrs []net.Addr + testAddr, err := net.ResolveTCPAddr("tcp", "192.0.0.1:8888") + if err != nil { + return nil, err + } + testAddrs = append(testAddrs, testAddr) + + // Next, create a temporary graph database for usage within the test. + graph, graphBackend, err := makeTestGraph(t, useCache) + if err != nil { + return nil, err + } + + aliasMap := make(map[string]route.Vertex) + privKeyMap := make(map[string]*btcec.PrivateKey) + + nodeIndex := byte(0) + addNodeWithAlias := func(alias string, features *lnwire.FeatureVector) ( + *channeldb.LightningNode, error) { + + keyBytes := []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, nodeIndex + 1, + } + + privKey, pubKey := btcec.PrivKeyFromBytes(keyBytes) + + if features == nil { + features = lnwire.EmptyFeatureVector() + } + + dbNode := &channeldb.LightningNode{ + HaveNodeAnnouncement: true, + AuthSigBytes: testSig.Serialize(), + LastUpdate: testTime, + Addresses: testAddrs, + Alias: alias, + Features: features, + } + + copy(dbNode.PubKeyBytes[:], pubKey.SerializeCompressed()) + + privKeyMap[alias] = privKey + + // With the node fully parsed, add it as a vertex within the + // graph. + if err := graph.AddLightningNode(dbNode); err != nil { + return nil, err + } + + aliasMap[alias] = dbNode.PubKeyBytes + nodeIndex++ + + return dbNode, nil + } + + // Add the source node. + dbNode, err := addNodeWithAlias(source, lnwire.EmptyFeatureVector()) + if err != nil { + return nil, err + } + + if err = graph.SetSourceNode(dbNode); err != nil { + return nil, err + } + + // Initialize variable that keeps track of the next channel id to assign + // if none is specified. + nextUnassignedChannelID := uint64(100000) + + links := make(map[lnwire.ShortChannelID]htlcswitch.ChannelLink) + + for _, testChannel := range testChannels { + for _, node := range []*testChannelEnd{ + testChannel.Node1, testChannel.Node2, + } { + _, exists := aliasMap[node.Alias] + if !exists { + var features *lnwire.FeatureVector + if node.testChannelPolicy != nil { + features = + node.testChannelPolicy.Features + } + _, err := addNodeWithAlias( + node.Alias, features, + ) + if err != nil { + return nil, err + } + } + } + + channelID := testChannel.ChannelID + + // If no channel id is specified, generate an id. + if channelID == 0 { + channelID = nextUnassignedChannelID + nextUnassignedChannelID++ + } + + var hash [sha256.Size]byte + hash[len(hash)-1] = byte(channelID) + + fundingPoint := &wire.OutPoint{ + Hash: chainhash.Hash(hash), + Index: 0, + } + + capacity := lnwire.MilliSatoshi(testChannel.Capacity * 1000) + shortID := lnwire.NewShortChanIDFromInt(channelID) + links[shortID] = &mockLink{ + bandwidth: capacity, + } + + // Sort nodes + node1 := testChannel.Node1 + node2 := testChannel.Node2 + node1Vertex := aliasMap[node1.Alias] + node2Vertex := aliasMap[node2.Alias] + if bytes.Compare(node1Vertex[:], node2Vertex[:]) == 1 { + node1, node2 = node2, node1 + node1Vertex, node2Vertex = node2Vertex, node1Vertex + } + + // We first insert the existence of the edge between the two + // nodes. + edgeInfo := models.ChannelEdgeInfo{ + ChannelID: channelID, + AuthProof: &testAuthProof, + ChannelPoint: *fundingPoint, + Capacity: testChannel.Capacity, + + NodeKey1Bytes: node1Vertex, + BitcoinKey1Bytes: node1Vertex, + NodeKey2Bytes: node2Vertex, + BitcoinKey2Bytes: node2Vertex, + } + + err = graph.AddChannelEdge(&edgeInfo) + if err != nil && err != channeldb.ErrEdgeAlreadyExist { + return nil, err + } + + getExtraData := func( + end *testChannelEnd) lnwire.ExtraOpaqueData { + + var extraData lnwire.ExtraOpaqueData + inboundFee := lnwire.Fee{ + BaseFee: int32(end.InboundFeeBaseMsat), + FeeRate: int32(end.InboundFeeRate), + } + require.NoError(t, extraData.PackRecords(&inboundFee)) + + return extraData + } + + if node1.testChannelPolicy != nil { + var msgFlags lnwire.ChanUpdateMsgFlags + if node1.MaxHTLC != 0 { + msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc + } + var channelFlags lnwire.ChanUpdateChanFlags + if node1.Disabled { + channelFlags |= lnwire.ChanUpdateDisabled + } + + edgePolicy := &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + MessageFlags: msgFlags, + ChannelFlags: channelFlags, + ChannelID: channelID, + LastUpdate: node1.LastUpdate, + TimeLockDelta: node1.Expiry, + MinHTLC: node1.MinHTLC, + MaxHTLC: node1.MaxHTLC, + FeeBaseMSat: node1.FeeBaseMsat, + FeeProportionalMillionths: node1.FeeRate, + ToNode: node2Vertex, + ExtraOpaqueData: getExtraData(node1), + } + if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { + return nil, err + } + } + + if node2.testChannelPolicy != nil { + var msgFlags lnwire.ChanUpdateMsgFlags + if node2.MaxHTLC != 0 { + msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc + } + var channelFlags lnwire.ChanUpdateChanFlags + if node2.Disabled { + channelFlags |= lnwire.ChanUpdateDisabled + } + channelFlags |= lnwire.ChanUpdateDirection + + edgePolicy := &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + MessageFlags: msgFlags, + ChannelFlags: channelFlags, + ChannelID: channelID, + LastUpdate: node2.LastUpdate, + TimeLockDelta: node2.Expiry, + MinHTLC: node2.MinHTLC, + MaxHTLC: node2.MaxHTLC, + FeeBaseMSat: node2.FeeBaseMsat, + FeeProportionalMillionths: node2.FeeRate, + ToNode: node1Vertex, + ExtraOpaqueData: getExtraData(node2), + } + if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { + return nil, err + } + } + + channelID++ + } + + return &testGraphInstance{ + graph: graph, + graphBackend: graphBackend, + aliasMap: aliasMap, + privKeyMap: privKeyMap, + links: links, + }, nil +} + +type mockLink struct { + htlcswitch.ChannelLink + bandwidth lnwire.MilliSatoshi + mayAddOutgoingErr error + ineligible bool +} + +// Bandwidth returns the bandwidth the mock was configured with. +func (m *mockLink) Bandwidth() lnwire.MilliSatoshi { + return m.bandwidth +} + +// EligibleToForward returns the mock's configured eligibility. +func (m *mockLink) EligibleToForward() bool { + return !m.ineligible +} + +// MayAddOutgoingHtlc returns the error configured in our mock. +func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error { + return m.mayAddOutgoingErr +} diff --git a/routing/errors.go b/graph/errors.go similarity index 80% rename from routing/errors.go rename to graph/errors.go index 95ed613ca0..c0d6b8904a 100644 --- a/routing/errors.go +++ b/graph/errors.go @@ -1,4 +1,4 @@ -package routing +package graph import "github.com/go-errors/errors" @@ -39,27 +39,27 @@ const ( ErrParentValidationFailed ) -// routerError is a structure that represent the error inside the routing package, +// graphError is a structure that represent the error inside the graph package, // this structure carries additional information about error code in order to // be able distinguish errors outside of the current package. -type routerError struct { +type graphError struct { err *errors.Error code errorCode } // Error represents errors as the string // NOTE: Part of the error interface. -func (e *routerError) Error() string { +func (e *graphError) Error() string { return e.err.Error() } -// A compile time check to ensure routerError implements the error interface. -var _ error = (*routerError)(nil) +// A compile time check to ensure graphError implements the error interface. +var _ error = (*graphError)(nil) -// newErrf creates a routerError by the given error formatted description and +// newErrf creates a graphError by the given error formatted description and // its corresponding error code. -func newErrf(code errorCode, format string, a ...interface{}) *routerError { - return &routerError{ +func newErrf(code errorCode, format string, a ...interface{}) *graphError { + return &graphError{ code: code, err: errors.Errorf(format, a...), } @@ -68,7 +68,7 @@ func newErrf(code errorCode, format string, a ...interface{}) *routerError { // IsError is a helper function which is needed to have ability to check that // returned error has specific error code. func IsError(e interface{}, codes ...errorCode) bool { - err, ok := e.(*routerError) + err, ok := e.(*graphError) if !ok { return false } diff --git a/routing/notifications.go b/graph/notifications.go similarity index 95% rename from routing/notifications.go rename to graph/notifications.go index 7263b9a47c..36f4e09a97 100644 --- a/routing/notifications.go +++ b/graph/notifications.go @@ -1,4 +1,4 @@ -package routing +package graph import ( "fmt" @@ -13,7 +13,6 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" - "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/lnwire" ) @@ -57,16 +56,16 @@ type topologyClientUpdate struct { // topology occurs. Changes that will be sent at notifications include: new // nodes appearing, node updating their attributes, new channels, channels // closing, and updates in the routing policies of a channel's directed edges. -func (r *ChannelRouter) SubscribeTopology() (*TopologyClient, error) { +func (b *Builder) SubscribeTopology() (*TopologyClient, error) { // If the router is not yet started, return an error to avoid a // deadlock waiting for it to handle the subscription request. - if atomic.LoadUint32(&r.started) == 0 { + if !b.started.Load() { return nil, fmt.Errorf("router not started") } // We'll first atomically obtain the next ID for this client from the // incrementing client ID counter. - clientID := atomic.AddUint64(&r.ntfnClientCounter, 1) + clientID := atomic.AddUint64(&b.ntfnClientCounter, 1) log.Debugf("New graph topology client subscription, client %v", clientID) @@ -74,12 +73,12 @@ func (r *ChannelRouter) SubscribeTopology() (*TopologyClient, error) { ntfnChan := make(chan *TopologyChange, 10) select { - case r.ntfnClientUpdates <- &topologyClientUpdate{ + case b.ntfnClientUpdates <- &topologyClientUpdate{ cancel: false, clientID: clientID, ntfnChan: ntfnChan, }: - case <-r.quit: + case <-b.quit: return nil, errors.New("ChannelRouter shutting down") } @@ -87,11 +86,11 @@ func (r *ChannelRouter) SubscribeTopology() (*TopologyClient, error) { TopologyChanges: ntfnChan, Cancel: func() { select { - case r.ntfnClientUpdates <- &topologyClientUpdate{ + case b.ntfnClientUpdates <- &topologyClientUpdate{ cancel: true, clientID: clientID, }: - case <-r.quit: + case <-b.quit: return } }, @@ -117,7 +116,7 @@ type topologyClient struct { // notifyTopologyChange notifies all registered clients of a new change in // graph topology in a non-blocking. -func (r *ChannelRouter) notifyTopologyChange(topologyDiff *TopologyChange) { +func (b *Builder) notifyTopologyChange(topologyDiff *TopologyChange) { // notifyClient is a helper closure that will send topology updates to // the given client. @@ -146,7 +145,7 @@ func (r *ChannelRouter) notifyTopologyChange(topologyDiff *TopologyChange) { // Similarly, if the ChannelRouter itself exists early, // then we'll also exit ourselves. - case <-r.quit: + case <-b.quit: } }(client) @@ -158,7 +157,7 @@ func (r *ChannelRouter) notifyTopologyChange(topologyDiff *TopologyChange) { // Range over the set of active clients, and attempt to send the // topology updates. - r.topologyClients.Range(notifyClient) + b.topologyClients.Range(notifyClient) } // TopologyChange represents a new set of modifications to the channel graph. @@ -314,7 +313,7 @@ type ChannelEdgeUpdate struct { // constitutes. This function will also fetch any required auxiliary // information required to create the topology change update from the graph // database. -func addToTopologyChange(graph graph.DB, update *TopologyChange, +func addToTopologyChange(graph DB, update *TopologyChange, msg interface{}) error { switch m := msg.(type) { diff --git a/routing/notifications_test.go b/graph/notifications_test.go similarity index 78% rename from routing/notifications_test.go rename to graph/notifications_test.go index 4e095649b5..290eec0e2a 100644 --- a/routing/notifications_test.go +++ b/graph/notifications_test.go @@ -1,7 +1,8 @@ -package routing +package graph import ( "bytes" + "encoding/hex" "fmt" "image/color" prand "math/rand" @@ -11,13 +12,17 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/kvdb" + lnmock "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/btcwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -49,15 +54,28 @@ var ( bitcoinKey2 = priv2.PubKey() timeout = time.Second * 5 + + testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571319d18e949ddfa2965fb6caa1bf0314f882d7") + testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a88121167221b6700d72a0ead154c03be696a292d24ae") + testRScalar = new(btcec.ModNScalar) + testSScalar = new(btcec.ModNScalar) + _ = testRScalar.SetByteSlice(testRBytes) + _ = testSScalar.SetByteSlice(testSBytes) + testSig = ecdsa.NewSignature(testRScalar, testSScalar) + + testAuthProof = models.ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + } ) -func createTestNode() (*channeldb.LightningNode, error) { +func createTestNode(t *testing.T) *channeldb.LightningNode { updateTime := prand.Int63() priv, err := btcec.NewPrivateKey() - if err != nil { - return nil, errors.Errorf("unable create private key: %v", err) - } + require.NoError(t, err) pub := priv.PubKey().SerializeCompressed() n := &channeldb.LightningNode{ @@ -71,7 +89,7 @@ func createTestNode() (*channeldb.LightningNode, error) { } copy(n.PubKeyBytes[:], pub) - return n, nil + return n } func randEdgePolicy(chanID *lnwire.ShortChannelID, @@ -271,7 +289,7 @@ type mockChainView struct { } // A compile time check to ensure mockChainView implements the -// chainview.FilteredChainView. +// chainview.FilteredChainViewReader. var _ chainview.FilteredChainView = (*mockChainView)(nil) func newMockChainView(chain lnwallet.BlockChainIO) *mockChainView { @@ -302,6 +320,15 @@ func (m *mockChainView) UpdateFilter(ops []channeldb.EdgePoint, updateHeight uin return nil } +func (m *mockChainView) Start() error { + return nil +} + +func (m *mockChainView) Stop() error { + close(m.quit) + return nil +} + func (m *mockChainView) notifyBlock(hash chainhash.Hash, height uint32, txns []*wire.MsgTx, t *testing.T) { @@ -405,15 +432,6 @@ func (m *mockChainView) FilterBlock(blockHash *chainhash.Hash) (*chainview.Filte return filteredBlock, nil } -func (m *mockChainView) Start() error { - return nil -} - -func (m *mockChainView) Stop() error { - close(m.quit) - return nil -} - // TestEdgeUpdateNotification tests that when edges are updated or added, // a proper notification is sent of to all registered clients. func TestEdgeUpdateNotification(t *testing.T) { @@ -437,10 +455,8 @@ func TestEdgeUpdateNotification(t *testing.T) { // Next we'll create two test nodes that the fake channel will be open // between. - node1, err := createTestNode() - require.NoError(t, err, "unable to create test node") - node2, err := createTestNode() - require.NoError(t, err, "unable to create test node") + node1 := createTestNode(t) + node2 := createTestNode(t) // Finally, to conclude our test set up, we'll create a channel // update to announce the created channel between the two nodes. @@ -458,13 +474,13 @@ func TestEdgeUpdateNotification(t *testing.T) { copy(edge.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) copy(edge.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) - if err := ctx.router.AddEdge(edge); err != nil { + if err := ctx.builder.AddEdge(edge); err != nil { t.Fatalf("unable to add edge: %v", err) } // With the channel edge now in place, we'll subscribe for topology // notifications. - ntfnClient, err := ctx.router.SubscribeTopology() + ntfnClient, err := ctx.builder.SubscribeTopology() require.NoError(t, err, "unable to subscribe for channel notifications") // Create random policy edges that are stemmed to the channel id @@ -477,10 +493,10 @@ func TestEdgeUpdateNotification(t *testing.T) { require.NoError(t, err, "unable to create a random chan policy") edge2.ChannelFlags = 1 - if err := ctx.router.UpdateEdge(edge1); err != nil { + if err := ctx.builder.UpdateEdge(edge1); err != nil { t.Fatalf("unable to add edge update: %v", err) } - if err := ctx.router.UpdateEdge(edge2); err != nil { + if err := ctx.builder.UpdateEdge(edge2); err != nil { t.Fatalf("unable to add edge update: %v", err) } @@ -625,10 +641,8 @@ func TestNodeUpdateNotification(t *testing.T) { // Create two nodes acting as endpoints in the created channel, and use // them to trigger notifications by sending updated node announcement // messages. - node1, err := createTestNode() - require.NoError(t, err, "unable to create test node") - node2, err := createTestNode() - require.NoError(t, err, "unable to create test node") + node1 := createTestNode(t) + node2 := createTestNode(t) testFeaturesBuf := new(bytes.Buffer) require.NoError(t, testFeatures.Encode(testFeaturesBuf)) @@ -649,20 +663,20 @@ func TestNodeUpdateNotification(t *testing.T) { // Adding the edge will add the nodes to the graph, but with no info // except the pubkey known. - if err := ctx.router.AddEdge(edge); err != nil { + if err := ctx.builder.AddEdge(edge); err != nil { t.Fatalf("unable to add edge: %v", err) } // Create a new client to receive notifications. - ntfnClient, err := ctx.router.SubscribeTopology() + ntfnClient, err := ctx.builder.SubscribeTopology() require.NoError(t, err, "unable to subscribe for channel notifications") // Change network topology by adding the updated info for the two nodes // to the channel router. - if err := ctx.router.AddNode(node1); err != nil { + if err := ctx.builder.AddNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - if err := ctx.router.AddNode(node2); err != nil { + if err := ctx.builder.AddNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -756,7 +770,7 @@ func TestNodeUpdateNotification(t *testing.T) { nodeUpdateAnn.LastUpdate = node1.LastUpdate.Add(300 * time.Millisecond) // Add new node topology update to the channel router. - if err := ctx.router.AddNode(&nodeUpdateAnn); err != nil { + if err := ctx.builder.AddNode(&nodeUpdateAnn); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -788,7 +802,7 @@ func TestNotificationCancellation(t *testing.T) { ctx := createTestCtxSingleNode(t, startingBlockHeight) // Create a new client to receive notifications. - ntfnClient, err := ctx.router.SubscribeTopology() + ntfnClient, err := ctx.builder.SubscribeTopology() require.NoError(t, err, "unable to subscribe for channel notifications") // We'll create the utxo for a new channel. @@ -808,10 +822,8 @@ func TestNotificationCancellation(t *testing.T) { // We'll create a fresh new node topology update to feed to the channel // router. - node1, err := createTestNode() - require.NoError(t, err, "unable to create test node") - node2, err := createTestNode() - require.NoError(t, err, "unable to create test node") + node1 := createTestNode(t) + node2 := createTestNode(t) // Before we send the message to the channel router, we'll cancel the // notifications for this client. As a result, the notification @@ -832,15 +844,15 @@ func TestNotificationCancellation(t *testing.T) { } copy(edge.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) copy(edge.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) - if err := ctx.router.AddEdge(edge); err != nil { + if err := ctx.builder.AddEdge(edge); err != nil { t.Fatalf("unable to add edge: %v", err) } - if err := ctx.router.AddNode(node1); err != nil { + if err := ctx.builder.AddNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - if err := ctx.router.AddNode(node2); err != nil { + if err := ctx.builder.AddNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -883,10 +895,8 @@ func TestChannelCloseNotification(t *testing.T) { // Next we'll create two test nodes that the fake channel will be open // between. - node1, err := createTestNode() - require.NoError(t, err, "unable to create test node") - node2, err := createTestNode() - require.NoError(t, err, "unable to create test node") + node1 := createTestNode(t) + node2 := createTestNode(t) // Finally, to conclude our test set up, we'll create a channel // announcement to announce the created channel between the two nodes. @@ -903,13 +913,13 @@ func TestChannelCloseNotification(t *testing.T) { } copy(edge.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) copy(edge.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) - if err := ctx.router.AddEdge(edge); err != nil { + if err := ctx.builder.AddEdge(edge); err != nil { t.Fatalf("unable to add edge: %v", err) } // With the channel edge now in place, we'll subscribe for topology // notifications. - ntfnClient, err := ctx.router.SubscribeTopology() + ntfnClient, err := ctx.builder.SubscribeTopology() require.NoError(t, err, "unable to subscribe for channel notifications") // Next, we'll simulate the closure of our channel by generating a new @@ -999,3 +1009,200 @@ func TestEncodeHexColor(t *testing.T) { } } } + +type testCtx struct { + builder *Builder + + graph *channeldb.ChannelGraph + + aliases map[string]route.Vertex + + privKeys map[string]*btcec.PrivateKey + + channelIDs map[route.Vertex]map[route.Vertex]uint64 + + chain *mockChain + chainView *mockChainView + + notifier *lnmock.ChainNotifier +} + +func (c *testCtx) getChannelIDFromAlias(t *testing.T, a, b string) uint64 { + vertexA, ok := c.aliases[a] + require.True(t, ok, "cannot find aliases for %s", a) + + vertexB, ok := c.aliases[b] + require.True(t, ok, "cannot find aliases for %s", b) + + channelIDMap, ok := c.channelIDs[vertexA] + require.True(t, ok, "cannot find channelID map %s(%s)", vertexA, a) + + channelID, ok := channelIDMap[vertexB] + require.True(t, ok, "cannot find channelID using %s(%s)", vertexB, b) + + return channelID +} + +func createTestCtxSingleNode(t *testing.T, + startingHeight uint32) *testCtx { + + graph, graphBackend, err := makeTestGraph(t, true) + require.NoError(t, err, "failed to make test graph") + + sourceNode := createTestNode(t) + + require.NoError(t, + graph.SetSourceNode(sourceNode), "failed to set source node", + ) + + graphInstance := &testGraphInstance{ + graph: graph, + graphBackend: graphBackend, + } + + return createTestCtxFromGraphInstance( + t, startingHeight, graphInstance, false, + ) +} + +func (c *testCtx) RestartBuilder(t *testing.T) { + c.chainView.Reset() + + selfNode, err := c.graph.SourceNode() + require.NoError(t, err) + + // With the chainView reset, we'll now re-create the builder itself, and + // start it. + builder, err := NewBuilder(&Config{ + SelfNode: selfNode.PubKeyBytes, + Graph: c.graph, + Chain: c.chain, + ChainView: c.chainView, + Notifier: c.builder.cfg.Notifier, + ChannelPruneExpiry: time.Hour * 24, + GraphPruneInterval: time.Hour * 2, + AssumeChannelValid: c.builder.cfg.AssumeChannelValid, + FirstTimePruneDelay: c.builder.cfg.FirstTimePruneDelay, + StrictZombiePruning: c.builder.cfg.StrictZombiePruning, + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, + }) + require.NoError(t, err) + require.NoError(t, builder.Start()) + + // Finally, we'll swap out the pointer in the testCtx with this fresh + // instance of the router. + c.builder = builder +} + +// makeTestGraph creates a new instance of a channeldb.ChannelGraph for testing +// purposes. +func makeTestGraph(t *testing.T, useCache bool) (*channeldb.ChannelGraph, + kvdb.Backend, error) { + + // Create channelgraph for the first time. + backend, backendCleanup, err := kvdb.GetTestBackend(t.TempDir(), "cgr") + if err != nil { + return nil, nil, err + } + + t.Cleanup(backendCleanup) + + opts := channeldb.DefaultOptions() + graph, err := channeldb.NewChannelGraph( + backend, opts.RejectCacheSize, opts.ChannelCacheSize, + opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, + useCache, false, + ) + if err != nil { + return nil, nil, err + } + + return graph, backend, nil +} + +type testGraphInstance struct { + graph *channeldb.ChannelGraph + graphBackend kvdb.Backend + + // aliasMap is a map from a node's alias to its public key. This type is + // provided in order to allow easily look up from the human memorable alias + // to an exact node's public key. + aliasMap map[string]route.Vertex + + // privKeyMap maps a node alias to its private key. This is used to be + // able to mock a remote node's signing behaviour. + privKeyMap map[string]*btcec.PrivateKey + + // channelIDs stores the channel ID for each node. + channelIDs map[route.Vertex]map[route.Vertex]uint64 + + // links maps channel ids to a mock channel update handler. + links map[lnwire.ShortChannelID]htlcswitch.ChannelLink +} + +func createTestCtxFromGraphInstance(t *testing.T, + startingHeight uint32, graphInstance *testGraphInstance, + strictPruning bool) *testCtx { + + return createTestCtxFromGraphInstanceAssumeValid( + t, startingHeight, graphInstance, false, strictPruning, + ) +} + +func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, + startingHeight uint32, graphInstance *testGraphInstance, + assumeValid bool, strictPruning bool) *testCtx { + + // We'll initialize an instance of the channel router with mock + // versions of the chain and channel notifier. As we don't need to test + // any p2p functionality, the peer send and switch send messages won't + // be populated. + chain := newMockChain(startingHeight) + chainView := newMockChainView(chain) + + notifier := &lnmock.ChainNotifier{ + EpochChan: make(chan *chainntnfs.BlockEpoch), + SpendChan: make(chan *chainntnfs.SpendDetail), + ConfChan: make(chan *chainntnfs.TxConfirmation), + } + + selfnode, err := graphInstance.graph.SourceNode() + require.NoError(t, err) + + graphBuilder, err := NewBuilder(&Config{ + SelfNode: selfnode.PubKeyBytes, + Graph: graphInstance.graph, + Chain: chain, + ChainView: chainView, + Notifier: notifier, + ChannelPruneExpiry: time.Hour * 24, + GraphPruneInterval: time.Hour * 2, + AssumeChannelValid: assumeValid, + FirstTimePruneDelay: 0, + StrictZombiePruning: strictPruning, + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, + }) + require.NoError(t, err) + require.NoError(t, graphBuilder.Start()) + + ctx := &testCtx{ + builder: graphBuilder, + graph: graphInstance.graph, + aliases: graphInstance.aliasMap, + privKeys: graphInstance.privKeyMap, + channelIDs: graphInstance.channelIDs, + chain: chain, + chainView: chainView, + notifier: notifier, + } + + t.Cleanup(func() { + graphBuilder.Stop() + }) + + return ctx +} diff --git a/graph/setup_test.go b/graph/setup_test.go new file mode 100644 index 0000000000..a1e2f4dfff --- /dev/null +++ b/graph/setup_test.go @@ -0,0 +1,11 @@ +package graph + +import ( + "testing" + + "github.com/lightningnetwork/lnd/kvdb" +) + +func TestMain(m *testing.M) { + kvdb.RunTests(m) +} diff --git a/routing/stats.go b/graph/stats.go similarity index 98% rename from routing/stats.go rename to graph/stats.go index b960025c1c..91e897ae53 100644 --- a/routing/stats.go +++ b/graph/stats.go @@ -1,4 +1,4 @@ -package routing +package graph import ( "fmt" diff --git a/graph/testdata/basic_graph.json b/graph/testdata/basic_graph.json new file mode 100644 index 0000000000..7e4e3636ed --- /dev/null +++ b/graph/testdata/basic_graph.json @@ -0,0 +1,298 @@ +{ + "info": [ + "This file encodes a basic graph that resembles the following ascii graph:", + "", +" 50k satoshis ┌──────┐ ", +" ┌───────────────────▶│luo ji│◀─┐ ", +" │ └──────┘ │ ┌──────┐ ", +" │ │ | elst | ", +" │ │ └──────┘ ", +" │ │ ▲ ", +" │ │ | 100k sat ", +" │ │ ▼ ", +" ▼ │ ┌──────┐ ", +" ┌────────┐ │ │sophon│◀┐ ", +" │satoshi │ │ └──────┘ │ ", +" └────────┘ │ ▲ │ ", +" ▲ │ | │ 110k satoshis ", +" │ ┌───────────────────┘ | │ ", +" │ │ 100k satoshis | │ ", +" │ │ | │ ", +" │ │ 120k sat | │ ┌────────┐ ", +" └──────────┤ (hi fee) ▼ └─▶│son goku│ ", +" 10k satoshis │ ┌────────────┐ └────────┘ ", +" │ | pham nuwen | ▲ ", +" │ └────────────┘ │ ", +" │ ▲ │ ", +" ▼ | 120k sat (hi fee) │ ", +" ┌──────────┐ | │ ", +" │ roasbeef │◀──────────────┴──────────────────────┘ ", +" └──────────┘ 100k satoshis ", + +" the graph also includes a channel from roasbeef to sophon via pham nuwen" + ], + "nodes": [ + { + "source": true, + "pubkey": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "alias": "roasbeef" + }, + { + "source": false, + "pubkey": "026c43a8ac1cd8519985766e90748e1e06871dab0ff6b8af27e8c1a61640481318", + "privkey": "82b266f659bd83a976bac11b2cc442baec5508e84e61085d7ec2b0fc52156c87", + "alias": "songoku" + }, + { + "source": false, + "pubkey": "03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99", + "alias": "satoshi" + }, + { + "source": false, + "pubkey": "02e7b1aaac10977c38e9c61c74dc66840de211bcec3021603e7977bc5e28edabfd", + "alias": "luoji" + }, + { + "source": false, + "pubkey": "036264734b40c9e91d3d990a8cdfbbe23b5b0b7ad3cd0e080a25dcd05d39eeb7eb", + "alias": "sophon" + }, + { + "source": false, + "pubkey": "02a1d2856be336a58af08989aea0d8c41e072ccc392c46f8ce0e6e069f002035f3", + "alias": "phamnuwen" + }, + { + "source": false, + "pubkey": "02a4b236b69b09b8efe6ccf822fa95ee95a0196451f4d066a450b7489e2e354a64", + "alias": "elst" + } + ], + "edges": [ + { + "node_1": "02a4b236b69b09b8efe6ccf822fa95ee95a0196451f4d066a450b7489e2e354a64", + "node_2": "036264734b40c9e91d3d990a8cdfbbe23b5b0b7ad3cd0e080a25dcd05d39eeb7eb", + "channel_id": 15433, + "channel_point": "33bd5d49a50e284221561b91e781f1fca0d60341c9f9dd785b5e379a6d88af3d:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1000, + "max_htlc": 100000000, + "fee_base_msat": 200, + "fee_rate": 0, + "capacity": 100000 + }, + { + "node_1": "02a4b236b69b09b8efe6ccf822fa95ee95a0196451f4d066a450b7489e2e354a64", + "node_2": "036264734b40c9e91d3d990a8cdfbbe23b5b0b7ad3cd0e080a25dcd05d39eeb7eb", + "channel_id": 15433, + "channel_point": "33bd5d49a50e284221561b91e781f1fca0d60341c9f9dd785b5e379a6d88af3d:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1000, + "max_htlc": 100000000, + "fee_base_msat": 200, + "fee_rate": 0, + "capacity": 100000 + }, + { + "node_1": "02a1d2856be336a58af08989aea0d8c41e072ccc392c46f8ce0e6e069f002035f3", + "node_2": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "channel_id": 999991, + "channel_point": "48a0e8b856fef01d9feda7d25a4fac6dae48749e28ba356b92d712ab7f5bd2d0:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1000, + "max_htlc": 120000000, + "fee_base_msat": 10000, + "fee_rate": 100000, + "capacity": 120000 + }, + { + "node_1": "02a1d2856be336a58af08989aea0d8c41e072ccc392c46f8ce0e6e069f002035f3", + "node_2": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "channel_id": 999991, + "channel_point": "48a0e8b856fef01d9feda7d25a4fac6dae48749e28ba356b92d712ab7f5bd2d0:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1000, + "max_htlc": 120000000, + "fee_base_msat": 10000, + "fee_rate": 100000, + "capacity": 120000 + }, + { + "node_1": "02a1d2856be336a58af08989aea0d8c41e072ccc392c46f8ce0e6e069f002035f3", + "node_2": "036264734b40c9e91d3d990a8cdfbbe23b5b0b7ad3cd0e080a25dcd05d39eeb7eb", + "channel_id": 99999, + "channel_point": "05ffda8890d0a4fffe0ddca0b1932ba0415b1d5868a99515384a4e7883d96b88:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1000, + "max_htlc": 120000000, + "fee_base_msat": 10000, + "fee_rate": 100000, + "capacity": 120000 + }, + { + "node_1": "02a1d2856be336a58af08989aea0d8c41e072ccc392c46f8ce0e6e069f002035f3", + "node_2": "036264734b40c9e91d3d990a8cdfbbe23b5b0b7ad3cd0e080a25dcd05d39eeb7eb", + "channel_id": 99999, + "channel_point": "05ffda8890d0a4fffe0ddca0b1932ba0415b1d5868a99515384a4e7883d96b88:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1000, + "max_htlc": 120000000, + "fee_base_msat": 10000, + "fee_rate": 100000, + "capacity": 120000 + }, + { + "node_1": "026c43a8ac1cd8519985766e90748e1e06871dab0ff6b8af27e8c1a61640481318", + "node_2": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "channel_id": 12345, + "channel_point": "89dc56859c6a082d15ba1a7f6cb6be3fea62e1746e2cb8497b1189155c21a233:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1000, + "max_htlc": 100000000, + "fee_base_msat": 10, + "fee_rate": 1000, + "capacity": 100000 + }, + { + "node_1": "026c43a8ac1cd8519985766e90748e1e06871dab0ff6b8af27e8c1a61640481318", + "node_2": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "channel_id": 12345, + "channel_point": "89dc56859c6a082d15ba1a7f6cb6be3fea62e1746e2cb8497b1189155c21a233:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 10, + "fee_rate": 1000, + "capacity": 100000 + }, + { + "node_1": "026c43a8ac1cd8519985766e90748e1e06871dab0ff6b8af27e8c1a61640481318", + "node_2": "036264734b40c9e91d3d990a8cdfbbe23b5b0b7ad3cd0e080a25dcd05d39eeb7eb", + "channel_id": 3495345, + "channel_point": "9f155756b33a0a6827713965babbd561b55f9520444ac5db0cf7cb2eb0deb5bc:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1, + "max_htlc": 110000000, + "fee_base_msat": 10, + "fee_rate": 1000, + "capacity": 110000 + }, + { + "node_1": "026c43a8ac1cd8519985766e90748e1e06871dab0ff6b8af27e8c1a61640481318", + "node_2": "036264734b40c9e91d3d990a8cdfbbe23b5b0b7ad3cd0e080a25dcd05d39eeb7eb", + "channel_id": 3495345, + "channel_point": "9f155756b33a0a6827713965babbd561b55f9520444ac5db0cf7cb2eb0deb5bc:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1, + "max_htlc": 110000000, + "fee_base_msat": 10, + "fee_rate": 1000, + "capacity": 110000 + }, + { + "node_1": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "node_2": "03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99", + "channel_id": 2340213491, + "channel_point": "72cd6e8422c407fb6d098690f1130b7ded7ec2f7f5e1d30bd9d521f015363793:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1, + "max_htlc": 10000000, + "fee_base_msat": 10, + "fee_rate": 1000, + "capacity": 10000 + }, + { + "node_1": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "node_2": "03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99", + "channel_id": 2340213491, + "channel_point": "72cd6e8422c407fb6d098690f1130b7ded7ec2f7f5e1d30bd9d521f015363793:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1, + "max_htlc": 10000000, + "fee_base_msat": 10, + "fee_rate": 1000, + "capacity": 10000 + }, + { + "node_1": "02e7b1aaac10977c38e9c61c74dc66840de211bcec3021603e7977bc5e28edabfd", + "node_2": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "channel_id": 689530843, + "channel_point": "25376aa6cb81913ad30416bd22d4083241bd6d68e811d0284d3c3a17795c458a:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 10, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 10, + "fee_rate": 1000, + "capacity": 100000 + }, + { + "node_1": "02e7b1aaac10977c38e9c61c74dc66840de211bcec3021603e7977bc5e28edabfd", + "node_2": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "channel_id": 689530843, + "channel_point": "25376aa6cb81913ad30416bd22d4083241bd6d68e811d0284d3c3a17795c458a:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 10, + "fee_rate": 1000, + "capacity": 100000 + }, + { + "node_1": "02e7b1aaac10977c38e9c61c74dc66840de211bcec3021603e7977bc5e28edabfd", + "node_2": "03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99", + "channel_id": 523452362, + "channel_point": "704a5675c91b1c674309a6475fc51072c2913d6117ee6103c9f1b86956bcbe02:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1, + "max_htlc": 50000000, + "fee_base_msat": 10, + "fee_rate": 1000, + "capacity": 50000 + }, + { + "node_1": "02e7b1aaac10977c38e9c61c74dc66840de211bcec3021603e7977bc5e28edabfd", + "node_2": "03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99", + "channel_id": 523452362, + "channel_point": "704a5675c91b1c674309a6475fc51072c2913d6117ee6103c9f1b86956bcbe02:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 1, + "min_htlc": 1, + "max_htlc": 50000000, + "fee_base_msat": 10, + "fee_rate": 1000, + "capacity": 50000 + } + ] +} diff --git a/graph/testdata/spec_example.json b/graph/testdata/spec_example.json new file mode 100644 index 0000000000..f0a730c3a9 --- /dev/null +++ b/graph/testdata/spec_example.json @@ -0,0 +1,147 @@ +{ + "nodes": [ + { + "source": false, + "pubkey": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "alias": "A" + }, + { + "source": true, + "pubkey": "032b480de5d002f1a8fd1fe1bbf0a0f1b07760f65f052e66d56f15d71097c01add", + "alias": "B" + }, + { + "source": false, + "pubkey": "03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99", + "alias": "C" + }, + { + "source": false, + "pubkey": "02e7b1aaac10977c38e9c61c74dc66840de211bcec3021603e7977bc5e28edabfd", + "alias": "D" + } + ], + "edges": [ + { + + "comment": "A -> B channel", + "node_1": "032b480de5d002f1a8fd1fe1bbf0a0f1b07760f65f052e66d56f15d71097c01add", + "node_2": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "channel_id": 12345, + "channel_point": "89dc56859c6a082d15ba1a7f6cb6be3fea62e1746e2cb8497b1189155c21a233:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 10, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 100, + "fee_rate": 1000, + "capacity": 100000 + }, + { + "comment": "B -> A channel", + "node_1": "032b480de5d002f1a8fd1fe1bbf0a0f1b07760f65f052e66d56f15d71097c01add", + "node_2": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "channel_id": 12345, + "channel_point": "89dc56859c6a082d15ba1a7f6cb6be3fea62e1746e2cb8497b1189155c21a233:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 20, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 200, + "fee_rate": 2000, + "capacity": 100000 + }, + { + "comment": "A -> D channel", + "node_1": "02e7b1aaac10977c38e9c61c74dc66840de211bcec3021603e7977bc5e28edabfd", + "node_2": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "channel_id": 12345839, + "channel_point": "89dc56859c6a082d15ba1a7f6cb6be3fea62e1746e2cb8497b1189155c21a233:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 10, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 100, + "fee_rate": 1000, + "capacity": 100000 + }, + { + "comment": "D -> A channel", + "node_1": "02e7b1aaac10977c38e9c61c74dc66840de211bcec3021603e7977bc5e28edabfd", + "node_2": "0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6", + "channel_id": 12345839, + "channel_point": "89dc56859c6a082d15ba1a7f6cb6be3fea62e1746e2cb8497b1189155c21a233:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 40, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 400, + "fee_rate": 4000, + "capacity": 100000 + }, + { + "comment": "D -> C channel", + "node_1": "02e7b1aaac10977c38e9c61c74dc66840de211bcec3021603e7977bc5e28edabfd", + "node_2": "03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99", + "channel_id": 1234583, + "channel_point": "89dc56859c6a082d15ba1a7f6cb6be3fea62e1746e2cb8497b1189155c21a233:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 40, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 400, + "fee_rate": 4000, + "capacity": 100000 + }, + { + "comment": "C -> D channel", + "node_1": "02e7b1aaac10977c38e9c61c74dc66840de211bcec3021603e7977bc5e28edabfd", + "node_2": "03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99", + "channel_id": 1234583, + "channel_point": "89dc56859c6a082d15ba1a7f6cb6be3fea62e1746e2cb8497b1189155c21a233:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 30, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 300, + "fee_rate": 3000, + "capacity": 100000 + }, + { + "comment": "C -> B channel", + "node_1": "032b480de5d002f1a8fd1fe1bbf0a0f1b07760f65f052e66d56f15d71097c01add", + "node_2": "03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99", + "channel_id": 1234589, + "channel_point": "89dc56859c6a082d15ba1a7f6cb6be3fea62e1746e2cb8497b1189155c21a233:0", + "channel_flags": 1, + "message_flags": 1, + "expiry": 30, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 300, + "fee_rate": 3000, + "capacity": 100000 + }, + { + "comment": "B -> C channel", + "node_1": "032b480de5d002f1a8fd1fe1bbf0a0f1b07760f65f052e66d56f15d71097c01add", + "node_2": "03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99", + "channel_id": 1234589, + "channel_point": "89dc56859c6a082d15ba1a7f6cb6be3fea62e1746e2cb8497b1189155c21a233:0", + "channel_flags": 0, + "message_flags": 1, + "expiry": 20, + "min_htlc": 1, + "max_htlc": 100000000, + "fee_base_msat": 200, + "fee_rate": 2000, + "capacity": 100000 + } + ] +} diff --git a/routing/validation_barrier.go b/graph/validation_barrier.go similarity index 99% rename from routing/validation_barrier.go rename to graph/validation_barrier.go index aeef3d4b81..2f3c8c02ce 100644 --- a/routing/validation_barrier.go +++ b/graph/validation_barrier.go @@ -1,4 +1,4 @@ -package routing +package graph import ( "fmt" diff --git a/routing/validation_barrier_test.go b/graph/validation_barrier_test.go similarity index 91% rename from routing/validation_barrier_test.go rename to graph/validation_barrier_test.go index 2eda0120fc..da404443f5 100644 --- a/routing/validation_barrier_test.go +++ b/graph/validation_barrier_test.go @@ -1,12 +1,12 @@ -package routing_test +package graph_test import ( "encoding/binary" "testing" "time" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/routing" ) // TestValidationBarrierSemaphore checks basic properties of the validation @@ -21,7 +21,7 @@ func TestValidationBarrierSemaphore(t *testing.T) { ) quit := make(chan struct{}) - barrier := routing.NewValidationBarrier(numTasks, quit) + barrier := graph.NewValidationBarrier(numTasks, quit) // Saturate the semaphore with jobs. for i := 0; i < numTasks; i++ { @@ -69,7 +69,7 @@ func TestValidationBarrierQuit(t *testing.T) { ) quit := make(chan struct{}) - barrier := routing.NewValidationBarrier(2*numTasks, quit) + barrier := graph.NewValidationBarrier(2*numTasks, quit) // Create a set of unique channel announcements that we will prep for // validation. @@ -141,8 +141,8 @@ func TestValidationBarrierQuit(t *testing.T) { switch { // First half should return without failure. - case i < numTasks/4 && !routing.IsError( - err, routing.ErrParentValidationFailed, + case i < numTasks/4 && !graph.IsError( + err, graph.ErrParentValidationFailed, ): t.Fatalf("unexpected failure while waiting: %v", err) @@ -150,11 +150,11 @@ func TestValidationBarrierQuit(t *testing.T) { t.Fatalf("unexpected failure while waiting: %v", err) // Last half should return the shutdown error. - case i >= numTasks/2 && !routing.IsError( - err, routing.ErrVBarrierShuttingDown, + case i >= numTasks/2 && !graph.IsError( + err, graph.ErrVBarrierShuttingDown, ): t.Fatalf("expected failure after quitting: want %v, "+ - "got %v", routing.ErrVBarrierShuttingDown, err) + "got %v", graph.ErrVBarrierShuttingDown, err) } } } diff --git a/netann/channel_update_test.go b/netann/channel_update_test.go index e49e5c65e8..7af51effc0 100644 --- a/netann/channel_update_test.go +++ b/netann/channel_update_test.go @@ -7,11 +7,11 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" - "github.com/lightningnetwork/lnd/routing" ) type mockSigner struct { @@ -182,7 +182,7 @@ func TestUpdateDisableFlag(t *testing.T) { // Finally, validate the signature using the router's // verification logic. - err = routing.VerifyChannelUpdateSignature( + err = graph.VerifyChannelUpdateSignature( newUpdate, pubKey, ) if err != nil { diff --git a/pilot.go b/pilot.go index 380dc5f092..2a37b080d0 100644 --- a/pilot.go +++ b/pilot.go @@ -295,6 +295,6 @@ func initAutoPilot(svr *server, cfg *lncfg.AutoPilot, }, nil }, SubscribeTransactions: svr.cc.Wallet.SubscribeTransactions, - SubscribeTopology: svr.chanRouter.SubscribeTopology, + SubscribeTopology: svr.graphBuilder.SubscribeTopology, }, nil } diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index e48f3c7fbd..e4912d988b 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -1392,10 +1392,6 @@ func TestNewRoute(t *testing.T) { // to fail or succeed. expectError bool - // expectedErrorCode indicates the expected error code when - // expectError is true. - expectedErrorCode errorCode - expectedMPP *record.MPP }{ { @@ -1606,23 +1602,9 @@ func TestNewRoute(t *testing.T) { metadata: testCase.metadata, }, nil, ) + require.NoError(t, err) - if testCase.expectError { - expectedCode := testCase.expectedErrorCode - if err == nil || !IsError(err, expectedCode) { - t.Fatalf("expected newRoute to fail "+ - "with error code %v but got "+ - "%v instead", - expectedCode, err) - } - } else { - if err != nil { - t.Errorf("unable to create path: %v", err) - return - } - - assertRoute(t, route) - } + assertRoute(t, route) }) } } @@ -2232,8 +2214,8 @@ func TestPathFindSpecExample(t *testing.T) { carol := ctx.aliases["C"] const amt lnwire.MilliSatoshi = 4999999 req, err := NewRouteRequest( - bob, &carol, amt, 0, noRestrictions, nil, nil, nil, - MinCLTVDelta, + bob, &carol, amt, 0, noRestrictions, nil, nil, + nil, MinCLTVDelta, ) require.NoError(t, err, "invalid route request") @@ -2244,33 +2226,18 @@ func TestPathFindSpecExample(t *testing.T) { // // It should be sending the exact payment amount as there are no // additional hops. - if route.TotalAmount != amt { - t.Fatalf("wrong total amount: got %v, expected %v", - route.TotalAmount, amt) - } - if route.Hops[0].AmtToForward != amt { - t.Fatalf("wrong forward amount: got %v, expected %v", - route.Hops[0].AmtToForward, amt) - } - - fee := route.HopFee(0) - if fee != 0 { - t.Fatalf("wrong hop fee: got %v, expected %v", fee, 0) - } + require.Equal(t, amt, route.TotalAmount) + require.Equal(t, amt, route.Hops[0].AmtToForward) + require.Zero(t, route.HopFee(0)) // The CLTV expiry should be the current height plus 18 (the expiry for // the B -> C channel. - if route.TotalTimeLock != - startingHeight+MinCLTVDelta { - - t.Fatalf("wrong total time lock: got %v, expecting %v", - route.TotalTimeLock, - startingHeight+MinCLTVDelta) - } + require.EqualValues(t, startingHeight+MinCLTVDelta, route.TotalTimeLock) // Next, we'll set A as the source node so we can assert that we create // the proper route for any queries starting with Alice. alice := ctx.aliases["A"] + ctx.router.cfg.SelfNode = alice // We'll now request a route from A -> B -> C. req, err = NewRouteRequest( @@ -2283,32 +2250,21 @@ func TestPathFindSpecExample(t *testing.T) { require.NoError(t, err, "unable to find routes") // The route should be two hops. - if len(route.Hops) != 2 { - t.Fatalf("route should be %v hops, is instead %v", 2, - len(route.Hops)) - } + require.Len(t, route.Hops, 2) // The total amount should factor in a fee of 10199 and also use a CLTV // delta total of 38 (20 + 18), expectedAmt := lnwire.MilliSatoshi(5010198) - if route.TotalAmount != expectedAmt { - t.Fatalf("wrong amount: got %v, expected %v", - route.TotalAmount, expectedAmt) - } + require.Equal(t, expectedAmt, route.TotalAmount) + expectedDelta := uint32(20 + MinCLTVDelta) - if route.TotalTimeLock != startingHeight+expectedDelta { - t.Fatalf("wrong total time lock: got %v, expecting %v", - route.TotalTimeLock, startingHeight+expectedDelta) - } + require.Equal(t, startingHeight+expectedDelta, route.TotalTimeLock) // Ensure that the hops of the route are properly crafted. // // After taking the fee, Bob should be forwarding the remainder which // is the exact payment to Bob. - if route.Hops[0].AmtToForward != amt { - t.Fatalf("wrong forward amount: got %v, expected %v", - route.Hops[0].AmtToForward, amt) - } + require.Equal(t, amt, route.Hops[0].AmtToForward) // We shouldn't pay any fee for the first, hop, but the fee for the // second hop posted fee should be exactly: @@ -2317,59 +2273,31 @@ func TestPathFindSpecExample(t *testing.T) { // hop, so we should get a fee of exactly: // // * 200 + 4999999 * 2000 / 1000000 = 10199 - - fee = route.HopFee(0) - if fee != 10199 { - t.Fatalf("wrong hop fee: got %v, expected %v", fee, 10199) - } + require.EqualValues(t, 10199, route.HopFee(0)) // While for the final hop, as there's no additional hop afterwards, we // pay no fee. - fee = route.HopFee(1) - if fee != 0 { - t.Fatalf("wrong hop fee: got %v, expected %v", fee, 0) - } + require.Zero(t, route.HopFee(1)) // The outgoing CLTV value itself should be the current height plus 30 // to meet Carol's requirements. - if route.Hops[0].OutgoingTimeLock != - startingHeight+MinCLTVDelta { - - t.Fatalf("wrong total time lock: got %v, expecting %v", - route.Hops[0].OutgoingTimeLock, - startingHeight+MinCLTVDelta) - } + require.EqualValues(t, startingHeight+MinCLTVDelta, + route.Hops[0].OutgoingTimeLock) // For B -> C, we assert that the final hop also has the proper // parameters. lastHop := route.Hops[1] - if lastHop.AmtToForward != amt { - t.Fatalf("wrong forward amount: got %v, expected %v", - lastHop.AmtToForward, amt) - } - if lastHop.OutgoingTimeLock != - startingHeight+MinCLTVDelta { - - t.Fatalf("wrong total time lock: got %v, expecting %v", - lastHop.OutgoingTimeLock, - startingHeight+MinCLTVDelta) - } + require.EqualValues(t, amt, lastHop.AmtToForward) + require.EqualValues(t, startingHeight+MinCLTVDelta, lastHop.OutgoingTimeLock) } func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, path []*unifiedEdge, nodeAliases ...string) { - if len(path) != len(nodeAliases) { - t.Fatalf("number of hops=(%v) and number of aliases=(%v) do "+ - "not match", len(path), len(nodeAliases)) - } + require.Len(t, path, len(nodeAliases)) for i, hop := range path { - if hop.policy.ToNodePubKey() != aliasMap[nodeAliases[i]] { - t.Fatalf("expected %v to be pos #%v in hop, instead "+ - "%v was", nodeAliases[i], i, - hop.policy.ToNodePubKey()) - } + require.Equal(t, aliasMap[nodeAliases[i]], hop.policy.ToNodePubKey()) } } @@ -2380,9 +2308,7 @@ func TestNewRouteFromEmptyHops(t *testing.T) { var source route.Vertex _, err := route.NewRouteFromHops(0, 0, source, []*route.Hop{}) - if err != route.ErrNoRouteHopsProvided { - t.Fatalf("expected empty hops error: instead got: %v", err) - } + require.ErrorIs(t, err, route.ErrNoRouteHopsProvided) } // runRestrictOutgoingChannel asserts that a outgoing channel restriction is @@ -2425,11 +2351,6 @@ func runRestrictOutgoingChannel(t *testing.T, useCache bool) { ctx := newPathFindingTestContext(t, useCache, testChannels, "roasbeef") - const ( - startingHeight = 100 - finalHopCLTV = 1 - ) - paymentAmt := lnwire.NewMSatFromSatoshis(100) target := ctx.keyFromAlias("target") outgoingChannelID := uint64(chanSourceB1) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 8769ca5d31..5e7f8b4918 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -912,7 +912,7 @@ func (p *paymentLifecycle) handleFailureMessage(rt *route.Route, } // Apply channel update to the channel edge policy in our db. - if !p.router.applyChannelUpdate(update) { + if !p.router.cfg.ApplyChannelUpdate(update) { log.Debugf("Invalid channel update received: node=%v", errVertex) } diff --git a/routing/payment_session.go b/routing/payment_session.go index 0d46f71199..84f2135d79 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -9,6 +9,7 @@ import ( "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -412,7 +413,7 @@ func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool { // Validate the message signature. - if err := VerifyChannelUpdateSignature(msg, pubKey); err != nil { + if err := graph.VerifyChannelUpdateSignature(msg, pubKey); err != nil { log.Errorf( "Unable to validate channel update signature: %v", err, ) diff --git a/routing/router.go b/routing/router.go index 276744d1b3..293f2c9fc7 100644 --- a/routing/router.go +++ b/routing/router.go @@ -5,41 +5,27 @@ import ( "context" "fmt" "math" - "runtime" - "strings" "sync" "sync/atomic" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" - "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/amp" - "github.com/lightningnetwork/lnd/batch" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/htlcswitch" - "github.com/lightningnetwork/lnd/input" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" - "github.com/lightningnetwork/lnd/lnwallet/btcwallet" - "github.com/lightningnetwork/lnd/lnwallet/chanvalidate" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/multimutex" "github.com/lightningnetwork/lnd/record" - "github.com/lightningnetwork/lnd/routing/chainview" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/shards" - "github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/zpay32" ) @@ -49,21 +35,6 @@ const ( // trying more routes for a payment. DefaultPayAttemptTimeout = time.Second * 60 - // DefaultChannelPruneExpiry is the default duration used to determine - // if a channel should be pruned or not. - DefaultChannelPruneExpiry = time.Hour * 24 * 14 - - // DefaultFirstTimePruneDelay is the time we'll wait after startup - // before attempting to prune the graph for zombie channels. We don't - // do it immediately after startup to allow lnd to start up without - // getting blocked by this job. - DefaultFirstTimePruneDelay = 30 * time.Second - - // defaultStatInterval governs how often the router will log non-empty - // stats related to processing new channels, updates, or node - // announcements. - defaultStatInterval = time.Minute - // MinCLTVDelta is the minimum CLTV value accepted by LND for all // timelock deltas. This includes both forwarding CLTV deltas set on // channel updates, as well as final CLTV deltas used to create BOLT 11 @@ -251,24 +222,11 @@ type Config struct { // RoutingGraph is a graph source that will be used for pathfinding. RoutingGraph Graph - // Graph is the channel graph that the ChannelRouter will use to gather - // metrics from and also to carry out path finding queries. - Graph graph.DB - // Chain is the router's source to the most up-to-date blockchain data. // All incoming advertised channels will be checked against the chain // to ensure that the channels advertised are still open. Chain lnwallet.BlockChainIO - // ChainView is an instance of a FilteredChainView which is used to - // watch the sub-set of the UTXO set (the set of active channels) that - // we need in order to properly maintain the channel graph. - ChainView chainview.FilteredChainView - - // Notifier is a reference to the ChainNotifier, used to grab - // the latest blocks if the router is missing any. - Notifier chainntnfs.ChainNotifier - // Payer is an instance of a PaymentAttemptDispatcher and is used by // the router to send payment attempts onto the network, and receive // their results. @@ -291,22 +249,6 @@ type Config struct { // sessions. SessionSource PaymentSessionSource - // ChannelPruneExpiry is the duration used to determine if a channel - // should be pruned or not. If the delta between now and when the - // channel was last updated is greater than ChannelPruneExpiry, then - // the channel is marked as a zombie channel eligible for pruning. - ChannelPruneExpiry time.Duration - - // GraphPruneInterval is used as an interval to determine how often we - // should examine the channel graph to garbage collect zombie channels. - GraphPruneInterval time.Duration - - // FirstTimePruneDelay is the time we'll wait after startup before - // attempting to prune the graph for zombie channels. We don't do it - // immediately after startup to allow lnd to start up without getting - // blocked by this job. - FirstTimePruneDelay time.Duration - // QueryBandwidth is a method that allows the router to query the lower // link layer to determine the up-to-date available bandwidth at a // prospective link to be traversed. If the link isn't available, then @@ -321,1534 +263,178 @@ type Config struct { // the switch can properly handle the HTLC. NextPaymentID func() (uint64, error) - // AssumeChannelValid toggles whether the router will check for - // spentness of channel outpoints. For neutrino, this saves long rescans - // from blocking initial usage of the daemon. - AssumeChannelValid bool - // PathFindingConfig defines global path finding parameters. PathFindingConfig PathFindingConfig // Clock is mockable time provider. Clock clock.Clock - // StrictZombiePruning determines if we attempt to prune zombie - // channels according to a stricter criteria. If true, then we'll prune - // a channel if only *one* of the edges is considered a zombie. - // Otherwise, we'll only prune the channel when both edges have a very - // dated last update. - StrictZombiePruning bool - - // IsAlias returns whether a passed ShortChannelID is an alias. This is - // only used for our local channels. - IsAlias func(scid lnwire.ShortChannelID) bool + // ApplyChannelUpdate can be called to apply a new channel update to the + // graph that we received from a payment failure. + ApplyChannelUpdate func(msg *lnwire.ChannelUpdate) bool } // EdgeLocator is a struct used to identify a specific edge. type EdgeLocator struct { - // ChannelID is the channel of this edge. - ChannelID uint64 - - // Direction takes the value of 0 or 1 and is identical in definition to - // the channel direction flag. A value of 0 means the direction from the - // lower node pubkey to the higher. - Direction uint8 -} - -// String returns a human-readable version of the edgeLocator values. -func (e *EdgeLocator) String() string { - return fmt.Sprintf("%v:%v", e.ChannelID, e.Direction) -} - -// ChannelRouter is the layer 3 router within the Lightning stack. Below the -// ChannelRouter is the HtlcSwitch, and below that is the Bitcoin blockchain -// itself. The primary role of the ChannelRouter is to respond to queries for -// potential routes that can support a payment amount, and also general graph -// reachability questions. The router will prune the channel graph -// automatically as new blocks are discovered which spend certain known funding -// outpoints, thereby closing their respective channels. -type ChannelRouter struct { - ntfnClientCounter uint64 // To be used atomically. - - started uint32 // To be used atomically. - stopped uint32 // To be used atomically. - - bestHeight uint32 // To be used atomically. - - // cfg is a copy of the configuration struct that the ChannelRouter was - // initialized with. - cfg *Config - - // newBlocks is a channel in which new blocks connected to the end of - // the main chain are sent over, and blocks updated after a call to - // UpdateFilter. - newBlocks <-chan *chainview.FilteredBlock - - // staleBlocks is a channel in which blocks disconnected from the end - // of our currently known best chain are sent over. - staleBlocks <-chan *chainview.FilteredBlock - - // networkUpdates is a channel that carries new topology updates - // messages from outside the ChannelRouter to be processed by the - // networkHandler. - networkUpdates chan *routingMsg - - // topologyClients maps a client's unique notification ID to a - // topologyClient client that contains its notification dispatch - // channel. - topologyClients *lnutils.SyncMap[uint64, *topologyClient] - - // ntfnClientUpdates is a channel that's used to send new updates to - // topology notification clients to the ChannelRouter. Updates either - // add a new notification client, or cancel notifications for an - // existing client. - ntfnClientUpdates chan *topologyClientUpdate - - // channelEdgeMtx is a mutex we use to make sure we process only one - // ChannelEdgePolicy at a time for a given channelID, to ensure - // consistency between the various database accesses. - channelEdgeMtx *multimutex.Mutex[uint64] - - // statTicker is a resumable ticker that logs the router's progress as - // it discovers channels or receives updates. - statTicker ticker.Ticker - - // stats tracks newly processed channels, updates, and node - // announcements over a window of defaultStatInterval. - stats *routerStats - - quit chan struct{} - wg sync.WaitGroup -} - -// A compile time check to ensure ChannelRouter implements the -// ChannelGraphSource interface. -var _ graph.ChannelGraphSource = (*ChannelRouter)(nil) - -// New creates a new instance of the ChannelRouter with the specified -// configuration parameters. As part of initialization, if the router detects -// that the channel graph isn't fully in sync with the latest UTXO (since the -// channel graph is a subset of the UTXO set) set, then the router will proceed -// to fully sync to the latest state of the UTXO set. -func New(cfg Config) (*ChannelRouter, error) { - r := &ChannelRouter{ - cfg: &cfg, - networkUpdates: make(chan *routingMsg), - topologyClients: &lnutils.SyncMap[uint64, *topologyClient]{}, - ntfnClientUpdates: make(chan *topologyClientUpdate), - channelEdgeMtx: multimutex.NewMutex[uint64](), - statTicker: ticker.New(defaultStatInterval), - stats: new(routerStats), - quit: make(chan struct{}), - } - - return r, nil -} - -// Start launches all the goroutines the ChannelRouter requires to carry out -// its duties. If the router has already been started, then this method is a -// noop. -func (r *ChannelRouter) Start() error { - if !atomic.CompareAndSwapUint32(&r.started, 0, 1) { - return nil - } - - log.Info("Channel Router starting") - - bestHash, bestHeight, err := r.cfg.Chain.GetBestBlock() - if err != nil { - return err - } - - // If the graph has never been pruned, or hasn't fully been created yet, - // then we don't treat this as an explicit error. - if _, _, err := r.cfg.Graph.PruneTip(); err != nil { - switch { - case errors.Is(err, channeldb.ErrGraphNeverPruned): - fallthrough - - case errors.Is(err, channeldb.ErrGraphNotFound): - // If the graph has never been pruned, then we'll set - // the prune height to the current best height of the - // chain backend. - _, err = r.cfg.Graph.PruneGraph( - nil, bestHash, uint32(bestHeight), - ) - if err != nil { - return err - } - - default: - return err - } - } - - // If AssumeChannelValid is present, then we won't rely on pruning - // channels from the graph based on their spentness, but whether they - // are considered zombies or not. We will start zombie pruning after a - // small delay, to avoid slowing down startup of lnd. - if r.cfg.AssumeChannelValid { - time.AfterFunc(r.cfg.FirstTimePruneDelay, func() { - select { - case <-r.quit: - return - default: - } - - log.Info("Initial zombie prune starting") - if err := r.pruneZombieChans(); err != nil { - log.Errorf("Unable to prune zombies: %v", err) - } - }) - } else { - // Otherwise, we'll use our filtered chain view to prune - // channels as soon as they are detected as spent on-chain. - if err := r.cfg.ChainView.Start(); err != nil { - return err - } - - // Once the instance is active, we'll fetch the channel we'll - // receive notifications over. - r.newBlocks = r.cfg.ChainView.FilteredBlocks() - r.staleBlocks = r.cfg.ChainView.DisconnectedBlocks() - - // Before we perform our manual block pruning, we'll construct - // and apply a fresh chain filter to the active - // FilteredChainView instance. We do this before, as otherwise - // we may miss on-chain events as the filter hasn't properly - // been applied. - channelView, err := r.cfg.Graph.ChannelView() - if err != nil && !errors.Is( - err, channeldb.ErrGraphNoEdgesFound, - ) { - - return err - } - - log.Infof("Filtering chain using %v channels active", - len(channelView)) - - if len(channelView) != 0 { - err = r.cfg.ChainView.UpdateFilter( - channelView, uint32(bestHeight), - ) - if err != nil { - return err - } - } - - // The graph pruning might have taken a while and there could be - // new blocks available. - _, bestHeight, err = r.cfg.Chain.GetBestBlock() - if err != nil { - return err - } - r.bestHeight = uint32(bestHeight) - - // Before we begin normal operation of the router, we first need - // to synchronize the channel graph to the latest state of the - // UTXO set. - if err := r.syncGraphWithChain(); err != nil { - return err - } - - // Finally, before we proceed, we'll prune any unconnected nodes - // from the graph in order to ensure we maintain a tight graph - // of "useful" nodes. - err = r.cfg.Graph.PruneGraphNodes() - if err != nil && !errors.Is( - err, channeldb.ErrGraphNodesNotFound, - ) { - - return err - } - } - - // If any payments are still in flight, we resume, to make sure their - // results are properly handled. - payments, err := r.cfg.Control.FetchInFlightPayments() - if err != nil { - return err - } - - // Before we restart existing payments and start accepting more - // payments to be made, we clean the network result store of the - // Switch. We do this here at startup to ensure no more payments can be - // made concurrently, so we know the toKeep map will be up-to-date - // until the cleaning has finished. - toKeep := make(map[uint64]struct{}) - for _, p := range payments { - for _, a := range p.HTLCs { - toKeep[a.AttemptID] = struct{}{} - } - } - - log.Debugf("Cleaning network result store.") - if err := r.cfg.Payer.CleanStore(toKeep); err != nil { - return err - } - - for _, payment := range payments { - log.Infof("Resuming payment %v", payment.Info.PaymentIdentifier) - r.wg.Add(1) - go func(payment *channeldb.MPPayment) { - defer r.wg.Done() - - // Get the hashes used for the outstanding HTLCs. - htlcs := make(map[uint64]lntypes.Hash) - for _, a := range payment.HTLCs { - a := a - - // We check whether the individual attempts - // have their HTLC hash set, if not we'll fall - // back to the overall payment hash. - hash := payment.Info.PaymentIdentifier - if a.Hash != nil { - hash = *a.Hash - } - - htlcs[a.AttemptID] = hash - } - - // Since we are not supporting creating more shards - // after a restart (only receiving the result of the - // shards already outstanding), we create a simple - // shard tracker that will map the attempt IDs to - // hashes used for the HTLCs. This will be enough also - // for AMP payments, since we only need the hashes for - // the individual HTLCs to regenerate the circuits, and - // we don't currently persist the root share necessary - // to re-derive them. - shardTracker := shards.NewSimpleShardTracker( - payment.Info.PaymentIdentifier, htlcs, - ) - - // We create a dummy, empty payment session such that - // we won't make another payment attempt when the - // result for the in-flight attempt is received. - paySession := r.cfg.SessionSource.NewPaymentSessionEmpty() - - // We pass in a non-timeout context, to indicate we - // don't need it to timeout. It will stop immediately - // after the existing attempt has finished anyway. We - // also set a zero fee limit, as no more routes should - // be tried. - noTimeout := time.Duration(0) - _, _, err := r.sendPayment( - context.Background(), 0, - payment.Info.PaymentIdentifier, noTimeout, - paySession, shardTracker, - ) - if err != nil { - log.Errorf("Resuming payment %v failed: %v.", - payment.Info.PaymentIdentifier, err) - return - } - - log.Infof("Resumed payment %v completed.", - payment.Info.PaymentIdentifier) - }(payment) - } - - r.wg.Add(1) - go r.networkHandler() - - return nil -} - -// Stop signals the ChannelRouter to gracefully halt all routines. This method -// will *block* until all goroutines have excited. If the channel router has -// already stopped then this method will return immediately. -func (r *ChannelRouter) Stop() error { - if !atomic.CompareAndSwapUint32(&r.stopped, 0, 1) { - return nil - } - - log.Info("Channel Router shutting down...") - defer log.Debug("Channel Router shutdown complete") - - // Our filtered chain view could've only been started if - // AssumeChannelValid isn't present. - if !r.cfg.AssumeChannelValid { - if err := r.cfg.ChainView.Stop(); err != nil { - return err - } - } - - close(r.quit) - r.wg.Wait() - - return nil -} - -// syncGraphWithChain attempts to synchronize the current channel graph with -// the latest UTXO set state. This process involves pruning from the channel -// graph any channels which have been closed by spending their funding output -// since we've been down. -func (r *ChannelRouter) syncGraphWithChain() error { - // First, we'll need to check to see if we're already in sync with the - // latest state of the UTXO set. - bestHash, bestHeight, err := r.cfg.Chain.GetBestBlock() - if err != nil { - return err - } - r.bestHeight = uint32(bestHeight) - - pruneHash, pruneHeight, err := r.cfg.Graph.PruneTip() - if err != nil { - switch { - // If the graph has never been pruned, or hasn't fully been - // created yet, then we don't treat this as an explicit error. - case errors.Is(err, channeldb.ErrGraphNeverPruned): - case errors.Is(err, channeldb.ErrGraphNotFound): - default: - return err - } - } - - log.Infof("Prune tip for Channel Graph: height=%v, hash=%v", - pruneHeight, pruneHash) - - switch { - - // If the graph has never been pruned, then we can exit early as this - // entails it's being created for the first time and hasn't seen any - // block or created channels. - case pruneHeight == 0 || pruneHash == nil: - return nil - - // If the block hashes and heights match exactly, then we don't need to - // prune the channel graph as we're already fully in sync. - case bestHash.IsEqual(pruneHash) && uint32(bestHeight) == pruneHeight: - return nil - } - - // If the main chain blockhash at prune height is different from the - // prune hash, this might indicate the database is on a stale branch. - mainBlockHash, err := r.cfg.Chain.GetBlockHash(int64(pruneHeight)) - if err != nil { - return err - } - - // While we are on a stale branch of the chain, walk backwards to find - // first common block. - for !pruneHash.IsEqual(mainBlockHash) { - log.Infof("channel graph is stale. Disconnecting block %v "+ - "(hash=%v)", pruneHeight, pruneHash) - // Prune the graph for every channel that was opened at height - // >= pruneHeight. - _, err := r.cfg.Graph.DisconnectBlockAtHeight(pruneHeight) - if err != nil { - return err - } - - pruneHash, pruneHeight, err = r.cfg.Graph.PruneTip() - if err != nil { - switch { - // If at this point the graph has never been pruned, we - // can exit as this entails we are back to the point - // where it hasn't seen any block or created channels, - // alas there's nothing left to prune. - case errors.Is(err, channeldb.ErrGraphNeverPruned): - return nil - - case errors.Is(err, channeldb.ErrGraphNotFound): - return nil - - default: - return err - } - } - mainBlockHash, err = r.cfg.Chain.GetBlockHash(int64(pruneHeight)) - if err != nil { - return err - } - } - - log.Infof("Syncing channel graph from height=%v (hash=%v) to height=%v "+ - "(hash=%v)", pruneHeight, pruneHash, bestHeight, bestHash) - - // If we're not yet caught up, then we'll walk forward in the chain - // pruning the channel graph with each new block that hasn't yet been - // consumed by the channel graph. - var spentOutputs []*wire.OutPoint - for nextHeight := pruneHeight + 1; nextHeight <= uint32(bestHeight); nextHeight++ { - // Break out of the rescan early if a shutdown has been - // requested, otherwise long rescans will block the daemon from - // shutting down promptly. - select { - case <-r.quit: - return ErrRouterShuttingDown - default: - } - - // Using the next height, request a manual block pruning from - // the chainview for the particular block hash. - log.Infof("Filtering block for closed channels, at height: %v", - int64(nextHeight)) - nextHash, err := r.cfg.Chain.GetBlockHash(int64(nextHeight)) - if err != nil { - return err - } - log.Tracef("Running block filter on block with hash: %v", - nextHash) - filterBlock, err := r.cfg.ChainView.FilterBlock(nextHash) - if err != nil { - return err - } - - // We're only interested in all prior outputs that have been - // spent in the block, so collate all the referenced previous - // outpoints within each tx and input. - for _, tx := range filterBlock.Transactions { - for _, txIn := range tx.TxIn { - spentOutputs = append(spentOutputs, - &txIn.PreviousOutPoint) - } - } - } - - // With the spent outputs gathered, attempt to prune the channel graph, - // also passing in the best hash+height so the prune tip can be updated. - closedChans, err := r.cfg.Graph.PruneGraph( - spentOutputs, bestHash, uint32(bestHeight), - ) - if err != nil { - return err - } - - log.Infof("Graph pruning complete: %v channels were closed since "+ - "height %v", len(closedChans), pruneHeight) - return nil -} - -// isZombieChannel takes two edge policy updates and determines if the -// corresponding channel should be considered a zombie. The first boolean is -// true if the policy update from node 1 is considered a zombie, the second -// boolean is that of node 2, and the final boolean is true if the channel -// is considered a zombie. -func (r *ChannelRouter) isZombieChannel(e1, - e2 *models.ChannelEdgePolicy) (bool, bool, bool) { - - chanExpiry := r.cfg.ChannelPruneExpiry - - e1Zombie := e1 == nil || time.Since(e1.LastUpdate) >= chanExpiry - e2Zombie := e2 == nil || time.Since(e2.LastUpdate) >= chanExpiry - - var e1Time, e2Time time.Time - if e1 != nil { - e1Time = e1.LastUpdate - } - if e2 != nil { - e2Time = e2.LastUpdate - } - - return e1Zombie, e2Zombie, r.IsZombieChannel(e1Time, e2Time) -} - -// IsZombieChannel takes the timestamps of the latest channel updates for a -// channel and returns true if the channel should be considered a zombie based -// on these timestamps. -func (r *ChannelRouter) IsZombieChannel(updateTime1, - updateTime2 time.Time) bool { - - chanExpiry := r.cfg.ChannelPruneExpiry - - e1Zombie := updateTime1.IsZero() || - time.Since(updateTime1) >= chanExpiry - - e2Zombie := updateTime2.IsZero() || - time.Since(updateTime2) >= chanExpiry - - // If we're using strict zombie pruning, then a channel is only - // considered live if both edges have a recent update we know of. - if r.cfg.StrictZombiePruning { - return e1Zombie || e2Zombie - } - - // Otherwise, if we're using the less strict variant, then a channel is - // considered live if either of the edges have a recent update. - return e1Zombie && e2Zombie -} - -// pruneZombieChans is a method that will be called periodically to prune out -// any "zombie" channels. We consider channels zombies if *both* edges haven't -// been updated since our zombie horizon. If AssumeChannelValid is present, -// we'll also consider channels zombies if *both* edges are disabled. This -// usually signals that a channel has been closed on-chain. We do this -// periodically to keep a healthy, lively routing table. -func (r *ChannelRouter) pruneZombieChans() error { - chansToPrune := make(map[uint64]struct{}) - chanExpiry := r.cfg.ChannelPruneExpiry - - log.Infof("Examining channel graph for zombie channels") - - // A helper method to detect if the channel belongs to this node - isSelfChannelEdge := func(info *models.ChannelEdgeInfo) bool { - return info.NodeKey1Bytes == r.cfg.SelfNode || - info.NodeKey2Bytes == r.cfg.SelfNode - } - - // First, we'll collect all the channels which are eligible for garbage - // collection due to being zombies. - filterPruneChans := func(info *models.ChannelEdgeInfo, - e1, e2 *models.ChannelEdgePolicy) error { - - // Exit early in case this channel is already marked to be - // pruned - _, markedToPrune := chansToPrune[info.ChannelID] - if markedToPrune { - return nil - } - - // We'll ensure that we don't attempt to prune our *own* - // channels from the graph, as in any case this should be - // re-advertised by the sub-system above us. - if isSelfChannelEdge(info) { - return nil - } - - e1Zombie, e2Zombie, isZombieChan := r.isZombieChannel(e1, e2) - - if e1Zombie { - log.Tracef("Node1 pubkey=%x of chan_id=%v is zombie", - info.NodeKey1Bytes, info.ChannelID) - } - - if e2Zombie { - log.Tracef("Node2 pubkey=%x of chan_id=%v is zombie", - info.NodeKey2Bytes, info.ChannelID) - } - - // If either edge hasn't been updated for a period of - // chanExpiry, then we'll mark the channel itself as eligible - // for graph pruning. - if !isZombieChan { - return nil - } - - log.Debugf("ChannelID(%v) is a zombie, collecting to prune", - info.ChannelID) - - // TODO(roasbeef): add ability to delete single directional edge - chansToPrune[info.ChannelID] = struct{}{} - - return nil - } - - // If AssumeChannelValid is present we'll look at the disabled bit for - // both edges. If they're both disabled, then we can interpret this as - // the channel being closed and can prune it from our graph. - if r.cfg.AssumeChannelValid { - disabledChanIDs, err := r.cfg.Graph.DisabledChannelIDs() - if err != nil { - return fmt.Errorf("unable to get disabled channels "+ - "ids chans: %v", err) - } - - disabledEdges, err := r.cfg.Graph.FetchChanInfos( - disabledChanIDs, - ) - if err != nil { - return fmt.Errorf("unable to fetch disabled channels "+ - "edges chans: %v", err) - } - - // Ensuring we won't prune our own channel from the graph. - for _, disabledEdge := range disabledEdges { - if !isSelfChannelEdge(disabledEdge.Info) { - chansToPrune[disabledEdge.Info.ChannelID] = - struct{}{} - } - } - } - - startTime := time.Unix(0, 0) - endTime := time.Now().Add(-1 * chanExpiry) - oldEdges, err := r.cfg.Graph.ChanUpdatesInHorizon(startTime, endTime) - if err != nil { - return fmt.Errorf("unable to fetch expired channel updates "+ - "chans: %v", err) - } - - for _, u := range oldEdges { - err = filterPruneChans(u.Info, u.Policy1, u.Policy2) - if err != nil { - log.Warnf("Filter pruning channels: %w\n", err) - } - } - - log.Infof("Pruning %v zombie channels", len(chansToPrune)) - if len(chansToPrune) == 0 { - return nil - } - - // With the set of zombie-like channels obtained, we'll do another pass - // to delete them from the channel graph. - toPrune := make([]uint64, 0, len(chansToPrune)) - for chanID := range chansToPrune { - toPrune = append(toPrune, chanID) - log.Tracef("Pruning zombie channel with ChannelID(%v)", chanID) - } - err = r.cfg.Graph.DeleteChannelEdges( - r.cfg.StrictZombiePruning, true, toPrune..., - ) - if err != nil { - return fmt.Errorf("unable to delete zombie channels: %w", err) - } - - // With the channels pruned, we'll also attempt to prune any nodes that - // were a part of them. - err = r.cfg.Graph.PruneGraphNodes() - if err != nil && !errors.Is(err, channeldb.ErrGraphNodesNotFound) { - return fmt.Errorf("unable to prune graph nodes: %w", err) - } - - return nil -} - -// handleNetworkUpdate is responsible for processing the update message and -// notifies topology changes, if any. -// -// NOTE: must be run inside goroutine. -func (r *ChannelRouter) handleNetworkUpdate(vb *ValidationBarrier, - update *routingMsg) { - - defer r.wg.Done() - defer vb.CompleteJob() - - // If this message has an existing dependency, then we'll wait until - // that has been fully validated before we proceed. - err := vb.WaitForDependants(update.msg) - if err != nil { - switch { - case IsError(err, ErrVBarrierShuttingDown): - update.err <- err - - case IsError(err, ErrParentValidationFailed): - update.err <- newErrf(ErrIgnored, err.Error()) - - default: - log.Warnf("unexpected error during validation "+ - "barrier shutdown: %v", err) - update.err <- err - } - - return - } - - // Process the routing update to determine if this is either a new - // update from our PoV or an update to a prior vertex/edge we - // previously accepted. - err = r.processUpdate(update.msg, update.op...) - update.err <- err - - // If this message had any dependencies, then we can now signal them to - // continue. - allowDependents := err == nil || IsError(err, ErrIgnored, ErrOutdated) - vb.SignalDependants(update.msg, allowDependents) - - // If the error is not nil here, there's no need to send topology - // change. - if err != nil { - // We now decide to log an error or not. If allowDependents is - // false, it means there is an error and the error is neither - // ErrIgnored nor ErrOutdated. In this case, we'll log an error. - // Otherwise, we'll add debug log only. - if allowDependents { - log.Debugf("process network updates got: %v", err) - } else { - log.Errorf("process network updates got: %v", err) - } - - return - } - - // Otherwise, we'll send off a new notification for the newly accepted - // update, if any. - topChange := &TopologyChange{} - err = addToTopologyChange(r.cfg.Graph, topChange, update.msg) - if err != nil { - log.Errorf("unable to update topology change notification: %v", - err) - return - } - - if !topChange.isEmpty() { - r.notifyTopologyChange(topChange) - } -} - -// networkHandler is the primary goroutine for the ChannelRouter. The roles of -// this goroutine include answering queries related to the state of the -// network, pruning the graph on new block notification, applying network -// updates, and registering new topology clients. -// -// NOTE: This MUST be run as a goroutine. -func (r *ChannelRouter) networkHandler() { - defer r.wg.Done() - - graphPruneTicker := time.NewTicker(r.cfg.GraphPruneInterval) - defer graphPruneTicker.Stop() - - defer r.statTicker.Stop() - - r.stats.Reset() - - // We'll use this validation barrier to ensure that we process all jobs - // in the proper order during parallel validation. - // - // NOTE: For AssumeChannelValid, we bump up the maximum number of - // concurrent validation requests since there are no blocks being - // fetched. This significantly increases the performance of IGD for - // neutrino nodes. - // - // However, we dial back to use multiple of the number of cores when - // fully validating, to avoid fetching up to 1000 blocks from the - // backend. On bitcoind, this will empirically cause massive latency - // spikes when executing this many concurrent RPC calls. Critical - // subsystems or basic rpc calls that rely on calls such as GetBestBlock - // will hang due to excessive load. - // - // See https://github.com/lightningnetwork/lnd/issues/4892. - var validationBarrier *ValidationBarrier - if r.cfg.AssumeChannelValid { - validationBarrier = NewValidationBarrier(1000, r.quit) - } else { - validationBarrier = NewValidationBarrier( - 4*runtime.NumCPU(), r.quit, - ) - } - - for { - - // If there are stats, resume the statTicker. - if !r.stats.Empty() { - r.statTicker.Resume() - } - - select { - // A new fully validated network update has just arrived. As a - // result we'll modify the channel graph accordingly depending - // on the exact type of the message. - case update := <-r.networkUpdates: - // We'll set up any dependants, and wait until a free - // slot for this job opens up, this allows us to not - // have thousands of goroutines active. - validationBarrier.InitJobDependencies(update.msg) - - r.wg.Add(1) - go r.handleNetworkUpdate(validationBarrier, update) - - // TODO(roasbeef): remove all unconnected vertexes - // after N blocks pass with no corresponding - // announcements. - - case chainUpdate, ok := <-r.staleBlocks: - // If the channel has been closed, then this indicates - // the daemon is shutting down, so we exit ourselves. - if !ok { - return - } - - // Since this block is stale, we update our best height - // to the previous block. - blockHeight := chainUpdate.Height - atomic.StoreUint32(&r.bestHeight, blockHeight-1) - - // Update the channel graph to reflect that this block - // was disconnected. - _, err := r.cfg.Graph.DisconnectBlockAtHeight(blockHeight) - if err != nil { - log.Errorf("unable to prune graph with stale "+ - "block: %v", err) - continue - } - - // TODO(halseth): notify client about the reorg? - - // A new block has arrived, so we can prune the channel graph - // of any channels which were closed in the block. - case chainUpdate, ok := <-r.newBlocks: - // If the channel has been closed, then this indicates - // the daemon is shutting down, so we exit ourselves. - if !ok { - return - } - - // We'll ensure that any new blocks received attach - // directly to the end of our main chain. If not, then - // we've somehow missed some blocks. Here we'll catch - // up the chain with the latest blocks. - currentHeight := atomic.LoadUint32(&r.bestHeight) - switch { - case chainUpdate.Height == currentHeight+1: - err := r.updateGraphWithClosedChannels( - chainUpdate, - ) - if err != nil { - log.Errorf("unable to prune graph "+ - "with closed channels: %v", err) - } - - case chainUpdate.Height > currentHeight+1: - log.Errorf("out of order block: expecting "+ - "height=%v, got height=%v", - currentHeight+1, chainUpdate.Height) - - err := r.getMissingBlocks(currentHeight, chainUpdate) - if err != nil { - log.Errorf("unable to retrieve missing"+ - "blocks: %v", err) - } - - case chainUpdate.Height < currentHeight+1: - log.Errorf("out of order block: expecting "+ - "height=%v, got height=%v", - currentHeight+1, chainUpdate.Height) - - log.Infof("Skipping channel pruning since "+ - "received block height %v was already"+ - " processed.", chainUpdate.Height) - } - - // A new notification client update has arrived. We're either - // gaining a new client, or cancelling notifications for an - // existing client. - case ntfnUpdate := <-r.ntfnClientUpdates: - clientID := ntfnUpdate.clientID - - if ntfnUpdate.cancel { - client, ok := r.topologyClients.LoadAndDelete( - clientID, - ) - if ok { - close(client.exit) - client.wg.Wait() - - close(client.ntfnChan) - } - - continue - } - - r.topologyClients.Store(clientID, &topologyClient{ - ntfnChan: ntfnUpdate.ntfnChan, - exit: make(chan struct{}), - }) - - // The graph prune ticker has ticked, so we'll examine the - // state of the known graph to filter out any zombie channels - // for pruning. - case <-graphPruneTicker.C: - if err := r.pruneZombieChans(); err != nil { - log.Errorf("Unable to prune zombies: %v", err) - } - - // Log any stats if we've processed a non-empty number of - // channels, updates, or nodes. We'll only pause the ticker if - // the last window contained no updates to avoid resuming and - // pausing while consecutive windows contain new info. - case <-r.statTicker.Ticks(): - if !r.stats.Empty() { - log.Infof(r.stats.String()) - } else { - r.statTicker.Pause() - } - r.stats.Reset() - - // The router has been signalled to exit, to we exit our main - // loop so the wait group can be decremented. - case <-r.quit: - return - } - } -} - -// getMissingBlocks walks through all missing blocks and updates the graph -// closed channels accordingly. -func (r *ChannelRouter) getMissingBlocks(currentHeight uint32, - chainUpdate *chainview.FilteredBlock) error { - - outdatedHash, err := r.cfg.Chain.GetBlockHash(int64(currentHeight)) - if err != nil { - return err - } - - outdatedBlock := &chainntnfs.BlockEpoch{ - Height: int32(currentHeight), - Hash: outdatedHash, - } - - epochClient, err := r.cfg.Notifier.RegisterBlockEpochNtfn( - outdatedBlock, - ) - if err != nil { - return err - } - defer epochClient.Cancel() - - blockDifference := int(chainUpdate.Height - currentHeight) - - // We'll walk through all the outdated blocks and make sure we're able - // to update the graph with any closed channels from them. - for i := 0; i < blockDifference; i++ { - var ( - missingBlock *chainntnfs.BlockEpoch - ok bool - ) - - select { - case missingBlock, ok = <-epochClient.Epochs: - if !ok { - return nil - } - - case <-r.quit: - return nil - } - - filteredBlock, err := r.cfg.ChainView.FilterBlock( - missingBlock.Hash, - ) - if err != nil { - return err - } - - err = r.updateGraphWithClosedChannels( - filteredBlock, - ) - if err != nil { - return err - } - } - - return nil -} - -// updateGraphWithClosedChannels prunes the channel graph of closed channels -// that are no longer needed. -func (r *ChannelRouter) updateGraphWithClosedChannels( - chainUpdate *chainview.FilteredBlock) error { - - // Once a new block arrives, we update our running track of the height - // of the chain tip. - blockHeight := chainUpdate.Height - - atomic.StoreUint32(&r.bestHeight, blockHeight) - log.Infof("Pruning channel graph using block %v (height=%v)", - chainUpdate.Hash, blockHeight) - - // We're only interested in all prior outputs that have been spent in - // the block, so collate all the referenced previous outpoints within - // each tx and input. - var spentOutputs []*wire.OutPoint - for _, tx := range chainUpdate.Transactions { - for _, txIn := range tx.TxIn { - spentOutputs = append(spentOutputs, - &txIn.PreviousOutPoint) - } - } - - // With the spent outputs gathered, attempt to prune the channel graph, - // also passing in the hash+height of the block being pruned so the - // prune tip can be updated. - chansClosed, err := r.cfg.Graph.PruneGraph(spentOutputs, - &chainUpdate.Hash, chainUpdate.Height) - if err != nil { - log.Errorf("unable to prune routing table: %v", err) - return err - } - - log.Infof("Block %v (height=%v) closed %v channels", chainUpdate.Hash, - blockHeight, len(chansClosed)) - - if len(chansClosed) == 0 { - return err - } - - // Notify all currently registered clients of the newly closed channels. - closeSummaries := createCloseSummaries(blockHeight, chansClosed...) - r.notifyTopologyChange(&TopologyChange{ - ClosedChannels: closeSummaries, - }) - - return nil -} - -// assertNodeAnnFreshness returns a non-nil error if we have an announcement in -// the database for the passed node with a timestamp newer than the passed -// timestamp. ErrIgnored will be returned if we already have the node, and -// ErrOutdated will be returned if we have a timestamp that's after the new -// timestamp. -func (r *ChannelRouter) assertNodeAnnFreshness(node route.Vertex, - msgTimestamp time.Time) error { - - // If we are not already aware of this node, it means that we don't - // know about any channel using this node. To avoid a DoS attack by - // node announcements, we will ignore such nodes. If we do know about - // this node, check that this update brings info newer than what we - // already have. - lastUpdate, exists, err := r.cfg.Graph.HasLightningNode(node) - if err != nil { - return errors.Errorf("unable to query for the "+ - "existence of node: %v", err) - } - if !exists { - return newErrf(ErrIgnored, "Ignoring node announcement"+ - " for node not found in channel graph (%x)", - node[:]) - } - - // If we've reached this point then we're aware of the vertex being - // advertised. So we now check if the new message has a new time stamp, - // if not then we won't accept the new data as it would override newer - // data. - if !lastUpdate.Before(msgTimestamp) { - return newErrf(ErrOutdated, "Ignoring outdated "+ - "announcement for %x", node[:]) - } - - return nil -} - -// addZombieEdge adds a channel that failed complete validation into the zombie -// index, so we can avoid having to re-validate it in the future. -func (r *ChannelRouter) addZombieEdge(chanID uint64) error { - // If the edge fails validation we'll mark the edge itself as a zombie, - // so we don't continue to request it. We use the "zero key" for both - // node pubkeys so this edge can't be resurrected. - var zeroKey [33]byte - err := r.cfg.Graph.MarkEdgeZombie(chanID, zeroKey, zeroKey) - if err != nil { - return fmt.Errorf("unable to mark spent chan(id=%v) as a "+ - "zombie: %w", chanID, err) - } - - return nil -} - -// makeFundingScript is used to make the funding script for both segwit v0 and -// segwit v1 (taproot) channels. -// -// TODO(roasbeef: export and use elsewhere? -func makeFundingScript(bitcoinKey1, bitcoinKey2 []byte, - chanFeatures []byte) ([]byte, error) { - - legacyFundingScript := func() ([]byte, error) { - witnessScript, err := input.GenMultiSigScript( - bitcoinKey1, bitcoinKey2, - ) - if err != nil { - return nil, err - } - pkScript, err := input.WitnessScriptHash(witnessScript) - if err != nil { - return nil, err - } - - return pkScript, nil - } - - if len(chanFeatures) == 0 { - return legacyFundingScript() - } - - // In order to make the correct funding script, we'll need to parse the - // chanFeatures bytes into a feature vector we can interact with. - rawFeatures := lnwire.NewRawFeatureVector() - err := rawFeatures.Decode(bytes.NewReader(chanFeatures)) - if err != nil { - return nil, fmt.Errorf("unable to parse chan feature "+ - "bits: %w", err) - } - - chanFeatureBits := lnwire.NewFeatureVector( - rawFeatures, lnwire.Features, - ) - if chanFeatureBits.HasFeature( - lnwire.SimpleTaprootChannelsOptionalStaging, - ) { - - pubKey1, err := btcec.ParsePubKey(bitcoinKey1) - if err != nil { - return nil, err - } - pubKey2, err := btcec.ParsePubKey(bitcoinKey2) - if err != nil { - return nil, err - } - - fundingScript, _, err := input.GenTaprootFundingScript( - pubKey1, pubKey2, 0, - ) - if err != nil { - return nil, err - } - - return fundingScript, nil - } - - return legacyFundingScript() -} - -// processUpdate processes a new relate authenticated channel/edge, node or -// channel/edge update network update. If the update didn't affect the internal -// state of the draft due to either being out of date, invalid, or redundant, -// then error is returned. -func (r *ChannelRouter) processUpdate(msg interface{}, - op ...batch.SchedulerOption) error { - - switch msg := msg.(type) { - case *channeldb.LightningNode: - // Before we add the node to the database, we'll check to see - // if the announcement is "fresh" or not. If it isn't, then - // we'll return an error. - err := r.assertNodeAnnFreshness(msg.PubKeyBytes, msg.LastUpdate) - if err != nil { - return err - } - - if err := r.cfg.Graph.AddLightningNode(msg, op...); err != nil { - return errors.Errorf("unable to add node %x to the "+ - "graph: %v", msg.PubKeyBytes, err) - } - - log.Tracef("Updated vertex data for node=%x", msg.PubKeyBytes) - r.stats.incNumNodeUpdates() - - case *models.ChannelEdgeInfo: - log.Debugf("Received ChannelEdgeInfo for channel %v", - msg.ChannelID) - - // Prior to processing the announcement we first check if we - // already know of this channel, if so, then we can exit early. - _, _, exists, isZombie, err := r.cfg.Graph.HasChannelEdge( - msg.ChannelID, - ) - if err != nil && !errors.Is( - err, channeldb.ErrGraphNoEdgesFound, - ) { - - return errors.Errorf("unable to check for edge "+ - "existence: %v", err) - } - if isZombie { - return newErrf(ErrIgnored, "ignoring msg for zombie "+ - "chan_id=%v", msg.ChannelID) - } - if exists { - return newErrf(ErrIgnored, "ignoring msg for known "+ - "chan_id=%v", msg.ChannelID) - } - - // If AssumeChannelValid is present, then we are unable to - // perform any of the expensive checks below, so we'll - // short-circuit our path straight to adding the edge to our - // graph. If the passed ShortChannelID is an alias, then we'll - // skip validation as it will not map to a legitimate tx. This - // is not a DoS vector as only we can add an alias - // ChannelAnnouncement from the gossiper. - scid := lnwire.NewShortChanIDFromInt(msg.ChannelID) - if r.cfg.AssumeChannelValid || r.cfg.IsAlias(scid) { - if err := r.cfg.Graph.AddChannelEdge(msg, op...); err != nil { - return fmt.Errorf("unable to add edge: %w", err) - } - log.Tracef("New channel discovered! Link "+ - "connects %x and %x with ChannelID(%v)", - msg.NodeKey1Bytes, msg.NodeKey2Bytes, - msg.ChannelID) - r.stats.incNumEdgesDiscovered() - - break - } - - // Before we can add the channel to the channel graph, we need - // to obtain the full funding outpoint that's encoded within - // the channel ID. - channelID := lnwire.NewShortChanIDFromInt(msg.ChannelID) - fundingTx, err := r.fetchFundingTxWrapper(&channelID) - if err != nil { - // In order to ensure we don't erroneously mark a - // channel as a zombie due to an RPC failure, we'll - // attempt to string match for the relevant errors. - // - // * btcd: - // * https://github.com/btcsuite/btcd/blob/master/rpcserver.go#L1316 - // * https://github.com/btcsuite/btcd/blob/master/rpcserver.go#L1086 - // * bitcoind: - // * https://github.com/bitcoin/bitcoin/blob/7fcf53f7b4524572d1d0c9a5fdc388e87eb02416/src/rpc/blockchain.cpp#L770 - // * https://github.com/bitcoin/bitcoin/blob/7fcf53f7b4524572d1d0c9a5fdc388e87eb02416/src/rpc/blockchain.cpp#L954 - switch { - case strings.Contains(err.Error(), "not found"): - fallthrough - - case strings.Contains(err.Error(), "out of range"): - // If the funding transaction isn't found at - // all, then we'll mark the edge itself as a - // zombie, so we don't continue to request it. - // We use the "zero key" for both node pubkeys - // so this edge can't be resurrected. - zErr := r.addZombieEdge(msg.ChannelID) - if zErr != nil { - return zErr - } - - default: - } - - return newErrf(ErrNoFundingTransaction, "unable to "+ - "locate funding tx: %v", err) - } - - // Recreate witness output to be sure that declared in channel - // edge bitcoin keys and channel value corresponds to the - // reality. - fundingPkScript, err := makeFundingScript( - msg.BitcoinKey1Bytes[:], msg.BitcoinKey2Bytes[:], - msg.Features, - ) - if err != nil { - return err - } - - // Next we'll validate that this channel is actually - // well-formed. If this check fails, then this channel either - // doesn't exist, or isn't the one that was meant to be created - // according to the passed channel proofs. - fundingPoint, err := chanvalidate.Validate(&chanvalidate.Context{ - Locator: &chanvalidate.ShortChanIDChanLocator{ - ID: channelID, - }, - MultiSigPkScript: fundingPkScript, - FundingTx: fundingTx, - }) - if err != nil { - // Mark the edge as a zombie, so we won't try to - // re-validate it on start up. - if err := r.addZombieEdge(msg.ChannelID); err != nil { - return err - } - - return newErrf(ErrInvalidFundingOutput, "output "+ - "failed validation: %w", err) - } - - // Now that we have the funding outpoint of the channel, ensure - // that it hasn't yet been spent. If so, then this channel has - // been closed, so we'll ignore it. - chanUtxo, err := r.cfg.Chain.GetUtxo( - fundingPoint, fundingPkScript, channelID.BlockHeight, - r.quit, - ) - if err != nil { - if errors.Is(err, btcwallet.ErrOutputSpent) { - zErr := r.addZombieEdge(msg.ChannelID) - if zErr != nil { - return zErr - } - } - - return newErrf(ErrChannelSpent, "unable to fetch utxo "+ - "for chan_id=%v, chan_point=%v: %v", - msg.ChannelID, fundingPoint, err) - } - - // TODO(roasbeef): this is a hack, needs to be removed - // after commitment fees are dynamic. - msg.Capacity = btcutil.Amount(chanUtxo.Value) - msg.ChannelPoint = *fundingPoint - if err := r.cfg.Graph.AddChannelEdge(msg, op...); err != nil { - return errors.Errorf("unable to add edge: %v", err) - } - - log.Debugf("New channel discovered! Link "+ - "connects %x and %x with ChannelPoint(%v): "+ - "chan_id=%v, capacity=%v", - msg.NodeKey1Bytes, msg.NodeKey2Bytes, - fundingPoint, msg.ChannelID, msg.Capacity) - r.stats.incNumEdgesDiscovered() - - // As a new edge has been added to the channel graph, we'll - // update the current UTXO filter within our active - // FilteredChainView, so we are notified if/when this channel is - // closed. - filterUpdate := []channeldb.EdgePoint{ - { - FundingPkScript: fundingPkScript, - OutPoint: *fundingPoint, - }, - } - err = r.cfg.ChainView.UpdateFilter( - filterUpdate, atomic.LoadUint32(&r.bestHeight), - ) - if err != nil { - return errors.Errorf("unable to update chain "+ - "view: %v", err) - } - - case *models.ChannelEdgePolicy: - log.Debugf("Received ChannelEdgePolicy for channel %v", - msg.ChannelID) + // ChannelID is the channel of this edge. + ChannelID uint64 - // We make sure to hold the mutex for this channel ID, - // such that no other goroutine is concurrently doing - // database accesses for the same channel ID. - r.channelEdgeMtx.Lock(msg.ChannelID) - defer r.channelEdgeMtx.Unlock(msg.ChannelID) + // Direction takes the value of 0 or 1 and is identical in definition to + // the channel direction flag. A value of 0 means the direction from the + // lower node pubkey to the higher. + Direction uint8 +} - edge1Timestamp, edge2Timestamp, exists, isZombie, err := - r.cfg.Graph.HasChannelEdge(msg.ChannelID) - if err != nil && !errors.Is( - err, channeldb.ErrGraphNoEdgesFound, - ) { +// String returns a human-readable version of the edgeLocator values. +func (e *EdgeLocator) String() string { + return fmt.Sprintf("%v:%v", e.ChannelID, e.Direction) +} - return errors.Errorf("unable to check for edge "+ - "existence: %v", err) +// ChannelRouter is the layer 3 router within the Lightning stack. Below the +// ChannelRouter is the HtlcSwitch, and below that is the Bitcoin blockchain +// itself. The primary role of the ChannelRouter is to respond to queries for +// potential routes that can support a payment amount, and also general graph +// reachability questions. The router will prune the channel graph +// automatically as new blocks are discovered which spend certain known funding +// outpoints, thereby closing their respective channels. +type ChannelRouter struct { + started uint32 // To be used atomically. + stopped uint32 // To be used atomically. - } + // cfg is a copy of the configuration struct that the ChannelRouter was + // initialized with. + cfg *Config - // If the channel is marked as a zombie in our database, and - // we consider this a stale update, then we should not apply the - // policy. - isStaleUpdate := time.Since(msg.LastUpdate) > r.cfg.ChannelPruneExpiry - if isZombie && isStaleUpdate { - return newErrf(ErrIgnored, "ignoring stale update "+ - "(flags=%v|%v) for zombie chan_id=%v", - msg.MessageFlags, msg.ChannelFlags, - msg.ChannelID) - } + quit chan struct{} + wg sync.WaitGroup +} - // If the channel doesn't exist in our database, we cannot - // apply the updated policy. - if !exists { - return newErrf(ErrIgnored, "ignoring update "+ - "(flags=%v|%v) for unknown chan_id=%v", - msg.MessageFlags, msg.ChannelFlags, - msg.ChannelID) - } +// New creates a new instance of the ChannelRouter with the specified +// configuration parameters. As part of initialization, if the router detects +// that the channel graph isn't fully in sync with the latest UTXO (since the +// channel graph is a subset of the UTXO set) set, then the router will proceed +// to fully sync to the latest state of the UTXO set. +func New(cfg Config) (*ChannelRouter, error) { + return &ChannelRouter{ + cfg: &cfg, + quit: make(chan struct{}), + }, nil +} - // As edges are directional edge node has a unique policy for - // the direction of the edge they control. Therefore, we first - // check if we already have the most up-to-date information for - // that edge. If this message has a timestamp not strictly - // newer than what we already know of we can exit early. - switch { - - // A flag set of 0 indicates this is an announcement for the - // "first" node in the channel. - case msg.ChannelFlags&lnwire.ChanUpdateDirection == 0: - - // Ignore outdated message. - if !edge1Timestamp.Before(msg.LastUpdate) { - return newErrf(ErrOutdated, "Ignoring "+ - "outdated update (flags=%v|%v) for "+ - "known chan_id=%v", msg.MessageFlags, - msg.ChannelFlags, msg.ChannelID) - } +// Start launches all the goroutines the ChannelRouter requires to carry out +// its duties. If the router has already been started, then this method is a +// noop. +func (r *ChannelRouter) Start() error { + if !atomic.CompareAndSwapUint32(&r.started, 0, 1) { + return nil + } - // Similarly, a flag set of 1 indicates this is an announcement - // for the "second" node in the channel. - case msg.ChannelFlags&lnwire.ChanUpdateDirection == 1: + log.Info("Channel Router starting") - // Ignore outdated message. - if !edge2Timestamp.Before(msg.LastUpdate) { - return newErrf(ErrOutdated, "Ignoring "+ - "outdated update (flags=%v|%v) for "+ - "known chan_id=%v", msg.MessageFlags, - msg.ChannelFlags, msg.ChannelID) - } - } + // If any payments are still in flight, we resume, to make sure their + // results are properly handled. + payments, err := r.cfg.Control.FetchInFlightPayments() + if err != nil { + return err + } - // Now that we know this isn't a stale update, we'll apply the - // new edge policy to the proper directional edge within the - // channel graph. - if err = r.cfg.Graph.UpdateEdgePolicy(msg, op...); err != nil { - err := errors.Errorf("unable to add channel: %v", err) - log.Error(err) - return err + // Before we restart existing payments and start accepting more + // payments to be made, we clean the network result store of the + // Switch. We do this here at startup to ensure no more payments can be + // made concurrently, so we know the toKeep map will be up-to-date + // until the cleaning has finished. + toKeep := make(map[uint64]struct{}) + for _, p := range payments { + for _, a := range p.HTLCs { + toKeep[a.AttemptID] = struct{}{} } + } - log.Tracef("New channel update applied: %v", - newLogClosure(func() string { return spew.Sdump(msg) })) - r.stats.incNumChannelUpdates() - - default: - return errors.Errorf("wrong routing update message type") + log.Debugf("Cleaning network result store.") + if err := r.cfg.Payer.CleanStore(toKeep); err != nil { + return err } - return nil -} + for _, payment := range payments { + log.Infof("Resuming payment %v", payment.Info.PaymentIdentifier) + r.wg.Add(1) + go func(payment *channeldb.MPPayment) { + defer r.wg.Done() -// fetchFundingTxWrapper is a wrapper around fetchFundingTx, except that it -// will exit if the router has stopped. -func (r *ChannelRouter) fetchFundingTxWrapper(chanID *lnwire.ShortChannelID) ( - *wire.MsgTx, error) { + // Get the hashes used for the outstanding HTLCs. + htlcs := make(map[uint64]lntypes.Hash) + for _, a := range payment.HTLCs { + a := a - txChan := make(chan *wire.MsgTx, 1) - errChan := make(chan error, 1) + // We check whether the individual attempts + // have their HTLC hash set, if not we'll fall + // back to the overall payment hash. + hash := payment.Info.PaymentIdentifier + if a.Hash != nil { + hash = *a.Hash + } - go func() { - tx, err := r.fetchFundingTx(chanID) - if err != nil { - errChan <- err - return - } + htlcs[a.AttemptID] = hash + } - txChan <- tx - }() + // Since we are not supporting creating more shards + // after a restart (only receiving the result of the + // shards already outstanding), we create a simple + // shard tracker that will map the attempt IDs to + // hashes used for the HTLCs. This will be enough also + // for AMP payments, since we only need the hashes for + // the individual HTLCs to regenerate the circuits, and + // we don't currently persist the root share necessary + // to re-derive them. + shardTracker := shards.NewSimpleShardTracker( + payment.Info.PaymentIdentifier, htlcs, + ) - select { - case tx := <-txChan: - return tx, nil + // We create a dummy, empty payment session such that + // we won't make another payment attempt when the + // result for the in-flight attempt is received. + paySession := r.cfg.SessionSource.NewPaymentSessionEmpty() - case err := <-errChan: - return nil, err + // We pass in a non-timeout context, to indicate we + // don't need it to timeout. It will stop immediately + // after the existing attempt has finished anyway. We + // also set a zero fee limit, as no more routes should + // be tried. + noTimeout := time.Duration(0) + _, _, err := r.sendPayment( + context.Background(), 0, + payment.Info.PaymentIdentifier, noTimeout, + paySession, shardTracker, + ) + if err != nil { + log.Errorf("Resuming payment %v failed: %v.", + payment.Info.PaymentIdentifier, err) + return + } - case <-r.quit: - return nil, ErrRouterShuttingDown + log.Infof("Resumed payment %v completed.", + payment.Info.PaymentIdentifier) + }(payment) } + + return nil } -// fetchFundingTx returns the funding transaction identified by the passed -// short channel ID. -// -// TODO(roasbeef): replace with call to GetBlockTransaction? (would allow to -// later use getblocktxn) -func (r *ChannelRouter) fetchFundingTx( - chanID *lnwire.ShortChannelID) (*wire.MsgTx, error) { - - // First fetch the block hash by the block number encoded, then use - // that hash to fetch the block itself. - blockNum := int64(chanID.BlockHeight) - blockHash, err := r.cfg.Chain.GetBlockHash(blockNum) - if err != nil { - return nil, err - } - fundingBlock, err := r.cfg.Chain.GetBlock(blockHash) - if err != nil { - return nil, err +// Stop signals the ChannelRouter to gracefully halt all routines. This method +// will *block* until all goroutines have excited. If the channel router has +// already stopped then this method will return immediately. +func (r *ChannelRouter) Stop() error { + if !atomic.CompareAndSwapUint32(&r.stopped, 0, 1) { + return nil } - // As a sanity check, ensure that the advertised transaction index is - // within the bounds of the total number of transactions within a - // block. - numTxns := uint32(len(fundingBlock.Transactions)) - if chanID.TxIndex > numTxns-1 { - return nil, fmt.Errorf("tx_index=#%v "+ - "is out of range (max_index=%v), network_chan_id=%v", - chanID.TxIndex, numTxns-1, chanID) - } + log.Info("Channel Router shutting down...") + defer log.Debug("Channel Router shutdown complete") - return fundingBlock.Transactions[chanID.TxIndex].Copy(), nil -} + close(r.quit) + r.wg.Wait() -// routingMsg couples a routing related routing topology update to the -// error channel. -type routingMsg struct { - msg interface{} - op []batch.SchedulerOption - err chan error + return nil } // RouteRequest contains the parameters for a pathfinding request. It may @@ -2672,328 +1258,6 @@ func (r *ChannelRouter) extractChannelUpdate( return update } -// applyChannelUpdate validates a channel update and if valid, applies it to the -// database. It returns a bool indicating whether the updates were successful. -func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate) bool { - ch, _, _, err := r.GetChannelByID(msg.ShortChannelID) - if err != nil { - log.Errorf("Unable to retrieve channel by id: %v", err) - return false - } - - var pubKey *btcec.PublicKey - - switch msg.ChannelFlags & lnwire.ChanUpdateDirection { - case 0: - pubKey, _ = ch.NodeKey1() - - case 1: - pubKey, _ = ch.NodeKey2() - } - - // Exit early if the pubkey cannot be decided. - if pubKey == nil { - log.Errorf("Unable to decide pubkey with ChannelFlags=%v", - msg.ChannelFlags) - return false - } - - err = ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg) - if err != nil { - log.Errorf("Unable to validate channel update: %v", err) - return false - } - - err = r.UpdateEdge(&models.ChannelEdgePolicy{ - SigBytes: msg.Signature.ToSignatureBytes(), - ChannelID: msg.ShortChannelID.ToUint64(), - LastUpdate: time.Unix(int64(msg.Timestamp), 0), - MessageFlags: msg.MessageFlags, - ChannelFlags: msg.ChannelFlags, - TimeLockDelta: msg.TimeLockDelta, - MinHTLC: msg.HtlcMinimumMsat, - MaxHTLC: msg.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate), - ExtraOpaqueData: msg.ExtraOpaqueData, - }) - if err != nil && !IsError(err, ErrIgnored, ErrOutdated) { - log.Errorf("Unable to apply channel update: %v", err) - return false - } - - return true -} - -// AddNode is used to add information about a node to the router database. If -// the node with this pubkey is not present in an existing channel, it will -// be ignored. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) AddNode(node *channeldb.LightningNode, - op ...batch.SchedulerOption) error { - - rMsg := &routingMsg{ - msg: node, - op: op, - err: make(chan error, 1), - } - - select { - case r.networkUpdates <- rMsg: - select { - case err := <-rMsg.err: - return err - case <-r.quit: - return ErrRouterShuttingDown - } - case <-r.quit: - return ErrRouterShuttingDown - } -} - -// AddEdge is used to add edge/channel to the topology of the router, after all -// information about channel will be gathered this edge/channel might be used -// in construction of payment path. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) AddEdge(edge *models.ChannelEdgeInfo, - op ...batch.SchedulerOption) error { - - rMsg := &routingMsg{ - msg: edge, - op: op, - err: make(chan error, 1), - } - - select { - case r.networkUpdates <- rMsg: - select { - case err := <-rMsg.err: - return err - case <-r.quit: - return ErrRouterShuttingDown - } - case <-r.quit: - return ErrRouterShuttingDown - } -} - -// UpdateEdge is used to update edge information, without this message edge -// considered as not fully constructed. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) UpdateEdge(update *models.ChannelEdgePolicy, - op ...batch.SchedulerOption) error { - - rMsg := &routingMsg{ - msg: update, - op: op, - err: make(chan error, 1), - } - - select { - case r.networkUpdates <- rMsg: - select { - case err := <-rMsg.err: - return err - case <-r.quit: - return ErrRouterShuttingDown - } - case <-r.quit: - return ErrRouterShuttingDown - } -} - -// CurrentBlockHeight returns the block height from POV of the router subsystem. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) CurrentBlockHeight() (uint32, error) { - _, height, err := r.cfg.Chain.GetBestBlock() - return uint32(height), err -} - -// SyncedHeight returns the block height to which the router subsystem currently -// is synced to. This can differ from the above chain height if the goroutine -// responsible for processing the blocks isn't yet up to speed. -func (r *ChannelRouter) SyncedHeight() uint32 { - return atomic.LoadUint32(&r.bestHeight) -} - -// GetChannelByID return the channel by the channel id. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) ( - *models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) { - - return r.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) -} - -// FetchLightningNode attempts to look up a target node by its identity public -// key. channeldb.ErrGraphNodeNotFound is returned if the node doesn't exist -// within the graph. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) FetchLightningNode( - node route.Vertex) (*channeldb.LightningNode, error) { - - return r.cfg.Graph.FetchLightningNode(node) -} - -// ForEachNode is used to iterate over every node in router topology. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) ForEachNode( - cb func(*channeldb.LightningNode) error) error { - - return r.cfg.Graph.ForEachNode( - func(_ kvdb.RTx, n *channeldb.LightningNode) error { - return cb(n) - }) -} - -// ForAllOutgoingChannels is used to iterate over all outgoing channels owned by -// the router. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { - - return r.cfg.Graph.ForEachNodeChannel(r.cfg.SelfNode, - func(tx kvdb.RTx, c *models.ChannelEdgeInfo, - e *models.ChannelEdgePolicy, - _ *models.ChannelEdgePolicy) error { - - if e == nil { - return fmt.Errorf("channel from self node " + - "has no policy") - } - - return cb(tx, c, e) - }, - ) -} - -// AddProof updates the channel edge info with proof which is needed to -// properly announce the edge to the rest of the network. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) AddProof(chanID lnwire.ShortChannelID, - proof *models.ChannelAuthProof) error { - - info, _, _, err := r.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) - if err != nil { - return err - } - - info.AuthProof = proof - return r.cfg.Graph.UpdateChannelEdge(info) -} - -// IsStaleNode returns true if the graph source has a node announcement for the -// target node with a more recent timestamp. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) IsStaleNode(node route.Vertex, - timestamp time.Time) bool { - - // If our attempt to assert that the node announcement is fresh fails, - // then we know that this is actually a stale announcement. - err := r.assertNodeAnnFreshness(node, timestamp) - if err != nil { - log.Debugf("Checking stale node %x got %v", node, err) - return true - } - - return false -} - -// IsPublicNode determines whether the given vertex is seen as a public node in -// the graph from the graph's source node's point of view. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) IsPublicNode(node route.Vertex) (bool, error) { - return r.cfg.Graph.IsPublicNode(node) -} - -// IsKnownEdge returns true if the graph source already knows of the passed -// channel ID either as a live or zombie edge. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) IsKnownEdge(chanID lnwire.ShortChannelID) bool { - _, _, exists, isZombie, _ := r.cfg.Graph.HasChannelEdge( - chanID.ToUint64(), - ) - return exists || isZombie -} - -// IsStaleEdgePolicy returns true if the graph source has a channel edge for -// the passed channel ID (and flags) that have a more recent timestamp. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, - timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool { - - edge1Timestamp, edge2Timestamp, exists, isZombie, err := - r.cfg.Graph.HasChannelEdge(chanID.ToUint64()) - if err != nil { - log.Debugf("Check stale edge policy got error: %v", err) - return false - - } - - // If we know of the edge as a zombie, then we'll make some additional - // checks to determine if the new policy is fresh. - if isZombie { - // When running with AssumeChannelValid, we also prune channels - // if both of their edges are disabled. We'll mark the new - // policy as stale if it remains disabled. - if r.cfg.AssumeChannelValid { - isDisabled := flags&lnwire.ChanUpdateDisabled == - lnwire.ChanUpdateDisabled - if isDisabled { - return true - } - } - - // Otherwise, we'll fall back to our usual ChannelPruneExpiry. - return time.Since(timestamp) > r.cfg.ChannelPruneExpiry - } - - // If we don't know of the edge, then it means it's fresh (thus not - // stale). - if !exists { - return false - } - - // As edges are directional edge node has a unique policy for the - // direction of the edge they control. Therefore, we first check if we - // already have the most up-to-date information for that edge. If so, - // then we can exit early. - switch { - // A flag set of 0 indicates this is an announcement for the "first" - // node in the channel. - case flags&lnwire.ChanUpdateDirection == 0: - return !edge1Timestamp.Before(timestamp) - - // Similarly, a flag set of 1 indicates this is an announcement for the - // "second" node in the channel. - case flags&lnwire.ChanUpdateDirection == 1: - return !edge2Timestamp.Before(timestamp) - } - - return false -} - -// MarkEdgeLive clears an edge from our zombie index, deeming it as live. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error { - return r.cfg.Graph.MarkEdgeLive(chanID.ToUint64()) -} - // ErrNoChannel is returned when a route cannot be built because there are no // channels that satisfy all requirements. type ErrNoChannel struct { diff --git a/routing/router_test.go b/routing/router_test.go index 49ca6a2665..824d6aed9a 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -6,6 +6,8 @@ import ( "image/color" "math" "math/rand" + "net" + "sync" "sync/atomic" "testing" "time" @@ -16,15 +18,16 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/go-errors/errors" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/htlcswitch" - lnmock "github.com/lightningnetwork/lnd/lntest/mock" - "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" @@ -33,11 +36,38 @@ import ( "github.com/stretchr/testify/require" ) -var uniquePaymentID uint64 = 1 // to be used atomically +var ( + uniquePaymentID uint64 = 1 // to be used atomically + + testAddr = &net.TCPAddr{IP: (net.IP)([]byte{0xA, 0x0, 0x0, 0x1}), + Port: 9000} + testAddrs = []net.Addr{testAddr} + + testFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) + + testHash = [32]byte{ + 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, + 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, + 0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9, + 0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, + } + + testTime = time.Date(2018, time.January, 9, 14, 00, 00, 0, time.UTC) + + priv1, _ = btcec.NewPrivateKey() + bitcoinKey1 = priv1.PubKey() + + priv2, _ = btcec.NewPrivateKey() + bitcoinKey2 = priv2.PubKey() + + timeout = time.Second * 5 +) type testCtx struct { router *ChannelRouter + graphBuilder *mockGraphBuilder + graph *channeldb.ChannelGraph aliases map[string]route.Vertex @@ -45,12 +75,6 @@ type testCtx struct { privKeys map[string]*btcec.PrivateKey channelIDs map[route.Vertex]map[route.Vertex]uint64 - - chain *mockChain - - chainView *mockChainView - - notifier *lnmock.ChainNotifier } func (c *testCtx) getChannelIDFromAlias(t *testing.T, a, b string) uint64 { @@ -69,57 +93,22 @@ func (c *testCtx) getChannelIDFromAlias(t *testing.T, a, b string) uint64 { return channelID } -func (c *testCtx) RestartRouter(t *testing.T) { - // First, we'll reset the chainView's state as it doesn't persist the - // filter between restarts. - c.chainView.Reset() - - source, err := c.graph.SourceNode() - require.NoError(t, err) - - // With the chainView reset, we'll now re-create the router itself, and - // start it. - router, err := New(Config{ - SelfNode: source.PubKeyBytes, - RoutingGraph: newMockGraphSessionChanDB(c.graph), - Graph: c.graph, - Chain: c.chain, - ChainView: c.chainView, - Payer: &mockPaymentAttemptDispatcherOld{}, - Control: makeMockControlTower(), - ChannelPruneExpiry: time.Hour * 24, - GraphPruneInterval: time.Hour * 2, - IsAlias: func(scid lnwire.ShortChannelID) bool { - return false - }, - }) - require.NoError(t, err, "unable to create router") - require.NoError(t, router.Start(), "unable to start router") - - // Finally, we'll swap out the pointer in the testCtx with this fresh - // instance of the router. - c.router = router -} - -func createTestCtxFromGraphInstance(t *testing.T, - startingHeight uint32, graphInstance *testGraphInstance, - strictPruning bool) *testCtx { +func createTestCtxFromGraphInstance(t *testing.T, startingHeight uint32, + graphInstance *testGraphInstance) *testCtx { return createTestCtxFromGraphInstanceAssumeValid( - t, startingHeight, graphInstance, false, strictPruning, + t, startingHeight, graphInstance, ) } func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, - startingHeight uint32, graphInstance *testGraphInstance, - assumeValid bool, strictPruning bool) *testCtx { + startingHeight uint32, graphInstance *testGraphInstance) *testCtx { // We'll initialize an instance of the channel router with mock // versions of the chain and channel notifier. As we don't need to test // any p2p functionality, the peer send and switch send messages won't // be populated. chain := newMockChain(startingHeight) - chainView := newMockChainView(chain) pathFindingConfig := PathFindingConfig{ MinProbability: 0.01, @@ -154,50 +143,34 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, MissionControl: mc, } - notifier := &lnmock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), - SpendChan: make(chan *chainntnfs.SpendDetail), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } + graphBuilder := newMockGraphBuilder(graphInstance.graph) router, err := New(Config{ - SelfNode: sourceNode.PubKeyBytes, - RoutingGraph: newMockGraphSessionChanDB(graphInstance.graph), - Graph: graphInstance.graph, - Chain: chain, - ChainView: chainView, - Payer: &mockPaymentAttemptDispatcherOld{}, - Notifier: notifier, - Control: makeMockControlTower(), - MissionControl: mc, - SessionSource: sessionSource, - ChannelPruneExpiry: time.Hour * 24, - GraphPruneInterval: time.Hour * 2, - GetLink: graphInstance.getLink, + SelfNode: sourceNode.PubKeyBytes, + RoutingGraph: newMockGraphSessionChanDB(graphInstance.graph), + Chain: chain, + Payer: &mockPaymentAttemptDispatcherOld{}, + Control: makeMockControlTower(), + MissionControl: mc, + SessionSource: sessionSource, + GetLink: graphInstance.getLink, NextPaymentID: func() (uint64, error) { next := atomic.AddUint64(&uniquePaymentID, 1) return next, nil }, - PathFindingConfig: pathFindingConfig, - Clock: clock.NewTestClock(time.Unix(1, 0)), - AssumeChannelValid: assumeValid, - StrictZombiePruning: strictPruning, - IsAlias: func(scid lnwire.ShortChannelID) bool { - return false - }, + PathFindingConfig: pathFindingConfig, + Clock: clock.NewTestClock(time.Unix(1, 0)), + ApplyChannelUpdate: graphBuilder.ApplyChannelUpdate, }) - require.NoError(t, err, "unable to create router") require.NoError(t, router.Start(), "unable to start router") ctx := &testCtx{ - router: router, - graph: graphInstance.graph, - aliases: graphInstance.aliasMap, - privKeys: graphInstance.privKeyMap, - channelIDs: graphInstance.channelIDs, - chain: chain, - chainView: chainView, - notifier: notifier, + router: router, + graphBuilder: graphBuilder, + graph: graphInstance.graph, + aliases: graphInstance.aliasMap, + privKeys: graphInstance.privKeyMap, + channelIDs: graphInstance.channelIDs, } t.Cleanup(func() { @@ -207,27 +180,27 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, return ctx } -func createTestCtxSingleNode(t *testing.T, - startingHeight uint32) *testCtx { - - graph, graphBackend, err := makeTestGraph(t, true) - require.NoError(t, err, "failed to make test graph") +func createTestNode() (*channeldb.LightningNode, error) { + updateTime := rand.Int63() - sourceNode, err := createTestNode() - require.NoError(t, err, "failed to create test node") - - require.NoError(t, - graph.SetSourceNode(sourceNode), "failed to set source node", - ) + priv, err := btcec.NewPrivateKey() + if err != nil { + return nil, errors.Errorf("unable create private key: %v", err) + } - graphInstance := &testGraphInstance{ - graph: graph, - graphBackend: graphBackend, + pub := priv.PubKey().SerializeCompressed() + n := &channeldb.LightningNode{ + HaveNodeAnnouncement: true, + LastUpdate: time.Unix(updateTime, 0), + Addresses: testAddrs, + Color: color.RGBA{1, 2, 3, 0}, + Alias: "kek" + string(pub[:]), + AuthSigBytes: testSig.Serialize(), + Features: testFeatures, } + copy(n.PubKeyBytes[:], pub) - return createTestCtxFromGraphInstance( - t, startingHeight, graphInstance, false, - ) + return n, nil } func createTestCtxFromFile(t *testing.T, @@ -238,9 +211,7 @@ func createTestCtxFromFile(t *testing.T, graphInstance, err := parseTestGraph(t, true, testGraph) require.NoError(t, err, "unable to create test graph") - return createTestCtxFromGraphInstance( - t, startingHeight, graphInstance, false, - ) + return createTestCtxFromGraphInstance(t, startingHeight, graphInstance) } // Add valid signature to channel update simulated as error received from the @@ -474,13 +445,11 @@ func TestChannelUpdateValidation(t *testing.T) { require.NoError(t, err, "unable to create graph") const startingBlockHeight = 101 - ctx := createTestCtxFromGraphInstance( - t, startingBlockHeight, testGraph, true, - ) + ctx := createTestCtxFromGraphInstance(t, startingBlockHeight, testGraph) // Assert that the initially configured fee is retrieved correctly. - _, e1, e2, err := ctx.router.GetChannelByID( - lnwire.NewShortChanIDFromInt(1), + _, e1, e2, err := ctx.graph.FetchChannelEdgesByID( + lnwire.NewShortChanIDFromInt(1).ToUint64(), ) require.NoError(t, err, "cannot retrieve channel") @@ -541,14 +510,18 @@ func TestChannelUpdateValidation(t *testing.T) { // empty for this test. var payment lntypes.Hash + // Instruct the mock graph builder to reject the next update we send + // it. + ctx.graphBuilder.setNextReject(true) + // Send off the payment request to the router. The specified route // should be attempted and the channel update should be received by - // router and ignored because it is missing a valid signature. + // graph and ignored because it is missing a valid signature. _, err = ctx.router.SendToRoute(payment, rt) require.Error(t, err, "expected route to fail with channel update") - _, e1, e2, err = ctx.router.GetChannelByID( - lnwire.NewShortChanIDFromInt(1), + _, e1, e2, err = ctx.graph.FetchChannelEdgesByID( + lnwire.NewShortChanIDFromInt(1).ToUint64(), ) require.NoError(t, err, "cannot retrieve channel") @@ -560,14 +533,17 @@ func TestChannelUpdateValidation(t *testing.T) { // Next, add a signature to the channel update. signErrChanUpdate(t, testGraph.privKeyMap["b"], &errChanUpdate) + // Let the graph builder accept the next update. + ctx.graphBuilder.setNextReject(false) + // Retry the payment using the same route as before. _, err = ctx.router.SendToRoute(payment, rt) require.Error(t, err, "expected route to fail with channel update") // This time a valid signature was supplied and the policy change should // have been applied to the graph. - _, e1, e2, err = ctx.router.GetChannelByID( - lnwire.NewShortChanIDFromInt(1), + _, e1, e2, err = ctx.graph.FetchChannelEdgesByID( + lnwire.NewShortChanIDFromInt(1).ToUint64(), ) require.NoError(t, err, "cannot retrieve channel") @@ -1202,1677 +1178,186 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { ) } -// TestAddProof checks that we can update the channel proof after channel -// info was added to the database. -func TestAddProof(t *testing.T) { +// TestFindPathFeeWeighting tests that the findPath method will properly prefer +// routes with lower fees over routes with lower time lock values. This is +// meant to exercise the fact that the internal findPath method ranks edges +// with the square of the total fee in order bias towards lower fees. +func TestFindPathFeeWeighting(t *testing.T) { t.Parallel() - ctx := createTestCtxSingleNode(t, 0) + const startingBlockHeight = 101 + ctx := createTestCtxFromFile(t, startingBlockHeight, basicGraphFilePath) - // Before creating out edge, we'll create two new nodes within the - // network that the channel will connect. - node1, err := createTestNode() - if err != nil { - t.Fatal(err) - } - node2, err := createTestNode() - if err != nil { - t.Fatal(err) - } + var preImage [32]byte + copy(preImage[:], bytes.Repeat([]byte{9}, 32)) - // In order to be able to add the edge we should have a valid funding - // UTXO within the blockchain. - fundingTx, _, chanID, err := createChannelEdge(ctx, - bitcoinKey1.SerializeCompressed(), bitcoinKey2.SerializeCompressed(), - 100, 0) - require.NoError(t, err, "unable create channel edge") - fundingBlock := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{fundingTx}, - } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) + sourceNode, err := ctx.graph.SourceNode() + require.NoError(t, err, "unable to fetch source node") - // After utxo was recreated adding the edge without the proof. - edge := &models.ChannelEdgeInfo{ - ChannelID: chanID.ToUint64(), - NodeKey1Bytes: node1.PubKeyBytes, - NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: nil, - } - copy(edge.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) - copy(edge.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) + amt := lnwire.MilliSatoshi(100) - if err := ctx.router.AddEdge(edge); err != nil { - t.Fatalf("unable to add edge: %v", err) - } + target := ctx.aliases["luoji"] - // Now we'll attempt to update the proof and check that it has been - // properly updated. - if err := ctx.router.AddProof(*chanID, &testAuthProof); err != nil { - t.Fatalf("unable to add proof: %v", err) - } + // We'll now attempt a path finding attempt using this set up. Due to + // the edge weighting, we should select the direct path over the 2 hop + // path even though the direct path has a higher potential time lock. + path, err := dbFindPath( + ctx.graph, nil, &mockBandwidthHints{}, + noRestrictions, + testPathFindingConfig, + sourceNode.PubKeyBytes, target, amt, 0, 0, + ) + require.NoError(t, err, "unable to find path") - info, _, _, err := ctx.router.GetChannelByID(*chanID) - require.NoError(t, err, "unable to get channel") - if info.AuthProof == nil { - t.Fatal("proof have been updated") + // The route that was chosen should be exactly one hop, and should be + // directly to luoji. + if len(path) != 1 { + t.Fatalf("expected path length of 1, instead was: %v", len(path)) + } + if path[0].policy.ToNodePubKey() != ctx.aliases["luoji"] { + t.Fatalf("wrong node: %v", path[0].policy.ToNodePubKey()) } } -// TestIgnoreNodeAnnouncement tests that adding a node to the router that is -// not known from any channel announcement, leads to the announcement being -// ignored. -func TestIgnoreNodeAnnouncement(t *testing.T) { +// TestEmptyRoutesGenerateSphinxPacket tests that the generateSphinxPacket +// function is able to gracefully handle being passed a nil set of hops for the +// route by the caller. +func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { t.Parallel() - const startingBlockHeight = 101 - ctx := createTestCtxFromFile(t, startingBlockHeight, basicGraphFilePath) - - pub := priv1.PubKey() - node := &channeldb.LightningNode{ - HaveNodeAnnouncement: true, - LastUpdate: time.Unix(123, 0), - Addresses: testAddrs, - Color: color.RGBA{1, 2, 3, 0}, - Alias: "node11", - AuthSigBytes: testSig.Serialize(), - Features: testFeatures, - } - copy(node.PubKeyBytes[:], pub.SerializeCompressed()) - - err := ctx.router.AddNode(node) - if !IsError(err, ErrIgnored) { - t.Fatalf("expected to get ErrIgnore, instead got: %v", err) + sessionKey, _ := btcec.NewPrivateKey() + emptyRoute := &route.Route{} + _, _, err := generateSphinxPacket(emptyRoute, testHash[:], sessionKey) + if err != route.ErrNoRouteHopsProvided { + t.Fatalf("expected empty hops error: instead got: %v", err) } } -// TestIgnoreChannelEdgePolicyForUnknownChannel checks that a router will -// ignore a channel policy for a channel not in the graph. -func TestIgnoreChannelEdgePolicyForUnknownChannel(t *testing.T) { +// TestUnknownErrorSource tests that if the source of an error is unknown, all +// edges along the route will be pruned. +func TestUnknownErrorSource(t *testing.T) { t.Parallel() - const startingBlockHeight = 101 + // Setup a network. It contains two paths to c: a->b->c and an + // alternative a->d->c. + chanCapSat := btcutil.Amount(100000) + testChannels := []*testChannel{ + symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 1), + symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 3), + symmetricTestChannel("a", "d", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + FeeBaseMsat: 100000, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 2), + symmetricTestChannel("d", "c", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + FeeBaseMsat: 100000, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 4), + } - // Setup an initially empty network. - testChannels := []*testChannel{} - testGraph, err := createTestGraphFromChannels( - t, true, testChannels, "roasbeef", - ) + testGraph, err := createTestGraphFromChannels(t, true, testChannels, "a") require.NoError(t, err, "unable to create graph") - ctx := createTestCtxFromGraphInstance( - t, startingBlockHeight, testGraph, false, + const startingBlockHeight = 101 + ctx := createTestCtxFromGraphInstance(t, startingBlockHeight, testGraph) + + // Create a payment to node c. + payment := createDummyLightningPayment( + t, ctx.aliases["c"], lnwire.NewMSatFromSatoshis(1000), ) - var pub1 [33]byte - copy(pub1[:], priv1.PubKey().SerializeCompressed()) + // We'll modify the SendToSwitch method so that it simulates hop b as a + // node that returns an unparsable failure if approached via the a->b + // channel. + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - var pub2 [33]byte - copy(pub2[:], priv2.PubKey().SerializeCompressed()) + // If channel a->b is used, return an error without + // source and message. The sender won't know the origin + // of the error. + if firstHop.ToUint64() == 1 { + return [32]byte{}, + htlcswitch.ErrUnreadableFailureMessage + } - // Add the edge between the two unknown nodes to the graph, and check - // that the nodes are found after the fact. - fundingTx, _, chanID, err := createChannelEdge( - ctx, bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), 10000, 500, - ) - require.NoError(t, err, "unable to create channel edge") - fundingBlock := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{fundingTx}, - } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) + // Otherwise the payment succeeds. + return lntypes.Preimage{}, nil + }) - edge := &models.ChannelEdgeInfo{ - ChannelID: chanID.ToUint64(), - NodeKey1Bytes: pub1, - NodeKey2Bytes: pub2, - BitcoinKey1Bytes: pub1, - BitcoinKey2Bytes: pub2, - AuthProof: nil, - } - edgePolicy := &models.ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - } + // Send off the payment request to the router. The expectation is that + // the route a->b->c is tried first. An unreadable faiure is returned + // which should pruning the channel a->b. We expect the payment to + // succeed via a->d. + _, _, err = ctx.router.SendPayment(payment) + require.NoErrorf(t, err, "unable to send payment: %v", + payment.paymentHash) - // Attempt to update the edge. This should be ignored, since the edge - // is not yet added to the router. - err = ctx.router.UpdateEdge(edgePolicy) - if !IsError(err, ErrIgnored) { - t.Fatalf("expected to get ErrIgnore, instead got: %v", err) - } + // Next we modify payment result to return an unknown failure. + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - // Add the edge. - if err := ctx.router.AddEdge(edge); err != nil { - t.Fatalf("expected to be able to add edge to the channel graph,"+ - " even though the vertexes were unknown: %v.", err) - } + // If channel a->b is used, simulate that the failure + // couldn't be decoded (FailureMessage is nil). + if firstHop.ToUint64() == 2 { + return [32]byte{}, + htlcswitch.NewUnknownForwardingError(1) + } + + // Otherwise the payment succeeds. + return lntypes.Preimage{}, nil + }) - // Now updating the edge policy should succeed. - if err := ctx.router.UpdateEdge(edgePolicy); err != nil { - t.Fatalf("unable to update edge policy: %v", err) + // Send off the payment request to the router. We expect the payment to + // fail because both routes have been pruned. + payment.paymentHash[1] ^= 1 + _, _, err = ctx.router.SendPayment(payment) + if err == nil { + t.Fatalf("expected payment to fail") } } -// TestAddEdgeUnknownVertexes tests that if an edge is added that contains two -// vertexes which we don't know of, the edge should be available for use -// regardless. This is due to the fact that we don't actually need node -// announcements for the channel vertexes to be able to use the channel. -func TestAddEdgeUnknownVertexes(t *testing.T) { +// TestSendToRouteStructuredError asserts that SendToRoute returns a structured +// error. +func TestSendToRouteStructuredError(t *testing.T) { t.Parallel() - const startingBlockHeight = 101 - ctx := createTestCtxFromFile(t, startingBlockHeight, basicGraphFilePath) - - var pub1 [33]byte - copy(pub1[:], priv1.PubKey().SerializeCompressed()) + // Setup a three node network. + chanCapSat := btcutil.Amount(100000) + testChannels := []*testChannel{ + symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 1), + symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 2), + } - var pub2 [33]byte - copy(pub2[:], priv2.PubKey().SerializeCompressed()) + testGraph, err := createTestGraphFromChannels(t, true, testChannels, "a") + require.NoError(t, err, "unable to create graph") - // The two nodes we are about to add should not exist yet. - _, exists1, err := ctx.graph.HasLightningNode(pub1) - require.NoError(t, err, "unable to query graph") - if exists1 { - t.Fatalf("node already existed") - } - _, exists2, err := ctx.graph.HasLightningNode(pub2) - require.NoError(t, err, "unable to query graph") - if exists2 { - t.Fatalf("node already existed") - } - - // Add the edge between the two unknown nodes to the graph, and check - // that the nodes are found after the fact. - fundingTx, _, chanID, err := createChannelEdge(ctx, - bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), - 10000, 500, - ) - require.NoError(t, err, "unable to create channel edge") - fundingBlock := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{fundingTx}, - } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - - edge := &models.ChannelEdgeInfo{ - ChannelID: chanID.ToUint64(), - NodeKey1Bytes: pub1, - NodeKey2Bytes: pub2, - BitcoinKey1Bytes: pub1, - BitcoinKey2Bytes: pub2, - AuthProof: nil, - } - if err := ctx.router.AddEdge(edge); err != nil { - t.Fatalf("expected to be able to add edge to the channel graph,"+ - " even though the vertexes were unknown: %v.", err) - } - - // We must add the edge policy to be able to use the edge for route - // finding. - edgePolicy := &models.ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - ToNode: edge.NodeKey2Bytes, - } - edgePolicy.ChannelFlags = 0 - - if err := ctx.router.UpdateEdge(edgePolicy); err != nil { - t.Fatalf("unable to update edge policy: %v", err) - } - - // Create edge in the other direction as well. - edgePolicy = &models.ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - ToNode: edge.NodeKey1Bytes, - } - edgePolicy.ChannelFlags = 1 - - if err := ctx.router.UpdateEdge(edgePolicy); err != nil { - t.Fatalf("unable to update edge policy: %v", err) - } - - // After adding the edge between the two previously unknown nodes, they - // should have been added to the graph. - _, exists1, err = ctx.graph.HasLightningNode(pub1) - require.NoError(t, err, "unable to query graph") - if !exists1 { - t.Fatalf("node1 was not added to the graph") - } - _, exists2, err = ctx.graph.HasLightningNode(pub2) - require.NoError(t, err, "unable to query graph") - if !exists2 { - t.Fatalf("node2 was not added to the graph") - } - - // We will connect node1 to the rest of the test graph, and make sure - // we can find a route to node2, which will use the just added channel - // edge. - - // We will connect node 1 to "sophon" - connectNode := ctx.aliases["sophon"] - connectNodeKey, err := btcec.ParsePubKey(connectNode[:]) - if err != nil { - t.Fatal(err) - } - - var ( - pubKey1 *btcec.PublicKey - pubKey2 *btcec.PublicKey - ) - node1Bytes := priv1.PubKey().SerializeCompressed() - node2Bytes := connectNode - if bytes.Compare(node1Bytes[:], node2Bytes[:]) == -1 { - pubKey1 = priv1.PubKey() - pubKey2 = connectNodeKey - } else { - pubKey1 = connectNodeKey - pubKey2 = priv1.PubKey() - } - - fundingTx, _, chanID, err = createChannelEdge(ctx, - pubKey1.SerializeCompressed(), pubKey2.SerializeCompressed(), - 10000, 510) - require.NoError(t, err, "unable to create channel edge") - fundingBlock = &wire.MsgBlock{ - Transactions: []*wire.MsgTx{fundingTx}, - } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - - edge = &models.ChannelEdgeInfo{ - ChannelID: chanID.ToUint64(), - AuthProof: nil, - } - copy(edge.NodeKey1Bytes[:], node1Bytes) - edge.NodeKey2Bytes = node2Bytes - copy(edge.BitcoinKey1Bytes[:], node1Bytes) - edge.BitcoinKey2Bytes = node2Bytes - - if err := ctx.router.AddEdge(edge); err != nil { - t.Fatalf("unable to add edge to the channel graph: %v.", err) - } - - edgePolicy = &models.ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - ToNode: edge.NodeKey2Bytes, - } - edgePolicy.ChannelFlags = 0 - - if err := ctx.router.UpdateEdge(edgePolicy); err != nil { - t.Fatalf("unable to update edge policy: %v", err) - } - - edgePolicy = &models.ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - ToNode: edge.NodeKey1Bytes, - } - edgePolicy.ChannelFlags = 1 - - if err := ctx.router.UpdateEdge(edgePolicy); err != nil { - t.Fatalf("unable to update edge policy: %v", err) - } - - // We should now be able to find a route to node 2. - paymentAmt := lnwire.NewMSatFromSatoshis(100) - targetNode := priv2.PubKey() - var targetPubKeyBytes route.Vertex - copy(targetPubKeyBytes[:], targetNode.SerializeCompressed()) - - req, err := NewRouteRequest( - ctx.router.cfg.SelfNode, &targetPubKeyBytes, - paymentAmt, 0, noRestrictions, nil, nil, nil, MinCLTVDelta, - ) - require.NoError(t, err, "invalid route request") - _, _, err = ctx.router.FindRoute(req) - require.NoError(t, err, "unable to find any routes") - - // Now check that we can update the node info for the partial node - // without messing up the channel graph. - n1 := &channeldb.LightningNode{ - HaveNodeAnnouncement: true, - LastUpdate: time.Unix(123, 0), - Addresses: testAddrs, - Color: color.RGBA{1, 2, 3, 0}, - Alias: "node11", - AuthSigBytes: testSig.Serialize(), - Features: testFeatures, - } - copy(n1.PubKeyBytes[:], priv1.PubKey().SerializeCompressed()) - - if err := ctx.router.AddNode(n1); err != nil { - t.Fatalf("could not add node: %v", err) - } - - n2 := &channeldb.LightningNode{ - HaveNodeAnnouncement: true, - LastUpdate: time.Unix(123, 0), - Addresses: testAddrs, - Color: color.RGBA{1, 2, 3, 0}, - Alias: "node22", - AuthSigBytes: testSig.Serialize(), - Features: testFeatures, - } - copy(n2.PubKeyBytes[:], priv2.PubKey().SerializeCompressed()) - - if err := ctx.router.AddNode(n2); err != nil { - t.Fatalf("could not add node: %v", err) - } - - // Should still be able to find the route, and the info should be - // updated. - req, err = NewRouteRequest( - ctx.router.cfg.SelfNode, &targetPubKeyBytes, - paymentAmt, 0, noRestrictions, nil, nil, nil, MinCLTVDelta, - ) - require.NoError(t, err, "invalid route request") - - _, _, err = ctx.router.FindRoute(req) - require.NoError(t, err, "unable to find any routes") - - copy1, err := ctx.graph.FetchLightningNode(pub1) - require.NoError(t, err, "unable to fetch node") - - if copy1.Alias != n1.Alias { - t.Fatalf("fetched node not equal to original") - } - - copy2, err := ctx.graph.FetchLightningNode(pub2) - require.NoError(t, err, "unable to fetch node") - - if copy2.Alias != n2.Alias { - t.Fatalf("fetched node not equal to original") - } -} - -// TestWakeUpOnStaleBranch tests that upon startup of the ChannelRouter, if the -// the chain previously reflected in the channel graph is stale (overtaken by a -// longer chain), the channel router will prune the graph for any channels -// confirmed on the stale chain, and resync to the main chain. -func TestWakeUpOnStaleBranch(t *testing.T) { - t.Parallel() - - const startingBlockHeight = 101 - ctx := createTestCtxSingleNode(t, startingBlockHeight) - - const chanValue = 10000 - - // chanID1 will not be reorged out. - var chanID1 uint64 - - // chanID2 will be reorged out. - var chanID2 uint64 - - // Create 10 common blocks, confirming chanID1. - for i := uint32(1); i <= 10; i++ { - block := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - } - height := startingBlockHeight + i - if i == 5 { - fundingTx, _, chanID, err := createChannelEdge(ctx, - bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), - chanValue, height) - if err != nil { - t.Fatalf("unable create channel edge: %v", err) - } - block.Transactions = append(block.Transactions, - fundingTx) - chanID1 = chanID.ToUint64() - - } - ctx.chain.addBlock(block, height, rand.Uint32()) - ctx.chain.setBestBlock(int32(height)) - ctx.chainView.notifyBlock(block.BlockHash(), height, - []*wire.MsgTx{}, t) - } - - // Give time to process new blocks - time.Sleep(time.Millisecond * 500) - - _, forkHeight, err := ctx.chain.GetBestBlock() - require.NoError(t, err, "unable to ge best block") - - // Create 10 blocks on the minority chain, confirming chanID2. - for i := uint32(1); i <= 10; i++ { - block := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - } - height := uint32(forkHeight) + i - if i == 5 { - fundingTx, _, chanID, err := createChannelEdge(ctx, - bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), - chanValue, height) - if err != nil { - t.Fatalf("unable create channel edge: %v", err) - } - block.Transactions = append(block.Transactions, - fundingTx) - chanID2 = chanID.ToUint64() - } - ctx.chain.addBlock(block, height, rand.Uint32()) - ctx.chain.setBestBlock(int32(height)) - ctx.chainView.notifyBlock(block.BlockHash(), height, - []*wire.MsgTx{}, t) - } - // Give time to process new blocks - time.Sleep(time.Millisecond * 500) - - // Now add the two edges to the channel graph, and check that they - // correctly show up in the database. - node1, err := createTestNode() - require.NoError(t, err, "unable to create test node") - node2, err := createTestNode() - require.NoError(t, err, "unable to create test node") - - edge1 := &models.ChannelEdgeInfo{ - ChannelID: chanID1, - NodeKey1Bytes: node1.PubKeyBytes, - NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - } - copy(edge1.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) - copy(edge1.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) - - if err := ctx.router.AddEdge(edge1); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - edge2 := &models.ChannelEdgeInfo{ - ChannelID: chanID2, - NodeKey1Bytes: node1.PubKeyBytes, - NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - } - copy(edge2.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) - copy(edge2.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) - - if err := ctx.router.AddEdge(edge2); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - // Check that the fundingTxs are in the graph db. - _, _, has, isZombie, err := ctx.graph.HasChannelEdge(chanID1) - if err != nil { - t.Fatalf("error looking for edge: %v", chanID1) - } - if !has { - t.Fatalf("could not find edge in graph") - } - if isZombie { - t.Fatal("edge was marked as zombie") - } - - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) - if err != nil { - t.Fatalf("error looking for edge: %v", chanID2) - } - if !has { - t.Fatalf("could not find edge in graph") - } - if isZombie { - t.Fatal("edge was marked as zombie") - } - - // Stop the router, so we can reorg the chain while its offline. - if err := ctx.router.Stop(); err != nil { - t.Fatalf("unable to stop router: %v", err) - } - - // Create a 15 block fork. - for i := uint32(1); i <= 15; i++ { - block := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - } - height := uint32(forkHeight) + i - ctx.chain.addBlock(block, height, rand.Uint32()) - ctx.chain.setBestBlock(int32(height)) - } - - // Give time to process new blocks. - time.Sleep(time.Millisecond * 500) - - source, err := ctx.graph.SourceNode() - require.NoError(t, err) - - // Create new router with same graph database. - router, err := New(Config{ - SelfNode: source.PubKeyBytes, - RoutingGraph: newMockGraphSessionChanDB(ctx.graph), - Graph: ctx.graph, - Chain: ctx.chain, - ChainView: ctx.chainView, - Payer: &mockPaymentAttemptDispatcherOld{}, - Control: makeMockControlTower(), - ChannelPruneExpiry: time.Hour * 24, - GraphPruneInterval: time.Hour * 2, - - // We'll set the delay to zero to prune immediately. - FirstTimePruneDelay: 0, - - IsAlias: func(scid lnwire.ShortChannelID) bool { - return false - }, - }) - if err != nil { - t.Fatalf("unable to create router %v", err) - } - - // It should resync to the longer chain on startup. - if err := router.Start(); err != nil { - t.Fatalf("unable to start router: %v", err) - } - - // The channel with chanID2 should not be in the database anymore, - // since it is not confirmed on the longest chain. chanID1 should - // still be. - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID1) - if err != nil { - t.Fatalf("error looking for edge: %v", chanID1) - } - if !has { - t.Fatalf("did not find edge in graph") - } - if isZombie { - t.Fatal("edge was marked as zombie") - } - - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) - if err != nil { - t.Fatalf("error looking for edge: %v", chanID2) - } - if has { - t.Fatalf("found edge in graph") - } - if isZombie { - t.Fatal("reorged edge should not be marked as zombie") - } -} - -// TestDisconnectedBlocks checks that the router handles a reorg happening when -// it is active. -func TestDisconnectedBlocks(t *testing.T) { - t.Parallel() - - const startingBlockHeight = 101 - ctx := createTestCtxSingleNode(t, startingBlockHeight) - - const chanValue = 10000 - - // chanID1 will not be reorged out, while chanID2 will be reorged out. - var chanID1, chanID2 uint64 - - // Create 10 common blocks, confirming chanID1. - for i := uint32(1); i <= 10; i++ { - block := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - } - height := startingBlockHeight + i - if i == 5 { - fundingTx, _, chanID, err := createChannelEdge(ctx, - bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), - chanValue, height) - if err != nil { - t.Fatalf("unable create channel edge: %v", err) - } - block.Transactions = append(block.Transactions, - fundingTx) - chanID1 = chanID.ToUint64() - - } - ctx.chain.addBlock(block, height, rand.Uint32()) - ctx.chain.setBestBlock(int32(height)) - ctx.chainView.notifyBlock(block.BlockHash(), height, - []*wire.MsgTx{}, t) - } - - // Give time to process new blocks - time.Sleep(time.Millisecond * 500) - - _, forkHeight, err := ctx.chain.GetBestBlock() - require.NoError(t, err, "unable to get best block") - - // Create 10 blocks on the minority chain, confirming chanID2. - var minorityChain []*wire.MsgBlock - for i := uint32(1); i <= 10; i++ { - block := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - } - height := uint32(forkHeight) + i - if i == 5 { - fundingTx, _, chanID, err := createChannelEdge(ctx, - bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), - chanValue, height) - if err != nil { - t.Fatalf("unable create channel edge: %v", err) - } - block.Transactions = append(block.Transactions, - fundingTx) - chanID2 = chanID.ToUint64() - } - minorityChain = append(minorityChain, block) - ctx.chain.addBlock(block, height, rand.Uint32()) - ctx.chain.setBestBlock(int32(height)) - ctx.chainView.notifyBlock(block.BlockHash(), height, - []*wire.MsgTx{}, t) - } - // Give time to process new blocks - time.Sleep(time.Millisecond * 500) - - // Now add the two edges to the channel graph, and check that they - // correctly show up in the database. - node1, err := createTestNode() - require.NoError(t, err, "unable to create test node") - node2, err := createTestNode() - require.NoError(t, err, "unable to create test node") - - edge1 := &models.ChannelEdgeInfo{ - ChannelID: chanID1, - NodeKey1Bytes: node1.PubKeyBytes, - NodeKey2Bytes: node2.PubKeyBytes, - BitcoinKey1Bytes: node1.PubKeyBytes, - BitcoinKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - } - copy(edge1.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) - copy(edge1.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) - - if err := ctx.router.AddEdge(edge1); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - edge2 := &models.ChannelEdgeInfo{ - ChannelID: chanID2, - NodeKey1Bytes: node1.PubKeyBytes, - NodeKey2Bytes: node2.PubKeyBytes, - BitcoinKey1Bytes: node1.PubKeyBytes, - BitcoinKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - } - copy(edge2.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) - copy(edge2.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) - - if err := ctx.router.AddEdge(edge2); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - // Check that the fundingTxs are in the graph db. - _, _, has, isZombie, err := ctx.graph.HasChannelEdge(chanID1) - if err != nil { - t.Fatalf("error looking for edge: %v", chanID1) - } - if !has { - t.Fatalf("could not find edge in graph") - } - if isZombie { - t.Fatal("edge was marked as zombie") - } - - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) - if err != nil { - t.Fatalf("error looking for edge: %v", chanID2) - } - if !has { - t.Fatalf("could not find edge in graph") - } - if isZombie { - t.Fatal("edge was marked as zombie") - } - - // Create a 15 block fork. We first let the chainView notify the router - // about stale blocks, before sending the now connected blocks. We do - // this because we expect this order from the chainview. - ctx.chainView.notifyStaleBlockAck = make(chan struct{}, 1) - for i := len(minorityChain) - 1; i >= 0; i-- { - block := minorityChain[i] - height := uint32(forkHeight) + uint32(i) + 1 - ctx.chainView.notifyStaleBlock(block.BlockHash(), height, - block.Transactions, t) - <-ctx.chainView.notifyStaleBlockAck - } - - time.Sleep(time.Second * 2) - - ctx.chainView.notifyBlockAck = make(chan struct{}, 1) - for i := uint32(1); i <= 15; i++ { - block := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - } - height := uint32(forkHeight) + i - ctx.chain.addBlock(block, height, rand.Uint32()) - ctx.chain.setBestBlock(int32(height)) - ctx.chainView.notifyBlock(block.BlockHash(), height, - block.Transactions, t) - <-ctx.chainView.notifyBlockAck - } - - time.Sleep(time.Millisecond * 500) - - // chanID2 should not be in the database anymore, since it is not - // confirmed on the longest chain. chanID1 should still be. - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID1) - if err != nil { - t.Fatalf("error looking for edge: %v", chanID1) - } - if !has { - t.Fatalf("did not find edge in graph") - } - if isZombie { - t.Fatal("edge was marked as zombie") - } - - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) - if err != nil { - t.Fatalf("error looking for edge: %v", chanID2) - } - if has { - t.Fatalf("found edge in graph") - } - if isZombie { - t.Fatal("reorged edge should not be marked as zombie") - } -} - -// TestChansClosedOfflinePruneGraph tests that if channels we know of are -// closed while we're offline, then once we resume operation of the -// ChannelRouter, then the channels are properly pruned. -func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { - t.Parallel() - - const startingBlockHeight = 101 - ctx := createTestCtxSingleNode(t, startingBlockHeight) - - const chanValue = 10000 - - // First, we'll create a channel, to be mined shortly at height 102. - block102 := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - } - nextHeight := startingBlockHeight + 1 - fundingTx1, chanUTXO, chanID1, err := createChannelEdge(ctx, - bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), - chanValue, uint32(nextHeight)) - require.NoError(t, err, "unable create channel edge") - block102.Transactions = append(block102.Transactions, fundingTx1) - ctx.chain.addBlock(block102, uint32(nextHeight), rand.Uint32()) - ctx.chain.setBestBlock(int32(nextHeight)) - ctx.chainView.notifyBlock(block102.BlockHash(), uint32(nextHeight), - []*wire.MsgTx{}, t) - - // We'll now create the edges and nodes within the database required - // for the ChannelRouter to properly recognize the channel we added - // above. - node1, err := createTestNode() - require.NoError(t, err, "unable to create test node") - node2, err := createTestNode() - require.NoError(t, err, "unable to create test node") - edge1 := &models.ChannelEdgeInfo{ - ChannelID: chanID1.ToUint64(), - NodeKey1Bytes: node1.PubKeyBytes, - NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - } - copy(edge1.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) - copy(edge1.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) - if err := ctx.router.AddEdge(edge1); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - // The router should now be aware of the channel we created above. - _, _, hasChan, isZombie, err := ctx.graph.HasChannelEdge(chanID1.ToUint64()) - if err != nil { - t.Fatalf("error looking for edge: %v", chanID1) - } - if !hasChan { - t.Fatalf("could not find edge in graph") - } - if isZombie { - t.Fatal("edge was marked as zombie") - } - - // With the transaction included, and the router's database state - // updated, we'll now mine 5 additional blocks on top of it. - for i := 0; i < 5; i++ { - nextHeight++ - - block := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - } - ctx.chain.addBlock(block, uint32(nextHeight), rand.Uint32()) - ctx.chain.setBestBlock(int32(nextHeight)) - ctx.chainView.notifyBlock(block.BlockHash(), uint32(nextHeight), - []*wire.MsgTx{}, t) - } - - // At this point, our starting height should be 107. - _, chainHeight, err := ctx.chain.GetBestBlock() - require.NoError(t, err, "unable to get best block") - if chainHeight != 107 { - t.Fatalf("incorrect chain height: expected %v, got %v", - 107, chainHeight) - } - - // Next, we'll "shut down" the router in order to simulate downtime. - if err := ctx.router.Stop(); err != nil { - t.Fatalf("unable to shutdown router: %v", err) - } - - // While the router is "offline" we'll mine 5 additional blocks, with - // the second block closing the channel we created above. - for i := 0; i < 5; i++ { - nextHeight++ - - block := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - } - - if i == 2 { - // For the second block, we'll add a transaction that - // closes the channel we created above by spending the - // output. - closingTx := wire.NewMsgTx(2) - closingTx.AddTxIn(&wire.TxIn{ - PreviousOutPoint: *chanUTXO, - }) - block.Transactions = append(block.Transactions, - closingTx) - } - - ctx.chain.addBlock(block, uint32(nextHeight), rand.Uint32()) - ctx.chain.setBestBlock(int32(nextHeight)) - ctx.chainView.notifyBlock(block.BlockHash(), uint32(nextHeight), - []*wire.MsgTx{}, t) - } - - // At this point, our starting height should be 112. - _, chainHeight, err = ctx.chain.GetBestBlock() - require.NoError(t, err, "unable to get best block") - if chainHeight != 112 { - t.Fatalf("incorrect chain height: expected %v, got %v", - 112, chainHeight) - } - - // Now we'll re-start the ChannelRouter. It should recognize that it's - // behind the main chain and prune all the blocks that it missed while - // it was down. - ctx.RestartRouter(t) - - // At this point, the channel that was pruned should no longer be known - // by the router. - _, _, hasChan, isZombie, err = ctx.graph.HasChannelEdge(chanID1.ToUint64()) - if err != nil { - t.Fatalf("error looking for edge: %v", chanID1) - } - if hasChan { - t.Fatalf("channel was found in graph but shouldn't have been") - } - if isZombie { - t.Fatal("closed channel should not be marked as zombie") - } -} - -// TestPruneChannelGraphStaleEdges ensures that we properly prune stale edges -// from the channel graph. -func TestPruneChannelGraphStaleEdges(t *testing.T) { - t.Parallel() - - freshTimestamp := time.Now() - staleTimestamp := time.Unix(0, 0) - - // We'll create the following test graph so that two of the channels - // are pruned. - testChannels := []*testChannel{ - // No edges. - { - Node1: &testChannelEnd{Alias: "a"}, - Node2: &testChannelEnd{Alias: "b"}, - Capacity: 100000, - ChannelID: 1, - }, - - // Only one edge with a stale timestamp. - { - Node1: &testChannelEnd{ - Alias: "d", - testChannelPolicy: &testChannelPolicy{ - LastUpdate: staleTimestamp, - }, - }, - Node2: &testChannelEnd{Alias: "b"}, - Capacity: 100000, - ChannelID: 2, - }, - - // Only one edge with a stale timestamp, but it's the source - // node so it won't get pruned. - { - Node1: &testChannelEnd{ - Alias: "a", - testChannelPolicy: &testChannelPolicy{ - LastUpdate: staleTimestamp, - }, - }, - Node2: &testChannelEnd{Alias: "b"}, - Capacity: 100000, - ChannelID: 3, - }, - - // Only one edge with a fresh timestamp. - { - Node1: &testChannelEnd{ - Alias: "a", - testChannelPolicy: &testChannelPolicy{ - LastUpdate: freshTimestamp, - }, - }, - Node2: &testChannelEnd{Alias: "b"}, - Capacity: 100000, - ChannelID: 4, - }, - - // One edge fresh, one edge stale. This will be pruned with - // strict pruning activated. - { - Node1: &testChannelEnd{ - Alias: "c", - testChannelPolicy: &testChannelPolicy{ - LastUpdate: freshTimestamp, - }, - }, - Node2: &testChannelEnd{ - Alias: "d", - testChannelPolicy: &testChannelPolicy{ - LastUpdate: staleTimestamp, - }, - }, - Capacity: 100000, - ChannelID: 5, - }, - - // Both edges fresh. - symmetricTestChannel("g", "h", 100000, &testChannelPolicy{ - LastUpdate: freshTimestamp, - }, 6), - - // Both edges stale, only one pruned. This should be pruned for - // both normal and strict pruning. - symmetricTestChannel("e", "f", 100000, &testChannelPolicy{ - LastUpdate: staleTimestamp, - }, 7), - } - - for _, strictPruning := range []bool{true, false} { - // We'll create our test graph and router backed with these test - // channels we've created. - testGraph, err := createTestGraphFromChannels( - t, true, testChannels, "a", - ) - if err != nil { - t.Fatalf("unable to create test graph: %v", err) - } - - const startingHeight = 100 - ctx := createTestCtxFromGraphInstance( - t, startingHeight, testGraph, strictPruning, - ) - - // All of the channels should exist before pruning them. - assertChannelsPruned(t, ctx.graph, testChannels) - - // Proceed to prune the channels - only the last one should be pruned. - if err := ctx.router.pruneZombieChans(); err != nil { - t.Fatalf("unable to prune zombie channels: %v", err) - } - - // We expect channels that have either both edges stale, or one edge - // stale with both known. - var prunedChannels []uint64 - if strictPruning { - prunedChannels = []uint64{2, 5, 7} - } else { - prunedChannels = []uint64{2, 7} - } - assertChannelsPruned(t, ctx.graph, testChannels, prunedChannels...) - } -} - -// TestPruneChannelGraphDoubleDisabled test that we can properly prune channels -// with both edges disabled from our channel graph. -func TestPruneChannelGraphDoubleDisabled(t *testing.T) { - t.Parallel() - - t.Run("no_assumechannelvalid", func(t *testing.T) { - testPruneChannelGraphDoubleDisabled(t, false) - }) - t.Run("assumechannelvalid", func(t *testing.T) { - testPruneChannelGraphDoubleDisabled(t, true) - }) -} - -func testPruneChannelGraphDoubleDisabled(t *testing.T, assumeValid bool) { - // We'll create the following test graph so that only the last channel - // is pruned. We'll use a fresh timestamp to ensure they're not pruned - // according to that heuristic. - timestamp := time.Now() - testChannels := []*testChannel{ - // Channel from self shouldn't be pruned. - symmetricTestChannel( - "self", "a", 100000, &testChannelPolicy{ - LastUpdate: timestamp, - Disabled: true, - }, 99, - ), - - // No edges. - { - Node1: &testChannelEnd{Alias: "a"}, - Node2: &testChannelEnd{Alias: "b"}, - Capacity: 100000, - ChannelID: 1, - }, - - // Only one edge disabled. - { - Node1: &testChannelEnd{ - Alias: "a", - testChannelPolicy: &testChannelPolicy{ - LastUpdate: timestamp, - Disabled: true, - }, - }, - Node2: &testChannelEnd{Alias: "b"}, - Capacity: 100000, - ChannelID: 2, - }, - - // Only one edge enabled. - { - Node1: &testChannelEnd{ - Alias: "a", - testChannelPolicy: &testChannelPolicy{ - LastUpdate: timestamp, - Disabled: false, - }, - }, - Node2: &testChannelEnd{Alias: "b"}, - Capacity: 100000, - ChannelID: 3, - }, - - // One edge disabled, one edge enabled. - { - Node1: &testChannelEnd{ - Alias: "a", - testChannelPolicy: &testChannelPolicy{ - LastUpdate: timestamp, - Disabled: true, - }, - }, - Node2: &testChannelEnd{ - Alias: "b", - testChannelPolicy: &testChannelPolicy{ - LastUpdate: timestamp, - Disabled: false, - }, - }, - Capacity: 100000, - ChannelID: 1, - }, - - // Both edges enabled. - symmetricTestChannel("c", "d", 100000, &testChannelPolicy{ - LastUpdate: timestamp, - Disabled: false, - }, 2), - - // Both edges disabled, only one pruned. - symmetricTestChannel("e", "f", 100000, &testChannelPolicy{ - LastUpdate: timestamp, - Disabled: true, - }, 3), - } - - // We'll create our test graph and router backed with these test - // channels we've created. - testGraph, err := createTestGraphFromChannels( - t, true, testChannels, "self", - ) - require.NoError(t, err, "unable to create test graph") - - const startingHeight = 100 - ctx := createTestCtxFromGraphInstanceAssumeValid( - t, startingHeight, testGraph, assumeValid, false, - ) - - // All the channels should exist within the graph before pruning them - // when not using AssumeChannelValid, otherwise we should have pruned - // the last channel on startup. - if !assumeValid { - assertChannelsPruned(t, ctx.graph, testChannels) - } else { - // Sleep to allow the pruning to finish. - time.Sleep(200 * time.Millisecond) - - prunedChannel := testChannels[len(testChannels)-1].ChannelID - assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel) - } - - if err := ctx.router.pruneZombieChans(); err != nil { - t.Fatalf("unable to prune zombie channels: %v", err) - } - - // If we attempted to prune them without AssumeChannelValid being set, - // none should be pruned. Otherwise the last channel should still be - // pruned. - if !assumeValid { - assertChannelsPruned(t, ctx.graph, testChannels) - } else { - prunedChannel := testChannels[len(testChannels)-1].ChannelID - assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel) - } -} - -// TestFindPathFeeWeighting tests that the findPath method will properly prefer -// routes with lower fees over routes with lower time lock values. This is -// meant to exercise the fact that the internal findPath method ranks edges -// with the square of the total fee in order bias towards lower fees. -func TestFindPathFeeWeighting(t *testing.T) { - t.Parallel() - - const startingBlockHeight = 101 - ctx := createTestCtxFromFile(t, startingBlockHeight, basicGraphFilePath) - - var preImage [32]byte - copy(preImage[:], bytes.Repeat([]byte{9}, 32)) - - sourceNode, err := ctx.graph.SourceNode() - require.NoError(t, err, "unable to fetch source node") - - amt := lnwire.MilliSatoshi(100) - - target := ctx.aliases["luoji"] - - // We'll now attempt a path finding attempt using this set up. Due to - // the edge weighting, we should select the direct path over the 2 hop - // path even though the direct path has a higher potential time lock. - path, err := dbFindPath( - ctx.graph, nil, &mockBandwidthHints{}, - noRestrictions, - testPathFindingConfig, - sourceNode.PubKeyBytes, target, amt, 0, 0, - ) - require.NoError(t, err, "unable to find path") - - // The route that was chosen should be exactly one hop, and should be - // directly to luoji. - if len(path) != 1 { - t.Fatalf("expected path length of 1, instead was: %v", len(path)) - } - if path[0].policy.ToNodePubKey() != ctx.aliases["luoji"] { - t.Fatalf("wrong node: %v", path[0].policy.ToNodePubKey()) - } -} - -// TestIsStaleNode tests that the IsStaleNode method properly detects stale -// node announcements. -func TestIsStaleNode(t *testing.T) { - t.Parallel() - - const startingBlockHeight = 101 - ctx := createTestCtxSingleNode(t, startingBlockHeight) - - // Before we can insert a node in to the database, we need to create a - // channel that it's linked to. - var ( - pub1 [33]byte - pub2 [33]byte - ) - copy(pub1[:], priv1.PubKey().SerializeCompressed()) - copy(pub2[:], priv2.PubKey().SerializeCompressed()) - - fundingTx, _, chanID, err := createChannelEdge(ctx, - bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), - 10000, 500) - require.NoError(t, err, "unable to create channel edge") - fundingBlock := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{fundingTx}, - } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - - edge := &models.ChannelEdgeInfo{ - ChannelID: chanID.ToUint64(), - NodeKey1Bytes: pub1, - NodeKey2Bytes: pub2, - BitcoinKey1Bytes: pub1, - BitcoinKey2Bytes: pub2, - AuthProof: nil, - } - if err := ctx.router.AddEdge(edge); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - // Before we add the node, if we query for staleness, we should get - // false, as we haven't added the full node. - updateTimeStamp := time.Unix(123, 0) - if ctx.router.IsStaleNode(pub1, updateTimeStamp) { - t.Fatalf("incorrectly detected node as stale") - } - - // With the node stub in the database, we'll add the fully node - // announcement to the database. - n1 := &channeldb.LightningNode{ - HaveNodeAnnouncement: true, - LastUpdate: updateTimeStamp, - Addresses: testAddrs, - Color: color.RGBA{1, 2, 3, 0}, - Alias: "node11", - AuthSigBytes: testSig.Serialize(), - Features: testFeatures, - } - copy(n1.PubKeyBytes[:], priv1.PubKey().SerializeCompressed()) - if err := ctx.router.AddNode(n1); err != nil { - t.Fatalf("could not add node: %v", err) - } - - // If we use the same timestamp and query for staleness, we should get - // true. - if !ctx.router.IsStaleNode(pub1, updateTimeStamp) { - t.Fatalf("failure to detect stale node update") - } - - // If we update the timestamp and once again query for staleness, it - // should report false. - newTimeStamp := time.Unix(1234, 0) - if ctx.router.IsStaleNode(pub1, newTimeStamp) { - t.Fatalf("incorrectly detected node as stale") - } -} - -// TestIsKnownEdge tests that the IsKnownEdge method properly detects stale -// channel announcements. -func TestIsKnownEdge(t *testing.T) { - t.Parallel() - - const startingBlockHeight = 101 - ctx := createTestCtxSingleNode(t, startingBlockHeight) - - // First, we'll create a new channel edge (just the info) and insert it - // into the database. - var ( - pub1 [33]byte - pub2 [33]byte - ) - copy(pub1[:], priv1.PubKey().SerializeCompressed()) - copy(pub2[:], priv2.PubKey().SerializeCompressed()) - - fundingTx, _, chanID, err := createChannelEdge(ctx, - bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), - 10000, 500) - require.NoError(t, err, "unable to create channel edge") - fundingBlock := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{fundingTx}, - } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - - edge := &models.ChannelEdgeInfo{ - ChannelID: chanID.ToUint64(), - NodeKey1Bytes: pub1, - NodeKey2Bytes: pub2, - BitcoinKey1Bytes: pub1, - BitcoinKey2Bytes: pub2, - AuthProof: nil, - } - if err := ctx.router.AddEdge(edge); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - // Now that the edge has been inserted, query is the router already - // knows of the edge should return true. - if !ctx.router.IsKnownEdge(*chanID) { - t.Fatalf("router should detect edge as known") - } -} - -// TestIsStaleEdgePolicy tests that the IsStaleEdgePolicy properly detects -// stale channel edge update announcements. -func TestIsStaleEdgePolicy(t *testing.T) { - t.Parallel() - - const startingBlockHeight = 101 - ctx := createTestCtxFromFile(t, startingBlockHeight, basicGraphFilePath) - - // First, we'll create a new channel edge (just the info) and insert it - // into the database. - var ( - pub1 [33]byte - pub2 [33]byte - ) - copy(pub1[:], priv1.PubKey().SerializeCompressed()) - copy(pub2[:], priv2.PubKey().SerializeCompressed()) - - fundingTx, _, chanID, err := createChannelEdge(ctx, - bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), - 10000, 500) - require.NoError(t, err, "unable to create channel edge") - fundingBlock := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{fundingTx}, - } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - - // If we query for staleness before adding the edge, we should get - // false. - updateTimeStamp := time.Unix(123, 0) - if ctx.router.IsStaleEdgePolicy(*chanID, updateTimeStamp, 0) { - t.Fatalf("router failed to detect fresh edge policy") - } - if ctx.router.IsStaleEdgePolicy(*chanID, updateTimeStamp, 1) { - t.Fatalf("router failed to detect fresh edge policy") - } - - edge := &models.ChannelEdgeInfo{ - ChannelID: chanID.ToUint64(), - NodeKey1Bytes: pub1, - NodeKey2Bytes: pub2, - BitcoinKey1Bytes: pub1, - BitcoinKey2Bytes: pub2, - AuthProof: nil, - } - if err := ctx.router.AddEdge(edge); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - // We'll also add two edge policies, one for each direction. - edgePolicy := &models.ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: updateTimeStamp, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - } - edgePolicy.ChannelFlags = 0 - if err := ctx.router.UpdateEdge(edgePolicy); err != nil { - t.Fatalf("unable to update edge policy: %v", err) - } - - edgePolicy = &models.ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: updateTimeStamp, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - } - edgePolicy.ChannelFlags = 1 - if err := ctx.router.UpdateEdge(edgePolicy); err != nil { - t.Fatalf("unable to update edge policy: %v", err) - } - - // Now that the edges have been added, an identical (chanID, flag, - // timestamp) tuple for each edge should be detected as a stale edge. - if !ctx.router.IsStaleEdgePolicy(*chanID, updateTimeStamp, 0) { - t.Fatalf("router failed to detect stale edge policy") - } - if !ctx.router.IsStaleEdgePolicy(*chanID, updateTimeStamp, 1) { - t.Fatalf("router failed to detect stale edge policy") - } - - // If we now update the timestamp for both edges, the router should - // detect that this tuple represents a fresh edge. - updateTimeStamp = time.Unix(9999, 0) - if ctx.router.IsStaleEdgePolicy(*chanID, updateTimeStamp, 0) { - t.Fatalf("router failed to detect fresh edge policy") - } - if ctx.router.IsStaleEdgePolicy(*chanID, updateTimeStamp, 1) { - t.Fatalf("router failed to detect fresh edge policy") - } -} - -// TestEmptyRoutesGenerateSphinxPacket tests that the generateSphinxPacket -// function is able to gracefully handle being passed a nil set of hops for the -// route by the caller. -func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { - t.Parallel() - - sessionKey, _ := btcec.NewPrivateKey() - emptyRoute := &route.Route{} - _, _, err := generateSphinxPacket(emptyRoute, testHash[:], sessionKey) - if err != route.ErrNoRouteHopsProvided { - t.Fatalf("expected empty hops error: instead got: %v", err) - } -} - -// TestUnknownErrorSource tests that if the source of an error is unknown, all -// edges along the route will be pruned. -func TestUnknownErrorSource(t *testing.T) { - t.Parallel() - - // Setup a network. It contains two paths to c: a->b->c and an - // alternative a->d->c. - chanCapSat := btcutil.Amount(100000) - testChannels := []*testChannel{ - symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{ - Expiry: 144, - FeeRate: 400, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), - }, 1), - symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{ - Expiry: 144, - FeeRate: 400, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), - }, 3), - symmetricTestChannel("a", "d", chanCapSat, &testChannelPolicy{ - Expiry: 144, - FeeRate: 400, - FeeBaseMsat: 100000, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), - }, 2), - symmetricTestChannel("d", "c", chanCapSat, &testChannelPolicy{ - Expiry: 144, - FeeRate: 400, - FeeBaseMsat: 100000, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), - }, 4), - } - - testGraph, err := createTestGraphFromChannels(t, true, testChannels, "a") - require.NoError(t, err, "unable to create graph") - - const startingBlockHeight = 101 - ctx := createTestCtxFromGraphInstance( - t, startingBlockHeight, testGraph, false, - ) - - // Create a payment to node c. - payment := createDummyLightningPayment( - t, ctx.aliases["c"], lnwire.NewMSatFromSatoshis(1000), - ) - - // We'll modify the SendToSwitch method so that it simulates hop b as a - // node that returns an unparsable failure if approached via the a->b - // channel. - ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( - func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - - // If channel a->b is used, return an error without - // source and message. The sender won't know the origin - // of the error. - if firstHop.ToUint64() == 1 { - return [32]byte{}, - htlcswitch.ErrUnreadableFailureMessage - } - - // Otherwise the payment succeeds. - return lntypes.Preimage{}, nil - }) - - // Send off the payment request to the router. The expectation is that - // the route a->b->c is tried first. An unreadable faiure is returned - // which should pruning the channel a->b. We expect the payment to - // succeed via a->d. - _, _, err = ctx.router.SendPayment(payment) - require.NoErrorf(t, err, "unable to send payment: %v", - payment.paymentHash) - - // Next we modify payment result to return an unknown failure. - ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( - func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - - // If channel a->b is used, simulate that the failure - // couldn't be decoded (FailureMessage is nil). - if firstHop.ToUint64() == 2 { - return [32]byte{}, - htlcswitch.NewUnknownForwardingError(1) - } - - // Otherwise the payment succeeds. - return lntypes.Preimage{}, nil - }) - - // Send off the payment request to the router. We expect the payment to - // fail because both routes have been pruned. - payment.paymentHash[1] ^= 1 - _, _, err = ctx.router.SendPayment(payment) - if err == nil { - t.Fatalf("expected payment to fail") - } -} - -// assertChannelsPruned ensures that only the given channels are pruned from the -// graph out of the set of all channels. -func assertChannelsPruned(t *testing.T, graph *channeldb.ChannelGraph, - channels []*testChannel, prunedChanIDs ...uint64) { - - t.Helper() - - pruned := make(map[uint64]struct{}, len(channels)) - for _, chanID := range prunedChanIDs { - pruned[chanID] = struct{}{} - } - - for _, channel := range channels { - _, shouldPrune := pruned[channel.ChannelID] - _, _, exists, isZombie, err := graph.HasChannelEdge( - channel.ChannelID, - ) - if err != nil { - t.Fatalf("unable to determine existence of "+ - "channel=%v in the graph: %v", - channel.ChannelID, err) - } - if !shouldPrune && !exists { - t.Fatalf("expected channel=%v to exist within "+ - "the graph", channel.ChannelID) - } - if shouldPrune && exists { - t.Fatalf("expected channel=%v to not exist "+ - "within the graph", channel.ChannelID) - } - if !shouldPrune && isZombie { - t.Fatalf("expected channel=%v to not be marked "+ - "as zombie", channel.ChannelID) - } - if shouldPrune && !isZombie { - t.Fatalf("expected channel=%v to be marked as "+ - "zombie", channel.ChannelID) - } - } -} - -// TestSendToRouteStructuredError asserts that SendToRoute returns a structured -// error. -func TestSendToRouteStructuredError(t *testing.T) { - t.Parallel() - - // Setup a three node network. - chanCapSat := btcutil.Amount(100000) - testChannels := []*testChannel{ - symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{ - Expiry: 144, - FeeRate: 400, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), - }, 1), - symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{ - Expiry: 144, - FeeRate: 400, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), - }, 2), - } - - testGraph, err := createTestGraphFromChannels(t, true, testChannels, "a") - require.NoError(t, err, "unable to create graph") - - const startingBlockHeight = 101 - ctx := createTestCtxFromGraphInstance( - t, startingBlockHeight, testGraph, false, - ) + const startingBlockHeight = 101 + ctx := createTestCtxFromGraphInstance(t, startingBlockHeight, testGraph) // Set up an init channel for the control tower, such that we can make // sure the payment is initiated correctly. @@ -2985,9 +1470,7 @@ func TestSendToRouteMaxHops(t *testing.T) { const startingBlockHeight = 101 - ctx := createTestCtxFromGraphInstance( - t, startingBlockHeight, testGraph, false, - ) + ctx := createTestCtxFromGraphInstance(t, startingBlockHeight, testGraph) // Create a 30 hop route that exceeds the maximum hop limit. const payAmt = lnwire.MilliSatoshi(10000) @@ -3090,9 +1573,7 @@ func TestBuildRoute(t *testing.T) { const startingBlockHeight = 101 - ctx := createTestCtxFromGraphInstance( - t, startingBlockHeight, testGraph, false, - ) + ctx := createTestCtxFromGraphInstance(t, startingBlockHeight, testGraph) checkHops := func(rt *route.Route, expected []uint64, payAddr [32]byte) { @@ -3222,263 +1703,329 @@ func TestGetPathEdges(t *testing.T) { continue } - require.NoError(t, err) - require.Equal(t, pathEdges, tc.expectedEdges) - require.Equal(t, amt, tc.expectedAmt) - } -} + require.NoError(t, err) + require.Equal(t, pathEdges, tc.expectedEdges) + require.Equal(t, amt, tc.expectedAmt) + } +} + +// TestSendToRouteSkipTempErrSuccess validates a successful payment send. +func TestSendToRouteSkipTempErrSuccess(t *testing.T) { + t.Parallel() + + var ( + payHash lntypes.Hash + payAmt = lnwire.MilliSatoshi(10000) + ) + + preimage := lntypes.Preimage{1} + testAttempt := makeSettledAttempt(t, int(payAmt), preimage) + + node, err := createTestNode() + require.NoError(t, err) + + // Create a simple 1-hop route. + hops := []*route.Hop{ + { + ChannelID: 1, + PubKeyBytes: node.PubKeyBytes, + AmtToForward: payAmt, + OutgoingTimeLock: 120, + MPP: record.NewMPP(payAmt, [32]byte{}), + }, + } + rt, err := route.NewRouteFromHops(payAmt, 100, node.PubKeyBytes, hops) + require.NoError(t, err) + + // Create mockers. + controlTower := &mockControlTower{} + payer := &mockPaymentAttemptDispatcher{} + missionControl := &mockMissionControl{} + + // Create the router. + router := &ChannelRouter{cfg: &Config{ + Control: controlTower, + Payer: payer, + MissionControl: missionControl, + Clock: clock.NewTestClock(time.Unix(1, 0)), + NextPaymentID: func() (uint64, error) { + return 0, nil + }, + }} + + // Register mockers with the expected method calls. + controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) + controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) + controlTower.On("SettleAttempt", + payHash, mock.Anything, mock.Anything, + ).Return(testAttempt, nil) + + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(nil) -// edgeCreationModifier is an enum-like type used to modify steps that are -// skipped when creating a channel in the test context. -type edgeCreationModifier uint8 + // Create a buffered chan and it will be returned by GetAttemptResult. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + payer.On("GetAttemptResult", + mock.Anything, mock.Anything, mock.Anything, + ).Return(resultChan, nil).Run(func(_ mock.Arguments) { + // Send a successful payment result. + resultChan <- &htlcswitch.PaymentResult{} + }) -const ( - // edgeCreationNoFundingTx is used to skip adding the funding - // transaction of an edge to the chain. - edgeCreationNoFundingTx edgeCreationModifier = iota + missionControl.On("ReportPaymentSuccess", + mock.Anything, rt, + ).Return(nil) - // edgeCreationNoUTXO is used to skip adding the UTXO of a channel to - // the UTXO set. - edgeCreationNoUTXO + // Mock the control tower to return the mocked payment. + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() - // edgeCreationBadScript is used to create the edge, but use the wrong - // scrip which should cause it to fail output validation. - edgeCreationBadScript -) + // Mock the payment to return nil failure reason. + payment.On("TerminalInfo").Return(nil, nil) -// newChannelEdgeInfo is a helper function used to create a new channel edge, -// possibly skipping adding it to parts of the chain/state as well. -func newChannelEdgeInfo(ctx *testCtx, fundingHeight uint32, - ecm edgeCreationModifier) (*models.ChannelEdgeInfo, error) { + // Expect a successful send to route. + attempt, err := router.SendToRouteSkipTempErr(payHash, rt) + require.NoError(t, err) + require.Equal(t, testAttempt, attempt) - node1, err := createTestNode() - if err != nil { - return nil, err - } - node2, err := createTestNode() - if err != nil { - return nil, err - } + // Assert the above methods are called as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + missionControl.AssertExpectations(t) + payment.AssertExpectations(t) +} + +// TestSendToRouteSkipTempErrNonMPP checks that an error is return when +// skipping temp error for non-MPP. +func TestSendToRouteSkipTempErrNonMPP(t *testing.T) { + t.Parallel() - fundingTx, _, chanID, err := createChannelEdge( - ctx, bitcoinKey1.SerializeCompressed(), - bitcoinKey2.SerializeCompressed(), 100, fundingHeight, + var ( + payHash lntypes.Hash + payAmt = lnwire.MilliSatoshi(10000) ) - if err != nil { - return nil, fmt.Errorf("unable to create edge: %w", err) - } - edge := &models.ChannelEdgeInfo{ - ChannelID: chanID.ToUint64(), - NodeKey1Bytes: node1.PubKeyBytes, - NodeKey2Bytes: node2.PubKeyBytes, - } - copy(edge.BitcoinKey1Bytes[:], bitcoinKey1.SerializeCompressed()) - copy(edge.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) + node, err := createTestNode() + require.NoError(t, err) - if ecm == edgeCreationNoFundingTx { - return edge, nil + // Create a simple 1-hop route without the MPP field. + hops := []*route.Hop{ + { + ChannelID: 1, + PubKeyBytes: node.PubKeyBytes, + AmtToForward: payAmt, + }, } + rt, err := route.NewRouteFromHops(payAmt, 100, node.PubKeyBytes, hops) + require.NoError(t, err) - fundingBlock := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{fundingTx}, - } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) + // Create mockers. + controlTower := &mockControlTower{} + payer := &mockPaymentAttemptDispatcher{} + missionControl := &mockMissionControl{} - if ecm == edgeCreationNoUTXO { - ctx.chain.delUtxo(wire.OutPoint{ - Hash: fundingTx.TxHash(), - }) - } + // Create the router. + router := &ChannelRouter{cfg: &Config{ + Control: controlTower, + Payer: payer, + MissionControl: missionControl, + Clock: clock.NewTestClock(time.Unix(1, 0)), + NextPaymentID: func() (uint64, error) { + return 0, nil + }, + }} - if ecm == edgeCreationBadScript { - fundingTx.TxOut[0].PkScript[0] ^= 1 - } + // Expect an error to be returned. + attempt, err := router.SendToRouteSkipTempErr(payHash, rt) + require.ErrorIs(t, ErrSkipTempErr, err) + require.Nil(t, attempt) - return edge, nil + // Assert the above methods are not called. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + missionControl.AssertExpectations(t) } -func assertChanChainRejection(t *testing.T, ctx *testCtx, - edge *models.ChannelEdgeInfo, failCode errorCode) { - - t.Helper() - - err := ctx.router.AddEdge(edge) - if !IsError(err, failCode) { - t.Fatalf("validation should have failed: %v", err) - } +// TestSendToRouteSkipTempErrTempFailure validates a temporary failure won't +// cause the payment to be failed. +func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { + t.Parallel() - // This channel should now be present in the zombie channel index. - _, _, _, isZombie, err := ctx.graph.HasChannelEdge( - edge.ChannelID, + var ( + payHash lntypes.Hash + payAmt = lnwire.MilliSatoshi(10000) ) - require.Nil(t, err) - require.True(t, isZombie, "edge should be marked as zombie") -} - -// TestChannelOnChainRejectionZombie tests that if we fail validating a channel -// due to some sort of on-chain rejection (no funding transaction, or invalid -// UTXO), then we'll mark the channel as a zombie. -func TestChannelOnChainRejectionZombie(t *testing.T) { - t.Parallel() - ctx := createTestCtxSingleNode(t, 0) + testAttempt := makeFailedAttempt(t, int(payAmt)) + node, err := createTestNode() + require.NoError(t, err) - // To start, we'll make an edge for the channel, but we won't add the - // funding transaction to the mock blockchain, which should cause the - // validation to fail below. - edge, err := newChannelEdgeInfo(ctx, 1, edgeCreationNoFundingTx) - require.Nil(t, err) + // Create a simple 1-hop route. + hops := []*route.Hop{ + { + ChannelID: 1, + PubKeyBytes: node.PubKeyBytes, + AmtToForward: payAmt, + OutgoingTimeLock: 120, + MPP: record.NewMPP(payAmt, [32]byte{}), + }, + } + rt, err := route.NewRouteFromHops(payAmt, 100, node.PubKeyBytes, hops) + require.NoError(t, err) - // We expect this to fail as the transaction isn't present in the - // chain (nor the block). - assertChanChainRejection(t, ctx, edge, ErrNoFundingTransaction) + // Create mockers. + controlTower := &mockControlTower{} + payer := &mockPaymentAttemptDispatcher{} + missionControl := &mockMissionControl{} - // Next, we'll make another channel edge, but actually add it to the - // graph this time. - edge, err = newChannelEdgeInfo(ctx, 2, edgeCreationNoUTXO) - require.Nil(t, err) + // Create the router. + router := &ChannelRouter{cfg: &Config{ + Control: controlTower, + Payer: payer, + MissionControl: missionControl, + Clock: clock.NewTestClock(time.Unix(1, 0)), + NextPaymentID: func() (uint64, error) { + return 0, nil + }, + }} - // Instead now, we'll remove it from the set of UTXOs which should - // cause the spentness validation to fail. - assertChanChainRejection(t, ctx, edge, ErrChannelSpent) + // Create the error to be returned. + tempErr := htlcswitch.NewForwardingError( + &lnwire.FailTemporaryChannelFailure{}, 1, + ) - // If we cause the funding transaction the chain to fail validation, we - // should see similar behavior. - edge, err = newChannelEdgeInfo(ctx, 3, edgeCreationBadScript) - require.Nil(t, err) - assertChanChainRejection(t, ctx, edge, ErrInvalidFundingOutput) -} + // Register mockers with the expected method calls. + controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) + controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) + controlTower.On("FailAttempt", + payHash, mock.Anything, mock.Anything, + ).Return(testAttempt, nil) -func createDummyTestGraph(t *testing.T) *testGraphInstance { - // Setup two simple channels such that we can mock sending along this - // route. - chanCapSat := btcutil.Amount(100000) - testChannels := []*testChannel{ - symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{ - Expiry: 144, - FeeRate: 400, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), - }, 1), - symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{ - Expiry: 144, - FeeRate: 400, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), - }, 2), - } + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(tempErr) - testGraph, err := createTestGraphFromChannels(t, true, testChannels, "a") - require.NoError(t, err, "failed to create graph") - return testGraph -} + // Mock the control tower to return the mocked payment. + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() -func createDummyLightningPayment(t *testing.T, - target route.Vertex, amt lnwire.MilliSatoshi) *LightningPayment { + // Mock the mission control to return a nil reason from reporting the + // attempt failure. + missionControl.On("ReportPaymentFail", + mock.Anything, rt, mock.Anything, mock.Anything, + ).Return(nil, nil) - var preImage lntypes.Preimage - _, err := rand.Read(preImage[:]) - require.NoError(t, err, "unable to generate preimage") + // Mock the payment to return nil failure reason. + payment.On("TerminalInfo").Return(nil, nil) - payHash := preImage.Hash() + // Expect a failed send to route. + attempt, err := router.SendToRouteSkipTempErr(payHash, rt) + require.Equal(t, tempErr, err) + require.Equal(t, testAttempt, attempt) - return &LightningPayment{ - Target: target, - Amount: amt, - FeeLimit: noFeeLimit, - paymentHash: &payHash, - } + // Assert the above methods are called as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } -// TestBlockDifferenceFix tests if when the router is behind on blocks, the -// router catches up to the best block head. -func TestBlockDifferenceFix(t *testing.T) { - t.Parallel() - - initialBlockHeight := uint32(0) - - // Starting height here is set to 0, which is behind where we want to - // be. - ctx := createTestCtxSingleNode(t, initialBlockHeight) +// TestSendToRouteSkipTempErrPermanentFailure validates a permanent failure +// will fail the payment. +func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { + var ( + payHash lntypes.Hash + payAmt = lnwire.MilliSatoshi(10000) + ) - // Add initial block to our mini blockchain. - block := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - } - ctx.chain.addBlock(block, initialBlockHeight, rand.Uint32()) + testAttempt := makeFailedAttempt(t, int(payAmt)) + node, err := createTestNode() + require.NoError(t, err) - // Let's generate a new block of height 5, 5 above where our node is at. - newBlock := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, + // Create a simple 1-hop route. + hops := []*route.Hop{ + { + ChannelID: 1, + PubKeyBytes: node.PubKeyBytes, + AmtToForward: payAmt, + OutgoingTimeLock: 120, + MPP: record.NewMPP(payAmt, [32]byte{}), + }, } - newBlockHeight := uint32(5) - - blockDifference := newBlockHeight - initialBlockHeight + rt, err := route.NewRouteFromHops(payAmt, 100, node.PubKeyBytes, hops) + require.NoError(t, err) - ctx.chainView.notifyBlockAck = make(chan struct{}, 1) + // Create mockers. + controlTower := &mockControlTower{} + payer := &mockPaymentAttemptDispatcher{} + missionControl := &mockMissionControl{} - ctx.chain.addBlock(newBlock, newBlockHeight, rand.Uint32()) - ctx.chain.setBestBlock(int32(newBlockHeight)) - ctx.chainView.notifyBlock(block.BlockHash(), newBlockHeight, - []*wire.MsgTx{}, t) + // Create the router. + router := &ChannelRouter{cfg: &Config{ + Control: controlTower, + Payer: payer, + MissionControl: missionControl, + Clock: clock.NewTestClock(time.Unix(1, 0)), + NextPaymentID: func() (uint64, error) { + return 0, nil + }, + }} - <-ctx.chainView.notifyBlockAck + // Create the error to be returned. + permErr := htlcswitch.NewForwardingError( + &lnwire.FailIncorrectDetails{}, 1, + ) - // At this point, the chain notifier should have noticed that we're - // behind on blocks, and will send the n missing blocks that we - // need to the client's epochs channel. Let's replicate this - // functionality. - for i := 0; i < int(blockDifference); i++ { - currBlockHeight := int32(i + 1) + // Register mockers with the expected method calls. + controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) + controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) - nonce := rand.Uint32() + controlTower.On("FailAttempt", + payHash, mock.Anything, mock.Anything, + ).Return(testAttempt, nil) - newBlock := &wire.MsgBlock{ - Transactions: []*wire.MsgTx{}, - Header: wire.BlockHeader{Nonce: nonce}, - } - ctx.chain.addBlock(newBlock, uint32(currBlockHeight), nonce) - currHash := newBlock.Header.BlockHash() + // Expect the payment to be failed. + controlTower.On("FailPayment", payHash, mock.Anything).Return(nil) - newEpoch := &chainntnfs.BlockEpoch{ - Height: currBlockHeight, - Hash: &currHash, - } + // Mock an error to be returned from sending the htlc. + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(permErr) - ctx.notifier.EpochChan <- newEpoch + failureReason := channeldb.FailureReasonPaymentDetails + missionControl.On("ReportPaymentFail", + mock.Anything, rt, mock.Anything, mock.Anything, + ).Return(&failureReason, nil) - ctx.chainView.notifyBlock(currHash, - uint32(currBlockHeight), block.Transactions, t) + // Mock the control tower to return the mocked payment. + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() - <-ctx.chainView.notifyBlockAck - } + // Mock the payment to return a failure reason. + payment.On("TerminalInfo").Return(nil, &failureReason) - err := wait.NoError(func() error { - // Then router height should be updated to the latest block. - if atomic.LoadUint32(&ctx.router.bestHeight) != newBlockHeight { - return fmt.Errorf("height should have been updated "+ - "to %v, instead got %v", newBlockHeight, - ctx.router.bestHeight) - } + // Expect a failed send to route. + attempt, err := router.SendToRouteSkipTempErr(payHash, rt) + require.Equal(t, permErr, err) + require.Equal(t, testAttempt, attempt) - return nil - }, testTimeout) - require.NoError(t, err, "block height wasn't updated") + // Assert the above methods are called as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } -// TestSendToRouteSkipTempErrSuccess validates a successful payment send. -func TestSendToRouteSkipTempErrSuccess(t *testing.T) { - t.Parallel() - +// TestSendToRouteTempFailure validates a temporary failure will cause the +// payment to be failed. +func TestSendToRouteTempFailure(t *testing.T) { var ( payHash lntypes.Hash payAmt = lnwire.MilliSatoshi(10000) ) - preimage := lntypes.Preimage{1} - testAttempt := makeSettledAttempt(t, int(payAmt), preimage) - + testAttempt := makeFailedAttempt(t, int(payAmt)) node, err := createTestNode() require.NoError(t, err) @@ -3511,29 +2058,24 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { }, }} + // Create the error to be returned. + tempErr := htlcswitch.NewForwardingError( + &lnwire.FailTemporaryChannelFailure{}, 1, + ) + // Register mockers with the expected method calls. controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) - controlTower.On("SettleAttempt", + controlTower.On("FailAttempt", payHash, mock.Anything, mock.Anything, ).Return(testAttempt, nil) - payer.On("SendHTLC", - mock.Anything, mock.Anything, mock.Anything, - ).Return(nil) + // Expect the payment to be failed. + controlTower.On("FailPayment", payHash, mock.Anything).Return(nil) - // Create a buffered chan and it will be returned by GetAttemptResult. - resultChan := make(chan *htlcswitch.PaymentResult, 1) - payer.On("GetAttemptResult", + payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, - ).Return(resultChan, nil).Run(func(_ mock.Arguments) { - // Send a successful payment result. - resultChan <- &htlcswitch.PaymentResult{} - }) - - missionControl.On("ReportPaymentSuccess", - mock.Anything, rt, - ).Return(nil) + ).Return(tempErr) // Mock the control tower to return the mocked payment. payment := &mockMPPayment{} @@ -3542,9 +2084,14 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { // Mock the payment to return nil failure reason. payment.On("TerminalInfo").Return(nil, nil) - // Expect a successful send to route. - attempt, err := router.SendToRouteSkipTempErr(payHash, rt) - require.NoError(t, err) + // Return a nil reason to mock a temporary failure. + missionControl.On("ReportPaymentFail", + mock.Anything, rt, mock.Anything, mock.Anything, + ).Return(nil, nil) + + // Expect a failed send to route. + attempt, err := router.SendToRoute(payHash, rt) + require.Equal(t, tempErr, err) require.Equal(t, testAttempt, attempt) // Assert the above methods are called as expected. @@ -3554,459 +2101,518 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { payment.AssertExpectations(t) } -// TestSendToRouteSkipTempErrNonMPP checks that an error is return when -// skipping temp error for non-MPP. -func TestSendToRouteSkipTempErrNonMPP(t *testing.T) { +// TestNewRouteRequest tests creation of route requests for blinded and +// unblinded routes. +func TestNewRouteRequest(t *testing.T) { t.Parallel() + //nolint:lll + source, err := route.NewVertexFromStr("0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6") + require.NoError(t, err) + sourcePubkey, err := btcec.ParsePubKey(source[:]) + require.NoError(t, err) + + //nolint:lll + v1, err := route.NewVertexFromStr("026c43a8ac1cd8519985766e90748e1e06871dab0ff6b8af27e8c1a61640481318") + require.NoError(t, err) + pubkey1, err := btcec.ParsePubKey(v1[:]) + require.NoError(t, err) + + //nolint:lll + v2, err := route.NewVertexFromStr("03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99") + require.NoError(t, err) + pubkey2, err := btcec.ParsePubKey(v2[:]) + require.NoError(t, err) + var ( - payHash lntypes.Hash - payAmt = lnwire.MilliSatoshi(10000) + unblindedCltv uint16 = 500 + blindedCltv uint16 = 1000 ) - node, err := createTestNode() - require.NoError(t, err) + blindedSelfIntro := &BlindedPayment{ + CltvExpiryDelta: blindedCltv, + BlindedPath: &sphinx.BlindedPath{ + IntroductionPoint: sourcePubkey, + BlindedHops: []*sphinx.BlindedHopInfo{{}}, + }, + } - // Create a simple 1-hop route without the MPP field. - hops := []*route.Hop{ - { - ChannelID: 1, - PubKeyBytes: node.PubKeyBytes, - AmtToForward: payAmt, + blindedOtherIntro := &BlindedPayment{ + CltvExpiryDelta: blindedCltv, + BlindedPath: &sphinx.BlindedPath{ + IntroductionPoint: pubkey1, + BlindedHops: []*sphinx.BlindedHopInfo{ + {}, + }, }, } - rt, err := route.NewRouteFromHops(payAmt, 100, node.PubKeyBytes, hops) - require.NoError(t, err) - // Create mockers. - controlTower := &mockControlTower{} - payer := &mockPaymentAttemptDispatcher{} - missionControl := &mockMissionControl{} + blindedMultiHop := &BlindedPayment{ + CltvExpiryDelta: blindedCltv, + BlindedPath: &sphinx.BlindedPath{ + IntroductionPoint: pubkey1, + BlindedHops: []*sphinx.BlindedHopInfo{ + {}, + { + BlindedNodePub: pubkey2, + }, + }, + }, + } - // Create the router. - router := &ChannelRouter{cfg: &Config{ - Control: controlTower, - Payer: payer, - MissionControl: missionControl, - Clock: clock.NewTestClock(time.Unix(1, 0)), - NextPaymentID: func() (uint64, error) { - return 0, nil + testCases := []struct { + name string + target *route.Vertex + routeHints RouteHints + blindedPayment *BlindedPayment + finalExpiry uint16 + + expectedTarget route.Vertex + expectedCltv uint16 + err error + }{ + { + name: "blinded and target", + target: &v1, + blindedPayment: blindedOtherIntro, + err: ErrTargetAndBlinded, }, - }} + { + // For single-hop blinded we have a final cltv. + name: "blinded intro node only", + blindedPayment: blindedOtherIntro, + expectedTarget: v1, + expectedCltv: blindedCltv, + err: nil, + }, + { + // For multi-hop blinded, we have no final cltv. + name: "blinded multi-hop", + blindedPayment: blindedMultiHop, + expectedTarget: v2, + expectedCltv: 0, + err: nil, + }, + { + name: "unblinded", + target: &v2, + finalExpiry: unblindedCltv, + expectedTarget: v2, + expectedCltv: unblindedCltv, + err: nil, + }, + { + name: "source node intro", + blindedPayment: blindedSelfIntro, + err: ErrSelfIntro, + }, + { + name: "hints and blinded", + blindedPayment: blindedMultiHop, + routeHints: make( + map[route.Vertex][]AdditionalEdge, + ), + err: ErrHintsAndBlinded, + }, + { + name: "expiry and blinded", + blindedPayment: blindedMultiHop, + finalExpiry: unblindedCltv, + err: ErrExpiryAndBlinded, + }, + { + name: "invalid blinded payment", + blindedPayment: &BlindedPayment{}, + err: ErrNoBlindedPath, + }, + } - // Expect an error to be returned. - attempt, err := router.SendToRouteSkipTempErr(payHash, rt) - require.ErrorIs(t, ErrSkipTempErr, err) - require.Nil(t, attempt) + for _, testCase := range testCases { + testCase := testCase - // Assert the above methods are not called. - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - missionControl.AssertExpectations(t) + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + req, err := NewRouteRequest( + source, testCase.target, 1000, 0, nil, nil, + testCase.routeHints, testCase.blindedPayment, + testCase.finalExpiry, + ) + require.ErrorIs(t, err, testCase.err) + + // Skip request validation if we got a non-nil error. + if err != nil { + return + } + + require.Equal(t, req.Target, testCase.expectedTarget) + require.Equal( + t, req.FinalExpiry, testCase.expectedCltv, + ) + }) + } } -// TestSendToRouteSkipTempErrTempFailure validates a temporary failure won't -// cause the payment to be failed. -func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { +// TestAddEdgeUnknownVertexes tests that if an edge is added that contains two +// vertexes which we don't know of, the edge should be available for use +// regardless. This is due to the fact that we don't actually need node +// announcements for the channel vertexes to be able to use the channel. +func TestAddEdgeUnknownVertexes(t *testing.T) { t.Parallel() - var ( - payHash lntypes.Hash - payAmt = lnwire.MilliSatoshi(10000) + const startingBlockHeight = 101 + ctx := createTestCtxFromFile(t, startingBlockHeight, basicGraphFilePath) + + var pub1 [33]byte + copy(pub1[:], priv1.PubKey().SerializeCompressed()) + + var pub2 [33]byte + copy(pub2[:], priv2.PubKey().SerializeCompressed()) + + // The two nodes we are about to add should not exist yet. + _, exists1, err := ctx.graph.HasLightningNode(pub1) + require.NoError(t, err, "unable to query graph") + require.False(t, exists1) + + _, exists2, err := ctx.graph.HasLightningNode(pub2) + require.NoError(t, err, "unable to query graph") + require.False(t, exists2) + + // Add the edge between the two unknown nodes to the graph, and check + // that the nodes are found after the fact. + _, _, chanID, err := createChannelEdge( + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + 10000, 500, ) + require.NoError(t, err, "unable to create channel edge") - testAttempt := makeFailedAttempt(t, int(payAmt)) - node, err := createTestNode() - require.NoError(t, err) + edge := &models.ChannelEdgeInfo{ + ChannelID: chanID.ToUint64(), + NodeKey1Bytes: pub1, + NodeKey2Bytes: pub2, + BitcoinKey1Bytes: pub1, + BitcoinKey2Bytes: pub2, + AuthProof: nil, + } + require.NoError(t, ctx.graph.AddChannelEdge(edge)) - // Create a simple 1-hop route. - hops := []*route.Hop{ - { - ChannelID: 1, - PubKeyBytes: node.PubKeyBytes, - AmtToForward: payAmt, - OutgoingTimeLock: 120, - MPP: record.NewMPP(payAmt, [32]byte{}), - }, + // We must add the edge policy to be able to use the edge for route + // finding. + edgePolicy := &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + ToNode: edge.NodeKey2Bytes, } - rt, err := route.NewRouteFromHops(payAmt, 100, node.PubKeyBytes, hops) - require.NoError(t, err) + edgePolicy.ChannelFlags = 0 - // Create mockers. - controlTower := &mockControlTower{} - payer := &mockPaymentAttemptDispatcher{} - missionControl := &mockMissionControl{} + require.NoError(t, ctx.graph.UpdateEdgePolicy(edgePolicy)) - // Create the router. - router := &ChannelRouter{cfg: &Config{ - Control: controlTower, - Payer: payer, - MissionControl: missionControl, - Clock: clock.NewTestClock(time.Unix(1, 0)), - NextPaymentID: func() (uint64, error) { - return 0, nil - }, - }} + // Create edge in the other direction as well. + edgePolicy = &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + ToNode: edge.NodeKey1Bytes, + } + edgePolicy.ChannelFlags = 1 - // Create the error to be returned. - tempErr := htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, 1, - ) + require.NoError(t, ctx.graph.UpdateEdgePolicy(edgePolicy)) - // Register mockers with the expected method calls. - controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) - controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) - controlTower.On("FailAttempt", - payHash, mock.Anything, mock.Anything, - ).Return(testAttempt, nil) + // After adding the edge between the two previously unknown nodes, they + // should have been added to the graph. + _, exists1, err = ctx.graph.HasLightningNode(pub1) + require.NoError(t, err, "unable to query graph") + require.True(t, exists1) - payer.On("SendHTLC", - mock.Anything, mock.Anything, mock.Anything, - ).Return(tempErr) + _, exists2, err = ctx.graph.HasLightningNode(pub2) + require.NoError(t, err, "unable to query graph") + require.True(t, exists2) - // Mock the control tower to return the mocked payment. - payment := &mockMPPayment{} - controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() + // We will connect node1 to the rest of the test graph, and make sure + // we can find a route to node2, which will use the just added channel + // edge. - // Mock the mission control to return a nil reason from reporting the - // attempt failure. - missionControl.On("ReportPaymentFail", - mock.Anything, rt, mock.Anything, mock.Anything, - ).Return(nil, nil) + // We will connect node 1 to "sophon" + connectNode := ctx.aliases["sophon"] + connectNodeKey, err := btcec.ParsePubKey(connectNode[:]) + require.NoError(t, err) - // Mock the payment to return nil failure reason. - payment.On("TerminalInfo").Return(nil, nil) + var ( + pubKey1 *btcec.PublicKey + pubKey2 *btcec.PublicKey + ) + node1Bytes := priv1.PubKey().SerializeCompressed() + node2Bytes := connectNode + if bytes.Compare(node1Bytes[:], node2Bytes[:]) == -1 { + pubKey1 = priv1.PubKey() + pubKey2 = connectNodeKey + } else { + pubKey1 = connectNodeKey + pubKey2 = priv1.PubKey() + } - // Expect a failed send to route. - attempt, err := router.SendToRouteSkipTempErr(payHash, rt) - require.Equal(t, tempErr, err) - require.Equal(t, testAttempt, attempt) + _, _, chanID, err = createChannelEdge( + pubKey1.SerializeCompressed(), pubKey2.SerializeCompressed(), + 10000, 510) + require.NoError(t, err, "unable to create channel edge") - // Assert the above methods are called as expected. - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - missionControl.AssertExpectations(t) - payment.AssertExpectations(t) -} + edge = &models.ChannelEdgeInfo{ + ChannelID: chanID.ToUint64(), + AuthProof: nil, + } + copy(edge.NodeKey1Bytes[:], node1Bytes) + edge.NodeKey2Bytes = node2Bytes + copy(edge.BitcoinKey1Bytes[:], node1Bytes) + edge.BitcoinKey2Bytes = node2Bytes -// TestSendToRouteSkipTempErrPermanentFailure validates a permanent failure -// will fail the payment. -func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { - var ( - payHash lntypes.Hash - payAmt = lnwire.MilliSatoshi(10000) - ) + require.NoError(t, ctx.graph.AddChannelEdge(edge)) - testAttempt := makeFailedAttempt(t, int(payAmt)) - node, err := createTestNode() - require.NoError(t, err) + edgePolicy = &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + ToNode: edge.NodeKey2Bytes, + } + edgePolicy.ChannelFlags = 0 - // Create a simple 1-hop route. - hops := []*route.Hop{ - { - ChannelID: 1, - PubKeyBytes: node.PubKeyBytes, - AmtToForward: payAmt, - OutgoingTimeLock: 120, - MPP: record.NewMPP(payAmt, [32]byte{}), - }, + require.NoError(t, ctx.graph.UpdateEdgePolicy(edgePolicy)) + + edgePolicy = &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + ToNode: edge.NodeKey1Bytes, } - rt, err := route.NewRouteFromHops(payAmt, 100, node.PubKeyBytes, hops) - require.NoError(t, err) + edgePolicy.ChannelFlags = 1 - // Create mockers. - controlTower := &mockControlTower{} - payer := &mockPaymentAttemptDispatcher{} - missionControl := &mockMissionControl{} + require.NoError(t, ctx.graph.UpdateEdgePolicy(edgePolicy)) - // Create the router. - router := &ChannelRouter{cfg: &Config{ - Control: controlTower, - Payer: payer, - MissionControl: missionControl, - Clock: clock.NewTestClock(time.Unix(1, 0)), - NextPaymentID: func() (uint64, error) { - return 0, nil - }, - }} + // We should now be able to find a route to node 2. + paymentAmt := lnwire.NewMSatFromSatoshis(100) + targetNode := priv2.PubKey() + var targetPubKeyBytes route.Vertex + copy(targetPubKeyBytes[:], targetNode.SerializeCompressed()) - // Create the error to be returned. - permErr := htlcswitch.NewForwardingError( - &lnwire.FailIncorrectDetails{}, 1, + req, err := NewRouteRequest( + ctx.router.cfg.SelfNode, &targetPubKeyBytes, + paymentAmt, 0, noRestrictions, nil, nil, nil, MinCLTVDelta, ) + require.NoError(t, err, "invalid route request") + _, _, err = ctx.router.FindRoute(req) + require.NoError(t, err, "unable to find any routes") - // Register mockers with the expected method calls. - controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) - controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) + // Now check that we can update the node info for the partial node + // without messing up the channel graph. + n1 := &channeldb.LightningNode{ + HaveNodeAnnouncement: true, + LastUpdate: time.Unix(123, 0), + Addresses: testAddrs, + Color: color.RGBA{1, 2, 3, 0}, + Alias: "node11", + AuthSigBytes: testSig.Serialize(), + Features: testFeatures, + } + copy(n1.PubKeyBytes[:], priv1.PubKey().SerializeCompressed()) - controlTower.On("FailAttempt", - payHash, mock.Anything, mock.Anything, - ).Return(testAttempt, nil) + require.NoError(t, ctx.graph.AddLightningNode(n1)) - // Expect the payment to be failed. - controlTower.On("FailPayment", payHash, mock.Anything).Return(nil) + n2 := &channeldb.LightningNode{ + HaveNodeAnnouncement: true, + LastUpdate: time.Unix(123, 0), + Addresses: testAddrs, + Color: color.RGBA{1, 2, 3, 0}, + Alias: "node22", + AuthSigBytes: testSig.Serialize(), + Features: testFeatures, + } + copy(n2.PubKeyBytes[:], priv2.PubKey().SerializeCompressed()) - // Mock an error to be returned from sending the htlc. - payer.On("SendHTLC", - mock.Anything, mock.Anything, mock.Anything, - ).Return(permErr) + require.NoError(t, ctx.graph.AddLightningNode(n2)) - failureReason := channeldb.FailureReasonPaymentDetails - missionControl.On("ReportPaymentFail", - mock.Anything, rt, mock.Anything, mock.Anything, - ).Return(&failureReason, nil) + // Should still be able to find the route, and the info should be + // updated. + req, err = NewRouteRequest( + ctx.router.cfg.SelfNode, &targetPubKeyBytes, + paymentAmt, 0, noRestrictions, nil, nil, nil, MinCLTVDelta, + ) + require.NoError(t, err, "invalid route request") - // Mock the control tower to return the mocked payment. - payment := &mockMPPayment{} - controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() + _, _, err = ctx.router.FindRoute(req) + require.NoError(t, err, "unable to find any routes") - // Mock the payment to return a failure reason. - payment.On("TerminalInfo").Return(nil, &failureReason) + copy1, err := ctx.graph.FetchLightningNode(pub1) + require.NoError(t, err, "unable to fetch node") - // Expect a failed send to route. - attempt, err := router.SendToRouteSkipTempErr(payHash, rt) - require.Equal(t, permErr, err) - require.Equal(t, testAttempt, attempt) + require.Equal(t, n1.Alias, copy1.Alias) - // Assert the above methods are called as expected. - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - missionControl.AssertExpectations(t) - payment.AssertExpectations(t) + copy2, err := ctx.graph.FetchLightningNode(pub2) + require.NoError(t, err, "unable to fetch node") + + require.Equal(t, n2.Alias, copy2.Alias) } -// TestSendToRouteTempFailure validates a temporary failure will cause the -// payment to be failed. -func TestSendToRouteTempFailure(t *testing.T) { - var ( - payHash lntypes.Hash - payAmt = lnwire.MilliSatoshi(10000) - ) +func createDummyLightningPayment(t *testing.T, + target route.Vertex, amt lnwire.MilliSatoshi) *LightningPayment { - testAttempt := makeFailedAttempt(t, int(payAmt)) - node, err := createTestNode() - require.NoError(t, err) + var preImage lntypes.Preimage + _, err := rand.Read(preImage[:]) + require.NoError(t, err, "unable to generate preimage") - // Create a simple 1-hop route. - hops := []*route.Hop{ - { - ChannelID: 1, - PubKeyBytes: node.PubKeyBytes, - AmtToForward: payAmt, - OutgoingTimeLock: 120, - MPP: record.NewMPP(payAmt, [32]byte{}), - }, + payHash := preImage.Hash() + + return &LightningPayment{ + Target: target, + Amount: amt, + FeeLimit: noFeeLimit, + paymentHash: &payHash, } - rt, err := route.NewRouteFromHops(payAmt, 100, node.PubKeyBytes, hops) - require.NoError(t, err) +} - // Create mockers. - controlTower := &mockControlTower{} - payer := &mockPaymentAttemptDispatcher{} - missionControl := &mockMissionControl{} +type mockGraphBuilder struct { + rejectUpdate bool + updateEdge func(update *models.ChannelEdgePolicy) error +} - // Create the router. - router := &ChannelRouter{cfg: &Config{ - Control: controlTower, - Payer: payer, - MissionControl: missionControl, - Clock: clock.NewTestClock(time.Unix(1, 0)), - NextPaymentID: func() (uint64, error) { - return 0, nil +func newMockGraphBuilder(graph graph.DB) *mockGraphBuilder { + return &mockGraphBuilder{ + updateEdge: func(update *models.ChannelEdgePolicy) error { + return graph.UpdateEdgePolicy(update) }, - }} + } +} - // Create the error to be returned. - tempErr := htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, 1, - ) +func (m *mockGraphBuilder) setNextReject(reject bool) { + m.rejectUpdate = reject +} - // Register mockers with the expected method calls. - controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) - controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) - controlTower.On("FailAttempt", - payHash, mock.Anything, mock.Anything, - ).Return(testAttempt, nil) +func (m *mockGraphBuilder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate) bool { + if m.rejectUpdate { + return false + } + + err := m.updateEdge(&models.ChannelEdgePolicy{ + SigBytes: msg.Signature.ToSignatureBytes(), + ChannelID: msg.ShortChannelID.ToUint64(), + LastUpdate: time.Unix(int64(msg.Timestamp), 0), + MessageFlags: msg.MessageFlags, + ChannelFlags: msg.ChannelFlags, + TimeLockDelta: msg.TimeLockDelta, + MinHTLC: msg.HtlcMinimumMsat, + MaxHTLC: msg.HtlcMaximumMsat, + FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate), + ExtraOpaqueData: msg.ExtraOpaqueData, + }) - // Expect the payment to be failed. - controlTower.On("FailPayment", payHash, mock.Anything).Return(nil) + return err == nil +} - payer.On("SendHTLC", - mock.Anything, mock.Anything, mock.Anything, - ).Return(tempErr) +type mockChain struct { + lnwallet.BlockChainIO - // Mock the control tower to return the mocked payment. - payment := &mockMPPayment{} - controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() + blocks map[chainhash.Hash]*wire.MsgBlock + blockIndex map[uint32]chainhash.Hash + blockHeightIndex map[chainhash.Hash]uint32 - // Mock the payment to return nil failure reason. - payment.On("TerminalInfo").Return(nil, nil) + utxos map[wire.OutPoint]wire.TxOut - // Return a nil reason to mock a temporary failure. - missionControl.On("ReportPaymentFail", - mock.Anything, rt, mock.Anything, mock.Anything, - ).Return(nil, nil) + bestHeight int32 - // Expect a failed send to route. - attempt, err := router.SendToRoute(payHash, rt) - require.Equal(t, tempErr, err) - require.Equal(t, testAttempt, attempt) + sync.RWMutex +} - // Assert the above methods are called as expected. - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - missionControl.AssertExpectations(t) - payment.AssertExpectations(t) +func newMockChain(currentHeight uint32) *mockChain { + return &mockChain{ + bestHeight: int32(currentHeight), + blocks: make(map[chainhash.Hash]*wire.MsgBlock), + utxos: make(map[wire.OutPoint]wire.TxOut), + blockIndex: make(map[uint32]chainhash.Hash), + blockHeightIndex: make(map[chainhash.Hash]uint32), + } } -// TestNewRouteRequest tests creation of route requests for blinded and -// unblinded routes. -func TestNewRouteRequest(t *testing.T) { - t.Parallel() +func (m *mockChain) GetBestBlock() (*chainhash.Hash, int32, error) { + m.RLock() + defer m.RUnlock() - //nolint:lll - source, err := route.NewVertexFromStr("0367cec75158a4129177bfb8b269cb586efe93d751b43800d456485e81c2620ca6") - require.NoError(t, err) - sourcePubkey, err := btcec.ParsePubKey(source[:]) - require.NoError(t, err) + blockHash := m.blockIndex[uint32(m.bestHeight)] - //nolint:lll - v1, err := route.NewVertexFromStr("026c43a8ac1cd8519985766e90748e1e06871dab0ff6b8af27e8c1a61640481318") - require.NoError(t, err) - pubkey1, err := btcec.ParsePubKey(v1[:]) - require.NoError(t, err) + return &blockHash, m.bestHeight, nil +} - //nolint:lll - v2, err := route.NewVertexFromStr("03c19f0027ffbb0ae0e14a4d958788793f9d74e107462473ec0c3891e4feb12e99") - require.NoError(t, err) - pubkey2, err := btcec.ParsePubKey(v2[:]) - require.NoError(t, err) +func (m *mockChain) setBestBlock(height int32) { + m.Lock() + defer m.Unlock() - var ( - unblindedCltv uint16 = 500 - blindedCltv uint16 = 1000 - ) + m.bestHeight = height +} - blindedSelfIntro := &BlindedPayment{ - CltvExpiryDelta: blindedCltv, - BlindedPath: &sphinx.BlindedPath{ - IntroductionPoint: sourcePubkey, - BlindedHops: []*sphinx.BlindedHopInfo{{}}, - }, - } +func (m *mockChain) addUtxo(op wire.OutPoint, out *wire.TxOut) { + m.Lock() + m.utxos[op] = *out + m.Unlock() +} - blindedOtherIntro := &BlindedPayment{ - CltvExpiryDelta: blindedCltv, - BlindedPath: &sphinx.BlindedPath{ - IntroductionPoint: pubkey1, - BlindedHops: []*sphinx.BlindedHopInfo{ - {}, - }, - }, - } +func (m *mockChain) delUtxo(op wire.OutPoint) { + m.Lock() + delete(m.utxos, op) + m.Unlock() +} - blindedMultiHop := &BlindedPayment{ - CltvExpiryDelta: blindedCltv, - BlindedPath: &sphinx.BlindedPath{ - IntroductionPoint: pubkey1, - BlindedHops: []*sphinx.BlindedHopInfo{ - {}, - { - BlindedNodePub: pubkey2, - }, - }, - }, - } +func (m *mockChain) addBlock(block *wire.MsgBlock, height uint32, nonce uint32) { + m.Lock() + block.Header.Nonce = nonce + hash := block.Header.BlockHash() + m.blocks[hash] = block + m.blockIndex[height] = hash + m.blockHeightIndex[hash] = height + m.Unlock() +} - testCases := []struct { - name string - target *route.Vertex - routeHints RouteHints - blindedPayment *BlindedPayment - finalExpiry uint16 +func createChannelEdge(bitcoinKey1, bitcoinKey2 []byte, + chanValue btcutil.Amount, fundingHeight uint32) (*wire.MsgTx, + *wire.OutPoint, *lnwire.ShortChannelID, error) { - expectedTarget route.Vertex - expectedCltv uint16 - err error - }{ - { - name: "blinded and target", - target: &v1, - blindedPayment: blindedOtherIntro, - err: ErrTargetAndBlinded, - }, - { - // For single-hop blinded we have a final cltv. - name: "blinded intro node only", - blindedPayment: blindedOtherIntro, - expectedTarget: v1, - expectedCltv: blindedCltv, - err: nil, - }, - { - // For multi-hop blinded, we have no final cltv. - name: "blinded multi-hop", - blindedPayment: blindedMultiHop, - expectedTarget: v2, - expectedCltv: 0, - err: nil, - }, - { - name: "unblinded", - target: &v2, - finalExpiry: unblindedCltv, - expectedTarget: v2, - expectedCltv: unblindedCltv, - err: nil, - }, - { - name: "source node intro", - blindedPayment: blindedSelfIntro, - err: ErrSelfIntro, - }, - { - name: "hints and blinded", - blindedPayment: blindedMultiHop, - routeHints: make( - map[route.Vertex][]AdditionalEdge, - ), - err: ErrHintsAndBlinded, - }, - { - name: "expiry and blinded", - blindedPayment: blindedMultiHop, - finalExpiry: unblindedCltv, - err: ErrExpiryAndBlinded, - }, - { - name: "invalid blinded payment", - blindedPayment: &BlindedPayment{}, - err: ErrNoBlindedPath, - }, + fundingTx := wire.NewMsgTx(2) + _, tx, err := input.GenFundingPkScript( + bitcoinKey1, + bitcoinKey2, + int64(chanValue), + ) + if err != nil { + return nil, nil, nil, err } - for _, testCase := range testCases { - testCase := testCase - - t.Run(testCase.name, func(t *testing.T) { - t.Parallel() - - req, err := NewRouteRequest( - source, testCase.target, 1000, 0, nil, nil, - testCase.routeHints, testCase.blindedPayment, - testCase.finalExpiry, - ) - require.ErrorIs(t, err, testCase.err) - - // Skip request validation if we got a non-nil error. - if err != nil { - return - } + fundingTx.TxOut = append(fundingTx.TxOut, tx) + chanUtxo := wire.OutPoint{ + Hash: fundingTx.TxHash(), + Index: 0, + } - require.Equal(t, req.Target, testCase.expectedTarget) - require.Equal( - t, req.FinalExpiry, testCase.expectedCltv, - ) - }) + // Our fake channel will be "confirmed" at height 101. + chanID := &lnwire.ShortChannelID{ + BlockHeight: fundingHeight, + TxIndex: 0, + TxPosition: 0, } + + return fundingTx, &chanUtxo, chanID, nil } diff --git a/rpcserver.go b/rpcserver.go index 011674e6c4..1ba82015e7 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -49,6 +49,7 @@ import ( "github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" + "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -3075,7 +3076,7 @@ func (r *rpcServer) GetInfo(_ context.Context, // date, we add the router's state to it. So the flag will only toggle // to true once the router was also able to catch up. if !r.cfg.Routing.AssumeChannelValid { - routerHeight := r.server.chanRouter.SyncedHeight() + routerHeight := r.server.graphBuilder.SyncedHeight() isSynced = isSynced && uint32(bestHeight) == routerHeight } @@ -3118,7 +3119,7 @@ func (r *rpcServer) GetInfo(_ context.Context, // TODO(roasbeef): add synced height n stuff isTestNet := chainreg.IsTestnet(&r.cfg.ActiveNetParams) - nodeColor := routing.EncodeHexColor(nodeAnn.RGBColor) + nodeColor := graph.EncodeHexColor(nodeAnn.RGBColor) version := build.Version() + " commit=" + build.Commit return &lnrpc.GetInfoResponse{ @@ -6418,7 +6419,7 @@ func marshalNode(node *channeldb.LightningNode) *lnrpc.LightningNode { PubKey: hex.EncodeToString(node.PubKeyBytes[:]), Addresses: nodeAddrs, Alias: node.Alias, - Color: routing.EncodeHexColor(node.Color), + Color: graph.EncodeHexColor(node.Color), Features: features, CustomRecords: customRecords, } @@ -6613,7 +6614,7 @@ func (r *rpcServer) SubscribeChannelGraph(req *lnrpc.GraphTopologySubscription, // First, we start by subscribing to a new intent to receive // notifications from the channel router. - client, err := r.server.chanRouter.SubscribeTopology() + client, err := r.server.graphBuilder.SubscribeTopology() if err != nil { return err } @@ -6665,7 +6666,7 @@ func (r *rpcServer) SubscribeChannelGraph(req *lnrpc.GraphTopologySubscription, // marshallTopologyChange performs a mapping from the topology change struct // returned by the router to the form of notifications expected by the current // gRPC service. -func marshallTopologyChange(topChange *routing.TopologyChange) *lnrpc.GraphTopologyUpdate { +func marshallTopologyChange(topChange *graph.TopologyChange) *lnrpc.GraphTopologyUpdate { // encodeKey is a simple helper function that converts a live public // key into a hex-encoded version of the compressed serialization for diff --git a/server.go b/server.go index 17108ee469..bf57dd418a 100644 --- a/server.go +++ b/server.go @@ -342,7 +342,7 @@ type server struct { // updatePersistentPeerAddrs subscribes to topology changes and stores // advertised addresses for any NodeAnnouncements from our persisted peers. func (s *server) updatePersistentPeerAddrs() error { - graphSub, err := s.chanRouter.SubscribeTopology() + graphSub, err := s.graphBuilder.SubscribeTopology() if err != nil { return err } @@ -976,33 +976,37 @@ func newServer(cfg *Config, listenAddrs []net.Addr, strictPruning := cfg.Bitcoin.Node == "neutrino" || cfg.Routing.StrictZombiePruning - s.graphBuilder, err = graph.NewBuilder(&graph.Config{}) - if err != nil { - return nil, fmt.Errorf("can't create graph builder: %w", err) - } - - s.chanRouter, err = routing.New(routing.Config{ + s.graphBuilder, err = graph.NewBuilder(&graph.Config{ SelfNode: selfNode.PubKeyBytes, - RoutingGraph: graphsession.NewRoutingGraph(chanGraph), Graph: chanGraph, Chain: cc.ChainIO, ChainView: cc.ChainView, Notifier: cc.ChainNotifier, - Payer: s.htlcSwitch, - Control: s.controlTower, - MissionControl: s.missionControl, - SessionSource: paymentSessionSource, - ChannelPruneExpiry: routing.DefaultChannelPruneExpiry, + ChannelPruneExpiry: graph.DefaultChannelPruneExpiry, GraphPruneInterval: time.Hour, - FirstTimePruneDelay: routing.DefaultFirstTimePruneDelay, - GetLink: s.htlcSwitch.GetLinkByShortID, + FirstTimePruneDelay: graph.DefaultFirstTimePruneDelay, AssumeChannelValid: cfg.Routing.AssumeChannelValid, - NextPaymentID: sequencer.NextID, - PathFindingConfig: pathFindingConfig, - Clock: clock.NewDefaultClock(), StrictZombiePruning: strictPruning, IsAlias: aliasmgr.IsAlias, }) + if err != nil { + return nil, fmt.Errorf("can't create graph builder: %w", err) + } + + s.chanRouter, err = routing.New(routing.Config{ + SelfNode: selfNode.PubKeyBytes, + RoutingGraph: graphsession.NewRoutingGraph(chanGraph), + Chain: cc.ChainIO, + Payer: s.htlcSwitch, + Control: s.controlTower, + MissionControl: s.missionControl, + SessionSource: paymentSessionSource, + GetLink: s.htlcSwitch.GetLinkByShortID, + NextPaymentID: sequencer.NextID, + PathFindingConfig: pathFindingConfig, + Clock: clock.NewDefaultClock(), + ApplyChannelUpdate: s.graphBuilder.ApplyChannelUpdate, + }) if err != nil { return nil, fmt.Errorf("can't create router: %w", err) } @@ -1018,7 +1022,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } s.authGossiper = discovery.New(discovery.Config{ - Router: s.chanRouter, + Router: s.graphBuilder, Notifier: s.cc.ChainNotifier, ChainHash: *s.cfg.ActiveNetParams.GenesisHash, Broadcast: s.BroadcastMessage, @@ -1053,11 +1057,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, FindBaseByAlias: s.aliasMgr.FindBaseSCID, GetAlias: s.aliasMgr.GetPeerAlias, FindChannel: s.findChannel, - IsStillZombieChannel: s.chanRouter.IsZombieChannel, + IsStillZombieChannel: s.graphBuilder.IsZombieChannel, }, nodeKeyDesc) s.localChanMgr = &localchans.Manager{ - ForAllOutgoingChannels: s.chanRouter.ForAllOutgoingChannels, + ForAllOutgoingChannels: s.graphBuilder.ForAllOutgoingChannels, PropagateChanPolicyUpdate: s.authGossiper.PropagateChanPolicyUpdate, UpdateForwardingPolicies: s.htlcSwitch.UpdateForwardingPolicies, FetchChannel: s.chanStateDB.FetchChannel, @@ -4667,7 +4671,7 @@ func (s *server) fetchLastChanUpdate() func(lnwire.ShortChannelID) ( ourPubKey := s.identityECDH.PubKey().SerializeCompressed() return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { - info, edge1, edge2, err := s.chanRouter.GetChannelByID(cid) + info, edge1, edge2, err := s.graphBuilder.GetChannelByID(cid) if err != nil { return nil, err } From 9327a83cd2e326edf92afa29553939b8ce6b1575 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 18 Jun 2024 12:02:10 -0700 Subject: [PATCH 15/20] discovery: rename Gossiper graph dep --- discovery/gossiper.go | 44 +++++++++++++++++++------------------- discovery/gossiper_test.go | 4 ++-- server.go | 2 +- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 752ee7446c..c3031011dc 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -165,11 +165,11 @@ type Config struct { // * also need to do same for Notifier ChainHash chainhash.Hash - // Router is the subsystem which is responsible for managing the + // Graph is the subsystem which is responsible for managing the // topology of lightning network. After incoming channel, node, channel // updates announcements are validated they are sent to the router in // order to be included in the LN graph. - Router graph.ChannelGraphSource + Graph graph.ChannelGraphSource // ChanSeries is an interfaces that provides access to a time series // view of the current known channel graph. Each GossipSyncer enabled @@ -591,7 +591,7 @@ func (d *AuthenticatedGossiper) start() error { } d.blockEpochs = blockEpochs - height, err := d.cfg.Router.CurrentBlockHeight() + height, err := d.cfg.Graph.CurrentBlockHeight() if err != nil { return err } @@ -1595,7 +1595,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { havePublicChannels bool edgesToUpdate []updateTuple ) - err := d.cfg.Router.ForAllOutgoingChannels(func( + err := d.cfg.Graph.ForAllOutgoingChannels(func( _ kvdb.RTx, info *models.ChannelEdgeInfo, edge *models.ChannelEdgePolicy) error { @@ -1831,7 +1831,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge( // First, we'll fetch the state of the channel as we know if from the // database. - chanInfo, e1, e2, err := d.cfg.Router.GetChannelByID( + chanInfo, e1, e2, err := d.cfg.Graph.GetChannelByID( chanAnnMsg.ShortChannelID, ) if err != nil { @@ -1871,7 +1871,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge( // If everything checks out, then we'll add the fully assembled proof // to the database. - err = d.cfg.Router.AddProof(chanAnnMsg.ShortChannelID, proof) + err = d.cfg.Graph.AddProof(chanAnnMsg.ShortChannelID, proof) if err != nil { err := fmt.Errorf("unable add proof to shortChanID=%v: %w", chanAnnMsg.ShortChannelID, err) @@ -1928,7 +1928,7 @@ func (d *AuthenticatedGossiper) addNode(msg *lnwire.NodeAnnouncement, ExtraOpaqueData: msg.ExtraOpaqueData, } - return d.cfg.Router.AddNode(node, op...) + return d.cfg.Graph.AddNode(node, op...) } // isPremature decides whether a given network message has a block height+delta @@ -2072,7 +2072,7 @@ func (d *AuthenticatedGossiper) processZombieUpdate( // With the signature valid, we'll proceed to mark the // edge as live and wait for the channel announcement to // come through again. - err = d.cfg.Router.MarkEdgeLive(scid) + err = d.cfg.Graph.MarkEdgeLive(scid) switch { case errors.Is(err, channeldb.ErrZombieEdgeNotFound): log.Errorf("edge with chan_id=%v was not found in the "+ @@ -2099,7 +2099,7 @@ func (d *AuthenticatedGossiper) processZombieUpdate( func (d *AuthenticatedGossiper) fetchNodeAnn( pubKey [33]byte) (*lnwire.NodeAnnouncement, error) { - node, err := d.cfg.Router.FetchLightningNode(pubKey) + node, err := d.cfg.Graph.FetchLightningNode(pubKey) if err != nil { return nil, err } @@ -2112,7 +2112,7 @@ func (d *AuthenticatedGossiper) fetchNodeAnn( func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { switch msg := msg.(type) { case *lnwire.AnnounceSignatures: - chanInfo, _, _, err := d.cfg.Router.GetChannelByID( + chanInfo, _, _, err := d.cfg.Graph.GetChannelByID( msg.ShortChannelID, ) @@ -2134,7 +2134,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { return chanInfo.AuthProof != nil case *lnwire.ChannelUpdate: - _, p1, p2, err := d.cfg.Router.GetChannelByID(msg.ShortChannelID) + _, p1, p2, err := d.cfg.Graph.GetChannelByID(msg.ShortChannelID) // If the channel cannot be found, it is most likely a leftover // message for a channel that was closed, so we can consider it @@ -2207,7 +2207,7 @@ func (d *AuthenticatedGossiper) updateChannel(info *models.ChannelEdgeInfo, } // Finally, we'll write the new edge policy to disk. - if err := d.cfg.Router.UpdateEdge(edge); err != nil { + if err := d.cfg.Graph.UpdateEdge(edge); err != nil { return nil, nil, err } @@ -2327,7 +2327,7 @@ func (d *AuthenticatedGossiper) handleNodeAnnouncement(nMsg *networkMsg, // We'll quickly ask the router if it already has a newer update for // this node so we can skip validating signatures if not required. - if d.cfg.Router.IsStaleNode(nodeAnn.NodeID, timestamp) { + if d.cfg.Graph.IsStaleNode(nodeAnn.NodeID, timestamp) { log.Debugf("Skipped processing stale node: %x", nodeAnn.NodeID) nMsg.err <- nil return nil, true @@ -2354,7 +2354,7 @@ func (d *AuthenticatedGossiper) handleNodeAnnouncement(nMsg *networkMsg, // In order to ensure we don't leak unadvertised nodes, we'll make a // quick check to ensure this node intends to publicly advertise itself // to the network. - isPublic, err := d.cfg.Router.IsPublicNode(nodeAnn.NodeID) + isPublic, err := d.cfg.Graph.IsPublicNode(nodeAnn.NodeID) if err != nil { log.Errorf("Unable to determine if node %x is advertised: %v", nodeAnn.NodeID, err) @@ -2447,7 +2447,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // At this point, we'll now ask the router if this is a zombie/known // edge. If so we can skip all the processing below. - if d.cfg.Router.IsKnownEdge(ann.ShortChannelID) { + if d.cfg.Graph.IsKnownEdge(ann.ShortChannelID) { nMsg.err <- nil return nil, true } @@ -2527,9 +2527,9 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // database and is now making decisions based on this DB state, before // it writes to the DB. d.channelMtx.Lock(ann.ShortChannelID.ToUint64()) - err := d.cfg.Router.AddEdge(edge, ops...) + err := d.cfg.Graph.AddEdge(edge, ops...) if err != nil { - log.Debugf("Router rejected edge for short_chan_id(%v): %v", + log.Debugf("Graph rejected edge for short_chan_id(%v): %v", ann.ShortChannelID.ToUint64(), err) defer d.channelMtx.Unlock(ann.ShortChannelID.ToUint64()) @@ -2725,7 +2725,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, graphScid = upd.ShortChannelID } - if d.cfg.Router.IsStaleEdgePolicy( + if d.cfg.Graph.IsStaleEdgePolicy( graphScid, timestamp, upd.ChannelFlags, ) { @@ -2749,7 +2749,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, d.channelMtx.Lock(graphScid.ToUint64()) defer d.channelMtx.Unlock(graphScid.ToUint64()) - chanInfo, e1, e2, err := d.cfg.Router.GetChannelByID(graphScid) + chanInfo, e1, e2, err := d.cfg.Graph.GetChannelByID(graphScid) switch { // No error, break. case err == nil: @@ -2945,7 +2945,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, ExtraOpaqueData: upd.ExtraOpaqueData, } - if err := d.cfg.Router.UpdateEdge(update, ops...); err != nil { + if err := d.cfg.Graph.UpdateEdge(update, ops...); err != nil { if graph.IsError( err, graph.ErrOutdated, graph.ErrIgnored, @@ -3092,7 +3092,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, d.channelMtx.Lock(ann.ShortChannelID.ToUint64()) defer d.channelMtx.Unlock(ann.ShortChannelID.ToUint64()) - chanInfo, e1, e2, err := d.cfg.Router.GetChannelByID( + chanInfo, e1, e2, err := d.cfg.Graph.GetChannelByID( ann.ShortChannelID, ) if err != nil { @@ -3282,7 +3282,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, // attest to the bitcoin keys by validating the signatures of // announcement. If proof is valid then we'll populate the channel edge // with it, so we can announce it on peer connect. - err = d.cfg.Router.AddProof(ann.ShortChannelID, &dbProof) + err = d.cfg.Graph.AddProof(ann.ShortChannelID, &dbProof) if err != nil { err := fmt.Errorf("unable add proof to the channel chanID=%v:"+ " %v", ann.ChannelID, err) diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 33d87416ac..7cfc7bce8f 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -783,7 +783,7 @@ func createTestCtx(t *testing.T, startHeight uint32) (*testCtx, error) { Timestamp: testTimestamp, }, nil }, - Router: router, + Graph: router, TrickleDelay: trickleDelay, RetransmitTicker: ticker.NewForce(retransmitDelay), RebroadcastInterval: rebroadcastInterval, @@ -1457,7 +1457,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { NotifyWhenOffline: ctx.gossiper.reliableSender.cfg.NotifyWhenOffline, FetchSelfAnnouncement: ctx.gossiper.cfg.FetchSelfAnnouncement, UpdateSelfAnnouncement: ctx.gossiper.cfg.UpdateSelfAnnouncement, - Router: ctx.gossiper.cfg.Router, + Graph: ctx.gossiper.cfg.Graph, TrickleDelay: trickleDelay, RetransmitTicker: ticker.NewForce(retransmitDelay), RebroadcastInterval: rebroadcastInterval, diff --git a/server.go b/server.go index bf57dd418a..3317fb2f28 100644 --- a/server.go +++ b/server.go @@ -1022,7 +1022,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } s.authGossiper = discovery.New(discovery.Config{ - Router: s.graphBuilder, + Graph: s.graphBuilder, Notifier: s.cc.ChainNotifier, ChainHash: *s.cfg.ActiveNetParams.GenesisHash, Broadcast: s.BroadcastMessage, From 743502f99d196cb51518370de4da5d0da20121e7 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 16 Jun 2024 21:09:10 -0400 Subject: [PATCH 16/20] funding: rename from router graph to graph --- funding/manager.go | 64 ++++++++++++++++++++--------------------- funding/manager_test.go | 26 ++++++++--------- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/funding/manager.go b/funding/manager.go index 70d5cd9c43..7fd0e9b111 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -629,11 +629,11 @@ const ( // but we still haven't announced the channel to the network. channelReadySent - // addedToRouterGraph is the opening state of a channel if the - // channel has been successfully added to the router graph - // immediately after the channelReady message has been sent, but - // we still haven't announced the channel to the network. - addedToRouterGraph + // addedToGraph is the opening state of a channel if the channel has + // been successfully added to the graph immediately after the + // channelReady message has been sent, but we still haven't announced + // the channel to the network. + addedToGraph ) func (c channelOpeningState) String() string { @@ -642,8 +642,8 @@ func (c channelOpeningState) String() string { return "markedOpen" case channelReadySent: return "channelReadySent" - case addedToRouterGraph: - return "addedToRouterGraph" + case addedToGraph: + return "addedToGraph" default: return "unknown" } @@ -1039,9 +1039,9 @@ func (f *Manager) reservationCoordinator() { // advanceFundingState will advance the channel through the steps after the // funding transaction is broadcasted, up until the point where the channel is // ready for operation. This includes waiting for the funding transaction to -// confirm, sending channel_ready to the peer, adding the channel to the -// router graph, and announcing the channel. The updateChan can be set non-nil -// to get OpenStatusUpdates. +// confirm, sending channel_ready to the peer, adding the channel to the graph, +// and announcing the channel. The updateChan can be set non-nil to get +// OpenStatusUpdates. // // NOTE: This MUST be run as a goroutine. func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, @@ -1152,7 +1152,7 @@ func (f *Manager) stateStep(channel *channeldb.OpenChannel, return nil // channelReady was sent to peer, but the channel was not added to the - // router graph and the channel announcement was not sent. + // graph and the channel announcement was not sent. case channelReadySent: // We must wait until we've received the peer's channel_ready // before sending a channel_update according to BOLT#07. @@ -1183,7 +1183,7 @@ func (f *Manager) stateStep(channel *channeldb.OpenChannel, // The channel was added to the Router's topology, but the channel // announcement was not sent. - case addedToRouterGraph: + case addedToGraph: if channel.IsZeroConf() { // If this is a zero-conf channel, then we will wait // for it to be confirmed before announcing it to the @@ -3377,15 +3377,15 @@ func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( return fwdMinHTLC, fwdMaxHTLC } -// addToRouterGraph sends a ChannelAnnouncement and a ChannelUpdate to the -// gossiper so that the channel is added to the Router's internal graph. +// addToGraph sends a ChannelAnnouncement and a ChannelUpdate to the +// gossiper so that the channel is added to the graph builder's internal graph. // These announcement messages are NOT broadcasted to the greater network, // only to the channel counter party. The proofs required to announce the // channel to the greater network will be created and sent in annAfterSixConfs. // The peerAlias is used for zero-conf channels to give the counter-party a // ChannelUpdate they understand. ourPolicy may be set for various // option-scid-alias channels to re-use the same policy. -func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, +func (f *Manager) addToGraph(completeChan *channeldb.OpenChannel, shortChanID *lnwire.ShortChannelID, peerAlias *lnwire.ShortChannelID, ourPolicy *models.ChannelEdgePolicy) error { @@ -3454,8 +3454,8 @@ func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, // annAfterSixConfs broadcasts the necessary channel announcement messages to // the network after 6 confs. Should be called after the channelReady message -// is sent and the channel is added to the router graph (channelState is -// 'addedToRouterGraph') and the channel is ready to be used. This is the last +// is sent and the channel is added to the graph (channelState is +// 'addedToGraph') and the channel is ready to be used. This is the last // step in the channel opening process, and the opening state will be deleted // from the database if successful. func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, @@ -3566,7 +3566,7 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, } // We'll delete the edge and add it again via - // addToRouterGraph. This is because the peer may have + // addToGraph. This is because the peer may have // sent us a ChannelUpdate with an alias and we don't // want to relay this. ourPolicy, err := f.cfg.DeleteAliasEdge(baseScid) @@ -3576,12 +3576,12 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, err) } - err = f.addToRouterGraph( + err = f.addToGraph( completeChan, &baseScid, nil, ourPolicy, ) if err != nil { return fmt.Errorf("failed to re-add to "+ - "router graph: %v", err) + "graph: %v", err) } } @@ -3605,9 +3605,9 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, return nil } -// waitForZeroConfChannel is called when the state is addedToRouterGraph with +// waitForZeroConfChannel is called when the state is addedToGraph with // a zero-conf channel. This will wait for the real confirmation, add the -// confirmed SCID to the router graph, and then announce after six confs. +// confirmed SCID to the graph, and then announce after six confs. func (f *Manager) waitForZeroConfChannel(c *channeldb.OpenChannel) error { // First we'll check whether the channel is confirmed on-chain. If it // is already confirmed, the chainntnfs subsystem will return with the @@ -3662,15 +3662,15 @@ func (f *Manager) waitForZeroConfChannel(c *channeldb.OpenChannel) error { } // We'll need to update the graph with the new ShortChannelID - // via an addToRouterGraph call. We don't pass in the peer's + // via an addToGraph call. We don't pass in the peer's // alias since we'll be using the confirmed SCID from now on // regardless if it's public or not. - err = f.addToRouterGraph( + err = f.addToGraph( c, &confChan.shortChanID, nil, ourPolicy, ) if err != nil { return fmt.Errorf("failed adding confirmed zero-conf "+ - "SCID to router graph: %v", err) + "SCID to graph: %v", err) } } @@ -3972,7 +3972,7 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen // handleChannelReadyReceived is called once the remote's channelReady message // is received and processed. At this stage, we must have sent out our // channelReady message, once the remote's channelReady is processed, the -// channel is now active, thus we change its state to `addedToRouterGraph` to +// channel is now active, thus we change its state to `addedToGraph` to // let the channel start handling routing. func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel, scid *lnwire.ShortChannelID, pendingChanID [32]byte, @@ -4004,9 +4004,9 @@ func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel, peerAlias = &foundAlias } - err := f.addToRouterGraph(channel, scid, peerAlias, nil) + err := f.addToGraph(channel, scid, peerAlias, nil) if err != nil { - return fmt.Errorf("failed adding to router graph: %w", err) + return fmt.Errorf("failed adding to graph: %w", err) } // As the channel is now added to the ChannelRouter's topology, the @@ -4014,15 +4014,15 @@ func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel, // moved to the last state (actually deleted from the database) after // the channel is finally announced. err = f.saveChannelOpeningState( - &channel.FundingOutpoint, addedToRouterGraph, scid, + &channel.FundingOutpoint, addedToGraph, scid, ) if err != nil { return fmt.Errorf("error setting channel state to"+ - " addedToRouterGraph: %w", err) + " addedToGraph: %w", err) } log.Debugf("Channel(%v) with ShortChanID %v: successfully "+ - "added to router graph", chanID, scid) + "added to graph", chanID, scid) // Give the caller a final update notifying them that the channel is fundingPoint := channel.FundingOutpoint @@ -4347,7 +4347,7 @@ func (f *Manager) announceChannel(localIDKey, remoteIDKey *btcec.PublicKey, } // We only send the channel proof announcement and the node announcement - // because addToRouterGraph previously sent the ChannelAnnouncement and + // because addToGraph previously sent the ChannelAnnouncement and // the ChannelUpdate announcement messages. The channel proof and node // announcements are broadcast to the greater network. errChan := f.cfg.SendAnnouncement(ann.chanProof) diff --git a/funding/manager_test.go b/funding/manager_test.go index 69e23eeb83..9db175ec39 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -1140,13 +1140,13 @@ func assertChannelReadySent(t *testing.T, alice, bob *testNode, assertDatabaseState(t, bob, fundingOutPoint, channelReadySent) } -func assertAddedToRouterGraph(t *testing.T, alice, bob *testNode, +func assertAddedToGraph(t *testing.T, alice, bob *testNode, fundingOutPoint *wire.OutPoint) { t.Helper() - assertDatabaseState(t, alice, fundingOutPoint, addedToRouterGraph) - assertDatabaseState(t, bob, fundingOutPoint, addedToRouterGraph) + assertDatabaseState(t, alice, fundingOutPoint, addedToGraph) + assertDatabaseState(t, bob, fundingOutPoint, addedToGraph) } // assertChannelAnnouncements checks that alice and bob both sends the expected @@ -1523,7 +1523,7 @@ func testNormalWorkflow(t *testing.T, chanType *lnwire.ChannelType) { assertChannelAnnouncements(t, alice, bob, capacity, nil, nil, nil, nil) // Check that the state machine is updated accordingly - assertAddedToRouterGraph(t, alice, bob, fundingOutPoint) + assertAddedToGraph(t, alice, bob, fundingOutPoint) // The funding transaction is now confirmed, wait for the // OpenStatusUpdate_ChanOpen update @@ -1877,7 +1877,7 @@ func TestFundingManagerRestartBehavior(t *testing.T) { assertChannelAnnouncements(t, alice, bob, capacity, nil, nil, nil, nil) // Check that the state machine is updated accordingly - assertAddedToRouterGraph(t, alice, bob, fundingOutPoint) + assertAddedToGraph(t, alice, bob, fundingOutPoint) // Next, we check that Alice sends the announcement signatures // on restart after six confirmations. Bob should as expected send @@ -2042,7 +2042,7 @@ func TestFundingManagerOfflinePeer(t *testing.T) { assertChannelAnnouncements(t, alice, bob, capacity, nil, nil, nil, nil) // Check that the state machine is updated accordingly - assertAddedToRouterGraph(t, alice, bob, fundingOutPoint) + assertAddedToGraph(t, alice, bob, fundingOutPoint) // The funding transaction is now confirmed, wait for the // OpenStatusUpdate_ChanOpen update @@ -2501,7 +2501,7 @@ func TestFundingManagerReceiveChannelReadyTwice(t *testing.T) { assertChannelAnnouncements(t, alice, bob, capacity, nil, nil, nil, nil) // Check that the state machine is updated accordingly - assertAddedToRouterGraph(t, alice, bob, fundingOutPoint) + assertAddedToGraph(t, alice, bob, fundingOutPoint) // The funding transaction is now confirmed, wait for the // OpenStatusUpdate_ChanOpen update @@ -2594,7 +2594,7 @@ func TestFundingManagerRestartAfterChanAnn(t *testing.T) { assertChannelAnnouncements(t, alice, bob, capacity, nil, nil, nil, nil) // Check that the state machine is updated accordingly - assertAddedToRouterGraph(t, alice, bob, fundingOutPoint) + assertAddedToGraph(t, alice, bob, fundingOutPoint) // The funding transaction is now confirmed, wait for the // OpenStatusUpdate_ChanOpen update @@ -2698,7 +2698,7 @@ func TestFundingManagerRestartAfterReceivingChannelReady(t *testing.T) { assertChannelAnnouncements(t, alice, bob, capacity, nil, nil, nil, nil) // Check that the state machine is updated accordingly - assertAddedToRouterGraph(t, alice, bob, fundingOutPoint) + assertAddedToGraph(t, alice, bob, fundingOutPoint) // Notify that six confirmations has been reached on funding // transaction. @@ -2912,9 +2912,9 @@ func TestFundingManagerPrivateRestart(t *testing.T) { // announcements. assertChannelAnnouncements(t, alice, bob, capacity, nil, nil, nil, nil) - // Note: We don't check for the addedToRouterGraph state because in + // Note: We don't check for the addedToGraph state because in // the private channel mode, the state is quickly changed from - // addedToRouterGraph to deleted from the database since the public + // addedToGraph to deleted from the database since the public // announcement phase is skipped. // The funding transaction is now confirmed, wait for the @@ -4563,8 +4563,8 @@ func testZeroConf(t *testing.T, chanType *lnwire.ChannelType) { // We'll now wait for the OpenStatusUpdate_ChanOpen update. waitForOpenUpdate(t, updateChan) - // Assert that both Alice & Bob are in the addedToRouterGraph state. - assertAddedToRouterGraph(t, alice, bob, fundingOp) + // Assert that both Alice & Bob are in the addedToGraph state. + assertAddedToGraph(t, alice, bob, fundingOp) // We'll now restart Alice's funding manager and assert that the tx // is rebroadcast. From fe34d62eb1776f8c2cb488753bb8eaece4b6d773 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 18 Jun 2024 12:34:25 -0700 Subject: [PATCH 17/20] graph+routing: address linter errors This is done in a separate commit so as to keep the original code-move commit mostly a pure code move. --- discovery/gossiper.go | 4 +- graph/builder.go | 113 +++++++++++++++++----------- graph/builder_test.go | 91 ++++++++++++++--------- graph/log.go | 2 +- graph/notifications.go | 1 - graph/notifications_test.go | 42 ++++------- routing/pathfind_test.go | 6 +- routing/router_test.go | 143 +++++++++++++++--------------------- rpcserver.go | 7 +- server.go | 1 + 10 files changed, 217 insertions(+), 193 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index c3031011dc..4805369ad5 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -2200,7 +2200,9 @@ func (d *AuthenticatedGossiper) updateChannel(info *models.ChannelEdgeInfo, // To ensure that our signature is valid, we'll verify it ourself // before committing it to the slice returned. - err = graph.ValidateChannelUpdateAnn(d.selfKey, info.Capacity, chanUpdate) + err = graph.ValidateChannelUpdateAnn( + d.selfKey, info.Capacity, chanUpdate, + ) if err != nil { return nil, nil, fmt.Errorf("generated invalid channel "+ "update sig: %v", err) diff --git a/graph/builder.go b/graph/builder.go index 4a3445cfc4..264a2aacd9 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -222,7 +222,7 @@ func (b *Builder) Start() error { // channels from the graph based on their spentness, but whether they // are considered zombies or not. We will start zombie pruning after a // small delay, to avoid slowing down startup of lnd. - if b.cfg.AssumeChannelValid { + if b.cfg.AssumeChannelValid { //nolint:nestif time.AfterFunc(b.cfg.FirstTimePruneDelay, func() { select { case <-b.quit: @@ -256,6 +256,7 @@ func (b *Builder) Start() error { if err != nil && !errors.Is( err, channeldb.ErrGraphNoEdgesFound, ) { + return err } @@ -290,7 +291,9 @@ func (b *Builder) Start() error { // from the graph in order to ensure we maintain a tight graph // of "useful" nodes. err = b.cfg.Graph.PruneGraphNodes() - if err != nil && err != channeldb.ErrGraphNodesNotFound { + if err != nil && + !errors.Is(err, channeldb.ErrGraphNodesNotFound) { + return err } } @@ -344,8 +347,8 @@ func (b *Builder) syncGraphWithChain() error { switch { // If the graph has never been pruned, or hasn't fully been // created yet, then we don't treat this as an explicit error. - case err == channeldb.ErrGraphNeverPruned: - case err == channeldb.ErrGraphNotFound: + case errors.Is(err, channeldb.ErrGraphNeverPruned): + case errors.Is(err, channeldb.ErrGraphNotFound): default: return err } @@ -355,7 +358,6 @@ func (b *Builder) syncGraphWithChain() error { pruneHeight, pruneHash) switch { - // If the graph has never been pruned, then we can exit early as this // entails it's being created for the first time and hasn't seen any // block or created channels. @@ -388,34 +390,40 @@ func (b *Builder) syncGraphWithChain() error { } pruneHash, pruneHeight, err = b.cfg.Graph.PruneTip() - if err != nil { - switch { - // If at this point the graph has never been pruned, we - // can exit as this entails we are back to the point - // where it hasn't seen any block or created channels, - // alas there's nothing left to prune. - case err == channeldb.ErrGraphNeverPruned: - return nil - case err == channeldb.ErrGraphNotFound: - return nil - default: - return err - } + switch { + // If at this point the graph has never been pruned, we can exit + // as this entails we are back to the point where it hasn't seen + // any block or created channels, alas there's nothing left to + // prune. + case errors.Is(err, channeldb.ErrGraphNeverPruned): + return nil + + case errors.Is(err, channeldb.ErrGraphNotFound): + return nil + + case err != nil: + return err + + default: } - mainBlockHash, err = b.cfg.Chain.GetBlockHash(int64(pruneHeight)) + + mainBlockHash, err = b.cfg.Chain.GetBlockHash( + int64(pruneHeight), + ) if err != nil { return err } } - log.Infof("Syncing channel graph from height=%v (hash=%v) to height=%v "+ - "(hash=%v)", pruneHeight, pruneHash, bestHeight, bestHash) + log.Infof("Syncing channel graph from height=%v (hash=%v) to "+ + "height=%v (hash=%v)", pruneHeight, pruneHash, bestHeight, + bestHash) // If we're not yet caught up, then we'll walk forward in the chain // pruning the channel graph with each new block that hasn't yet been // consumed by the channel graph. var spentOutputs []*wire.OutPoint - for nextHeight := pruneHeight + 1; nextHeight <= uint32(bestHeight); nextHeight++ { + for nextHeight := pruneHeight + 1; nextHeight <= uint32(bestHeight); nextHeight++ { //nolint:lll // Break out of the rescan early if a shutdown has been // requested, otherwise long rescans will block the daemon from // shutting down promptly. @@ -462,6 +470,7 @@ func (b *Builder) syncGraphWithChain() error { log.Infof("Graph pruning complete: %v channels were closed since "+ "height %v", len(closedChans), pruneHeight) + return nil } @@ -615,7 +624,11 @@ func (b *Builder) pruneZombieChans() error { } for _, u := range oldEdges { - filterPruneChans(u.Info, u.Policy1, u.Policy2) + err = filterPruneChans(u.Info, u.Policy1, u.Policy2) + if err != nil { + return fmt.Errorf("error filtering channels to "+ + "prune: %w", err) + } } log.Infof("Pruning %v zombie channels", len(chansToPrune)) @@ -640,7 +653,7 @@ func (b *Builder) pruneZombieChans() error { // With the channels pruned, we'll also attempt to prune any nodes that // were a part of them. err = b.cfg.Graph.PruneGraphNodes() - if err != nil && err != channeldb.ErrGraphNodesNotFound { + if err != nil && !errors.Is(err, channeldb.ErrGraphNodesNotFound) { return fmt.Errorf("unable to prune graph nodes: %w", err) } @@ -761,7 +774,6 @@ func (b *Builder) networkHandler() { } for { - // If there are stats, resume the statTicker. if !b.stats.Empty() { b.statTicker.Resume() @@ -793,12 +805,14 @@ func (b *Builder) networkHandler() { // Since this block is stale, we update our best height // to the previous block. - blockHeight := uint32(chainUpdate.Height) + blockHeight := chainUpdate.Height atomic.StoreUint32(&b.bestHeight, blockHeight-1) // Update the channel graph to reflect that this block // was disconnected. - _, err := b.cfg.Graph.DisconnectBlockAtHeight(blockHeight) + _, err := b.cfg.Graph.DisconnectBlockAtHeight( + blockHeight, + ) if err != nil { log.Errorf("unable to prune graph with stale "+ "block: %v", err) @@ -836,7 +850,9 @@ func (b *Builder) networkHandler() { "height=%v, got height=%v", currentHeight+1, chainUpdate.Height) - err := b.getMissingBlocks(currentHeight, chainUpdate) + err := b.getMissingBlocks( + currentHeight, chainUpdate, + ) if err != nil { log.Errorf("unable to retrieve missing"+ "blocks: %v", err) @@ -1136,6 +1152,8 @@ func makeFundingScript(bitcoinKey1, bitcoinKey2 []byte, // channel/edge update network update. If the update didn't affect the internal // state of the draft due to either being out of date, invalid, or redundant, // then error is returned. +// +//nolint:funlen func (b *Builder) processUpdate(msg interface{}, op ...batch.SchedulerOption) error { @@ -1166,7 +1184,9 @@ func (b *Builder) processUpdate(msg interface{}, _, _, exists, isZombie, err := b.cfg.Graph.HasChannelEdge( msg.ChannelID, ) - if err != nil && err != channeldb.ErrGraphNoEdgesFound { + if err != nil && + !errors.Is(err, channeldb.ErrGraphNoEdgesFound) { + return errors.Errorf("unable to check for edge "+ "existence: %v", err) } @@ -1188,7 +1208,8 @@ func (b *Builder) processUpdate(msg interface{}, // ChannelAnnouncement from the gossiper. scid := lnwire.NewShortChanIDFromInt(msg.ChannelID) if b.cfg.AssumeChannelValid || b.cfg.IsAlias(scid) { - if err := b.cfg.Graph.AddChannelEdge(msg, op...); err != nil { + err := b.cfg.Graph.AddChannelEdge(msg, op...) + if err != nil { return fmt.Errorf("unable to add edge: %w", err) } log.Tracef("New channel discovered! Link "+ @@ -1206,6 +1227,8 @@ func (b *Builder) processUpdate(msg interface{}, channelID := lnwire.NewShortChanIDFromInt(msg.ChannelID) fundingTx, err := b.fetchFundingTxWrapper(&channelID) if err != nil { + //nolint:lll + // // In order to ensure we don't erroneously mark a // channel as a zombie due to an RPC failure, we'll // attempt to string match for the relevant errors. @@ -1253,13 +1276,15 @@ func (b *Builder) processUpdate(msg interface{}, // formed. If this check fails, then this channel either // doesn't exist, or isn't the one that was meant to be created // according to the passed channel proofs. - fundingPoint, err := chanvalidate.Validate(&chanvalidate.Context{ - Locator: &chanvalidate.ShortChanIDChanLocator{ - ID: channelID, + fundingPoint, err := chanvalidate.Validate( + &chanvalidate.Context{ + Locator: &chanvalidate.ShortChanIDChanLocator{ + ID: channelID, + }, + MultiSigPkScript: fundingPkScript, + FundingTx: fundingTx, }, - MultiSigPkScript: fundingPkScript, - FundingTx: fundingTx, - }) + ) if err != nil { // Mark the edge as a zombie so we won't try to // re-validate it on start up. @@ -1336,16 +1361,20 @@ func (b *Builder) processUpdate(msg interface{}, edge1Timestamp, edge2Timestamp, exists, isZombie, err := b.cfg.Graph.HasChannelEdge(msg.ChannelID) - if err != nil && err != channeldb.ErrGraphNoEdgesFound { + if err != nil && !errors.Is( + err, channeldb.ErrGraphNoEdgesFound, + ) { + return errors.Errorf("unable to check for edge "+ "existence: %v", err) - } // If the channel is marked as a zombie in our database, and // we consider this a stale update, then we should not apply the // policy. - isStaleUpdate := time.Since(msg.LastUpdate) > b.cfg.ChannelPruneExpiry + isStaleUpdate := time.Since(msg.LastUpdate) > + b.cfg.ChannelPruneExpiry + if isZombie && isStaleUpdate { return newErrf(ErrIgnored, "ignoring stale update "+ "(flags=%v|%v) for zombie chan_id=%v", @@ -1368,7 +1397,6 @@ func (b *Builder) processUpdate(msg interface{}, // that edge. If this message has a timestamp not strictly // newer than what we already know of we can exit early. switch { - // A flag set of 0 indicates this is an announcement for the // "first" node in the channel. case msg.ChannelFlags&lnwire.ChanUpdateDirection == 0: @@ -1448,7 +1476,7 @@ func (b *Builder) fetchFundingTxWrapper(chanID *lnwire.ShortChannelID) ( // short channel ID. // // TODO(roasbeef): replace with call to GetBlockTransaction? (would allow to -// later use getblocktxn) +// later use getblocktxn). func (b *Builder) fetchFundingTx( chanID *lnwire.ShortChannelID) (*wire.MsgTx, error) { @@ -1702,6 +1730,7 @@ func (b *Builder) AddProof(chanID lnwire.ShortChannelID, } info.AuthProof = proof + return b.cfg.Graph.UpdateChannelEdge(info) } @@ -1739,6 +1768,7 @@ func (b *Builder) IsKnownEdge(chanID lnwire.ShortChannelID) bool { _, _, exists, isZombie, _ := b.cfg.Graph.HasChannelEdge( chanID.ToUint64(), ) + return exists || isZombie } @@ -1754,7 +1784,6 @@ func (b *Builder) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, if err != nil { log.Debugf("Check stale edge policy got error: %v", err) return false - } // If we know of the edge as a zombie, then we'll make some additional diff --git a/graph/builder_test.go b/graph/builder_test.go index d3e25d2aab..e3a3bd0913 100644 --- a/graph/builder_test.go +++ b/graph/builder_test.go @@ -182,7 +182,7 @@ func TestIgnoreChannelEdgePolicyForUnknownChannel(t *testing.T) { } // TestWakeUpOnStaleBranch tests that upon startup of the ChannelRouter, if the -// the chain previously reflected in the channel graph is stale (overtaken by a +// chain previously reflected in the channel graph is stale (overtaken by a // longer chain), the channel router will prune the graph for any channels // confirmed on the stale chain, and resync to the main chain. func TestWakeUpOnStaleBranch(t *testing.T) { @@ -216,7 +216,6 @@ func TestWakeUpOnStaleBranch(t *testing.T) { block.Transactions = append(block.Transactions, fundingTx) chanID1 = chanID.ToUint64() - } ctx.chain.addBlock(block, height, rand.Uint32()) ctx.chain.setBestBlock(int32(height)) @@ -418,7 +417,6 @@ func TestDisconnectedBlocks(t *testing.T) { block.Transactions = append(block.Transactions, fundingTx) chanID1 = chanID.ToUint64() - } ctx.chain.addBlock(block, height, rand.Uint32()) ctx.chain.setBestBlock(int32(height)) @@ -633,7 +631,9 @@ func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { } // The router should now be aware of the channel we created above. - _, _, hasChan, isZombie, err := ctx.graph.HasChannelEdge(chanID1.ToUint64()) + _, _, hasChan, isZombie, err := ctx.graph.HasChannelEdge( + chanID1.ToUint64(), + ) if err != nil { t.Fatalf("error looking for edge: %v", chanID1) } @@ -713,7 +713,9 @@ func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { // At this point, the channel that was pruned should no longer be known // by the router. - _, _, hasChan, isZombie, err = ctx.graph.HasChannelEdge(chanID1.ToUint64()) + _, _, hasChan, isZombie, err = ctx.graph.HasChannelEdge( + chanID1.ToUint64(), + ) if err != nil { t.Fatalf("error looking for edge: %v", chanID1) } @@ -833,20 +835,23 @@ func TestPruneChannelGraphStaleEdges(t *testing.T) { // All of the channels should exist before pruning them. assertChannelsPruned(t, ctx.graph, testChannels) - // Proceed to prune the channels - only the last one should be pruned. + // Proceed to prune the channels - only the last one should be + // pruned. if err := ctx.builder.pruneZombieChans(); err != nil { t.Fatalf("unable to prune zombie channels: %v", err) } - // We expect channels that have either both edges stale, or one edge - // stale with both known. + // We expect channels that have either both edges stale, or one + // edge stale with both known. var prunedChannels []uint64 if strictPruning { prunedChannels = []uint64{2, 5, 7} } else { prunedChannels = []uint64{2, 7} } - assertChannelsPruned(t, ctx.graph, testChannels, prunedChannels...) + assertChannelsPruned( + t, ctx.graph, testChannels, prunedChannels..., + ) } } @@ -1387,7 +1392,9 @@ func TestBlockDifferenceFix(t *testing.T) { err := wait.NoError(func() error { // Then router height should be updated to the latest block. - if atomic.LoadUint32(&ctx.builder.bestHeight) != newBlockHeight { + if atomic.LoadUint32(&ctx.builder.bestHeight) != + newBlockHeight { + return fmt.Errorf("height should have been updated "+ "to %v, instead got %v", newBlockHeight, ctx.builder.bestHeight) @@ -1589,7 +1596,10 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( } err = graph.AddChannelEdge(&edgeInfo) - if err != nil && err != channeldb.ErrEdgeAlreadyExist { + if err != nil && !errors.Is( + err, channeldb.ErrEdgeAlreadyExist, + ) { + return nil, err } @@ -1601,17 +1611,27 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( } edgePolicy := &models.ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), - ChannelFlags: channelFlags, - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: edge.Expiry, - MinHTLC: lnwire.MilliSatoshi(edge.MinHTLC), - MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC), - FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat), - FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate), - ToNode: targetNode, + SigBytes: testSig.Serialize(), + MessageFlags: lnwire.ChanUpdateMsgFlags( + edge.MessageFlags, + ), + ChannelFlags: channelFlags, + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: edge.Expiry, + MinHTLC: lnwire.MilliSatoshi( + edge.MinHTLC, + ), + MaxHTLC: lnwire.MilliSatoshi( + edge.MaxHTLC, + ), + FeeBaseMSat: lnwire.MilliSatoshi( + edge.FeeBaseMsat, + ), + FeeProportionalMillionths: lnwire.MilliSatoshi( + edge.FeeRate, + ), + ToNode: targetNode, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { return nil, err @@ -1652,7 +1672,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( // testGraph is the struct which corresponds to the JSON format used to encode // graphs within the files in the testdata directory. // -// TODO(roasbeef): add test graph auto-generator +// TODO(roasbeef): add test graph auto-generator. type testGraph struct { Info []string `json:"info"` Nodes []testNode `json:"nodes"` @@ -1788,13 +1808,14 @@ type testChannelPolicy struct { Features *lnwire.FeatureVector } -// createTestGraphFromChannels returns a fully populated ChannelGraph based on a set of -// test channels. Additional required information like keys are derived in -// a deterministic way and added to the channel graph. A list of nodes is -// not required and derived from the channel data. The goal is to keep -// instantiating a test channel graph as light weight as possible. +// createTestGraphFromChannels returns a fully populated ChannelGraph based on a +// set of test channels. Additional required information like keys are derived +// in a deterministic way and added to the channel graph. A list of nodes is not +// required and derived from the channel data. The goal is to keep instantiating +// a test channel graph as light weight as possible. func createTestGraphFromChannels(t *testing.T, useCache bool, - testChannels []*testChannel, source string) (*testGraphInstance, error) { + testChannels []*testChannel, source string) (*testGraphInstance, + error) { // We'll use this fake address for the IP address of all the nodes in // our tests. This value isn't needed for path finding so it doesn't @@ -1940,7 +1961,9 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, } err = graph.AddChannelEdge(&edgeInfo) - if err != nil && err != channeldb.ErrEdgeAlreadyExist { + if err != nil && + !errors.Is(err, channeldb.ErrEdgeAlreadyExist) { + return nil, err } @@ -1981,7 +2004,8 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, ToNode: node2Vertex, ExtraOpaqueData: getExtraData(node1), } - if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { + err := graph.UpdateEdgePolicy(edgePolicy) + if err != nil { return nil, err } } @@ -2011,12 +2035,13 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, ToNode: node1Vertex, ExtraOpaqueData: getExtraData(node2), } - if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { + err := graph.UpdateEdgePolicy(edgePolicy) + if err != nil { return nil, err } } - channelID++ + channelID++ //nolint:ineffassign } return &testGraphInstance{ diff --git a/graph/log.go b/graph/log.go index 2bd55297a0..cd31dae11c 100644 --- a/graph/log.go +++ b/graph/log.go @@ -18,7 +18,7 @@ func init() { } // DisableLog disables all library log output. Logging output is disabled by -// by default until UseLogger is called. +// default until UseLogger is called. func DisableLog() { UseLogger(btclog.Disabled) } diff --git a/graph/notifications.go b/graph/notifications.go index 36f4e09a97..90748b05af 100644 --- a/graph/notifications.go +++ b/graph/notifications.go @@ -117,7 +117,6 @@ type topologyClient struct { // notifyTopologyChange notifies all registered clients of a new change in // graph topology in a non-blocking. func (b *Builder) notifyTopologyChange(topologyDiff *TopologyChange) { - // notifyClient is a helper closure that will send topology updates to // the given client. notifyClient := func(clientID uint64, client *topologyClient) bool { diff --git a/graph/notifications_test.go b/graph/notifications_test.go index 290eec0e2a..09ebf1211b 100644 --- a/graph/notifications_test.go +++ b/graph/notifications_test.go @@ -55,13 +55,19 @@ var ( timeout = time.Second * 5 - testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571319d18e949ddfa2965fb6caa1bf0314f882d7") - testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a88121167221b6700d72a0ead154c03be696a292d24ae") - testRScalar = new(btcec.ModNScalar) - testSScalar = new(btcec.ModNScalar) - _ = testRScalar.SetByteSlice(testRBytes) - _ = testSScalar.SetByteSlice(testSBytes) - testSig = ecdsa.NewSignature(testRScalar, testSScalar) + testRBytes, _ = hex.DecodeString( + "8ce2bc69281ce27da07e6683571319d18e949ddfa2965fb6caa1bf03" + + "14f882d7", + ) + testSBytes, _ = hex.DecodeString( + "299105481d63e0f4bc2a88121167221b6700d72a0ead154c03be696a2" + + "92d24ae", + ) + testRScalar = new(btcec.ModNScalar) + testSScalar = new(btcec.ModNScalar) + _ = testRScalar.SetByteSlice(testRBytes) + _ = testSScalar.SetByteSlice(testSBytes) + testSig = ecdsa.NewSignature(testRScalar, testSScalar) testAuthProof = models.ChannelAuthProof{ NodeSig1Bytes: testSig.Serialize(), @@ -1027,22 +1033,6 @@ type testCtx struct { notifier *lnmock.ChainNotifier } -func (c *testCtx) getChannelIDFromAlias(t *testing.T, a, b string) uint64 { - vertexA, ok := c.aliases[a] - require.True(t, ok, "cannot find aliases for %s", a) - - vertexB, ok := c.aliases[b] - require.True(t, ok, "cannot find aliases for %s", b) - - channelIDMap, ok := c.channelIDs[vertexA] - require.True(t, ok, "cannot find channelID map %s(%s)", vertexA, a) - - channelID, ok := channelIDMap[vertexB] - require.True(t, ok, "cannot find channelID using %s(%s)", vertexB, b) - - return channelID -} - func createTestCtxSingleNode(t *testing.T, startingHeight uint32) *testCtx { @@ -1127,8 +1117,8 @@ type testGraphInstance struct { graphBackend kvdb.Backend // aliasMap is a map from a node's alias to its public key. This type is - // provided in order to allow easily look up from the human memorable alias - // to an exact node's public key. + // provided in order to allow easily look up from the human memorable + // alias to an exact node's public key. aliasMap map[string]route.Vertex // privKeyMap maps a node alias to its private key. This is used to be @@ -1201,7 +1191,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, } t.Cleanup(func() { - graphBuilder.Stop() + require.NoError(t, graphBuilder.Stop()) }) return ctx diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index e4912d988b..9c69cba5c9 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -2288,7 +2288,8 @@ func TestPathFindSpecExample(t *testing.T) { // parameters. lastHop := route.Hops[1] require.EqualValues(t, amt, lastHop.AmtToForward) - require.EqualValues(t, startingHeight+MinCLTVDelta, lastHop.OutgoingTimeLock) + require.EqualValues(t, startingHeight+MinCLTVDelta, + lastHop.OutgoingTimeLock) } func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, @@ -2297,7 +2298,8 @@ func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, require.Len(t, path, len(nodeAliases)) for i, hop := range path { - require.Equal(t, aliasMap[nodeAliases[i]], hop.policy.ToNodePubKey()) + require.Equal(t, aliasMap[nodeAliases[i]], + hop.policy.ToNodePubKey()) } } diff --git a/routing/router_test.go b/routing/router_test.go index 824d6aed9a..1749a6dabc 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -59,8 +59,6 @@ var ( priv2, _ = btcec.NewPrivateKey() bitcoinKey2 = priv2.PubKey() - - timeout = time.Second * 5 ) type testCtx struct { @@ -194,7 +192,7 @@ func createTestNode() (*channeldb.LightningNode, error) { LastUpdate: time.Unix(updateTime, 0), Addresses: testAddrs, Color: color.RGBA{1, 2, 3, 0}, - Alias: "kek" + string(pub[:]), + Alias: "kek" + string(pub), AuthSigBytes: testSig.Serialize(), Features: testFeatures, } @@ -308,7 +306,6 @@ func TestSendPaymentRouteFailureFallback(t *testing.T) { // the more costly path (through pham nuwen). ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( // TODO(roasbeef): temp node failure @@ -607,26 +604,29 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { // We'll now modify the SendToSwitch method to return an error for the // outgoing channel to Son goku. This will be a fee related error, so // it should only cause the edge to be pruned after the second attempt. - ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( - func(firstHop lnwire.ShortChannelID) ([32]byte, error) { + dispatcher, ok := ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld) //nolint:lll + require.True(t, ok) - roasbeefSongoku := lnwire.NewShortChanIDFromInt( - roasbeefSongokuChanID, + dispatcher.setPaymentResult(func(firstHop lnwire.ShortChannelID) ( + [32]byte, error) { + + roasbeefSongoku := lnwire.NewShortChanIDFromInt( + roasbeefSongokuChanID, + ) + if firstHop == roasbeefSongoku { + return [32]byte{}, htlcswitch.NewForwardingError( + // Within our error, we'll add a + // channel update which is meant to + // reflect the new fee schedule for the + // node/channel. + &lnwire.FailFeeInsufficient{ + Update: errChanUpdate, + }, 1, ) - if firstHop == roasbeefSongoku { - return [32]byte{}, htlcswitch.NewForwardingError( - // Within our error, we'll add a - // channel update which is meant to - // reflect the new fee schedule for the - // node/channel. - &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, - }, 1, - ) - } + } - return preImage, nil - }) + return preImage, nil + }) // Send off the payment request to the router, route through phamnuwen // should've been selected as a fall back and succeeded correctly. @@ -1211,12 +1211,8 @@ func TestFindPathFeeWeighting(t *testing.T) { // The route that was chosen should be exactly one hop, and should be // directly to luoji. - if len(path) != 1 { - t.Fatalf("expected path length of 1, instead was: %v", len(path)) - } - if path[0].policy.ToNodePubKey() != ctx.aliases["luoji"] { - t.Fatalf("wrong node: %v", path[0].policy.ToNodePubKey()) - } + require.Len(t, path, 1) + require.Equal(t, ctx.aliases["luoji"], path[0].policy.ToNodePubKey()) } // TestEmptyRoutesGenerateSphinxPacket tests that the generateSphinxPacket @@ -1228,9 +1224,7 @@ func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { sessionKey, _ := btcec.NewPrivateKey() emptyRoute := &route.Route{} _, _, err := generateSphinxPacket(emptyRoute, testHash[:], sessionKey) - if err != route.ErrNoRouteHopsProvided { - t.Fatalf("expected empty hops error: instead got: %v", err) - } + require.ErrorIs(t, err, route.ErrNoRouteHopsProvided) } // TestUnknownErrorSource tests that if the source of an error is unknown, all @@ -1270,7 +1264,9 @@ func TestUnknownErrorSource(t *testing.T) { }, 4), } - testGraph, err := createTestGraphFromChannels(t, true, testChannels, "a") + testGraph, err := createTestGraphFromChannels( + t, true, testChannels, "a", + ) require.NoError(t, err, "unable to create graph") const startingBlockHeight = 101 @@ -1284,20 +1280,23 @@ func TestUnknownErrorSource(t *testing.T) { // We'll modify the SendToSwitch method so that it simulates hop b as a // node that returns an unparsable failure if approached via the a->b // channel. - ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( - func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - - // If channel a->b is used, return an error without - // source and message. The sender won't know the origin - // of the error. - if firstHop.ToUint64() == 1 { - return [32]byte{}, - htlcswitch.ErrUnreadableFailureMessage - } + dispatcher, ok := ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld) //nolint:lll + require.True(t, ok) + + dispatcher.setPaymentResult(func(firstHop lnwire.ShortChannelID) ( + [32]byte, error) { + + // If channel a->b is used, return an error without + // source and message. The sender won't know the origin + // of the error. + if firstHop.ToUint64() == 1 { + return [32]byte{}, + htlcswitch.ErrUnreadableFailureMessage + } - // Otherwise the payment succeeds. - return lntypes.Preimage{}, nil - }) + // Otherwise the payment succeeds. + return lntypes.Preimage{}, nil + }) // Send off the payment request to the router. The expectation is that // the route a->b->c is tried first. An unreadable faiure is returned @@ -1308,19 +1307,22 @@ func TestUnknownErrorSource(t *testing.T) { payment.paymentHash) // Next we modify payment result to return an unknown failure. - ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( - func(firstHop lnwire.ShortChannelID) ([32]byte, error) { + dispatcher, ok = ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld) //nolint:lll + require.True(t, ok) - // If channel a->b is used, simulate that the failure - // couldn't be decoded (FailureMessage is nil). - if firstHop.ToUint64() == 2 { - return [32]byte{}, - htlcswitch.NewUnknownForwardingError(1) - } + dispatcher.setPaymentResult(func(firstHop lnwire.ShortChannelID) ( + [32]byte, error) { - // Otherwise the payment succeeds. - return lntypes.Preimage{}, nil - }) + // If channel a->b is used, simulate that the failure + // couldn't be decoded (FailureMessage is nil). + if firstHop.ToUint64() == 2 { + return [32]byte{}, + htlcswitch.NewUnknownForwardingError(1) + } + + // Otherwise the payment succeeds. + return lntypes.Preimage{}, nil + }) // Send off the payment request to the router. We expect the payment to // fail because both routes have been pruned. @@ -2353,7 +2355,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { ) node1Bytes := priv1.PubKey().SerializeCompressed() node2Bytes := connectNode - if bytes.Compare(node1Bytes[:], node2Bytes[:]) == -1 { + if bytes.Compare(node1Bytes, node2Bytes[:]) == -1 { pubKey1 = priv1.PubKey() pubKey2 = connectNodeKey } else { @@ -2558,35 +2560,6 @@ func (m *mockChain) GetBestBlock() (*chainhash.Hash, int32, error) { return &blockHash, m.bestHeight, nil } -func (m *mockChain) setBestBlock(height int32) { - m.Lock() - defer m.Unlock() - - m.bestHeight = height -} - -func (m *mockChain) addUtxo(op wire.OutPoint, out *wire.TxOut) { - m.Lock() - m.utxos[op] = *out - m.Unlock() -} - -func (m *mockChain) delUtxo(op wire.OutPoint) { - m.Lock() - delete(m.utxos, op) - m.Unlock() -} - -func (m *mockChain) addBlock(block *wire.MsgBlock, height uint32, nonce uint32) { - m.Lock() - block.Header.Nonce = nonce - hash := block.Header.BlockHash() - m.blocks[hash] = block - m.blockIndex[height] = hash - m.blockHeightIndex[hash] = height - m.Unlock() -} - func createChannelEdge(bitcoinKey1, bitcoinKey2 []byte, chanValue btcutil.Amount, fundingHeight uint32) (*wire.MsgTx, *wire.OutPoint, *lnwire.ShortChannelID, error) { diff --git a/rpcserver.go b/rpcserver.go index 1ba82015e7..465035ea55 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6666,7 +6666,8 @@ func (r *rpcServer) SubscribeChannelGraph(req *lnrpc.GraphTopologySubscription, // marshallTopologyChange performs a mapping from the topology change struct // returned by the router to the form of notifications expected by the current // gRPC service. -func marshallTopologyChange(topChange *graph.TopologyChange) *lnrpc.GraphTopologyUpdate { +func marshallTopologyChange( + topChange *graph.TopologyChange) *lnrpc.GraphTopologyUpdate { // encodeKey is a simple helper function that converts a live public // key into a hex-encoded version of the compressed serialization for @@ -6677,7 +6678,9 @@ func marshallTopologyChange(topChange *graph.TopologyChange) *lnrpc.GraphTopolog nodeUpdates := make([]*lnrpc.NodeUpdate, len(topChange.NodeUpdates)) for i, nodeUpdate := range topChange.NodeUpdates { - nodeAddrs := make([]*lnrpc.NodeAddress, 0, len(nodeUpdate.Addresses)) + nodeAddrs := make( + []*lnrpc.NodeAddress, 0, len(nodeUpdate.Addresses), + ) for _, addr := range nodeUpdate.Addresses { nodeAddr := &lnrpc.NodeAddress{ Network: addr.Network(), diff --git a/server.go b/server.go index 3317fb2f28..7ae6ed21e7 100644 --- a/server.go +++ b/server.go @@ -1060,6 +1060,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, IsStillZombieChannel: s.graphBuilder.IsZombieChannel, }, nodeKeyDesc) + //nolint:lll s.localChanMgr = &localchans.Manager{ ForAllOutgoingChannels: s.graphBuilder.ForAllOutgoingChannels, PropagateChanPolicyUpdate: s.authGossiper.PropagateChanPolicyUpdate, From 90dff730ce294a171bfa9f52a4f336554049166b Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 15 Jul 2024 15:00:10 +0200 Subject: [PATCH 18/20] graph: updated builder to use atomic ints Instead of relying on devs to remember that they must only be accessed atomically. --- graph/builder.go | 18 +++++++++--------- graph/builder_test.go | 7 ++----- graph/notifications.go | 3 +-- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/graph/builder.go b/graph/builder.go index 264a2aacd9..2436f31767 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -116,8 +116,8 @@ type Builder struct { started atomic.Bool stopped atomic.Bool - ntfnClientCounter uint64 // To be used atomically. - bestHeight uint32 // To be used atomically. + ntfnClientCounter atomic.Uint64 + bestHeight atomic.Uint32 cfg *Config @@ -278,7 +278,7 @@ func (b *Builder) Start() error { if err != nil { return err } - b.bestHeight = uint32(bestHeight) + b.bestHeight.Store(uint32(bestHeight)) // Before we begin normal operation of the router, we first need // to synchronize the channel graph to the latest state of the @@ -340,7 +340,7 @@ func (b *Builder) syncGraphWithChain() error { if err != nil { return err } - b.bestHeight = uint32(bestHeight) + b.bestHeight.Store(uint32(bestHeight)) pruneHash, pruneHeight, err := b.cfg.Graph.PruneTip() if err != nil { @@ -806,7 +806,7 @@ func (b *Builder) networkHandler() { // Since this block is stale, we update our best height // to the previous block. blockHeight := chainUpdate.Height - atomic.StoreUint32(&b.bestHeight, blockHeight-1) + b.bestHeight.Store(blockHeight - 1) // Update the channel graph to reflect that this block // was disconnected. @@ -834,7 +834,7 @@ func (b *Builder) networkHandler() { // directly to the end of our main chain. If not, then // we've somehow missed some blocks. Here we'll catch // up the chain with the latest blocks. - currentHeight := atomic.LoadUint32(&b.bestHeight) + currentHeight := b.bestHeight.Load() switch { case chainUpdate.Height == currentHeight+1: err := b.updateGraphWithClosedChannels( @@ -991,7 +991,7 @@ func (b *Builder) updateGraphWithClosedChannels( // of the chain tip. blockHeight := chainUpdate.Height - atomic.StoreUint32(&b.bestHeight, blockHeight) + b.bestHeight.Store(blockHeight) log.Infof("Pruning channel graph using block %v (height=%v)", chainUpdate.Hash, blockHeight) @@ -1342,7 +1342,7 @@ func (b *Builder) processUpdate(msg interface{}, }, } err = b.cfg.ChainView.UpdateFilter( - filterUpdate, atomic.LoadUint32(&b.bestHeight), + filterUpdate, b.bestHeight.Load(), ) if err != nil { return errors.Errorf("unable to update chain "+ @@ -1658,7 +1658,7 @@ func (b *Builder) CurrentBlockHeight() (uint32, error) { // is synced to. This can differ from the above chain height if the goroutine // responsible for processing the blocks isn't yet up to speed. func (b *Builder) SyncedHeight() uint32 { - return atomic.LoadUint32(&b.bestHeight) + return b.bestHeight.Load() } // GetChannelByID return the channel by the channel id. diff --git a/graph/builder_test.go b/graph/builder_test.go index e3a3bd0913..600bd86344 100644 --- a/graph/builder_test.go +++ b/graph/builder_test.go @@ -11,7 +11,6 @@ import ( "net" "os" "strings" - "sync/atomic" "testing" "time" @@ -1392,12 +1391,10 @@ func TestBlockDifferenceFix(t *testing.T) { err := wait.NoError(func() error { // Then router height should be updated to the latest block. - if atomic.LoadUint32(&ctx.builder.bestHeight) != - newBlockHeight { - + if ctx.builder.bestHeight.Load() != newBlockHeight { return fmt.Errorf("height should have been updated "+ "to %v, instead got %v", newBlockHeight, - ctx.builder.bestHeight) + ctx.builder.bestHeight.Load()) } return nil diff --git a/graph/notifications.go b/graph/notifications.go index 90748b05af..14ea3d127d 100644 --- a/graph/notifications.go +++ b/graph/notifications.go @@ -5,7 +5,6 @@ import ( "image/color" "net" "sync" - "sync/atomic" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" @@ -65,7 +64,7 @@ func (b *Builder) SubscribeTopology() (*TopologyClient, error) { // We'll first atomically obtain the next ID for this client from the // incrementing client ID counter. - clientID := atomic.AddUint64(&b.ntfnClientCounter, 1) + clientID := b.ntfnClientCounter.Add(1) log.Debugf("New graph topology client subscription, client %v", clientID) From d1c54d74a82944053a658234090bdc36d0ce864a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 15 Jul 2024 15:01:28 +0200 Subject: [PATCH 19/20] routing: close graph session if getBandwidthHints fails Ensure that the graph session used during pathfinding is properly closed if the call to getBandwidthHints fails. --- routing/payment_session.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/routing/payment_session.go b/routing/payment_session.go index 84f2135d79..a464bd93ed 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -294,6 +294,12 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // attempt, because concurrent payments may change balances. bandwidthHints, err := p.getBandwidthHints(graph) if err != nil { + // Close routing graph session. + if graphErr := closeGraph(); graphErr != nil { + log.Errorf("could not close graph session: %v", + graphErr) + } + return nil, err } From b112e10bf2b43dbbe2598ebf4f708de4e39ca983 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 12 Jul 2024 12:30:23 +0200 Subject: [PATCH 20/20] docs: update release notes Also move incorrect entry from 18.2 to 18.3 --- docs/release-notes/release-notes-0.18.2.md | 4 ---- docs/release-notes/release-notes-0.18.3.md | 9 +++++++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/release-notes/release-notes-0.18.2.md b/docs/release-notes/release-notes-0.18.2.md index 072d28669b..3be21d93ac 100644 --- a/docs/release-notes/release-notes-0.18.2.md +++ b/docs/release-notes/release-notes-0.18.2.md @@ -40,10 +40,6 @@ ## BOLT Spec Updates ## Testing ## Database - -* [Fixed](https://github.com/lightningnetwork/lnd/pull/8854) pagination issues - in SQL invoicedb queries. - ## Code Health ## Tooling and Documentation diff --git a/docs/release-notes/release-notes-0.18.3.md b/docs/release-notes/release-notes-0.18.3.md index 9245903954..95f1961bfe 100644 --- a/docs/release-notes/release-notes-0.18.3.md +++ b/docs/release-notes/release-notes-0.18.3.md @@ -100,7 +100,16 @@ invoice database. Invoices with incorrect expiry values will be updated to 24-hour expiry, which is the default behavior in LND. +* [Fixed](https://github.com/lightningnetwork/lnd/pull/8854) pagination issues + in SQL invoicedb queries. + ## Code Health + +* [Move graph building and + maintaining](https://github.com/lightningnetwork/lnd/pull/8848) duties from + the `routing.ChannelRouter` to the new `graph.Builder` sub-system and also + remove the `channeldb.ChannelGraph` pointer from the `ChannelRouter`. + ## Tooling and Documentation # Contributors (Alphabetical Order)