Skip to content

Commit

Permalink
ip rate limit (#146)
Browse files Browse the repository at this point in the history
* ip rate limit

* rename

* update config

* add benches

* update config

* rename

* xff

* update config

* fmt

* update
  • Loading branch information
ermalkaleci authored Dec 5, 2023
1 parent 43cbd17 commit e9ab79b
Show file tree
Hide file tree
Showing 9 changed files with 453 additions and 170 deletions.
4 changes: 4 additions & 0 deletions benches/bench/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use pprof::criterion::{Output, PProfProfiler};
use std::{sync::Arc, time::Duration};
use tokio::runtime::Runtime as TokioRuntime;

mod rate_limit;

use helpers::{
client::{rpc_params, ws_client, ws_handshake, ClientT, HeaderMap, SubscriptionClientT},
ASYNC_INJECT_CALL, KIB, SUB_METHOD_NAME, UNSUB_METHOD_NAME,
Expand Down Expand Up @@ -70,6 +72,7 @@ criterion_group!(
config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = AsyncBencher::websocket_benches_inject
);

criterion_main!(
sync_benches,
sync_benches_mid,
Expand All @@ -79,6 +82,7 @@ criterion_main!(
async_benches_slow,
subscriptions,
async_benches_inject,
rate_limit::rate_limit_benches,
);

const SERVER_ONE_ENDPOINT: &str = "127.0.0.1:9955";
Expand Down
52 changes: 52 additions & 0 deletions benches/bench/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use criterion::{criterion_group, Criterion};
use futures_util::future::BoxFuture;
use futures_util::FutureExt;
use governor::Jitter;
use governor::RateLimiter;
use jsonrpsee::server::middleware::rpc::RpcServiceT;
use jsonrpsee::types::{Request, ResponsePayload};
use jsonrpsee::MethodResponse;
use std::num::NonZeroU32;
use std::time::Duration;
use subway::extensions::rate_limit::{build_quota, ConnectionRateLimit, IpRateLimit};

#[derive(Clone)]
struct MockService;
impl RpcServiceT<'static> for MockService {
type Future = BoxFuture<'static, MethodResponse>;

fn call(&self, req: Request<'static>) -> Self::Future {
async move { MethodResponse::response(req.id, ResponsePayload::result("ok"), 1024) }.boxed()
}
}

pub fn connection_rate_limit(c: &mut Criterion) {
let rate_limit = ConnectionRateLimit::new(
MockService,
NonZeroU32::new(1000).unwrap(),
Duration::from_millis(1000),
Jitter::up_to(Duration::from_millis(10)),
);

c.bench_function("rate_limit/connection_rate_limit", |b| {
b.iter(|| rate_limit.call(Request::new("test".into(), None, jsonrpsee::types::Id::Number(1))))
});
}

pub fn ip_rate_limit(c: &mut Criterion) {
let burst = NonZeroU32::new(1000).unwrap();
let quota = build_quota(burst, Duration::from_millis(1000));
let limiter = RateLimiter::keyed(quota);
let rate_limit = IpRateLimit::new(
MockService,
"::1".to_string(),
std::sync::Arc::new(limiter),
Jitter::up_to(Duration::from_millis(10)),
);

c.bench_function("rate_limit/ip_rate_limit", |b| {
b.iter(|| rate_limit.call(Request::new("test".into(), None, jsonrpsee::types::Id::Number(1))))
});
}

criterion_group!(rate_limit_benches, connection_rate_limit, ip_rate_limit);
13 changes: 10 additions & 3 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@ extensions:
- path: /liveness
method: chain_getBlockHash
cors: all
rate_limit: # 20 requests per second per connection
burst: 20
period_secs: 1
rate_limit: # these are for demo purpose only, please adjust to your needs
connection: # 20 RPC requests per second per connection
burst: 20
period_secs: 1
ip: # 500 RPC requests per 10 seconds per ip
burst: 500
period_secs: 10
# use X-Forwarded-For header to get real ip, if available (e.g. behind a load balancer).
# WARNING: Use with caution, as this xff header can be forged.
use_xff: true # default is false

middlewares:
methods:
Expand Down
115 changes: 115 additions & 0 deletions src/extensions/rate_limit/connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultDirectRateLimiter, Jitter, RateLimiter};
use jsonrpsee::{
server::{middleware::rpc::RpcServiceT, types::Request},
MethodResponse,
};
use std::num::NonZeroU32;
use std::{sync::Arc, time::Duration};

#[derive(Clone)]
pub struct ConnectionRateLimitLayer {
burst: NonZeroU32,
period: Duration,
jitter: Jitter,
}

impl ConnectionRateLimitLayer {
pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self {
Self { burst, period, jitter }
}
}

impl<S> tower::Layer<S> for ConnectionRateLimitLayer {
type Service = ConnectionRateLimit<S>;

fn layer(&self, service: S) -> Self::Service {
ConnectionRateLimit::new(service, self.burst, self.period, self.jitter)
}
}

#[derive(Clone)]
pub struct ConnectionRateLimit<S> {
service: S,
limiter: Arc<DefaultDirectRateLimiter>,
jitter: Jitter,
}

impl<S> ConnectionRateLimit<S> {
pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self {
let quota = super::build_quota(burst, period);
let limiter = Arc::new(RateLimiter::direct(quota));
Self {
service,
limiter,
jitter,
}
}
}

impl<'a, S> RpcServiceT<'a> for ConnectionRateLimit<S>
where
S: RpcServiceT<'a> + Send + Sync + Clone + 'static,
{
type Future = BoxFuture<'a, MethodResponse>;

fn call(&self, req: Request<'a>) -> Self::Future {
let jitter = self.jitter;
let service = self.service.clone();
let limiter = self.limiter.clone();

async move {
limiter.until_ready_with_jitter(jitter).await;
service.call(req).await
}
.boxed()
}
}

#[cfg(test)]
mod tests {
use super::*;
use jsonrpsee::types::{Id, ResponsePayload};

#[derive(Clone)]
struct MockService;
impl RpcServiceT<'static> for MockService {
type Future = BoxFuture<'static, MethodResponse>;

fn call(&self, req: Request<'static>) -> Self::Future {
async move { MethodResponse::response(req.id, ResponsePayload::result("ok"), 1024) }.boxed()
}
}

#[tokio::test]
async fn rate_limit_works() {
let service = ConnectionRateLimit::new(
MockService,
NonZeroU32::new(10).unwrap(),
Duration::from_millis(100),
Jitter::up_to(Duration::from_millis(10)),
);

let batch = |service: ConnectionRateLimit<MockService>, count: usize, delay| async move {
tokio::time::sleep(Duration::from_millis(delay)).await;
let calls = (1..=count)
.map(|id| service.call(Request::new("test".into(), None, Id::Number(id as u64))))
.collect::<Vec<_>>();
let results = futures::future::join_all(calls).await;
assert_eq!(results.iter().filter(|r| r.is_success()).count(), count);
};

let start = tokio::time::Instant::now();
// background task to make calls
let batch1 = tokio::spawn(batch(service.clone(), 30, 0));
let batch2 = tokio::spawn(batch(service.clone(), 40, 200));
let batch3 = tokio::spawn(batch(service.clone(), 20, 300));
batch1.await.unwrap();
batch2.await.unwrap();
batch3.await.unwrap();
let duration = start.elapsed().as_millis();
println!("duration: {} ms", duration);
// should take between 800..900 millis. each 100ms period handles 10 calls
assert!(duration > 800 && duration < 900);
}
}
71 changes: 71 additions & 0 deletions src/extensions/rate_limit/ip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultKeyedRateLimiter, Jitter};
use jsonrpsee::{
server::{middleware::rpc::RpcServiceT, types::Request},
MethodResponse,
};
use std::sync::Arc;

