Skip to content

Commit

Permalink
Split msgs in iwant response if bigger than limit (#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomrsantos authored Oct 2, 2023
1 parent 61929ae commit 7587181
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 65 deletions.
2 changes: 1 addition & 1 deletion libp2p/protocols/pubsub/gossipsub.nim
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ method init*(g: GossipSub) =
g.codecs &= GossipSubCodec
g.codecs &= GossipSubCodec_10

method onNewPeer(g: GossipSub, peer: PubSubPeer) =
method onNewPeer*(g: GossipSub, peer: PubSubPeer) =
g.withPeerStats(peer.peerId) do (stats: var PeerStats):
# Make sure stats and peer information match, even when reloading peer stats
# from a previous connection
Expand Down
47 changes: 43 additions & 4 deletions libp2p/protocols/pubsub/pubsubpeer.nim
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ proc sendEncoded*(p: PubSubPeer, msg: seq[byte]) {.raises: [], async.} =
return

if msg.len > p.maxMessageSize:
info "trying to send a too big for pubsub", maxSize=p.maxMessageSize, msgSize=msg.len
info "trying to send a msg too big for pubsub", maxSize=p.maxMessageSize, msgSize=msg.len
return

if p.sendConn == nil:
Expand All @@ -269,9 +269,42 @@ proc sendEncoded*(p: PubSubPeer, msg: seq[byte]) {.raises: [], async.} =

await conn.close() # This will clean up the send connection

proc send*(p: PubSubPeer, msg: RPCMsg, anonymize: bool) {.raises: [].} =
trace "sending msg to peer", peer = p, rpcMsg = shortLog(msg)
iterator splitRPCMsg(peer: PubSubPeer, rpcMsg: RPCMsg, maxSize: int, anonymize: bool): seq[byte] =
## This iterator takes an `RPCMsg` and sequentially repackages its Messages into new `RPCMsg` instances.
## Each new `RPCMsg` accumulates Messages until reaching the specified `maxSize`. If a single Message
## exceeds the `maxSize` when trying to fit into an empty `RPCMsg`, the latter is skipped as too large to send.
## Every constructed `RPCMsg` is then encoded, optionally anonymized, and yielded as a sequence of bytes.

var currentRPCMsg = rpcMsg
currentRPCMsg.messages = newSeq[Message]()

var currentSize = byteSize(currentRPCMsg)

for msg in rpcMsg.messages:
let msgSize = byteSize(msg)

# Check if adding the next message will exceed maxSize
if float(currentSize + msgSize) * 1.1 > float(maxSize): # Guessing 10% protobuf overhead
if currentRPCMsg.messages.len == 0:
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
continue # Skip this message

trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
yield encodeRpcMsg(currentRPCMsg, anonymize)
currentRPCMsg = RPCMsg()
currentSize = 0

currentRPCMsg.messages.add(msg)
currentSize += msgSize

# Check if there is a non-empty currentRPCMsg left to be added
if currentSize > 0 and currentRPCMsg.messages.len > 0:
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
yield encodeRpcMsg(currentRPCMsg, anonymize)
else:
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)

proc send*(p: PubSubPeer, msg: RPCMsg, anonymize: bool) {.raises: [].} =
# When sending messages, we take care to re-encode them with the right
# anonymization flag to ensure that we're not penalized for sending invalid
# or malicious data on the wire - in particular, re-encoding protects against
Expand All @@ -289,7 +322,13 @@ proc send*(p: PubSubPeer, msg: RPCMsg, anonymize: bool) {.raises: [].} =
sendMetrics(msg)
encodeRpcMsg(msg, anonymize)

asyncSpawn p.sendEncoded(encoded)
if encoded.len > p.maxMessageSize and msg.messages.len > 1:
for encodedSplitMsg in splitRPCMsg(p, msg, p.maxMessageSize, anonymize):
asyncSpawn p.sendEncoded(encodedSplitMsg)
else:
# If the message size is within limits, send it as is
trace "sending msg to peer", peer = p, rpcMsg = shortLog(msg)
asyncSpawn p.sendEncoded(encoded)

