Skip to content

Commit

Permalink
feat: add global rate limit config (#58)
Browse files Browse the repository at this point in the history
* feat: add global rate limit config

* fmt

* add validation for global
  • Loading branch information
yjhmelody authored Jul 25, 2024
1 parent ba8cdf0 commit 2673993
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 16 deletions.
12 changes: 12 additions & 0 deletions configs/demo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ extensions:
max_connections: 2000
max_batch_size: 10
cors: all
# these are for demo purpose only, please adjust to your needs.
# It's recommend to use big `period_secs` to reduce CPU costs.
rate_limit:
connection: # 10 RPC requests per 10 second per connection
burst: 10
period_secs: 10
ip: # 20 RPC requests per 10 seconds per ip
burst: 20
period_secs: 10
global: # 40 RPC requests per 10 seconds for global request
burst: 40
period_secs: 10
whitelist:
eth_call:
# allow 0x01 to create contract.
Expand Down
12 changes: 12 additions & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,18 @@ pub async fn validate(config: &Config) -> Result<(), anyhow::Error> {
}
}
}

if let Some(ref rule) = rate_limit.global {
for method in &config.rpcs.methods {
if method.rate_limit_weight > rule.burst {
bail!(
"`{}` rate_limit_weight is too big for global: {}",
method.method,
method.rate_limit_weight,
);
}
}
}
}

// since endpoints connection test is async
Expand Down
139 changes: 139 additions & 0 deletions src/extensions/rate_limit/global.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use crate::extensions::rate_limit::MethodWeights;
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultDirectRateLimiter, Jitter};
use jsonrpsee::{
server::{middleware::rpc::RpcServiceT, types::Request},
MethodResponse,
};
use std::{num::NonZeroU32, sync::Arc};

#[derive(Clone)]
pub struct GlobalRateLimitLayer {
limiter: Arc<DefaultDirectRateLimiter>,
jitter: Jitter,
method_weights: MethodWeights,
}

impl GlobalRateLimitLayer {
pub fn new(limiter: Arc<DefaultDirectRateLimiter>, jitter: Jitter, method_weights: MethodWeights) -> Self {
Self {
limiter,
jitter,
method_weights,
}
}
}

impl<S> tower::Layer<S> for GlobalRateLimitLayer {
type Service = GlobalRateLimit<S>;

fn layer(&self, service: S) -> Self::Service {
GlobalRateLimit::new(service, self.limiter.clone(), self.jitter, self.method_weights.clone())
}
}

#[derive(Clone)]
pub struct GlobalRateLimit<S> {
service: S,
limiter: Arc<DefaultDirectRateLimiter>,
jitter: Jitter,
method_weights: MethodWeights,
}

impl<S> GlobalRateLimit<S> {
pub fn new(
service: S,
limiter: Arc<DefaultDirectRateLimiter>,
jitter: Jitter,
method_weights: MethodWeights,
) -> Self {
Self {
service,
limiter,
jitter,
method_weights,
}
}
}

impl<'a, S> RpcServiceT<'a> for GlobalRateLimit<S>
where
S: RpcServiceT<'a> + Send + Sync + Clone + 'static,
{
type Future = BoxFuture<'a, MethodResponse>;

fn call(&self, req: Request<'a>) -> Self::Future {
let jitter = self.jitter;
let service = self.service.clone();
let limiter = self.limiter.clone();
let weight = self.method_weights.get(req.method_name());

async move {
if let Some(n) = NonZeroU32::new(weight) {
limiter
.until_n_ready_with_jitter(n, jitter)
.await
.expect("check_n have been done during init");
}
service.call(req).await
}
.boxed()
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::extensions::rate_limit::build_quota;
use governor::RateLimiter;
use jsonrpsee::types::Id;
use jsonrpsee::ResponsePayload;
use std::time::Duration;

#[derive(Clone)]
struct MockService;
impl RpcServiceT<'static> for MockService {
type Future = BoxFuture<'static, MethodResponse>;

fn call(&self, req: Request<'static>) -> Self::Future {
async move { MethodResponse::response(req.id, ResponsePayload::success("ok"), 1024) }.boxed()
}
}

#[tokio::test]
async fn rate_limit_works() {
let burst = NonZeroU32::new(10).unwrap();
let period = Duration::from_millis(100);
let quota = build_quota(burst, period);
let limiter = Arc::new(RateLimiter::direct(quota));

let service = GlobalRateLimit::new(
MockService,
limiter,
Jitter::up_to(Duration::from_millis(10)),
Default::default(),
);

let batch = |service: GlobalRateLimit<MockService>, count: usize, delay| async move {
tokio::time::sleep(Duration::from_millis(delay)).await;
let calls = (1..=count)
.map(|id| service.call(Request::new("test".into(), None, Id::Number(id as u64))))
.collect::<Vec<_>>();
let results = futures::future::join_all(calls).await;
assert_eq!(results.iter().filter(|r| r.is_success()).count(), count);
};

let start = tokio::time::Instant::now();
// background task to make calls
let batch1 = tokio::spawn(batch(service.clone(), 30, 0));
let batch2 = tokio::spawn(batch(service.clone(), 40, 200));
let batch3 = tokio::spawn(batch(service.clone(), 20, 300));
batch1.await.unwrap();
batch2.await.unwrap();
batch3.await.unwrap();
let duration = start.elapsed().as_millis();
println!("duration: {} ms", duration);
// should take between 800..900 millis. each 100ms period handles 10 calls
assert!(duration > 800 && duration < 900);
}
}
57 changes: 41 additions & 16 deletions src/extensions/rate_limit/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use governor::{DefaultKeyedRateLimiter, Jitter, Quota, RateLimiter};
use super::{Extension, ExtensionRegistry};
use governor::{DefaultDirectRateLimiter, DefaultKeyedRateLimiter, Jitter, Quota, RateLimiter};
use serde::Deserialize;
use std::num::NonZeroU32;
use std::{sync::Arc, time::Duration};