#[derive(Clone)]
pub struct IpRateLimitLayer {
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
}

impl IpRateLimitLayer {
pub fn new(ip_addr: String, limiter: Arc<DefaultKeyedRateLimiter<String>>, jitter: Jitter) -> Self {
Self {
ip_addr,
limiter,
jitter,
}
}
}

impl<S> tower::Layer<S> for IpRateLimitLayer {
type Service = IpRateLimit<S>;

fn layer(&self, service: S) -> Self::Service {
IpRateLimit::new(service, self.ip_addr.clone(), self.limiter.clone(), self.jitter)
}
}

#[derive(Clone)]
pub struct IpRateLimit<S> {
service: S,
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
}

impl<S> IpRateLimit<S> {
pub fn new(service: S, ip_addr: String, limiter: Arc<DefaultKeyedRateLimiter<String>>, jitter: Jitter) -> Self {
Self {
service,
ip_addr,
limiter,
jitter,
}
}
}

impl<'a, S> RpcServiceT<'a> for IpRateLimit<S>
where
S: RpcServiceT<'a> + Send + Sync + Clone + 'static,
{
type Future = BoxFuture<'a, MethodResponse>;

fn call(&self, req: Request<'a>) -> Self::Future {
let ip_addr = self.ip_addr.clone();
let jitter = self.jitter;
let service = self.service.clone();
let limiter = self.limiter.clone();

async move {
limiter.until_key_ready_with_jitter(&ip_addr, jitter).await;
service.call(req).await
}
.boxed()
}
}
Loading

0 comments on commit e9ab79b

Please sign in to comment.