From d630c264d630c1d1bff23a877537a057e0d5fdbc Mon Sep 17 00:00:00 2001 From: Ermal Kaleci Date: Thu, 11 Apr 2024 12:14:02 +0200 Subject: [PATCH] rename --- README.md | 2 +- src/extensions/mod.rs | 4 +- src/extensions/validate/mod.rs | 33 -------------- src/extensions/validator/mod.rs | 67 +++++++++++++++++++++++++++++ src/middlewares/methods/validate.rs | 49 +++------------------ 5 files changed, 77 insertions(+), 78 deletions(-) delete mode 100644 src/extensions/validate/mod.rs create mode 100644 src/extensions/validator/mod.rs diff --git a/README.md b/README.md index 803bace..bbeaffc 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ NOTE: Keep in mind that if you place `validate` middleware before `inject_params Ignored methods can be defined in extension config: ```yml extensions: - validate: + validator: ignore_methods: - system_health - system_name diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index ebc94b1..4959199 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -16,7 +16,7 @@ pub mod merge_subscription; pub mod rate_limit; pub mod server; pub mod telemetry; -pub mod validate; +pub mod validator; #[async_trait] pub trait Extension: Sized { @@ -139,5 +139,5 @@ define_all_extensions! { server: server::SubwayServerBuilder, event_bus: event_bus::EventBus, rate_limit: rate_limit::RateLimitBuilder, - validate: validate::Validate, + validator: validator::Validator, } diff --git a/src/extensions/validate/mod.rs b/src/extensions/validate/mod.rs deleted file mode 100644 index 9b31229..0000000 --- a/src/extensions/validate/mod.rs +++ /dev/null @@ -1,33 +0,0 @@ -use async_trait::async_trait; -use serde::Deserialize; - -use super::{Extension, ExtensionRegistry}; - -#[derive(Default)] -pub struct Validate { - pub config: ValidateConfig, -} - -#[derive(Deserialize, Default, Debug, Clone)] -pub struct ValidateConfig { - pub ignore_methods: Vec, -} - -#[async_trait] -impl Extension for Validate { - type Config = ValidateConfig; - - async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { - Ok(Self::new(config.clone())) - } -} - -impl Validate { - pub fn new(config: ValidateConfig) -> Self { - Self { config } - } - - pub fn ignore(&self, method: &String) -> bool { - self.config.ignore_methods.contains(method) - } -} diff --git a/src/extensions/validator/mod.rs b/src/extensions/validator/mod.rs new file mode 100644 index 0000000..00ee5e8 --- /dev/null +++ b/src/extensions/validator/mod.rs @@ -0,0 +1,67 @@ +use crate::extensions::client::Client; +use crate::middlewares::{CallRequest, CallResult}; +use crate::utils::errors; +use async_trait::async_trait; +use serde::Deserialize; +use std::sync::Arc; + +use super::{Extension, ExtensionRegistry}; + +#[derive(Default)] +pub struct Validator { + pub config: ValidateConfig, +} + +#[derive(Deserialize, Default, Debug, Clone)] +pub struct ValidateConfig { + pub ignore_methods: Vec, +} + +#[async_trait] +impl Extension for Validator { + type Config = ValidateConfig; + + async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { + Ok(Self::new(config.clone())) + } +} + +impl Validator { + pub fn new(config: ValidateConfig) -> Self { + Self { config } + } + + pub fn ignore(&self, method: &String) -> bool { + self.config.ignore_methods.contains(method) + } + + pub fn validate(&self, client: Arc, request: CallRequest, response: CallResult) { + tokio::spawn(async move { + let healthy_endpoints = client.endpoints().iter().filter(|x| x.health().score() > 0); + futures::future::join_all(healthy_endpoints.map(|endpoint| async { + let expected = endpoint + .request( + &request.method, + request.params.clone(), + std::time::Duration::from_secs(30), + ) + .await + .map_err(errors::map_error); + + if response != expected { + let request = serde_json::to_string_pretty(&request).unwrap_or_default(); + let actual = match &response { + Ok(value) => serde_json::to_string_pretty(&value).unwrap_or_default(), + Err(e) => e.to_string() + }; + let expected = match &expected { + Ok(value) => serde_json::to_string_pretty(&value).unwrap_or_default(), + Err(e) => e.to_string() + }; + let endpoint_url = endpoint.url(); + tracing::error!("Response mismatch for request:\n{request}\nSubway response:\n{actual}\nEndpoint {endpoint_url} response:\n{expected}"); + } + })).await; + }); + } +} diff --git a/src/middlewares/methods/validate.rs b/src/middlewares/methods/validate.rs index bd0e6be..24e1971 100644 --- a/src/middlewares/methods/validate.rs +++ b/src/middlewares/methods/validate.rs @@ -1,21 +1,20 @@ use async_trait::async_trait; use std::sync::Arc; -use crate::utils::errors; use crate::{ - extensions::{client::Client, validate::Validate}, + extensions::{client::Client, validator::Validator}, middlewares::{CallRequest, CallResult, Middleware, MiddlewareBuilder, NextFn, RpcMethod}, utils::{TypeRegistry, TypeRegistryRef}, }; pub struct ValidateMiddleware { - validate: Arc, + validator: Arc, client: Arc, } impl ValidateMiddleware { - pub fn new(validate: Arc, client: Arc) -> Self { - Self { validate, client } + pub fn new(validator: Arc, client: Arc) -> Self { + Self { validator, client } } } @@ -25,7 +24,7 @@ impl MiddlewareBuilder for ValidateMiddlewar _method: &RpcMethod, extensions: &TypeRegistryRef, ) -> Option>> { - let validate = extensions.read().await.get::().unwrap_or_default(); + let validate = extensions.read().await.get::().unwrap_or_default(); let client = extensions .read() @@ -44,44 +43,10 @@ impl Middleware for ValidateMiddleware { context: TypeRegistry, next: NextFn, ) -> CallResult { - let client = self.client.clone(); let result = next(request.clone(), context).await; - let actual = result.clone(); - - if self.validate.ignore(&request.method) { - return result; + if !self.validator.ignore(&request.method) { + self.validator.validate(self.client.clone(), request, result.clone()); } - - if let Err(err) = tokio::spawn(async move { - let healthy_endpoints = client.endpoints().iter().filter(|x| x.health().score() > 0); - futures::future::join_all(healthy_endpoints.map(|endpoint| async { - let expected = endpoint - .request( - &request.method, - request.params.clone(), - std::time::Duration::from_secs(30), - ) - .await - .map_err(errors::map_error); - - if actual != expected { - let request = serde_json::to_string_pretty(&request).unwrap_or_default(); - let actual = match &actual { - Ok(value) => serde_json::to_string_pretty(&value).unwrap_or_default(), - Err(e) => e.to_string() - }; - let expected = match &expected { - Ok(value) => serde_json::to_string_pretty(&value).unwrap_or_default(), - Err(e) => e.to_string() - }; - let endpoint_url = endpoint.url(); - tracing::error!("Response mismatch for request:\n{request}\nSubway response:\n{actual}\nEndpoint {endpoint_url} response:\n{expected}"); - } - })).await; - }).await { - tracing::error!("Validate task failed: {err:?}"); - } - result } }