From bdb347c92af5f3dabaf33db1fc0ace425ae23375 Mon Sep 17 00:00:00 2001 From: yjh Date: Thu, 17 Oct 2024 10:23:14 +0800 Subject: [PATCH] feat: support rate limit config for every rpc (#72) --- configs/demo_config.yml | 7 + src/extensions/rate_limit/global.rs | 2 +- src/extensions/rate_limit/method.rs | 228 ++++++++++++++++++++++++++++ src/extensions/rate_limit/mod.rs | 69 +++++++-- src/extensions/rate_limit/weight.rs | 3 +- src/extensions/server/mod.rs | 5 + 6 files changed, 302 insertions(+), 12 deletions(-) create mode 100644 src/extensions/rate_limit/method.rs diff --git a/configs/demo_config.yml b/configs/demo_config.yml index 0bf34f3..f20cdfb 100644 --- a/configs/demo_config.yml +++ b/configs/demo_config.yml @@ -35,6 +35,13 @@ extensions: # these are for demo purpose only, please adjust to your needs. # It's recommend to use big `period_secs` to reduce CPU costs. # rate_limit: +# methods: +# bursts: +# # Alloc 300 burst to this method +# eth_sendRawTransaction: 300 +# eth_call: 300 +# period_secs: 10 +# blocking: false # false by default # connection: # 1000 weight per 10 second per connection # burst: 1000 # period_secs: 10 diff --git a/src/extensions/rate_limit/global.rs b/src/extensions/rate_limit/global.rs index 1047a49..2fc3ea6 100644 --- a/src/extensions/rate_limit/global.rs +++ b/src/extensions/rate_limit/global.rs @@ -159,6 +159,6 @@ mod tests { let duration = start.elapsed().as_millis(); println!("duration: {} ms", duration); // should take between 800..900 millis. each 100ms period handles 10 calls - assert!(duration > 800 && duration < 900); + assert!(duration > 800 && duration < 1000); } } diff --git a/src/extensions/rate_limit/method.rs b/src/extensions/rate_limit/method.rs new file mode 100644 index 0000000..33e3027 --- /dev/null +++ b/src/extensions/rate_limit/method.rs @@ -0,0 +1,228 @@ +use crate::extensions::rate_limit::MethodWeights; +use crate::utils::errors; +use futures::{future::BoxFuture, FutureExt}; +use governor::{DefaultDirectRateLimiter, Jitter}; +use jsonrpsee::{ + server::{middleware::rpc::RpcServiceT, types::Request}, + MethodResponse, +}; +use std::collections::HashMap; +use std::{num::NonZeroU32, sync::Arc}; + +#[derive(Clone)] +pub struct MethodRateLimitLayer { + limiters: HashMap>, + jitter: Jitter, + method_weights: MethodWeights, + blocking: bool, +} + +impl MethodRateLimitLayer { + pub fn new( + limiters: HashMap>, + jitter: Jitter, + method_weights: MethodWeights, + ) -> Option { + if limiters.is_empty() { + return None; + } + Some(Self { + limiters, + jitter, + method_weights, + blocking: false, + }) + } + + pub fn blocking(mut self, blocking: bool) -> Self { + self.blocking = blocking; + self + } +} + +impl tower::Layer for MethodRateLimitLayer { + type Service = MethodRateLimit; + + fn layer(&self, service: S) -> Self::Service { + MethodRateLimit::new(service, self.limiters.clone(), self.jitter, self.method_weights.clone()) + .blocking(self.blocking) + } +} + +#[derive(Clone)] +pub struct MethodRateLimit { + service: S, + limiters: HashMap>, + jitter: Jitter, + method_weights: MethodWeights, + blocking: bool, +} + +impl MethodRateLimit { + pub fn new( + service: S, + limiters: HashMap>, + jitter: Jitter, + method_weights: MethodWeights, + ) -> Self { + Self { + service, + limiters, + jitter, + method_weights, + blocking: false, + } + } + + pub fn blocking(mut self, blocking: bool) -> Self { + self.blocking = blocking; + self + } +} + +impl<'a, S> RpcServiceT<'a> for MethodRateLimit +where + S: RpcServiceT<'a> + Send + Sync + Clone + 'static, +{ + type Future = BoxFuture<'a, MethodResponse>; + + fn call(&self, req: Request<'a>) -> Self::Future { + let method_name = req.method_name(); + // no config for this method so no limit for it. + let Some(limiter) = self.limiters.get(method_name).cloned() else { + let service = self.service.clone(); + return async move { service.call(req).await }.boxed(); + }; + let jitter = self.jitter; + let service = self.service.clone(); + let weight = self.method_weights.get(method_name); + let blocking = self.blocking; + + async move { + if let Some(n) = NonZeroU32::new(weight) { + if blocking { + limiter + .until_n_ready_with_jitter(n, jitter) + .await + .expect("check_n have been done during init"); + } else { + match limiter.check_n(n).expect("check_n have been done during init") { + Ok(_) => {} + Err(_) => return MethodResponse::error(req.id, errors::reached_rate_limit()), + } + } + } + service.call(req).await + } + .boxed() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{RpcMethod, RpcSubscription, RpcSubscriptionMethod}; + use crate::extensions::rate_limit::build_quota; + use governor::RateLimiter; + use jsonrpsee::types::Id; + use jsonrpsee::ResponsePayload; + use std::time::Duration; + + #[derive(Clone)] + struct MockService; + impl RpcServiceT<'static> for MockService { + type Future = BoxFuture<'static, MethodResponse>; + + fn call(&self, req: Request<'static>) -> Self::Future { + async move { MethodResponse::response(req.id, ResponsePayload::success("ok"), 1024) }.boxed() + } + } + + #[tokio::test] + async fn method_rate_limit_works() { + let burst = NonZeroU32::new(10).unwrap(); + let period = Duration::from_millis(100); + let quota = build_quota(burst, period); + let mut limiters = HashMap::new(); + limiters.insert("test".to_string(), Arc::new(RateLimiter::direct(quota))); + limiters.insert("subscribe".to_string(), Arc::new(RateLimiter::direct(quota))); + let weights = MethodWeights::from_config( + &[RpcMethod { + method: "test".to_string(), + rate_limit_weight: 10, + ..Default::default() + }], + &[RpcSubscription { + subscribe: RpcSubscriptionMethod { + method: "subscribe".to_string(), + rate_limit_weight: 10, + }, + unsubscribe: RpcSubscriptionMethod { + method: "unsubscribe".to_string(), + rate_limit_weight: 10, + }, + name: "subscribe".to_string(), + merge_strategy: None, + }], + ); + + let service = MethodRateLimit::new(MockService, limiters, Jitter::up_to(Duration::from_millis(1)), weights) + .blocking(true); + + let batch = |method_name: String, service: MethodRateLimit, count: usize, delay| async move { + tokio::time::sleep(Duration::from_millis(delay)).await; + let calls = (1..=count) + .map(|id| service.call(Request::new(method_name.clone().into(), None, Id::Number(id as u64)))) + .collect::>(); + let results = futures::future::join_all(calls).await; + assert_eq!(results.iter().filter(|r| r.is_success()).count(), count); + }; + + // limited by 10 weights + { + let start = tokio::time::Instant::now(); + tokio::spawn(batch("test".to_string(), service.clone(), 1, 0)) + .await + .unwrap(); + let duration = start.elapsed().as_millis(); + assert!(duration < 100, "actual duration: {}", duration); + + let start = tokio::time::Instant::now(); + tokio::spawn(batch("test".to_string(), service.clone(), 2, 0)) + .await + .unwrap(); + let duration = start.elapsed().as_millis(); + assert!(duration > 100, "actual duration: {}", duration); + + let start = tokio::time::Instant::now(); + tokio::spawn(batch("subscribe".to_string(), service.clone(), 1, 0)) + .await + .unwrap(); + let duration = start.elapsed().as_millis(); + assert!(duration < 100, "actual duration: {}", duration); + + let start = tokio::time::Instant::now(); + tokio::spawn(batch("subscribe".to_string(), service.clone(), 2, 0)) + .await + .unwrap(); + let duration = start.elapsed().as_millis(); + assert!(duration > 100, "actual duration: {}", duration); + } + + // limited by default 1 weight. + { + let start = tokio::time::Instant::now(); + tokio::spawn(batch("foo".to_string(), service.clone(), 5, 0)) + .await + .unwrap(); + let duration = start.elapsed().as_millis(); + assert!(duration < 100, "actual duration: {}", duration); + + tokio::spawn(batch("foo".to_string(), service.clone(), 6, 0)) + .await + .unwrap(); + let duration = start.elapsed().as_millis(); + assert!(duration < 100, "actual duration: {}", duration); + } + } +} diff --git a/src/extensions/rate_limit/mod.rs b/src/extensions/rate_limit/mod.rs index a68cfb9..25f1793 100644 --- a/src/extensions/rate_limit/mod.rs +++ b/src/extensions/rate_limit/mod.rs @@ -1,16 +1,19 @@ use super::{Extension, ExtensionRegistry}; use governor::{DefaultDirectRateLimiter, DefaultKeyedRateLimiter, Jitter, Quota, RateLimiter}; use serde::Deserialize; +use std::collections::HashMap; use std::num::NonZeroU32; use std::{sync::Arc, time::Duration}; mod connection; mod global; mod ip; +mod method; mod weight; mod xff; use crate::extensions::rate_limit::global::GlobalRateLimitLayer; +use crate::extensions::rate_limit::method::MethodRateLimitLayer; pub use connection::{ConnectionRateLimit, ConnectionRateLimitLayer}; pub use ip::{IpRateLimit, IpRateLimitLayer}; pub use weight::MethodWeights; @@ -19,12 +22,33 @@ pub use xff::XFF; #[derive(Deserialize, Debug, Clone, Default)] pub struct RateLimitConfig { pub ip: Option, + pub methods: Option, pub connection: Option, pub global: Option, #[serde(default)] pub use_xff: bool, } +/// The rule for every method. +/// +/// Only work for method rate limit. +#[derive(Deserialize, Debug, Clone, Default)] +pub struct MethodsRule { + /// The burst for every rpc. + pub bursts: HashMap, + /// Return the responses with delay instead of returning a rate limit jsonrpc error directly if true. + #[serde(default)] + pub blocking: bool, + /// period is the period of time in which the burst is allowed + #[serde(default = "default_period_secs")] + pub period_secs: u64, + // jitter_millis is the maximum amount of jitter to add to the rate limit + // this is to prevent a thundering herd problem https://en.wikipedia.org/wiki/Thundering_herd_problem + // e.g. if jitter_up_to_millis is 1000, then additional delay of random(0, 1000) milliseconds will be added + #[serde(default = "default_jitter_up_to_millis")] + pub jitter_up_to_millis: u64, +} + #[derive(Deserialize, Debug, Clone, Default)] pub struct Rule { /// burst is the maximum number of requests that can be made in a period @@ -60,6 +84,10 @@ pub struct RateLimitBuilder { global_jitter: Option, global_limiter: Option>, global_blocking: bool, + + method_jitter: Option, + method_blocking: bool, + method_limiters: HashMap>, } #[async_trait::async_trait] @@ -105,6 +133,20 @@ impl RateLimitBuilder { global_blocking = rule.blocking; } + let mut method_jitter = None; + let mut method_blocking = false; + let mut method_limiters = HashMap::new(); + if let Some(ref methods_rule) = config.methods { + for (method, burst) in methods_rule.bursts.iter() { + let burst = NonZeroU32::new(*burst).unwrap(); + let quota = build_quota(burst, Duration::from_secs(methods_rule.period_secs)); + let method_limiter = Arc::new(DefaultDirectRateLimiter::direct(quota)); + method_limiters.insert(method.clone(), method_limiter); + } + method_jitter = Some(Jitter::up_to(Duration::from_millis(methods_rule.jitter_up_to_millis))); + method_blocking = methods_rule.blocking; + } + Self { config, @@ -115,6 +157,10 @@ impl RateLimitBuilder { global_jitter, global_limiter, global_blocking, + + method_jitter, + method_limiters, + method_blocking, } } @@ -130,10 +176,10 @@ impl RateLimitBuilder { } pub fn ip_limit(&self, remote_ip: String, method_weights: MethodWeights) -> Option { - self.ip_limiter.as_ref().map(|ip_limiter| { + self.ip_limiter.as_ref().map(|limiter| { IpRateLimitLayer::new( remote_ip, - ip_limiter.clone(), + limiter.clone(), self.ip_jitter.unwrap_or_default(), method_weights, ) @@ -141,14 +187,19 @@ impl RateLimitBuilder { }) } + pub fn method_limits(&self, method_weights: MethodWeights) -> Option { + MethodRateLimitLayer::new( + self.method_limiters.clone(), + self.method_jitter.unwrap_or_default(), + method_weights, + ) + .map(|layer| layer.blocking(self.method_blocking)) + } + pub fn global_limit(&self, method_weights: MethodWeights) -> Option { - self.global_limiter.as_ref().map(|global_limiter| { - GlobalRateLimitLayer::new( - global_limiter.clone(), - self.global_jitter.unwrap_or_default(), - method_weights, - ) - .blocking(self.global_blocking) + self.global_limiter.as_ref().map(|limiter| { + GlobalRateLimitLayer::new(limiter.clone(), self.global_jitter.unwrap_or_default(), method_weights) + .blocking(self.global_blocking) }) } diff --git a/src/extensions/rate_limit/weight.rs b/src/extensions/rate_limit/weight.rs index 89540ac..1c9a17d 100644 --- a/src/extensions/rate_limit/weight.rs +++ b/src/extensions/rate_limit/weight.rs @@ -1,6 +1,7 @@ use crate::config::{RpcMethod, RpcSubscription}; use std::{collections::BTreeMap, sync::Arc}; +/// The weights for every rpc request. #[derive(Clone, Debug, Default)] pub struct MethodWeights(Arc>); @@ -8,9 +9,7 @@ impl MethodWeights { pub fn get(&self, method: &str) -> u32 { self.0.get(method).cloned().unwrap_or(1) } -} -impl MethodWeights { pub fn from_config(methods: &[RpcMethod], subscriptions: &[RpcSubscription]) -> Self { let mut weights = BTreeMap::default(); for method in methods { diff --git a/src/extensions/server/mod.rs b/src/extensions/server/mod.rs index 821bdcf..a0cc4cd 100644 --- a/src/extensions/server/mod.rs +++ b/src/extensions/server/mod.rs @@ -261,6 +261,11 @@ impl SubwayServerBuilder { .as_ref() .and_then(|r| r.ip_limit(socket_ip, rpc_method_weights.clone())), ) + .option_layer( + rate_limit_builder + .as_ref() + .and_then(|r| r.method_limits(rpc_method_weights.clone())), + ) .option_layer( rate_limit_builder .as_ref()