Skip to content

Commit

Permalink
add callbacks on db tx
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaubennassar committed Sep 12, 2024
1 parent a212233 commit d17db59
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 18 deletions.
4 changes: 2 additions & 2 deletions bridgesync/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func (p *processor) getLastProcessedBlockWithTx(tx db.DBer) (uint64, error) {
// Reorg triggers a purge and reset process on the processor to leaf it on a state
// as if the last block processed was firstReorgedBlock-1
func (p *processor) Reorg(ctx context.Context, firstReorgedBlock uint64) error {
tx, err := p.db.BeginTx(ctx, nil)
tx, err := db.NewTx(ctx, p.db)
if err != nil {
return err
}
Expand Down Expand Up @@ -222,7 +222,7 @@ func (p *processor) Reorg(ctx context.Context, firstReorgedBlock uint64) error {
// ProcessBlock process the events of the block to build the exit tree
// and updates the last processed block (can be called without events for that purpose)
func (p *processor) ProcessBlock(ctx context.Context, block sync.Block) error {
tx, err := p.db.BeginTx(ctx, nil)
tx, err := db.NewTx(ctx, p.db)
if err != nil {
return err
}
Expand Down
49 changes: 49 additions & 0 deletions db/tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package db

import (
"context"
"database/sql"
)

type Tx struct {
*sql.Tx
rollbackCallbacks []func()
commitCallbacks []func()
}

func NewTx(ctx context.Context, db *sql.DB) (*Tx, error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &Tx{
Tx: tx,
}, nil
}

func (s *Tx) AddRollbackCallback(cb func()) {
s.rollbackCallbacks = append(s.rollbackCallbacks, cb)
}
func (s *Tx) AddCommitCallback(cb func()) {
s.commitCallbacks = append(s.commitCallbacks, cb)
}

func (s *Tx) Commit() error {
if err := s.Tx.Commit(); err != nil {
return err
}
for _, cb := range s.commitCallbacks {
cb()
}
return nil
}

func (s *Tx) Rollback() error {
if err := s.Tx.Rollback(); err != nil {
return err
}
for _, cb := range s.rollbackCallbacks {
cb()
}
return nil
}
4 changes: 2 additions & 2 deletions l1infotreesync/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (p *processor) getLastProcessedBlockWithTx(tx db.DBer) (uint64, error) {
// Reorg triggers a purge and reset process on the processor to leaf it on a state
// as if the last block processed was firstReorgedBlock-1
func (p *processor) Reorg(ctx context.Context, firstReorgedBlock uint64) error {
tx, err := p.db.BeginTx(ctx, nil)
tx, err := db.NewTx(ctx, p.db)
if err != nil {
return err
}
Expand Down Expand Up @@ -218,7 +218,7 @@ func (p *processor) Reorg(ctx context.Context, firstReorgedBlock uint64) error {
// ProcessBlock process the events of the block to build the rollup exit tree and the l1 info tree
// and updates the last processed block (can be called without events for that purpose)
func (p *processor) ProcessBlock(ctx context.Context, b sync.Block) error {
tx, err := p.db.BeginTx(ctx, nil)
tx, err := db.NewTx(ctx, p.db)
if err != nil {
return err
}
Expand Down
6 changes: 4 additions & 2 deletions tree/appendonlytree.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"

"github.com/0xPolygon/cdk/db"
"github.com/0xPolygon/cdk/tree/types"
"github.com/ethereum/go-ethereum/common"
)
Expand All @@ -26,7 +27,7 @@ func NewAppendOnlyTree(db *sql.DB, dbPrefix string) *AppendOnlyTree {
}
}

func (t *AppendOnlyTree) AddLeaf(tx *sql.Tx, blockNum, blockPosition uint64, leaf types.Leaf) error {
func (t *AppendOnlyTree) AddLeaf(tx *db.Tx, blockNum, blockPosition uint64, leaf types.Leaf) error {
if int64(leaf.Index) != t.lastIndex+1 {
// rebuild cache
if err := t.initCache(tx); err != nil {
Expand Down Expand Up @@ -72,10 +73,11 @@ func (t *AppendOnlyTree) AddLeaf(tx *sql.Tx, blockNum, blockPosition uint64, lea
return err
}
t.lastIndex++
tx.AddRollbackCallback(func() { t.lastIndex-- })
return nil
}

func (t *AppendOnlyTree) initCache(tx *sql.Tx) error {
func (t *AppendOnlyTree) initCache(tx *db.Tx) error {
siblings := [types.DefaultHeight]common.Hash{}
lastRoot, err := t.getLastRootWithTx(tx)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions tree/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func generateZeroHashes(height uint8) []common.Hash {
return zeroHashes
}

func (t *Tree) storeNodes(tx db.DBer, nodes []types.TreeNode) error {
func (t *Tree) storeNodes(tx *db.Tx, nodes []types.TreeNode) error {
for _, node := range nodes {
if err := meddler.Insert(tx, t.rhtTable, &node); err != nil {
if sqliteErr, ok := db.SQLiteErr(err); ok {
Expand All @@ -167,7 +167,7 @@ func (t *Tree) storeNodes(tx db.DBer, nodes []types.TreeNode) error {
return nil
}

func (t *Tree) storeRoot(tx db.DBer, root types.Root) error {
func (t *Tree) storeRoot(tx *db.Tx, root types.Root) error {
return meddler.Insert(tx, t.rootTable, &root)
}

Expand Down Expand Up @@ -241,7 +241,7 @@ func (t *Tree) GetLeaf(ctx context.Context, index uint32, root common.Hash) (com
}

// Reorg deletes all the data relevant from firstReorgedBlock (includded) and onwards
func (t *Tree) Reorg(tx db.DBer, firstReorgedBlock uint64) error {
func (t *Tree) Reorg(tx *db.Tx, firstReorgedBlock uint64) error {
_, err := tx.Exec(
fmt.Sprintf(`DELETE FROM %s WHERE block_num >= $1`, t.rootTable),
firstReorgedBlock,
Expand Down
16 changes: 8 additions & 8 deletions tree/tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ func TestMTAddLeaf(t *testing.T) {
log.Debug("DB created at: ", dbPath)
err := migrations.RunMigrations(dbPath)
require.NoError(t, err)
db, err := db.NewSQLiteDB(dbPath)
treeDB, err := db.NewSQLiteDB(dbPath)
require.NoError(t, err)
_, err = db.Exec(`select * from root`)
_, err = treeDB.Exec(`select * from root`)
require.NoError(t, err)
merkletree := tree.NewAppendOnlyTree(db, "")
merkletree := tree.NewAppendOnlyTree(treeDB, "")

// Add exisiting leaves
tx, err := db.BeginTx(ctx, nil)
tx, err := db.NewTx(ctx, treeDB)
require.NoError(t, err)
for i, leaf := range testVector.ExistingLeaves {
err = merkletree.AddLeaf(tx, uint64(i), 0, types.Leaf{
Expand All @@ -57,7 +57,7 @@ func TestMTAddLeaf(t *testing.T) {
}

// Add new bridge
tx, err = db.BeginTx(ctx, nil)
tx, err = db.NewTx(ctx, treeDB)
require.NoError(t, err)
err = merkletree.AddLeaf(tx, uint64(len(testVector.ExistingLeaves)), 0, types.Leaf{
Index: uint32(len(testVector.ExistingLeaves)),
Expand Down Expand Up @@ -87,11 +87,11 @@ func TestMTGetProof(t *testing.T) {
dbPath := path.Join(t.TempDir(), "file::memory:?cache=shared")
err := migrations.RunMigrations(dbPath)
require.NoError(t, err)
db, err := db.NewSQLiteDB(dbPath)
treeDB, err := db.NewSQLiteDB(dbPath)
require.NoError(t, err)
tre := tree.NewAppendOnlyTree(db, "")
tre := tree.NewAppendOnlyTree(treeDB, "")

tx, err := db.BeginTx(ctx, nil)
tx, err := db.NewTx(ctx, treeDB)
require.NoError(t, err)
for li, leaf := range testVector.Deposits {
err = tre.AddLeaf(tx, uint64(li), 0, types.Leaf{
Expand Down
3 changes: 2 additions & 1 deletion tree/updatabletree.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tree
import (
"database/sql"

"github.com/0xPolygon/cdk/db"
"github.com/0xPolygon/cdk/tree/types"
"github.com/ethereum/go-ethereum/common"
)
Expand All @@ -21,7 +22,7 @@ func NewUpdatableTree(db *sql.DB, dbPrefix string) *UpdatableTree {
return ut
}

func (t *UpdatableTree) UpsertLeaf(tx *sql.Tx, blockNum, blockPosition uint64, leaf types.Leaf) error {
func (t *UpdatableTree) UpsertLeaf(tx *db.Tx, blockNum, blockPosition uint64, leaf types.Leaf) error {
var rootHash common.Hash
root, err := t.getLastRootWithTx(tx)
if err != nil {
Expand Down

0 comments on commit d17db59

Please sign in to comment.