Skip to content

Commit

Permalink
make callback function for blockContext
Browse files Browse the repository at this point in the history
  • Loading branch information
rabbitprincess committed Dec 5, 2023
1 parent e265ab9 commit 4217719
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 40 deletions.
2 changes: 1 addition & 1 deletion chain/chainhandle.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ func NewTxExecutor(execCtx context.Context, ccc consensus.ChainConsensusCluster,
return ErrInvalidBlockHeader
}
blockSnap := bState.Snapshot()
evmService := evm.NewEVM(bState.GetEvmRoot(), bState.EvmStateDB)
evmService := evm.NewEVM(bState.GetEvmRoot(), cdb, bState.LuaStateDB, bState.EvmStateDB)
err := executeTx(execCtx, ccc, cdb, bState, tx, bi, preloadService, evmService)
if err != nil {
logger.Error().Err(err).Str("hash", base58.Encode(tx.GetHash())).Msg("tx failed")
Expand Down
4 changes: 2 additions & 2 deletions chain/chainhandle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestErrorInExecuteTx(t *testing.T) {
initTest(t, true)
defer deinitTest()
bs := state.NewBlockState(sdb.GetStateDB(), sdb.OpenEvmStateDB(nil))
evmService := evm.NewEVM(bs.GetEvmRoot(), bs.EvmStateDB)
evmService := evm.NewEVM(bs.GetEvmRoot(), nil, bs.LuaStateDB, bs.EvmStateDB)

tx := &types.Tx{}
err := executeTx(nil, nil, nil, bs, types.NewTransaction(tx), newTestBlockInfo(chainID), contract.ChainService, evmService)
Expand Down Expand Up @@ -111,7 +111,7 @@ func TestBasicExecuteTx(t *testing.T) {
initTest(t, true)
defer deinitTest()
bs := state.NewBlockState(sdb.GetStateDB(), sdb.OpenEvmStateDB(nil))
evmService := evm.NewEVM(bs.GetEvmRoot(), bs.EvmStateDB)
evmService := evm.NewEVM(bs.GetEvmRoot(), nil, bs.LuaStateDB, bs.EvmStateDB)

tx := &types.Tx{Body: &types.TxBody{}}

Expand Down
2 changes: 1 addition & 1 deletion chain/chainservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ func (cw *ChainWorker) Receive(context actor.Context) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
logger.Info().Msgf("evm query received for contract %s with payload %s", hex.EncodeToString(msg.Contract), hex.EncodeToString(msg.Queryinfo))
evmService := evm.NewEVMQuery(cw.sdb.EvmRootHash, cw.sdb.OpenEvmStateDB(nil))
evmService := evm.NewEVMQuery(cw.cdb, cw.sdb.EvmRootHash, cw.sdb.OpenNewStateDB(nil), cw.sdb.OpenEvmStateDB(nil))
res, _, err := evmService.Query(nil, msg.Contract, msg.Queryinfo)
context.Respond(message.GetEVMQueryRsp{Result: res, Err: err})
case *message.GetElected:
Expand Down
65 changes: 39 additions & 26 deletions evm/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"github.com/aergoio/aergo-lib/log"
"github.com/aergoio/aergo/v2/state"
"github.com/aergoio/aergo/v2/state/ethdb"
"github.com/aergoio/aergo/v2/state/statedb"
"github.com/aergoio/aergo/v2/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/vm"
)
Expand All @@ -15,26 +17,33 @@ var (
logger = log.NewLogger("evm")
)

type EVM struct {
accounts map[common.Address]*state.AccountState
blocks map[uint64][]byte
type ChainAccessor interface {
GetBlockByNo(blockNo types.BlockNo) (*types.Block, error)
GetBestBlock() (*types.Block, error)
}

readonly bool
ethState *ethdb.StateDB
stateRoot common.Hash
type EVM struct {
readonly bool
chainAccessor ChainAccessor
luaState *statedb.StateDB
ethState *ethdb.StateDB
stateRoot common.Hash
}

func NewEVM(prevStateRoot []byte, ethState *ethdb.StateDB) *EVM {
func NewEVM(prevStateRoot []byte, chainAccessor ChainAccessor, luaState *statedb.StateDB, ethState *ethdb.StateDB) *EVM {
return &EVM{
readonly: false,
stateRoot: common.BytesToHash(prevStateRoot),
ethState: ethState,
readonly: false,
chainAccessor: chainAccessor,
stateRoot: common.BytesToHash(prevStateRoot),
luaState: luaState,
ethState: ethState,
}
}

func NewEVMQuery(queryStateRoot []byte, ethState *ethdb.StateDB) *EVM {
func NewEVMQuery(chainAccessor ChainAccessor, queryStateRoot []byte, luaState *statedb.StateDB, ethState *ethdb.StateDB) *EVM {
return &EVM{
readonly: true,
luaState: nil,
stateRoot: common.BytesToHash(queryStateRoot),
ethState: ethState,
}
Expand Down Expand Up @@ -119,33 +128,37 @@ func (e *EVM) Create(ethAddress common.Address, payload []byte) ([]byte, []byte,

func (e *EVM) GetHashFn() vm.GetHashFunc {
return func(n uint64) common.Hash {
blockHash := e.blocks[n]
return common.BytesToHash(blockHash)
block, err := e.chainAccessor.GetBlockByNo(n)
if err != nil {
return common.Hash{}
}
return common.BytesToHash(block.Hash)
}
}

func (e *EVM) TransferFn() vm.TransferFunc {
return func(db vm.StateDB, sender, recipient common.Address, amount *big.Int) {
if senderState := e.accounts[sender]; senderState != nil {
senderState.SubBalance(amount)
} else {
// TODO - get from state
senderAccState, err := state.GetAccountState(e.ethState.GetId(sender), e.luaState, e.ethState)
if err != nil {
panic("impossible") // FIXME
}
receipientAccState, err := state.GetAccountState(e.ethState.GetId(recipient), e.luaState, e.ethState)
if err != nil {
panic("impossible") // FIXME
}
if receipientState := e.accounts[recipient]; receipientState != nil {
receipientState.AddBalance(amount)
} else {
// TODO - get from state
err = state.SendBalance(senderAccState, receipientAccState, amount)
if err != nil {
panic("impossible") // FIXME
}
}
}

func (e *EVM) CanTransferFn() vm.CanTransferFunc {
return func(sdb vm.StateDB, addr common.Address, amount *big.Int) bool {
if state := e.accounts[addr]; state != nil {
return state.Balance().Cmp(amount) >= 0
} else {
// TODO - get from state
accState, err := state.GetAccountState(e.ethState.GetId(addr), e.luaState, e.ethState)
if err != nil {
panic("impossible") // FIXME
}
return false
return accState.Balance().Cmp(amount) >= 0
}
}
2 changes: 1 addition & 1 deletion state/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (as *AccountState) PutState() error {
return err
}
if as.ethStates != nil {
as.ethStates.PutState(as.ethId, new(big.Int).SetBytes(as.newState.Balance), as.newState.Nonce, nil)
as.ethStates.PutState(as.id, as.ethId, new(big.Int).SetBytes(as.newState.Balance), as.newState.Nonce, nil)
}
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion state/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (sdb *ChainStateDB) Init(dbType string, dataDir string, bestBlock *types.Bl
}

if sdb.ethStore == nil {
dbPath := common.PathMkdirAll(dataDir, "state_evm")
dbPath := common.PathMkdirAll(dataDir, ethdb.StateName)
sdb.ethStore, err = ethdb.NewDB(dbPath, dbType)
if err != nil {
return err
Expand Down
29 changes: 22 additions & 7 deletions state/ethdb/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package ethdb
import (
"math/big"

"github.com/aergoio/aergo/v2/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/trie"
)

const (
StateName = "evm_state"
StateName = "state_evm"
)

type StateDB struct {
Expand Down Expand Up @@ -43,16 +44,30 @@ func (sdb *StateDB) GetStateDB() *state.StateDB {
return sdb.evmStateDB
}

func (sdb *StateDB) PutState(addr common.Address, balance *big.Int, nonce uint64, code []byte) {
func (sdb *StateDB) PutState(id []byte, addr common.Address, balance *big.Int, nonce uint64, code []byte) {
sdb.evmStateDB.SetNonce(addr, nonce)
sdb.evmStateDB.SetBalance(addr, balance)
if len(code) > 0 {
sdb.evmStateDB.SetCode(addr, code)
}

// id must be 33 bytes
idWithCode := make([]byte, types.AddressLength+len(code))
copy(idWithCode, id)
copy(idWithCode[types.AddressLength:], code)

sdb.evmStateDB.SetCode(addr, idWithCode)
}

func (sdb *StateDB) GetState(addr common.Address) (id []byte, balance *big.Int, nonce uint64, code []byte) {
idWithCode := sdb.evmStateDB.GetCode(addr)
id = idWithCode[:types.AddressLength]
balance = sdb.evmStateDB.GetBalance(addr)
nonce = sdb.evmStateDB.GetNonce(addr)
code = idWithCode[types.AddressLength:]
return id, balance, nonce, code
}

func (sdb *StateDB) GetState(addr common.Address) (balance *big.Int, nonce uint64, code []byte) {
return sdb.evmStateDB.GetBalance(addr), sdb.evmStateDB.GetNonce(addr), sdb.evmStateDB.GetCode(addr)
func (sdb *StateDB) GetId(addr common.Address) (id []byte) {
idWithCode := sdb.evmStateDB.GetCode(addr)
return idWithCode[:types.AddressLength]
}

func (sdb *StateDB) Root() []byte {
Expand Down
2 changes: 1 addition & 1 deletion state/ethdb/statedb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestState(t *testing.T) {
balance := big.NewInt(100)
nonce := uint64(0)
code := []byte("code")
sdbOld.PutState(addr, balance, nonce, code)
sdbOld.PutState(nil, addr, balance, nonce, code)

newRoot, err := sdbOld.Commit(0)
require.NoError(t, err)
Expand Down

0 comments on commit 4217719

Please sign in to comment.