Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: trie_set #382

Merged
merged 7 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions cairo/ethereum/cancun/fork_types.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from starkware.cairo.common.alloc import alloc

from ethereum_types.bytes import Bytes20, Bytes32, Bytes256, Bytes, BytesStruct, HashedBytes32
from ethereum.utils.bytes import Bytes__eq__
from ethereum_types.numeric import Uint, U256, U256Struct, bool
from ethereum.crypto.hash import Hash32

Expand Down Expand Up @@ -117,3 +118,26 @@ func EMPTY_ACCOUNT() -> Account {
tempvar account = Account(value=new AccountStruct(nonce=Uint(0), balance=balance, code=code));
return account;
}

func Account__eq__(a: Account, b: Account) -> bool {
Eikix marked this conversation as resolved.
Show resolved Hide resolved
if (a.value.nonce.value != b.value.nonce.value) {
tempvar res = bool(0);
return res;
}
if (a.value.balance.value.low != b.value.balance.value.low) {
tempvar res = bool(0);
return res;
}
if (a.value.balance.value.high != b.value.balance.value.high) {
tempvar res = bool(0);
return res;
}
if (a.value.code.value.len != b.value.code.value.len) {
tempvar res = bool(0);
return res;
}

let code_eq = Bytes__eq__(a.value.code, b.value.code);

return code_eq;
}
69 changes: 68 additions & 1 deletion cairo/ethereum/cancun/trie.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ from starkware.cairo.common.cairo_builtins import KeccakBuiltin
from starkware.cairo.common.memcpy import memcpy

