Skip to content

Commit

Permalink
feat: support rate limit config for every rpc
Browse files Browse the repository at this point in the history
  • Loading branch information
yjhmelody committed Oct 16, 2024
1 parent 2cd84fc commit 1d0f9f5
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 12 deletions.
7 changes: 7 additions & 0 deletions configs/demo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ extensions:
# these are for demo purpose only, please adjust to your needs.
# It's recommend to use big `period_secs` to reduce CPU costs.
# rate_limit:
# methods:
# bursts:
# # Alloc 300 burst to this method
# eth_sendRawTransaction: 300
# eth_call: 300
# period_secs: 10
# blocking: false # false by default
# connection: # 1000 weight per 10 second per connection
# burst: 1000
# period_secs: 10
Expand Down
2 changes: 1 addition & 1 deletion src/extensions/rate_limit/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,6 @@ mod tests {
let duration = start.elapsed().as_millis();
println!("duration: {} ms", duration);
// should take between 800..900 millis. each 100ms period handles 10 calls
assert!(duration > 800 && duration < 900);
assert!(duration > 800 && duration < 1000);
}
}
228 changes: 228 additions & 0 deletions src/extensions/rate_limit/method.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
use crate::extensions::rate_limit::MethodWeights;
use crate::utils::errors;
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultDirectRateLimiter, Jitter};
use jsonrpsee::{
server::{middleware::rpc::RpcServiceT, types::Request},
MethodResponse,
};
use std::collections::HashMap;
use std::{num::NonZeroU32, sync::Arc};

#[derive(Clone)]
pub struct MethodRateLimitLayer {
limiters: HashMap<String, Arc<DefaultDirectRateLimiter>>,
jitter: Jitter,
method_weights: MethodWeights,
blocking: bool,
}

impl MethodRateLimitLayer {
pub fn new(
limiters: HashMap<String, Arc<DefaultDirectRateLimiter>>,
jitter: Jitter,
method_weights: MethodWeights,
) -> Option<Self> {
if limiters.is_empty() {
return None;
}
Some(Self {
limiters,
jitter,
method_weights,
blocking: false,
})
}

pub fn blocking(mut self, blocking: bool) -> Self {
self.blocking = blocking;
self
}
}

impl<S> tower::Layer<S> for MethodRateLimitLayer {
type Service = MethodRateLimit<S>;

fn layer(&self, service: S) -> Self::Service {
MethodRateLimit::new(service, self.limiters.clone(), self.jitter, self.method_weights.clone())
.blocking(self.blocking)
}
}

#[derive(Clone)]
pub struct MethodRateLimit<S> {
service: S,
limiters: HashMap<String, Arc<DefaultDirectRateLimiter>>,
jitter: Jitter,
method_weights: MethodWeights,
blocking: bool,
}

impl<S> MethodRateLimit<S> {
pub fn new(
service: S,
limiters: HashMap<String, Arc<DefaultDirectRateLimiter>>,
jitter: Jitter,
method_weights: MethodWeights,
) -> Self {
Self {
service,
limiters,
jitter,
method_weights,
blocking: false,
}
}

pub fn blocking(mut self, blocking: bool) -> Self {
self.blocking = blocking;
self
}
}