proc canAskIWant*(p: PubSubPeer, msgId: MessageId): bool =
for sentIHave in p.sentIHaves.mitems():
Expand Down
82 changes: 56 additions & 26 deletions libp2p/protocols/pubsub/rpc/messages.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

{.push raises: [].}

import options, sequtils
import options, sequtils, sugar
import "../../.."/[
peerid,
routing_record,
Expand All @@ -18,6 +18,14 @@ import "../../.."/[

export options

proc expectedFields[T](t: typedesc[T], existingFieldNames: seq[string]) {.raises: [CatchableError].} =
var fieldNames: seq[string]
for name, _ in fieldPairs(T()):
fieldNames &= name
if fieldNames != existingFieldNames:
fieldNames.keepIf(proc(it: string): bool = it notin existingFieldNames)
raise newException(CatchableError, $T & " fields changed, please search for and revise all relevant procs. New fields: " & $fieldNames)

type
PeerInfoMsg* = object
peerId*: PeerId
Expand Down Expand Up @@ -117,31 +125,53 @@ func shortLog*(m: RPCMsg): auto =
control: m.control.get(ControlMessage()).shortLog
)

static: expectedFields(PeerInfoMsg, @["peerId", "signedPeerRecord"])
proc byteSize(peerInfo: PeerInfoMsg): int =
peerInfo.peerId.len + peerInfo.signedPeerRecord.len

static: expectedFields(SubOpts, @["subscribe", "topic"])
proc byteSize(subOpts: SubOpts): int =
1 + subOpts.topic.len # 1 byte for the bool

static: expectedFields(Message, @["fromPeer", "data", "seqno", "topicIds", "signature", "key"])
proc byteSize*(msg: Message): int =
var total = 0
total += msg.fromPeer.len
total += msg.data.len
total += msg.seqno.len
total += msg.signature.len
total += msg.key.len
for topicId in msg.topicIds:
total += topicId.len
return total
msg.fromPeer.len + msg.data.len + msg.seqno.len +
msg.signature.len + msg.key.len + msg.topicIds.foldl(a + b.len, 0)

proc byteSize*(msgs: seq[Message]): int =
msgs.mapIt(byteSize(it)).foldl(a + b, 0)

proc byteSize*(ihave: seq[ControlIHave]): int =
var total = 0
for item in ihave:
total += item.topicId.len
for msgId in item.messageIds:
total += msgId.len
return total

proc byteSize*(iwant: seq[ControlIWant]): int =
var total = 0
for item in iwant:
for msgId in item.messageIds:
total += msgId.len
return total
msgs.foldl(a + b.byteSize, 0)

static: expectedFields(ControlIHave, @["topicId", "messageIds"])
proc byteSize(controlIHave: ControlIHave): int =
controlIHave.topicId.len + controlIHave.messageIds.foldl(a + b.len, 0)

proc byteSize*(ihaves: seq[ControlIHave]): int =
ihaves.foldl(a + b.byteSize, 0)

static: expectedFields(ControlIWant, @["messageIds"])
proc byteSize(controlIWant: ControlIWant): int =
controlIWant.messageIds.foldl(a + b.len, 0)

proc byteSize*(iwants: seq[ControlIWant]): int =
iwants.foldl(a + b.byteSize, 0)

static: expectedFields(ControlGraft, @["topicId"])
proc byteSize(controlGraft: ControlGraft): int =
controlGraft.topicId.len

static: expectedFields(ControlPrune, @["topicId", "peers", "backoff"])
proc byteSize(controlPrune: ControlPrune): int =
controlPrune.topicId.len + controlPrune.peers.foldl(a + b.byteSize, 0) + 8 # 8 bytes for uint64

static: expectedFields(ControlMessage, @["ihave", "iwant", "graft", "prune", "idontwant"])
proc byteSize(control: ControlMessage): int =
control.ihave.foldl(a + b.byteSize, 0) + control.iwant.foldl(a + b.byteSize, 0) +
control.graft.foldl(a + b.byteSize, 0) + control.prune.foldl(a + b.byteSize, 0) +
control.idontwant.foldl(a + b.byteSize, 0)

static: expectedFields(RPCMsg, @["subscriptions", "messages", "control", "ping", "pong"])
proc byteSize*(rpc: RPCMsg): int =
result = rpc.subscriptions.foldl(a + b.byteSize, 0) + byteSize(rpc.messages) +
rpc.ping.len + rpc.pong.len
rpc.control.withValue(ctrl):
result += ctrl.byteSize
162 changes: 139 additions & 23 deletions tests/pubsub/testgossipinternal.nim
Original file line number Diff line number Diff line change
@@ -1,42 +1,31 @@
include ../../libp2p/protocols/pubsub/gossipsub
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.

{.used.}

import std/[options, deques]
import std/[options, deques, sequtils, enumerate, algorithm]
import stew/byteutils
import ../../libp2p/builders
import ../../libp2p/errors
import ../../libp2p/crypto/crypto
import ../../libp2p/stream/bufferstream
import ../../libp2p/protocols/pubsub/[pubsub, gossipsub, mcache, mcache, peertable]
import ../../libp2p/protocols/pubsub/rpc/[message, messages]
import ../../libp2p/switch
import ../../libp2p/muxers/muxer
import ../../libp2p/protocols/pubsub/rpc/protobuf
import utils

import ../helpers

type
TestGossipSub = ref object of GossipSub

proc noop(data: seq[byte]) {.async, gcsafe.} = discard

proc getPubSubPeer(p: TestGossipSub, peerId: PeerId): PubSubPeer =
proc getConn(): Future[Connection] =
p.switch.dial(peerId, GossipSubCodec)

let pubSubPeer = PubSubPeer.new(peerId, getConn, nil, GossipSubCodec, 1024 * 1024, Opt.some(TokenBucket.new(1024, 500.milliseconds)))
debug "created new pubsub peer", peerId

p.peers[peerId] = pubSubPeer

onNewPeer(p, pubSubPeer)
pubSubPeer

proc randomPeerId(): PeerId =
try:
PeerId.init(PrivateKey.random(ECDSA, rng[]).get()).tryGet()
except CatchableError as exc:
raise newException(Defect, exc.msg)

const MsgIdSuccess = "msg id gen success"

suite "GossipSub internal":
Expand Down Expand Up @@ -826,3 +815,130 @@ suite "GossipSub internal":

await allFuturesThrowing(conns.mapIt(it.close()))
await gossipSub.switch.stop()

proc setupTest(): Future[tuple[gossip0: GossipSub, gossip1: GossipSub, receivedMessages: ref HashSet[seq[byte]]]] {.async.} =
let
nodes = generateNodes(2, gossip = true, verifySignature = false)
discard await allFinished(
nodes[0].switch.start(),
nodes[1].switch.start()
)

await nodes[1].switch.connect(nodes[0].switch.peerInfo.peerId, nodes[0].switch.peerInfo.addrs)

var receivedMessages = new(HashSet[seq[byte]])

proc handlerA(topic: string, data: seq[byte]) {.async, gcsafe.} =
receivedMessages[].incl(data)

proc handlerB(topic: string, data: seq[byte]) {.async, gcsafe.} =
discard

nodes[0].subscribe("foobar", handlerA)
nodes[1].subscribe("foobar", handlerB)
await waitSubGraph(nodes, "foobar")

var gossip0: GossipSub = GossipSub(nodes[0])
var gossip1: GossipSub = GossipSub(nodes[1])

return (gossip0, gossip1, receivedMessages)

proc teardownTest(gossip0: GossipSub, gossip1: GossipSub) {.async.} =
await allFuturesThrowing(
gossip0.switch.stop(),
gossip1.switch.stop()
)

proc createMessages(gossip0: GossipSub, gossip1: GossipSub, size1: int, size2: int): tuple[iwantMessageIds: seq[MessageId], sentMessages: HashSet[seq[byte]]] =
var iwantMessageIds = newSeq[MessageId]()
var sentMessages = initHashSet[seq[byte]]()

for i, size in enumerate([size1, size2]):
let data = newSeqWith[byte](size, i.byte)
sentMessages.incl(data)

let msg = Message.init(gossip1.peerInfo.peerId, data, "foobar", some(uint64(i + 1)))
let iwantMessageId = gossip1.msgIdProvider(msg).expect(MsgIdSuccess)
iwantMessageIds.add(iwantMessageId)
gossip1.mcache.put(iwantMessageId, msg)

let peer = gossip1.peers[(gossip0.peerInfo.peerId)]
peer.sentIHaves[^1].incl(iwantMessageId)

return (iwantMessageIds, sentMessages)

asyncTest "e2e - Split IWANT replies when individual messages are below maxSize but combined exceed maxSize":
# This test checks if two messages, each below the maxSize, are correctly split when their combined size exceeds maxSize.
# Expected: Both messages should be received.
let (gossip0, gossip1, receivedMessages) = await setupTest()

let messageSize = gossip1.maxMessageSize div 2 + 1
let (iwantMessageIds, sentMessages) = createMessages(gossip0, gossip1, messageSize, messageSize)

gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage(
ihave: @[ControlIHave(topicId: "foobar", messageIds: iwantMessageIds)]
))))