from src.utils.bytes import uint256_to_bytes32_little
from src.utils.dict import hashdict_read
from src.utils.dict import hashdict_read, hashdict_write
from ethereum.crypto.hash import keccak256
from ethereum.utils.numeric import min, is_zero
from ethereum.rlp import encode, _encode_bytes, _encode
from ethereum.utils.numeric import U256__eq__
from ethereum_types.numeric import U256, Uint, bool, U256Struct
from ethereum_types.bytes import (
Bytes,
Expand All @@ -30,6 +31,7 @@ from ethereum_types.bytes import (
from ethereum.cancun.blocks import Receipt, Withdrawal
from ethereum.cancun.fork_types import (
Account,
Account__eq__,
AccountStruct,
Address,
Bytes32U256DictAccess,
Expand Down Expand Up @@ -390,6 +392,71 @@ func trie_get_TrieBytes32U256{poseidon_ptr: PoseidonBuiltin*, trie: TrieBytes32U
return res;
}

func trie_set_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressAccount}(
key: Address, value: Account
) {
let dict_ptr_start = cast(trie.value._data.value.dict_ptr_start, DictAccess*);
let dict_ptr = cast(trie.value._data.value.dict_ptr, DictAccess*);

let is_default = Account__eq__(value, trie.value.default);

with dict_ptr_start, dict_ptr {
let (keys) = alloc();
assert [keys] = key.value;

if (is_default.value != 0) {
Eikix marked this conversation as resolved.
Show resolved Hide resolved
hashdict_write(1, keys, 0);
tempvar dict_ptr_start = dict_ptr_start;
tempvar dict_ptr = dict_ptr;
tempvar poseidon_ptr = poseidon_ptr;
} else {
hashdict_write(1, keys, cast(value.value, felt));
tempvar dict_ptr_start = dict_ptr_start;
tempvar dict_ptr = dict_ptr;
tempvar poseidon_ptr = poseidon_ptr;
}
}
let new_dict_ptr = cast(dict_ptr, AddressAccountDictAccess*);
tempvar mapping = MappingAddressAccount(
new MappingAddressAccountStruct(trie.value._data.value.dict_ptr_start, new_dict_ptr)
);
tempvar trie = TrieAddressAccount(
new TrieAddressAccountStruct(trie.value.secured, trie.value.default, mapping)
);
return ();
}

func trie_set_TrieBytes32U256{poseidon_ptr: PoseidonBuiltin*, trie: TrieBytes32U256}(
key: Bytes32, value: U256
) {
let dict_ptr_start = cast(trie.value._data.value.dict_ptr_start, DictAccess*);
let dict_ptr = cast(trie.value._data.value.dict_ptr, DictAccess*);

let is_default = U256__eq__(value, trie.value.default);

with dict_ptr_start, dict_ptr {
if (is_default.value != 0) {
hashdict_write(2, cast(key.value, felt*), 0);
enitrat marked this conversation as resolved.
Show resolved Hide resolved
tempvar dict_ptr_start = dict_ptr_start;
tempvar dict_ptr = dict_ptr;
tempvar poseidon_ptr = poseidon_ptr;
} else {
hashdict_write(2, cast(key.value, felt*), cast(value.value, felt));
tempvar dict_ptr_start = dict_ptr_start;
tempvar dict_ptr = dict_ptr;
tempvar poseidon_ptr = poseidon_ptr;
}
}
let new_dict_ptr = cast(dict_ptr, Bytes32U256DictAccess*);
tempvar mapping = MappingBytes32U256(
new MappingBytes32U256Struct(trie.value._data.value.dict_ptr_start, new_dict_ptr)
);
tempvar trie = TrieBytes32U256(
new TrieBytes32U256Struct(trie.value.secured, trie.value.default, mapping)
);
return ();
}

func common_prefix_length(a: Bytes, b: Bytes) -> felt {
alloc_locals;
local result;
Expand Down
60 changes: 60 additions & 0 deletions cairo/ethereum/utils/bytes.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from ethereum_types.bytes import Bytes, BytesStruct
from ethereum_types.numeric import bool
from ethereum.utils.numeric import is_zero
from starkware.cairo.common.math import assert_not_equal

func Bytes__eq__(_self: Bytes, other: Bytes) -> bool {
if (_self.value.len != other.value.len) {
tempvar res = bool(0);
return res;
}

// Case diff: we can let the prover do the work of iterating over the bytes,
// return the first different byte index, and assert in cairo that the a[index] != b[index]
tempvar is_diff;
tempvar diff_index;
%{
self_bytes = b''.join([memory[ids._self.value.data + i].to_bytes(1, "little") for i in range(ids._self.value.len)])
other_bytes = b''.join([memory[ids.other.value.data + i].to_bytes(1, "little") for i in range(ids.other.value.len)])
diff_index = next((i for i, (b_self, b_other) in enumerate(zip(self_bytes, other_bytes)) if b_self != b_other), None)
if diff_index is not None:
ids.is_diff = 1
ids.diff_index = diff_index
else:
# No differences found in common prefix. Lengths were checked before
ids.is_diff = 0
ids.diff_index = 0
%}

if (is_diff == 1) {
// Assert that the bytes are different at the first different index
with_attr error_message("Bytes__eq__: bytes at provided index are equal") {
assert_not_equal(_self.value.data[diff_index], other.value.data[diff_index]);
}
tempvar res = bool(0);
return res;
}

// Case equal: we need to iterate over all keys in cairo, because the prover might not have been honest
// about the first different byte index.
tempvar i = 0;

loop:
let index = [ap - 1];
let self_value = cast([fp - 4], BytesStruct*);
let other_value = cast([fp - 3], BytesStruct*);

let is_end = is_zero(index - self_value.len);
tempvar res = bool(1);
jmp end if is_end != 0;

let is_eq = is_zero(self_value.data[index] - other_value.data[index]);

tempvar i = i + 1;
jmp loop if is_eq != 0;
tempvar res = bool(0);

end:
let res = bool([ap - 1]);
return res;
}
11 changes: 10 additions & 1 deletion cairo/ethereum/utils/numeric.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from starkware.cairo.common.math_cmp import is_le, is_not_zero
from starkware.cairo.common.uint256 import uint256_reverse_endian
from ethereum_types.numeric import Uint, U256, U256Struct
from ethereum_types.numeric import Uint, U256, U256Struct, bool
from ethereum_types.bytes import Bytes32, Bytes32Struct
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin

Expand Down Expand Up @@ -139,3 +139,12 @@ func U256_to_le_bytes(value: U256) -> Bytes32 {
tempvar res = Bytes32(value.value);
return res;
}

func U256__eq__(a: U256, b: U256) -> bool {
if (a.value.low == b.value.low and a.value.high == b.value.high) {
tempvar res = bool(1);
return res;
}
tempvar res = bool(0);
return res;
}
4 changes: 3 additions & 1 deletion cairo/ethereum_types/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
// The data layout defined in this file are coherent with the Cairo arg generation process defined in args_gen.py and Cairo serialization process in serde.py

from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.math import assert_not_equal
from starkware.cairo.common.uint256 import Uint256
from ethereum_types.numeric import U128
from src.utils.utils import Helpers
from ethereum_types.numeric import U128, bool

// Bytes types
struct Bytes0 {
Expand Down
36 changes: 34 additions & 2 deletions cairo/src/utils/dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.squash_dict import squash_dict
from starkware.cairo.common.uint256 import Uint256

from ethereum_types.numeric import U256
from ethereum_types.numeric import U256, U256Struct
from ethereum_types.bytes import Bytes32
from ethereum.cancun.fork_types import Address
from ethereum.utils.numeric import U256__eq__
from ethereum.cancun.fork_types import Address, Account, AccountStruct, Account__eq__

from src.utils.maths import unsigned_div_rem

Expand Down Expand Up @@ -81,3 +82,34 @@ func hashdict_read{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
let dict_ptr = dict_ptr + DictAccess.SIZE;
return (value=value);
}

// A wrapper around dict_write that hashes the key before accessing the dictionary if the key
// does not fit in a felt.
// @param key_len: The number of felt values used to represent the key.
// @param key: The key to access the dictionary.
// @param new_value: The value to write to the dictionary.
func hashdict_write{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
key_len: felt, key: felt*, new_value: felt
) {
alloc_locals;
local felt_key;
if (key_len == 1) {
assert felt_key = key[0];
tempvar poseidon_ptr = poseidon_ptr;
} else {
let (felt_key_) = poseidon_hash_many(key_len, key);
assert felt_key = felt_key_;
tempvar poseidon_ptr = poseidon_ptr;
}
%{
dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)])
ids.dict_ptr.prev_value = dict_tracker.data[preimage]
dict_tracker.data[preimage] = ids.new_value
%}
dict_ptr.key = felt_key;
dict_ptr.new_value = new_value;
let dict_ptr = dict_ptr + DictAccess.SIZE;
return ();
}
Empty file.
9 changes: 8 additions & 1 deletion cairo/tests/ethereum/cancun/test_fork_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import pytest
from hypothesis import given

from ethereum.cancun.fork_types import EMPTY_ACCOUNT
from ethereum.cancun.fork_types import EMPTY_ACCOUNT, Account

pytestmark = pytest.mark.python_vm


class TestForkTypes:
def test_account_default(self, cairo_run):
assert EMPTY_ACCOUNT == cairo_run("EMPTY_ACCOUNT")

@given(account_a=..., account_b=...)
def test_account_eq(self, cairo_run, account_a: Account, account_b: Account):
assert (account_a == account_b) == cairo_run(
"Account__eq__", account_a, account_b
)
17 changes: 17 additions & 0 deletions cairo/tests/ethereum/cancun/test_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
nibble_list_to_compact,
patricialize,
trie_get,
trie_set,
)
from tests.utils.assertion import sequence_equal
from tests.utils.errors import cairo_error
Expand Down Expand Up @@ -169,3 +170,19 @@ def test_trie_get_TrieBytes32U256(
result_py = trie_get(trie, key)
assert result_cairo == result_py
assert trie_cairo == trie

@given(trie=..., key=..., value=...)
def test_trie_set_TrieAddressAccount(
self, cairo_run, trie: Trie[Address, Account], key: Address, value: Account
):
cairo_trie = cairo_run("trie_set_TrieAddressAccount", trie, key, value)
trie_set(trie, key, value)
assert cairo_trie == trie

@given(trie=..., key=..., value=...)
def test_trie_set_TrieBytes32U256(
self, cairo_run, trie: Trie[Bytes32, U256], key: Bytes32, value: U256
):
cairo_trie = cairo_run("trie_set_TrieBytes32U256", trie, key, value)
trie_set(trie, key, value)
assert cairo_trie == trie
11 changes: 11 additions & 0 deletions cairo/tests/ethereum/utils/test_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
from ethereum_types.bytes import Bytes
from hypothesis import given

pytestmark = pytest.mark.python_vm


class TestBytes:
@given(a=..., b=...)
def test_Bytes__eq__(self, cairo_run, a: Bytes, b: Bytes):
assert (a == b) == cairo_run("Bytes__eq__", a, b)
4 changes: 4 additions & 0 deletions cairo/tests/ethereum/utils/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ def test_U256_to_le_bytes(self, cairo_run, value: U256):
expected = value.to_le_bytes32()
result = cairo_run("U256_to_le_bytes", value)
assert result == expected

@given(a=..., b=...)
def test_U256__eq__(self, cairo_run, a: U256, b: U256):
assert (a == b) == cairo_run("U256__eq__", a, b)
15 changes: 14 additions & 1 deletion cairo/tests/utils/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ def serialize_pointers(self, path: Tuple[str, ...], ptr):
output[name] = member_ptr
return output

def is_pointer_wrapper(self, path: Tuple[str, ...]) -> bool:
"""Returns whether the type is a wrapper to a pointer."""
members = get_struct_definition(self.program, path).members
if len(members) != 1:
return False
return isinstance(list(members.values())[0].cairo_type, TypePointer)

def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
"""
Recursively serialize a Cairo instance, returning the corresponding Python instance.
Expand Down Expand Up @@ -305,6 +312,9 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
tracker_data = self.dict_manager.trackers[dict_ptr.segment_index].data
if isinstance(cairo_key_type, TypeFelt):
for key, value in tracker_data.items():
# We skip serialization of null pointers, but serialize values equal to zero
if value == 0 and self.is_pointer_wrapper(value_type.scope.path):
enitrat marked this conversation as resolved.
Show resolved Hide resolved
continue
# Reconstruct the original key from the preimage
if python_key_type in [
Bytes32,
Expand All @@ -325,6 +335,8 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
hashed_key = poseidon_hash_many(key)
preimage = bytes(list(key))
serialized_dict[preimage] = dict_data[hashed_key]
else:
raise ValueError(f"Unsupported key type: {python_key_type}")

elif get_origin(python_key_type) is tuple:
# If the key is a tuple, we're in the case of a Set[Tuple[Address, Bytes32]]]
Expand All @@ -348,8 +360,9 @@ def key_transform(k):

serialized_dict = {
key_transform(k): dict_data[key_transform(k)]
for k in tracker_data
for k, v in tracker_data.items()
if key_transform(k) in dict_data
and not (v == 0 and self.is_pointer_wrapper(value_type.scope.path))
}

if origin_cls is set:
Expand Down
Loading
Loading