From 0cd63cdf4b605dcf736f8d36b736137a1217a8a2 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Fri, 24 Jan 2025 19:28:02 +0100 Subject: [PATCH] test: reenable `test_state_root_task` test (#13911) Co-authored-by: Federico Gimenez Co-authored-by: Federico Gimenez --- crates/engine/tree/src/tree/mod.rs | 5 +- crates/engine/tree/src/tree/root.rs | 285 +++++++++++++++++----------- 2 files changed, 173 insertions(+), 117 deletions(-) diff --git a/crates/engine/tree/src/tree/mod.rs b/crates/engine/tree/src/tree/mod.rs index ef5b3183af40..ea0832d0c9f3 100644 --- a/crates/engine/tree/src/tree/mod.rs +++ b/crates/engine/tree/src/tree/mod.rs @@ -594,10 +594,7 @@ where ) -> Self { let (incoming_tx, incoming) = std::sync::mpsc::channel(); - // The thread pool requires at least 2 threads as it contains a long running sparse trie - // task. - let num_threads = - std::thread::available_parallelism().map_or(2, |num| (num.get() / 2).max(2)); + let num_threads = root::thread_pool_size(); let state_root_task_pool = Arc::new( rayon::ThreadPoolBuilder::new() diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index 5929895d0fa7..3b012802521f 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -27,7 +27,7 @@ use reth_trie_sparse::{ }; use revm_primitives::{keccak256, EvmState, B256}; use std::{ - collections::BTreeMap, + collections::{BTreeMap, VecDeque}, sync::{ mpsc::{self, channel, Receiver, Sender}, Arc, @@ -39,6 +39,16 @@ use tracing::{debug, error, trace}; /// The level below which the sparse trie hashes are calculated in [`update_sparse_trie`]. const SPARSE_TRIE_INCREMENTAL_LEVEL: usize = 2; +/// Determines the size of the thread pool to be used in [`StateRootTask`]. +/// It should be at least three, one for multiproof calculations plus two to be +/// used internally in [`StateRootTask`]. +/// +/// NOTE: this value can be greater than the available cores in the host, it +/// represents the maximum number of threads that can be handled by the pool. +pub(crate) fn thread_pool_size() -> usize { + std::thread::available_parallelism().map_or(3, |num| (num.get() / 2).max(3)) +} + /// Outcome of the state root computation, including the state root itself with /// the trie updates and the total time spent. #[derive(Debug)] @@ -296,6 +306,129 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState { hashed_state } +/// Input parameters for spawning a multiproof calculation. +#[derive(Debug)] +struct MultiproofInput { + config: StateRootConfig, + hashed_state_update: HashedPostState, + proof_targets: MultiProofTargets, + proof_sequence_number: u64, + state_root_message_sender: Sender, + source: ProofFetchSource, +} + +/// Manages concurrent multiproof calculations. +/// Takes care of not having more calculations in flight than a given thread +/// pool size, further calculation requests are queued and spawn later, after +/// availability has been signaled. +#[derive(Debug)] +struct MultiproofManager { + /// Maximum number of concurrent calculations. + max_concurrent: usize, + /// Currently running calculations. + inflight: usize, + /// Queued calculations. + pending: VecDeque>, + /// Thread pool to spawn multiproof calculations. + thread_pool: Arc, +} + +impl MultiproofManager +where + Factory: DatabaseProviderFactory + + StateCommitmentProvider + + Clone + + Send + + Sync + + 'static, +{ + /// Creates a new [`MultiproofManager`]. + fn new(thread_pool: Arc, thread_pool_size: usize) -> Self { + // we keep 2 threads to be used internally by [`StateRootTask`] + let max_concurrent = thread_pool_size.saturating_sub(2); + debug_assert!(max_concurrent != 0); + Self { + thread_pool, + max_concurrent, + inflight: 0, + pending: VecDeque::with_capacity(max_concurrent), + } + } + + /// Spawns a new multiproof calculation or enqueues it for later if + /// `max_concurrent` are already inflight. + fn spawn_or_queue(&mut self, input: MultiproofInput) { + if self.inflight >= self.max_concurrent { + self.pending.push_back(input); + return; + } + + self.spawn_multiproof(input); + } + + /// Signals that a multiproof calculation has finished and there's room to + /// spawn a new calculation if needed. + fn on_calculation_complete(&mut self) { + self.inflight = self.inflight.saturating_sub(1); + + if let Some(input) = self.pending.pop_front() { + self.spawn_multiproof(input); + } + } + + /// Spawns a multiproof calculation. + fn spawn_multiproof(&mut self, input: MultiproofInput) { + let MultiproofInput { + config, + hashed_state_update, + proof_targets, + proof_sequence_number, + state_root_message_sender, + source, + } = input; + let thread_pool = self.thread_pool.clone(); + + self.thread_pool.spawn(move || { + trace!( + target: "engine::root", + proof_sequence_number, + ?proof_targets, + "Starting multiproof calculation", + ); + let start = Instant::now(); + let result = calculate_multiproof(thread_pool, config, proof_targets.clone()); + trace!( + target: "engine::root", + proof_sequence_number, + elapsed = ?start.elapsed(), + "Multiproof calculated", + ); + + match result { + Ok(proof) => { + let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated( + Box::new(ProofCalculated { + sequence_number: proof_sequence_number, + update: SparseTrieUpdate { + state: hashed_state_update, + targets: proof_targets, + multiproof: proof, + }, + source, + }), + )); + } + Err(error) => { + let _ = state_root_message_sender + .send(StateRootMessage::ProofCalculationError(error)); + } + } + }); + + self.inflight += 1; + } +} + /// Standalone task that receives a transaction state stream and updates relevant /// data structures to calculate state root. /// @@ -316,8 +449,10 @@ pub struct StateRootTask { fetched_proof_targets: MultiProofTargets, /// Proof sequencing handler. proof_sequencer: ProofSequencer, - /// Reference to the shared thread pool for parallel proof generation + /// Reference to the shared thread pool for parallel proof generation. thread_pool: Arc, + /// Manages calculation of multiproofs. + multiproof_manager: MultiproofManager, } impl StateRootTask @@ -338,7 +473,8 @@ where tx, fetched_proof_targets: Default::default(), proof_sequencer: ProofSequencer::new(), - thread_pool, + thread_pool: thread_pool.clone(), + multiproof_manager: MultiproofManager::new(thread_pool, thread_pool_size()), } } @@ -397,99 +533,34 @@ where } /// Handles request for proof prefetch. - fn on_prefetch_proof( - config: StateRootConfig, - targets: MultiProofTargets, - fetched_proof_targets: &mut MultiProofTargets, - proof_sequence_number: u64, - state_root_message_sender: Sender, - thread_pool: Arc, - ) { - extend_multi_proof_targets_ref(fetched_proof_targets, &targets); - - Self::spawn_multiproof( - config, - Default::default(), - targets, - proof_sequence_number, - state_root_message_sender, - thread_pool, - ProofFetchSource::Prefetch, - ); + fn on_prefetch_proof(&mut self, targets: MultiProofTargets) { + extend_multi_proof_targets_ref(&mut self.fetched_proof_targets, &targets); + + self.multiproof_manager.spawn_or_queue(MultiproofInput { + config: self.config.clone(), + hashed_state_update: Default::default(), + proof_targets: targets, + proof_sequence_number: self.proof_sequencer.next_sequence(), + state_root_message_sender: self.tx.clone(), + source: ProofFetchSource::Prefetch, + }); } /// Handles state updates. /// /// Returns proof targets derived from the state update. - fn on_state_update( - config: StateRootConfig, - update: EvmState, - fetched_proof_targets: &mut MultiProofTargets, - proof_sequence_number: u64, - state_root_message_sender: Sender, - thread_pool: Arc, - ) { + fn on_state_update(&mut self, update: EvmState, proof_sequence_number: u64) { let hashed_state_update = evm_state_to_hashed_post_state(update); + let proof_targets = get_proof_targets(&hashed_state_update, &self.fetched_proof_targets); + extend_multi_proof_targets_ref(&mut self.fetched_proof_targets, &proof_targets); - let proof_targets = get_proof_targets(&hashed_state_update, fetched_proof_targets); - extend_multi_proof_targets_ref(fetched_proof_targets, &proof_targets); - - Self::spawn_multiproof( - config, + self.multiproof_manager.spawn_or_queue(MultiproofInput { + config: self.config.clone(), hashed_state_update, proof_targets, proof_sequence_number, - state_root_message_sender, - thread_pool, - ProofFetchSource::StateUpdate, - ); - } - - fn spawn_multiproof( - config: StateRootConfig, - hashed_state_update: HashedPostState, - proof_targets: MultiProofTargets, - proof_sequence_number: u64, - state_root_message_sender: Sender, - thread_pool: Arc, - source: ProofFetchSource, - ) { - // Dispatch proof gathering for this state update - thread_pool.clone().spawn(move || { - trace!( - target: "engine::root", - proof_sequence_number, - ?proof_targets, - "Starting multiproof calculation", - ); - let start = Instant::now(); - let result = calculate_multiproof(thread_pool, config, proof_targets.clone()); - trace!( - target: "engine::root", - proof_sequence_number, - elapsed = ?start.elapsed(), - "Multiproof calculated", - ); - - match result { - Ok(proof) => { - let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated( - Box::new(ProofCalculated { - sequence_number: proof_sequence_number, - update: SparseTrieUpdate { - state: hashed_state_update, - targets: proof_targets, - multiproof: proof, - }, - source, - }), - )); - } - Err(error) => { - let _ = state_root_message_sender - .send(StateRootMessage::ProofCalculationError(error)); - } - } + state_root_message_sender: self.tx.clone(), + source: ProofFetchSource::StateUpdate, }); } @@ -526,24 +597,20 @@ where let mut last_update_time = None; loop { + trace!(target: "engine::root", "entering main channel receiving loop"); match self.rx.recv() { Ok(message) => match message { StateRootMessage::PrefetchProofs(targets) => { + trace!(target: "engine::root", "processing StateRootMessage::PrefetchProofs"); debug!( target: "engine::root", len = targets.len(), "Prefetching proofs" ); - Self::on_prefetch_proof( - self.config.clone(), - targets, - &mut self.fetched_proof_targets, - self.proof_sequencer.next_sequence(), - self.tx.clone(), - self.thread_pool.clone(), - ); + self.on_prefetch_proof(targets); } StateRootMessage::StateUpdate(update) => { + trace!(target: "engine::root", "processing StateRootMessage::StateUpdate"); if updates_received == 0 { first_update_time = Some(Instant::now()); debug!(target: "engine::root", "Started state root calculation"); @@ -557,23 +624,19 @@ where total_updates = updates_received, "Received new state update" ); - Self::on_state_update( - self.config.clone(), - update, - &mut self.fetched_proof_targets, - self.proof_sequencer.next_sequence(), - self.tx.clone(), - self.thread_pool.clone(), - ); + let next_sequence = self.proof_sequencer.next_sequence(); + self.on_state_update(update, next_sequence); } StateRootMessage::FinishedStateUpdates => { - trace!(target: "engine::root", "Finished state updates"); + trace!(target: "engine::root", "processing StateRootMessage::FinishedStateUpdates"); updates_finished = true; } StateRootMessage::ProofCalculated(proof_calculated) => { + trace!(target: "engine::root", "processing StateRootMessage::ProofCalculated"); if proof_calculated.is_from_state_update() { proofs_processed += 1; } + debug!( target: "engine::root", sequence = proof_calculated.sequence_number, @@ -581,6 +644,8 @@ where "Processing calculated proof" ); + self.multiproof_manager.on_calculation_complete(); + if let Some(combined_update) = self.on_proof(proof_calculated.sequence_number, proof_calculated.update) { @@ -599,6 +664,7 @@ where } } StateRootMessage::RootCalculated { state_root, trie_updates, iterations } => { + trace!(target: "engine::root", "processing StateRootMessage::RootCalculated"); let total_time = first_update_time.expect("first update time should be set").elapsed(); let time_from_last_update = @@ -694,7 +760,7 @@ where let elapsed = update_sparse_trie(&mut trie, update).map_err(|e| { ParallelStateRootError::Other(format!("could not calculate state root: {e:?}")) })?; - trace!(target: "engine::root", ?elapsed, "Root calculation completed"); + trace!(target: "engine::root", ?elapsed, num_iterations, "Root calculation completed"); } debug!(target: "engine::root", num_iterations, "All proofs processed, ending calculation"); @@ -853,7 +919,6 @@ mod tests { }; use std::sync::Arc; - #[allow(dead_code)] fn convert_revm_to_reth_account(revm_account: &RevmAccount) -> RethAccount { RethAccount { balance: revm_account.info.balance, @@ -866,7 +931,6 @@ mod tests { } } - #[allow(dead_code)] fn create_mock_state_updates(num_accounts: usize, updates_per_account: usize) -> Vec { let mut rng = generators::rng(); let all_addresses: Vec
= (0..num_accounts).map(|_| rng.gen()).collect(); @@ -910,9 +974,7 @@ mod tests { updates } - // TODO: re-enable test once gh worker hang is figured out. - // #[test] - #[allow(dead_code)] + #[test] fn test_state_root_task() { reth_tracing::init_test_tracing(); @@ -973,10 +1035,7 @@ mod tests { prefix_sets: Arc::new(input.prefix_sets), }; - // The thread pool requires at least 2 threads as it contains a long running sparse trie - // task. - let num_threads = - std::thread::available_parallelism().map_or(2, |num| (num.get() / 2).max(2)); + let num_threads = thread_pool_size(); let state_root_task_pool = rayon::ThreadPoolBuilder::new() .num_threads(num_threads)