Skip to content

Commit

Permalink
V2 Retrieval Isolation
Browse files Browse the repository at this point in the history
- Updates operator socket format to host:v1DispersalPort;v1RetrievalPort;v2DispersalPort;V2RetrievalPort
- Refactors parseOperatorSocket to handle v2 retrieval socket with strict port validation
- Adds ValidatePort() to core utils
- Adds stricter validation for operator sockets. Valid formats:
    - host:v1DispersalPort;v1RetrievalPort
    - host:v1DispersalPort;v1RetrievalPort;v2DispersalPort;V2RetrievalPort
- Adds host validation supporting both FQDN and IP addr
- Registers node for the new V2 retrieval service
  • Loading branch information
supriya-premkumar committed Jan 29, 2025
1 parent fefad3a commit 85c70a9
Show file tree
Hide file tree
Showing 24 changed files with 265 additions and 88 deletions.
4 changes: 2 additions & 2 deletions api/clients/node_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (c client) GetBlobHeader(
blobIndex uint32,
) (*core.BlobHeader, *merkletree.Proof, error) {
conn, err := grpc.NewClient(
core.OperatorSocket(socket).GetRetrievalSocket(),
core.OperatorSocket(socket).GetV1RetrievalSocket(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
Expand Down Expand Up @@ -86,7 +86,7 @@ func (c client) GetChunks(
chunksChan chan RetrievedChunks,
) {
conn, err := grpc.NewClient(
core.OperatorSocket(opInfo.Socket).GetRetrievalSocket(),
core.OperatorSocket(opInfo.Socket).GetV1RetrievalSocket(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion api/clients/v2/retrieval_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ func (r *retrievalClient) getChunksFromOperator(
chunksChan chan clients.RetrievedChunks,
) {
conn, err := grpc.NewClient(
core.OperatorSocket(opInfo.Socket).GetRetrievalSocket(),
//TODO: Verify if this should point to V2RetrievalSocket
core.OperatorSocket(opInfo.Socket).GetV2RetrievalSocket(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
defer func() {
Expand Down
5 changes: 4 additions & 1 deletion core/mock/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type PrivateOperatorInfo struct {
DispersalPort string
RetrievalPort string
V2DispersalPort string
V2RetrievalPort string
}

type PrivateOperatorState struct {
Expand Down Expand Up @@ -140,7 +141,8 @@ func (d *ChainDataMock) GetTotalOperatorStateWithQuorums(ctx context.Context, bl
dispersalPort := fmt.Sprintf("3%03v", 2*i)
retrievalPort := fmt.Sprintf("3%03v", 2*i+1)
v2DispersalPort := fmt.Sprintf("3%03v", 2*i+2)
socket := core.MakeOperatorSocket(host, dispersalPort, retrievalPort, v2DispersalPort)
v2RetrievalPort := fmt.Sprintf("3%03v", 2*i+3)
socket := core.MakeOperatorSocket(host, dispersalPort, retrievalPort, v2DispersalPort, v2RetrievalPort)

indexed := &core.IndexedOperatorInfo{
Socket: string(socket),
Expand All @@ -161,6 +163,7 @@ func (d *ChainDataMock) GetTotalOperatorStateWithQuorums(ctx context.Context, bl
DispersalPort: dispersalPort,
RetrievalPort: retrievalPort,
V2DispersalPort: v2DispersalPort,
V2RetrievalPort: v2RetrievalPort,
}

indexedOperators[id] = indexed
Expand Down
18 changes: 13 additions & 5 deletions core/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,25 +528,33 @@ func decode(data []byte, obj any) error {
}

func (s OperatorSocket) GetV1DispersalSocket() string {
ip, v1DispersalPort, _, _, err := ParseOperatorSocket(string(s))
ip, v1DispersalPort, _, _, _, err := ParseOperatorSocket(string(s))
if err != nil {
return ""
}
return fmt.Sprintf("%s:%s", ip, v1DispersalPort)
}

func (s OperatorSocket) GetV2DispersalSocket() string {
ip, _, _, v2DispersalPort, err := ParseOperatorSocket(string(s))
ip, _, _, v2DispersalPort, _, err := ParseOperatorSocket(string(s))
if err != nil || v2DispersalPort == "" {
return ""
}
return fmt.Sprintf("%s:%s", ip, v2DispersalPort)
}

func (s OperatorSocket) GetRetrievalSocket() string {
ip, _, retrievalPort, _, err := ParseOperatorSocket(string(s))
func (s OperatorSocket) GetV1RetrievalSocket() string {
ip, _, v1retrievalPort, _, _, err := ParseOperatorSocket(string(s))
if err != nil {
return ""
}
return fmt.Sprintf("%s:%s", ip, retrievalPort)
return fmt.Sprintf("%s:%s", ip, v1retrievalPort)
}

func (s OperatorSocket) GetV2RetrievalSocket() string {
ip, _, _, _, v2RetrievalPort, err := ParseOperatorSocket(string(s))
if err != nil || v2RetrievalPort == "" {
return ""
}
return fmt.Sprintf("%s:%s", ip, v2RetrievalPort)
}
101 changes: 79 additions & 22 deletions core/serialization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,36 +195,37 @@ func TestHashPubKeyG1(t *testing.T) {
}

func TestParseOperatorSocket(t *testing.T) {
operatorSocket := "localhost:1234;5678;9999"
host, dispersalPort, retrievalPort, v2DispersalPort, err := core.ParseOperatorSocket(operatorSocket)
operatorSocket := "localhost:1234;5678;9999;10001"
host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, v2RetrievalPort, err := core.ParseOperatorSocket(operatorSocket)
assert.NoError(t, err)
assert.Equal(t, "localhost", host)
assert.Equal(t, "1234", dispersalPort)
assert.Equal(t, "5678", retrievalPort)
assert.Equal(t, "1234", v1DispersalPort)
assert.Equal(t, "5678", v1RetrievalPort)
assert.Equal(t, "9999", v2DispersalPort)
assert.Equal(t, "10001", v2RetrievalPort)

host, dispersalPort, retrievalPort, v2DispersalPort, err = core.ParseOperatorSocket("localhost:1234;5678")
host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, _, err = core.ParseOperatorSocket("localhost:1234;5678")
assert.NoError(t, err)
assert.Equal(t, "localhost", host)
assert.Equal(t, "1234", dispersalPort)
assert.Equal(t, "5678", retrievalPort)
assert.Equal(t, "1234", v1DispersalPort)
assert.Equal(t, "5678", v1RetrievalPort)
assert.Equal(t, "", v2DispersalPort)

_, _, _, _, err = core.ParseOperatorSocket("localhost;1234;5678")
_, _, _, _, _, err = core.ParseOperatorSocket("localhost;1234;5678")
assert.NotNil(t, err)
assert.ErrorContains(t, err, "invalid socket address format")
assert.ErrorContains(t, err, "invalid host address format")

_, _, _, _, err = core.ParseOperatorSocket("localhost:12345678")
_, _, _, _, _, err = core.ParseOperatorSocket("localhost:12345678")
assert.NotNil(t, err)
assert.ErrorContains(t, err, "invalid socket address format")
assert.ErrorContains(t, err, "invalid v1 dispersal port format")

_, _, _, _, err = core.ParseOperatorSocket("localhost1234;5678")
_, _, _, _, _, err = core.ParseOperatorSocket("localhost1234;5678")
assert.NotNil(t, err)
assert.ErrorContains(t, err, "invalid socket address format")
assert.ErrorContains(t, err, "invalid host address format")
}

func TestGetV1DispersalSocket(t *testing.T) {
operatorSocket := core.OperatorSocket("localhost:1234;5678;9999")
operatorSocket := core.OperatorSocket("localhost:1234;5678;9999;1025")
socket := operatorSocket.GetV1DispersalSocket()
assert.Equal(t, "localhost:1234", socket)

Expand All @@ -234,28 +235,84 @@ func TestGetV1DispersalSocket(t *testing.T) {

operatorSocket = core.OperatorSocket("localhost:1234;5678;")
socket = operatorSocket.GetV1DispersalSocket()
assert.Equal(t, "localhost:1234", socket)
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234")
socket = operatorSocket.GetV1DispersalSocket()
assert.Equal(t, "", socket)
}

func TestGetRetrievalSocket(t *testing.T) {
operatorSocket := core.OperatorSocket("localhost:1234;5678;9999")
socket := operatorSocket.GetRetrievalSocket()
func TestGetV1RetrievalSocket(t *testing.T) {
// Valid v1/v2 socket
operatorSocket := core.OperatorSocket("localhost:1234;5678;9999;10001")
socket := operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "localhost:5678", socket)

// Valid v1 socket
operatorSocket = core.OperatorSocket("localhost:1234;5678")
socket = operatorSocket.GetRetrievalSocket()
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "localhost:5678", socket)

// Invalid socket testcases
operatorSocket = core.OperatorSocket("localhost:1234;5678;9999;10001;")
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234;5678;")
socket = operatorSocket.GetRetrievalSocket()
assert.Equal(t, "localhost:5678", socket)
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:;1234;5678;")
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234;:;5678;")
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:;;;")
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234")
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)
}

func TestGetV2RetrievalSocket(t *testing.T) {
// Valid v1/v2 socket
operatorSocket := core.OperatorSocket("localhost:1234;5678;9999;10001")
socket := operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "localhost:10001", socket)

// Invalid v2 socket
operatorSocket = core.OperatorSocket("localhost:1234;5678")
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

// Invalid socket testcases
operatorSocket = core.OperatorSocket("localhost:1234;5678;9999;10001;")
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234;5678;")
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:;1234;5678;")
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234;:;5678;")
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:;;;")
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234")
socket = operatorSocket.GetRetrievalSocket()
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)
}

Expand Down
71 changes: 45 additions & 26 deletions core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"math/big"
"net"
"slices"
"strings"
)
Expand All @@ -19,48 +20,66 @@ func (s OperatorSocket) String() string {
return string(s)
}

func MakeOperatorSocket(nodeIP, dispersalPort, retrievalPort, v2DispersalPort string) OperatorSocket {
if v2DispersalPort == "" {
func MakeOperatorSocket(nodeIP, dispersalPort, retrievalPort, v2DispersalPort, v2RetrievalPort string) OperatorSocket {
//TODO: Add config checks for invalid v1/v2 configs -- for v1 both v2 ports must be empty and for v2 both ports must be valid, reject any other combinations
if v2DispersalPort == "" && v2RetrievalPort == "" {
return OperatorSocket(fmt.Sprintf("%s:%s;%s", nodeIP, dispersalPort, retrievalPort))
}
return OperatorSocket(fmt.Sprintf("%s:%s;%s;%s", nodeIP, dispersalPort, retrievalPort, v2DispersalPort))
return OperatorSocket(fmt.Sprintf("%s:%s;%s;%s;%s", nodeIP, dispersalPort, retrievalPort, v2DispersalPort, v2RetrievalPort))
}

type StakeAmount = *big.Int

func ParseOperatorSocket(socket string) (host string, dispersalPort string, retrievalPort string, v2DispersalPort string, err error) {
s := strings.Split(socket, ";")
func ParseOperatorSocket(socket string) (host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, v2RetrievalPort string, err error) {

if len(s) == 2 {
// no v2 dispersal port
retrievalPort = s[1]
s = strings.Split(s[0], ":")
if len(s) != 2 {
err = fmt.Errorf("invalid socket address format: %s", socket)
return
}
host = s[0]
dispersalPort = s[1]
s := strings.Split(socket, ";")

host, v1DispersalPort, err = net.SplitHostPort(s[0])
if _, err = net.LookupHost(host); err != nil {
//Invalid host
host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, v2RetrievalPort, err =
"", "", "", "", "",
fmt.Errorf("invalid host address format in %s: it must specify valid IP or host name (ex. 0.0.0.0:32004;32005;32006;32007)", socket)
return
}
if err = ValidatePort(v1DispersalPort); err != nil {
host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, v2RetrievalPort, err =
"", "", "", "", "",
fmt.Errorf("invalid v1 dispersal port format in %s: it must specify valid v1 dispersal port (ex. 0.0.0.0:32004;32005;32006;32007)", socket)
return
}

if len(s) == 3 {
// all ports specified
switch len(s) {
case 4:
v2DispersalPort = s[2]
retrievalPort = s[1]
if err = ValidatePort(v2DispersalPort); err != nil {
host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, v2RetrievalPort, err =
"", "", "", "", "",
fmt.Errorf("invalid v2 dispersal port format in %s: it must specify valid v2 dispersal port (ex. 0.0.0.0:32004;32005;32006;32007)", socket)
}

s = strings.Split(s[0], ":")
if len(s) != 2 {
err = fmt.Errorf("invalid socket address format: %s", socket)
return
v2RetrievalPort = s[3]
if err = ValidatePort(v2RetrievalPort); err != nil {
host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, v2RetrievalPort, err =
"", "", "", "", "",
fmt.Errorf("invalid v2 retrieval port format in %s: it must specify valid v2 retrieval port (ex. 0.0.0.0:32004;32005;32006;32007)", socket)
}
fallthrough
case 2:
// V1 Parsing
v1RetrievalPort = s[1]
if err = ValidatePort(v1RetrievalPort); err != nil {
host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, v2RetrievalPort, err =
"", "", "", "", "",
fmt.Errorf("invalid v1 retrieval port format in %s: it must specify valid v1 retrieval port (ex. 0.0.0.0:32004;32005;32006;32007)", socket)
}
host = s[0]
dispersalPort = s[1]
return
default:
host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, v2RetrievalPort, err =
"", "", "", "", "",
fmt.Errorf("invalid socket address format %s: it must specify v1 dispersal/retrieval ports, or v2 dispersal/retrieval ports (ex. 0.0.0.0:32004;32005;32006;32007)", socket)
return
}

return "", "", "", "", fmt.Errorf("invalid socket address format %s: it must specify dispersal port, retrieval port, and/or v2 dispersal port (ex. 0.0.0.0:32004;32005;32006)", socket)
}

// OperatorInfo contains information about an operator which is stored on the blockchain state,
Expand Down
14 changes: 14 additions & 0 deletions core/utils.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package core

import (
"fmt"
"math"
"math/big"
"strconv"

"golang.org/x/exp/constraints"
)
Expand All @@ -23,3 +25,15 @@ func NextPowerOf2[T constraints.Integer](d T) T {
nextPower := math.Ceil(math.Log2(float64(d)))
return T(math.Pow(2.0, nextPower))
}

func ValidatePort(portStr string) error {
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("port is not a valid number: %v", err)
}

if port < 1 || port > 65535 {
return fmt.Errorf("port number out of valid range (1-65535)")
}
return err
}
2 changes: 1 addition & 1 deletion disperser/common/semver/semver.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func ScanOperators(operators map[core.OperatorID]*core.IndexedOperatorInfo, oper
operatorSocket := core.OperatorSocket(operators[operatorId].Socket)
var socket string
if useRetrievalSocket {
socket = operatorSocket.GetRetrievalSocket()
socket = operatorSocket.GetV1RetrievalSocket()
} else {
socket = operatorSocket.GetV1DispersalSocket()
}
Expand Down
2 changes: 1 addition & 1 deletion disperser/controller/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (d *Dispatcher) HandleBatch(ctx context.Context) (chan core.SigningMessage,
for opID, op := range state.IndexedOperators {
opID := opID
op := op
host, _, _, v2DispersalPort, err := core.ParseOperatorSocket(op.Socket)
host, _, _, v2DispersalPort, _, err := core.ParseOperatorSocket(op.Socket)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse operator socket (%s): %w", op.Socket, err)
}
Expand Down
Loading

0 comments on commit 85c70a9

Please sign in to comment.