use super::{Extension, ExtensionRegistry};

mod connection;
mod global;
mod ip;
mod weight;
mod xff;

use crate::extensions::rate_limit::global::GlobalRateLimitLayer;
pub use connection::{ConnectionRateLimit, ConnectionRateLimitLayer};
pub use ip::{IpRateLimit, IpRateLimitLayer};
pub use weight::MethodWeights;
Expand All @@ -19,6 +20,7 @@ pub use xff::XFF;
pub struct RateLimitConfig {
pub ip: Option<Rule>,
pub connection: Option<Rule>,
pub global: Option<Rule>,
#[serde(default)]
pub use_xff: bool,
}
Expand Down Expand Up @@ -47,8 +49,12 @@ fn default_jitter_up_to_millis() -> u64 {

pub struct RateLimitBuilder {
config: RateLimitConfig,

ip_jitter: Option<Jitter>,
ip_limiter: Option<Arc<DefaultKeyedRateLimiter<String>>>,

global_jitter: Option<Jitter>,
global_limiter: Option<Arc<DefaultDirectRateLimiter>>,
}

#[async_trait::async_trait]
Expand All @@ -72,22 +78,30 @@ impl RateLimitBuilder {
assert!(rule.period_secs > 0, "period_secs must be greater than 0");
}

let mut ip_limiter = None;
let mut ip_jitter = None;
if let Some(ref rule) = config.ip {
let burst = NonZeroU32::new(rule.burst).unwrap();
let quota = build_quota(burst, Duration::from_secs(rule.period_secs));
let ip_limiter = Some(Arc::new(RateLimiter::keyed(quota)));
let ip_jitter = Some(Jitter::up_to(Duration::from_millis(rule.jitter_up_to_millis)));
Self {
config,
ip_jitter,
ip_limiter,
}
} else {
Self {
config,
ip_jitter: None,
ip_limiter: None,
}
ip_limiter = Some(Arc::new(RateLimiter::keyed(quota)));
ip_jitter = Some(Jitter::up_to(Duration::from_millis(rule.jitter_up_to_millis)));
}

let mut global_limiter = None;
let mut global_jitter = None;
if let Some(ref rule) = config.global {
let burst = NonZeroU32::new(rule.burst).unwrap();
let quota = build_quota(burst, Duration::from_secs(rule.period_secs));
global_limiter = Some(Arc::new(DefaultDirectRateLimiter::direct(quota)));
global_jitter = Some(Jitter::up_to(Duration::from_millis(rule.jitter_up_to_millis)));
}

Self {
config,
ip_jitter,
ip_limiter,
global_jitter,
global_limiter,
}
}

Expand All @@ -101,6 +115,7 @@ impl RateLimitBuilder {
None
}
}

pub fn ip_limit(&self, remote_ip: String, method_weights: MethodWeights) -> Option<IpRateLimitLayer> {
self.ip_limiter.as_ref().map(|ip_limiter| {
IpRateLimitLayer::new(
Expand All @@ -112,6 +127,16 @@ impl RateLimitBuilder {
})
}

pub fn global_limit(&self, method_weights: MethodWeights) -> Option<GlobalRateLimitLayer> {
self.global_limiter.as_ref().map(|global_limiter| {
GlobalRateLimitLayer::new(
global_limiter.clone(),
self.global_jitter.unwrap_or_default(),
method_weights,
)
})
}

// whether to use the X-Forwarded-For header to get the remote ip
pub fn use_xff(&self) -> bool {
self.config.use_xff
Expand Down
5 changes: 5 additions & 0 deletions src/extensions/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ impl SubwayServerBuilder {
async move {
let rpc_middleware =
RpcServiceBuilder::new()
.option_layer(
rate_limit_builder
.as_ref()
.and_then(|r| r.global_limit(rpc_method_weights.clone())),
)
.option_layer(
rate_limit_builder
.as_ref()
Expand Down

0 comments on commit 2673993

Please sign in to comment.