Skip to content

Commit

Permalink
feat(pubsub): add {.async: (raises).} annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
richard-ramos committed Jan 11, 2025
1 parent 1fa30f0 commit e1b48e5
Show file tree
Hide file tree
Showing 19 changed files with 312 additions and 249 deletions.
12 changes: 9 additions & 3 deletions examples/tutorial_4_gossipsub.nim
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ proc oneNode(node: Node, rng: ref HmacDrbgContext) {.async.} =
# This procedure will handle one of the node of the network
node.gossip.addValidator(
["metrics"],
proc(topic: string, message: Message): Future[ValidationResult] {.async.} =
proc(
topic: string, message: Message
): Future[ValidationResult] {.async: (raises: []).} =
let decoded = MetricList.decode(message.data)
if decoded.isErr:
return ValidationResult.Reject
Expand All @@ -92,8 +94,12 @@ proc oneNode(node: Node, rng: ref HmacDrbgContext) {.async.} =
if node.hostname == "John":
node.gossip.subscribe(
"metrics",
proc(topic: string, data: seq[byte]) {.async.} =
echo MetricList.decode(data).tryGet()
proc(topic: string, data: seq[byte]) {.async: (raises: []).} =
let m = MetricList.decode(data)
if m.isErr:
raiseAssert "failed to decode metric"
else:
echo m
,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorial_6_game.nim
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ proc networking(g: Game) {.async.} =

gossip.subscribe(
"/tron/matchmaking",
proc(topic: string, data: seq[byte]) {.async.} =
proc(topic: string, data: seq[byte]) {.async: (raises: []).} =
# If we are still looking for an opponent,
# try to match anyone broadcasting its address
if g.peerFound.finished or g.hasCandidate:
Expand Down
4 changes: 2 additions & 2 deletions libp2p/daemon/daemonapi.nim
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ type
.}
P2PPubSubCallback* = proc(
api: DaemonAPI, ticket: PubsubTicket, message: PubSubMessage
): Future[bool] {.gcsafe, raises: [CatchableError].}
): Future[bool] {.gcsafe, async: (raises: [CatchableError]).}

DaemonError* = object of LPError
DaemonRemoteError* = object of DaemonError
Expand Down Expand Up @@ -1296,7 +1296,7 @@ proc pubsubLoop(api: DaemonAPI, ticket: PubsubTicket) {.async.} =

proc pubsubSubscribe*(
api: DaemonAPI, topic: string, handler: P2PPubSubCallback
): Future[PubsubTicket] {.async.} =
): Future[PubsubTicket] {.async: (raises: [CatchableError]).} =
## Subscribe to topic ``topic``.
var transp = await api.newConnection()
try:
Expand Down
96 changes: 39 additions & 57 deletions libp2p/protobuf/minprotobuf.nim
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ export results, utility

{.push public.}

const MaxMessageSize = 1'u shl 22

type
ProtoFieldKind* = enum
## Protobuf's field types enum
Expand All @@ -39,7 +37,6 @@ type
buffer*: seq[byte]
offset*: int
length*: int
maxSize*: uint

ProtoHeader* = object
wire*: ProtoFieldKind
Expand All @@ -63,7 +60,6 @@ type
VarintDecode
MessageIncomplete
BufferOverflow
MessageTooBig
BadWireType
IncorrectBlob
RequiredFieldMissing
Expand Down Expand Up @@ -99,11 +95,14 @@ template getProtoHeader*(field: ProtoField): uint64 =
template toOpenArray*(pb: ProtoBuffer): untyped =
toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1)

template lenu64*(x: untyped): untyped =
uint64(len(x))

template isEmpty*(pb: ProtoBuffer): bool =
len(pb.buffer) - pb.offset <= 0

template isEnough*(pb: ProtoBuffer, length: int): bool =
len(pb.buffer) - pb.offset - length >= 0
template isEnough*(pb: ProtoBuffer, length: uint64): bool =
pb.offset <= len(pb.buffer) and length <= uint64(len(pb.buffer) - pb.offset)

