Skip to content

Commit

Permalink
fix(trie): sparse trie tree masks (#13760)
Browse files Browse the repository at this point in the history
  • Loading branch information
shekhirin authored Jan 10, 2025
1 parent 986c754 commit 69f9e16
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 157 deletions.
2 changes: 1 addition & 1 deletion crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ where
config.prefix_sets,
thread_pool,
)
.with_branch_node_hash_masks(true)
.with_branch_node_masks(true)
.multiproof(proof_targets)?)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/engine/tree/src/tree/trie_updates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ fn branch_nodes_equal(
) -> bool {
if let (Some(task), Some(regular)) = (task.as_ref(), regular.as_ref()) {
task.state_mask == regular.state_mask &&
// We do not compare the tree mask because it is known to be mismatching
task.tree_mask == regular.tree_mask &&
task.hash_mask == regular.hash_mask &&
task.hashes == regular.hashes &&
task.root_hash == regular.root_hash
Expand Down
9 changes: 9 additions & 0 deletions crates/trie/common/src/proofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub struct MultiProof {
pub account_subtree: ProofNodes,
/// The hash masks of the branch nodes in the account proof.
pub branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
/// The tree masks of the branch nodes in the account proof.
pub branch_node_tree_masks: HashMap<Nibbles, TrieMask>,
/// Storage trie multiproofs.
pub storages: B256HashMap<StorageMultiProof>,
}
Expand Down Expand Up @@ -115,6 +117,7 @@ impl MultiProof {
self.account_subtree.extend_from(other.account_subtree);

self.branch_node_hash_masks.extend(other.branch_node_hash_masks);
self.branch_node_tree_masks.extend(other.branch_node_tree_masks);

for (hashed_address, storage) in other.storages {
match self.storages.entry(hashed_address) {
Expand All @@ -123,6 +126,7 @@ impl MultiProof {
let entry = entry.get_mut();
entry.subtree.extend_from(storage.subtree);
entry.branch_node_hash_masks.extend(storage.branch_node_hash_masks);
entry.branch_node_tree_masks.extend(storage.branch_node_tree_masks);
}
hash_map::Entry::Vacant(entry) => {
entry.insert(storage);
Expand All @@ -141,6 +145,8 @@ pub struct StorageMultiProof {
pub subtree: ProofNodes,
/// The hash masks of the branch nodes in the storage proof.
pub branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
/// The tree masks of the branch nodes in the storage proof.
pub branch_node_tree_masks: HashMap<Nibbles, TrieMask>,
}

impl StorageMultiProof {
Expand All @@ -153,6 +159,7 @@ impl StorageMultiProof {
Bytes::from([EMPTY_STRING_CODE]),
)]),
branch_node_hash_masks: HashMap::default(),
branch_node_tree_masks: HashMap::default(),
}
}

Expand Down Expand Up @@ -398,6 +405,7 @@ mod tests {
root,
subtree: subtree1,
branch_node_hash_masks: HashMap::default(),
branch_node_tree_masks: HashMap::default(),
},
);

Expand All @@ -412,6 +420,7 @@ mod tests {
root,
subtree: subtree2,
branch_node_hash_masks: HashMap::default(),
branch_node_tree_masks: HashMap::default(),
},
);

Expand Down
41 changes: 23 additions & 18 deletions crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ pub struct ParallelProof<Factory> {
/// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
/// if we have cached nodes for them.
pub prefix_sets: Arc<TriePrefixSetsMut>,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
/// Flag indicating whether to include branch node masks in the proof.
collect_branch_node_masks: bool,
/// Thread pool for local tasks
thread_pool: Arc<rayon::ThreadPool>,
/// Parallel state root metrics.
Expand All @@ -67,16 +67,16 @@ impl<Factory> ParallelProof<Factory> {
nodes_sorted,
state_sorted,
prefix_sets,
collect_branch_node_hash_masks: false,
collect_branch_node_masks: false,
thread_pool,
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics::default(),
}
}

/// Set the flag indicating whether to include branch node hash masks in the proof.
pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
self.collect_branch_node_hash_masks = branch_node_hash_masks;
/// Set the flag indicating whether to include branch node masks in the proof.
pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
self.collect_branch_node_masks = branch_node_masks;
self
}
}
Expand Down Expand Up @@ -137,7 +137,7 @@ where
let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
let trie_nodes_sorted = self.nodes_sorted.clone();
let hashed_state_sorted = self.state_sorted.clone();
let collect_masks = self.collect_branch_node_hash_masks;
let collect_masks = self.collect_branch_node_masks;

let (tx, rx) = std::sync::mpsc::sync_channel(1);

Expand Down Expand Up @@ -182,7 +182,7 @@ where
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(collect_masks)
.with_branch_node_masks(collect_masks)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()));

Expand Down Expand Up @@ -233,7 +233,7 @@ where
let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_hash_masks);
.with_updates(self.collect_branch_node_masks);

// Initialize all storage multiproofs as empty.
// Storage multiproofs for non empty tries will be overwritten if necessary.
Expand Down Expand Up @@ -301,18 +301,23 @@ where
self.metrics.record_state_trie(tracker.finish());

