Skip to content

Commit

Permalink
streamable for merkle blob serialization (#19154)
Browse files Browse the repository at this point in the history
- **streamable for merkle blob serialization**
- **wip**

<!-- Merging Requirements:
- Please give your PR a title that is release-note friendly
- In order to be merged, you must add the most appropriate category
Label (Added, Changed, Fixed) to your PR
-->
<!-- Explain why this is an improvement (Does this add missing
functionality, improve performance, or reduce complexity?) -->

### Purpose:

<!-- Does this PR introduce a breaking change? -->

### Current Behavior:

### New Behavior:

<!-- As we aim for complete code coverage, please include details
regarding unit, and regression tests -->

### Testing Notes:

<!-- Attach any visual examples, or supporting evidence (attach any
.gif/video/console output below) -->
  • Loading branch information
altendky authored Jan 28, 2025
2 parents 137bd80 + 0a68277 commit 4363032
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 210 deletions.
195 changes: 97 additions & 98 deletions chia/_tests/core/data_layer/test_merkle_blob.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import hashlib
import struct
from dataclasses import astuple, dataclass
import itertools
from dataclasses import dataclass
from random import Random
from typing import Generic, Protocol, TypeVar, final

Expand All @@ -28,14 +28,14 @@
ValueId,
data_size,
metadata_size,
null_parent,
pack_raw_node,
raw_node_classes,
raw_node_type_to_class,
spacing,
unpack_raw_node,
)
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import int64, uint32

pytestmark = pytest.mark.data_layer

Expand Down Expand Up @@ -64,23 +64,6 @@ def raw_node_class_fixture(request: SubRequest) -> RawMerkleNodeProtocol:
return request.param # type: ignore[no-any-return]


class_to_structs: dict[type[object], struct.Struct] = {
NodeMetadata: NodeMetadata.struct,
**{cls: cls.struct for cls in raw_node_classes},
}


@pytest.fixture(
name="class_struct",
scope="session",
params=class_to_structs.values(),
ids=[cls.__name__ for cls in class_to_structs.keys()],
)
def class_struct_fixture(request: SubRequest) -> RawMerkleNodeProtocol:
# https://github.com/pytest-dev/pytest/issues/8763
return request.param # type: ignore[no-any-return]


def test_raw_node_class_types_are_unique() -> None:
assert len(raw_node_type_to_class) == len(raw_node_classes)

Expand All @@ -90,32 +73,46 @@ def test_metadata_size_not_changed() -> None:


def test_data_size_not_changed() -> None:
assert data_size == 52


def test_raw_node_struct_sizes(raw_node_class: RawMerkleNodeProtocol) -> None:
assert raw_node_class.struct.size == data_size


def test_all_big_endian(class_struct: struct.Struct) -> None:
assert class_struct.format.startswith(">")
assert data_size == 53


# TODO: check all struct types against attribute types

RawMerkleNodeT = TypeVar("RawMerkleNodeT", bound=RawMerkleNodeProtocol)


reference_blob = bytes(range(data_size))
counter = itertools.count()
# hash
internal_reference_blob = bytes([next(counter) for _ in range(32)])
# optional parent
internal_reference_blob += bytes([1])
internal_reference_blob += bytes([next(counter) for _ in range(4)])
# left
internal_reference_blob += bytes([next(counter) for _ in range(4)])
# right
internal_reference_blob += bytes([next(counter) for _ in range(4)])
internal_reference_blob += bytes(0 for _ in range(data_size - len(internal_reference_blob)))
assert len(internal_reference_blob) == data_size

counter = itertools.count()
# hash
leaf_reference_blob = bytes([next(counter) for _ in range(32)])
# optional parent
leaf_reference_blob += bytes([1])
leaf_reference_blob += bytes([next(counter) for _ in range(4)])
# key
leaf_reference_blob += bytes([next(counter) for _ in range(8)])
# value
leaf_reference_blob += bytes([next(counter) for _ in range(8)])
leaf_reference_blob += bytes(0 for _ in range(data_size - len(leaf_reference_blob)))
assert len(leaf_reference_blob) == data_size


