Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(engine): integrate state root task and comment it #13265

Merged
merged 13 commits into from
Dec 17, 2024
Merged
18 changes: 8 additions & 10 deletions crates/engine/tree/benches/state_root_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use revm_primitives::{
Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap,
B256, KECCAK_EMPTY, U256,
};
use std::sync::Arc;

#[derive(Debug, Clone)]
struct BenchParams {
Expand Down Expand Up @@ -137,16 +136,15 @@ fn bench_state_root(c: &mut Criterion) {
let state_updates = create_bench_state_updates(params);
setup_provider(&factory, &state_updates).expect("failed to setup provider");

let trie_input = Arc::new(TrieInput::from_state(Default::default()));

let config = StateRootConfig {
consistent_view: ConsistentDbView::new(factory, None),
input: trie_input,
};
let trie_input = TrieInput::from_state(Default::default());
let config = StateRootConfig::new_from_input(
ConsistentDbView::new(factory, None),
trie_input,
);
let provider = config.consistent_view.provider_ro().unwrap();
let nodes_sorted = config.input.nodes.clone().into_sorted();
let state_sorted = config.input.state.clone().into_sorted();
let prefix_sets = Arc::new(config.input.prefix_sets.clone());
let nodes_sorted = config.nodes_sorted.clone();
let state_sorted = config.state_sorted.clone();
let prefix_sets = config.prefix_sets.clone();

(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)
},
Expand Down
108 changes: 82 additions & 26 deletions crates/engine/tree/src/tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2224,13 +2224,47 @@ where

let exec_time = Instant::now();

// TODO: create StateRootTask with the receiving end of a channel and
// pass the sending end of the channel to the state hook.
let noop_state_hook = |_state: &EvmState| {};
let persistence_not_in_progress = !self.persistence_state.in_progress();

// TODO: uncomment to use StateRootTask
shekhirin marked this conversation as resolved.
Show resolved Hide resolved

// let (state_root_handle, state_hook) = if persistence_not_in_progress {
// let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;
//
// let state_root_config = StateRootConfig::new_from_input(
// consistent_view.clone(),
// self.compute_trie_input(consistent_view, block.header().parent_hash())
// .map_err(ParallelStateRootError::into)?,
// );
//
// let provider_ro = consistent_view.provider_ro()?;
// let nodes_sorted = state_root_config.nodes_sorted.clone();
// let state_sorted = state_root_config.state_sorted.clone();
// let prefix_sets = state_root_config.prefix_sets.clone();
// let blinded_provider_factory = ProofBlindedProviderFactory::new(
// InMemoryTrieCursorFactory::new(
// DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
// &nodes_sorted,
// ),
// HashedPostStateCursorFactory::new(
// DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
// &state_sorted,
// ),
// prefix_sets,
// );
//
// let state_root_task = StateRootTask::new(state_root_config,
// blinded_provider_factory); let state_hook = state_root_task.state_hook();
// (Some(state_root_task.spawn(scope)), Box::new(state_hook) as Box<dyn OnStateHook>)
// } else {
// (None, Box::new(|_state: &EvmState| {}) as Box<dyn OnStateHook>)
// };
let state_hook = Box::new(|_state: &EvmState| {});

let output = self.metrics.executor.execute_metered(
executor,
(&block, U256::MAX).into(),
Box::new(noop_state_hook),
state_hook,
)?;

trace!(target: "engine::tree", elapsed=?exec_time.elapsed(), ?block_number, "Executed block");
Expand All @@ -2253,33 +2287,47 @@ where

trace!(target: "engine::tree", block=?sealed_block.num_hash(), "Calculating block state root");
let root_time = Instant::now();
let mut state_root_result = None;

// TODO: switch to calculate state root using `StateRootTask`.

// We attempt to compute state root in parallel if we are currently not persisting anything
// to database. This is safe, because the database state cannot change until we
// finish parallel computation. It is important that nothing is being persisted as
// we are computing in parallel, because we initialize a different database transaction
// per thread and it might end up with a different view of the database.
let persistence_in_progress = self.persistence_state.in_progress();
if !persistence_in_progress {
state_root_result = match self
.compute_state_root_parallel(block.header().parent_hash(), &hashed_state)
{
Ok((state_root, trie_output)) => Some((state_root, trie_output)),
let state_root_result = if persistence_not_in_progress {
// TODO: uncomment to use StateRootTask

// if let Some(state_root_handle) = state_root_handle {
// match state_root_handle.wait_for_result() {
// Ok((task_state_root, task_trie_updates)) => {
// info!(
// target: "engine::tree",
// block = ?sealed_block.num_hash(),
// ?task_state_root,
// "State root task finished"
// );
// }
// Err(error) => {
// info!(target: "engine::tree", ?error, "Failed to wait for state root task
// result"); }
// }
// }

match self.compute_state_root_parallel(block.header().parent_hash(), &hashed_state) {
Ok(result) => Some(result),
Err(ParallelStateRootError::Provider(ProviderError::ConsistentView(error))) => {
debug!(target: "engine", %error, "Parallel state root computation failed consistency check, falling back");
None
}
Err(error) => return Err(InsertBlockErrorKindTwo::Other(Box::new(error))),
};
}
}
} else {
None
};

let (state_root, trie_output) = if let Some(result) = state_root_result {
result
} else {
debug!(target: "engine::tree", block=?sealed_block.num_hash(), persistence_in_progress, "Failed to compute state root in parallel");
debug!(target: "engine::tree", block=?sealed_block.num_hash(), ?persistence_not_in_progress, "Failed to compute state root in parallel");
state_provider.state_root_with_updates(hashed_state.clone())?
};

Expand Down Expand Up @@ -2344,14 +2392,25 @@ where
parent_hash: B256,
hashed_state: &HashedPostState,
) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
// TODO: when we switch to calculate state root using `StateRootTask` this
// method can be still useful to calculate the required `TrieInput` to
// create the task.
let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;