impl<'a, S> RpcServiceT<'a> for MethodRateLimit<S>
where
S: RpcServiceT<'a> + Send + Sync + Clone + 'static,
{
type Future = BoxFuture<'a, MethodResponse>;

fn call(&self, req: Request<'a>) -> Self::Future {
let method_name = req.method_name();
// no config for this method so no limit for it.
let Some(limiter) = self.limiters.get(method_name).cloned() else {
let service = self.service.clone();
return async move { service.call(req).await }.boxed();
};
let jitter = self.jitter;
let service = self.service.clone();
let weight = self.method_weights.get(method_name);
let blocking = self.blocking;

async move {
if let Some(n) = NonZeroU32::new(weight) {
if blocking {
limiter
.until_n_ready_with_jitter(n, jitter)
.await
.expect("check_n have been done during init");
} else {
match limiter.check_n(n).expect("check_n have been done during init") {
Ok(_) => {}
Err(_) => return MethodResponse::error(req.id, errors::reached_rate_limit()),
}
}
}
service.call(req).await
}
.boxed()
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::config::{RpcMethod, RpcSubscription, RpcSubscriptionMethod};
use crate::extensions::rate_limit::build_quota;
use governor::RateLimiter;
use jsonrpsee::types::Id;
use jsonrpsee::ResponsePayload;
use std::time::Duration;

#[derive(Clone)]
struct MockService;
impl RpcServiceT<'static> for MockService {
type Future = BoxFuture<'static, MethodResponse>;

fn call(&self, req: Request<'static>) -> Self::Future {
async move { MethodResponse::response(req.id, ResponsePayload::success("ok"), 1024) }.boxed()
}
}

#[tokio::test]
async fn method_rate_limit_works() {
let burst = NonZeroU32::new(10).unwrap();
let period = Duration::from_millis(100);
let quota = build_quota(burst, period);
let mut limiters = HashMap::new();
limiters.insert("test".to_string(), Arc::new(RateLimiter::direct(quota)));
limiters.insert("subscribe".to_string(), Arc::new(RateLimiter::direct(quota)));
let weights = MethodWeights::from_config(
&[RpcMethod {
method: "test".to_string(),
rate_limit_weight: 10,
..Default::default()
}],
&[RpcSubscription {
subscribe: RpcSubscriptionMethod {
method: "subscribe".to_string(),
rate_limit_weight: 10,
},
unsubscribe: RpcSubscriptionMethod {
method: "unsubscribe".to_string(),
rate_limit_weight: 10,
},
name: "subscribe".to_string(),
merge_strategy: None,
}],
);

let service = MethodRateLimit::new(MockService, limiters, Jitter::up_to(Duration::from_millis(1)), weights)
.blocking(true);

let batch = |method_name: String, service: MethodRateLimit<MockService>, count: usize, delay| async move {
tokio::time::sleep(Duration::from_millis(delay)).await;
let calls = (1..=count)
.map(|id| service.call(Request::new(method_name.clone().into(), None, Id::Number(id as u64))))
.collect::<Vec<_>>();
let results = futures::future::join_all(calls).await;
assert_eq!(results.iter().filter(|r| r.is_success()).count(), count);
};

// limited by 10 weights
{
let start = tokio::time::Instant::now();
tokio::spawn(batch("test".to_string(), service.clone(), 1, 0))
.await
.unwrap();
let duration = start.elapsed().as_millis();
assert!(duration < 100, "actual duration: {}", duration);

let start = tokio::time::Instant::now();
tokio::spawn(batch("test".to_string(), service.clone(), 2, 0))
.await
.unwrap();
let duration = start.elapsed().as_millis();
assert!(duration > 100, "actual duration: {}", duration);

let start = tokio::time::Instant::now();
tokio::spawn(batch("subscribe".to_string(), service.clone(), 1, 0))
.await
.unwrap();
let duration = start.elapsed().as_millis();
assert!(duration < 100, "actual duration: {}", duration);

let start = tokio::time::Instant::now();
tokio::spawn(batch("subscribe".to_string(), service.clone(), 2, 0))
.await
.unwrap();
let duration = start.elapsed().as_millis();
assert!(duration > 100, "actual duration: {}", duration);
}

// limited by default 1 weight.
{
let start = tokio::time::Instant::now();
tokio::spawn(batch("foo".to_string(), service.clone(), 5, 0))
.await
.unwrap();
let duration = start.elapsed().as_millis();
assert!(duration < 100, "actual duration: {}", duration);

tokio::spawn(batch("foo".to_string(), service.clone(), 6, 0))
.await
.unwrap();
let duration = start.elapsed().as_millis();
assert!(duration < 100, "actual duration: {}", duration);
}
}
}
69 changes: 60 additions & 9 deletions src/extensions/rate_limit/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
use super::{Extension, ExtensionRegistry};
use governor::{DefaultDirectRateLimiter, DefaultKeyedRateLimiter, Jitter, Quota, RateLimiter};
use serde::Deserialize;
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::{sync::Arc, time::Duration};

mod connection;
mod global;
mod ip;
mod method;
mod weight;
mod xff;

