diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index 790d24e948d7..93cac7b435ed 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -748,7 +748,7 @@ where config.prefix_sets, thread_pool, ) - .with_branch_node_hash_masks(true) + .with_branch_node_masks(true) .multiproof(proof_targets)?) } diff --git a/crates/engine/tree/src/tree/trie_updates.rs b/crates/engine/tree/src/tree/trie_updates.rs index ea78aca13b87..576f0c742647 100644 --- a/crates/engine/tree/src/tree/trie_updates.rs +++ b/crates/engine/tree/src/tree/trie_updates.rs @@ -198,7 +198,7 @@ fn branch_nodes_equal( ) -> bool { if let (Some(task), Some(regular)) = (task.as_ref(), regular.as_ref()) { task.state_mask == regular.state_mask && - // We do not compare the tree mask because it is known to be mismatching + task.tree_mask == regular.tree_mask && task.hash_mask == regular.hash_mask && task.hashes == regular.hashes && task.root_hash == regular.root_hash diff --git a/crates/trie/common/src/proofs.rs b/crates/trie/common/src/proofs.rs index 54171710761e..2e64ef39728c 100644 --- a/crates/trie/common/src/proofs.rs +++ b/crates/trie/common/src/proofs.rs @@ -29,6 +29,8 @@ pub struct MultiProof { pub account_subtree: ProofNodes, /// The hash masks of the branch nodes in the account proof. pub branch_node_hash_masks: HashMap, + /// The tree masks of the branch nodes in the account proof. + pub branch_node_tree_masks: HashMap, /// Storage trie multiproofs. pub storages: B256HashMap, } @@ -115,6 +117,7 @@ impl MultiProof { self.account_subtree.extend_from(other.account_subtree); self.branch_node_hash_masks.extend(other.branch_node_hash_masks); + self.branch_node_tree_masks.extend(other.branch_node_tree_masks); for (hashed_address, storage) in other.storages { match self.storages.entry(hashed_address) { @@ -123,6 +126,7 @@ impl MultiProof { let entry = entry.get_mut(); entry.subtree.extend_from(storage.subtree); entry.branch_node_hash_masks.extend(storage.branch_node_hash_masks); + entry.branch_node_tree_masks.extend(storage.branch_node_tree_masks); } hash_map::Entry::Vacant(entry) => { entry.insert(storage); @@ -141,6 +145,8 @@ pub struct StorageMultiProof { pub subtree: ProofNodes, /// The hash masks of the branch nodes in the storage proof. pub branch_node_hash_masks: HashMap, + /// The tree masks of the branch nodes in the storage proof. + pub branch_node_tree_masks: HashMap, } impl StorageMultiProof { @@ -153,6 +159,7 @@ impl StorageMultiProof { Bytes::from([EMPTY_STRING_CODE]), )]), branch_node_hash_masks: HashMap::default(), + branch_node_tree_masks: HashMap::default(), } } @@ -398,6 +405,7 @@ mod tests { root, subtree: subtree1, branch_node_hash_masks: HashMap::default(), + branch_node_tree_masks: HashMap::default(), }, ); @@ -412,6 +420,7 @@ mod tests { root, subtree: subtree2, branch_node_hash_masks: HashMap::default(), + branch_node_tree_masks: HashMap::default(), }, ); diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index 31df5f232879..f7716ee13161 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -44,8 +44,8 @@ pub struct ParallelProof { /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here, /// if we have cached nodes for them. pub prefix_sets: Arc, - /// Flag indicating whether to include branch node hash masks in the proof. - collect_branch_node_hash_masks: bool, + /// Flag indicating whether to include branch node masks in the proof. + collect_branch_node_masks: bool, /// Thread pool for local tasks thread_pool: Arc, /// Parallel state root metrics. @@ -67,16 +67,16 @@ impl ParallelProof { nodes_sorted, state_sorted, prefix_sets, - collect_branch_node_hash_masks: false, + collect_branch_node_masks: false, thread_pool, #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics::default(), } } - /// Set the flag indicating whether to include branch node hash masks in the proof. - pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self { - self.collect_branch_node_hash_masks = branch_node_hash_masks; + /// Set the flag indicating whether to include branch node masks in the proof. + pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self { + self.collect_branch_node_masks = branch_node_masks; self } } @@ -137,7 +137,7 @@ where let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default(); let trie_nodes_sorted = self.nodes_sorted.clone(); let hashed_state_sorted = self.state_sorted.clone(); - let collect_masks = self.collect_branch_node_hash_masks; + let collect_masks = self.collect_branch_node_masks; let (tx, rx) = std::sync::mpsc::sync_channel(1); @@ -182,7 +182,7 @@ where hashed_address, ) .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned())) - .with_branch_node_hash_masks(collect_masks) + .with_branch_node_masks(collect_masks) .storage_multiproof(target_slots) .map_err(|e| ParallelStateRootError::Other(e.to_string())); @@ -233,7 +233,7 @@ where let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect(); let mut hash_builder = HashBuilder::default() .with_proof_retainer(retainer) - .with_updates(self.collect_branch_node_hash_masks); + .with_updates(self.collect_branch_node_masks); // Initialize all storage multiproofs as empty. // Storage multiproofs for non empty tries will be overwritten if necessary. @@ -301,18 +301,23 @@ where self.metrics.record_state_trie(tracker.finish()); let account_subtree = hash_builder.take_proof_nodes(); - let branch_node_hash_masks = if self.collect_branch_node_hash_masks { - hash_builder - .updated_branch_nodes - .unwrap_or_default() - .into_iter() - .map(|(path, node)| (path, node.hash_mask)) - .collect() + let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks { + let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default(); + ( + updated_branch_nodes + .iter() + .map(|(path, node)| (path.clone(), node.hash_mask)) + .collect(), + updated_branch_nodes + .into_iter() + .map(|(path, node)| (path, node.tree_mask)) + .collect(), + ) } else { - HashMap::default() + (HashMap::default(), HashMap::default()) }; - Ok(MultiProof { account_subtree, branch_node_hash_masks, storages }) + Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages }) } } diff --git a/crates/trie/sparse/src/state.rs b/crates/trie/sparse/src/state.rs index b95cb62c7e65..505a326c0bf4 100644 --- a/crates/trie/sparse/src/state.rs +++ b/crates/trie/sparse/src/state.rs @@ -155,13 +155,14 @@ impl SparseStateTrie { self.provider_factory.account_node_provider(), root_node, None, + None, self.retain_updates, )?; // Reveal the remaining proof nodes. for (path, bytes) in proof { let node = TrieNode::decode(&mut &bytes[..])?; - trie.reveal_node(path, node, None)?; + trie.reveal_node(path, node, None, None)?; } // Mark leaf path as revealed. @@ -196,13 +197,14 @@ impl SparseStateTrie { self.provider_factory.storage_node_provider(account), root_node, None, + None, self.retain_updates, )?; // Reveal the remaining proof nodes. for (path, bytes) in proof { let node = TrieNode::decode(&mut &bytes[..])?; - trie.reveal_node(path, node, None)?; + trie.reveal_node(path, node, None, None)?; } // Mark leaf path as revealed. @@ -227,20 +229,24 @@ impl SparseStateTrie { self.provider_factory.account_node_provider(), root_node, multiproof.branch_node_hash_masks.get(&Nibbles::default()).copied(), + multiproof.branch_node_tree_masks.get(&Nibbles::default()).copied(), self.retain_updates, )?; // Reveal the remaining proof nodes. for (path, bytes) in account_nodes { let node = TrieNode::decode(&mut &bytes[..])?; - let hash_mask = if let TrieNode::Branch(_) = node { - multiproof.branch_node_hash_masks.get(&path).copied() + let (hash_mask, tree_mask) = if let TrieNode::Branch(_) = node { + ( + multiproof.branch_node_hash_masks.get(&path).copied(), + multiproof.branch_node_tree_masks.get(&path).copied(), + ) } else { - None + (None, None) }; - trace!(target: "trie::sparse", ?path, ?node, ?hash_mask, "Revealing account node"); - trie.reveal_node(path, node, hash_mask)?; + trace!(target: "trie::sparse", ?path, ?node, ?hash_mask, ?tree_mask, "Revealing account node"); + trie.reveal_node(path, node, hash_mask, tree_mask)?; } } @@ -254,20 +260,24 @@ impl SparseStateTrie { self.provider_factory.storage_node_provider(account), root_node, storage_subtree.branch_node_hash_masks.get(&Nibbles::default()).copied(), + storage_subtree.branch_node_tree_masks.get(&Nibbles::default()).copied(), self.retain_updates, )?; // Reveal the remaining proof nodes. for (path, bytes) in nodes { let node = TrieNode::decode(&mut &bytes[..])?; - let hash_mask = if let TrieNode::Branch(_) = node { - storage_subtree.branch_node_hash_masks.get(&path).copied() + let (hash_mask, tree_mask) = if let TrieNode::Branch(_) = node { + ( + storage_subtree.branch_node_hash_masks.get(&path).copied(), + storage_subtree.branch_node_tree_masks.get(&path).copied(), + ) } else { - None + (None, None) }; - trace!(target: "trie::sparse", ?account, ?path, ?node, ?hash_mask, "Revealing storage node"); - trie.reveal_node(path, node, hash_mask)?; + trace!(target: "trie::sparse", ?account, ?path, ?node, ?hash_mask, ?tree_mask, "Revealing storage node"); + trie.reveal_node(path, node, hash_mask, tree_mask)?; } } } @@ -348,6 +358,7 @@ impl SparseStateTrie { self.provider_factory.storage_node_provider(account), trie_node, None, + None, self.retain_updates, )?; } else { @@ -355,7 +366,7 @@ impl SparseStateTrie { storage_trie_entry .as_revealed_mut() .ok_or(SparseTrieErrorKind::Blind)? - .reveal_node(path, trie_node, None)?; + .reveal_node(path, trie_node, None, None)?; } } else if path.is_empty() { // Handle special state root node case. @@ -363,6 +374,7 @@ impl SparseStateTrie { self.provider_factory.account_node_provider(), trie_node, None, + None, self.retain_updates, )?; } else { @@ -370,7 +382,7 @@ impl SparseStateTrie { self.state .as_revealed_mut() .ok_or(SparseTrieErrorKind::Blind)? - .reveal_node(path, trie_node, None)?; + .reveal_node(path, trie_node, None, None)?; } } @@ -668,6 +680,7 @@ mod tests { Nibbles::from_nibbles([0x1]), TrieMask::new(0b00), )]), + branch_node_tree_masks: HashMap::default(), storages: HashMap::from_iter([ ( address_1, @@ -675,6 +688,7 @@ mod tests { root, subtree: storage_proof_nodes.clone(), branch_node_hash_masks: storage_branch_node_hash_masks.clone(), + branch_node_tree_masks: HashMap::default(), }, ), ( @@ -683,6 +697,7 @@ mod tests { root, subtree: storage_proof_nodes, branch_node_hash_masks: storage_branch_node_hash_masks, + branch_node_tree_masks: HashMap::default(), }, ), ]), diff --git a/crates/trie/sparse/src/trie.rs b/crates/trie/sparse/src/trie.rs index 7ff0e40e1a21..b7cba834567f 100644 --- a/crates/trie/sparse/src/trie.rs +++ b/crates/trie/sparse/src/trie.rs @@ -60,9 +60,16 @@ impl SparseTrie { &mut self, root: TrieNode, hash_mask: Option, + tree_mask: Option, retain_updates: bool, ) -> SparseTrieResult<&mut RevealedSparseTrie> { - self.reveal_root_with_provider(Default::default(), root, hash_mask, retain_updates) + self.reveal_root_with_provider( + Default::default(), + root, + hash_mask, + tree_mask, + retain_updates, + ) } } @@ -100,6 +107,7 @@ impl

SparseTrie

{ provider: P, root: TrieNode, hash_mask: Option, + tree_mask: Option, retain_updates: bool, ) -> SparseTrieResult<&mut RevealedSparseTrie

> { if self.is_blind() { @@ -107,6 +115,7 @@ impl

SparseTrie

{ provider, root, hash_mask, + tree_mask, retain_updates, )?)) } @@ -163,6 +172,8 @@ pub struct RevealedSparseTrie

{ nodes: HashMap, /// All branch node hash masks. branch_node_hash_masks: HashMap, + /// All branch node tree masks. + branch_node_tree_masks: HashMap, /// All leaf values. values: HashMap>, /// Prefix set. @@ -178,6 +189,7 @@ impl

fmt::Debug for RevealedSparseTrie

{ f.debug_struct("RevealedSparseTrie") .field("nodes", &self.nodes) .field("branch_hash_masks", &self.branch_node_hash_masks) + .field("branch_tree_masks", &self.branch_node_tree_masks) .field("values", &self.values) .field("prefix_set", &self.prefix_set) .field("updates", &self.updates) @@ -192,6 +204,7 @@ impl Default for RevealedSparseTrie { provider: Default::default(), nodes: HashMap::from_iter([(Nibbles::default(), SparseNode::Empty)]), branch_node_hash_masks: HashMap::default(), + branch_node_tree_masks: HashMap::default(), values: HashMap::default(), prefix_set: PrefixSetMut::default(), updates: None, @@ -205,19 +218,21 @@ impl RevealedSparseTrie { pub fn from_root( node: TrieNode, hash_mask: Option, + tree_mask: Option, retain_updates: bool, ) -> SparseTrieResult { let mut this = Self { provider: Default::default(), nodes: HashMap::default(), branch_node_hash_masks: HashMap::default(), + branch_node_tree_masks: HashMap::default(), values: HashMap::default(), prefix_set: PrefixSetMut::default(), rlp_buf: Vec::new(), updates: None, } .with_updates(retain_updates); - this.reveal_node(Nibbles::default(), node, hash_mask)?; + this.reveal_node(Nibbles::default(), node, hash_mask, tree_mask)?; Ok(this) } } @@ -228,19 +243,21 @@ impl

RevealedSparseTrie

{ provider: P, node: TrieNode, hash_mask: Option, + tree_mask: Option, retain_updates: bool, ) -> SparseTrieResult { let mut this = Self { provider, nodes: HashMap::default(), branch_node_hash_masks: HashMap::default(), + branch_node_tree_masks: HashMap::default(), values: HashMap::default(), prefix_set: PrefixSetMut::default(), rlp_buf: Vec::new(), updates: None, } .with_updates(retain_updates); - this.reveal_node(Nibbles::default(), node, hash_mask)?; + this.reveal_node(Nibbles::default(), node, hash_mask, tree_mask)?; Ok(this) } @@ -250,6 +267,7 @@ impl

RevealedSparseTrie

{ provider, nodes: self.nodes, branch_node_hash_masks: self.branch_node_hash_masks, + branch_node_tree_masks: self.branch_node_tree_masks, values: self.values, prefix_set: self.prefix_set, updates: self.updates, @@ -286,6 +304,7 @@ impl

RevealedSparseTrie

{ path: Nibbles, node: TrieNode, hash_mask: Option, + tree_mask: Option, ) -> SparseTrieResult<()> { // If the node is already revealed and it's not a hash node, do nothing. if self.nodes.get(&path).is_some_and(|node| !node.is_hash()) { @@ -295,6 +314,9 @@ impl

RevealedSparseTrie

{ if let Some(hash_mask) = hash_mask { self.branch_node_hash_masks.insert(path.clone(), hash_mask); } + if let Some(tree_mask) = tree_mask { + self.branch_node_tree_masks.insert(path.clone(), tree_mask); + } match node { TrieNode::EmptyRoot => { @@ -321,7 +343,10 @@ impl

RevealedSparseTrie

{ // Memoize the hash of a previously blinded node in a new branch // node. hash: Some(*hash), - store_in_db_trie: None, + store_in_db_trie: Some( + hash_mask.is_some_and(|mask| !mask.is_empty()) || + tree_mask.is_some_and(|mask| !mask.is_empty()), + ), }); } // Branch node already exists, or an extension node was placed where a @@ -433,7 +458,7 @@ impl

RevealedSparseTrie

{ return Ok(()) } - self.reveal_node(path, TrieNode::decode(&mut &child[..])?, None) + self.reveal_node(path, TrieNode::decode(&mut &child[..])?, None, None) } /// Traverse trie nodes down to the leaf node and collect all nodes along the path. @@ -627,22 +652,20 @@ impl

RevealedSparseTrie

{ let mut prefix_set_contains = |path: &Nibbles| *is_in_prefix_set.get_or_insert_with(|| prefix_set.contains(path)); - let (rlp_node, calculated, node_type) = match self.nodes.get_mut(&path).unwrap() { - SparseNode::Empty => { - (RlpNode::word_rlp(&EMPTY_ROOT_HASH), false, SparseNodeType::Empty) - } - SparseNode::Hash(hash) => (RlpNode::word_rlp(hash), false, SparseNodeType::Hash), + let (rlp_node, node_type) = match self.nodes.get_mut(&path).unwrap() { + SparseNode::Empty => (RlpNode::word_rlp(&EMPTY_ROOT_HASH), SparseNodeType::Empty), + SparseNode::Hash(hash) => (RlpNode::word_rlp(hash), SparseNodeType::Hash), SparseNode::Leaf { key, hash } => { let mut path = path.clone(); path.extend_from_slice_unchecked(key); if let Some(hash) = hash.filter(|_| !prefix_set_contains(&path)) { - (RlpNode::word_rlp(&hash), false, SparseNodeType::Leaf) + (RlpNode::word_rlp(&hash), SparseNodeType::Leaf) } else { let value = self.values.get(&path).unwrap(); self.rlp_buf.clear(); let rlp_node = LeafNodeRef { key, value }.rlp(&mut self.rlp_buf); *hash = rlp_node.as_hash(); - (rlp_node, true, SparseNodeType::Leaf) + (rlp_node, SparseNodeType::Leaf) } } SparseNode::Extension { key, hash } => { @@ -651,22 +674,20 @@ impl

RevealedSparseTrie

{ if let Some(hash) = hash.filter(|_| !prefix_set_contains(&path)) { ( RlpNode::word_rlp(&hash), - false, SparseNodeType::Extension { store_in_db_trie: true }, ) } else if buffers.rlp_node_stack.last().is_some_and(|e| e.0 == child_path) { - let (_, child, _, node_type) = buffers.rlp_node_stack.pop().unwrap(); + let (_, child, child_node_type) = buffers.rlp_node_stack.pop().unwrap(); self.rlp_buf.clear(); let rlp_node = ExtensionNodeRef::new(key, &child).rlp(&mut self.rlp_buf); *hash = rlp_node.as_hash(); ( rlp_node, - true, SparseNodeType::Extension { // Inherit the `store_in_db_trie` flag from the child node, which is // always the branch node - store_in_db_trie: node_type.store_in_db_trie(), + store_in_db_trie: child_node_type.store_in_db_trie(), }, ) } else { @@ -682,7 +703,6 @@ impl

RevealedSparseTrie

{ buffers.rlp_node_stack.push(( path, RlpNode::word_rlp(&hash), - false, SparseNodeType::Branch { store_in_db_trie }, )); continue @@ -710,8 +730,7 @@ impl

RevealedSparseTrie

{ let mut hashes = Vec::new(); for (i, child_path) in buffers.branch_child_buf.iter().enumerate() { if buffers.rlp_node_stack.last().is_some_and(|e| &e.0 == child_path) { - let (_, child, calculated, node_type) = - buffers.rlp_node_stack.pop().unwrap(); + let (_, child, child_node_type) = buffers.rlp_node_stack.pop().unwrap(); // Update the masks only if we need to retain trie updates if retain_updates { @@ -720,13 +739,16 @@ impl

RevealedSparseTrie

{ // Determine whether we need to set trie mask bit. let should_set_tree_mask_bit = + // A blinded node has the tree mask bit set + ( + child_node_type.is_hash() && + self.branch_node_tree_masks + .get(&path) + .is_some_and(|mask| mask.is_bit_set(last_child_nibble)) + ) || // A branch or an extension node explicitly set the // `store_in_db_trie` flag - node_type.store_in_db_trie() || - // Set the flag according to whether a child node was - // pre-calculated (`calculated = false`), meaning that it wasn't - // in the database - !calculated; + child_node_type.store_in_db_trie(); if should_set_tree_mask_bit { tree_mask.set_bit(last_child_nibble); } @@ -735,8 +757,8 @@ impl

RevealedSparseTrie

{ // is a blinded node that has its hash mask bit set according to the // database, set the hash mask bit and save the hash. let hash = child.as_hash().filter(|_| { - node_type.is_branch() || - (node_type.is_hash() && + child_node_type.is_branch() || + (child_node_type.is_hash() && self.branch_node_hash_masks .get(&path) .is_some_and(|mask| { @@ -806,14 +828,10 @@ impl

RevealedSparseTrie

{ }; *store_in_db_trie = Some(store_in_db_trie_value); - ( - rlp_node, - true, - SparseNodeType::Branch { store_in_db_trie: store_in_db_trie_value }, - ) + (rlp_node, SparseNodeType::Branch { store_in_db_trie: store_in_db_trie_value }) } }; - buffers.rlp_node_stack.push((path, rlp_node, calculated, node_type)); + buffers.rlp_node_stack.push((path, rlp_node, node_type)); } debug_assert_eq!(buffers.rlp_node_stack.len(), 1); @@ -894,7 +912,7 @@ impl RevealedSparseTrie

{ // remove or do nothing, so // we can safely ignore the hash mask here and // pass `None`. - self.reveal_node(current.clone(), decoded, None)?; + self.reveal_node(current.clone(), decoded, None, None)?; } } } @@ -1046,7 +1064,7 @@ impl RevealedSparseTrie

{ // We'll never have to update the revealed branch node, only remove // or do nothing, so we can safely ignore the hash mask here and // pass `None`. - self.reveal_node(child_path.clone(), decoded, None)?; + self.reveal_node(child_path.clone(), decoded, None, None)?; } } @@ -1251,7 +1269,7 @@ struct RlpNodeBuffers { /// Stack of paths we need rlp nodes for and whether the path is in the prefix set. path_stack: Vec<(Nibbles, Option)>, /// Stack of rlp nodes - rlp_node_stack: Vec<(Nibbles, RlpNode, bool, SparseNodeType)>, + rlp_node_stack: Vec<(Nibbles, RlpNode, SparseNodeType)>, /// Reusable branch child path branch_child_buf: SmallVec<[Nibbles; 16]>, /// Reusable branch value stack @@ -1336,7 +1354,8 @@ mod tests { state: impl IntoIterator + Clone, destroyed_accounts: B256HashSet, proof_targets: impl IntoIterator, - ) -> (B256, TrieUpdates, ProofNodes, HashMap) { + ) -> (B256, TrieUpdates, ProofNodes, HashMap, HashMap) + { let mut account_rlp = Vec::new(); let mut hash_builder = HashBuilder::default() @@ -1383,12 +1402,19 @@ mod tests { .iter() .map(|(path, node)| (path.clone(), node.hash_mask)) .collect(); + let branch_node_tree_masks = hash_builder + .updated_branch_nodes + .clone() + .unwrap_or_default() + .iter() + .map(|(path, node)| (path.clone(), node.tree_mask)) + .collect(); let mut trie_updates = TrieUpdates::default(); let removed_keys = node_iter.walker.take_removed_keys(); trie_updates.finalize(hash_builder, removed_keys, destroyed_accounts); - (root, trie_updates, proof_nodes, branch_node_hash_masks) + (root, trie_updates, proof_nodes, branch_node_hash_masks, branch_node_tree_masks) } /// Assert that the sparse trie nodes and the proof nodes from the hash builder are equal. @@ -1450,7 +1476,7 @@ mod tests { account_rlp }; - let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) = + let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) = run_hash_builder([(key.clone(), value())], Default::default(), [key.clone()]); let mut sparse = RevealedSparseTrie::default().with_updates(true); @@ -1475,7 +1501,7 @@ mod tests { account_rlp }; - let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) = + let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) = run_hash_builder( paths.iter().cloned().zip(std::iter::repeat_with(value)), Default::default(), @@ -1504,7 +1530,7 @@ mod tests { account_rlp }; - let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) = + let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) = run_hash_builder( paths.iter().cloned().zip(std::iter::repeat_with(value)), Default::default(), @@ -1541,7 +1567,7 @@ mod tests { account_rlp }; - let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) = + let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) = run_hash_builder( paths.iter().sorted_unstable().cloned().zip(std::iter::repeat_with(value)), Default::default(), @@ -1579,7 +1605,7 @@ mod tests { account_rlp }; - let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) = + let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) = run_hash_builder( paths.iter().cloned().zip(std::iter::repeat_with(|| old_value)), Default::default(), @@ -1597,7 +1623,7 @@ mod tests { assert_eq!(sparse_updates.updated_nodes, hash_builder_updates.account_nodes); assert_eq_sparse_trie_proof_nodes(&sparse, hash_builder_proof_nodes); - let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) = + let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) = run_hash_builder( paths.iter().cloned().zip(std::iter::repeat_with(|| new_value)), Default::default(), @@ -1871,7 +1897,7 @@ mod tests { )); let mut sparse = - RevealedSparseTrie::from_root(branch.clone(), Some(TrieMask::new(0b01)), false) + RevealedSparseTrie::from_root(branch.clone(), Some(TrieMask::new(0b01)), None, false) .unwrap(); // Reveal a branch node and one of its children @@ -1879,8 +1905,8 @@ mod tests { // Branch (Mask = 11) // ├── 0 -> Hash (Path = 0) // └── 1 -> Leaf (Path = 1) - sparse.reveal_node(Nibbles::default(), branch, Some(TrieMask::new(0b01))).unwrap(); - sparse.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None).unwrap(); + sparse.reveal_node(Nibbles::default(), branch, Some(TrieMask::new(0b01)), None).unwrap(); + sparse.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None, None).unwrap(); // Removing a blinded leaf should result in an error assert_matches!( @@ -1904,7 +1930,7 @@ mod tests { )); let mut sparse = - RevealedSparseTrie::from_root(branch.clone(), Some(TrieMask::new(0b01)), false) + RevealedSparseTrie::from_root(branch.clone(), Some(TrieMask::new(0b01)), None, false) .unwrap(); // Reveal a branch node and one of its children @@ -1912,8 +1938,8 @@ mod tests { // Branch (Mask = 11) // ├── 0 -> Hash (Path = 0) // └── 1 -> Leaf (Path = 1) - sparse.reveal_node(Nibbles::default(), branch, Some(TrieMask::new(0b01))).unwrap(); - sparse.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None).unwrap(); + sparse.reveal_node(Nibbles::default(), branch, Some(TrieMask::new(0b01)), None).unwrap(); + sparse.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None, None).unwrap(); // Removing a non-existent leaf should be a noop let sparse_old = sparse.clone(); @@ -1951,7 +1977,7 @@ mod tests { // Insert state updates into the hash builder and calculate the root state.extend(update); - let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) = + let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) = run_hash_builder( state.clone(), Default::default(), @@ -1982,7 +2008,7 @@ mod tests { let sparse_root = updated_sparse.root(); let sparse_updates = updated_sparse.take_updates(); - let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) = + let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) = run_hash_builder( state.clone(), Default::default(), @@ -2063,24 +2089,29 @@ mod tests { }; // Generate the proof for the root node and initialize the sparse trie with it - let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = run_hash_builder( - [(key1(), value()), (key3(), value())], - Default::default(), - [Nibbles::default()], - ); + let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) = + run_hash_builder( + [(key1(), value()), (key3(), value())], + Default::default(), + [Nibbles::default()], + ); let mut sparse = RevealedSparseTrie::from_root( TrieNode::decode(&mut &hash_builder_proof_nodes.nodes_sorted()[0].1[..]).unwrap(), branch_node_hash_masks.get(&Nibbles::default()).copied(), + branch_node_tree_masks.get(&Nibbles::default()).copied(), false, ) .unwrap(); // Generate the proof for the first key and reveal it in the sparse trie - let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = + let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) = run_hash_builder([(key1(), value()), (key3(), value())], Default::default(), [key1()]); for (path, node) in hash_builder_proof_nodes.nodes_sorted() { let hash_mask = branch_node_hash_masks.get(&path).copied(); - sparse.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask).unwrap(); + let tree_mask = branch_node_tree_masks.get(&path).copied(); + sparse + .reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask, tree_mask) + .unwrap(); } // Check that the branch node exists with only two nibbles set @@ -2099,11 +2130,14 @@ mod tests { ); // Generate the proof for the third key and reveal it in the sparse trie - let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = + let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) = run_hash_builder([(key1(), value()), (key3(), value())], Default::default(), [key3()]); for (path, node) in hash_builder_proof_nodes.nodes_sorted() { let hash_mask = branch_node_hash_masks.get(&path).copied(); - sparse.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask).unwrap(); + let tree_mask = branch_node_tree_masks.get(&path).copied(); + sparse + .reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask, tree_mask) + .unwrap(); } // Check that nothing changed in the branch node @@ -2114,7 +2148,7 @@ mod tests { // Generate the nodes for the full trie with all three key using the hash builder, and // compare them to the sparse trie - let (_, _, hash_builder_proof_nodes, _) = run_hash_builder( + let (_, _, hash_builder_proof_nodes, _, _) = run_hash_builder( [(key1(), value()), (key2(), value()), (key3(), value())], Default::default(), [key1(), key2(), key3()], @@ -2141,28 +2175,34 @@ mod tests { let value = || Account::default(); // Generate the proof for the root node and initialize the sparse trie with it - let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = run_hash_builder( - [(key1(), value()), (key2(), value()), (key3(), value())], - Default::default(), - [Nibbles::default()], - ); + let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) = + run_hash_builder( + [(key1(), value()), (key2(), value()), (key3(), value())], + Default::default(), + [Nibbles::default()], + ); let mut sparse = RevealedSparseTrie::from_root( TrieNode::decode(&mut &hash_builder_proof_nodes.nodes_sorted()[0].1[..]).unwrap(), branch_node_hash_masks.get(&Nibbles::default()).copied(), + branch_node_tree_masks.get(&Nibbles::default()).copied(), false, ) .unwrap(); // Generate the proof for the children of the root branch node and reveal it in the sparse // trie - let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = run_hash_builder( - [(key1(), value()), (key2(), value()), (key3(), value())], - Default::default(), - [key1(), Nibbles::from_nibbles_unchecked([0x01])], - ); + let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) = + run_hash_builder( + [(key1(), value()), (key2(), value()), (key3(), value())], + Default::default(), + [key1(), Nibbles::from_nibbles_unchecked([0x01])], + ); for (path, node) in hash_builder_proof_nodes.nodes_sorted() { let hash_mask = branch_node_hash_masks.get(&path).copied(); - sparse.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask).unwrap(); + let tree_mask = branch_node_tree_masks.get(&path).copied(); + sparse + .reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask, tree_mask) + .unwrap(); } // Check that the branch node exists @@ -2181,14 +2221,18 @@ mod tests { ); // Generate the proof for the third key and reveal it in the sparse trie - let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = run_hash_builder( - [(key1(), value()), (key2(), value()), (key3(), value())], - Default::default(), - [key2()], - ); + let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) = + run_hash_builder( + [(key1(), value()), (key2(), value()), (key3(), value())], + Default::default(), + [key2()], + ); for (path, node) in hash_builder_proof_nodes.nodes_sorted() { let hash_mask = branch_node_hash_masks.get(&path).copied(); - sparse.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask).unwrap(); + let tree_mask = branch_node_tree_masks.get(&path).copied(); + sparse + .reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask, tree_mask) + .unwrap(); } // Check that nothing changed in the extension node @@ -2219,14 +2263,16 @@ mod tests { }; // Generate the proof for the root node and initialize the sparse trie with it - let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = run_hash_builder( - [(key1(), value()), (key2(), value())], - Default::default(), - [Nibbles::default()], - ); + let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) = + run_hash_builder( + [(key1(), value()), (key2(), value())], + Default::default(), + [Nibbles::default()], + ); let mut sparse = RevealedSparseTrie::from_root( TrieNode::decode(&mut &hash_builder_proof_nodes.nodes_sorted()[0].1[..]).unwrap(), branch_node_hash_masks.get(&Nibbles::default()).copied(), + branch_node_tree_masks.get(&Nibbles::default()).copied(), false, ) .unwrap(); @@ -2247,11 +2293,14 @@ mod tests { ); // Generate the proof for the first key and reveal it in the sparse trie - let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = + let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) = run_hash_builder([(key1(), value()), (key2(), value())], Default::default(), [key1()]); for (path, node) in hash_builder_proof_nodes.nodes_sorted() { let hash_mask = branch_node_hash_masks.get(&path).copied(); - sparse.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask).unwrap(); + let tree_mask = branch_node_tree_masks.get(&path).copied(); + sparse + .reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask, tree_mask) + .unwrap(); } // Check that the branch node wasn't overwritten by the extension node in the proof @@ -2345,7 +2394,7 @@ mod tests { account_rlp }; - let (hash_builder_root, hash_builder_updates, _, _) = run_hash_builder( + let (hash_builder_root, hash_builder_updates, _, _, _) = run_hash_builder( [(key1(), value()), (key2(), value())], Default::default(), [Nibbles::default()], diff --git a/crates/trie/trie/src/proof/mod.rs b/crates/trie/trie/src/proof/mod.rs index 165c27c82e2e..5c632b5cecaa 100644 --- a/crates/trie/trie/src/proof/mod.rs +++ b/crates/trie/trie/src/proof/mod.rs @@ -33,8 +33,8 @@ pub struct Proof { hashed_cursor_factory: H, /// A set of prefix sets that have changes. prefix_sets: TriePrefixSetsMut, - /// Flag indicating whether to include branch node hash masks in the proof. - collect_branch_node_hash_masks: bool, + /// Flag indicating whether to include branch node masks in the proof. + collect_branch_node_masks: bool, } impl Proof { @@ -44,7 +44,7 @@ impl Proof { trie_cursor_factory: t, hashed_cursor_factory: h, prefix_sets: TriePrefixSetsMut::default(), - collect_branch_node_hash_masks: false, + collect_branch_node_masks: false, } } @@ -54,7 +54,7 @@ impl Proof { trie_cursor_factory, hashed_cursor_factory: self.hashed_cursor_factory, prefix_sets: self.prefix_sets, - collect_branch_node_hash_masks: self.collect_branch_node_hash_masks, + collect_branch_node_masks: self.collect_branch_node_masks, } } @@ -64,7 +64,7 @@ impl Proof { trie_cursor_factory: self.trie_cursor_factory, hashed_cursor_factory, prefix_sets: self.prefix_sets, - collect_branch_node_hash_masks: self.collect_branch_node_hash_masks, + collect_branch_node_masks: self.collect_branch_node_masks, } } @@ -74,9 +74,9 @@ impl Proof { self } - /// Set the flag indicating whether to include branch node hash masks in the proof. - pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self { - self.collect_branch_node_hash_masks = branch_node_hash_masks; + /// Set the flag indicating whether to include branch node masks in the proof. + pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self { + self.collect_branch_node_masks = branch_node_masks; self } } @@ -117,7 +117,7 @@ where let retainer = targets.keys().map(Nibbles::unpack).collect(); let mut hash_builder = HashBuilder::default() .with_proof_retainer(retainer) - .with_updates(self.collect_branch_node_hash_masks); + .with_updates(self.collect_branch_node_masks); // Initialize all storage multiproofs as empty. // Storage multiproofs for non empty tries will be overwritten if necessary. @@ -144,7 +144,7 @@ where hashed_address, ) .with_prefix_set_mut(storage_prefix_set) - .with_branch_node_hash_masks(self.collect_branch_node_hash_masks) + .with_branch_node_masks(self.collect_branch_node_masks) .storage_multiproof(proof_targets.unwrap_or_default())?; // Encode account @@ -164,18 +164,23 @@ where } let _ = hash_builder.root(); let account_subtree = hash_builder.take_proof_nodes(); - let branch_node_hash_masks = if self.collect_branch_node_hash_masks { - hash_builder - .updated_branch_nodes - .unwrap_or_default() - .into_iter() - .map(|(path, node)| (path, node.hash_mask)) - .collect() + let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks { + let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default(); + ( + updated_branch_nodes + .iter() + .map(|(path, node)| (path.clone(), node.hash_mask)) + .collect(), + updated_branch_nodes + .into_iter() + .map(|(path, node)| (path, node.tree_mask)) + .collect(), + ) } else { - HashMap::default() + (HashMap::default(), HashMap::default()) }; - Ok(MultiProof { account_subtree, branch_node_hash_masks, storages }) + Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages }) } } @@ -190,8 +195,8 @@ pub struct StorageProof { hashed_address: B256, /// The set of storage slot prefixes that have changed. prefix_set: PrefixSetMut, - /// Flag indicating whether to include branch node hash masks in the proof. - collect_branch_node_hash_masks: bool, + /// Flag indicating whether to include branch node masks in the proof. + collect_branch_node_masks: bool, } impl StorageProof { @@ -207,7 +212,7 @@ impl StorageProof { hashed_cursor_factory: h, hashed_address, prefix_set: PrefixSetMut::default(), - collect_branch_node_hash_masks: false, + collect_branch_node_masks: false, } } @@ -218,7 +223,7 @@ impl StorageProof { hashed_cursor_factory: self.hashed_cursor_factory, hashed_address: self.hashed_address, prefix_set: self.prefix_set, - collect_branch_node_hash_masks: self.collect_branch_node_hash_masks, + collect_branch_node_masks: self.collect_branch_node_masks, } } @@ -229,7 +234,7 @@ impl StorageProof { hashed_cursor_factory, hashed_address: self.hashed_address, prefix_set: self.prefix_set, - collect_branch_node_hash_masks: self.collect_branch_node_hash_masks, + collect_branch_node_masks: self.collect_branch_node_masks, } } @@ -239,9 +244,9 @@ impl StorageProof { self } - /// Set the flag indicating whether to include branch node hash masks in the proof. - pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self { - self.collect_branch_node_hash_masks = branch_node_hash_masks; + /// Set the flag indicating whether to include branch node masks in the proof. + pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self { + self.collect_branch_node_masks = branch_node_masks; self } } @@ -282,7 +287,7 @@ where let retainer = ProofRetainer::from_iter(target_nibbles); let mut hash_builder = HashBuilder::default() .with_proof_retainer(retainer) - .with_updates(self.collect_branch_node_hash_masks); + .with_updates(self.collect_branch_node_masks); let mut storage_node_iter = TrieNodeIter::new(walker, hashed_storage_cursor); while let Some(node) = storage_node_iter.try_next()? { match node { @@ -300,17 +305,22 @@ where let root = hash_builder.root(); let subtree = hash_builder.take_proof_nodes(); - let branch_node_hash_masks = if self.collect_branch_node_hash_masks { - hash_builder - .updated_branch_nodes - .unwrap_or_default() - .into_iter() - .map(|(path, node)| (path, node.hash_mask)) - .collect() + let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks { + let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default(); + ( + updated_branch_nodes + .iter() + .map(|(path, node)| (path.clone(), node.hash_mask)) + .collect(), + updated_branch_nodes + .into_iter() + .map(|(path, node)| (path, node.tree_mask)) + .collect(), + ) } else { - HashMap::default() + (HashMap::default(), HashMap::default()) }; - Ok(StorageMultiProof { root, subtree, branch_node_hash_masks }) + Ok(StorageMultiProof { root, subtree, branch_node_hash_masks, branch_node_tree_masks }) } }