From 05f9379ddef36536e810a0a6ea6b29d261527323 Mon Sep 17 00:00:00 2001 From: hopeyen Date: Mon, 21 Oct 2024 17:35:27 -0700 Subject: [PATCH] feat: add on-chain write lock --- core/meterer/onchain_state.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/meterer/onchain_state.go b/core/meterer/onchain_state.go index d5523a9b3b..b351966197 100644 --- a/core/meterer/onchain_state.go +++ b/core/meterer/onchain_state.go @@ -3,6 +3,7 @@ package meterer import ( "context" "errors" + "sync" "github.com/Layr-Labs/eigenda/core" "github.com/Layr-Labs/eigenda/core/eth" @@ -26,6 +27,8 @@ type OnchainPaymentState struct { ActiveReservations map[string]core.ActiveReservation OnDemandPayments map[string]core.OnDemandPayment OnDemandQuorumNumbers []uint8 + ReservationsLock sync.RWMutex + PaymentsLock sync.RWMutex } func NewOnchainPaymentState(ctx context.Context, tx *eth.Transactor) (OnchainPaymentState, error) { @@ -54,6 +57,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, return err } + pcs.ReservationsLock.Lock() accountIDs := make([]string, 0, len(pcs.ActiveReservations)) for accountID := range pcs.ActiveReservations { accountIDs = append(accountIDs, accountID) @@ -64,7 +68,9 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, return err } pcs.ActiveReservations = activeReservations + pcs.ReservationsLock.Unlock() + pcs.PaymentsLock.Lock() accountIDs = make([]string, 0, len(pcs.OnDemandPayments)) for accountID := range pcs.OnDemandPayments { accountIDs = append(accountIDs, accountID) @@ -75,6 +81,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, return err } pcs.OnDemandPayments = onDemandPayments + pcs.PaymentsLock.Unlock() return nil } @@ -94,7 +101,9 @@ func (pcs *OnchainPaymentState) GetActiveReservationByAccount(ctx context.Contex return core.ActiveReservation{}, errors.New("payment not found") } + pcs.ReservationsLock.Lock() pcs.ActiveReservations[accountID] = res + pcs.ReservationsLock.Unlock() return res, nil } @@ -113,7 +122,9 @@ func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, return core.OnDemandPayment{}, errors.New("payment not found") } + pcs.PaymentsLock.Lock() pcs.OnDemandPayments[accountID] = res + pcs.PaymentsLock.Unlock() return res, nil }