use crate::extensions::rate_limit::global::GlobalRateLimitLayer;
use crate::extensions::rate_limit::method::MethodRateLimitLayer;
pub use connection::{ConnectionRateLimit, ConnectionRateLimitLayer};
pub use ip::{IpRateLimit, IpRateLimitLayer};
pub use weight::MethodWeights;
Expand All @@ -19,12 +22,33 @@ pub use xff::XFF;
#[derive(Deserialize, Debug, Clone, Default)]
pub struct RateLimitConfig {
pub ip: Option<Rule>,
pub methods: Option<MethodsRule>,
pub connection: Option<Rule>,
pub global: Option<Rule>,
#[serde(default)]
pub use_xff: bool,
}

/// The rule for every method.
///
/// Only work for method rate limit.
#[derive(Deserialize, Debug, Clone, Default)]
pub struct MethodsRule {
/// The burst for every rpc.
pub bursts: HashMap<String, u32>,
/// Return the responses with delay instead of returning a rate limit jsonrpc error directly if true.
#[serde(default)]
pub blocking: bool,
/// period is the period of time in which the burst is allowed
#[serde(default = "default_period_secs")]
pub period_secs: u64,
// jitter_millis is the maximum amount of jitter to add to the rate limit
// this is to prevent a thundering herd problem https://en.wikipedia.org/wiki/Thundering_herd_problem
// e.g. if jitter_up_to_millis is 1000, then additional delay of random(0, 1000) milliseconds will be added
#[serde(default = "default_jitter_up_to_millis")]
pub jitter_up_to_millis: u64,
}

#[derive(Deserialize, Debug, Clone, Default)]
pub struct Rule {
/// burst is the maximum number of requests that can be made in a period
Expand Down Expand Up @@ -60,6 +84,10 @@ pub struct RateLimitBuilder {
global_jitter: Option<Jitter>,
global_limiter: Option<Arc<DefaultDirectRateLimiter>>,
global_blocking: bool,

method_jitter: Option<Jitter>,
method_blocking: bool,
method_limiters: HashMap<String, Arc<DefaultDirectRateLimiter>>,
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -105,6 +133,20 @@ impl RateLimitBuilder {
global_blocking = rule.blocking;
}

let mut method_jitter = None;
let mut method_blocking = false;
let mut method_limiters = HashMap::new();
if let Some(ref methods_rule) = config.methods {
for (method, burst) in methods_rule.bursts.iter() {
let burst = NonZeroU32::new(*burst).unwrap();
let quota = build_quota(burst, Duration::from_secs(methods_rule.period_secs));
let method_limiter = Arc::new(DefaultDirectRateLimiter::direct(quota));
method_limiters.insert(method.clone(), method_limiter);
}
method_jitter = Some(Jitter::up_to(Duration::from_millis(methods_rule.jitter_up_to_millis)));
method_blocking = methods_rule.blocking;
}

Self {
config,

Expand All @@ -115,6 +157,10 @@ impl RateLimitBuilder {
global_jitter,
global_limiter,
global_blocking,

method_jitter,
method_limiters,
method_blocking,
}
}

Expand All @@ -130,25 +176,30 @@ impl RateLimitBuilder {
}

pub fn ip_limit(&self, remote_ip: String, method_weights: MethodWeights) -> Option<IpRateLimitLayer> {
self.ip_limiter.as_ref().map(|ip_limiter| {
self.ip_limiter.as_ref().map(|limiter| {
IpRateLimitLayer::new(
remote_ip,
ip_limiter.clone(),
limiter.clone(),
self.ip_jitter.unwrap_or_default(),
method_weights,
)
.blocking(self.ip_blocking)
})
}

pub fn method_limits(&self, method_weights: MethodWeights) -> Option<MethodRateLimitLayer> {
MethodRateLimitLayer::new(
self.method_limiters.clone(),
self.method_jitter.unwrap_or_default(),
method_weights,
)
.map(|layer| layer.blocking(self.method_blocking))
}

pub fn global_limit(&self, method_weights: MethodWeights) -> Option<GlobalRateLimitLayer> {
self.global_limiter.as_ref().map(|global_limiter| {
GlobalRateLimitLayer::new(
global_limiter.clone(),
self.global_jitter.unwrap_or_default(),
method_weights,
)
.blocking(self.global_blocking)
self.global_limiter.as_ref().map(|limiter| {
GlobalRateLimitLayer::new(limiter.clone(), self.global_jitter.unwrap_or_default(), method_weights)
.blocking(self.global_blocking)
})
}

Expand Down
Loading

0 comments on commit 1d0f9f5

Please sign in to comment.