@final
@dataclass
class RawNodeFromBlobCase(Generic[RawMerkleNodeT]):
raw: RawMerkleNodeT
blob_to_unpack: bytes = reference_blob
packed_blob_reference_leaf: bytes = reference_blob
packed_blob_reference_internal: bytes = reference_blob[:44] + bytes([0] * 8)
packed: bytes

marks: Marks = ()

Expand All @@ -127,101 +124,104 @@ def id(self) -> str:
reference_raw_nodes: list[DataCase] = [
RawNodeFromBlobCase(
raw=RawInternalMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0x20212223),
left=TreeIndex(0x24252627),
right=TreeIndex(0x28292A2B),
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0x20212223)),
left=TreeIndex(uint32(0x24252627)),
right=TreeIndex(uint32(0x28292A2B)),
),
packed=internal_reference_blob,
),
RawNodeFromBlobCase(
raw=RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0x20212223),
key=KeyId(KeyOrValueId(0x2425262728292A2B)),
value=ValueId(KeyOrValueId(0x2C2D2E2F30313233)),
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0x20212223)),
key=KeyId(KeyOrValueId(int64(0x2425262728292A2B))),
value=ValueId(KeyOrValueId(int64(0x2C2D2E2F30313233))),
),
packed=leaf_reference_blob,
),
]


@datacases(*reference_raw_nodes)
def test_raw_node_from_blob(case: RawNodeFromBlobCase[RawMerkleNodeProtocol]) -> None:
node = unpack_raw_node(
index=TreeIndex(0),
index=TreeIndex(uint32(0)),
metadata=NodeMetadata(type=case.raw.type, dirty=False),
data=case.blob_to_unpack,
data=case.packed,
)
assert node == case.raw


@datacases(*reference_raw_nodes)
def test_raw_node_to_blob(case: RawNodeFromBlobCase[RawMerkleNodeProtocol]) -> None:
blob = pack_raw_node(case.raw)
expected_blob = (
case.packed_blob_reference_leaf
if isinstance(case.raw, RawLeafMerkleNode)
else case.packed_blob_reference_internal
)

assert blob == expected_blob
assert blob == case.packed


def test_merkle_blob_one_leaf_loads() -> None:
# TODO: need to persist reference data
leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=null_parent,
key=KeyId(KeyOrValueId(0x0405060708090A0B)),
value=ValueId(KeyOrValueId(0x0405060708090A1B)),
hash=bytes32(range(32)),
parent=None,
key=KeyId(KeyOrValueId(int64(0x0405060708090A0B))),
value=ValueId(KeyOrValueId(int64(0x0405060708090A1B))),
)
blob = bytearray(NodeMetadata(type=NodeType.leaf, dirty=False).pack() + pack_raw_node(leaf))
blob = bytearray(bytes(NodeMetadata(type=NodeType.leaf, dirty=False)) + pack_raw_node(leaf))

merkle_blob = MerkleBlob(blob=blob)
assert merkle_blob.get_raw_node(TreeIndex(0)) == leaf
assert merkle_blob.get_raw_node(TreeIndex(uint32(0))) == leaf


