diff --git a/README.md b/README.md index 3de9e1e..bbeaffc 100644 --- a/README.md +++ b/README.md @@ -75,4 +75,27 @@ It's also possible to run individual benchmarks by: ``` cargo bench --bench bench ws_round_trip -``` \ No newline at end of file +``` + +## Validate Middleware + +This middleware will intercept all method request/responses and compare the result directly with healthy endpoint responses. +This is useful for debugging to make sure the returned values are as expected. +You can enable validate middleware on your config file. +```yml +middlewares: + methods: + - validate +``` +NOTE: Keep in mind that if you place `validate` middleware before `inject_params` you may get false positive errors because the request will not be the same. + +Ignored methods can be defined in extension config: +```yml +extensions: + validator: + ignore_methods: + - system_health + - system_name + - system_version + - author_pendingExtrinsics +``` diff --git a/src/extensions/client/mod.rs b/src/extensions/client/mod.rs index 0e20ffb..9b7a5e6 100644 --- a/src/extensions/client/mod.rs +++ b/src/extensions/client/mod.rs @@ -31,6 +31,7 @@ mod tests; const TRACER: utils::telemetry::Tracer = utils::telemetry::Tracer::new("client"); pub struct Client { + endpoints: Vec>, sender: tokio::sync::mpsc::Sender, rotation_notify: Arc, retries: u32, @@ -187,6 +188,7 @@ impl Client { let rotation_notify = Arc::new(Notify::new()); let rotation_notify_bg = rotation_notify.clone(); + let endpoints_ = endpoints.clone(); let background_task = tokio::spawn(async move { let request_backoff_counter = Arc::new(AtomicU32::new(0)); @@ -395,6 +397,7 @@ impl Client { }); Ok(Self { + endpoints: endpoints_, sender: message_tx, rotation_notify, retries: retries.unwrap_or(3), @@ -406,6 +409,10 @@ impl Client { Self::new(endpoints, None, None, None, None) } + pub fn endpoints(&self) -> &Vec> { + self.endpoints.as_ref() + } + pub async fn request(&self, method: &str, params: Vec) -> CallResult { async move { let (tx, rx) = tokio::sync::oneshot::channel(); diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 2cf1880..4959199 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -16,6 +16,7 @@ pub mod merge_subscription; pub mod rate_limit; pub mod server; pub mod telemetry; +pub mod validator; #[async_trait] pub trait Extension: Sized { @@ -138,4 +139,5 @@ define_all_extensions! { server: server::SubwayServerBuilder, event_bus: event_bus::EventBus, rate_limit: rate_limit::RateLimitBuilder, + validator: validator::Validator, } diff --git a/src/extensions/validator/mod.rs b/src/extensions/validator/mod.rs new file mode 100644 index 0000000..00ee5e8 --- /dev/null +++ b/src/extensions/validator/mod.rs @@ -0,0 +1,67 @@ +use crate::extensions::client::Client; +use crate::middlewares::{CallRequest, CallResult}; +use crate::utils::errors; +use async_trait::async_trait; +use serde::Deserialize; +use std::sync::Arc; + +use super::{Extension, ExtensionRegistry}; + +#[derive(Default)] +pub struct Validator { + pub config: ValidateConfig, +} + +#[derive(Deserialize, Default, Debug, Clone)] +pub struct ValidateConfig { + pub ignore_methods: Vec, +} + +#[async_trait] +impl Extension for Validator { + type Config = ValidateConfig; + + async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { + Ok(Self::new(config.clone())) + } +} + +impl Validator { + pub fn new(config: ValidateConfig) -> Self { + Self { config } + } + + pub fn ignore(&self, method: &String) -> bool { + self.config.ignore_methods.contains(method) + } + + pub fn validate(&self, client: Arc, request: CallRequest, response: CallResult) { + tokio::spawn(async move { + let healthy_endpoints = client.endpoints().iter().filter(|x| x.health().score() > 0); + futures::future::join_all(healthy_endpoints.map(|endpoint| async { + let expected = endpoint + .request( + &request.method, + request.params.clone(), + std::time::Duration::from_secs(30), + ) + .await + .map_err(errors::map_error); + + if response != expected { + let request = serde_json::to_string_pretty(&request).unwrap_or_default(); + let actual = match &response { + Ok(value) => serde_json::to_string_pretty(&value).unwrap_or_default(), + Err(e) => e.to_string() + }; + let expected = match &expected { + Ok(value) => serde_json::to_string_pretty(&value).unwrap_or_default(), + Err(e) => e.to_string() + }; + let endpoint_url = endpoint.url(); + tracing::error!("Response mismatch for request:\n{request}\nSubway response:\n{actual}\nEndpoint {endpoint_url} response:\n{expected}"); + } + })).await; + }); + } +} diff --git a/src/middlewares/factory.rs b/src/middlewares/factory.rs index d92da85..2ac12bc 100644 --- a/src/middlewares/factory.rs +++ b/src/middlewares/factory.rs @@ -30,6 +30,7 @@ pub async fn create_method_middleware( "block_tag" => block_tag::BlockTagMiddleware::build(method, extensions).await, "inject_params" => inject_params::InjectParamsMiddleware::build(method, extensions).await, "delay" => delay::DelayMiddleware::build(method, extensions).await, + "validate" => validate::ValidateMiddleware::build(method, extensions).await, #[cfg(test)] "crazy" => testing::CrazyMiddleware::build(method, extensions).await, _ => panic!("Unknown method middleware: {}", name), diff --git a/src/middlewares/methods/mod.rs b/src/middlewares/methods/mod.rs index 20be5e1..420c290 100644 --- a/src/middlewares/methods/mod.rs +++ b/src/middlewares/methods/mod.rs @@ -4,6 +4,7 @@ pub mod delay; pub mod inject_params; pub mod response; pub mod upstream; +pub mod validate; #[cfg(test)] pub mod testing; diff --git a/src/middlewares/methods/validate.rs b/src/middlewares/methods/validate.rs new file mode 100644 index 0000000..24e1971 --- /dev/null +++ b/src/middlewares/methods/validate.rs @@ -0,0 +1,52 @@ +use async_trait::async_trait; +use std::sync::Arc; + +use crate::{ + extensions::{client::Client, validator::Validator}, + middlewares::{CallRequest, CallResult, Middleware, MiddlewareBuilder, NextFn, RpcMethod}, + utils::{TypeRegistry, TypeRegistryRef}, +}; + +pub struct ValidateMiddleware { + validator: Arc, + client: Arc, +} + +impl ValidateMiddleware { + pub fn new(validator: Arc, client: Arc) -> Self { + Self { validator, client } + } +} + +#[async_trait] +impl MiddlewareBuilder for ValidateMiddleware { + async fn build( + _method: &RpcMethod, + extensions: &TypeRegistryRef, + ) -> Option>> { + let validate = extensions.read().await.get::().unwrap_or_default(); + + let client = extensions + .read() + .await + .get::() + .expect("Client extension not found"); + Some(Box::new(ValidateMiddleware::new(validate, client))) + } +} + +#[async_trait] +impl Middleware for ValidateMiddleware { + async fn call( + &self, + request: CallRequest, + context: TypeRegistry, + next: NextFn, + ) -> CallResult { + let result = next(request.clone(), context).await; + if !self.validator.ignore(&request.method) { + self.validator.validate(self.client.clone(), request, result.clone()); + } + result + } +} diff --git a/src/middlewares/mod.rs b/src/middlewares/mod.rs index 03871b9..27656b8 100644 --- a/src/middlewares/mod.rs +++ b/src/middlewares/mod.rs @@ -6,6 +6,7 @@ use jsonrpsee::{ PendingSubscriptionSink, }; use opentelemetry::trace::FutureExt as _; +use serde::Serialize; use std::{ fmt::{Debug, Formatter}, sync::Arc, @@ -20,7 +21,7 @@ pub mod factory; pub mod methods; pub mod subscriptions; -#[derive(Debug)] +#[derive(Clone, Debug, Serialize)] /// Represents a RPC request made to a middleware function. pub struct CallRequest { pub method: String,