diff --git a/zebra-rpc/src/server/http_request_compatibility.rs b/zebra-rpc/src/server/http_request_compatibility.rs index 5eb03b1c4fc..b352e0c5478 100644 --- a/zebra-rpc/src/server/http_request_compatibility.rs +++ b/zebra-rpc/src/server/http_request_compatibility.rs @@ -8,12 +8,13 @@ use std::pin::Pin; use futures::{future, FutureExt}; use http_body_util::BodyExt; -use hyper::{body::Bytes, header}; +use hyper::header; use jsonrpsee::{ core::BoxError, server::{HttpBody, HttpRequest, HttpResponse}, }; use jsonrpsee_types::ErrorObject; +use serde::{Deserialize, Serialize}; use tower::Service; use super::cookie::Cookie; @@ -118,24 +119,55 @@ impl HttpRequestMiddleware { } } - /// Remove any "jsonrpc: 1.0" fields in `data`, and return the resulting string. - pub fn remove_json_1_fields(data: String) -> String { - // Replace "jsonrpc = 1.0": - // - at the start or middle of a list, and - // - at the end of a list; - // with no spaces (lightwalletd format), and spaces after separators (example format). - // - // TODO: if we see errors from lightwalletd, make this replacement more accurate: - // - use a partial JSON fragment parser - // - combine the whole request into a single buffer, and use a JSON parser - // - use a regular expression - // - // We could also just handle the exact lightwalletd format, - // by replacing `{"jsonrpc":"1.0",` with `{"jsonrpc":"2.0`. - data.replace("\"jsonrpc\":\"1.0\",", "\"jsonrpc\":\"2.0\",") - .replace("\"jsonrpc\": \"1.0\",", "\"jsonrpc\": \"2.0\",") - .replace(",\"jsonrpc\":\"1.0\"", ",\"jsonrpc\":\"2.0\"") - .replace(", \"jsonrpc\": \"1.0\"", ", \"jsonrpc\": \"2.0\"") + /// Maps whatever JSON-RPC version the client is using to JSON-RPC 2.0. + async fn request_to_json_rpc_2( + request: HttpRequest, + ) -> (JsonRpcVersion, HttpRequest) { + let (parts, body) = request.into_parts(); + let bytes = body + .collect() + .await + .expect("Failed to collect body data") + .to_bytes(); + let (version, bytes) = + if let Ok(request) = serde_json::from_slice::<'_, JsonRpcRequest>(bytes.as_ref()) { + let version = request.version(); + if matches!(version, JsonRpcVersion::Unknown) { + (version, bytes) + } else { + ( + version, + serde_json::to_vec(&request.into_2()).expect("valid").into(), + ) + } + } else { + (JsonRpcVersion::Unknown, bytes) + }; + ( + version, + HttpRequest::from_parts(parts, HttpBody::from(bytes.as_ref().to_vec())), + ) + } + /// Maps JSON-2.0 to whatever JSON-RPC version the client is using. + async fn response_from_json_rpc_2( + version: JsonRpcVersion, + response: HttpResponse, + ) -> HttpResponse { + let (parts, body) = response.into_parts(); + let bytes = body + .collect() + .await + .expect("Failed to collect body data") + .to_bytes(); + let bytes = + if let Ok(response) = serde_json::from_slice::<'_, JsonRpcResponse>(bytes.as_ref()) { + serde_json::to_vec(&response.into_version(version)) + .expect("valid") + .into() + } else { + bytes + }; + HttpResponse::from_parts(parts, HttpBody::from(bytes.as_ref().to_vec())) } } @@ -203,25 +235,103 @@ where Self::insert_or_replace_content_type_header(request.headers_mut()); let mut service = self.service.clone(); - let (parts, body) = request.into_parts(); async move { - let bytes = body - .collect() - .await - .expect("Failed to collect body data") - .to_bytes(); + let (version, request) = Self::request_to_json_rpc_2(request).await; + let response = service.call(request).await.map_err(Into::into)?; + Ok(Self::response_from_json_rpc_2(version, response).await) + } + .boxed() + } +} + +#[derive(Clone, Copy, Debug)] +enum JsonRpcVersion { + /// bitcoind used a mishmash of 1.0, 1.1, and 2.0 for its JSON-RPC. + Bitcoind, + /// lightwalletd uses the above mishmash, but also breaks spec to include a + /// `"jsonrpc": "1.0"` key. + Lightwalletd, + /// The client is indicating strict 2.0 handling. + TwoPointZero, + /// On parse errors we don't modify anything, and let the `jsonrpsee` crate handle it. + Unknown, +} - let data = String::from_utf8_lossy(bytes.as_ref()).to_string(); +/// A version-agnostic JSON-RPC request. +#[derive(Debug, Deserialize, Serialize)] +struct JsonRpcRequest { + #[serde(skip_serializing_if = "Option::is_none")] + jsonrpc: Option, + method: String, + #[serde(skip_serializing_if = "Option::is_none")] + params: Option, + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, +} - // Fix JSON-RPC 1.0 requests. - let data = Self::remove_json_1_fields(data); - let body = HttpBody::from(Bytes::from(data).as_ref().to_vec()); +impl JsonRpcRequest { + fn version(&self) -> JsonRpcVersion { + match (self.jsonrpc.as_deref(), &self.params, &self.id) { + ( + Some("2.0"), + _, + None + | Some( + serde_json::Value::Null + | serde_json::Value::String(_) + | serde_json::Value::Number(_), + ), + ) => JsonRpcVersion::TwoPointZero, + (Some("1.0"), Some(_), Some(_)) => JsonRpcVersion::Lightwalletd, + (None, Some(_), Some(_)) => JsonRpcVersion::Bitcoind, + _ => JsonRpcVersion::Unknown, + } + } - let request = HttpRequest::from_parts(parts, body); + fn into_2(mut self) -> Self { + self.jsonrpc = Some("2.0".into()); + self + } +} +/// A version-agnostic JSON-RPC response. +#[derive(Debug, Deserialize, Serialize)] +struct JsonRpcResponse { + #[serde(skip_serializing_if = "Option::is_none")] + jsonrpc: Option, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + id: serde_json::Value, +} - service.call(request).await.map_err(Into::into) +impl JsonRpcResponse { + fn into_version(mut self, version: JsonRpcVersion) -> Self { + match version { + JsonRpcVersion::Bitcoind => { + self.jsonrpc = None; + self.result = self.result.or(Some(serde_json::Value::Null)); + self.error = self.error.or(Some(serde_json::Value::Null)); + } + JsonRpcVersion::Lightwalletd => { + self.jsonrpc = Some("1.0".into()); + self.result = self.result.or(Some(serde_json::Value::Null)); + self.error = self.error.or(Some(serde_json::Value::Null)); + } + JsonRpcVersion::TwoPointZero => { + // `jsonrpsee` should be returning valid JSON-RPC 2.0 responses. However, + // a valid result of `null` can be parsed into `None` by this parser, so + // we map the result explicitly to `Null` when there is no error. + assert_eq!(self.jsonrpc.as_deref(), Some("2.0")); + if self.error.is_none() { + self.result = self.result.or(Some(serde_json::Value::Null)); + } else { + assert!(self.result.is_none()); + } + } + JsonRpcVersion::Unknown => (), } - .boxed() + self } } diff --git a/zebra-state/src/service/non_finalized_state.rs b/zebra-state/src/service/non_finalized_state.rs index ebcbb2cfd35..91ae30ae23d 100644 --- a/zebra-state/src/service/non_finalized_state.rs +++ b/zebra-state/src/service/non_finalized_state.rs @@ -9,7 +9,7 @@ use std::{ }; use zebra_chain::{ - block::{self, Block}, + block::{self, Block, Hash}, parameters::Network, sprout, transparent, }; @@ -45,6 +45,10 @@ pub struct NonFinalizedState { /// callers should migrate to `chain_iter().next()`. chain_set: BTreeSet>, + /// Blocks that have been invalidated in, and removed from, the non finalized + /// state. + invalidated_blocks: HashMap>>, + // Configuration // /// The configured Zcash network. @@ -92,6 +96,7 @@ impl Clone for NonFinalizedState { Self { chain_set: self.chain_set.clone(), network: self.network.clone(), + invalidated_blocks: self.invalidated_blocks.clone(), #[cfg(feature = "getblocktemplate-rpcs")] should_count_metrics: self.should_count_metrics, @@ -112,6 +117,7 @@ impl NonFinalizedState { NonFinalizedState { chain_set: Default::default(), network: network.clone(), + invalidated_blocks: Default::default(), #[cfg(feature = "getblocktemplate-rpcs")] should_count_metrics: true, #[cfg(feature = "progress-bar")] @@ -264,6 +270,37 @@ impl NonFinalizedState { Ok(()) } + /// Invalidate block with hash `block_hash` and all descendants from the non-finalized state. Insert + /// the new chain into the chain_set and discard the previous. + pub fn invalidate_block(&mut self, block_hash: Hash) { + let Some(chain) = self.find_chain(|chain| chain.contains_block_hash(block_hash)) else { + return; + }; + + let invalidated_blocks = if chain.non_finalized_root_hash() == block_hash { + self.chain_set.remove(&chain); + chain.blocks.values().cloned().collect() + } else { + let (new_chain, invalidated_blocks) = chain + .invalidate_block(block_hash) + .expect("already checked that chain contains hash"); + + // Add the new chain fork or updated chain to the set of recent chains, and + // remove the chain containing the hash of the block from chain set + self.insert_with(Arc::new(new_chain.clone()), |chain_set| { + chain_set.retain(|c| !c.contains_block_hash(block_hash)) + }); + + invalidated_blocks + }; + + self.invalidated_blocks + .insert(block_hash, Arc::new(invalidated_blocks)); + + self.update_metrics_for_chains(); + self.update_metrics_bars(); + } + /// Commit block to the non-finalized state as a new chain where its parent /// is the finalized tip. #[tracing::instrument(level = "debug", skip(self, finalized_state, prepared))] @@ -586,6 +623,11 @@ impl NonFinalizedState { self.chain_set.len() } + /// Return the invalidated blocks. + pub fn invalidated_blocks(&self) -> HashMap>> { + self.invalidated_blocks.clone() + } + /// Return the chain whose tip block hash is `parent_hash`. /// /// The chain can be an existing chain in the non-finalized state, or a freshly diff --git a/zebra-state/src/service/non_finalized_state/chain.rs b/zebra-state/src/service/non_finalized_state/chain.rs index 6ad284a23f5..c7d0d2877c6 100644 --- a/zebra-state/src/service/non_finalized_state/chain.rs +++ b/zebra-state/src/service/non_finalized_state/chain.rs @@ -359,6 +359,26 @@ impl Chain { (block, treestate) } + // Returns the block at the provided height and all of its descendant blocks. + pub fn child_blocks(&self, block_height: &block::Height) -> Vec { + self.blocks + .range(block_height..) + .map(|(_h, b)| b.clone()) + .collect() + } + + // Returns a new chain without the invalidated block or its descendants. + pub fn invalidate_block( + &self, + block_hash: block::Hash, + ) -> Option<(Self, Vec)> { + let block_height = self.height_by_hash(block_hash)?; + let mut new_chain = self.fork(block_hash)?; + new_chain.pop_tip(); + new_chain.last_fork_height = None; + Some((new_chain, self.child_blocks(&block_height))) + } + /// Returns the height of the chain root. pub fn non_finalized_root_height(&self) -> block::Height { self.blocks @@ -1600,7 +1620,7 @@ impl DerefMut for Chain { /// The revert position being performed on a chain. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -enum RevertPosition { +pub(crate) enum RevertPosition { /// The chain root is being reverted via [`Chain::pop_root`], when a block /// is finalized. Root, @@ -1619,7 +1639,7 @@ enum RevertPosition { /// and [`Chain::pop_tip`] functions, and fear that it would be easy to /// introduce bugs when updating them, unless the code was reorganized to keep /// related operations adjacent to each other. -trait UpdateWith { +pub(crate) trait UpdateWith { /// When `T` is added to the chain tip, /// update [`Chain`] cumulative data members to add data that are derived from `T`. fn update_chain_tip_with(&mut self, _: &T) -> Result<(), ValidateContextError>; diff --git a/zebra-state/src/service/non_finalized_state/tests/vectors.rs b/zebra-state/src/service/non_finalized_state/tests/vectors.rs index b489d6f94f0..5b392e4a0b9 100644 --- a/zebra-state/src/service/non_finalized_state/tests/vectors.rs +++ b/zebra-state/src/service/non_finalized_state/tests/vectors.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use zebra_chain::{ amount::NonNegative, - block::{Block, Height}, + block::{self, Block, Height}, history_tree::NonEmptyHistoryTree, parameters::{Network, NetworkUpgrade}, serialization::ZcashDeserializeInto, @@ -216,6 +216,94 @@ fn finalize_pops_from_best_chain_for_network(network: Network) -> Result<()> { Ok(()) } +fn invalidate_block_removes_block_and_descendants_from_chain_for_network( + network: Network, +) -> Result<()> { + let block1: Arc = Arc::new(network.test_block(653599, 583999).unwrap()); + let block2 = block1.make_fake_child().set_work(10); + let block3 = block2.make_fake_child().set_work(1); + + let mut state = NonFinalizedState::new(&network); + let finalized_state = FinalizedState::new( + &Config::ephemeral(), + &network, + #[cfg(feature = "elasticsearch")] + false, + ); + + let fake_value_pool = ValueBalance::::fake_populated_pool(); + finalized_state.set_finalized_value_pool(fake_value_pool); + + state.commit_new_chain(block1.clone().prepare(), &finalized_state)?; + state.commit_block(block2.clone().prepare(), &finalized_state)?; + state.commit_block(block3.clone().prepare(), &finalized_state)?; + + assert_eq!( + state + .best_chain() + .unwrap_or(&Arc::new(Chain::default())) + .blocks + .len(), + 3 + ); + + state.invalidate_block(block2.hash()); + + let post_invalidated_chain = state.best_chain().unwrap(); + + assert_eq!(post_invalidated_chain.blocks.len(), 1); + assert!( + post_invalidated_chain.contains_block_hash(block1.hash()), + "the new modified chain should contain block1" + ); + + assert!( + !post_invalidated_chain.contains_block_hash(block2.hash()), + "the new modified chain should not contain block2" + ); + assert!( + !post_invalidated_chain.contains_block_hash(block3.hash()), + "the new modified chain should not contain block3" + ); + + let invalidated_blocks_state = &state.invalidated_blocks; + assert!( + invalidated_blocks_state.contains_key(&block2.hash()), + "invalidated blocks map should reference the hash of block2" + ); + + let invalidated_blocks_state_descendants = + invalidated_blocks_state.get(&block2.hash()).unwrap(); + + match network { + Network::Mainnet => assert!( + invalidated_blocks_state_descendants + .iter() + .any(|block| block.height == block::Height(653601)), + "invalidated descendants vec should contain block3" + ), + Network::Testnet(_parameters) => assert!( + invalidated_blocks_state_descendants + .iter() + .any(|block| block.height == block::Height(584001)), + "invalidated descendants vec should contain block3" + ), + } + + Ok(()) +} + +#[test] +fn invalidate_block_removes_block_and_descendants_from_chain() -> Result<()> { + let _init_guard = zebra_test::init(); + + for network in Network::iter() { + invalidate_block_removes_block_and_descendants_from_chain_for_network(network)?; + } + + Ok(()) +} + #[test] // This test gives full coverage for `take_chain_if` fn commit_block_extending_best_chain_doesnt_drop_worst_chains() -> Result<()> {