Skip to content

Commit

Permalink
Update p2p proto definitons (#1255)
Browse files Browse the repository at this point in the history
and make each request use seperate pids
  • Loading branch information
omerfirmak authored Sep 26, 2023
1 parent ddc7f6a commit 5a6b72e
Show file tree
Hide file tree
Showing 23 changed files with 4,676 additions and 2,179 deletions.
115 changes: 31 additions & 84 deletions p2p/starknet/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ import (
type NewStreamFunc func(ctx context.Context, pids ...protocol.ID) (network.Stream, error)

type Client struct {
newStream NewStreamFunc
protocolID protocol.ID
log utils.Logger
newStream NewStreamFunc
network utils.Network
log utils.Logger
}

func NewClient(newStream NewStreamFunc, protocolID protocol.ID, log utils.Logger) *Client {
func NewClient(newStream NewStreamFunc, snNetwork utils.Network, log utils.Logger) *Client {
return &Client{
newStream: newStream,
protocolID: protocolID,
log: log,
newStream: newStream,
network: snNetwork,
log: log,
}
}

func (c *Client) sendAndCloseWrite(stream network.Stream, req proto.Message) error {
func sendAndCloseWrite(stream network.Stream, req proto.Message) error {
reqBytes, err := proto.Marshal(req)
if err != nil {
return err
Expand All @@ -39,101 +39,48 @@ func (c *Client) sendAndCloseWrite(stream network.Stream, req proto.Message) err
return stream.CloseWrite()
}

func (c *Client) receiveInto(stream network.Stream, res proto.Message) error {
func receiveInto(stream network.Stream, res proto.Message) error {
return protodelim.UnmarshalFrom(&byteReader{stream}, res)
}

func (c *Client) sendAndReceiveInto(ctx context.Context, req, res proto.Message) error {
stream, err := c.newStream(ctx, c.protocolID)
if err != nil {
return err
}
defer stream.Close() // todo: dont ignore close errors

if err = c.sendAndCloseWrite(stream, req); err != nil {
return err
}

return c.receiveInto(stream, res)
}

func (c *Client) GetBlocks(ctx context.Context, req *spec.GetBlocks) (Stream[*spec.BlockHeader], error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetBlocks{
GetBlocks: req,
},
}

stream, err := c.newStream(ctx, c.protocolID)
func requestAndReceiveStream[ReqT proto.Message, ResT proto.Message](ctx context.Context,
newStream NewStreamFunc, protocolID protocol.ID, req ReqT,
) (Stream[ResT], error) {
stream, err := newStream(ctx, protocolID)
if err != nil {
return nil, err
}
if err := c.sendAndCloseWrite(stream, &wrappedReq); err != nil {
if err := sendAndCloseWrite(stream, req); err != nil {
return nil, err
}

return func() (*spec.BlockHeader, bool) {
var res spec.BlockHeader
if err := c.receiveInto(stream, &res); err != nil {
return func() (ResT, bool) {
var zero ResT
res := zero.ProtoReflect().New().Interface()
if err := receiveInto(stream, res); err != nil {
stream.Close() // todo: dont ignore close errors
return nil, false
return zero, false
}
return &res, true
return res.(ResT), true
}, nil
}

func (c *Client) GetSignatures(ctx context.Context, req *spec.GetSignatures) (*spec.Signatures, error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetSignatures{
GetSignatures: req,
},
}

var res spec.Signatures
if err := c.sendAndReceiveInto(ctx, &wrappedReq, &res); err != nil {
return nil, err
}
return &res, nil
func (c *Client) RequestBlockHeaders(ctx context.Context, req *spec.BlockHeadersRequest) (Stream[*spec.BlockHeadersResponse], error) {
return requestAndReceiveStream[*spec.BlockHeadersRequest, *spec.BlockHeadersResponse](ctx, c.newStream, BlockHeadersPID(c.network), req)
}

func (c *Client) GetEvents(ctx context.Context, req *spec.GetEvents) (*spec.Events, error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetEvents{
GetEvents: req,
},
}

var res spec.Events
if err := c.sendAndReceiveInto(ctx, &wrappedReq, &res); err != nil {
return nil, err
}
return &res, nil
func (c *Client) RequestBlockBodies(ctx context.Context, req *spec.BlockBodiesRequest) (Stream[*spec.BlockBodiesResponse], error) {
return requestAndReceiveStream[*spec.BlockBodiesRequest, *spec.BlockBodiesResponse](ctx, c.newStream, BlockBodiesPID(c.network), req)
}

func (c *Client) GetReceipts(ctx context.Context, req *spec.GetReceipts) (*spec.Receipts, error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetReceipts{
GetReceipts: req,
},
}

var res spec.Receipts
if err := c.sendAndReceiveInto(ctx, &wrappedReq, &res); err != nil {
return nil, err
}
return &res, nil
func (c *Client) RequestEvents(ctx context.Context, req *spec.EventsRequest) (Stream[*spec.EventsResponse], error) {
return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(c.network), req)
}

func (c *Client) GetTransactions(ctx context.Context, req *spec.GetTransactions) (*spec.Transactions, error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetTransactions{
GetTransactions: req,
},
}
func (c *Client) RequestReceipts(ctx context.Context, req *spec.ReceiptsRequest) (Stream[*spec.ReceiptsResponse], error) {
return requestAndReceiveStream[*spec.ReceiptsRequest, *spec.ReceiptsResponse](ctx, c.newStream, ReceiptsPID(c.network), req)
}

var res spec.Transactions
if err := c.sendAndReceiveInto(ctx, &wrappedReq, &res); err != nil {
return nil, err
}
return &res, nil
func (c *Client) RequestTransactions(ctx context.Context, req *spec.TransactionsRequest) (Stream[*spec.TransactionsResponse], error) {
return requestAndReceiveStream[*spec.TransactionsRequest, *spec.TransactionsResponse](ctx, c.newStream, TransactionsPID(c.network), req)
}
173 changes: 90 additions & 83 deletions p2p/starknet/handlers.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
//go:generate protoc --go_out=./ --proto_path=./ --go_opt=Mp2p/proto/requests.proto=./spec --go_opt=Mp2p/proto/transaction.proto=./spec --go_opt=Mp2p/proto/state.proto=./spec --go_opt=Mp2p/proto/snapshot.proto=./spec --go_opt=Mp2p/proto/receipt.proto=./spec --go_opt=Mp2p/proto/mempool.proto=./spec --go_opt=Mp2p/proto/event.proto=./spec --go_opt=Mp2p/proto/block.proto=./spec --go_opt=Mp2p/proto/common.proto=./spec p2p/proto/transaction.proto p2p/proto/state.proto p2p/proto/snapshot.proto p2p/proto/common.proto p2p/proto/block.proto p2p/proto/event.proto p2p/proto/receipt.proto p2p/proto/requests.proto
//go:generate protoc --go_out=./ --proto_path=./ --go_opt=Mp2p/proto/transaction.proto=./spec --go_opt=Mp2p/proto/state.proto=./spec --go_opt=Mp2p/proto/snapshot.proto=./spec --go_opt=Mp2p/proto/receipt.proto=./spec --go_opt=Mp2p/proto/mempool.proto=./spec --go_opt=Mp2p/proto/event.proto=./spec --go_opt=Mp2p/proto/block.proto=./spec --go_opt=Mp2p/proto/common.proto=./spec p2p/proto/transaction.proto p2p/proto/state.proto p2p/proto/snapshot.proto p2p/proto/common.proto p2p/proto/block.proto p2p/proto/event.proto p2p/proto/receipt.proto
package starknet

import (
"bytes"
"errors"
"fmt"
"sync"

"github.com/NethermindEth/juno/adapters/core2p2p"
"github.com/NethermindEth/juno/adapters/p2p2core"
"github.com/NethermindEth/juno/blockchain"
"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/p2p/starknet/spec"
"github.com/NethermindEth/juno/utils"
"github.com/libp2p/go-libp2p/core/network"
Expand Down Expand Up @@ -43,133 +38,145 @@ func getBuffer() *bytes.Buffer {
return buffer
}

func (h *Handler) StreamHandler(stream network.Stream) {
func streamHandler[ReqT proto.Message](stream network.Stream,
reqHandler func(req ReqT) (Stream[proto.Message], error), log utils.SimpleLogger,
) {
defer func() {
if err := stream.Close(); err != nil {
h.log.Debugw("Error closing stream", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
log.Debugw("Error closing stream", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
}
}()

buffer := getBuffer()
defer bufferPool.Put(buffer)

if _, err := buffer.ReadFrom(stream); err != nil {
h.log.Debugw("Error reading from stream", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
log.Debugw("Error reading from stream", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
return
}

var req spec.Request
if err := proto.Unmarshal(buffer.Bytes(), &req); err != nil {
h.log.Debugw("Error unmarshalling message", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
var zero ReqT
req := zero.ProtoReflect().New().Interface()
if err := proto.Unmarshal(buffer.Bytes(), req); err != nil {
log.Debugw("Error unmarshalling message", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
return
}

response, err := h.reqHandler(&req)
response, err := reqHandler(req.(ReqT))
if err != nil {
h.log.Debugw("Error handling request", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err, "request", req.String())
log.Debugw("Error handling request", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
return
}

for msg, valid := response(); valid; msg, valid = response() {
if _, err := protodelim.MarshalTo(stream, msg); err != nil { // todo: figure out if we need buffered io here
h.log.Debugw("Error writing response", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
log.Debugw("Error writing response", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
}
}
}

func (h *Handler) reqHandler(req *spec.Request) (Stream[proto.Message], error) {
var singleResponse proto.Message
var err error
switch typedReq := req.GetReq().(type) {
case *spec.Request_GetBlocks:
return h.HandleGetBlocks(typedReq.GetBlocks)
case *spec.Request_GetSignatures:
singleResponse, err = h.HandleGetSignatures(typedReq.GetSignatures)
case *spec.Request_GetEvents:
singleResponse, err = h.HandleGetEvents(typedReq.GetEvents)
case *spec.Request_GetReceipts:
singleResponse, err = h.HandleGetReceipts(typedReq.GetReceipts)
case *spec.Request_GetTransactions:
singleResponse, err = h.HandleGetTransactions(typedReq.GetTransactions)
default:
return nil, fmt.Errorf("unhandled request %T", typedReq)
}
func (h *Handler) BlockHeadersHandler(stream network.Stream) {
streamHandler[*spec.BlockHeadersRequest](stream, h.onBlockHeadersRequest, h.log)
}

if err != nil {
return nil, err
}
return StaticStream[proto.Message](singleResponse), nil
func (h *Handler) BlockBodiesHandler(stream network.Stream) {
streamHandler[*spec.BlockBodiesRequest](stream, h.onBlockBodiesRequest, h.log)
}

func (h *Handler) EventsHandler(stream network.Stream) {
streamHandler[*spec.EventsRequest](stream, h.onEventsRequest, h.log)
}

func (h *Handler) HandleGetBlocks(req *spec.GetBlocks) (Stream[proto.Message], error) {
func (h *Handler) ReceiptsHandler(stream network.Stream) {
streamHandler[*spec.ReceiptsRequest](stream, h.onReceiptsRequest, h.log)
}

func (h *Handler) TransactionsHandler(stream network.Stream) {
streamHandler[*spec.TransactionsRequest](stream, h.onTransactionsRequest, h.log)
}

func (h *Handler) onBlockHeadersRequest(req *spec.BlockHeadersRequest) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
count := uint32(0)
count := uint64(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
count++
return &spec.BlockHeader{
State: &spec.Merkle{
NLeaves: count - 1,
return &spec.BlockHeadersResponse{
Part: []*spec.BlockHeadersResponsePart{
{
HeaderMessage: &spec.BlockHeadersResponsePart_Header{
Header: &spec.BlockHeader{
Number: count - 1,
},
},
},
},
}, true
}, nil
}

func (h *Handler) HandleGetSignatures(req *spec.GetSignatures) (*spec.Signatures, error) {
func (h *Handler) onBlockBodiesRequest(req *spec.BlockBodiesRequest) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
return &spec.Signatures{
Id: req.Id,
count := uint64(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
count++
return &spec.BlockBodiesResponse{
Id: &spec.BlockID{
Number: count - 1,
},
}, true
}, nil
}

func (h *Handler) HandleGetEvents(req *spec.GetEvents) (*spec.Events, error) {
block, err := h.blockByID(req.Id)
if err != nil {
return nil, err
}

var result spec.Events
for _, receipt := range block.Receipts {
for _, ev := range receipt.Events {
event := &spec.Event{
FromAddress: core2p2p.AdaptFelt(ev.From),
Keys: utils.Map(ev.Keys, core2p2p.AdaptFelt),
Data: utils.Map(ev.Data, core2p2p.AdaptFelt),
}

result.Events = append(result.Events, event)
func (h *Handler) onEventsRequest(req *spec.EventsRequest) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
count := uint64(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
}

return &result, nil
count++
return &spec.EventsResponse{
Id: &spec.BlockID{
Number: count - 1,
},
}, true
}, nil
}

func (h *Handler) HandleGetReceipts(req *spec.GetReceipts) (*spec.Receipts, error) {
func (h *Handler) onReceiptsRequest(req *spec.ReceiptsRequest) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
magic := 37
return &spec.Receipts{
Receipts: make([]*spec.Receipt, magic),
count := uint64(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
count++
return &spec.ReceiptsResponse{
Id: &spec.BlockID{
Number: count - 1,
},
}, true
}, nil
}

func (h *Handler) HandleGetTransactions(req *spec.GetTransactions) (*spec.Transactions, error) {
func (h *Handler) onTransactionsRequest(req *spec.TransactionsRequest) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
magic := 1337
return &spec.Transactions{
Transactions: make([]*spec.Transaction, magic),
count := uint64(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
count++
return &spec.TransactionsResponse{
Id: &spec.BlockID{
Number: count - 1,
},
}, true
}, nil
}

func (h *Handler) blockByID(id *spec.BlockID) (*core.Block, error) {
switch {
case id == nil:
return nil, errors.New("block id is nil")
case id.Hash != nil:
hash := p2p2core.AdaptHash(id.Hash)
return h.bcReader.BlockByHash(hash)
default:
return h.bcReader.BlockByNumber(id.Height)
}
}
Loading

0 comments on commit 5a6b72e

Please sign in to comment.