Skip to content

Commit

Permalink
feat: trie set
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Jan 9, 2025
1 parent 9d29905 commit ae3ee40
Show file tree
Hide file tree
Showing 12 changed files with 285 additions and 4 deletions.
24 changes: 24 additions & 0 deletions cairo/ethereum/cancun/fork_types.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from starkware.cairo.common.alloc import alloc

from ethereum_types.bytes import Bytes20, Bytes32, Bytes256, Bytes, BytesStruct, HashedBytes32
from ethereum.utils.bytes import Bytes__eq__
from ethereum_types.numeric import Uint, U256, U256Struct, bool
from ethereum.crypto.hash import Hash32

Expand Down Expand Up @@ -117,3 +118,26 @@ func EMPTY_ACCOUNT() -> Account {
tempvar account = Account(value=new AccountStruct(nonce=Uint(0), balance=balance, code=code));
return account;
}

func Account__eq__(a: Account, b: Account) -> bool {
if (a.value.nonce.value != b.value.nonce.value) {
tempvar res = bool(0);
return res;
}
if (a.value.balance.value.low != b.value.balance.value.low) {
tempvar res = bool(0);
return res;
}
if (a.value.balance.value.high != b.value.balance.value.high) {
tempvar res = bool(0);
return res;
}
if (a.value.code.value.len != b.value.code.value.len) {
tempvar res = bool(0);
return res;
}

let code_eq = Bytes__eq__(a.value.code, b.value.code);

return code_eq;
}
61 changes: 60 additions & 1 deletion cairo/ethereum/cancun/trie.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ from starkware.cairo.common.cairo_builtins import KeccakBuiltin
from starkware.cairo.common.memcpy import memcpy