let mut input = self.compute_trie_input(consistent_view.clone(), parent_hash)?;
// Extend with block we are validating root for.
input.append_ref(hashed_state);

ParallelStateRoot::new(consistent_view, input).incremental_root_with_updates()
}

/// Computes the trie input at the provided parent hash.
fn compute_trie_input(
&self,
consistent_view: ConsistentDbView<P>,
parent_hash: B256,
) -> Result<TrieInput, ParallelStateRootError> {
let mut input = TrieInput::default();

if let Some((historical, blocks)) = self.state.tree_state.blocks_by_hash(parent_hash) {
debug!(target: "engine::tree", %parent_hash, %historical, "Calculating state root in parallel, parent found in memory");
debug!(target: "engine::tree", %parent_hash, %historical, "Parent found in memory");
// Retrieve revert state for historical block.
let revert_state = consistent_view.revert_state(historical)?;
input.append(revert_state);
Expand All @@ -2362,15 +2421,12 @@ where
}
} else {
// The block attaches to canonical persisted parent.
debug!(target: "engine::tree", %parent_hash, "Calculating state root in parallel, parent found in disk");
debug!(target: "engine::tree", %parent_hash, "Parent found on disk");
let revert_state = consistent_view.revert_state(parent_hash)?;
input.append(revert_state);
}

// Extend with block we are validating root for.
input.append_ref(hashed_state);

ParallelStateRoot::new(consistent_view, input).incremental_root_with_updates()
Ok(input)
}

