Skip to content

Commit

Permalink
Make server logic vector clock-aware
Browse files Browse the repository at this point in the history
  • Loading branch information
richardhuaaa committed Sep 10, 2024
1 parent b5409cf commit 9ab9f05
Show file tree
Hide file tree
Showing 22 changed files with 426 additions and 510 deletions.
18 changes: 9 additions & 9 deletions contracts/src/Nodes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ All nodes on the network periodically check this contract to determine which nod
contract Nodes is ERC721, Ownable {
constructor() ERC721("XMTP Node Operator", "XMTP") Ownable(msg.sender) {}

// uint16 counter so that we cannot create more than 65k IDs
// uint32 counter so that we cannot create more than max IDs
// The ERC721 standard expects the tokenID to be uint256 for standard methods unfortunately
uint16 private _nodeIdCounter;
uint32 private _nodeIdCounter;

// A node, as stored in the internal mapping
struct Node {
Expand All @@ -26,7 +26,7 @@ contract Nodes is ERC721, Ownable {
}

struct NodeWithId {
uint16 nodeId;
uint32 nodeId;
Node node;
}

Expand All @@ -42,8 +42,8 @@ contract Nodes is ERC721, Ownable {
address to,
bytes calldata signingKeyPub,
string calldata httpAddress
) public onlyOwner returns (uint16) {
uint16 nodeId = _nodeIdCounter;
) public onlyOwner returns (uint32) {
uint32 nodeId = _nodeIdCounter;
_mint(to, nodeId);
_nodes[nodeId] = Node(signingKeyPub, httpAddress, true);
_emitNodeUpdate(nodeId);
Expand Down Expand Up @@ -101,7 +101,7 @@ contract Nodes is ERC721, Ownable {
Get a list of healthy nodes with their ID and metadata
*/
function healthyNodes() public view returns (NodeWithId[] memory) {
uint16 totalNodeCount = _nodeIdCounter;
uint32 totalNodeCount = _nodeIdCounter;
uint256 healthyCount = 0;

// First, count the number of healthy nodes
Expand All @@ -116,7 +116,7 @@ contract Nodes is ERC721, Ownable {
uint256 currentIndex = 0;

// Populate the array with healthy nodes
for (uint16 i = 0; i < totalNodeCount; i++) {
for (uint32 i = 0; i < totalNodeCount; i++) {
if (_nodeExists(i) && _nodes[i].isHealthy) {
healthyNodesList[currentIndex] = NodeWithId({
nodeId: i,
Expand All @@ -133,10 +133,10 @@ contract Nodes is ERC721, Ownable {
Get all nodes regardless of their health status
*/
function allNodes() public view returns (NodeWithId[] memory) {
uint16 totalNodeCount = _nodeIdCounter;
uint32 totalNodeCount = _nodeIdCounter;
NodeWithId[] memory allNodesList = new NodeWithId[](totalNodeCount);

for (uint16 i = 0; i < totalNodeCount; i++) {
for (uint32 i = 0; i < totalNodeCount; i++) {
allNodesList[i] = NodeWithId({nodeId: i, node: _nodes[i]});
}

Expand Down
2 changes: 1 addition & 1 deletion dev/test
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -e

ulimit -n 2048

go test -timeout 3s `go list ./... | grep -v -e 'pkg/abis' -e 'pkg/config' -e 'pkg/proto' -e 'pkg/mock' -e 'pkg/testing'` "$@"
go test -timeout 10s `go list ./... | grep -v -e 'pkg/abis' -e 'pkg/config' -e 'pkg/proto' -e 'pkg/mock' -e 'pkg/testing'` "$@"

if [ -n "${RACE:-}" ]; then
echo
Expand Down
22 changes: 11 additions & 11 deletions pkg/abis/nodes.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pkg/api/publishWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type PublishWorker struct {
notifier chan<- bool
registrant *registrant.Registrant
store *sql.DB
subscription db.DBSubscription[queries.StagedOriginatorEnvelope]
subscription db.DBSubscription[queries.StagedOriginatorEnvelope, int64]
}

func StartPublishWorker(
Expand Down
27 changes: 19 additions & 8 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import (
)

const (
maxRequestedRows int32 = 1000
maxRequestedRows uint32 = 1000
maxVectorClockLength int = 100
)

type Service struct {
Expand Down Expand Up @@ -94,9 +95,11 @@ func (s *Service) queryReqToDBParams(
req *message_api.QueryEnvelopesRequest,
) (*queries.SelectGatewayEnvelopesParams, error) {
params := queries.SelectGatewayEnvelopesParams{
Topic: []byte{},
OriginatorNodeID: sql.NullInt32{},
RowLimit: db.NullInt32(maxRequestedRows),
Topic: nil,
OriginatorNodeID: sql.NullInt32{},
RowLimit: sql.NullInt32{},
CursorNodeIds: nil,
CursorSequenceIds: nil,
}

query := req.GetQuery()
Expand All @@ -112,11 +115,19 @@ func (s *Service) queryReqToDBParams(
default:
}

// TODO(rich): Handle last_seen properly
vc := query.GetLastSeen().GetNodeIdToSequenceId()
if len(vc) > maxVectorClockLength {
return nil, status.Errorf(
codes.InvalidArgument,
"vector clock length exceeds maximum of %d",
maxVectorClockLength,
)
}
db.SetVectorClock(&params, vc)

limit := int32(req.GetLimit())
limit := req.GetLimit()
if limit > 0 && limit <= maxRequestedRows {
params.RowLimit = db.NullInt32(limit)
params.RowLimit = db.NullInt32(int32(limit))
}

return &params, nil
Expand Down Expand Up @@ -194,7 +205,7 @@ func (s *Service) validateClientInfo(clientEnv *message_api.ClientEnvelope) ([]b
return nil, status.Errorf(codes.InvalidArgument, "missing target topic")
}

// TODO(rich): Verify all originators have synced past `last_originator_sids`
// TODO(rich): Verify all originators have synced past `last_seen`
// TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group
// TODO(rich): Perform any payload-specific validation (e.g. identity updates)

Expand Down
10 changes: 2 additions & 8 deletions pkg/api/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ func TestMissingTopicOnPublish(t *testing.T) {
func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopeParams {
db_rows := []queries.InsertGatewayEnvelopeParams{
{
// Auto-generated ID: 1
OriginatorNodeID: 1,
OriginatorSequenceID: 1,
Topic: []byte("topicA"),
Expand All @@ -139,7 +138,6 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
),
},
{
// Auto-generated ID: 2
OriginatorNodeID: 2,
OriginatorSequenceID: 1,
Topic: []byte("topicA"),
Expand All @@ -149,7 +147,6 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
),
},
{
// Auto-generated ID: 3
OriginatorNodeID: 1,
OriginatorSequenceID: 2,
Topic: []byte("topicB"),
Expand All @@ -159,7 +156,6 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
),
},
{
// Auto-generated ID: 4
OriginatorNodeID: 2,
OriginatorSequenceID: 2,
Topic: []byte("topicB"),
Expand All @@ -169,7 +165,6 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
),
},
{
// Auto-generated ID: 5
OriginatorNodeID: 1,
OriginatorSequenceID: 3,
Topic: []byte("topicA"),
Expand Down Expand Up @@ -256,7 +251,6 @@ func TestQueryEnvelopesByTopic(t *testing.T) {
}

func TestQueryEnvelopesFromLastSeen(t *testing.T) {
t.Skip("Not implemented yet")
svc, db, cleanup := newTestService(t)
defer cleanup()
db_rows := setupQueryTest(t, db)
Expand All @@ -266,13 +260,13 @@ func TestQueryEnvelopesFromLastSeen(t *testing.T) {
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Filter: nil,
LastSeen: &message_api.VectorClock{},
LastSeen: &message_api.VectorClock{NodeIdToSequenceId: map[uint32]uint64{1: 2}},
},
Limit: 0,
},
)
require.NoError(t, err)
checkRowsMatchProtos(t, db_rows, []int{}, resp.GetEnvelopes())
checkRowsMatchProtos(t, db_rows, []int{1, 3, 4}, resp.GetEnvelopes())
}

func TestQueryEnvelopesWithEmptyResult(t *testing.T) {
Expand Down
55 changes: 29 additions & 26 deletions pkg/db/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import (
"go.uber.org/zap"
)

type PollableDBQuery[ValueType any] func(ctx context.Context, lastSeenID int64, numRows int32) (results []ValueType, lastID int64, err error)
type PollableDBQuery[ValueType any, CursorType any] func(
ctx context.Context,
lastSeen CursorType,
numRows int32,
) (results []ValueType, nextCursor CursorType, err error)

// Poll whenever notified, or at an interval if not notified
type PollingOptions struct {
Expand All @@ -17,33 +21,33 @@ type PollingOptions struct {
NumRows int32
}

type DBSubscription[ValueType any] struct {
ctx context.Context
log *zap.Logger
lastSeenID int64
options PollingOptions
query PollableDBQuery[ValueType]
updates chan<- []ValueType
type DBSubscription[ValueType any, CursorType any] struct {
ctx context.Context
log *zap.Logger
lastSeen CursorType
options PollingOptions
query PollableDBQuery[ValueType, CursorType]
updates chan<- []ValueType
}

func NewDBSubscription[ValueType any](
func NewDBSubscription[ValueType any, CursorType any](
ctx context.Context,
log *zap.Logger,
query PollableDBQuery[ValueType],
lastSeenID int64,
query PollableDBQuery[ValueType, CursorType],
lastSeen CursorType,
options PollingOptions,
) *DBSubscription[ValueType] {
return &DBSubscription[ValueType]{
ctx: ctx,
log: log,
lastSeenID: lastSeenID,
options: options,
query: query,
updates: nil,
) *DBSubscription[ValueType, CursorType] {
return &DBSubscription[ValueType, CursorType]{
ctx: ctx,
log: log,
lastSeen: lastSeen,
options: options,
query: query,
updates: nil,
}
}

func (s *DBSubscription[ValueType]) Start() (<-chan []ValueType, error) {
func (s *DBSubscription[ValueType, CursorType]) Start() (<-chan []ValueType, error) {
if s.updates != nil {
return nil, fmt.Errorf("Already started")
}
Expand Down Expand Up @@ -75,25 +79,24 @@ func (s *DBSubscription[ValueType]) Start() (<-chan []ValueType, error) {
return updates, nil
}

func (s *DBSubscription[ValueType]) poll() {
func (s *DBSubscription[ValueType, CursorType]) poll() {
// Repeatedly query page by page until no more results
for {
results, lastID, err := s.query(s.ctx, s.lastSeenID, s.options.NumRows)
results, lastID, err := s.query(s.ctx, s.lastSeen, s.options.NumRows)
if s.ctx.Err() != nil {
break
} else if err != nil {
s.log.Error(
"Error querying for DB subscription",
zap.Error(err),
zap.Int64("lastSeenID", s.lastSeenID),
zap.Any("lastSeen", s.lastSeen),
zap.Int32("numRows", s.options.NumRows),
)
// Did not update lastSeenID; will retry on next poll
// Did not update lastSeen; will retry on next poll
break
} else if len(results) == 0 {
break
}
s.lastSeenID = lastID
s.lastSeen = lastID
s.updates <- results
if int32(len(results)) < s.options.NumRows {
break
Expand Down
Loading

0 comments on commit 9ab9f05

Please sign in to comment.