diff --git a/Cargo.lock b/Cargo.lock index 75ab028..d942438 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1172,8 +1172,6 @@ dependencies = [ [[package]] name = "governor" version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "821239e5672ff23e2a7060901fa622950bbd80b649cdaadd78d1c1767ed14eb4" dependencies = [ "cfg-if 1.0.0", "dashmap", diff --git a/Cargo.toml b/Cargo.toml index 026ffa9..46ddcaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ tracing-serde = "0.1.3" tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] } jsonrpsee = { path = "./vendor/jsonrpsee/jsonrpsee", features = ["full"] } -governor = "0.6.0" +governor = { path = "./vendor/governor/governor" } [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio", "html_reports"] } diff --git a/benches/bench/main.rs b/benches/bench/main.rs index 887246d..ea66c99 100644 --- a/benches/bench/main.rs +++ b/benches/bench/main.rs @@ -243,6 +243,7 @@ fn config() -> Config { response: None, cache: None, delay_ms: None, + rate_limit_weight: 1, }, RpcMethod { method: helpers::ASYNC_FAST_CALL.to_string(), @@ -250,6 +251,7 @@ fn config() -> Config { response: None, cache: None, delay_ms: None, + rate_limit_weight: 1, }, RpcMethod { method: helpers::SYNC_MEM_CALL.to_string(), @@ -257,6 +259,7 @@ fn config() -> Config { response: None, cache: None, delay_ms: None, + rate_limit_weight: 1, }, RpcMethod { method: helpers::ASYNC_MEM_CALL.to_string(), @@ -264,6 +267,7 @@ fn config() -> Config { response: None, cache: None, delay_ms: None, + rate_limit_weight: 1, }, RpcMethod { method: helpers::SYNC_SLOW_CALL.to_string(), @@ -271,6 +275,7 @@ fn config() -> Config { response: None, cache: None, delay_ms: None, + rate_limit_weight: 1, }, RpcMethod { method: helpers::ASYNC_SLOW_CALL.to_string(), @@ -278,6 +283,7 @@ fn config() -> Config { response: None, cache: None, delay_ms: None, + rate_limit_weight: 1, }, RpcMethod { method: helpers::ASYNC_INJECT_CALL.to_string(), @@ -298,6 +304,7 @@ fn config() -> Config { response: None, cache: None, delay_ms: None, + rate_limit_weight: 1, }, ], subscriptions: vec![RpcSubscription { diff --git a/benches/bench/rate_limit.rs b/benches/bench/rate_limit.rs index b17bda3..e89678c 100644 --- a/benches/bench/rate_limit.rs +++ b/benches/bench/rate_limit.rs @@ -26,6 +26,7 @@ pub fn connection_rate_limit(c: &mut Criterion) { NonZeroU32::new(1000).unwrap(), Duration::from_millis(1000), Jitter::up_to(Duration::from_millis(10)), + Default::default(), ); c.bench_function("rate_limit/connection_rate_limit", |b| { @@ -42,6 +43,7 @@ pub fn ip_rate_limit(c: &mut Criterion) { "::1".to_string(), std::sync::Arc::new(limiter), Jitter::up_to(Duration::from_millis(10)), + Default::default(), ); c.bench_function("rate_limit/ip_rate_limit", |b| { diff --git a/src/config/rpc.rs b/src/config/rpc.rs index 1c381fe..8598a90 100644 --- a/src/config/rpc.rs +++ b/src/config/rpc.rs @@ -35,6 +35,21 @@ pub struct RpcMethod { #[serde(default)] pub delay_ms: Option, + + /// This should not exceed max cell capacity. If it does, + /// method will return error. Burst size is the max cell capacity. + /// If rate limit is not configured, this will be ignored. + /// e.g. if rate limit is configured as 10r per 2s and rate_limit_weight is 10, + /// then only 1 call is allowed per 2s. If rate_limit_weight is 5, then 2 calls + /// are allowed per 2s. If rate_limit_weight is greater than 10, then method will + /// return error "rate limit exceeded". + /// Add this if you want to modify the default value of 1. + #[serde(default = "default_rate_limit_weight")] + pub rate_limit_weight: u32, +} + +fn default_rate_limit_weight() -> u32 { + 1 } #[derive(Copy, Clone, Deserialize, Debug)] diff --git a/src/extensions/rate_limit/connection.rs b/src/extensions/rate_limit/connection.rs index 06c155d..0bd74d4 100644 --- a/src/extensions/rate_limit/connection.rs +++ b/src/extensions/rate_limit/connection.rs @@ -1,22 +1,28 @@ +use crate::{extensions::rate_limit::MethodWeights, utils::errors}; use futures::{future::BoxFuture, FutureExt}; use governor::{DefaultDirectRateLimiter, Jitter, RateLimiter}; use jsonrpsee::{ server::{middleware::rpc::RpcServiceT, types::Request}, MethodResponse, }; -use std::num::NonZeroU32; -use std::{sync::Arc, time::Duration}; +use std::{num::NonZeroU32, sync::Arc, time::Duration}; #[derive(Clone)] pub struct ConnectionRateLimitLayer { burst: NonZeroU32, period: Duration, jitter: Jitter, + method_weights: MethodWeights, } impl ConnectionRateLimitLayer { - pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self { - Self { burst, period, jitter } + pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter, method_weights: MethodWeights) -> Self { + Self { + burst, + period, + jitter, + method_weights, + } } } @@ -24,7 +30,13 @@ impl tower::Layer for ConnectionRateLimitLayer { type Service = ConnectionRateLimit; fn layer(&self, service: S) -> Self::Service { - ConnectionRateLimit::new(service, self.burst, self.period, self.jitter) + ConnectionRateLimit::new( + service, + self.burst, + self.period, + self.jitter, + self.method_weights.clone(), + ) } } @@ -33,16 +45,18 @@ pub struct ConnectionRateLimit { service: S, limiter: Arc, jitter: Jitter, + method_weights: MethodWeights, } impl ConnectionRateLimit { - pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self { + pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter, method_weights: MethodWeights) -> Self { let quota = super::build_quota(burst, period); let limiter = Arc::new(RateLimiter::direct(quota)); Self { service, limiter, jitter, + method_weights, } } } @@ -57,9 +71,14 @@ where let jitter = self.jitter; let service = self.service.clone(); let limiter = self.limiter.clone(); + let weight = self.method_weights.get(req.method_name()); async move { - limiter.until_ready_with_jitter(jitter).await; + if let Some(n) = NonZeroU32::new(weight) { + if limiter.until_n_ready_with_jitter(n, jitter).await.is_err() { + return MethodResponse::error(req.id, errors::failed("rate limit exceeded")); + } + } service.call(req).await } .boxed() @@ -88,6 +107,7 @@ mod tests { NonZeroU32::new(10).unwrap(), Duration::from_millis(100), Jitter::up_to(Duration::from_millis(10)), + Default::default(), ); let batch = |service: ConnectionRateLimit, count: usize, delay| async move { diff --git a/src/extensions/rate_limit/ip.rs b/src/extensions/rate_limit/ip.rs index 365739a..4274e02 100644 --- a/src/extensions/rate_limit/ip.rs +++ b/src/extensions/rate_limit/ip.rs @@ -1,24 +1,32 @@ +use crate::{extensions::rate_limit::MethodWeights, utils::errors}; use futures::{future::BoxFuture, FutureExt}; use governor::{DefaultKeyedRateLimiter, Jitter}; use jsonrpsee::{ server::{middleware::rpc::RpcServiceT, types::Request}, MethodResponse, }; -use std::sync::Arc; +use std::{num::NonZeroU32, sync::Arc}; #[derive(Clone)] pub struct IpRateLimitLayer { ip_addr: String, limiter: Arc>, jitter: Jitter, + method_weights: MethodWeights, } impl IpRateLimitLayer { - pub fn new(ip_addr: String, limiter: Arc>, jitter: Jitter) -> Self { + pub fn new( + ip_addr: String, + limiter: Arc>, + jitter: Jitter, + method_weights: MethodWeights, + ) -> Self { Self { ip_addr, limiter, jitter, + method_weights, } } } @@ -27,7 +35,13 @@ impl tower::Layer for IpRateLimitLayer { type Service = IpRateLimit; fn layer(&self, service: S) -> Self::Service { - IpRateLimit::new(service, self.ip_addr.clone(), self.limiter.clone(), self.jitter) + IpRateLimit::new( + service, + self.ip_addr.clone(), + self.limiter.clone(), + self.jitter, + self.method_weights.clone(), + ) } } @@ -37,15 +51,23 @@ pub struct IpRateLimit { ip_addr: String, limiter: Arc>, jitter: Jitter, + method_weights: MethodWeights, } impl IpRateLimit { - pub fn new(service: S, ip_addr: String, limiter: Arc>, jitter: Jitter) -> Self { + pub fn new( + service: S, + ip_addr: String, + limiter: Arc>, + jitter: Jitter, + method_weights: MethodWeights, + ) -> Self { Self { service, ip_addr, limiter, jitter, + method_weights, } } } @@ -61,9 +83,17 @@ where let jitter = self.jitter; let service = self.service.clone(); let limiter = self.limiter.clone(); - + let weight = self.method_weights.get(req.method_name()); async move { - limiter.until_key_ready_with_jitter(&ip_addr, jitter).await; + if let Some(n) = NonZeroU32::new(weight) { + if limiter + .until_key_n_ready_with_jitter(&ip_addr, n, jitter) + .await + .is_err() + { + return MethodResponse::error(req.id, errors::failed("rate limit exceeded")); + } + } service.call(req).await } .boxed() diff --git a/src/extensions/rate_limit/mod.rs b/src/extensions/rate_limit/mod.rs index 3d7f7bf..ba987b9 100644 --- a/src/extensions/rate_limit/mod.rs +++ b/src/extensions/rate_limit/mod.rs @@ -7,10 +7,12 @@ use super::{Extension, ExtensionRegistry}; mod connection; mod ip; +mod weight; mod xff; pub use connection::{ConnectionRateLimit, ConnectionRateLimitLayer}; pub use ip::{IpRateLimit, IpRateLimitLayer}; +pub use weight::MethodWeights; pub use xff::XFF; #[derive(Deserialize, Debug, Clone, Default)] @@ -88,20 +90,25 @@ impl RateLimitBuilder { } } } - pub fn connection_limit(&self) -> Option { + pub fn connection_limit(&self, method_weights: MethodWeights) -> Option { if let Some(ref rule) = self.config.connection { let burst = NonZeroU32::new(rule.burst).unwrap(); let period = Duration::from_secs(rule.period_secs); let jitter = Jitter::up_to(Duration::from_millis(rule.jitter_up_to_millis)); - Some(ConnectionRateLimitLayer::new(burst, period, jitter)) + Some(ConnectionRateLimitLayer::new(burst, period, jitter, method_weights)) } else { None } } - pub fn ip_limit(&self, remote_ip: String) -> Option { - self.ip_limiter - .as_ref() - .map(|ip_limiter| IpRateLimitLayer::new(remote_ip, ip_limiter.clone(), self.ip_jitter.unwrap_or_default())) + pub fn ip_limit(&self, remote_ip: String, method_weights: MethodWeights) -> Option { + self.ip_limiter.as_ref().map(|ip_limiter| { + IpRateLimitLayer::new( + remote_ip, + ip_limiter.clone(), + self.ip_jitter.unwrap_or_default(), + method_weights, + ) + }) } // whether to use the X-Forwarded-For header to get the remote ip diff --git a/src/extensions/rate_limit/weight.rs b/src/extensions/rate_limit/weight.rs new file mode 100644 index 0000000..bf7d33c --- /dev/null +++ b/src/extensions/rate_limit/weight.rs @@ -0,0 +1,25 @@ +use crate::config::RpcMethod; +use std::collections::BTreeMap; + +#[derive(Clone, Debug, Default)] +pub struct MethodWeights(BTreeMap); + +impl MethodWeights { + pub fn add(&mut self, method: &str, weight: u32) { + self.0.insert(method.to_owned(), weight); + } + + pub fn get(&self, method: &str) -> u32 { + self.0.get(method).cloned().unwrap_or(1) + } +} + +impl MethodWeights { + pub fn from_config(methods: &[RpcMethod]) -> Self { + let mut weights = MethodWeights::default(); + for method in methods { + weights.add(&method.method, method.rate_limit_weight); + } + weights + } +} diff --git a/src/extensions/server/mod.rs b/src/extensions/server/mod.rs index 823e04a..f052f0a 100644 --- a/src/extensions/server/mod.rs +++ b/src/extensions/server/mod.rs @@ -16,7 +16,7 @@ use tower::ServiceBuilder; use tower_http::cors::{AllowOrigin, CorsLayer}; use super::{Extension, ExtensionRegistry}; -use crate::extensions::rate_limit::{RateLimitBuilder, XFF}; +use crate::extensions::rate_limit::{MethodWeights, RateLimitBuilder, XFF}; mod proxy_get_request; use proxy_get_request::{ProxyGetRequestLayer, ProxyGetRequestMethod}; @@ -97,6 +97,7 @@ impl SubwayServerBuilder { pub async fn build>>>( &self, rate_limit_builder: Option>, + rpc_method_weights: MethodWeights, rpc_module_builder: impl FnOnce() -> Fut, ) -> anyhow::Result<(SocketAddr, ServerHandle)> { let config = self.config.clone(); @@ -128,6 +129,7 @@ impl SubwayServerBuilder { let rpc_module = rpc_module.clone(); let stop_handle = stop_handle.clone(); let rate_limit_builder = rate_limit_builder.clone(); + let rpc_method_weights = rpc_method_weights.clone(); async move { // service_fn handle each request @@ -142,8 +144,16 @@ impl SubwayServerBuilder { } let rpc_middleware = RpcServiceBuilder::new() - .option_layer(rate_limit_builder.as_ref().and_then(|r| r.ip_limit(socket_ip))) - .option_layer(rate_limit_builder.as_ref().and_then(|r| r.connection_limit())); + .option_layer( + rate_limit_builder + .as_ref() + .and_then(|r| r.ip_limit(socket_ip, rpc_method_weights.clone())), + ) + .option_layer( + rate_limit_builder + .as_ref() + .and_then(|r| r.connection_limit(rpc_method_weights.clone())), + ); let service_builder = ServerBuilder::default() .set_rpc_middleware(rpc_middleware) diff --git a/src/middlewares/methods/cache.rs b/src/middlewares/methods/cache.rs index 9f9704c..4a3eba1 100644 --- a/src/middlewares/methods/cache.rs +++ b/src/middlewares/methods/cache.rs @@ -348,6 +348,7 @@ mod tests { params: vec![], response: None, delay_ms: None, + rate_limit_weight: 1, }, &ext, ) @@ -365,6 +366,7 @@ mod tests { params: vec![], response: None, delay_ms: None, + rate_limit_weight: 1, }, &ext, ) @@ -382,6 +384,7 @@ mod tests { params: vec![], response: None, delay_ms: None, + rate_limit_weight: 1, }, &ext, ) @@ -396,6 +399,7 @@ mod tests { params: vec![], response: None, delay_ms: None, + rate_limit_weight: 1, }, &ext, ) diff --git a/src/server.rs b/src/server.rs index 4745d70..0994de3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,12 +9,14 @@ use jsonrpsee::{ use opentelemetry::trace::FutureExt as _; use serde_json::json; -use crate::utils::TypeRegistryRef; use crate::{ config::Config, - extensions::{rate_limit::RateLimitBuilder, server::SubwayServerBuilder}, + extensions::{ + rate_limit::{MethodWeights, RateLimitBuilder}, + server::SubwayServerBuilder, + }, middlewares::{factory, CallRequest, Middlewares, SubscriptionRequest}, - utils::{errors, telemetry}, + utils::{errors, telemetry, TypeRegistryRef}, }; // TODO: https://github.com/paritytech/jsonrpsee/issues/985 @@ -45,11 +47,13 @@ pub async fn build(config: Config) -> anyhow::Result { let rate_limit_builder = extensions_registry.read().await.get::(); + let rpc_method_weights = MethodWeights::from_config(&config.rpcs.methods); + let request_timeout_seconds = server_builder.config.request_timeout_seconds; let registry = extensions_registry.clone(); let (addr, handle) = server_builder - .build(rate_limit_builder, move || async move { + .build(rate_limit_builder, rpc_method_weights, move || async move { let mut module = RpcModule::new(()); let tracer = telemetry::Tracer::new("server"); @@ -258,6 +262,7 @@ mod tests { cache: None, response: None, delay_ms: None, + rate_limit_weight: 1, }, RpcMethod { method: TIMEOUT.to_string(), @@ -265,6 +270,7 @@ mod tests { cache: None, response: None, delay_ms: None, + rate_limit_weight: 1, }, RpcMethod { method: CRAZY.to_string(), @@ -272,6 +278,7 @@ mod tests { cache: None, response: None, delay_ms: None, + rate_limit_weight: 1, }, ], subscriptions: vec![], diff --git a/vendor/governor b/vendor/governor index 921b986..870dfb0 160000 --- a/vendor/governor +++ b/vendor/governor @@ -1 +1 @@ -Subproject commit 921b986371f1eff03b6236dacd659b02738f3a66 +Subproject commit 870dfb073892b5637f988aaa25a322fdc3d3ef59