def test_merkle_blob_two_leafs_loads() -> None:
# TODO: break this test down into some reusable data and multiple tests
# TODO: need to persist reference data
root = RawInternalMerkleNode(
hash=bytes(range(32)),
parent=null_parent,
left=TreeIndex(1),
right=TreeIndex(2),
hash=bytes32(range(32)),
parent=None,
left=TreeIndex(uint32(1)),
right=TreeIndex(uint32(2)),
)
left_leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0),
key=KeyId(KeyOrValueId(0x0405060708090A0B)),
value=ValueId(KeyOrValueId(0x0405060708090A1B)),
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0)),
key=KeyId(KeyOrValueId(int64(0x0405060708090A0B))),
value=ValueId(KeyOrValueId(int64(0x0405060708090A1B))),
)
right_leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0),
key=KeyId(KeyOrValueId(0x1415161718191A1B)),
value=ValueId(KeyOrValueId(0x1415161718191A2B)),
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0)),
key=KeyId(KeyOrValueId(int64(0x1415161718191A1B))),
value=ValueId(KeyOrValueId(int64(0x1415161718191A2B))),
)
blob = bytearray()
blob.extend(NodeMetadata(type=NodeType.internal, dirty=True).pack() + pack_raw_node(root))
blob.extend(NodeMetadata(type=NodeType.leaf, dirty=False).pack() + pack_raw_node(left_leaf))
blob.extend(NodeMetadata(type=NodeType.leaf, dirty=False).pack() + pack_raw_node(right_leaf))
blob.extend(bytes(NodeMetadata(type=NodeType.internal, dirty=True)) + pack_raw_node(root))
blob.extend(bytes(NodeMetadata(type=NodeType.leaf, dirty=False)) + pack_raw_node(left_leaf))
blob.extend(bytes(NodeMetadata(type=NodeType.leaf, dirty=False)) + pack_raw_node(right_leaf))

merkle_blob = MerkleBlob(blob=blob)
assert merkle_blob.get_raw_node(TreeIndex(0)) == root
assert merkle_blob.get_raw_node(root.left) == left_leaf
assert merkle_blob.get_raw_node(root.right) == right_leaf
assert merkle_blob.get_raw_node(left_leaf.parent) == root
assert merkle_blob.get_raw_node(right_leaf.parent) == root

assert merkle_blob.get_lineage_with_indexes(TreeIndex(0)) == [(0, root)]
assert merkle_blob.get_lineage_with_indexes(root.left) == [(1, left_leaf), (0, root)]
assert merkle_blob.get_raw_node(TreeIndex(uint32(0))) == root
assert merkle_blob.get_raw_node(TreeIndex(root.left)) == left_leaf
assert merkle_blob.get_raw_node(TreeIndex(root.right)) == right_leaf
assert left_leaf.parent is not None
assert merkle_blob.get_raw_node(TreeIndex(left_leaf.parent)) == root
assert right_leaf.parent is not None
assert merkle_blob.get_raw_node(TreeIndex(right_leaf.parent)) == root

assert merkle_blob.get_lineage_with_indexes(TreeIndex(uint32(0))) == [(0, root)]
expected: list[tuple[TreeIndex, RawMerkleNodeProtocol]] = [
(TreeIndex(uint32(1)), left_leaf),
(TreeIndex(uint32(0)), root),
]
assert merkle_blob.get_lineage_with_indexes(TreeIndex(root.left)) == expected

merkle_blob.calculate_lazy_hashes()
son_hash = bytes32(range(32))
root_hash = internal_hash(son_hash, son_hash)
expected_node = InternalNode(root_hash, son_hash, son_hash)
assert merkle_blob.get_lineage_by_key_id(KeyId(KeyOrValueId(0x0405060708090A0B))) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KeyId(KeyOrValueId(0x1415161718191A1B))) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KeyId(KeyOrValueId(int64(0x0405060708090A0B)))) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KeyId(KeyOrValueId(int64(0x1415161718191A1B)))) == [expected_node]


def generate_kvid(seed: int) -> tuple[KeyId, ValueId]:
Expand All @@ -230,7 +230,7 @@ def generate_kvid(seed: int) -> tuple[KeyId, ValueId]:
for offset in range(2):
seed_bytes = (2 * seed + offset).to_bytes(8, byteorder="big", signed=True)
hash_obj = hashlib.sha256(seed_bytes)
hash_int = int.from_bytes(hash_obj.digest()[:8], byteorder="big", signed=True)
hash_int = int64.from_bytes(hash_obj.digest()[:8])
kv_ids.append(KeyOrValueId(hash_int))

return KeyId(kv_ids[0]), ValueId(kv_ids[1])
Expand Down Expand Up @@ -381,26 +381,25 @@ def test_proof_of_inclusion_merkle_blob() -> None:
assert proof_of_inclusion.valid()


