Skip to content
This repository has been archived by the owner on Jan 9, 2025. It is now read-only.

Commit

Permalink
Use model in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementWalter committed Oct 1, 2024
1 parent 534a34e commit 0339d7f
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 62 deletions.
10 changes: 5 additions & 5 deletions kakarot_scripts/utils/kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from eth_abi import decode
from eth_abi.exceptions import InsufficientDataBytes
from eth_account import Account as EvmAccount
from eth_account.typed_transactions import TypedTransaction
from eth_keys import keys
from eth_utils import keccak
from eth_utils.address import to_checksum_address
Expand Down Expand Up @@ -48,7 +47,8 @@
from kakarot_scripts.utils.starknet import wait_for_transaction
from kakarot_scripts.utils.uint256 import int_to_uint256
from tests.utils.constants import TRANSACTION_GAS_LIMIT
from tests.utils.helpers import pack_calldata, rlp_encode_signed_data
from tests.utils.helpers import pack_calldata
from tests.utils.models import TransactionModel

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -596,10 +596,10 @@ async def eth_send_transaction(
"data": data,
}

typed_transaction = TypedTransaction.from_dict(payload)
typed_transaction = TransactionModel.model_validate(payload)

evm_tx = EvmAccount.sign_transaction(
typed_transaction.as_dict(),
typed_transaction.model_dump(),
hex(evm_account.signer.private_key),
)

Expand All @@ -610,7 +610,7 @@ async def eth_send_transaction(
)
return receipt, [], receipt.status, receipt.gasUsed

encoded_unsigned_tx = rlp_encode_signed_data(typed_transaction.as_dict())
encoded_unsigned_tx = typed_transaction.encode()
packed_encoded_unsigned_tx = pack_calldata(bytes(encoded_unsigned_tx))
return await send_starknet_transaction(
evm_account,
Expand Down
9 changes: 5 additions & 4 deletions tests/src/kakarot/accounts/test_account_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from kakarot_scripts.utils.uint256 import int_to_uint256
from tests.utils.constants import CHAIN_ID, TRANSACTIONS
from tests.utils.errors import cairo_error
from tests.utils.helpers import generate_random_private_key, rlp_encode_signed_data
from tests.utils.helpers import generate_random_private_key
from tests.utils.hints import patch_hint
from tests.utils.models import TransactionModel
from tests.utils.syscall_handler import SyscallHandler

CHAIN_ID_OFFSET = 35
Expand Down Expand Up @@ -376,7 +377,7 @@ def test_should_raise_invalid_signature_for_invalid_chain_id_when_tx_type0_not_p
"chainId": CHAIN_ID,
"data": b"",
}
tx_data = list(rlp_encode_signed_data(transaction))
tx_data = TransactionModel.model_validate(transaction).encode()
private_key = generate_random_private_key()
address = int(private_key.public_key.to_checksum_address(), 16)
signed = Account.sign_transaction(transaction, private_key)
Expand Down Expand Up @@ -454,7 +455,7 @@ def test_pass_all_transactions_types(self, cairo_run, seed, transaction):
address = int(private_key.public_key.to_checksum_address(), 16)
signed = Account.sign_transaction(transaction, private_key)
signature = [*int_to_uint256(signed.r), *int_to_uint256(signed.s), signed.v]
tx_data = list(rlp_encode_signed_data(transaction))
tx_data = TransactionModel.model_validate(transaction).encode()

with (
SyscallHandler.patch("Account_evm_address", address),
Expand Down Expand Up @@ -490,7 +491,7 @@ def test_should_pass_all_data_len(self, cairo_run, bytecode):
"chainId": CHAIN_ID,
"data": bytecode,
}
tx_data = list(rlp_encode_signed_data(transaction))
tx_data = TransactionModel.model_validate(transaction).encode()

private_key = generate_random_private_key()
address = int(private_key.public_key.to_checksum_address(), 16)
Expand Down
20 changes: 10 additions & 10 deletions tests/src/kakarot/test_kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from kakarot_scripts.ef_tests.fetch import EF_TESTS_PARSED_DIR
from tests.utils.constants import CHAIN_ID, TRANSACTION_GAS_LIMIT, TRANSACTIONS
from tests.utils.errors import cairo_error
from tests.utils.helpers import felt_to_signed_int, rlp_encode_signed_data
from tests.utils.helpers import felt_to_signed_int
from tests.utils.models import TransactionModel
from tests.utils.syscall_handler import SyscallHandler, parse_state