template getPtr*(pb: ProtoBuffer): pointer =
cast[pointer](unsafeAddr pb.buffer[pb.offset])
Expand All @@ -127,33 +126,25 @@ proc vsizeof*(field: ProtoField): int {.inline.} =
0

proc initProtoBuffer*(
data: seq[byte], offset = 0, options: set[ProtoFlags] = {}, maxSize = MaxMessageSize
data: seq[byte], offset = 0, options: set[ProtoFlags] = {}
): ProtoBuffer =
## Initialize ProtoBuffer with shallow copy of ``data``.
result.buffer = data
result.offset = offset
result.options = options
result.maxSize = maxSize

proc initProtoBuffer*(
data: openArray[byte],
offset = 0,
options: set[ProtoFlags] = {},
maxSize = MaxMessageSize,
data: openArray[byte], offset = 0, options: set[ProtoFlags] = {}
): ProtoBuffer =
## Initialize ProtoBuffer with copy of ``data``.
result.buffer = @data
result.offset = offset
result.options = options
result.maxSize = maxSize

proc initProtoBuffer*(
options: set[ProtoFlags] = {}, maxSize = MaxMessageSize
): ProtoBuffer =
proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
## Initialize ProtoBuffer with new sequence of capacity ``cap``.
result.buffer = newSeq[byte]()
result.options = options
result.maxSize = maxSize
if WithVarintLength in options:
# Our buffer will start from position 10, so we can store length of buffer
# in [0, 9].
Expand Down Expand Up @@ -194,12 +185,12 @@ proc write*[T: ProtoScalar](pb: var ProtoBuffer, field: int, value: T) =
doAssert(vres.isOk())
pb.offset += length
elif T is float32:
doAssert(pb.isEnough(sizeof(T)))
doAssert(pb.isEnough(uint64(sizeof(T))))
let u32 = cast[uint32](value)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
pb.offset += sizeof(T)
elif T is float64:
doAssert(pb.isEnough(sizeof(T)))
doAssert(pb.isEnough(uint64(sizeof(T))))
let u64 = cast[uint64](value)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
Expand Down Expand Up @@ -242,12 +233,12 @@ proc writePacked*[T: ProtoScalar](
doAssert(vres.isOk())
pb.offset += length
elif T is float32:
doAssert(pb.isEnough(sizeof(T)))
doAssert(pb.isEnough(uint64(sizeof(T))))
let u32 = cast[uint32](item)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
pb.offset += sizeof(T)
elif T is float64:
doAssert(pb.isEnough(sizeof(T)))
doAssert(pb.isEnough(uint64(sizeof(T))))
let u64 = cast[uint64](item)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
Expand All @@ -268,7 +259,7 @@ proc write*[T: byte | char](pb: var ProtoBuffer, field: int, value: openArray[T]
doAssert(lres.isOk())
pb.offset += length
if len(value) > 0:
doAssert(pb.isEnough(len(value)))
doAssert(pb.isEnough(value.lenu64))
copyMem(addr pb.buffer[pb.offset], unsafeAddr value[0], len(value))
pb.offset += len(value)

Expand Down Expand Up @@ -327,13 +318,13 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] =
else:
err(ProtoError.VarintDecode)
of ProtoFieldKind.Fixed32:
if data.isEnough(sizeof(uint32)):
if data.isEnough(uint64(sizeof(uint32))):
data.offset += sizeof(uint32)
ok()
else:
err(ProtoError.VarintDecode)
of ProtoFieldKind.Fixed64:
if data.isEnough(sizeof(uint64)):
if data.isEnough(uint64(sizeof(uint64))):
data.offset += sizeof(uint64)
ok()
else:
Expand All @@ -343,14 +334,11 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] =
var bsize = 0'u64
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(data.maxSize):
if data.isEnough(int(bsize)):
data.offset += int(bsize)
ok()
else:
err(ProtoError.MessageIncomplete)
if data.isEnough(bsize):
data.offset += int(bsize)
ok()
else:
err(ProtoError.MessageTooBig)
err(ProtoError.MessageIncomplete)
else:
err(ProtoError.VarintDecode)
of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup:
Expand Down Expand Up @@ -382,15 +370,15 @@ proc getValue[T: ProtoScalar](
err(ProtoError.VarintDecode)
elif T is float32:
doAssert(header.wire == ProtoFieldKind.Fixed32)
if data.isEnough(sizeof(float32)):
if data.isEnough(uint64(sizeof(float32))):
outval = cast[float32](fromBytesLE(uint32, data.toOpenArray()))
data.offset += sizeof(float32)
ok()
else:
err(ProtoError.MessageIncomplete)
elif T is float64:
doAssert(header.wire == ProtoFieldKind.Fixed64)
if data.isEnough(sizeof(float64)):
if data.isEnough(uint64(sizeof(float64))):
outval = cast[float64](fromBytesLE(uint64, data.toOpenArray()))
data.offset += sizeof(float64)
ok()
Expand All @@ -410,22 +398,19 @@ proc getValue[T: byte | char](
outLength = 0
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(data.maxSize):
if data.isEnough(int(bsize)):
outLength = int(bsize)
if len(outBytes) >= int(bsize):
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ok()
else:
# Buffer overflow should not be critical failure
data.offset += int(bsize)
err(ProtoError.BufferOverflow)
if data.isEnough(bsize):
outLength = int(bsize)
if len(outBytes) >= int(bsize):
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ok()
else:
err(ProtoError.MessageIncomplete)
# Buffer overflow should not be critical failure
data.offset += int(bsize)
err(ProtoError.BufferOverflow)
else:
err(ProtoError.MessageTooBig)
err(ProtoError.MessageIncomplete)
else:
err(ProtoError.VarintDecode)

Expand All @@ -439,17 +424,14 @@ proc getValue[T: seq[byte] | string](

if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(data.maxSize):
if data.isEnough(int(bsize)):
outBytes.setLen(bsize)
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ok()
else:
err(ProtoError.MessageIncomplete)
if data.isEnough(bsize):
outBytes.setLen(bsize)
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ok()
else:
err(ProtoError.MessageTooBig)
err(ProtoError.MessageIncomplete)
else:
err(ProtoError.VarintDecode)

Expand Down
10 changes: 7 additions & 3 deletions libp2p/protocols/pubsub/floodsub.nim
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ method unsubscribePeer*(f: FloodSub, peer: PeerId) =

procCall PubSub(f).unsubscribePeer(peer)

method rpcHandler*(f: FloodSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
method rpcHandler*(
f: FloodSub, peer: PubSubPeer, data: seq[byte]
) {.async: (raises: [CancelledError, PeerMessageDecodeError, PeerRateLimitError]).} =
var rpcMsg = decodeRpcMsg(data).valueOr:
debug "failed to decode msg from peer", peer, err = error
raise newException(CatchableError, "Peer msg couldn't be decoded")
raise newException(PeerMessageDecodeError, "Peer msg couldn't be decoded")

trace "decoded msg from peer", peer, payload = rpcMsg.shortLog
# trigger hooks
Expand Down Expand Up @@ -192,7 +194,9 @@ method init*(f: FloodSub) =
f.handler = handler
f.codec = FloodSubCodec

method publish*(f: FloodSub, topic: string, data: seq[byte]): Future[int] {.async.} =
method publish*(
f: FloodSub, topic: string, data: seq[byte]
): Future[int] {.async: (raises: [LPError]).} =
# base returns always 0
discard await procCall PubSub(f).publish(topic, data)

Expand Down
32 changes: 21 additions & 11 deletions libp2p/protocols/pubsub/gossipsub.nim
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =

proc validateAndRelay(
g: GossipSub, msg: Message, msgId: MessageId, saltedId: SaltedId, peer: PubSubPeer
) {.async.} =
) {.async: (raises: []).} =
try:
template topic(): string =
msg.topic
Expand Down Expand Up @@ -511,7 +511,9 @@ proc messageOverhead(g: GossipSub, msg: RPCMsg, msgSize: int): int =

msgSize - payloadSize - controlSize

proc rateLimit*(g: GossipSub, peer: PubSubPeer, overhead: int) {.async.} =
proc rateLimit*(
g: GossipSub, peer: PubSubPeer, overhead: int
) {.async: (raises: [PeerRateLimitError]).} =
peer.overheadRateLimitOpt.withValue(overheadRateLimit):
if not overheadRateLimit.tryConsume(overhead):
libp2p_gossipsub_peers_rate_limit_hits.inc(labelValues = [peer.getAgent()])
Expand All @@ -524,7 +526,9 @@ proc rateLimit*(g: GossipSub, peer: PubSubPeer, overhead: int) {.async.} =
PeerRateLimitError, "Peer disconnected because it's above rate limit."
)

method rpcHandler*(g: GossipSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
method rpcHandler*(
g: GossipSub, peer: PubSubPeer, data: seq[byte]
) {.async: (raises: [CancelledError, PeerMessageDecodeError, PeerRateLimitError]).} =
let msgSize = data.len
var rpcMsg = decodeRpcMsg(data).valueOr:
debug "failed to decode msg from peer", peer, err = error
Expand All @@ -534,7 +538,7 @@ method rpcHandler*(g: GossipSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
# TODO evaluate behaviour penalty values
peer.behaviourPenalty += 0.1

raise newException(CatchableError, "Peer msg couldn't be decoded")
raise newException(PeerMessageDecodeError, "Peer msg couldn't be decoded")

when defined(libp2p_expensive_metrics):
for m in rpcMsg.messages:
Expand Down Expand Up @@ -682,7 +686,9 @@ method onTopicSubscription*(g: GossipSub, topic: string, subscribed: bool) =
# Send unsubscribe (in reverse order to sub/graft)
procCall PubSub(g).onTopicSubscription(topic, subscribed)

method publish*(g: GossipSub, topic: string, data: seq[byte]): Future[int] {.async.} =
method publish*(
g: GossipSub, topic: string, data: seq[byte]
): Future[int] {.async: (raises: [LPError]).} =
logScope:
topic

Expand Down Expand Up @@ -792,7 +798,9 @@ method publish*(g: GossipSub, topic: string, data: seq[byte]): Future[int] {.asy
trace "Published message to peers", peers = peers.len
return peers.len

proc maintainDirectPeer(g: GossipSub, id: PeerId, addrs: seq[MultiAddress]) {.async.} =
proc maintainDirectPeer(
g: GossipSub, id: PeerId, addrs: seq[MultiAddress]
) {.async: (raises: [CancelledError]).} =
if id notin g.peers:
trace "Attempting to dial a direct peer", peer = id
if g.switch.isConnected(id):
Expand All @@ -808,11 +816,13 @@ proc maintainDirectPeer(g: GossipSub, id: PeerId, addrs: seq[MultiAddress]) {.as
except CatchableError as exc:
debug "Direct peer error dialing", description = exc.msg

proc addDirectPeer*(g: GossipSub, id: PeerId, addrs: seq[MultiAddress]) {.async.} =
proc addDirectPeer*(
g: GossipSub, id: PeerId, addrs: seq[MultiAddress]
) {.async: (raises: [CancelledError]).} =
g.parameters.directPeers[id] = addrs
await g.maintainDirectPeer(id, addrs)

proc maintainDirectPeers(g: GossipSub) {.async.} =
proc maintainDirectPeers(g: GossipSub) {.async: (raises: [CancelledError]).} =
heartbeat "GossipSub DirectPeers", 1.minutes:
for id, addrs in g.parameters.directPeers:
await g.addDirectPeer(id, addrs)
Expand Down Expand Up @@ -846,9 +856,9 @@ method stop*(g: GossipSub): Future[void] {.async: (raises: [], raw: true).} =
return fut

# stop heartbeat interval
g.directPeersLoop.cancel()
g.scoringHeartbeatFut.cancel()
g.heartbeatFut.cancel()
g.directPeersLoop.cancelSoon()
g.scoringHeartbeatFut.cancelSoon()
g.heartbeatFut.cancelSoon()
g.heartbeatFut = nil
fut

Expand Down
Loading

0 comments on commit e1b48e5

Please sign in to comment.