@pytest.mark.parametrize(argnames="index", argvalues=[TreeIndex(-1), TreeIndex(1), TreeIndex(null_parent)])
@pytest.mark.parametrize(argnames="index", argvalues=[-1, 1, None])
def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex) -> None:
merkle_blob = MerkleBlob(blob=bytearray())
merkle_blob.insert(
KeyId(KeyOrValueId(0x1415161718191A1B)), ValueId(KeyOrValueId(0x1415161718191A1B)), bytes(range(12, data_size))
KeyId(KeyOrValueId(int64(0x1415161718191A1B))),
ValueId(KeyOrValueId(int64(0x1415161718191A1B))),
bytes32(range(12, 12 + 32)),
)

with pytest.raises(InvalidIndexError):
if index is None:
expected = (InvalidIndexError, TypeError)
else:
expected = (InvalidIndexError, chia_rs.datalayer.BlockIndexOutOfBoundsError)

with pytest.raises(expected):
merkle_blob.get_raw_node(index)

with pytest.raises(InvalidIndexError):
merkle_blob.get_metadata(index)


@pytest.mark.parametrize(argnames="cls", argvalues=raw_node_classes)
def test_as_tuple_matches_dataclasses_astuple(cls: type[RawMerkleNodeProtocol], seeded_random: Random) -> None:
raw_bytes = bytes(seeded_random.getrandbits(8) for _ in range(cls.struct.size))
raw_node = cls(*cls.struct.unpack(raw_bytes))
# TODO: try again to indicate that the RawMerkleNodeProtocol requires the dataclass interface
assert raw_node.as_tuple() == astuple(raw_node) # type: ignore[call-overload]
merkle_blob._get_metadata(index)


def test_helper_methods(merkle_blob_type: MerkleBlobCallable) -> None:
Expand All @@ -413,7 +412,7 @@ def test_helper_methods(merkle_blob_type: MerkleBlobCallable) -> None:
merkle_blob.insert(key, value, hash)
assert not merkle_blob.empty()
assert merkle_blob.get_root_hash() is not None
assert merkle_blob.get_root_hash() == merkle_blob.get_hash_at_index(TreeIndex(0))
assert merkle_blob.get_root_hash() == merkle_blob.get_hash_at_index(TreeIndex(uint32(0)))

merkle_blob.delete(key)
assert merkle_blob.empty()
Expand Down Expand Up @@ -470,8 +469,8 @@ def test_get_nodes(merkle_blob_type: MerkleBlobCallable) -> None:
all_nodes = merkle_blob.get_nodes_with_indexes()
for index, node in all_nodes:
if isinstance(node, (RawInternalMerkleNode, chia_rs.datalayer.InternalNode)):
left = merkle_blob.get_raw_node(node.left)
right = merkle_blob.get_raw_node(node.right)
left = merkle_blob.get_raw_node(TreeIndex(node.left))
right = merkle_blob.get_raw_node(TreeIndex(node.right))
assert left.parent == index
assert right.parent == index
assert bytes32(node.hash) == internal_hash(bytes32(left.hash), bytes32(right.hash))
Expand All @@ -487,7 +486,7 @@ def test_get_nodes(merkle_blob_type: MerkleBlobCallable) -> None:


def test_just_insert_a_bunch(merkle_blob_type: MerkleBlobCallable) -> None:
HASH = bytes(range(12, 44))
HASH = bytes32(range(12, 12 + 32))

import pathlib

Expand All @@ -501,6 +500,6 @@ def test_just_insert_a_bunch(merkle_blob_type: MerkleBlobCallable) -> None:
total_time = 0.0
for i in range(100000):
start = time.monotonic()
merkle_blob.insert(KeyId(KeyOrValueId(i)), ValueId(KeyOrValueId(i)), HASH)
merkle_blob.insert(KeyId(KeyOrValueId(int64(i))), ValueId(KeyOrValueId(int64(i))), HASH)
end = time.monotonic()
total_time += end - start
Loading

0 comments on commit 4363032

Please sign in to comment.