diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index 2c5986ad73db..7630857009b9 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -44,15 +44,10 @@ from chia.data_layer.util.merkle_blob import ( KVId, MerkleBlob, - NodeMetadata, RawInternalMerkleNode, RawLeafMerkleNode, TreeIndex, - null_parent, - pack_raw_node, - undefined_index, ) -from chia.data_layer.util.merkle_blob import NodeType as NodeTypeMerkleBlob from chia.types.blockchain_format.sized_bytes import bytes32 from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER, DBWrapper2 from chia.util.lru_cache import LRUCache @@ -230,10 +225,44 @@ async def insert_into_data_store_from_file( node_hash = leaf_hash(serialized_node.value1, serialized_node.value2) terminal_nodes[node_hash] = (kid, vid) - merkle_blob = MerkleBlob(blob=bytearray()) - if root_hash is not None: - await self.build_blob_from_nodes(internal_nodes, terminal_nodes, root_hash, merkle_blob, store_id) + missing_hashes: list[bytes32] = [] + merkle_blob_queries: dict[bytes32, list[int]] = defaultdict(list) + + for _, (left, right) in internal_nodes.items(): + for node_hash in (left, right): + if node_hash not in internal_nodes and node_hash not in terminal_nodes: + missing_hashes.append(node_hash) + + async with self.db_wrapper.reader() as reader: + for node_hash in missing_hashes: + cursor = await reader.execute( + "SELECT root_hash, idx FROM nodes WHERE hash = ? AND store_id = ?", + ( + node_hash, + store_id, + ), + ) + row = await cursor.fetchone() + if row is None: + raise Exception(f"Unknown hash {node_hash.hex()}") + + root_hash_blob = row["root_hash"] + index = row["idx"] + merkle_blob_queries[bytes32(root_hash_blob)].append(index) + + for root_hash_blob, indexes in merkle_blob_queries.items(): + merkle_blob = await self.get_merkle_blob(root_hash_blob, read_only=True) + for index in indexes: + nodes = merkle_blob.get_nodes_with_indexes(index=index) + index_to_hash = {index: bytes32(node.hash) for index, node in nodes} + for _, node in nodes: + if isinstance(node, RawLeafMerkleNode): + terminal_nodes[bytes32(node.hash)] = (node.key, node.value) + elif isinstance(node, RawInternalMerkleNode): + internal_nodes[bytes32(node.hash)] = (index_to_hash[node.left], index_to_hash[node.right]) + + merkle_blob = MerkleBlob.from_node_list(internal_nodes, terminal_nodes, root_hash) await self.insert_root_from_merkle_blob(merkle_blob, store_id, Status.COMMITTED) await self.add_node_hashes(store_id) @@ -525,74 +554,6 @@ async def add_node_hashes(self, store_id: bytes32) -> None: if hash not in existing_hashes: await self.add_node_hash(store_id, hash, root.node_hash, root.generation, index) - async def build_blob_from_nodes( - self, - internal_nodes: dict[bytes32, tuple[bytes32, bytes32]], - terminal_nodes: dict[bytes32, tuple[KVId, KVId]], - node_hash: bytes32, - merkle_blob: MerkleBlob, - store_id: bytes32, - ) -> TreeIndex: - if node_hash not in terminal_nodes and node_hash not in internal_nodes: - async with self.db_wrapper.reader() as reader: - cursor = await reader.execute( - "SELECT root_hash, idx FROM nodes WHERE hash = ? AND store_id = ?", - ( - node_hash, - store_id, - ), - ) - - row = await cursor.fetchone() - if row is None: - raise Exception(f"Unknown hash {node_hash}") - - root_hash = row["root_hash"] - index = row["idx"] - - other_merkle_blob = await self.get_merkle_blob(root_hash, read_only=True) - nodes = other_merkle_blob.get_nodes_with_indexes(index=index) - index_to_hash = {index: bytes32(node.hash) for index, node in nodes} - for _, node in nodes: - if isinstance(node, RawLeafMerkleNode): - terminal_nodes[bytes32(node.hash)] = (node.key, node.value) - elif isinstance(node, RawInternalMerkleNode): - internal_nodes[bytes32(node.hash)] = (index_to_hash[node.left], index_to_hash[node.right]) - - index = merkle_blob.get_new_index() - if node_hash in terminal_nodes: - kid, vid = terminal_nodes[node_hash] - merkle_blob.insert_entry_to_blob( - index, - NodeMetadata(type=NodeTypeMerkleBlob.leaf, dirty=False).pack() - + pack_raw_node(RawLeafMerkleNode(node_hash, null_parent, kid, vid)), - ) - elif node_hash in internal_nodes: - merkle_blob.insert_entry_to_blob( - index, - NodeMetadata(type=NodeTypeMerkleBlob.internal, dirty=False).pack() - + pack_raw_node( - RawInternalMerkleNode( - node_hash, - null_parent, - undefined_index, - undefined_index, - ) - ), - ) - left_hash, right_hash = internal_nodes[node_hash] - left_index = await self.build_blob_from_nodes( - internal_nodes, terminal_nodes, left_hash, merkle_blob, store_id - ) - right_index = await self.build_blob_from_nodes( - internal_nodes, terminal_nodes, right_hash, merkle_blob, store_id - ) - for child_index in (left_index, right_index): - merkle_blob.update_entry(index=child_index, parent=index) - merkle_blob.update_entry(index=index, left=left_index, right=right_index) - - return TreeIndex(index) - async def _insert_root( self, store_id: bytes32, diff --git a/chia/data_layer/util/merkle_blob.py b/chia/data_layer/util/merkle_blob.py index d488259cde56..29a5950edef2 100644 --- a/chia/data_layer/util/merkle_blob.py +++ b/chia/data_layer/util/merkle_blob.py @@ -49,6 +49,62 @@ def __post_init__(self) -> None: self.last_allocated_index = TreeIndex(len(self.blob) // spacing) self.free_indexes = self.get_free_indexes() + @classmethod + def from_node_list( + cls: type[MerkleBlob], + internal_nodes: dict[bytes32, tuple[bytes32, bytes32]], + terminal_nodes: dict[bytes32, tuple[KVId, KVId]], + root_hash: Optional[bytes32], + ) -> MerkleBlob: + merkle_blob = cls(blob=bytearray()) + + if root_hash is None: + if internal_nodes or terminal_nodes: + raise Exception("Nodes must be empty when root_hash is None") + else: + merkle_blob.build_blob_from_node_list(internal_nodes, terminal_nodes, root_hash) + + return merkle_blob + + def build_blob_from_node_list( + self, + internal_nodes: dict[bytes32, tuple[bytes32, bytes32]], + terminal_nodes: dict[bytes32, tuple[KVId, KVId]], + node_hash: bytes32, + ) -> TreeIndex: + if node_hash not in terminal_nodes and node_hash not in internal_nodes: + raise Exception(f"Unknown hash {node_hash.hex()}") + + index = self.get_new_index() + if node_hash in terminal_nodes: + kid, vid = terminal_nodes[node_hash] + self.insert_entry_to_blob( + index, + NodeMetadata(type=NodeType.leaf, dirty=False).pack() + + pack_raw_node(RawLeafMerkleNode(node_hash, null_parent, kid, vid)), + ) + elif node_hash in internal_nodes: + self.insert_entry_to_blob( + index, + NodeMetadata(type=NodeType.internal, dirty=False).pack() + + pack_raw_node( + RawInternalMerkleNode( + node_hash, + null_parent, + undefined_index, + undefined_index, + ) + ), + ) + left_hash, right_hash = internal_nodes[node_hash] + left_index = self.build_blob_from_node_list(internal_nodes, terminal_nodes, left_hash) + right_index = self.build_blob_from_node_list(internal_nodes, terminal_nodes, right_hash) + for child_index in (left_index, right_index): + self.update_entry(index=child_index, parent=index) + self.update_entry(index=index, left=left_index, right=right_index) + + return TreeIndex(index) + def get_new_index(self) -> TreeIndex: if len(self.free_indexes) == 0: self.last_allocated_index = TreeIndex(self.last_allocated_index + 1)