From e7a987b4a34a70f8bb9ab720f3a5db176ff1b6e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milo=C5=A1=20=C5=BDivkovi=C4=87?= Date: Sat, 26 Oct 2024 22:08:03 -0400 Subject: [PATCH] Add remaining unit tests for the switch --- tm2/pkg/p2p/discovery/discovery.go | 2 + tm2/pkg/p2p/switch.go | 65 +++-- tm2/pkg/p2p/switch_test.go | 399 ++++++++++++++++++++++++++++- 3 files changed, 441 insertions(+), 25 deletions(-) diff --git a/tm2/pkg/p2p/discovery/discovery.go b/tm2/pkg/p2p/discovery/discovery.go index ab36a737229..132e0855cf8 100644 --- a/tm2/pkg/p2p/discovery/discovery.go +++ b/tm2/pkg/p2p/discovery/discovery.go @@ -187,6 +187,8 @@ func (r *Reactor) handleDiscoveryRequest(peer p2p.Peer) error { peers = make([]*types.NetAddress, 0, len(localPeers)) ) + // TODO exclude private peers + // Shuffle and limit the peers shared shufflePeers(localPeers) diff --git a/tm2/pkg/p2p/switch.go b/tm2/pkg/p2p/switch.go index 963891844f2..477ed5cace2 100644 --- a/tm2/pkg/p2p/switch.go +++ b/tm2/pkg/p2p/switch.go @@ -323,43 +323,60 @@ func (sw *MultiplexSwitch) runRedialLoop(ctx context.Context) { ticker := time.NewTicker(time.Second * 10) defer ticker.Stop() - for { - select { - case <-ctx.Done(): - sw.Logger.Debug("redial crawl context canceled") + // redialFn goes through the persistent peer list + // and dials missing peers + redialFn := func() { + var ( + peers = sw.Peers() + peersToDial = make([]*types.NetAddress, 0) + ) - return - case <-ticker.C: - peers := sw.Peers() + sw.persistentPeers.Range(func(key, value any) bool { + var ( + id = key.(types.ID) + addr = value.(*types.NetAddress) + ) - peersToDial := make([]*types.NetAddress, 0) + // Check if the peer is part of the peer set + // or is scheduled for dialing + if peers.Has(id) || sw.dialQueue.Has(addr) { + return true + } - sw.persistentPeers.Range(func(key, value any) bool { - var ( - id = key.(types.ID) - addr = value.(*types.NetAddress) - ) + peersToDial = append(peersToDial, addr) - // Check if the peer is part of the peer set - // or is scheduled for dialing - if peers.Has(id) || sw.dialQueue.Has(addr) { - return true - } + return true + }) - peersToDial = append(peersToDial, addr) + if len(peersToDial) == 0 { + // No persistent peers are missing + return + } - return true - }) + // Add the peers to the dial queue + sw.DialPeers(peersToDial...) + } + + // Run the initial redial loop on start, + // in case persistent peer connections are not + // active + redialFn() + + for { + select { + case <-ctx.Done(): + sw.Logger.Debug("redial crawl context canceled") - // Add the peers to the dial queue - sw.DialPeers(peersToDial...) + return + case <-ticker.C: + redialFn() } } } // DialPeers adds the peers to the dial queue for async dialing. // To monitor dial progress, subscribe to adequate p2p MultiplexSwitch events -func (sw *MultiplexSwitch) DialPeers(peerAddrs ...*types.NetAddress) { +func (sw *MultiplexSwitch) DialPeers(peerAddrs ...*types.NetAddress) { // TODO remove pointer for _, peerAddr := range peerAddrs { // Check if this is our address if peerAddr.Same(sw.transport.NetAddress()) { diff --git a/tm2/pkg/p2p/switch_test.go b/tm2/pkg/p2p/switch_test.go index 34edd251df3..fd05ded96a1 100644 --- a/tm2/pkg/p2p/switch_test.go +++ b/tm2/pkg/p2p/switch_test.go @@ -2,6 +2,7 @@ package p2p import ( "context" + "net" "sync" "testing" "time" @@ -343,5 +344,401 @@ func TestMultiplexSwitch_DialLoop(t *testing.T) { func TestMultiplexSwitch_AcceptLoop(t *testing.T) { t.Parallel() - // TODO implement + t.Run("inbound limit reached", func(t *testing.T) { + t.Parallel() + + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() + + var ( + ch = make(chan struct{}, 1) + maxInbound = uint64(10) + + peerRemoved bool + + p = mock.GeneratePeers(t, 1)[0] + + mockTransport = &mockTransport{ + acceptFn: func(_ context.Context, _ PeerBehavior) (Peer, error) { + return p, nil + }, + removeFn: func(removedPeer Peer) { + require.Equal(t, p.ID(), removedPeer.ID()) + + peerRemoved = true + + ch <- struct{}{} + }, + } + + ps = &mockSet{ + numInboundFn: func() uint64 { + return maxInbound + }, + } + + sw = NewSwitch( + mockTransport, + WithMaxInboundPeers(maxInbound), + ) + ) + + // Set the peer set + sw.peers = ps + + // Run the accept loop + go sw.runAcceptLoop(ctx) + + select { + case <-ch: + case <-time.After(5 * time.Second): + } + + assert.True(t, peerRemoved) + }) + + t.Run("peer accepted", func(t *testing.T) { + t.Parallel() + + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() + + var ( + ch = make(chan struct{}, 1) + maxInbound = uint64(10) + + peerAdded bool + + p = mock.GeneratePeers(t, 1)[0] + + mockTransport = &mockTransport{ + acceptFn: func(_ context.Context, _ PeerBehavior) (Peer, error) { + return p, nil + }, + } + + ps = &mockSet{ + numInboundFn: func() uint64 { + return maxInbound - 1 // available slot + }, + addFn: func(peer Peer) { + require.Equal(t, p.ID(), peer.ID()) + + peerAdded = true + + ch <- struct{}{} + }, + } + + sw = NewSwitch( + mockTransport, + WithMaxInboundPeers(maxInbound), + ) + ) + + // Set the peer set + sw.peers = ps + + // Run the accept loop + go sw.runAcceptLoop(ctx) + + select { + case <-ch: + case <-time.After(5 * time.Second): + } + + assert.True(t, peerAdded) + }) +} + +func TestMultiplexSwitch_RedialLoop(t *testing.T) { + t.Parallel() + + t.Run("no peers to dial", func(t *testing.T) { + t.Parallel() + + var ( + ch = make(chan struct{}, 1) + + peersChecked = 0 + peers = mock.GeneratePeers(t, 10) + + ps = &mockSet{ + hasFn: func(id types.ID) bool { + exists := false + for _, p := range peers { + if p.ID() == id { + exists = true + + break + } + } + + require.True(t, exists) + + peersChecked++ + + if peersChecked == len(peers) { + ch <- struct{}{} + } + + return true + }, + } + ) + + // Make sure the peers are the + // switch persistent peers + addrs := make([]*types.NetAddress, 0, len(peers)) + + for _, p := range peers { + addrs = append(addrs, p.NodeInfo().NetAddress) + } + + // Create the switch + sw := NewSwitch( + nil, + WithPersistentPeers(addrs), + ) + + // Set the peer set + sw.peers = ps + + // Run the redial loop + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() + + go sw.runRedialLoop(ctx) + + select { + case <-ch: + case <-time.After(5 * time.Second): + } + + assert.Equal(t, len(peers), peersChecked) + }) + + t.Run("missing peers dialed", func(t *testing.T) { + t.Parallel() + + var ( + peers = mock.GeneratePeers(t, 10) + missingPeer = peers[0] + missingAddr = missingPeer.NodeInfo().NetAddress + + peersDialed []types.NetAddress + + mockTransport = &mockTransport{ + dialFn: func( + _ context.Context, + address types.NetAddress, + _ PeerBehavior, + ) (Peer, error) { + peersDialed = append(peersDialed, address) + + if address.Equals(*missingPeer.NodeInfo().NetAddress) { + return missingPeer, nil + } + + return nil, errors.New("invalid dial") + }, + } + ps = &mockSet{ + hasFn: func(id types.ID) bool { + return id != missingPeer.ID() + }, + } + ) + + // Make sure the peers are the + // switch persistent peers + addrs := make([]*types.NetAddress, 0, len(peers)) + + for _, p := range peers { + addrs = append(addrs, p.NodeInfo().NetAddress) + } + + // Create the switch + sw := NewSwitch( + mockTransport, + WithPersistentPeers(addrs), + ) + + // Set the peer set + sw.peers = ps + + // Run the redial loop + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() + + var wg sync.WaitGroup + + wg.Add(2) + + go func() { + defer wg.Done() + + sw.runRedialLoop(ctx) + }() + + go func() { + defer wg.Done() + + deadline := time.After(5 * time.Second) + + for { + select { + case <-deadline: + return + default: + if !sw.dialQueue.Has(missingAddr) { + continue + } + + cancelFn() + + return + } + } + }() + + wg.Wait() + + require.True(t, sw.dialQueue.Has(missingAddr)) + assert.Equal(t, missingAddr, sw.dialQueue.Peek().Address) + }) +} + +func TestMultiplexSwitch_DialPeers(t *testing.T) { + t.Parallel() + + t.Run("self dial request", func(t *testing.T) { + t.Parallel() + + var ( + addr = types.NetAddress{ + ID: "id", + IP: net.IP{}, + } + + p = mock.GeneratePeers(t, 1)[0] + + mockTransport = &mockTransport{ + netAddressFn: func() types.NetAddress { + return addr + }, + } + ) + + // Make sure the "peer" has the same address + // as the transport (node) + p.NodeInfoFn = func() types.NodeInfo { + return types.NodeInfo{ + NetAddress: &addr, + } + } + + sw := NewSwitch(mockTransport) + + // Dial the peers + sw.DialPeers(p.NodeInfo().NetAddress) + + // Make sure the peer wasn't actually dialed + assert.False(t, sw.dialQueue.Has(p.NodeInfo().NetAddress)) + }) + + t.Run("outbound peer limit reached", func(t *testing.T) { + t.Parallel() + + var ( + maxOutbound = uint64(10) + peers = mock.GeneratePeers(t, 10) + + mockTransport = &mockTransport{ + netAddressFn: func() types.NetAddress { + return types.NetAddress{ + ID: "id", + IP: net.IP{}, + } + }, + } + + ps = &mockSet{ + numOutboundFn: func() uint64 { + return maxOutbound + }, + } + ) + + sw := NewSwitch( + mockTransport, + WithMaxOutboundPeers(maxOutbound), + ) + + // Set the peer set + sw.peers = ps + + // Dial the peers + addrs := make([]*types.NetAddress, 0, len(peers)) + + for _, p := range peers { + addrs = append(addrs, p.NodeInfo().NetAddress) + } + + sw.DialPeers(addrs...) + + // Make sure no peers were dialed + for _, p := range peers { + assert.False(t, sw.dialQueue.Has(p.NodeInfo().NetAddress)) + } + }) + + t.Run("peers dialed", func(t *testing.T) { + t.Parallel() + + var ( + maxOutbound = uint64(1000) + peers = mock.GeneratePeers(t, int(maxOutbound/2)) + + mockTransport = &mockTransport{ + netAddressFn: func() types.NetAddress { + return types.NetAddress{ + ID: "id", + IP: net.IP{}, + } + }, + } + ) + + sw := NewSwitch( + mockTransport, + WithMaxOutboundPeers(10), + ) + + // Dial the peers + addrs := make([]*types.NetAddress, 0, len(peers)) + + for _, p := range peers { + addrs = append(addrs, p.NodeInfo().NetAddress) + } + + sw.DialPeers(addrs...) + + // Make sure peers were dialed + for _, p := range peers { + assert.True(t, sw.dialQueue.Has(p.NodeInfo().NetAddress)) + } + }) }