checkExpiring: receivedMessages[] == sentMessages
check receivedMessages[].len == 2

await teardownTest(gossip0, gossip1)

asyncTest "e2e - Discard IWANT replies when both messages individually exceed maxSize":
# This test checks if two messages, each exceeding the maxSize, are discarded and not sent.
# Expected: No messages should be received.
let (gossip0, gossip1, receivedMessages) = await setupTest()

let messageSize = gossip1.maxMessageSize + 10
let (bigIWantMessageIds, sentMessages) = createMessages(gossip0, gossip1, messageSize, messageSize)

gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage(
ihave: @[ControlIHave(topicId: "foobar", messageIds: bigIWantMessageIds)]
))))

await sleepAsync(300.milliseconds)
checkExpiring: receivedMessages[].len == 0

await teardownTest(gossip0, gossip1)

asyncTest "e2e - Process IWANT replies when both messages are below maxSize":
# This test checks if two messages, both below the maxSize, are correctly processed and sent.
# Expected: Both messages should be received.
let (gossip0, gossip1, receivedMessages) = await setupTest()
let size1 = gossip1.maxMessageSize div 2
let size2 = gossip1.maxMessageSize div 3
let (bigIWantMessageIds, sentMessages) = createMessages(gossip0, gossip1, size1, size2)

gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage(
ihave: @[ControlIHave(topicId: "foobar", messageIds: bigIWantMessageIds)]
))))

checkExpiring: receivedMessages[] == sentMessages
check receivedMessages[].len == 2

await teardownTest(gossip0, gossip1)

asyncTest "e2e - Split IWANT replies when one message is below maxSize and the other exceeds maxSize":
# This test checks if, when given two messages where one is below maxSize and the other exceeds it, only the smaller message is processed and sent.
# Expected: Only the smaller message should be received.
let (gossip0, gossip1, receivedMessages) = await setupTest()
let maxSize = gossip1.maxMessageSize
let size1 = maxSize div 2
let size2 = maxSize + 10
let (bigIWantMessageIds, sentMessages) = createMessages(gossip0, gossip1, size1, size2)

gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage(
ihave: @[ControlIHave(topicId: "foobar", messageIds: bigIWantMessageIds)]
))))

var smallestSet: HashSet[seq[byte]]
let seqs = toSeq(sentMessages)
if seqs[0] < seqs[1]:
smallestSet.incl(seqs[0])
else:
smallestSet.incl(seqs[1])

checkExpiring: receivedMessages[] == smallestSet
check receivedMessages[].len == 1

await teardownTest(gossip0, gossip1)
Loading

0 comments on commit 7587181

Please sign in to comment.