Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pubsub): add {.async: (raises).} annotations #1233

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions examples/tutorial_4_gossipsub.nim
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ proc oneNode(node: Node, rng: ref HmacDrbgContext) {.async.} =
# This procedure will handle one of the node of the network
node.gossip.addValidator(
["metrics"],
proc(topic: string, message: Message): Future[ValidationResult] {.async.} =
proc(
topic: string, message: Message
): Future[ValidationResult] {.async: (raises: []).} =
let decoded = MetricList.decode(message.data)
if decoded.isErr:
return ValidationResult.Reject
Expand All @@ -92,8 +94,12 @@ proc oneNode(node: Node, rng: ref HmacDrbgContext) {.async.} =
if node.hostname == "John":
node.gossip.subscribe(
"metrics",
proc(topic: string, data: seq[byte]) {.async.} =
echo MetricList.decode(data).tryGet()
proc(topic: string, data: seq[byte]) {.async: (raises: []).} =
let m = MetricList.decode(data)
if m.isErr:
raiseAssert "failed to decode metric"
else:
echo m
richard-ramos marked this conversation as resolved.
Show resolved Hide resolved
,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorial_6_game.nim
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ proc networking(g: Game) {.async.} =

gossip.subscribe(
"/tron/matchmaking",
proc(topic: string, data: seq[byte]) {.async.} =
proc(topic: string, data: seq[byte]) {.async: (raises: []).} =
# If we are still looking for an opponent,
# try to match anyone broadcasting its address
if g.peerFound.finished or g.hasCandidate:
Expand Down
4 changes: 2 additions & 2 deletions libp2p/daemon/daemonapi.nim
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ type
.}
P2PPubSubCallback* = proc(
api: DaemonAPI, ticket: PubsubTicket, message: PubSubMessage
): Future[bool] {.gcsafe, raises: [CatchableError].}
): Future[bool] {.gcsafe, async: (raises: [CatchableError]).}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, to maintain "callback compatibility" in public api, what we do is to introduce a 2 overload, like so: https://github.com/status-im/nim-chronos/blob/7c5cbf04a6bb2258654f16cb1d51b4304986cfd1/chronos/transports/stream.nim#L121

then the overload using the old callback includes a "translation wrapper" and/or is deprecated

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Will implement this in a subsequent PR


DaemonError* = object of LPError
DaemonRemoteError* = object of DaemonError
Expand Down Expand Up @@ -1296,7 +1296,7 @@ proc pubsubLoop(api: DaemonAPI, ticket: PubsubTicket) {.async.} =

proc pubsubSubscribe*(
api: DaemonAPI, topic: string, handler: P2PPubSubCallback
): Future[PubsubTicket] {.async.} =
): Future[PubsubTicket] {.async: (raises: [CatchableError]).} =
## Subscribe to topic ``topic``.
var transp = await api.newConnection()
try:
Expand Down
96 changes: 39 additions & 57 deletions libp2p/protobuf/minprotobuf.nim
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ export results, utility

{.push public.}

const MaxMessageSize = 1'u shl 22

type
ProtoFieldKind* = enum
## Protobuf's field types enum
Expand All @@ -39,7 +37,6 @@ type
buffer*: seq[byte]
offset*: int
length*: int
maxSize*: uint

ProtoHeader* = object
wire*: ProtoFieldKind
Expand All @@ -63,7 +60,6 @@ type
VarintDecode
MessageIncomplete
BufferOverflow
MessageTooBig
BadWireType
IncorrectBlob
RequiredFieldMissing
Expand Down Expand Up @@ -99,11 +95,14 @@ template getProtoHeader*(field: ProtoField): uint64 =
template toOpenArray*(pb: ProtoBuffer): untyped =
toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1)

template lenu64*(x: untyped): untyped =
uint64(len(x))

template isEmpty*(pb: ProtoBuffer): bool =
len(pb.buffer) - pb.offset <= 0

template isEnough*(pb: ProtoBuffer, length: int): bool =
len(pb.buffer) - pb.offset - length >= 0
template isEnough*(pb: ProtoBuffer, length: uint64): bool =
pb.offset <= len(pb.buffer) and length <= uint64(len(pb.buffer) - pb.offset)