let account_subtree = hash_builder.take_proof_nodes();
let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
hash_builder
.updated_branch_nodes
.unwrap_or_default()
.into_iter()
.map(|(path, node)| (path, node.hash_mask))
.collect()
let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
(
updated_branch_nodes
.iter()
.map(|(path, node)| (path.clone(), node.hash_mask))
.collect(),
updated_branch_nodes
.into_iter()
.map(|(path, node)| (path, node.tree_mask))
.collect(),
)
} else {
HashMap::default()
(HashMap::default(), HashMap::default())
};

Ok(MultiProof { account_subtree, branch_node_hash_masks, storages })
Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
}
}

Expand Down
43 changes: 29 additions & 14 deletions crates/trie/sparse/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,14 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.account_node_provider(),
root_node,
None,
None,
self.retain_updates,
)?;

// Reveal the remaining proof nodes.
for (path, bytes) in proof {
let node = TrieNode::decode(&mut &bytes[..])?;
trie.reveal_node(path, node, None)?;
trie.reveal_node(path, node, None, None)?;
}

// Mark leaf path as revealed.
Expand Down Expand Up @@ -196,13 +197,14 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.storage_node_provider(account),
root_node,
None,
None,
self.retain_updates,
)?;

// Reveal the remaining proof nodes.
for (path, bytes) in proof {
let node = TrieNode::decode(&mut &bytes[..])?;
trie.reveal_node(path, node, None)?;
trie.reveal_node(path, node, None, None)?;
}

// Mark leaf path as revealed.
Expand All @@ -227,20 +229,24 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.account_node_provider(),
root_node,
multiproof.branch_node_hash_masks.get(&Nibbles::default()).copied(),
multiproof.branch_node_tree_masks.get(&Nibbles::default()).copied(),
self.retain_updates,
)?;

// Reveal the remaining proof nodes.
for (path, bytes) in account_nodes {
let node = TrieNode::decode(&mut &bytes[..])?;
let hash_mask = if let TrieNode::Branch(_) = node {
multiproof.branch_node_hash_masks.get(&path).copied()
let (hash_mask, tree_mask) = if let TrieNode::Branch(_) = node {
(
multiproof.branch_node_hash_masks.get(&path).copied(),
multiproof.branch_node_tree_masks.get(&path).copied(),
)
} else {
None
(None, None)
};

trace!(target: "trie::sparse", ?path, ?node, ?hash_mask, "Revealing account node");
trie.reveal_node(path, node, hash_mask)?;
trace!(target: "trie::sparse", ?path, ?node, ?hash_mask, ?tree_mask, "Revealing account node");
trie.reveal_node(path, node, hash_mask, tree_mask)?;
}
}

Expand All @@ -254,20 +260,24 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.storage_node_provider(account),
root_node,
storage_subtree.branch_node_hash_masks.get(&Nibbles::default()).copied(),
storage_subtree.branch_node_tree_masks.get(&Nibbles::default()).copied(),
self.retain_updates,
)?;

// Reveal the remaining proof nodes.
for (path, bytes) in nodes {
let node = TrieNode::decode(&mut &bytes[..])?;
let hash_mask = if let TrieNode::Branch(_) = node {
storage_subtree.branch_node_hash_masks.get(&path).copied()
let (hash_mask, tree_mask) = if let TrieNode::Branch(_) = node {
(
storage_subtree.branch_node_hash_masks.get(&path).copied(),
storage_subtree.branch_node_tree_masks.get(&path).copied(),
)
} else {
None
(None, None)
};

trace!(target: "trie::sparse", ?account, ?path, ?node, ?hash_mask, "Revealing storage node");
trie.reveal_node(path, node, hash_mask)?;
trace!(target: "trie::sparse", ?account, ?path, ?node, ?hash_mask, ?tree_mask, "Revealing storage node");
trie.reveal_node(path, node, hash_mask, tree_mask)?;
}
}
}
Expand Down Expand Up @@ -348,29 +358,31 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.storage_node_provider(account),
trie_node,
None,
None,
self.retain_updates,
)?;
} else {
// Reveal non-root storage trie node.
storage_trie_entry
.as_revealed_mut()
.ok_or(SparseTrieErrorKind::Blind)?
.reveal_node(path, trie_node, None)?;
.reveal_node(path, trie_node, None, None)?;
}
} else if path.is_empty() {
// Handle special state root node case.
self.state.reveal_root_with_provider(
self.provider_factory.account_node_provider(),
trie_node,
None,
None,
self.retain_updates,
)?;
} else {
// Reveal non-root state trie node.
self.state
.as_revealed_mut()
.ok_or(SparseTrieErrorKind::Blind)?
.reveal_node(path, trie_node, None)?;
.reveal_node(path, trie_node, None, None)?;
}
}

Expand Down Expand Up @@ -668,13 +680,15 @@ mod tests {
Nibbles::from_nibbles([0x1]),
TrieMask::new(0b00),
)]),
branch_node_tree_masks: HashMap::default(),
storages: HashMap::from_iter([
(
address_1,
StorageMultiProof {
root,
subtree: storage_proof_nodes.clone(),
branch_node_hash_masks: storage_branch_node_hash_masks.clone(),
branch_node_tree_masks: HashMap::default(),
},
),
(
Expand All @@ -683,6 +697,7 @@ mod tests {
root,
subtree: storage_proof_nodes,
branch_node_hash_masks: storage_branch_node_hash_masks,
branch_node_tree_masks: HashMap::default(),
},
),
]),
Expand Down
Loading

0 comments on commit 69f9e16

Please sign in to comment.