From 1bd102941cb8521cdfc357071909c34464655691 Mon Sep 17 00:00:00 2001 From: hopeyen Date: Wed, 9 Oct 2024 18:46:51 -0700 Subject: [PATCH] refactor: test mocks and size/price calculation --- core/auth/payment_metadata.go | 4 +- core/auth/payment_metadata_test.go | 8 +- core/meterer/meterer.go | 23 +++-- core/meterer/meterer_test.go | 134 +++++++++++++++++------------ core/meterer/onchain_state_test.go | 2 +- core/mock/payment_state.go | 4 +- 6 files changed, 104 insertions(+), 71 deletions(-) diff --git a/core/auth/payment_metadata.go b/core/auth/payment_metadata.go index c43b8a3e53..b31ca38e51 100644 --- a/core/auth/payment_metadata.go +++ b/core/auth/payment_metadata.go @@ -21,8 +21,8 @@ type EIP712Signer struct { } // NewEIP712Signer creates a new EIP712Signer instance -func NewEIP712Signer(chainID *big.Int, verifyingContract common.Address) *EIP712Signer { - return &EIP712Signer{ +func NewEIP712Signer(chainID *big.Int, verifyingContract common.Address) EIP712Signer { + return EIP712Signer{ domain: apitypes.TypedDataDomain{ Name: "EigenDA", Version: "1", diff --git a/core/auth/payment_metadata_test.go b/core/auth/payment_metadata_test.go index e9de77024d..0ddec89996 100644 --- a/core/auth/payment_metadata_test.go +++ b/core/auth/payment_metadata_test.go @@ -55,7 +55,7 @@ func TestConstructPaymentMetadata(t *testing.T) { require.NoError(t, err) header, err := auth.ConstructPaymentMetadata( - signer, + &signer, 0, // binIndex 1000, // cumulativePayment 1024, // dataLength @@ -86,7 +86,7 @@ func TestEIP712SignerWithDifferentKeys(t *testing.T) { require.NoError(t, err) header, err := auth.ConstructPaymentMetadata( - signer, + &signer, 0, 1000, 1024, @@ -117,7 +117,7 @@ func TestEIP712SignerWithModifiedHeader(t *testing.T) { require.NoError(t, err) header, err := auth.ConstructPaymentMetadata( - signer, + &signer, 0, 1000, 1024, @@ -152,7 +152,7 @@ func TestEIP712SignerWithDifferentChainID(t *testing.T) { require.NoError(t, err) header, err := auth.ConstructPaymentMetadata( - signer1, + &signer1, 0, 1000, 1024, diff --git a/core/meterer/meterer.go b/core/meterer/meterer.go index 04413924fa..b6a925c0a8 100644 --- a/core/meterer/meterer.go +++ b/core/meterer/meterer.go @@ -26,6 +26,10 @@ type Config struct { // ChainReadTimeout is the timeout for reading payment state from chain ChainReadTimeout time.Duration + // ChainID indicate the network in which meterer(payment) is handled + ChainID *big.Int + // VerifyingContract is the address of the PaymentVault contract that verifies signatures + VerifyingContract common.Address } // Meterer handles payment accounting across different accounts. Disperser API server receives requests from clients and each request contains a blob header @@ -39,6 +43,7 @@ type Meterer struct { // OffchainStore uses DynamoDB to track metering and used to validate requests OffchainStore OffchainStore + signer auth.EIP712Signer logger logging.Logger } @@ -49,12 +54,14 @@ func NewMeterer( logger logging.Logger, ) (*Meterer, error) { // TODO: create a separate thread to pull from the chain and update chain state + return &Meterer{ Config: config, ChainState: paymentChainState, OffchainStore: offchainStore, + signer: auth.NewEIP712Signer(config.ChainID, config.VerifyingContract), logger: logger.With("component", "Meterer"), }, nil } @@ -69,6 +76,7 @@ func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata) // Validate against the payment method if header.CumulativePayment == 0 { + fmt.Println("reservation: ", header.AccountID) reservation, err := m.ChainState.GetActiveReservationByAccount(ctx, header.AccountID) if err != nil { return fmt.Errorf("failed to get active reservation by account: %w", err) @@ -93,11 +101,7 @@ func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata) // ValidateSignature checks if the signature is valid against all other fields in the header // Assuming the signature is an eip712 signature func (m *Meterer) ValidateSignature(ctx context.Context, header core.PaymentMetadata) error { - // Create the EIP712Signer - //TODO: update the chainID and verifyingContract - signer := auth.NewEIP712Signer(big.NewInt(17000), common.HexToAddress("0x1234000000000000000000000000000000000000")) - - recoveredAddress, err := signer.RecoverSender(&header) + recoveredAddress, err := m.signer.RecoverSender(&header) if err != nil { return fmt.Errorf("failed to recover sender: %w", err) } @@ -251,9 +255,14 @@ func (m *Meterer) PaymentCharged(dataLength uint32) uint64 { return uint64(m.SymbolsCharged(dataLength)) * uint64(m.PricePerSymbol) } -// SymbolsCharged returns the chargeable data length for a given data length +// SymbolsCharged returns the number of symbols charged for a given data length +// being at least MinNumSymbols or the nearest rounded-up multiple of MinNumSymbols. func (m *Meterer) SymbolsCharged(dataLength uint32) uint32 { - return uint32(max(dataLength, m.MinNumSymbols)) + if dataLength <= m.MinNumSymbols { + return m.MinNumSymbols + } + // Round up to the nearest multiple of MinNumSymbols + return uint32(core.RoundUpDivide(uint(dataLength), uint(m.MinNumSymbols))) * m.MinNumSymbols } // ValidateBinIndex checks if the provided bin index is valid diff --git a/core/meterer/meterer_test.go b/core/meterer/meterer_test.go index 81faa75ed7..341eca1fe6 100644 --- a/core/meterer/meterer_test.go +++ b/core/meterer/meterer_test.go @@ -3,6 +3,7 @@ package meterer_test import ( "context" "crypto/ecdsa" + "errors" "fmt" "math/big" "os" @@ -23,19 +24,26 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ory/dockertest/v3" "github.com/stretchr/testify/assert" + testifymock "github.com/stretchr/testify/mock" "github.com/Layr-Labs/eigensdk-go/logging" ) var ( - dockertestPool *dockertest.Pool - dockertestResource *dockertest.Resource - dynamoClient *commondynamodb.Client - clientConfig commonaws.ClientConfig - privateKey1 *ecdsa.PrivateKey - privateKey2 *ecdsa.PrivateKey - signer *auth.EIP712Signer - mt *meterer.Meterer + dockertestPool *dockertest.Pool + dockertestResource *dockertest.Resource + dynamoClient *commondynamodb.Client + clientConfig commonaws.ClientConfig + privateKey1 *ecdsa.PrivateKey + privateKey2 *ecdsa.PrivateKey + account1 string + account1Reservations core.ActiveReservation + account1OnDemandPayments core.OnDemandPayment + account2 string + account2Reservations core.ActiveReservation + account2OnDemandPayments core.OnDemandPayment + signer auth.EIP712Signer + mt *meterer.Meterer deployLocalStack bool localStackPort = "4566" @@ -49,20 +57,6 @@ func TestMain(m *testing.M) { os.Exit(code) } -// // Mock data initialization method -// func InitializeMockPayments(pcs *meterer.OnchainPaymentState, privateKey1 *ecdsa.PrivateKey, privateKey2 *ecdsa.PrivateKey) { -// // Initialize mock active reservations -// now := uint64(time.Now().Unix()) -// pcs.ActiveReservations.Reservations = map[string]*meterer.ActiveReservation{ -// crypto.PubkeyToAddress(privateKey1.PublicKey).Hex(): {BytesPerSec: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplit: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}, -// crypto.PubkeyToAddress(privateKey2.PublicKey).Hex(): {BytesPerSec: 200, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplit: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}}, -// } -// pcs.OnDemandPayments.Payments = map[string]*meterer.OnDemandPayment{ -// crypto.PubkeyToAddress(privateKey1.PublicKey).Hex(): {CumulativePayment: 1500}, -// crypto.PubkeyToAddress(privateKey2.PublicKey).Hex(): {CumulativePayment: 1000}, -// } -// } - func setup(_ *testing.M) { deployLocalStack = !(os.Getenv("DEPLOY_LOCALSTACK") == "false") @@ -99,10 +93,6 @@ func setup(_ *testing.M) { panic("failed to create dynamodb client") } - chainID := big.NewInt(17000) - verifyingContract := gethcommon.HexToAddress("0x1234000000000000000000000000000000000000") - signer = auth.NewEIP712Signer(chainID, verifyingContract) - privateKey1, err = crypto.GenerateKey() if err != nil { teardown() @@ -115,12 +105,20 @@ func setup(_ *testing.M) { } logger = logging.NewNoopLogger() + signer = auth.NewEIP712Signer(big.NewInt(17000), gethcommon.HexToAddress("0x1234000000000000000000000000000000000000")) + + if err != nil { + teardown() + panic("failed to create EIP712 signer") + } config := meterer.Config{ PricePerSymbol: 1, MinNumSymbols: 1, GlobalSymbolsPerSecond: 1000, ReservationWindow: 1, ChainReadTimeout: 3 * time.Second, + ChainID: big.NewInt(17000), + VerifyingContract: gethcommon.HexToAddress("0x1234000000000000000000000000000000000000"), } err = meterer.CreateReservationTable(clientConfig, "reservations") @@ -139,6 +137,16 @@ func setup(_ *testing.M) { panic("failed to create global reservation table") } + now := uint64(time.Now().Unix()) + account1 = crypto.PubkeyToAddress(privateKey1.PublicKey).Hex() + fmt.Println("account1", account1) + account2 = crypto.PubkeyToAddress(privateKey2.PublicKey).Hex() + fmt.Println("account2", account2) + account1Reservations = core.ActiveReservation{SymbolsPerSec: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplit: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}} + account2Reservations = core.ActiveReservation{SymbolsPerSec: 200, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplit: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}} + account1OnDemandPayments = core.OnDemandPayment{CumulativePayment: 1500} + account2OnDemandPayments = core.OnDemandPayment{CumulativePayment: 1000} + store, err := meterer.NewOffchainStore( clientConfig, "reservations", @@ -160,8 +168,6 @@ func setup(_ *testing.M) { // metrics.NewNoopMetrics(), ) - // InitializeMockPayments(paymentChainState, privateKey1, privateKey2) - if err != nil { teardown() panic("failed to create meterer") @@ -179,6 +185,13 @@ func TestMetererReservations(t *testing.T) { meterer.CreateReservationTable(clientConfig, "reservations") binIndex := meterer.GetBinIndex(uint64(time.Now().Unix()), mt.ReservationWindow) quoromNumbers := []uint8{0, 1} + paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account string) bool { + return account == account1 + })).Return(account1Reservations, nil) + paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account string) bool { + return account == account2 + })).Return(account2Reservations, nil) + paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(core.ActiveReservation{}, errors.New("reservation not found")) // test invalid signature invalidHeader := &core.PaymentMetadata{ @@ -193,7 +206,9 @@ func TestMetererReservations(t *testing.T) { assert.ErrorContains(t, err, "invalid signature: recovered address") // test invalid quorom ID - header, err := auth.ConstructPaymentMetadata(signer, 1, 0, 1000, []uint8{0, 1, 2}, privateKey1) + header, err := auth.ConstructPaymentMetadata(&signer, 1, 0, 1000, []uint8{0, 1, 2}, privateKey1) + fmt.Println("--- this header test invalid quorum ID ---") + fmt.Println("header", header) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "quorum number mismatch") @@ -203,13 +218,13 @@ func TestMetererReservations(t *testing.T) { if err != nil { t.Fatalf("Failed to generate key: %v", err) } - header, err = auth.ConstructPaymentMetadata(signer, 1, 0, 1000, []uint8{0, 1, 2}, unregisteredUser) + header, err = auth.ConstructPaymentMetadata(&signer, 1, 0, 1000, []uint8{0, 1, 2}, unregisteredUser) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "failed to get active reservation by account: reservation not found") // test invalid bin index - header, err = auth.ConstructPaymentMetadata(signer, binIndex, 0, 2000, quoromNumbers, privateKey1) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, 0, 2000, quoromNumbers, privateKey1) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "invalid bin index for reservation") @@ -218,7 +233,7 @@ func TestMetererReservations(t *testing.T) { accountID := crypto.PubkeyToAddress(privateKey2.PublicKey).Hex() for i := 0; i < 9; i++ { dataLength := 20 - header, err = auth.ConstructPaymentMetadata(signer, binIndex, 0, uint32(dataLength), quoromNumbers, privateKey2) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, 0, uint32(dataLength), quoromNumbers, privateKey2) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.NoError(t, err) @@ -233,7 +248,7 @@ func TestMetererReservations(t *testing.T) { } // frist over flow is allowed - header, err = auth.ConstructPaymentMetadata(signer, binIndex, 0, 25, quoromNumbers, privateKey2) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, 0, 25, quoromNumbers, privateKey2) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.NoError(t, err) @@ -248,13 +263,13 @@ func TestMetererReservations(t *testing.T) { assert.Equal(t, strconv.Itoa(int(5)), item["BinUsage"].(*types.AttributeValueMemberN).Value) // second over flow - header, err = auth.ConstructPaymentMetadata(signer, binIndex, 0, 1, quoromNumbers, privateKey2) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, 0, 1, quoromNumbers, privateKey2) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "bin has already been filled") // overwhelming bin overflow for empty bins (assuming all previous requests happened within 1 reservation window) - header, err = auth.ConstructPaymentMetadata(signer, binIndex-1, 0, 1000, quoromNumbers, privateKey2) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex-1, 0, 1000, quoromNumbers, privateKey2) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "overflow usage exceeds bin limit") @@ -267,6 +282,15 @@ func TestMetererOnDemand(t *testing.T) { quorumNumbers := []uint8{0, 1} binIndex := uint32(0) // this field doesn't matter for on-demand payments wrt global rate limit + paymentChainState.On("GetOnDemandPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account string) bool { + return account == account1 + })).Return(account1OnDemandPayments, nil) + paymentChainState.On("GetOnDemandPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account string) bool { + return account == account2 + })).Return(account2OnDemandPayments, nil) + paymentChainState.On("GetOnDemandPaymentByAccount", testifymock.Anything, testifymock.Anything).Return(core.OnDemandPayment{}, errors.New("payment not found")) + paymentChainState.On("GetOnDemandQuorumNumbers", testifymock.Anything).Return(quorumNumbers, nil) + // test invalid signature invalidHeader := &core.PaymentMetadata{ AccountID: crypto.PubkeyToAddress(privateKey1.PublicKey).Hex(), @@ -284,19 +308,19 @@ func TestMetererOnDemand(t *testing.T) { if err != nil { t.Fatalf("Failed to generate key: %v", err) } - header, err := auth.ConstructPaymentMetadata(signer, 1, 1, 1000, quorumNumbers, unregisteredUser) + header, err := auth.ConstructPaymentMetadata(&signer, 1, 1, 1000, quorumNumbers, unregisteredUser) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "failed to get on-demand payment by account: payment not found") // test invalid quorom ID - header, err = auth.ConstructPaymentMetadata(signer, 1, 1, 1000, []uint8{0, 1, 2}, privateKey1) + header, err = auth.ConstructPaymentMetadata(&signer, 1, 1, 1000, []uint8{0, 1, 2}, privateKey1) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "invalid quorum for On-Demand Request") // test insufficient cumulative payment - header, err = auth.ConstructPaymentMetadata(signer, 0, 1, 2000, quorumNumbers, privateKey1) + header, err = auth.ConstructPaymentMetadata(&signer, 0, 1, 2000, quorumNumbers, privateKey1) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) @@ -310,34 +334,34 @@ func TestMetererOnDemand(t *testing.T) { assert.Equal(t, 1, len(result)) // test duplicated cumulative payments - header, err = auth.ConstructPaymentMetadata(signer, binIndex, uint64(100), 100, quorumNumbers, privateKey2) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, uint64(100), 100, quorumNumbers, privateKey2) err = mt.MeterRequest(ctx, *header) assert.NoError(t, err) - header, err = auth.ConstructPaymentMetadata(signer, binIndex, uint64(100), 100, quorumNumbers, privateKey2) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, uint64(100), 100, quorumNumbers, privateKey2) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "exact payment already exists") // test valid payments for i := 1; i < 9; i++ { - header, err = auth.ConstructPaymentMetadata(signer, binIndex, uint64(100*(i+1)), 100, quorumNumbers, privateKey2) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, uint64(100*(i+1)), 100, quorumNumbers, privateKey2) err = mt.MeterRequest(ctx, *header) assert.NoError(t, err) } // test cumulative payment on-chain constraint - header, err = auth.ConstructPaymentMetadata(signer, binIndex, 1001, 1, quorumNumbers, privateKey2) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, 1001, 1, quorumNumbers, privateKey2) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "invalid on-demand payment: request claims a cumulative payment greater than the on-chain deposit") // test insufficient increment in cumulative payment - header, err = auth.ConstructPaymentMetadata(signer, binIndex, 901, 2, quorumNumbers, privateKey2) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, 901, 2, quorumNumbers, privateKey2) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "invalid on-demand payment: insufficient cumulative payment increment") // test cannot insert cumulative payment in out of order - header, err = auth.ConstructPaymentMetadata(signer, binIndex, 50, 50, quorumNumbers, privateKey2) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, 50, 50, quorumNumbers, privateKey2) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "invalid on-demand payment: breaking cumulative payment invariants") @@ -350,7 +374,7 @@ func TestMetererOnDemand(t *testing.T) { assert.NoError(t, err) assert.Equal(t, numPrevRecords, len(result)) // test failed global rate limit - header, err = auth.ConstructPaymentMetadata(signer, binIndex, 1002, 1001, quorumNumbers, privateKey1) + header, err = auth.ConstructPaymentMetadata(&signer, binIndex, 1002, 1001, quorumNumbers, privateKey1) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header) assert.ErrorContains(t, err, "failed global rate limiting") @@ -374,37 +398,37 @@ func TestMeterer_paymentCharged(t *testing.T) { { name: "Data length equal to min chargeable size", dataLength: 1024, - pricePerSymbol: 100, + pricePerSymbol: 1, minNumSymbols: 1024, - expected: 100, + expected: 1024, }, { name: "Data length less than min chargeable size", dataLength: 512, - pricePerSymbol: 100, + pricePerSymbol: 2, minNumSymbols: 1024, - expected: 100, + expected: 2048, }, { name: "Data length greater than min chargeable size", dataLength: 2048, - pricePerSymbol: 100, + pricePerSymbol: 1, minNumSymbols: 1024, - expected: 200, + expected: 2048, }, { name: "Large data length", dataLength: 1 << 20, // 1 MB - pricePerSymbol: 100, + pricePerSymbol: 1, minNumSymbols: 1024, - expected: 102400, + expected: 1 << 20, }, { name: "Price not evenly divisible by min chargeable size", dataLength: 1536, - pricePerSymbol: 150, + pricePerSymbol: 1, minNumSymbols: 1024, - expected: 225, + expected: 2048, }, } diff --git a/core/meterer/onchain_state_test.go b/core/meterer/onchain_state_test.go index b485c1284a..3f29f52d5e 100644 --- a/core/meterer/onchain_state_test.go +++ b/core/meterer/onchain_state_test.go @@ -73,7 +73,7 @@ func TestGetActiveReservations(t *testing.T) { func TestGetActiveReservationByAccount(t *testing.T) { mockState := &mock.MockOnchainPaymentState{} ctx := context.Background() - mockState.On("GetActiveReservationsByAccount", testifymock.Anything, testifymock.Anything, testifymock.Anything).Return(dummyActiveReservation, nil) + mockState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(dummyActiveReservation, nil) reservation, err := mockState.GetActiveReservationByAccount(ctx, "account1") assert.NoError(t, err) diff --git a/core/mock/payment_state.go b/core/mock/payment_state.go index 8b1c957e0e..9f0131f5b4 100644 --- a/core/mock/payment_state.go +++ b/core/mock/payment_state.go @@ -43,7 +43,7 @@ func (m *MockOnchainPaymentState) GetActiveReservations(ctx context.Context) (ma } func (m *MockOnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, accountID string) (core.ActiveReservation, error) { - args := m.Called() + args := m.Called(ctx, accountID) var value core.ActiveReservation if args.Get(0) != nil { value = args.Get(0).(core.ActiveReservation) @@ -61,7 +61,7 @@ func (m *MockOnchainPaymentState) GetOnDemandPayments(ctx context.Context) (map[ } func (m *MockOnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, accountID string) (core.OnDemandPayment, error) { - args := m.Called() + args := m.Called(ctx, accountID) var value core.OnDemandPayment if args.Get(0) != nil { value = args.Get(0).(core.OnDemandPayment)