template getPtr*(pb: ProtoBuffer): pointer =
cast[pointer](unsafeAddr pb.buffer[pb.offset])
Expand All @@ -127,33 +126,25 @@ proc vsizeof*(field: ProtoField): int {.inline.} =
0

proc initProtoBuffer*(
data: seq[byte], offset = 0, options: set[ProtoFlags] = {}, maxSize = MaxMessageSize
data: seq[byte], offset = 0, options: set[ProtoFlags] = {}
): ProtoBuffer =
## Initialize ProtoBuffer with shallow copy of ``data``.
result.buffer = data
result.offset = offset
result.options = options
result.maxSize = maxSize

proc initProtoBuffer*(
data: openArray[byte],
offset = 0,
options: set[ProtoFlags] = {},
maxSize = MaxMessageSize,
data: openArray[byte], offset = 0, options: set[ProtoFlags] = {}
): ProtoBuffer =
## Initialize ProtoBuffer with copy of ``data``.
result.buffer = @data
result.offset = offset
result.options = options
result.maxSize = maxSize

proc initProtoBuffer*(
options: set[ProtoFlags] = {}, maxSize = MaxMessageSize
): ProtoBuffer =
proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
## Initialize ProtoBuffer with new sequence of capacity ``cap``.
result.buffer = newSeq[byte]()
result.options = options
result.maxSize = maxSize
if WithVarintLength in options:
# Our buffer will start from position 10, so we can store length of buffer
# in [0, 9].
Expand Down Expand Up @@ -194,12 +185,12 @@ proc write*[T: ProtoScalar](pb: var ProtoBuffer, field: int, value: T) =
doAssert(vres.isOk())
pb.offset += length
elif T is float32:
doAssert(pb.isEnough(sizeof(T)))
doAssert(pb.isEnough(uint64(sizeof(T))))
let u32 = cast[uint32](value)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
pb.offset += sizeof(T)
elif T is float64:
doAssert(pb.isEnough(sizeof(T)))
doAssert(pb.isEnough(uint64(sizeof(T))))
let u64 = cast[uint64](value)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
Expand Down Expand Up @@ -242,12 +233,12 @@ proc writePacked*[T: ProtoScalar](
doAssert(vres.isOk())
pb.offset += length
elif T is float32:
doAssert(pb.isEnough(sizeof(T)))
doAssert(pb.isEnough(uint64(sizeof(T))))
let u32 = cast[uint32](item)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
pb.offset += sizeof(T)
elif T is float64:
doAssert(pb.isEnough(sizeof(T)))
doAssert(pb.isEnough(uint64(sizeof(T))))
let u64 = cast[uint64](item)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
Expand All @@ -268,7 +259,7 @@ proc write*[T: byte | char](pb: var ProtoBuffer, field: int, value: openArray[T]
doAssert(lres.isOk())
pb.offset += length
if len(value) > 0:
doAssert(pb.isEnough(len(value)))
doAssert(pb.isEnough(value.lenu64))
copyMem(addr pb.buffer[pb.offset], unsafeAddr value[0], len(value))
pb.offset += len(value)

Expand Down Expand Up @@ -327,13 +318,13 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] =
else:
err(ProtoError.VarintDecode)
of ProtoFieldKind.Fixed32:
if data.isEnough(sizeof(uint32)):
if data.isEnough(uint64(sizeof(uint32))):
data.offset += sizeof(uint32)
ok()
else:
err(ProtoError.VarintDecode)
of ProtoFieldKind.Fixed64:
if data.isEnough(sizeof(uint64)):
if data.isEnough(uint64(sizeof(uint64))):
data.offset += sizeof(uint64)
ok()
else:
Expand All @@ -343,14 +334,11 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] =
var bsize = 0'u64
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(data.maxSize):
if data.isEnough(int(bsize)):
data.offset += int(bsize)
ok()
else:
err(ProtoError.MessageIncomplete)
if data.isEnough(bsize):
data.offset += int(bsize)
ok()
else:
err(ProtoError.MessageTooBig)
err(ProtoError.MessageIncomplete)
else:
err(ProtoError.VarintDecode)
of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup:
Expand Down Expand Up @@ -382,15 +370,15 @@ proc getValue[T: ProtoScalar](
err(ProtoError.VarintDecode)
elif T is float32:
doAssert(header.wire == ProtoFieldKind.Fixed32)
if data.isEnough(sizeof(float32)):
if data.isEnough(uint64(sizeof(float32))):
outval = cast[float32](fromBytesLE(uint32, data.toOpenArray()))
data.offset += sizeof(float32)
ok()
else:
err(ProtoError.MessageIncomplete)
elif T is float64:
doAssert(header.wire == ProtoFieldKind.Fixed64)
if data.isEnough(sizeof(float64)):
if data.isEnough(uint64(sizeof(float64))):
outval = cast[float64](fromBytesLE(uint64, data.toOpenArray()))
data.offset += sizeof(float64)
ok()
Expand All @@ -410,22 +398,19 @@ proc getValue[T: byte | char](
outLength = 0
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(data.maxSize):
if data.isEnough(int(bsize)):
outLength = int(bsize)
if len(outBytes) >= int(bsize):
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ok()
else:
# Buffer overflow should not be critical failure
data.offset += int(bsize)
err(ProtoError.BufferOverflow)
if data.isEnough(bsize):
outLength = int(bsize)
if len(outBytes) >= int(bsize):
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ok()
else:
err(ProtoError.MessageIncomplete)
# Buffer overflow should not be critical failure
data.offset += int(bsize)
err(ProtoError.BufferOverflow)
else:
err(ProtoError.MessageTooBig)
err(ProtoError.MessageIncomplete)
else:
err(ProtoError.VarintDecode)

Expand All @@ -439,17 +424,14 @@ proc getValue[T: seq[byte] | string](

if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(data.maxSize):
if data.isEnough(int(bsize)):
outBytes.setLen(bsize)
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ok()
else:
err(ProtoError.MessageIncomplete)
if data.isEnough(bsize):
outBytes.setLen(bsize)
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ok()
else:
err(ProtoError.MessageTooBig)
err(ProtoError.MessageIncomplete)
else:
err(ProtoError.VarintDecode)

Expand Down
10 changes: 7 additions & 3 deletions libp2p/protocols/pubsub/floodsub.nim
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ method unsubscribePeer*(f: FloodSub, peer: PeerId) =

procCall PubSub(f).unsubscribePeer(peer)

method rpcHandler*(f: FloodSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
method rpcHandler*(
f: FloodSub, peer: PubSubPeer, data: seq[byte]
) {.async: (raises: [CancelledError, PeerMessageDecodeError, PeerRateLimitError]).} =
var rpcMsg = decodeRpcMsg(data).valueOr:
debug "failed to decode msg from peer", peer, err = error
raise newException(CatchableError, "Peer msg couldn't be decoded")
raise newException(PeerMessageDecodeError, "Peer msg couldn't be decoded")

trace "decoded msg from peer", peer, payload = rpcMsg.shortLog
# trigger hooks
Expand Down Expand Up @@ -192,7 +194,9 @@ method init*(f: FloodSub) =
f.handler = handler
f.codec = FloodSubCodec

method publish*(f: FloodSub, topic: string, data: seq[byte]): Future[int] {.async.} =
method publish*(
f: FloodSub, topic: string, data: seq[byte]
): Future[int] {.async: (raises: [LPError]).} =
# base returns always 0
discard await procCall PubSub(f).publish(topic, data)

Expand Down
32 changes: 21 additions & 11 deletions libp2p/protocols/pubsub/gossipsub.nim
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =

proc validateAndRelay(
g: GossipSub, msg: Message, msgId: MessageId, saltedId: SaltedId, peer: PubSubPeer
) {.async.} =
) {.async: (raises: []).} =
try:
template topic(): string =
msg.topic
Expand Down Expand Up @@ -511,7 +511,9 @@ proc messageOverhead(g: GossipSub, msg: RPCMsg, msgSize: int): int =

msgSize - payloadSize - controlSize

proc rateLimit*(g: GossipSub, peer: PubSubPeer, overhead: int) {.async.} =
proc rateLimit*(
g: GossipSub, peer: PubSubPeer, overhead: int
) {.async: (raises: [PeerRateLimitError]).} =
peer.overheadRateLimitOpt.withValue(overheadRateLimit):
if not overheadRateLimit.tryConsume(overhead):
libp2p_gossipsub_peers_rate_limit_hits.inc(labelValues = [peer.getAgent()])
Expand All @@ -524,7 +526,9 @@ proc rateLimit*(g: GossipSub, peer: PubSubPeer, overhead: int) {.async.} =
PeerRateLimitError, "Peer disconnected because it's above rate limit."
)

method rpcHandler*(g: GossipSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
method rpcHandler*(
g: GossipSub, peer: PubSubPeer, data: seq[byte]
) {.async: (raises: [CancelledError, PeerMessageDecodeError, PeerRateLimitError]).} =
let msgSize = data.len
var rpcMsg = decodeRpcMsg(data).valueOr:
debug "failed to decode msg from peer", peer, err = error
Expand All @@ -534,7 +538,7 @@ method rpcHandler*(g: GossipSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
# TODO evaluate behaviour penalty values
peer.behaviourPenalty += 0.1

raise newException(CatchableError, "Peer msg couldn't be decoded")
raise newException(PeerMessageDecodeError, "Peer msg couldn't be decoded")

when defined(libp2p_expensive_metrics):
for m in rpcMsg.messages:
Expand Down Expand Up @@ -682,7 +686,9 @@ method onTopicSubscription*(g: GossipSub, topic: string, subscribed: bool) =
# Send unsubscribe (in reverse order to sub/graft)
procCall PubSub(g).onTopicSubscription(topic, subscribed)

method publish*(g: GossipSub, topic: string, data: seq[byte]): Future[int] {.async.} =
method publish*(
g: GossipSub, topic: string, data: seq[byte]
): Future[int] {.async: (raises: [LPError]).} =
logScope:
topic

Expand Down Expand Up @@ -792,7 +798,9 @@ method publish*(g: GossipSub, topic: string, data: seq[byte]): Future[int] {.asy
trace "Published message to peers", peers = peers.len
return peers.len

proc maintainDirectPeer(g: GossipSub, id: PeerId, addrs: seq[MultiAddress]) {.async.} =
proc maintainDirectPeer(
g: GossipSub, id: PeerId, addrs: seq[MultiAddress]
) {.async: (raises: [CancelledError]).} =
if id notin g.peers:
trace "Attempting to dial a direct peer", peer = id
if g.switch.isConnected(id):
Expand All @@ -808,11 +816,13 @@ proc maintainDirectPeer(g: GossipSub, id: PeerId, addrs: seq[MultiAddress]) {.as
except CatchableError as exc:
debug "Direct peer error dialing", description = exc.msg

proc addDirectPeer*(g: GossipSub, id: PeerId, addrs: seq[MultiAddress]) {.async.} =
proc addDirectPeer*(
g: GossipSub, id: PeerId, addrs: seq[MultiAddress]
) {.async: (raises: [CancelledError]).} =
g.parameters.directPeers[id] = addrs
await g.maintainDirectPeer(id, addrs)

proc maintainDirectPeers(g: GossipSub) {.async.} =
proc maintainDirectPeers(g: GossipSub) {.async: (raises: [CancelledError]).} =
richard-ramos marked this conversation as resolved.
Show resolved Hide resolved
heartbeat "GossipSub DirectPeers", 1.minutes:
for id, addrs in g.parameters.directPeers:
await g.addDirectPeer(id, addrs)
Expand Down Expand Up @@ -846,9 +856,9 @@ method stop*(g: GossipSub): Future[void] {.async: (raises: [], raw: true).} =
return fut

# stop heartbeat interval
g.directPeersLoop.cancel()
g.scoringHeartbeatFut.cancel()
g.heartbeatFut.cancel()
g.directPeersLoop.cancelSoon()
g.scoringHeartbeatFut.cancelSoon()
g.heartbeatFut.cancelSoon()
g.heartbeatFut = nil
fut

Expand Down
Loading
Loading