/// Handles an error that occurred while inserting a block.
Expand Down Expand Up @@ -2648,7 +2704,7 @@ mod tests {
use reth_primitives::{Block, BlockExt, EthPrimitives};
use reth_provider::test_utils::MockEthProvider;
use reth_rpc_types_compat::engine::{block_to_payload_v1, payload::block_to_payload_v3};
use reth_trie::updates::TrieUpdates;
use reth_trie::{updates::TrieUpdates, HashedPostState};
use std::{
str::FromStr,
sync::mpsc::{channel, Sender},
Expand Down
72 changes: 52 additions & 20 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@ use reth_provider::{
StateCommitmentProvider,
};
use reth_trie::{
proof::Proof, updates::TrieUpdates, HashedPostState, HashedStorage, MultiProof,
MultiProofTargets, Nibbles, TrieInput,
hashed_cursor::HashedPostStateCursorFactory,
prefix_set::TriePrefixSetsMut,
proof::Proof,
trie_cursor::InMemoryTrieCursorFactory,
updates::{TrieUpdates, TrieUpdatesSorted},
HashedPostState, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets, Nibbles,
TrieInput,
};
use reth_trie_db::DatabaseProof;
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseProof, DatabaseTrieCursorFactory};
use reth_trie_parallel::root::ParallelStateRootError;
use reth_trie_sparse::{
blinded::{BlindedProvider, BlindedProviderFactory},
Expand Down Expand Up @@ -72,12 +77,31 @@ impl StateRootHandle {
}

/// Common configuration for state root tasks
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct StateRootConfig<Factory> {
/// View over the state in the database.
pub consistent_view: ConsistentDbView<Factory>,
/// Latest trie input.
pub input: Arc<TrieInput>,
/// The sorted collection of cached in-memory intermediate trie nodes that
/// can be reused for computation.
pub nodes_sorted: Arc<TrieUpdatesSorted>,
/// The sorted in-memory overlay hashed state.
pub state_sorted: Arc<HashedPostStateSorted>,
/// The collection of prefix sets for the computation. Since the prefix sets _always_
/// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
/// if we have cached nodes for them.
pub prefix_sets: Arc<TriePrefixSetsMut>,
}

impl<Factory> StateRootConfig<Factory> {
/// Creates a new state root config from the consistent view and the trie input.
pub fn new_from_input(consistent_view: ConsistentDbView<Factory>, input: TrieInput) -> Self {
Self {
consistent_view,
nodes_sorted: Arc::new(input.nodes.into_sorted()),
state_sorted: Arc::new(input.state.into_sorted()),
prefix_sets: Arc::new(input.prefix_sets),
}
}
}

/// Messages used internally by the state root task
Expand Down Expand Up @@ -321,8 +345,7 @@ where
/// Returns proof targets derived from the state update.
fn on_state_update(
scope: &rayon::Scope<'env>,
view: ConsistentDbView<Factory>,
input: Arc<TrieInput>,
config: StateRootConfig<Factory>,
update: EvmState,
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
Expand All @@ -335,7 +358,7 @@ where

// Dispatch proof gathering for this state update
scope.spawn(move |_| {
let provider = match view.provider_ro() {
let provider = match config.consistent_view.provider_ro() {
Ok(provider) => provider,
Err(error) => {
error!(target: "engine::root", ?error, "Could not get provider");
Expand All @@ -346,11 +369,18 @@ where
};

// TODO: replace with parallel proof
let result = Proof::overlay_multiproof(
provider.tx_ref(),
input.as_ref().clone(),
proof_targets.clone(),
);
let result = Proof::from_tx(provider.tx_ref())
.with_trie_cursor_factory(InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&config.nodes_sorted,
))
.with_hashed_cursor_factory(HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&config.state_sorted,
))
.with_prefix_sets_mut(config.prefix_sets.as_ref().clone())
.with_branch_node_hash_masks(true)
.multiproof(proof_targets.clone());
match result {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Expand Down Expand Up @@ -472,8 +502,7 @@ where
);
Self::on_state_update(
scope,
self.config.consistent_view.clone(),
self.config.input.clone(),
self.config.clone(),
update,
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
Expand Down Expand Up @@ -859,13 +888,16 @@ mod tests {
}
}

let input = TrieInput::from_state(hashed_state);
let nodes_sorted = Arc::new(input.nodes.clone().into_sorted());
let state_sorted = Arc::new(input.state.clone().into_sorted());
let config = StateRootConfig {
consistent_view: ConsistentDbView::new(factory, None),
input: Arc::new(TrieInput::from_state(hashed_state)),
nodes_sorted: nodes_sorted.clone(),
state_sorted: state_sorted.clone(),
prefix_sets: Arc::new(input.prefix_sets),
};
let provider = config.consistent_view.provider_ro().unwrap();
let nodes_sorted = config.input.nodes.clone().into_sorted();
let state_sorted = config.input.state.clone().into_sorted();
let blinded_provider_factory = ProofBlindedProviderFactory::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
Expand All @@ -875,7 +907,7 @@ mod tests {
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&state_sorted,
),
Arc::new(config.input.prefix_sets.clone()),
config.prefix_sets.clone(),
);
let (root_from_task, _) = std::thread::scope(|std_scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
Expand Down
Loading