diff --git a/core/types/block.libevm_test.go b/core/types/block.libevm_test.go index 89a007b6aadbf..7ddf0ce6a3e27 100644 --- a/core/types/block.libevm_test.go +++ b/core/types/block.libevm_test.go @@ -17,15 +17,18 @@ package types_test import ( + "bytes" "encoding/json" "errors" "fmt" "io" + "math/big" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/ava-labs/libevm/common" . "github.com/ava-labs/libevm/core/types" "github.com/ava-labs/libevm/crypto" "github.com/ava-labs/libevm/libevm/ethtest" @@ -199,3 +202,83 @@ func TestHeaderHooks(t *testing.T) { } }) } + +type stubBodyHooks struct { + required uint32 + optional uint32 + rlpFields rlp.Fields +} + +func (bh *stubBodyHooks) RLPFieldsForEncoding(body *Body) *rlp.Fields { + return &rlp.Fields{ + Required: bh.rlpFields.Required, + Optional: bh.rlpFields.Optional, + } +} + +func (bh *stubBodyHooks) RLPFieldPointersForDecoding(b *Body) *rlp.Fields { + return &rlp.Fields{ + Required: bh.rlpFields.Required, + Optional: bh.rlpFields.Optional, + } +} + +func Test_BodyHooks(t *testing.T) { + TestOnlyClearRegisteredExtras() + t.Cleanup(TestOnlyClearRegisteredExtras) + + extras := RegisterExtras[ + stubHeaderHooks, *stubHeaderHooks, + stubBodyHooks, *stubBodyHooks, + struct{}]() + + t.Run("RLP", func(t *testing.T) { + // Setup body with extras + testTx := NewTransaction(1, common.Address{2}, big.NewInt(3), 4, big.NewInt(5), []byte{6}) + body := &Body{ + Transactions: []*Transaction{testTx}, + Uncles: []*Header{{ParentHash: common.Hash{7}}}, // ignored + Withdrawals: []*Withdrawal{{Amount: 8}}, // ignored + } + bodyExtras := &stubBodyHooks{ + required: 9, + optional: 10, + } + bodyExtras.rlpFields.Required = []any{ + &body.Transactions, + &body.Uncles, + &bodyExtras.required, + } + bodyExtras.rlpFields.Optional = []any{ + // Withdrawals not present + &bodyExtras.optional, + } + extras.Body.Set(body, bodyExtras) + + // Check encoding + wantEncoded := bytes.NewBuffer(nil) + err := bodyExtras.rlpFields.EncodeRLP(wantEncoded) + require.NoError(t, err, "rlpFields.EncodeRLP(buffer)") + + gotEncoded, err := rlp.EncodeToBytes(body) + require.NoError(t, err, "rlp.EncodeToBytes(%T)", body) + assert.Equal(t, wantEncoded.Bytes(), gotEncoded) + + // Check decoding + wantDecoded := &Body{ + Transactions: []*Transaction{testTx}, + } + decodedBodyExtras := &stubBodyHooks{ + required: 9, + optional: 10, + rlpFields: bodyExtras.rlpFields, + } + extras.Body.Set(wantDecoded, decodedBodyExtras) + + gotDecoded := new(Body) + err = rlp.DecodeBytes(gotEncoded, gotDecoded) + require.NoErrorf(t, err, "rlp.DecodeBytes(%#v)", gotEncoded) + + assert.Equal(t, wantDecoded, gotDecoded) + }) +}