from src.utils.bytes import uint256_to_bytes32_little
from src.utils.dict import hashdict_read
from src.utils.dict import hashdict_read, hashdict_delete_if_present_bytes, hashdict_write
from ethereum.crypto.hash import keccak256
from ethereum.utils.numeric import min, is_zero
from ethereum.rlp import encode, _encode_bytes, _encode
from ethereum.utils.numeric import U256__eq__
from ethereum_types.numeric import U256, Uint, bool, U256Struct
from ethereum_types.bytes import (
Bytes,
Expand All @@ -30,6 +31,7 @@ from ethereum_types.bytes import (
from ethereum.cancun.blocks import Receipt, Withdrawal
from ethereum.cancun.fork_types import (
Account,
Account__eq__,
AccountStruct,
Address,
Bytes32U256DictAccess,
Expand Down Expand Up @@ -390,6 +392,63 @@ func trie_get_TrieBytes32U256{poseidon_ptr: PoseidonBuiltin*, trie: TrieBytes32U
return res;
}

func trie_set_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressAccount}(
key: Address, value: Account
) {
let dict_ptr = cast(trie.value._data.value.dict_ptr, DictAccess*);

let default = cast(trie.value.default.value, felt);
let is_default = Account__eq__(value, trie.value.default);

with dict_ptr {
let (keys) = alloc();
assert [keys] = key.value;
if (is_default.value != 0) {
hashdict_delete_if_present_bytes(1, keys);
tempvar dict_ptr = dict_ptr;
} else {
hashdict_write(1, keys, cast(value.value, felt));
tempvar dict_ptr = dict_ptr;
}
}
let new_dict_ptr = cast(dict_ptr, AddressAccountDictAccess*);
tempvar mapping = MappingAddressAccount(
new MappingAddressAccountStruct(trie.value._data.value.dict_ptr_start, new_dict_ptr)
);
tempvar trie = TrieAddressAccount(
new TrieAddressAccountStruct(trie.value.secured, trie.value.default, mapping)
);
return ();
}

func trie_set_TrieBytes32U256{poseidon_ptr: PoseidonBuiltin*, trie: TrieBytes32U256}(
key: Bytes32, value: U256
) {
let dict_ptr = cast(trie.value._data.value.dict_ptr, DictAccess*);

let default = cast(trie.value.default.value, felt);
let is_default = U256__eq__(value, trie.value.default);

with dict_ptr {
if (is_default.value != 0) {
hashdict_delete_if_present_bytes(2, cast(key.value, felt*));
tempvar dict_ptr = dict_ptr;
} else {
hashdict_write(2, cast(key.value, felt*), cast(value.value, felt));
tempvar dict_ptr = dict_ptr;
}
}
let new_dict_ptr = cast(dict_ptr, Bytes32U256DictAccess*);
tempvar mapping = MappingBytes32U256(
new MappingBytes32U256Struct(trie.value._data.value.dict_ptr_start, new_dict_ptr)
);
tempvar trie = TrieBytes32U256(
new TrieBytes32U256Struct(trie.value.secured, trie.value.default, mapping)
);
return ();
}

func common_prefix_length(a: Bytes, b: Bytes) -> felt {
alloc_locals;
local result;
Expand Down
58 changes: 58 additions & 0 deletions cairo/ethereum/utils/bytes.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from ethereum_types.bytes import Bytes, BytesStruct
from ethereum_types.numeric import bool
from ethereum.utils.numeric import is_zero
from starkware.cairo.common.math import assert_not_equal

func Bytes__eq__(_self: Bytes, other: Bytes) -> bool {
if (_self.value.len != other.value.len) {
tempvar res = bool(0);
return res;
}

// Case diff: we can let the prover do the work of iterating over the bytes,
// return the first different byte index, and assert in cairo that the a[index] != b[index]
tempvar is_diff;
tempvar diff_index;
%{
self_bytes = b''.join([memory[ids._self.value.data + i].to_bytes(1, "little") for i in range(ids._self.value.len)])
other_bytes = b''.join([memory[ids.other.value.data + i].to_bytes(1, "little") for i in range(ids.other.value.len)])
diff_index = next((i for i, (b_self, b_other) in enumerate(zip(self_bytes, other_bytes)) if b_self != b_other), None)
if diff_index is not None:
ids.is_diff = 1
ids.diff_index = diff_index
else:
# No differences found in common prefix. Lengths were checked before
ids.is_diff = 0
ids.diff_index = 0
%}

if (is_diff == 1) {
// Assert that the bytes are different at the first different index
assert_not_equal(_self.value.data[diff_index], other.value.data[diff_index]);
tempvar res = bool(1);
return res;
}

// Case equal: we need to iterate over all keys in cairo, because the prover might not have been honest
// about the first different byte index.
tempvar i = 0;

loop:
let index = [ap - 1];
let self_value = cast([fp - 4], BytesStruct*);
let other_value = cast([fp - 3], BytesStruct*);

let is_end = is_zero(index - self_value.len);
tempvar res = bool(1);
jmp end if is_end != 0;

let is_eq = is_zero(self_value.data[index] - other_value.data[index]);

tempvar i = i + 1;
jmp loop if is_eq != 0;
tempvar res = bool(0);

end:
let res = bool([ap - 1]);
return res;
}
11 changes: 10 additions & 1 deletion cairo/ethereum/utils/numeric.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from starkware.cairo.common.math_cmp import is_le, is_not_zero
from starkware.cairo.common.uint256 import uint256_reverse_endian
from ethereum_types.numeric import Uint, U256, U256Struct
from ethereum_types.numeric import Uint, U256, U256Struct, bool
from ethereum_types.bytes import Bytes32, Bytes32Struct
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin

Expand Down Expand Up @@ -139,3 +139,12 @@ func U256_to_le_bytes(value: U256) -> Bytes32 {
tempvar res = Bytes32(value.value);
return res;
}

func U256__eq__(a: U256, b: U256) -> bool {
if (a.value.low == b.value.low and a.value.high == b.value.high) {
tempvar res = bool(1);
return res;
}
tempvar res = bool(0);
return res;
}
4 changes: 3 additions & 1 deletion cairo/ethereum_types/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
// The data layout defined in this file are coherent with the Cairo arg generation process defined in args_gen.py and Cairo serialization process in serde.py

from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.math import assert_not_equal
from starkware.cairo.common.uint256 import Uint256
from ethereum_types.numeric import U128
from src.utils.utils import Helpers
from ethereum_types.numeric import U128, bool

// Bytes types
struct Bytes0 {
Expand Down
84 changes: 84 additions & 0 deletions cairo/src/utils/dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,87 @@ func hashdict_read{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
let dict_ptr = dict_ptr + DictAccess.SIZE;
return (value=value);
}

// A wrapper around dict_write that hashes the key before accessing the dictionary if the key
// does not fit in a felt.
// @param key_len: The number of felt values used to represent the key.
// @param key: The key to access the dictionary.
// @param new_value: The value to write to the dictionary.
func hashdict_write{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
key_len: felt, key: felt*, new_value: felt
) {
alloc_locals;
local felt_key;
if (key_len == 1) {
assert felt_key = key[0];
tempvar poseidon_ptr = poseidon_ptr;
} else {
let (felt_key_) = poseidon_hash_many(key_len, key);
assert felt_key = felt_key_;
tempvar poseidon_ptr = poseidon_ptr;
}
%{
dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)])
ids.dict_ptr.prev_value = dict_tracker.data[preimage]
dict_tracker.data[preimage] = ids.new_value
%}
dict_ptr.key = felt_key;
dict_ptr.new_value = new_value;
let dict_ptr = dict_ptr + DictAccess.SIZE;
return ();
}

// A special dict function that writes a special value to the dictionary to represent a deleted
// value if that key is present, or writes the new value if the key is not present.
// @param byte_length: The number of bytes in the key.
// @param key_len: The number of felt values used to represent the key.
// @param key: The key to access the dictionary.
// @param new_value: The value to write to the dictionary.
func hashdict_delete_if_present_bytes{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
key_len: felt, key: felt*
) {
alloc_locals;
local felt_key;
if (key_len == 1) {
assert felt_key = key[0];
tempvar poseidon_ptr = poseidon_ptr;
} else {
let (felt_key_) = poseidon_hash_many(key_len, key);
assert felt_key = felt_key_;
tempvar poseidon_ptr = poseidon_ptr;
}

tempvar is_deleted: felt;
%{
from tests.utils.hints import DELETED_KEY_FLAG
dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)])
if preimage in dict_tracker.data:
# Deleting the key means writing a special value to the dictionary to represent a deleted value.
ids.dict_ptr.prev_value = dict_tracker.data[preimage]
dict_tracker.data[preimage] = DELETED_KEY_FLAG
ids.is_deleted = 1
else:
# Nothing to do.
ids.dict_ptr.prev_value = 0
ids.is_deleted = 0
%}
dict_ptr.key = felt_key;

