diff --git a/common/aws/dynamodb/client.go b/common/aws/dynamodb/client.go index 53587495df..eeff546e88 100644 --- a/common/aws/dynamodb/client.go +++ b/common/aws/dynamodb/client.go @@ -255,6 +255,15 @@ func (c *Client) Query(ctx context.Context, tableName string, keyCondition strin return response.Items, nil } +// QueryWithInput is a wrapper for the Query function that allows for a custom query input +func (c *Client) QueryWithInput(ctx context.Context, input *dynamodb.QueryInput) ([]Item, error) { + response, err := c.dynamoClient.Query(ctx, input) + if err != nil { + return nil, err + } + return response.Items, nil +} + // QueryIndexCount returns the count of the items in the index that match the given key func (c *Client) QueryIndexCount(ctx context.Context, tableName string, indexName string, keyCondition string, expAttributeValues ExpressionValues) (int32, error) { response, err := c.dynamoClient.Query(ctx, &dynamodb.QueryInput{ diff --git a/common/aws/dynamodb/client_test.go b/common/aws/dynamodb/client_test.go index a67db4e6c6..8b7cee2580 100644 --- a/common/aws/dynamodb/client_test.go +++ b/common/aws/dynamodb/client_test.go @@ -661,3 +661,91 @@ func TestQueryIndexOrderWithLimit(t *testing.T) { assert.NoError(t, err) assert.Len(t, queryResult, 30) // Should return all items } + +func TestQueryWithInput(t *testing.T) { + tableName := "ProcessingQueryWithInput" + createTable(t, tableName) + + ctx := context.Background() + numItems := 30 + items := make([]commondynamodb.Item, numItems) + for i := 0; i < numItems; i++ { + requestedAt := time.Now().Add(-time.Duration(i) * time.Minute).Unix() + items[i] = commondynamodb.Item{ + "MetadataKey": &types.AttributeValueMemberS{Value: fmt.Sprintf("key%d", i)}, + "BlobKey": &types.AttributeValueMemberS{Value: fmt.Sprintf("blob%d", i)}, + "BlobSize": &types.AttributeValueMemberN{Value: "123"}, + "BlobStatus": &types.AttributeValueMemberN{Value: "0"}, + "RequestedAt": &types.AttributeValueMemberN{Value: strconv.FormatInt(requestedAt, 10)}, + } + } + unprocessed, err := dynamoClient.PutItems(ctx, tableName, items) + assert.NoError(t, err) + assert.Len(t, unprocessed, 0) + + // Test forward order with limit + queryInput := &dynamodb.QueryInput{ + TableName: aws.String(tableName), + IndexName: aws.String("StatusIndex"), + KeyConditionExpression: aws.String("BlobStatus = :status"), + ExpressionAttributeValues: commondynamodb.ExpressionValues{ + ":status": &types.AttributeValueMemberN{Value: "0"}, + }, + ScanIndexForward: aws.Bool(true), + Limit: aws.Int32(10), + } + queryResult, err := dynamoClient.QueryWithInput(ctx, queryInput) + assert.NoError(t, err) + assert.Len(t, queryResult, 10) + // Check if the items are in ascending order + for i := 0; i < len(queryResult)-1; i++ { + assert.True(t, queryResult[i]["RequestedAt"].(*types.AttributeValueMemberN).Value <= queryResult[i+1]["RequestedAt"].(*types.AttributeValueMemberN).Value) + } + + // Test reverse order with limit + queryInput = &dynamodb.QueryInput{ + TableName: aws.String(tableName), + IndexName: aws.String("StatusIndex"), + KeyConditionExpression: aws.String("BlobStatus = :status"), + ExpressionAttributeValues: commondynamodb.ExpressionValues{ + ":status": &types.AttributeValueMemberN{Value: "0"}, + }, + ScanIndexForward: aws.Bool(false), + Limit: aws.Int32(10), + } + queryResult, err = dynamoClient.QueryWithInput(ctx, queryInput) + assert.NoError(t, err) + assert.Len(t, queryResult, 10) + // Check if the items are in descending order + for i := 0; i < len(queryResult)-1; i++ { + assert.True(t, queryResult[i]["RequestedAt"].(*types.AttributeValueMemberN).Value >= queryResult[i+1]["RequestedAt"].(*types.AttributeValueMemberN).Value) + } + + // Test with a smaller limit + queryInput = &dynamodb.QueryInput{ + TableName: aws.String(tableName), + IndexName: aws.String("StatusIndex"), + KeyConditionExpression: aws.String("BlobStatus = :status"), + ExpressionAttributeValues: commondynamodb.ExpressionValues{ + ":status": &types.AttributeValueMemberN{Value: "0"}, + }, + Limit: aws.Int32(5), + } + queryResult, err = dynamoClient.QueryWithInput(ctx, queryInput) + assert.NoError(t, err) + assert.Len(t, queryResult, 5) + + // Test with a limit larger than the number of items + queryInput = &dynamodb.QueryInput{ + TableName: aws.String(tableName), + IndexName: aws.String("StatusIndex"), + KeyConditionExpression: aws.String("BlobStatus = :status"), + ExpressionAttributeValues: commondynamodb.ExpressionValues{ + ":status": &types.AttributeValueMemberN{Value: "0"}, + }, + Limit: aws.Int32(50), + } + queryResult, err = dynamoClient.QueryWithInput(ctx, queryInput) + assert.NoError(t, err) + assert.Len(t, queryResult, 30) // Should return all items +} diff --git a/core/data.go b/core/data.go index aeaf082bd2..21e24330c8 100644 --- a/core/data.go +++ b/core/data.go @@ -222,6 +222,14 @@ type Blob struct { Data []byte } +func (b *Blob) GetQuorumNumbers() []uint8 { + quorumNumbers := make([]uint8, 0, len(b.RequestHeader.SecurityParams)) + for _, sp := range b.RequestHeader.SecurityParams { + quorumNumbers = append(quorumNumbers, sp.QuorumID) + } + return quorumNumbers +} + // BlobAuthHeader contains the data that a user must sign to authenticate a blob request. // Signing the combination of the Nonce and the BlobCommitments prohibits the disperser from // using the signature to charge the user for a different blob or for dispersing the same blob @@ -482,22 +490,24 @@ type PaymentMetadata struct { BinIndex uint32 // TODO: we are thinking the contract can use uint128 for cumulative payment, // but the definition on v2 uses uint64. Double check with team. - CumulativePayment uint64 + CumulativePayment *big.Int } // Hash returns the Keccak256 hash of the PaymentMetadata func (pm *PaymentMetadata) Hash() []byte { // Create a byte slice to hold the serialized data - data := make([]byte, 0, len(pm.AccountID)+12) + data := make([]byte, 0, len(pm.AccountID)+4+pm.CumulativePayment.BitLen()/8+1) + // Append AccountID data = append(data, []byte(pm.AccountID)...) + // Append BinIndex binIndexBytes := make([]byte, 4) binary.BigEndian.PutUint32(binIndexBytes, pm.BinIndex) data = append(data, binIndexBytes...) - paymentBytes := make([]byte, 8) - binary.BigEndian.PutUint64(paymentBytes, pm.CumulativePayment) + // Append CumulativePayment + paymentBytes := pm.CumulativePayment.Bytes() data = append(data, paymentBytes...) return crypto.Keccak256(data) @@ -506,12 +516,12 @@ func (pm *PaymentMetadata) Hash() []byte { // OperatorInfo contains information about an operator which is stored on the blockchain state, // corresponding to a particular quorum type ActiveReservation struct { - DataRate uint64 // Bandwidth per reservation bin + SymbolsPerSec uint64 // reserve number of symbols per second StartTimestamp uint64 // Unix timestamp that's valid for basically eternity EndTimestamp uint64 - QuorumNumbers []uint8 - QuorumSplit []byte // ordered mapping of quorum number to payment split; on-chain validation should ensure split <= 100 + QuorumNumbers []uint8 // allowed quorums + QuorumSplit []byte // ordered mapping of quorum number to payment split; on-chain validation should ensure split <= 100 } type OnDemandPayment struct { diff --git a/core/eth/tx.go b/core/eth/tx.go index ccf85c91dd..1386a63734 100644 --- a/core/eth/tx.go +++ b/core/eth/tx.go @@ -763,16 +763,26 @@ func (t *Transactor) GetRequiredQuorumNumbers(ctx context.Context, blockNumber u return requiredQuorums, nil } -func (t *Transactor) GetActiveReservations(ctx context.Context, blockNumber uint32) (map[string]core.ActiveReservation, error) { +func (t *Transactor) GetActiveReservations(ctx context.Context, blockNumber uint32, accountIDs []string) (map[string]core.ActiveReservation, error) { // contract is not implemented yet return map[string]core.ActiveReservation{}, nil } -func (t *Transactor) GetOnDemandPayments(ctx context.Context, blockNumber uint32) (map[string]core.OnDemandPayment, error) { +func (t *Transactor) GetActiveReservationByAccount(ctx context.Context, blockNumber uint32, accountID string) (core.ActiveReservation, error) { + // contract is not implemented yet + return core.ActiveReservation{}, nil +} + +func (t *Transactor) GetOnDemandPayments(ctx context.Context, blockNumber uint32, accountIDs []string) (map[string]core.OnDemandPayment, error) { // contract is not implemented yet return map[string]core.OnDemandPayment{}, nil } +func (t *Transactor) GetOnDemandPaymentByAccount(ctx context.Context, blockNumber uint32, accountID string) (core.OnDemandPayment, error) { + // contract is not implemented yet + return core.OnDemandPayment{}, nil +} + func (t *Transactor) updateContractBindings(blsOperatorStateRetrieverAddr, eigenDAServiceManagerAddr gethcommon.Address) error { contractEigenDAServiceManager, err := eigendasrvmg.NewContractEigenDAServiceManager(eigenDAServiceManagerAddr, t.EthClient) diff --git a/core/meterer/meterer.go b/core/meterer/meterer.go index e0361d63e9..d0f508c117 100644 --- a/core/meterer/meterer.go +++ b/core/meterer/meterer.go @@ -1,19 +1,23 @@ package meterer import ( + "context" + "fmt" + "slices" "time" + "github.com/Layr-Labs/eigenda/core" "github.com/Layr-Labs/eigensdk-go/logging" ) // Config contains network parameters that should be published on-chain. We currently configure these params through disperser env vars. type Config struct { - // GlobalBytesPerSecond is the rate limit in bytes per second for on-demand payments - GlobalBytesPerSecond uint64 - // MinChargeableSize is the minimum size of a chargeable unit in bytes, used as a floor for on-demand payments - MinChargeableSize uint32 - // PricePerChargeable is the price per chargeable unit in gwei, used for on-demand payments - PricePerChargeable uint32 + // GlobalSymbolsPerSecond rate limit in symbols per second for on-demand payments + GlobalSymbolsPerSecond uint64 + // MinNumSymbols is the minimum number of symbols charged, round up for all smaller requests (must be in power of 2) + MinNumSymbols uint32 + // PricePerSymbol is the price per symbol in gwei, used for on-demand payments + PricePerSymbol uint32 // ReservationWindow is the duration of all reservations in seconds, used to calculate bin indices ReservationWindow uint32 @@ -26,7 +30,6 @@ type Config struct { // payments information is valid. type Meterer struct { Config - // ChainState reads on-chain payment state periodically and cache it in memory ChainState OnchainPayment // OffchainStore uses DynamoDB to track metering and used to validate requests @@ -40,8 +43,7 @@ func NewMeterer( paymentChainState OnchainPayment, offchainStore OffchainStore, logger logging.Logger, -) (*Meterer, error) { - // TODO: create a separate thread to pull from the chain and update chain state +) *Meterer { return &Meterer{ Config: config, @@ -49,5 +51,240 @@ func NewMeterer( OffchainStore: offchainStore, logger: logger.With("component", "Meterer"), - }, nil + } +} + +// Start starts to periodically refreshing the on-chain state +func (m *Meterer) Start(ctx context.Context) { + go func() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := m.ChainState.RefreshOnchainPaymentState(ctx, nil); err != nil { + m.logger.Error("Failed to refresh on-chain state", "error", err) + } + case <-ctx.Done(): + return + } + } + }() +} + +// MeterRequest validates a blob header and adds it to the meterer's state +// TODO: return error if there's a rejection (with reasoning) or internal error (should be very rare) +func (m *Meterer) MeterRequest(ctx context.Context, blob core.Blob, header core.PaymentMetadata) error { + headerQuorums := blob.GetQuorumNumbers() + // Validate against the payment method + if header.CumulativePayment.Sign() == 0 { + reservation, err := m.ChainState.GetActiveReservationByAccount(ctx, header.AccountID) + if err != nil { + return fmt.Errorf("failed to get active reservation by account: %w", err) + } + if err := m.ServeReservationRequest(ctx, header, &reservation, blob.RequestHeader.BlobAuthHeader.Length, headerQuorums); err != nil { + return fmt.Errorf("invalid reservation: %w", err) + } + } else { + onDemandPayment, err := m.ChainState.GetOnDemandPaymentByAccount(ctx, header.AccountID) + if err != nil { + return fmt.Errorf("failed to get on-demand payment by account: %w", err) + } + if err := m.ServeOnDemandRequest(ctx, header, &onDemandPayment, blob.RequestHeader.BlobAuthHeader.Length, headerQuorums); err != nil { + return fmt.Errorf("invalid on-demand request: %w", err) + } + } + + return nil +} + +// ServeReservationRequest handles the rate limiting logic for incoming requests +func (m *Meterer) ServeReservationRequest(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, blobLength uint, quorumNumbers []uint8) error { + if err := m.ValidateQuorum(quorumNumbers, reservation.QuorumNumbers); err != nil { + return fmt.Errorf("invalid quorum for reservation: %w", err) + } + if !m.ValidateBinIndex(header, reservation) { + return fmt.Errorf("invalid bin index for reservation") + } + + // Update bin usage atomically and check against reservation's data rate as the bin limit + if err := m.IncrementBinUsage(ctx, header, reservation, blobLength); err != nil { + return fmt.Errorf("bin overflows: %w", err) + } + + return nil +} + +// ValidateQuorums ensures that the quorums listed in the blobHeader are present within allowedQuorums +// Note: A reservation that does not utilize all of the allowed quorums will be accepted. However, it +// will still charge against all of the allowed quorums. A on-demand requrests require and only allow +// the ETH and EIGEN quorums. +func (m *Meterer) ValidateQuorum(headerQuorums []uint8, allowedQuorums []uint8) error { + if len(headerQuorums) == 0 { + return fmt.Errorf("no quorum params in blob header") + } + + // check that all the quorum ids are in ActiveReservation's + for _, q := range headerQuorums { + if !slices.Contains(allowedQuorums, q) { + // fail the entire request if there's a quorum number mismatch + return fmt.Errorf("quorum number mismatch: %d", q) + } + } + return nil +} + +// ValidateBinIndex checks if the provided bin index is valid +func (m *Meterer) ValidateBinIndex(header core.PaymentMetadata, reservation *core.ActiveReservation) bool { + now := uint64(time.Now().Unix()) + currentBinIndex := GetBinIndex(now, m.ReservationWindow) + // Valid bin indexes are either the current bin or the previous bin + if (header.BinIndex != currentBinIndex && header.BinIndex != (currentBinIndex-1)) || (GetBinIndex(reservation.StartTimestamp, m.ReservationWindow) > header.BinIndex || header.BinIndex > GetBinIndex(reservation.EndTimestamp, m.ReservationWindow)) { + return false + } + return true +} + +// IncrementBinUsage increments the bin usage atomically and checks for overflow +func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, blobLength uint) error { + numSymbols := m.SymbolsCharged(blobLength) + newUsage, err := m.OffchainStore.UpdateReservationBin(ctx, header.AccountID, uint64(header.BinIndex), uint64(numSymbols)) + if err != nil { + return fmt.Errorf("failed to increment bin usage: %w", err) + } + + // metered usage stays within the bin limit + usageLimit := m.GetReservationBinLimit(reservation) + if newUsage <= usageLimit { + return nil + } else if newUsage-uint64(numSymbols) >= usageLimit { + // metered usage before updating the size already exceeded the limit + return fmt.Errorf("bin has already been filled") + } + if newUsage <= 2*usageLimit && header.BinIndex+2 <= GetBinIndex(reservation.EndTimestamp, m.ReservationWindow) { + _, err := m.OffchainStore.UpdateReservationBin(ctx, header.AccountID, uint64(header.BinIndex+2), newUsage-usageLimit) + if err != nil { + return err + } + return nil + } + return fmt.Errorf("overflow usage exceeds bin limit") +} + +// GetBinIndex returns the current bin index by chunking time by the bin interval; +// bin interval used by the disperser should be public information +func GetBinIndex(timestamp uint64, binInterval uint32) uint32 { + return uint32(timestamp) / binInterval +} + +// ServeOnDemandRequest handles the rate limiting logic for incoming requests +// On-demand requests doesn't have additional quorum settings and should only be +// allowed by ETH and EIGEN quorums +func (m *Meterer) ServeOnDemandRequest(ctx context.Context, header core.PaymentMetadata, onDemandPayment *core.OnDemandPayment, blobLength uint, headerQuorums []uint8) error { + quorumNumbers, err := m.ChainState.GetOnDemandQuorumNumbers(ctx) + if err != nil { + return fmt.Errorf("failed to get on-demand quorum numbers: %w", err) + } + + if err := m.ValidateQuorum(headerQuorums, quorumNumbers); err != nil { + return fmt.Errorf("invalid quorum for On-Demand Request: %w", err) + } + // update blob header to use the miniumum chargeable size + symbolsCharged := m.SymbolsCharged(blobLength) + err = m.OffchainStore.AddOnDemandPayment(ctx, header, symbolsCharged) + if err != nil { + return fmt.Errorf("failed to update cumulative payment: %w", err) + } + // Validate payments attached + err = m.ValidatePayment(ctx, header, onDemandPayment, blobLength) + if err != nil { + // No tolerance for incorrect payment amounts; no rollbacks + return fmt.Errorf("invalid on-demand payment: %w", err) + } + + // Update bin usage atomically and check against bin capacity + if err := m.IncrementGlobalBinUsage(ctx, uint64(symbolsCharged)); err != nil { + //TODO: conditionally remove the payment based on the error type (maybe if the error is store-op related) + err := m.OffchainStore.RemoveOnDemandPayment(ctx, header.AccountID, header.CumulativePayment) + if err != nil { + return err + } + return fmt.Errorf("failed global rate limiting") + } + + return nil +} + +// ValidatePayment checks if the provided payment header is valid against the local accounting +// prevPmt is the largest cumulative payment strictly less than PaymentMetadata.cumulativePayment if exists +// nextPmt is the smallest cumulative payment strictly greater than PaymentMetadata.cumulativePayment if exists +// nextPmtDataLength is the dataLength of corresponding to nextPmt if exists +// prevPmt + PaymentMetadata.DataLength * m.FixedFeePerByte +// <= PaymentMetadata.CumulativePayment +// <= nextPmt - nextPmtDataLength * m.FixedFeePerByte > nextPmt +func (m *Meterer) ValidatePayment(ctx context.Context, header core.PaymentMetadata, onDemandPayment *core.OnDemandPayment, blobLength uint) error { + if header.CumulativePayment.Cmp(onDemandPayment.CumulativePayment) > 0 { + return fmt.Errorf("request claims a cumulative payment greater than the on-chain deposit") + } + + prevPmt, nextPmt, nextPmtDataLength, err := m.OffchainStore.GetRelevantOnDemandRecords(ctx, header.AccountID, header.CumulativePayment) // zero if DNE + if err != nil { + return fmt.Errorf("failed to get relevant on-demand records: %w", err) + } + // the current request must increment cumulative payment by a magnitude sufficient to cover the blob size + if prevPmt+m.PaymentCharged(blobLength) > header.CumulativePayment.Uint64() { + return fmt.Errorf("insufficient cumulative payment increment") + } + // the current request must not break the payment magnitude for the next payment if the two requests were delivered out-of-order + if nextPmt != 0 && header.CumulativePayment.Uint64()+m.PaymentCharged(uint(nextPmtDataLength)) > nextPmt { + return fmt.Errorf("breaking cumulative payment invariants") + } + // check passed: blob can be safely inserted into the set of payments + return nil +} + +// PaymentCharged returns the chargeable price for a given data length +func (m *Meterer) PaymentCharged(dataLength uint) uint64 { + return uint64(m.SymbolsCharged(dataLength)) * uint64(m.PricePerSymbol) +} + +// SymbolsCharged returns the number of symbols charged for a given data length +// being at least MinNumSymbols or the nearest rounded-up multiple of MinNumSymbols. +func (m *Meterer) SymbolsCharged(dataLength uint) uint32 { + if dataLength <= uint(m.MinNumSymbols) { + return m.MinNumSymbols + } + // Round up to the nearest multiple of MinNumSymbols + return uint32(core.RoundUpDivide(uint(dataLength), uint(m.MinNumSymbols))) * m.MinNumSymbols +} + +// ValidateBinIndex checks if the provided bin index is valid +func (m *Meterer) ValidateGlobalBinIndex(header core.PaymentMetadata) (uint32, error) { + // Deterministic function: local clock -> index (1second intervals) + currentBinIndex := uint32(time.Now().Unix()) + + // Valid bin indexes are either the current bin or the previous bin (allow this second or prev sec) + if header.BinIndex != currentBinIndex && header.BinIndex != (currentBinIndex-1) { + return 0, fmt.Errorf("invalid bin index for on-demand request") + } + return currentBinIndex, nil +} + +// IncrementBinUsage increments the bin usage atomically and checks for overflow +func (m *Meterer) IncrementGlobalBinUsage(ctx context.Context, symbolsCharged uint64) error { + globalIndex := uint64(time.Now().Unix()) + newUsage, err := m.OffchainStore.UpdateGlobalBin(ctx, globalIndex, symbolsCharged) + if err != nil { + return fmt.Errorf("failed to increment global bin usage: %w", err) + } + if newUsage > m.GlobalSymbolsPerSecond { + return fmt.Errorf("global bin usage overflows") + } + return nil +} + +// GetReservationBinLimit returns the bin limit for a given reservation +func (m *Meterer) GetReservationBinLimit(reservation *core.ActiveReservation) uint64 { + return reservation.SymbolsPerSec * uint64(m.ReservationWindow) } diff --git a/core/meterer/meterer_test.go b/core/meterer/meterer_test.go index 64ef0be08f..5883642f22 100644 --- a/core/meterer/meterer_test.go +++ b/core/meterer/meterer_test.go @@ -1,36 +1,51 @@ package meterer_test import ( - "crypto/ecdsa" + "context" + "errors" "fmt" + "math/big" "os" + "strconv" "testing" "time" "github.com/Layr-Labs/eigenda/common" commonaws "github.com/Layr-Labs/eigenda/common/aws" commondynamodb "github.com/Layr-Labs/eigenda/common/aws/dynamodb" + "github.com/Layr-Labs/eigenda/core" "github.com/Layr-Labs/eigenda/core/meterer" "github.com/Layr-Labs/eigenda/core/mock" + "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/inabox/deploy" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/assert" + testifymock "github.com/stretchr/testify/mock" "github.com/Layr-Labs/eigensdk-go/logging" ) var ( - dockertestPool *dockertest.Pool - dockertestResource *dockertest.Resource - dynamoClient *commondynamodb.Client - clientConfig commonaws.ClientConfig - privateKey1 *ecdsa.PrivateKey - privateKey2 *ecdsa.PrivateKey - mt *meterer.Meterer - - deployLocalStack bool - localStackPort = "4566" - paymentChainState = &mock.MockOnchainPaymentState{} + dockertestPool *dockertest.Pool + dockertestResource *dockertest.Resource + dynamoClient *commondynamodb.Client + clientConfig commonaws.ClientConfig + accountID1 string + account1Reservations core.ActiveReservation + account1OnDemandPayments core.OnDemandPayment + accountID2 string + account2Reservations core.ActiveReservation + account2OnDemandPayments core.OnDemandPayment + mt *meterer.Meterer + + deployLocalStack bool + localStackPort = "4566" + paymentChainState = &mock.MockOnchainPaymentState{} + ondemandTableName = "ondemand-meterer-test" + reservationTableName = "reservations-meterer-test" + globalReservationTableName = "global-reservation-meterer-test" ) func TestMain(m *testing.M) { @@ -76,12 +91,12 @@ func setup(_ *testing.M) { panic("failed to create dynamodb client") } - privateKey1, err = crypto.GenerateKey() + privateKey1, err := crypto.GenerateKey() if err != nil { teardown() panic("failed to generate private key") } - privateKey2, err = crypto.GenerateKey() + privateKey2, err := crypto.GenerateKey() if err != nil { teardown() panic("failed to generate private key") @@ -89,34 +104,42 @@ func setup(_ *testing.M) { logger = logging.NewNoopLogger() config := meterer.Config{ - PricePerChargeable: 1, - MinChargeableSize: 1, - GlobalBytesPerSecond: 1000, - ReservationWindow: 60, - ChainReadTimeout: 3 * time.Second, + PricePerSymbol: 2, + MinNumSymbols: 3, + GlobalSymbolsPerSecond: 1009, + ReservationWindow: 1, + ChainReadTimeout: 3 * time.Second, } - err = meterer.CreateReservationTable(clientConfig, "reservations") + err = meterer.CreateReservationTable(clientConfig, reservationTableName) if err != nil { teardown() panic("failed to create reservation table") } - err = meterer.CreateOnDemandTable(clientConfig, "ondemand") + err = meterer.CreateOnDemandTable(clientConfig, ondemandTableName) if err != nil { teardown() panic("failed to create ondemand table") } - err = meterer.CreateGlobalReservationTable(clientConfig, "global") + err = meterer.CreateGlobalReservationTable(clientConfig, globalReservationTableName) if err != nil { teardown() panic("failed to create global reservation table") } + now := uint64(time.Now().Unix()) + accountID1 = crypto.PubkeyToAddress(privateKey1.PublicKey).Hex() + accountID2 = crypto.PubkeyToAddress(privateKey2.PublicKey).Hex() + account1Reservations = core.ActiveReservation{SymbolsPerSec: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplit: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}} + account2Reservations = core.ActiveReservation{SymbolsPerSec: 200, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplit: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}} + account1OnDemandPayments = core.OnDemandPayment{CumulativePayment: big.NewInt(3864)} + account2OnDemandPayments = core.OnDemandPayment{CumulativePayment: big.NewInt(2000)} + store, err := meterer.NewOffchainStore( clientConfig, - "reservations", - "ondemand", - "global", + reservationTableName, + ondemandTableName, + globalReservationTableName, logger, ) @@ -126,18 +149,14 @@ func setup(_ *testing.M) { } // add some default sensible configs - mt, err = meterer.NewMeterer( + mt = meterer.NewMeterer( config, paymentChainState, store, logging.NewNoopLogger(), // metrics.NewNoopMetrics(), ) - - if err != nil { - teardown() - panic("failed to create meterer") - } + mt.Start(context.Background()) } func teardown() { @@ -145,3 +164,320 @@ func teardown() { deploy.PurgeDockertestResources(dockertestPool, dockertestResource) } } + +func TestMetererReservations(t *testing.T) { + ctx := context.Background() + binIndex := meterer.GetBinIndex(uint64(time.Now().Unix()), mt.ReservationWindow) + quoromNumbers := []uint8{0, 1} + paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account string) bool { + return account == accountID1 + })).Return(account1Reservations, nil) + paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account string) bool { + return account == accountID2 + })).Return(account2Reservations, nil) + paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(core.ActiveReservation{}, errors.New("reservation not found")) + + // test invalid quorom ID + blob, header := createMetererInput(1, 0, 1000, []uint8{0, 1, 2}, accountID1) + err := mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "quorum number mismatch") + + // overwhelming bin overflow for empty bins + blob, header = createMetererInput(binIndex-1, 0, 10, quoromNumbers, accountID2) + err = mt.MeterRequest(ctx, *blob, *header) + assert.NoError(t, err) + // overwhelming bin overflow for empty bins + blob, header = createMetererInput(binIndex-1, 0, 1000, quoromNumbers, accountID2) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "overflow usage exceeds bin limit") + + // test non-existent account + unregisteredUser, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + blob, header = createMetererInput(1, 0, 1000, []uint8{0, 1, 2}, crypto.PubkeyToAddress(unregisteredUser.PublicKey).Hex()) + assert.NoError(t, err) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "failed to get active reservation by account: reservation not found") + + // test invalid bin index + blob, header = createMetererInput(binIndex, 0, 2000, quoromNumbers, accountID1) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "invalid bin index for reservation") + + // test bin usage metering + dataLength := uint(20) + requiredLength := uint(21) // 21 should be charged for length of 20 since minNumSymbols is 3 + for i := 0; i < 9; i++ { + blob, header = createMetererInput(binIndex, 0, dataLength, quoromNumbers, accountID2) + err = mt.MeterRequest(ctx, *blob, *header) + assert.NoError(t, err) + item, err := dynamoClient.GetItem(ctx, reservationTableName, commondynamodb.Key{ + "AccountID": &types.AttributeValueMemberS{Value: accountID2}, + "BinIndex": &types.AttributeValueMemberN{Value: strconv.Itoa(int(binIndex))}, + }) + assert.NoError(t, err) + assert.Equal(t, accountID2, item["AccountID"].(*types.AttributeValueMemberS).Value) + assert.Equal(t, strconv.Itoa(int(binIndex)), item["BinIndex"].(*types.AttributeValueMemberN).Value) + assert.Equal(t, strconv.Itoa((i+1)*int(requiredLength)), item["BinUsage"].(*types.AttributeValueMemberN).Value) + + } + // first over flow is allowed + blob, header = createMetererInput(binIndex, 0, 25, quoromNumbers, accountID2) + assert.NoError(t, err) + err = mt.MeterRequest(ctx, *blob, *header) + assert.NoError(t, err) + overflowedBinIndex := binIndex + 2 + item, err := dynamoClient.GetItem(ctx, reservationTableName, commondynamodb.Key{ + "AccountID": &types.AttributeValueMemberS{Value: accountID2}, + "BinIndex": &types.AttributeValueMemberN{Value: strconv.Itoa(int(overflowedBinIndex))}, + }) + assert.NoError(t, err) + assert.Equal(t, accountID2, item["AccountID"].(*types.AttributeValueMemberS).Value) + assert.Equal(t, strconv.Itoa(int(overflowedBinIndex)), item["BinIndex"].(*types.AttributeValueMemberN).Value) + // 25 rounded up to the nearest multiple of minNumSymbols - (200-21*9) = 16 + assert.Equal(t, strconv.Itoa(int(16)), item["BinUsage"].(*types.AttributeValueMemberN).Value) + + // second over flow + blob, header = createMetererInput(binIndex, 0, 1, quoromNumbers, accountID2) + assert.NoError(t, err) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "bin has already been filled") +} + +func TestMetererOnDemand(t *testing.T) { + ctx := context.Background() + quorumNumbers := []uint8{0, 1} + binIndex := uint32(0) // this field doesn't matter for on-demand payments wrt global rate limit + + paymentChainState.On("GetOnDemandPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account string) bool { + return account == accountID1 + })).Return(account1OnDemandPayments, nil) + paymentChainState.On("GetOnDemandPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account string) bool { + return account == accountID2 + })).Return(account2OnDemandPayments, nil) + paymentChainState.On("GetOnDemandPaymentByAccount", testifymock.Anything, testifymock.Anything).Return(core.OnDemandPayment{}, errors.New("payment not found")) + paymentChainState.On("GetOnDemandQuorumNumbers", testifymock.Anything).Return(quorumNumbers, nil) + + // test unregistered account + unregisteredUser, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + blob, header := createMetererInput(binIndex, 2, 1000, quorumNumbers, crypto.PubkeyToAddress(unregisteredUser.PublicKey).Hex()) + assert.NoError(t, err) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "failed to get on-demand payment by account: payment not found") + + // test invalid quorom ID + blob, header = createMetererInput(binIndex, 1, 1000, []uint8{0, 1, 2}, accountID1) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "invalid quorum for On-Demand Request") + + // test insufficient cumulative payment + blob, header = createMetererInput(binIndex, 1, 2000, quorumNumbers, accountID1) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "insufficient cumulative payment increment") + // No rollback after meter request + result, err := dynamoClient.Query(ctx, ondemandTableName, "AccountID = :account", commondynamodb.ExpressionValues{ + ":account": &types.AttributeValueMemberS{ + Value: accountID1, + }}) + assert.NoError(t, err) + assert.Equal(t, 1, len(result)) + + // test duplicated cumulative payments + dataLength := uint(100) + priceCharged := mt.PaymentCharged(dataLength) + assert.Equal(t, uint64(102*mt.PricePerSymbol), priceCharged) + blob, header = createMetererInput(binIndex, priceCharged, dataLength, quorumNumbers, accountID2) + err = mt.MeterRequest(ctx, *blob, *header) + assert.NoError(t, err) + blob, header = createMetererInput(binIndex, priceCharged, dataLength, quorumNumbers, accountID2) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "exact payment already exists") + + // test valid payments + for i := 1; i < 9; i++ { + blob, header = createMetererInput(binIndex, uint64(priceCharged)*uint64(i+1), dataLength, quorumNumbers, accountID2) + err = mt.MeterRequest(ctx, *blob, *header) + assert.NoError(t, err) + } + + // test cumulative payment on-chain constraint + blob, header = createMetererInput(binIndex, 2023, 1, quorumNumbers, accountID2) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "invalid on-demand payment: request claims a cumulative payment greater than the on-chain deposit") + + // test insufficient increment in cumulative payment + previousCumulativePayment := uint64(priceCharged) * uint64(9) + dataLength = uint(2) + priceCharged = mt.PaymentCharged(dataLength) + blob, header = createMetererInput(binIndex, previousCumulativePayment+priceCharged-1, dataLength, quorumNumbers, accountID2) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "invalid on-demand payment: insufficient cumulative payment increment") + previousCumulativePayment = previousCumulativePayment + priceCharged + + // test cannot insert cumulative payment in out of order + blob, header = createMetererInput(binIndex, mt.PaymentCharged(50), 50, quorumNumbers, accountID2) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "invalid on-demand payment: breaking cumulative payment invariants") + + numPrevRecords := 12 + result, err = dynamoClient.Query(ctx, ondemandTableName, "AccountID = :account", commondynamodb.ExpressionValues{ + ":account": &types.AttributeValueMemberS{ + Value: accountID2, + }}) + assert.NoError(t, err) + assert.Equal(t, numPrevRecords, len(result)) + // test failed global rate limit (previously payment recorded: 2, global limit: 1009) + fmt.Println("need ", previousCumulativePayment+mt.PaymentCharged(1010)) + blob, header = createMetererInput(binIndex, previousCumulativePayment+mt.PaymentCharged(1010), 1010, quorumNumbers, accountID1) + err = mt.MeterRequest(ctx, *blob, *header) + assert.ErrorContains(t, err, "failed global rate limiting") + // Correct rollback + result, err = dynamoClient.Query(ctx, ondemandTableName, "AccountID = :account", commondynamodb.ExpressionValues{ + ":account": &types.AttributeValueMemberS{ + Value: accountID2, + }}) + assert.NoError(t, err) + assert.Equal(t, numPrevRecords, len(result)) +} + +func TestMeterer_paymentCharged(t *testing.T) { + tests := []struct { + name string + dataLength uint + pricePerSymbol uint32 + minNumSymbols uint32 + expected uint64 + }{ + { + name: "Data length equal to min chargeable size", + dataLength: 1024, + pricePerSymbol: 1, + minNumSymbols: 1024, + expected: 1024, + }, + { + name: "Data length less than min chargeable size", + dataLength: 512, + pricePerSymbol: 2, + minNumSymbols: 1024, + expected: 2048, + }, + { + name: "Data length greater than min chargeable size", + dataLength: 2048, + pricePerSymbol: 1, + minNumSymbols: 1024, + expected: 2048, + }, + { + name: "Large data length", + dataLength: 1 << 20, // 1 MB + pricePerSymbol: 1, + minNumSymbols: 1024, + expected: 1 << 20, + }, + { + name: "Price not evenly divisible by min chargeable size", + dataLength: 1536, + pricePerSymbol: 1, + minNumSymbols: 1024, + expected: 2048, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &meterer.Meterer{ + Config: meterer.Config{ + PricePerSymbol: tt.pricePerSymbol, + MinNumSymbols: tt.minNumSymbols, + }, + } + result := m.PaymentCharged(tt.dataLength) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestMeterer_symbolsCharged(t *testing.T) { + tests := []struct { + name string + dataLength uint + minNumSymbols uint32 + expected uint32 + }{ + { + name: "Data length equal to min chargeable size", + dataLength: 1024, + minNumSymbols: 1024, + expected: 1024, + }, + { + name: "Data length less than min chargeable size", + dataLength: 512, + minNumSymbols: 1024, + expected: 1024, + }, + { + name: "Data length greater than min chargeable size", + dataLength: 2048, + minNumSymbols: 1024, + expected: 2048, + }, + { + name: "Large data length", + dataLength: 1 << 20, // 1 MB + minNumSymbols: 1024, + expected: 1 << 20, + }, + { + name: "Very small data length", + dataLength: 16, + minNumSymbols: 1024, + expected: 1024, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &meterer.Meterer{ + Config: meterer.Config{ + MinNumSymbols: tt.minNumSymbols, + }, + } + result := m.SymbolsCharged(tt.dataLength) + assert.Equal(t, tt.expected, result) + }) + } +} + +func createMetererInput(binIndex uint32, cumulativePayment uint64, dataLength uint, quorumNumbers []uint8, accountID string) (blob *core.Blob, header *core.PaymentMetadata) { + sp := make([]*core.SecurityParam, len(quorumNumbers)) + for i, quorumID := range quorumNumbers { + sp[i] = &core.SecurityParam{ + QuorumID: quorumID, + } + } + blob = &core.Blob{ + RequestHeader: core.BlobRequestHeader{ + BlobAuthHeader: core.BlobAuthHeader{ + AccountID: accountID2, + BlobCommitments: encoding.BlobCommitments{ + Length: dataLength, + }, + }, + SecurityParams: sp, + }, + } + header = &core.PaymentMetadata{ + AccountID: accountID, + BinIndex: binIndex, + CumulativePayment: big.NewInt(int64(cumulativePayment)), + } + return blob, header +} diff --git a/core/meterer/offchain_store.go b/core/meterer/offchain_store.go index d253a1b7e2..08b6ed414f 100644 --- a/core/meterer/offchain_store.go +++ b/core/meterer/offchain_store.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/big" "strconv" "time" @@ -11,6 +12,8 @@ import ( commondynamodb "github.com/Layr-Labs/eigenda/common/aws/dynamodb" "github.com/Layr-Labs/eigenda/core" "github.com/Layr-Labs/eigensdk-go/logging" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) @@ -138,7 +141,7 @@ func (s *OffchainStore) AddOnDemandPayment(ctx context.Context, paymentMetadata result, err := s.dynamoClient.GetItem(ctx, s.onDemandTableName, commondynamodb.Item{ "AccountID": &types.AttributeValueMemberS{Value: paymentMetadata.AccountID}, - "CumulativePayments": &types.AttributeValueMemberN{Value: strconv.FormatUint(paymentMetadata.CumulativePayment, 10)}, + "CumulativePayments": &types.AttributeValueMemberN{Value: paymentMetadata.CumulativePayment.String()}, }, ) if err != nil { @@ -150,7 +153,7 @@ func (s *OffchainStore) AddOnDemandPayment(ctx context.Context, paymentMetadata err = s.dynamoClient.PutItem(ctx, s.onDemandTableName, commondynamodb.Item{ "AccountID": &types.AttributeValueMemberS{Value: paymentMetadata.AccountID}, - "CumulativePayments": &types.AttributeValueMemberN{Value: strconv.FormatUint(paymentMetadata.CumulativePayment, 10)}, + "CumulativePayments": &types.AttributeValueMemberN{Value: paymentMetadata.CumulativePayment.String()}, "DataLength": &types.AttributeValueMemberN{Value: strconv.FormatUint(uint64(symbolsCharged), 10)}, }, ) @@ -162,11 +165,11 @@ func (s *OffchainStore) AddOnDemandPayment(ctx context.Context, paymentMetadata } // RemoveOnDemandPayment removes a specific payment from the list for a specific account -func (s *OffchainStore) RemoveOnDemandPayment(ctx context.Context, accountID string, payment uint64) error { +func (s *OffchainStore) RemoveOnDemandPayment(ctx context.Context, accountID string, payment *big.Int) error { err := s.dynamoClient.DeleteItem(ctx, s.onDemandTableName, commondynamodb.Key{ "AccountID": &types.AttributeValueMemberS{Value: accountID}, - "CumulativePayments": &types.AttributeValueMemberN{Value: strconv.FormatUint(payment, 10)}, + "CumulativePayments": &types.AttributeValueMemberN{Value: payment.String()}, }, ) @@ -179,21 +182,22 @@ func (s *OffchainStore) RemoveOnDemandPayment(ctx context.Context, accountID str // GetRelevantOnDemandRecords gets previous cumulative payment, next cumulative payment, blob size of next payment // The queries are done sequentially instead of one-go for efficient querying and would not cause race condition errors for honest requests -func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountID string, cumulativePayment uint64) (uint64, uint64, uint32, error) { +func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountID string, cumulativePayment *big.Int) (uint64, uint64, uint32, error) { // Fetch the largest entry smaller than the given cumulativePayment - smallerResult, err := s.dynamoClient.QueryIndexOrderWithLimit(ctx, s.onDemandTableName, "AccountIDIndex", - "AccountID = :account AND CumulativePayments < :cumulativePayment", - commondynamodb.ExpressionValues{ + queryInput := &dynamodb.QueryInput{ + TableName: aws.String(s.onDemandTableName), + KeyConditionExpression: aws.String("AccountID = :account AND CumulativePayments < :cumulativePayment"), + ExpressionAttributeValues: commondynamodb.ExpressionValues{ ":account": &types.AttributeValueMemberS{Value: accountID}, - ":cumulativePayment": &types.AttributeValueMemberN{Value: strconv.FormatUint(cumulativePayment, 10)}, + ":cumulativePayment": &types.AttributeValueMemberN{Value: cumulativePayment.String()}, }, - false, // Retrieve results in descending order for the largest smaller amount - 1, - ) + ScanIndexForward: aws.Bool(false), + Limit: aws.Int32(1), + } + smallerResult, err := s.dynamoClient.QueryWithInput(ctx, queryInput) if err != nil { return 0, 0, 0, fmt.Errorf("failed to query smaller payments for account: %w", err) } - var prevPayment uint64 if len(smallerResult) > 0 { prevPayment, err = strconv.ParseUint(smallerResult[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10, 64) @@ -203,15 +207,17 @@ func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountI } // Fetch the smallest entry larger than the given cumulativePayment - largerResult, err := s.dynamoClient.QueryIndexOrderWithLimit(ctx, s.onDemandTableName, "AccountIDIndex", - "AccountID = :account AND CumulativePayments > :cumulativePayment", - commondynamodb.ExpressionValues{ + queryInput = &dynamodb.QueryInput{ + TableName: aws.String(s.onDemandTableName), + KeyConditionExpression: aws.String("AccountID = :account AND CumulativePayments > :cumulativePayment"), + ExpressionAttributeValues: commondynamodb.ExpressionValues{ ":account": &types.AttributeValueMemberS{Value: accountID}, - ":cumulativePayment": &types.AttributeValueMemberN{Value: strconv.FormatUint(cumulativePayment, 10)}, + ":cumulativePayment": &types.AttributeValueMemberN{Value: cumulativePayment.String()}, }, - true, // Retrieve results in ascending order for the smallest greater amount - 1, - ) + ScanIndexForward: aws.Bool(true), + Limit: aws.Int32(1), + } + largerResult, err := s.dynamoClient.QueryWithInput(ctx, queryInput) if err != nil { return 0, 0, 0, fmt.Errorf("failed to query the next payment for account: %w", err) } diff --git a/core/meterer/onchain_state.go b/core/meterer/onchain_state.go index bdbe8d26e9..4aed9f6e99 100644 --- a/core/meterer/onchain_state.go +++ b/core/meterer/onchain_state.go @@ -3,6 +3,7 @@ package meterer import ( "context" "errors" + "sync" "github.com/Layr-Labs/eigenda/core" "github.com/Layr-Labs/eigenda/core/eth" @@ -12,60 +13,77 @@ import ( // OnchainPaymentState is an interface for getting information about the current chain state for payments. type OnchainPayment interface { - GetCurrentBlockNumber(ctx context.Context) (uint32, error) - CurrentOnchainPaymentState(ctx context.Context, tx *eth.Transactor) (OnchainPaymentState, error) - GetActiveReservations(ctx context.Context, blockNumber uint32) (map[string]core.ActiveReservation, error) - GetActiveReservationsByAccount(ctx context.Context, blockNumber uint32, accountID string) (core.ActiveReservation, error) - GetOnDemandPayments(ctx context.Context, blockNumber uint32) (map[string]core.OnDemandPayment, error) - GetOnDemandPaymentByAccount(ctx context.Context, blockNumber uint32, accountID string) (core.OnDemandPayment, error) + RefreshOnchainPaymentState(ctx context.Context, tx *eth.Transactor) error + GetActiveReservations(ctx context.Context) (map[string]core.ActiveReservation, error) + GetActiveReservationByAccount(ctx context.Context, accountID string) (core.ActiveReservation, error) + GetOnDemandPayments(ctx context.Context) (map[string]core.OnDemandPayment, error) + GetOnDemandPaymentByAccount(ctx context.Context, accountID string) (core.OnDemandPayment, error) + GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) } type OnchainPaymentState struct { tx *eth.Transactor - ActiveReservations map[string]core.ActiveReservation - OnDemandPayments map[string]core.OnDemandPayment + ActiveReservations map[string]core.ActiveReservation + OnDemandPayments map[string]core.OnDemandPayment + OnDemandQuorumNumbers []uint8 + ReservationsLock sync.RWMutex + OnDemandLocks sync.RWMutex } func NewOnchainPaymentState(ctx context.Context, tx *eth.Transactor) (OnchainPaymentState, error) { - activeReservations, onDemandPayments, err := CurrentOnchainPaymentState(ctx, tx) + blockNumber, err := tx.GetCurrentBlockNumber(ctx) + if err != nil { + return OnchainPaymentState{}, err + } + + quorumNumbers, err := tx.GetRequiredQuorumNumbers(ctx, blockNumber) if err != nil { - return OnchainPaymentState{tx: tx}, err + return OnchainPaymentState{}, err } return OnchainPaymentState{ - tx: tx, - ActiveReservations: activeReservations, - OnDemandPayments: onDemandPayments, + tx: tx, + ActiveReservations: make(map[string]core.ActiveReservation), + OnDemandPayments: make(map[string]core.OnDemandPayment), + OnDemandQuorumNumbers: quorumNumbers, }, nil } -// CurrentOnchainPaymentState returns the current onchain payment state (TODO: can optimize based on contract interface) -func CurrentOnchainPaymentState(ctx context.Context, tx *eth.Transactor) (map[string]core.ActiveReservation, map[string]core.OnDemandPayment, error) { +// RefreshOnchainPaymentState returns the current onchain payment state (TODO: can optimize based on contract interface) +func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Transactor) error { blockNumber, err := tx.GetCurrentBlockNumber(ctx) if err != nil { - return nil, nil, err + return err } - activeReservations, err := tx.GetActiveReservations(ctx, blockNumber) - if err != nil { - return nil, nil, err + pcs.ReservationsLock.Lock() + accountIDs := make([]string, 0, len(pcs.ActiveReservations)) + for accountID := range pcs.ActiveReservations { + accountIDs = append(accountIDs, accountID) } - onDemandPayments, err := tx.GetOnDemandPayments(ctx, blockNumber) + activeReservations, err := tx.GetActiveReservations(ctx, blockNumber, accountIDs) if err != nil { - return nil, nil, err + return err } + pcs.ActiveReservations = activeReservations + pcs.ReservationsLock.Unlock() - return activeReservations, onDemandPayments, nil -} + pcs.OnDemandLocks.Lock() + accountIDs = make([]string, 0, len(pcs.OnDemandPayments)) + for accountID := range pcs.OnDemandPayments { + accountIDs = append(accountIDs, accountID) + } -func (pcs *OnchainPaymentState) GetCurrentBlockNumber(ctx context.Context) (uint32, error) { - blockNumber, err := pcs.tx.GetCurrentBlockNumber(ctx) + onDemandPayments, err := tx.GetOnDemandPayments(ctx, blockNumber, accountIDs) if err != nil { - return 0, err + return err } - return blockNumber, nil + pcs.OnDemandPayments = onDemandPayments + pcs.OnDemandLocks.Unlock() + + return nil } func (pcs *OnchainPaymentState) GetActiveReservations(ctx context.Context, blockNumber uint) (map[string]core.ActiveReservation, error) { @@ -73,11 +91,20 @@ func (pcs *OnchainPaymentState) GetActiveReservations(ctx context.Context, block } // GetActiveReservationByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation -func (pcs *OnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, blockNumber uint, accountID string) (*core.ActiveReservation, error) { +func (pcs *OnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, blockNumber uint32, accountID string) (core.ActiveReservation, error) { if reservation, ok := pcs.ActiveReservations[accountID]; ok { - return &reservation, nil + return reservation, nil } - return nil, errors.New("reservation not found") + // pulls the chain state + res, err := pcs.tx.GetActiveReservationByAccount(ctx, blockNumber, accountID) + if err != nil { + return core.ActiveReservation{}, errors.New("payment not found") + } + + pcs.ReservationsLock.Lock() + pcs.ActiveReservations[accountID] = res + pcs.ReservationsLock.Unlock() + return res, nil } func (pcs *OnchainPaymentState) GetOnDemandPayments(ctx context.Context, blockNumber uint) (map[string]core.OnDemandPayment, error) { @@ -85,9 +112,22 @@ func (pcs *OnchainPaymentState) GetOnDemandPayments(ctx context.Context, blockNu } // GetOnDemandPaymentByAccount returns a pointer to the on-demand payment for the given account ID; no writes will be made to the payment -func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, blockNumber uint, accountID string) (*core.OnDemandPayment, error) { +func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, blockNumber uint32, accountID string) (core.OnDemandPayment, error) { if payment, ok := pcs.OnDemandPayments[accountID]; ok { - return &payment, nil + return payment, nil + } + // pulls the chain state + res, err := pcs.tx.GetOnDemandPaymentByAccount(ctx, blockNumber, accountID) + if err != nil { + return core.OnDemandPayment{}, errors.New("payment not found") } - return nil, errors.New("payment not found") + + pcs.OnDemandLocks.Lock() + pcs.OnDemandPayments[accountID] = res + pcs.OnDemandLocks.Unlock() + return res, nil +} + +func (pcs *OnchainPaymentState) GetOnDemandQuorumNumbers(ctx context.Context, blockNumber uint32) ([]uint8, error) { + return pcs.tx.GetRequiredQuorumNumbers(ctx, blockNumber) } diff --git a/core/meterer/onchain_state_test.go b/core/meterer/onchain_state_test.go index 8684034b26..fb3265bf0c 100644 --- a/core/meterer/onchain_state_test.go +++ b/core/meterer/onchain_state_test.go @@ -7,7 +7,6 @@ import ( "github.com/Layr-Labs/eigenda/core" "github.com/Layr-Labs/eigenda/core/eth" - "github.com/Layr-Labs/eigenda/core/meterer" "github.com/Layr-Labs/eigenda/core/mock" "github.com/stretchr/testify/assert" testifymock "github.com/stretchr/testify/mock" @@ -15,7 +14,7 @@ import ( var ( dummyActiveReservation = core.ActiveReservation{ - DataRate: 100, + SymbolsPerSec: 100, StartTimestamp: 1000, EndTimestamp: 2000, QuorumSplit: []byte{50, 50}, @@ -25,28 +24,13 @@ var ( } ) -func TestGetCurrentOnchainPaymentState(t *testing.T) { +func TestRefreshOnchainPaymentState(t *testing.T) { mockState := &mock.MockOnchainPaymentState{} ctx := context.Background() - mockState.On("CurrentOnchainPaymentState", testifymock.Anything, testifymock.Anything).Return(meterer.OnchainPaymentState{ - ActiveReservations: map[string]core.ActiveReservation{ - "account1": dummyActiveReservation, - }, - OnDemandPayments: map[string]core.OnDemandPayment{ - "account1": dummyOnDemandPayment, - }, - }, nil) - - state, err := mockState.CurrentOnchainPaymentState(ctx, ð.Transactor{}) + mockState.On("RefreshOnchainPaymentState", testifymock.Anything, testifymock.Anything).Return(nil) + + err := mockState.RefreshOnchainPaymentState(ctx, ð.Transactor{}) assert.NoError(t, err) - assert.Equal(t, meterer.OnchainPaymentState{ - ActiveReservations: map[string]core.ActiveReservation{ - "account1": dummyActiveReservation, - }, - OnDemandPayments: map[string]core.OnDemandPayment{ - "account1": dummyOnDemandPayment, - }, - }, state) } func TestGetCurrentBlockNumber(t *testing.T) { @@ -66,7 +50,7 @@ func TestGetActiveReservations(t *testing.T) { } mockState.On("GetActiveReservations", testifymock.Anything, testifymock.Anything).Return(expectedReservations, nil) - reservations, err := mockState.GetActiveReservations(ctx, 1000) + reservations, err := mockState.GetActiveReservations(ctx) assert.NoError(t, err) assert.Equal(t, expectedReservations, reservations) } @@ -74,9 +58,9 @@ func TestGetActiveReservations(t *testing.T) { func TestGetActiveReservationByAccount(t *testing.T) { mockState := &mock.MockOnchainPaymentState{} ctx := context.Background() - mockState.On("GetActiveReservationsByAccount", testifymock.Anything, testifymock.Anything, testifymock.Anything).Return(dummyActiveReservation, nil) + mockState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(dummyActiveReservation, nil) - reservation, err := mockState.GetActiveReservationsByAccount(ctx, 1000, "account1") + reservation, err := mockState.GetActiveReservationByAccount(ctx, "account1") assert.NoError(t, err) assert.Equal(t, dummyActiveReservation, reservation) } @@ -89,7 +73,7 @@ func TestGetOnDemandPayments(t *testing.T) { } mockState.On("GetOnDemandPayments", testifymock.Anything, testifymock.Anything).Return(expectedPayments, nil) - payments, err := mockState.GetOnDemandPayments(ctx, 1000) + payments, err := mockState.GetOnDemandPayments(ctx) assert.NoError(t, err) assert.Equal(t, expectedPayments, payments) } @@ -100,7 +84,17 @@ func TestGetOnDemandPaymentByAccount(t *testing.T) { accountID := "account1" mockState.On("GetOnDemandPaymentByAccount", testifymock.Anything, testifymock.Anything, testifymock.Anything).Return(dummyOnDemandPayment, nil) - payment, err := mockState.GetOnDemandPaymentByAccount(ctx, 1000, accountID) + payment, err := mockState.GetOnDemandPaymentByAccount(ctx, accountID) assert.NoError(t, err) assert.Equal(t, dummyOnDemandPayment, payment) } + +func TestGetOnDemandQuorumNumbers(t *testing.T) { + mockState := &mock.MockOnchainPaymentState{} + ctx := context.Background() + mockState.On("GetOnDemandQuorumNumbers", testifymock.Anything, testifymock.Anything).Return([]uint8{0, 1}, nil) + + quorumNumbers, err := mockState.GetOnDemandQuorumNumbers(ctx) + assert.NoError(t, err) + assert.Equal(t, []uint8{0, 1}, quorumNumbers) +} diff --git a/core/mock/payment_state.go b/core/mock/payment_state.go index 3e973b42fe..a69b155d4d 100644 --- a/core/mock/payment_state.go +++ b/core/mock/payment_state.go @@ -24,16 +24,12 @@ func (m *MockOnchainPaymentState) GetCurrentBlockNumber(ctx context.Context) (ui return value, args.Error(1) } -func (m *MockOnchainPaymentState) CurrentOnchainPaymentState(ctx context.Context, tx *eth.Transactor) (meterer.OnchainPaymentState, error) { +func (m *MockOnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Transactor) error { args := m.Called() - var value meterer.OnchainPaymentState - if args.Get(0) != nil { - value = args.Get(0).(meterer.OnchainPaymentState) - } - return value, args.Error(1) + return args.Error(0) } -func (m *MockOnchainPaymentState) GetActiveReservations(ctx context.Context, blockNumber uint32) (map[string]core.ActiveReservation, error) { +func (m *MockOnchainPaymentState) GetActiveReservations(ctx context.Context) (map[string]core.ActiveReservation, error) { args := m.Called() var value map[string]core.ActiveReservation if args.Get(0) != nil { @@ -42,8 +38,8 @@ func (m *MockOnchainPaymentState) GetActiveReservations(ctx context.Context, blo return value, args.Error(1) } -func (m *MockOnchainPaymentState) GetActiveReservationsByAccount(ctx context.Context, blockNumber uint32, accountID string) (core.ActiveReservation, error) { - args := m.Called() +func (m *MockOnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, accountID string) (core.ActiveReservation, error) { + args := m.Called(ctx, accountID) var value core.ActiveReservation if args.Get(0) != nil { value = args.Get(0).(core.ActiveReservation) @@ -51,7 +47,7 @@ func (m *MockOnchainPaymentState) GetActiveReservationsByAccount(ctx context.Con return value, args.Error(1) } -func (m *MockOnchainPaymentState) GetOnDemandPayments(ctx context.Context, blockNumber uint32) (map[string]core.OnDemandPayment, error) { +func (m *MockOnchainPaymentState) GetOnDemandPayments(ctx context.Context) (map[string]core.OnDemandPayment, error) { args := m.Called() var value map[string]core.OnDemandPayment if args.Get(0) != nil { @@ -60,11 +56,20 @@ func (m *MockOnchainPaymentState) GetOnDemandPayments(ctx context.Context, block return value, args.Error(1) } -func (m *MockOnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, blockNumber uint32, accountID string) (core.OnDemandPayment, error) { - args := m.Called() +func (m *MockOnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, accountID string) (core.OnDemandPayment, error) { + args := m.Called(ctx, accountID) var value core.OnDemandPayment if args.Get(0) != nil { value = args.Get(0).(core.OnDemandPayment) } return value, args.Error(1) } + +func (m *MockOnchainPaymentState) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) { + args := m.Called() + var value []uint8 + if args.Get(0) != nil { + value = args.Get(0).([]uint8) + } + return value, args.Error(1) +} diff --git a/core/mock/tx.go b/core/mock/tx.go index 85669abe43..d0c54794ca 100644 --- a/core/mock/tx.go +++ b/core/mock/tx.go @@ -201,14 +201,26 @@ func (t *MockTransactor) PubkeyHashToOperator(ctx context.Context, operatorId co return result.(gethcommon.Address), args.Error(1) } -func (t *MockTransactor) GetActiveReservations(ctx context.Context, blockNumber uint32) (map[string]core.ActiveReservation, error) { +func (t *MockTransactor) GetActiveReservations(ctx context.Context, blockNumber uint32, accountIDs []string) (map[string]core.ActiveReservation, error) { args := t.Called() result := args.Get(0) return result.(map[string]core.ActiveReservation), args.Error(1) } -func (t *MockTransactor) GetOnDemandPayments(ctx context.Context, blockNumber uint32) (map[string]core.OnDemandPayment, error) { +func (t *MockTransactor) GetActiveReservationByAccount(ctx context.Context, blockNumber uint32, accountID string) (core.ActiveReservation, error) { + args := t.Called() + result := args.Get(0) + return result.(core.ActiveReservation), args.Error(1) +} + +func (t *MockTransactor) GetOnDemandPayments(ctx context.Context, blockNumber uint32, accountIDs []string) (map[string]core.OnDemandPayment, error) { args := t.Called() result := args.Get(0) return result.(map[string]core.OnDemandPayment), args.Error(1) } + +func (t *MockTransactor) GetOnDemandPaymentByAccount(ctx context.Context, blockNumber uint32, accountID string) (core.OnDemandPayment, error) { + args := t.Called() + result := args.Get(0) + return result.(core.OnDemandPayment), args.Error(1) +} diff --git a/core/tx.go b/core/tx.go index ad05caade3..e61c5a7bc0 100644 --- a/core/tx.go +++ b/core/tx.go @@ -148,8 +148,14 @@ type Transactor interface { GetRequiredQuorumNumbers(ctx context.Context, blockNumber uint32) ([]QuorumID, error) // GetActiveReservations returns active reservations (end timestamp > current timestamp) - GetActiveReservations(ctx context.Context, blockNumber uint32) (map[string]ActiveReservation, error) + GetActiveReservations(ctx context.Context, blockNumber uint32, accountIDs []string) (map[string]ActiveReservation, error) + + // GetActiveReservations returns active reservations (end timestamp > current timestamp) + GetActiveReservationByAccount(ctx context.Context, blockNumber uint32, accountID string) (ActiveReservation, error) + + // GetOnDemandPayments returns all on-demand payments + GetOnDemandPayments(ctx context.Context, blockNumber uint32, accountIDs []string) (map[string]OnDemandPayment, error) // GetOnDemandPayments returns all on-demand payments - GetOnDemandPayments(ctx context.Context, blockNumber uint32) (map[string]OnDemandPayment, error) + GetOnDemandPaymentByAccount(ctx context.Context, blockNumber uint32, accountID string) (OnDemandPayment, error) }