-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ip rate limit * rename * update config * add benches * update config * rename * xff * update config * fmt * update
- Loading branch information
1 parent
43cbd17
commit e9ab79b
Showing
9 changed files
with
453 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
use criterion::{criterion_group, Criterion}; | ||
use futures_util::future::BoxFuture; | ||
use futures_util::FutureExt; | ||
use governor::Jitter; | ||
use governor::RateLimiter; | ||
use jsonrpsee::server::middleware::rpc::RpcServiceT; | ||
use jsonrpsee::types::{Request, ResponsePayload}; | ||
use jsonrpsee::MethodResponse; | ||
use std::num::NonZeroU32; | ||
use std::time::Duration; | ||
use subway::extensions::rate_limit::{build_quota, ConnectionRateLimit, IpRateLimit}; | ||
|
||
#[derive(Clone)] | ||
struct MockService; | ||
impl RpcServiceT<'static> for MockService { | ||
type Future = BoxFuture<'static, MethodResponse>; | ||
|
||
fn call(&self, req: Request<'static>) -> Self::Future { | ||
async move { MethodResponse::response(req.id, ResponsePayload::result("ok"), 1024) }.boxed() | ||
} | ||
} | ||
|
||
pub fn connection_rate_limit(c: &mut Criterion) { | ||
let rate_limit = ConnectionRateLimit::new( | ||
MockService, | ||
NonZeroU32::new(1000).unwrap(), | ||
Duration::from_millis(1000), | ||
Jitter::up_to(Duration::from_millis(10)), | ||
); | ||
|
||
c.bench_function("rate_limit/connection_rate_limit", |b| { | ||
b.iter(|| rate_limit.call(Request::new("test".into(), None, jsonrpsee::types::Id::Number(1)))) | ||
}); | ||
} | ||
|
||
pub fn ip_rate_limit(c: &mut Criterion) { | ||
let burst = NonZeroU32::new(1000).unwrap(); | ||
let quota = build_quota(burst, Duration::from_millis(1000)); | ||
let limiter = RateLimiter::keyed(quota); | ||
let rate_limit = IpRateLimit::new( | ||
MockService, | ||
"::1".to_string(), | ||
std::sync::Arc::new(limiter), | ||
Jitter::up_to(Duration::from_millis(10)), | ||
); | ||
|
||
c.bench_function("rate_limit/ip_rate_limit", |b| { | ||
b.iter(|| rate_limit.call(Request::new("test".into(), None, jsonrpsee::types::Id::Number(1)))) | ||
}); | ||
} | ||
|
||
criterion_group!(rate_limit_benches, connection_rate_limit, ip_rate_limit); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
use futures::{future::BoxFuture, FutureExt}; | ||
use governor::{DefaultDirectRateLimiter, Jitter, RateLimiter}; | ||
use jsonrpsee::{ | ||
server::{middleware::rpc::RpcServiceT, types::Request}, | ||
MethodResponse, | ||
}; | ||
use std::num::NonZeroU32; | ||
use std::{sync::Arc, time::Duration}; | ||
|
||
#[derive(Clone)] | ||
pub struct ConnectionRateLimitLayer { | ||
burst: NonZeroU32, | ||
period: Duration, | ||
jitter: Jitter, | ||
} | ||
|
||
impl ConnectionRateLimitLayer { | ||
pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self { | ||
Self { burst, period, jitter } | ||
} | ||
} | ||
|
||
impl<S> tower::Layer<S> for ConnectionRateLimitLayer { | ||
type Service = ConnectionRateLimit<S>; | ||
|
||
fn layer(&self, service: S) -> Self::Service { | ||
ConnectionRateLimit::new(service, self.burst, self.period, self.jitter) | ||
} | ||
} | ||
|
||
#[derive(Clone)] | ||
pub struct ConnectionRateLimit<S> { | ||
service: S, | ||
limiter: Arc<DefaultDirectRateLimiter>, | ||
jitter: Jitter, | ||
} | ||
|
||
impl<S> ConnectionRateLimit<S> { | ||
pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self { | ||
let quota = super::build_quota(burst, period); | ||
let limiter = Arc::new(RateLimiter::direct(quota)); | ||
Self { | ||
service, | ||
limiter, | ||
jitter, | ||
} | ||
} | ||
} | ||
|
||
impl<'a, S> RpcServiceT<'a> for ConnectionRateLimit<S> | ||
where | ||
S: RpcServiceT<'a> + Send + Sync + Clone + 'static, | ||
{ | ||
type Future = BoxFuture<'a, MethodResponse>; | ||
|
||
fn call(&self, req: Request<'a>) -> Self::Future { | ||
let jitter = self.jitter; | ||
let service = self.service.clone(); | ||
let limiter = self.limiter.clone(); | ||
|
||
async move { | ||
limiter.until_ready_with_jitter(jitter).await; | ||
service.call(req).await | ||
} | ||
.boxed() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use jsonrpsee::types::{Id, ResponsePayload}; | ||
|
||
#[derive(Clone)] | ||
struct MockService; | ||
impl RpcServiceT<'static> for MockService { | ||
type Future = BoxFuture<'static, MethodResponse>; | ||
|
||
fn call(&self, req: Request<'static>) -> Self::Future { | ||
async move { MethodResponse::response(req.id, ResponsePayload::result("ok"), 1024) }.boxed() | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn rate_limit_works() { | ||
let service = ConnectionRateLimit::new( | ||
MockService, | ||
NonZeroU32::new(10).unwrap(), | ||
Duration::from_millis(100), | ||
Jitter::up_to(Duration::from_millis(10)), | ||
); | ||
|
||
let batch = |service: ConnectionRateLimit<MockService>, count: usize, delay| async move { | ||
tokio::time::sleep(Duration::from_millis(delay)).await; | ||
let calls = (1..=count) | ||
.map(|id| service.call(Request::new("test".into(), None, Id::Number(id as u64)))) | ||
.collect::<Vec<_>>(); | ||
let results = futures::future::join_all(calls).await; | ||
assert_eq!(results.iter().filter(|r| r.is_success()).count(), count); | ||
}; | ||
|
||
let start = tokio::time::Instant::now(); | ||
// background task to make calls | ||
let batch1 = tokio::spawn(batch(service.clone(), 30, 0)); | ||
let batch2 = tokio::spawn(batch(service.clone(), 40, 200)); | ||
let batch3 = tokio::spawn(batch(service.clone(), 20, 300)); | ||
batch1.await.unwrap(); | ||
batch2.await.unwrap(); | ||
batch3.await.unwrap(); | ||
let duration = start.elapsed().as_millis(); | ||
println!("duration: {} ms", duration); | ||
// should take between 800..900 millis. each 100ms period handles 10 calls | ||
assert!(duration > 800 && duration < 900); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
use futures::{future::BoxFuture, FutureExt}; | ||
use governor::{DefaultKeyedRateLimiter, Jitter}; | ||
use jsonrpsee::{ | ||
server::{middleware::rpc::RpcServiceT, types::Request}, | ||
MethodResponse, | ||
}; | ||
use std::sync::Arc; | ||
|
||
#[derive(Clone)] | ||
pub struct IpRateLimitLayer { | ||
ip_addr: String, | ||
limiter: Arc<DefaultKeyedRateLimiter<String>>, | ||
jitter: Jitter, | ||
} | ||
|
||
impl IpRateLimitLayer { | ||
pub fn new(ip_addr: String, limiter: Arc<DefaultKeyedRateLimiter<String>>, jitter: Jitter) -> Self { | ||
Self { | ||
ip_addr, | ||
limiter, | ||
jitter, | ||
} | ||
} | ||
} | ||
|
||
impl<S> tower::Layer<S> for IpRateLimitLayer { | ||
type Service = IpRateLimit<S>; | ||
|
||
fn layer(&self, service: S) -> Self::Service { | ||
IpRateLimit::new(service, self.ip_addr.clone(), self.limiter.clone(), self.jitter) | ||
} | ||
} | ||
|
||
#[derive(Clone)] | ||
pub struct IpRateLimit<S> { | ||
service: S, | ||
ip_addr: String, | ||
limiter: Arc<DefaultKeyedRateLimiter<String>>, | ||
jitter: Jitter, | ||
} | ||
|
||
impl<S> IpRateLimit<S> { | ||
pub fn new(service: S, ip_addr: String, limiter: Arc<DefaultKeyedRateLimiter<String>>, jitter: Jitter) -> Self { | ||
Self { | ||
service, | ||
ip_addr, | ||
limiter, | ||
jitter, | ||
} | ||
} | ||
} | ||
|
||
impl<'a, S> RpcServiceT<'a> for IpRateLimit<S> | ||
where | ||
S: RpcServiceT<'a> + Send + Sync + Clone + 'static, | ||
{ | ||
type Future = BoxFuture<'a, MethodResponse>; | ||
|
||
fn call(&self, req: Request<'a>) -> Self::Future { | ||
let ip_addr = self.ip_addr.clone(); | ||
let jitter = self.jitter; | ||
let service = self.service.clone(); | ||
let limiter = self.limiter.clone(); | ||
|
||
async move { | ||
limiter.until_key_ready_with_jitter(&ip_addr, jitter).await; | ||
service.call(req).await | ||
} | ||
.boxed() | ||
} | ||
} |
Oops, something went wrong.