Skip to content

Commit

Permalink
refactor: symbolLength consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Oct 29, 2024
1 parent e20e9b6 commit 6797360
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 57 deletions.
3 changes: 1 addition & 2 deletions api/clients/accountant.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type Accountant struct {
minNumSymbols uint32

// local accounting
// contains 3 bins; index 0 for current bin, 1 for next bin, 2 for overflowed bin
// contains 3 bins; circular wrapping of indices
binRecords []BinRecord
usageLock sync.Mutex
cumulativePayment *big.Int
Expand Down Expand Up @@ -62,7 +62,6 @@ func NewAccountant(reservation core.ActiveReservation, onDemand core.OnDemandPay
func (a *Accountant) BlobPaymentInfo(ctx context.Context, dataLength uint64) (uint32, *big.Int, error) {
now := time.Now().Unix()
currentBinIndex := meterer.GetBinIndex(uint64(now), a.reservationWindow)
// index := time.Now().Unix() / int64(a.reservationWindow)

a.usageLock.Lock()
defer a.usageLock.Unlock()
Expand Down
46 changes: 23 additions & 23 deletions core/meterer/meterer.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,22 @@ func (m *Meterer) Start(ctx context.Context, updateInterval time.Duration) {

// MeterRequest validates a blob header and adds it to the meterer's state
// TODO: return error if there's a rejection (with reasoning) or internal error (should be very rare)
func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata, blobLength uint, quorumNumbers []uint8) error {
func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata, numSymbols uint, quorumNumbers []uint8) error {
// Validate against the payment method
if header.CumulativePayment.Sign() == 0 {
reservation, err := m.ChainPaymentState.GetActiveReservationByAccount(ctx, header.AccountID)
if err != nil {
return fmt.Errorf("failed to get active reservation by account: %w", err)
}
if err := m.ServeReservationRequest(ctx, header, &reservation, blobLength, quorumNumbers); err != nil {
if err := m.ServeReservationRequest(ctx, header, &reservation, numSymbols, quorumNumbers); err != nil {
return fmt.Errorf("invalid reservation: %w", err)
}
} else {
onDemandPayment, err := m.ChainPaymentState.GetOnDemandPaymentByAccount(ctx, header.AccountID)
if err != nil {
return fmt.Errorf("failed to get on-demand payment by account: %w", err)
}
if err := m.ServeOnDemandRequest(ctx, header, &onDemandPayment, blobLength, quorumNumbers); err != nil {
if err := m.ServeOnDemandRequest(ctx, header, &onDemandPayment, numSymbols, quorumNumbers); err != nil {
return fmt.Errorf("invalid on-demand request: %w", err)
}
}
Expand All @@ -94,7 +94,7 @@ func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata,
}

// ServeReservationRequest handles the rate limiting logic for incoming requests
func (m *Meterer) ServeReservationRequest(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, blobLength uint, quorumNumbers []uint8) error {
func (m *Meterer) ServeReservationRequest(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, numSymbols uint, quorumNumbers []uint8) error {
if err := m.ValidateQuorum(quorumNumbers, reservation.QuorumNumbers); err != nil {
return fmt.Errorf("invalid quorum for reservation: %w", err)
}
Expand All @@ -103,7 +103,7 @@ func (m *Meterer) ServeReservationRequest(ctx context.Context, header core.Payme
}

// Update bin usage atomically and check against reservation's data rate as the bin limit
if err := m.IncrementBinUsage(ctx, header, reservation, blobLength); err != nil {
if err := m.IncrementBinUsage(ctx, header, reservation, numSymbols); err != nil {
return fmt.Errorf("bin overflows: %w", err)
}

Expand Down Expand Up @@ -142,9 +142,9 @@ func (m *Meterer) ValidateBinIndex(header core.PaymentMetadata, reservation *cor
}

// IncrementBinUsage increments the bin usage atomically and checks for overflow
func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, blobLength uint) error {
numSymbols := m.SymbolsCharged(blobLength)
newUsage, err := m.OffchainStore.UpdateReservationBin(ctx, header.AccountID, uint64(header.BinIndex), uint64(numSymbols))
func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, numSymbols uint) error {
symbolsCharged := m.SymbolsCharged(numSymbols)
newUsage, err := m.OffchainStore.UpdateReservationBin(ctx, header.AccountID, uint64(header.BinIndex), uint64(symbolsCharged))
if err != nil {
return fmt.Errorf("failed to increment bin usage: %w", err)
}
Expand Down Expand Up @@ -176,7 +176,7 @@ func GetBinIndex(timestamp uint64, binInterval uint32) uint32 {
// ServeOnDemandRequest handles the rate limiting logic for incoming requests
// On-demand requests doesn't have additional quorum settings and should only be
// allowed by ETH and EIGEN quorums
func (m *Meterer) ServeOnDemandRequest(ctx context.Context, header core.PaymentMetadata, onDemandPayment *core.OnDemandPayment, blobLength uint, headerQuorums []uint8) error {
func (m *Meterer) ServeOnDemandRequest(ctx context.Context, header core.PaymentMetadata, onDemandPayment *core.OnDemandPayment, numSymbols uint, headerQuorums []uint8) error {
quorumNumbers, err := m.ChainPaymentState.GetOnDemandQuorumNumbers(ctx)
if err != nil {
return fmt.Errorf("failed to get on-demand quorum numbers: %w", err)
Expand All @@ -186,13 +186,13 @@ func (m *Meterer) ServeOnDemandRequest(ctx context.Context, header core.PaymentM
return fmt.Errorf("invalid quorum for On-Demand Request: %w", err)
}
// update blob header to use the miniumum chargeable size
symbolsCharged := m.SymbolsCharged(blobLength)
symbolsCharged := m.SymbolsCharged(numSymbols)
err = m.OffchainStore.AddOnDemandPayment(ctx, header, symbolsCharged)
if err != nil {
return fmt.Errorf("failed to update cumulative payment: %w", err)
}
// Validate payments attached
err = m.ValidatePayment(ctx, header, onDemandPayment, blobLength)
err = m.ValidatePayment(ctx, header, onDemandPayment, numSymbols)
if err != nil {
// No tolerance for incorrect payment amounts; no rollbacks
return fmt.Errorf("invalid on-demand payment: %w", err)
Expand All @@ -214,44 +214,44 @@ func (m *Meterer) ServeOnDemandRequest(ctx context.Context, header core.PaymentM
// ValidatePayment checks if the provided payment header is valid against the local accounting
// prevPmt is the largest cumulative payment strictly less than PaymentMetadata.cumulativePayment if exists
// nextPmt is the smallest cumulative payment strictly greater than PaymentMetadata.cumulativePayment if exists
// nextPmtDataLength is the dataLength of corresponding to nextPmt if exists
// prevPmt + PaymentMetadata.DataLength * m.FixedFeePerByte
// nextPmtnumSymbols is the numSymbols of corresponding to nextPmt if exists
// prevPmt + PaymentMetadata.numSymbols * m.FixedFeePerByte
// <= PaymentMetadata.CumulativePayment
// <= nextPmt - nextPmtDataLength * m.FixedFeePerByte > nextPmt
func (m *Meterer) ValidatePayment(ctx context.Context, header core.PaymentMetadata, onDemandPayment *core.OnDemandPayment, blobLength uint) error {
// <= nextPmt - nextPmtnumSymbols * m.FixedFeePerByte > nextPmt
func (m *Meterer) ValidatePayment(ctx context.Context, header core.PaymentMetadata, onDemandPayment *core.OnDemandPayment, numSymbols uint) error {
if header.CumulativePayment.Cmp(onDemandPayment.CumulativePayment) > 0 {
return fmt.Errorf("request claims a cumulative payment greater than the on-chain deposit")
}

prevPmt, nextPmt, nextPmtDataLength, err := m.OffchainStore.GetRelevantOnDemandRecords(ctx, header.AccountID, header.CumulativePayment) // zero if DNE
prevPmt, nextPmt, nextPmtnumSymbols, err := m.OffchainStore.GetRelevantOnDemandRecords(ctx, header.AccountID, header.CumulativePayment) // zero if DNE
if err != nil {
return fmt.Errorf("failed to get relevant on-demand records: %w", err)
}
// the current request must increment cumulative payment by a magnitude sufficient to cover the blob size
if prevPmt+m.PaymentCharged(blobLength) > header.CumulativePayment.Uint64() {
if prevPmt+m.PaymentCharged(numSymbols) > header.CumulativePayment.Uint64() {
return fmt.Errorf("insufficient cumulative payment increment")
}
// the current request must not break the payment magnitude for the next payment if the two requests were delivered out-of-order
if nextPmt != 0 && header.CumulativePayment.Uint64()+m.PaymentCharged(uint(nextPmtDataLength)) > nextPmt {
if nextPmt != 0 && header.CumulativePayment.Uint64()+m.PaymentCharged(uint(nextPmtnumSymbols)) > nextPmt {
return fmt.Errorf("breaking cumulative payment invariants")
}
// check passed: blob can be safely inserted into the set of payments
return nil
}

// PaymentCharged returns the chargeable price for a given data length
func (m *Meterer) PaymentCharged(dataLength uint) uint64 {
return uint64(m.SymbolsCharged(dataLength)) * uint64(m.ChainPaymentState.GetPricePerSymbol())
func (m *Meterer) PaymentCharged(numSymbols uint) uint64 {
return uint64(m.SymbolsCharged(numSymbols)) * uint64(m.ChainPaymentState.GetPricePerSymbol())
}

// SymbolsCharged returns the number of symbols charged for a given data length
// being at least MinNumSymbols or the nearest rounded-up multiple of MinNumSymbols.
func (m *Meterer) SymbolsCharged(dataLength uint) uint32 {
if dataLength <= uint(m.ChainPaymentState.GetMinNumSymbols()) {
func (m *Meterer) SymbolsCharged(numSymbols uint) uint32 {
if numSymbols <= uint(m.ChainPaymentState.GetMinNumSymbols()) {
return m.ChainPaymentState.GetMinNumSymbols()
}
// Round up to the nearest multiple of MinNumSymbols
return uint32(core.RoundUpDivide(uint(dataLength), uint(m.ChainPaymentState.GetMinNumSymbols()))) * m.ChainPaymentState.GetMinNumSymbols()
return uint32(core.RoundUpDivide(uint(numSymbols), uint(m.ChainPaymentState.GetMinNumSymbols()))) * m.ChainPaymentState.GetMinNumSymbols()
}

// ValidateBinIndex checks if the provided bin index is valid
Expand Down
48 changes: 24 additions & 24 deletions core/meterer/meterer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ func TestMetererReservations(t *testing.T) {
assert.ErrorContains(t, err, "invalid bin index for reservation")

// test bin usage metering
dataLength := uint(20)
symbolLength := uint(20)
requiredLength := uint(21) // 21 should be charged for length of 20 since minNumSymbols is 3
for i := 0; i < 9; i++ {
header = createPaymentHeader(binIndex, 0, accountID2)
err = mt.MeterRequest(ctx, *header, dataLength, quoromNumbers)
err = mt.MeterRequest(ctx, *header, symbolLength, quoromNumbers)
assert.NoError(t, err)
item, err := dynamoClient.GetItem(ctx, reservationTableName, commondynamodb.Key{
"AccountID": &types.AttributeValueMemberS{Value: accountID2},
Expand Down Expand Up @@ -296,20 +296,20 @@ func TestMetererOnDemand(t *testing.T) {
assert.Equal(t, 1, len(result))

// test duplicated cumulative payments
dataLength := uint(100)
priceCharged := mt.PaymentCharged(dataLength)
symbolLength := uint(100)
priceCharged := mt.PaymentCharged(symbolLength)
assert.Equal(t, uint64(102*mt.ChainPaymentState.GetPricePerSymbol()), priceCharged)
header = createPaymentHeader(binIndex, priceCharged, accountID2)
err = mt.MeterRequest(ctx, *header, dataLength, quorumNumbers)
err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers)
assert.NoError(t, err)
header = createPaymentHeader(binIndex, priceCharged, accountID2)
err = mt.MeterRequest(ctx, *header, dataLength, quorumNumbers)
err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers)
assert.ErrorContains(t, err, "exact payment already exists")

// test valid payments
for i := 1; i < 9; i++ {
header = createPaymentHeader(binIndex, uint64(priceCharged)*uint64(i+1), accountID2)
err = mt.MeterRequest(ctx, *header, dataLength, quorumNumbers)
err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers)
assert.NoError(t, err)
}

Expand All @@ -320,10 +320,10 @@ func TestMetererOnDemand(t *testing.T) {

// test insufficient increment in cumulative payment
previousCumulativePayment := uint64(priceCharged) * uint64(9)
dataLength = uint(2)
priceCharged = mt.PaymentCharged(dataLength)
symbolLength = uint(2)
priceCharged = mt.PaymentCharged(symbolLength)
header = createPaymentHeader(binIndex, previousCumulativePayment+priceCharged-1, accountID2)
err = mt.MeterRequest(ctx, *header, dataLength, quorumNumbers)
err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers)
assert.ErrorContains(t, err, "invalid on-demand payment: insufficient cumulative payment increment")
previousCumulativePayment = previousCumulativePayment + priceCharged

Expand Down Expand Up @@ -355,42 +355,42 @@ func TestMetererOnDemand(t *testing.T) {
func TestMeterer_paymentCharged(t *testing.T) {
tests := []struct {
name string
dataLength uint
symbolLength uint
pricePerSymbol uint32
minNumSymbols uint32
expected uint64
}{
{
name: "Data length equal to min chargeable size",
dataLength: 1024,
symbolLength: 1024,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 1024,
},
{
name: "Data length less than min chargeable size",
dataLength: 512,
symbolLength: 512,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 1024,
},
{
name: "Data length greater than min chargeable size",
dataLength: 2048,
symbolLength: 2048,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 2048,
},
{
name: "Large data length",
dataLength: 1 << 20, // 1 MB
symbolLength: 1 << 20, // 1 MB
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 1 << 20,
},
{
name: "Price not evenly divisible by min chargeable size",
dataLength: 1536,
symbolLength: 1536,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 2048,
Expand All @@ -405,7 +405,7 @@ func TestMeterer_paymentCharged(t *testing.T) {
m := &meterer.Meterer{
ChainPaymentState: paymentChainState,
}
result := m.PaymentCharged(tt.dataLength)
result := m.PaymentCharged(tt.symbolLength)
assert.Equal(t, tt.expected, result)
})
}
Expand All @@ -414,37 +414,37 @@ func TestMeterer_paymentCharged(t *testing.T) {
func TestMeterer_symbolsCharged(t *testing.T) {
tests := []struct {
name string
dataLength uint
symbolLength uint
minNumSymbols uint32
expected uint32
}{
{
name: "Data length equal to min number of symobols",
dataLength: 1024,
symbolLength: 1024,
minNumSymbols: 1024,
expected: 1024,
},
{
name: "Data length less than min number of symbols",
dataLength: 512,
symbolLength: 512,
minNumSymbols: 1024,
expected: 1024,
},
{
name: "Data length greater than min number of symbols",
dataLength: 2048,
symbolLength: 2048,
minNumSymbols: 1024,
expected: 2048,
},
{
name: "Large data length",
dataLength: 1 << 20, // 1 MB
symbolLength: 1 << 20, // 1 MB
minNumSymbols: 1024,
expected: 1 << 20,
},
{
name: "Very small data length",
dataLength: 16,
symbolLength: 16,
minNumSymbols: 1024,
expected: 1024,
},
Expand All @@ -457,7 +457,7 @@ func TestMeterer_symbolsCharged(t *testing.T) {
m := &meterer.Meterer{
ChainPaymentState: paymentChainState,
}
result := m.SymbolsCharged(tt.dataLength)
result := m.SymbolsCharged(tt.symbolLength)
assert.Equal(t, tt.expected, result)
})
}
Expand Down
6 changes: 3 additions & 3 deletions disperser/apiserver/payment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestDispersePaidBlob(t *testing.T) {

dispersalServer := newTestServer(transactor, t.Name())

data := make([]byte, 1024*encoding.BYTES_PER_SYMBOL)
data := make([]byte, 1024)
_, err := rand.Read(data)
assert.NoError(t, err)

Expand All @@ -63,7 +63,7 @@ func TestDispersePaidBlob(t *testing.T) {
pm := pbcommon.PaymentHeader{
AccountId: signer.GetAccountID(),
BinIndex: 0,
CumulativePayment: big.NewInt(int64(int(symbolLength) * i)).Bytes(),
CumulativePayment: big.NewInt(int64(int(symbolLength) * i * encoding.BYTES_PER_SYMBOL)).Bytes(),
}
sig, err := signer.SignBlobPayment(&pm)
assert.NoError(t, err)
Expand All @@ -82,7 +82,7 @@ func TestDispersePaidBlob(t *testing.T) {
pm := pbcommon.PaymentHeader{
AccountId: signer.GetAccountID(),
BinIndex: 0,
CumulativePayment: big.NewInt(int64(symbolLength*3) - 1).Bytes(),
CumulativePayment: big.NewInt(int64(symbolLength*3)*encoding.BYTES_PER_SYMBOL - 1).Bytes(),
}
sig, err := signer.SignBlobPayment(&pm)
assert.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion disperser/apiserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ func newTestServer(transactor core.Writer, testName string) *apiserver.Dispersal
panic("failed to make initial query to the on-chain state")
}

mockState.On("GetPricePerSymbol").Return(uint32(1), nil)
mockState.On("GetPricePerSymbol").Return(uint32(encoding.BYTES_PER_SYMBOL), nil)
mockState.On("GetMinNumSymbols").Return(uint32(1), nil)
mockState.On("GetGlobalSymbolsPerSecond").Return(uint64(4096), nil)
mockState.On("GetRequiredQuorumNumbers").Return([]uint8{0, 1}, nil)
Expand Down
9 changes: 5 additions & 4 deletions inabox/tests/payment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/Layr-Labs/eigenda/core"
"github.com/Layr-Labs/eigenda/core/auth"
"github.com/Layr-Labs/eigenda/disperser"
"github.com/Layr-Labs/eigenda/encoding"
"github.com/ethereum/go-ethereum/crypto"

"github.com/Layr-Labs/eigenda/encoding/utils/codec"
Expand Down Expand Up @@ -46,8 +47,8 @@ var _ = Describe("Inabox Integration", func() {

Expect(disp).To(Not(BeNil()))

singleBlobSize := uint32(128)
data := make([]byte, singleBlobSize)
blobLength := uint32(4)
data := make([]byte, blobLength*encoding.BYTES_PER_SYMBOL)
_, err = rand.Read(data)
Expect(err).To(BeNil())

Expand All @@ -59,8 +60,8 @@ var _ = Describe("Inabox Integration", func() {
reservationBytesLimit := 1024
paymentLimit := 512
// TODO: payment calculation unit consistency
for i := 0; i < (int(reservationBytesLimit+paymentLimit))/int(singleBlobSize); i++ {
blobStatus, key, err := disp.DisperseBlob(ctx, paddedData, []uint8{0})
for i := 0; i < (int(reservationBytesLimit+paymentLimit))/int(blobLength); i++ {
blobStatus, key, err := disp.DispersePaidBlob(ctx, paddedData, []uint8{0})
Expect(err).To(BeNil())
Expect(key).To(Not(BeNil()))
Expect(blobStatus).To(Not(BeNil()))
Expand Down

0 comments on commit 6797360

Please sign in to comment.