CONTRACT_ADDRESS = 1234
Expand Down Expand Up @@ -520,7 +521,7 @@ def test_should_raise_invalid_chain_id_tx_type_different_from_0(
"accessList": [],
"chainId": 9999,
}
tx_data = list(rlp_encode_signed_data(transaction))
tx_data = TransactionModel.model_validate(transaction).encode()

with cairo_error(message="Invalid chain id"):
cairo_run(
Expand All @@ -533,8 +534,7 @@ def test_should_raise_invalid_chain_id_tx_type_different_from_0(
@pytest.mark.parametrize("tx", TRANSACTIONS)
def test_should_raise_invalid_nonce(self, cairo_run, tx):
# explicitly set the nonce in transaction to be different from the patch
tx = {**tx, "nonce": 0}
tx_data = list(rlp_encode_signed_data(tx))
tx_data = TransactionModel.model_validate({**tx, "nonce": 0}).encode()
with cairo_error(message="Invalid nonce"):
cairo_run(
"test__eth_send_raw_unsigned_tx",
Expand All @@ -557,7 +557,7 @@ def test_raise_gas_limit_too_high(self, cairo_run, gas_limit):
"accessList": [],
"chainId": CHAIN_ID,
}
tx_data = list(rlp_encode_signed_data(tx))
tx_data = TransactionModel.model_validate(tx).encode()

with cairo_error(message="Gas limit too high"):
cairo_run(
Expand All @@ -581,7 +581,7 @@ def test_raise_max_fee_per_gas_too_high(self, cairo_run, maxFeePerGas):
"accessList": [],
"chainId": CHAIN_ID,
}
tx_data = list(rlp_encode_signed_data(tx))
tx_data = TransactionModel.model_validate(tx).encode()

with cairo_error(message="Max fee per gas too high"):
cairo_run(
Expand All @@ -592,7 +592,7 @@ def test_raise_max_fee_per_gas_too_high(self, cairo_run, maxFeePerGas):

@pytest.mark.parametrize("tx", TRANSACTIONS)
def test_raise_transaction_gas_limit_too_high(self, cairo_run, tx):
tx_data = list(rlp_encode_signed_data(tx))
tx_data = TransactionModel.model_validate(tx).encode()

with (
SyscallHandler.patch("IAccount.get_nonce", lambda _, __: [tx["nonce"]]),
Expand All @@ -608,7 +608,7 @@ def test_raise_transaction_gas_limit_too_high(self, cairo_run, tx):
@SyscallHandler.patch("Kakarot_base_fee", TRANSACTION_GAS_LIMIT * 10**10)
@pytest.mark.parametrize("tx", TRANSACTIONS)
def test_raise_max_fee_per_gas_too_low(self, cairo_run, tx):
tx_data = list(rlp_encode_signed_data(tx))
tx_data = TransactionModel.model_validate(tx).encode()

with (
SyscallHandler.patch("IAccount.get_nonce", lambda _, __: [tx["nonce"]]),
Expand Down Expand Up @@ -646,7 +646,7 @@ def test_raise_max_priority_fee_too_high(
"accessList": [],
"chainId": CHAIN_ID,
}
tx_data = list(rlp_encode_signed_data(tx))
tx_data = TransactionModel.model_validate(tx).encode()

with cairo_error(message="Max priority fee greater than max fee per gas"):
cairo_run(
Expand All @@ -660,7 +660,7 @@ def test_raise_max_priority_fee_too_high(
@SyscallHandler.patch("IAccount.get_evm_address", lambda _, __: [0xABDE1])
@pytest.mark.parametrize("tx", TRANSACTIONS)
def test_raise_not_enough_ETH_balance(self, cairo_run, tx):
tx_data = list(rlp_encode_signed_data(tx))
tx_data = TransactionModel.model_validate(tx).encode()

with (
SyscallHandler.patch("IAccount.get_nonce", lambda _, __: [tx["nonce"]]),
Expand Down
18 changes: 8 additions & 10 deletions tests/src/utils/test_eth_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from tests.utils.constants import INVALID_TRANSACTIONS, TRANSACTIONS
from tests.utils.errors import cairo_error
from tests.utils.helpers import flatten_tx_access_list, rlp_encode_signed_data
from tests.utils.helpers import flatten_tx_access_list
from tests.utils.models import TransactionModel


class TestEthTransaction:
Expand Down Expand Up @@ -83,8 +84,8 @@ async def test_should_raise_with_params_overflow(
async def test_should_decode_all_transactions_types(
self, cairo_run, transaction
):
encoded_unsigned_tx = rlp_encode_signed_data(transaction)
decoded_tx = cairo_run("test__decode", data=list(encoded_unsigned_tx))
tx = TransactionModel.model_validate(transaction).encode()
decoded_tx = cairo_run("test__decode", data=list(tx))

expected_data = (
"0x" + transaction["data"].hex()
Expand Down Expand Up @@ -113,12 +114,9 @@ async def test_should_decode_all_transactions_types(
async def test_should_panic_on_unsupported_tx_types(
self, cairo_run, transaction
):
encoded_unsigned_tx = rlp_encode_signed_data(transaction)
tx = TransactionModel.model_validate(transaction).encode()
with cairo_error("Kakarot: transaction type not supported"):
cairo_run(
"test__decode",
data=list(encoded_unsigned_tx),
)
cairo_run("test__decode", data=list(tx))

class TestParseAccessList:
@pytest.mark.parametrize("transaction", TRANSACTIONS)
Expand All @@ -144,8 +142,8 @@ def test_should_parse_access_list(self, cairo_run, transaction):
class TestGetTxType:
@pytest.mark.parametrize("transaction", TRANSACTIONS)
def test_should_return_tx_type(self, cairo_run, transaction):
encoded_unsigned_tx = rlp_encode_signed_data(transaction)
tx_type = cairo_run("test__get_tx_type", data=list(encoded_unsigned_tx))
tx = TransactionModel.model_validate(transaction).encode()
tx_type = cairo_run("test__get_tx_type", data=list(tx))
assert tx_type == transaction.get("type", 0)

def test_should_raise_when_data_len_is_zero(self, cairo_run):
Expand Down
5 changes: 3 additions & 2 deletions tests/src/utils/test_rlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tests.utils.constants import TRANSACTIONS
from tests.utils.errors import cairo_error
from tests.utils.helpers import rlp_encode_signed_data
from tests.utils.models import TransactionModel


class TestRLP:
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_raise_when_data_contains_extra_bytes(

@pytest.mark.parametrize("transaction", TRANSACTIONS)
def test_should_decode_all_tx_types(self, cairo_run, transaction):
encoded_unsigned_tx = rlp_encode_signed_data(transaction)
encoded_unsigned_tx = TransactionModel.model_validate(transaction).encode()
if "type" in transaction:
# remove the type info from the encoded RLP
# create bytearray from bytes list and remove the first byte
Expand All @@ -81,3 +81,4 @@ def test_should_decode_all_tx_types(self, cairo_run, transaction):

items = cairo_run("test__decode", data=list(rlp_encoding))
assert items[0] == decode(rlp_encoding)
assert items[0] == decode(rlp_encoding)
31 changes: 0 additions & 31 deletions tests/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import rlp
from eth_abi import encode
from eth_account._utils.transaction_utils import transaction_rpc_to_rlp_structure
from eth_account.typed_transactions import TypedTransaction
from eth_keys import keys
from eth_utils import decode_hex, keccak, to_checksum_address
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME
Expand All @@ -20,35 +18,6 @@
)


def rlp_encode_signed_data(tx: dict) -> bytes:
if "type" in tx:
typed_transaction = TypedTransaction.from_dict(tx)

sanitized_transaction = transaction_rpc_to_rlp_structure(
typed_transaction.transaction.dictionary
)

# RPC-structured transaction to rlp-structured transaction
rlp_serializer = (
typed_transaction.transaction.__class__._unsigned_transaction_serializer
)
return [
typed_transaction.transaction_type,
*rlp.encode(rlp_serializer.from_dict(sanitized_transaction)),
]
else:
legacy_tx = [
tx["nonce"],
tx["gasPrice"],
tx["gas"] if "gas" in tx else tx["gasLimit"],
bytes.fromhex(f"{int(tx['to'], 16):040x}"),
tx["value"],
tx["data"],
] + ([tx["chainId"], 0, 0] if "chainId" in tx else [])

return rlp.encode(legacy_tx)


def hex_string_to_bytes_array(h: str):
if len(h) % 2 != 0:
raise ValueError(f"Provided string has an odd length {len(h)}")
Expand Down

0 comments on commit 0339d7f

Please sign in to comment.