if (is_deleted != 0) {
// TODO: is it fine to return an empty pointer, which is the same as representing an absent value,
// or should we return a special value?
tempvar ptr_zero = cast(0, DictAccess*);
dict_ptr.new_value = ptr_zero;
let dict_ptr = dict_ptr + DictAccess.SIZE;
return ();
}

tempvar prev_value = dict_ptr.prev_value;
dict_ptr.new_value = prev_value;
let dict_ptr = dict_ptr + DictAccess.SIZE;
return ();
}
9 changes: 8 additions & 1 deletion cairo/tests/ethereum/cancun/test_fork_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import pytest
from hypothesis import given

from ethereum.cancun.fork_types import EMPTY_ACCOUNT
from ethereum.cancun.fork_types import EMPTY_ACCOUNT, Account

pytestmark = pytest.mark.python_vm


class TestForkTypes:
def test_account_default(self, cairo_run):
assert EMPTY_ACCOUNT == cairo_run("EMPTY_ACCOUNT")

@given(account_a=..., account_b=...)
def test_account_eq(self, cairo_run, account_a: Account, account_b: Account):
assert (account_a == account_b) == cairo_run(
"Account__eq__", account_a, account_b
)
17 changes: 17 additions & 0 deletions cairo/tests/ethereum/cancun/test_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
nibble_list_to_compact,
patricialize,
trie_get,
trie_set,
)
from tests.utils.assertion import sequence_equal
from tests.utils.errors import cairo_error
Expand Down Expand Up @@ -169,3 +170,19 @@ def test_trie_get_TrieBytes32U256(
result_py = trie_get(trie, key)
assert result_cairo == result_py
assert trie_cairo == trie

@given(trie=..., key=..., value=...)
def test_trie_set_TrieAddressAccount(
self, cairo_run, trie: Trie[Address, Account], key: Address, value: Account
):
cairo_trie = cairo_run("trie_set_TrieAddressAccount", trie, key, value)
trie_set(trie, key, value)
assert cairo_trie == trie

@given(trie=..., key=..., value=...)
def test_trie_set_TrieBytes32U256(
self, cairo_run, trie: Trie[Bytes32, U256], key: Bytes32, value: U256
):
cairo_trie = cairo_run("trie_set_TrieBytes32U256", trie, key, value)
trie_set(trie, key, value)
assert cairo_trie == trie
11 changes: 11 additions & 0 deletions cairo/tests/ethereum/utils/test_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
from ethereum_types.bytes import Bytes
from hypothesis import given

pytestmark = pytest.mark.python_vm


class TestBytes:
@given(a=..., b=...)
def test_Bytes__eq__(self, cairo_run, a: Bytes, b: Bytes):
assert (a == b) == cairo_run("Bytes__eq__", a, b)
4 changes: 4 additions & 0 deletions cairo/tests/ethereum/utils/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ def test_U256_to_le_bytes(self, cairo_run, value: U256):
expected = value.to_le_bytes32()
result = cairo_run("U256_to_le_bytes", value)
assert result == expected

@given(a=..., b=...)
def test_U256__eq__(self, cairo_run, a: U256, b: U256):
assert (a == b) == cairo_run("U256__eq__", a, b)
1 change: 1 addition & 0 deletions cairo/tests/utils/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tests.utils.helpers import flatten

MaybeRelocatable = Union[int, Relocatable]
DELETED_KEY_FLAG = object()


def debug_info(program):
Expand Down
Loading

0 comments on commit ae3ee40

Please sign in to comment.