Skip to content

Commit

Permalink
feat: support non-blocking mode for rate limit (#60)
Browse files Browse the repository at this point in the history
* feat: support no-blocking mode for rate limit

* remove duplicated error data

* fix

* no_blocking => non_blocking

* improve by suggestions

* use blocking option instead of non_blocking

* fix default blocking option

---------

Co-authored-by: koushiro <[email protected]>
  • Loading branch information
yjhmelody and koushiro authored Jul 25, 2024
1 parent d762758 commit 4093c6c
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 27 deletions.
25 changes: 15 additions & 10 deletions configs/demo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@ extensions:
cors: all
# these are for demo purpose only, please adjust to your needs.
# It's recommend to use big `period_secs` to reduce CPU costs.
#rate_limit:
#connection: # 10 RPC requests per 10 second per connection
# burst: 10
# period_secs: 10
#ip: # 20 RPC requests per 10 seconds per ip
# burst: 20
# period_secs: 10
#global: # 40 RPC requests per 10 seconds for global request
# burst: 40
# period_secs: 10
# rate_limit:
# connection: # 10 RPC requests per 10 second per connection
# burst: 10
# period_secs: 10
# # default: true
# blocking: false # false by default
# ip: # 20 RPC requests per 10 seconds per ip
# burst: 2
# period_secs: 10
# blocking: false # false by default
# global: # 40 RPC requests per 10 seconds for global request
# burst: 1
# period_secs: 10
# blocking: false # false by default
# use_xff: true # whether to use the X-Forwarded-For header to get the remote ip, false by default
whitelist:
eth_call:
# allow 0x01 to create contract.
Expand Down
35 changes: 30 additions & 5 deletions src/extensions/rate_limit/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::extensions::rate_limit::MethodWeights;
use crate::utils::errors;
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultDirectRateLimiter, Jitter, RateLimiter};
use jsonrpsee::{
Expand All @@ -13,6 +14,7 @@ pub struct ConnectionRateLimitLayer {
period: Duration,
jitter: Jitter,
method_weights: MethodWeights,
blocking: bool,
}

impl ConnectionRateLimitLayer {
Expand All @@ -22,8 +24,14 @@ impl ConnectionRateLimitLayer {
period,
jitter,
method_weights,
blocking: false,
}
}

pub fn blocking(mut self, blocking: bool) -> Self {
self.blocking = blocking;
self
}
}

impl<S> tower::Layer<S> for ConnectionRateLimitLayer {
Expand All @@ -37,6 +45,7 @@ impl<S> tower::Layer<S> for ConnectionRateLimitLayer {
self.jitter,
self.method_weights.clone(),
)
.blocking(self.blocking)
}
}

Expand All @@ -46,19 +55,26 @@ pub struct ConnectionRateLimit<S> {
limiter: Arc<DefaultDirectRateLimiter>,
jitter: Jitter,
method_weights: MethodWeights,
blocking: bool,
}

impl<S> ConnectionRateLimit<S> {
pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter, method_weights: MethodWeights) -> Self {
let quota = super::build_quota(burst, period);
let limiter = Arc::new(RateLimiter::direct(quota));
Self {
blocking: false,
service,
limiter,
jitter,
method_weights,
}
}

pub fn blocking(mut self, blocking: bool) -> Self {
self.blocking = blocking;
self
}
}

impl<'a, S> RpcServiceT<'a> for ConnectionRateLimit<S>
Expand All @@ -72,13 +88,21 @@ where
let service = self.service.clone();
let limiter = self.limiter.clone();
let weight = self.method_weights.get(req.method_name());
let blocking = self.blocking;

async move {
if let Some(n) = NonZeroU32::new(weight) {
limiter
.until_n_ready_with_jitter(n, jitter)
.await
.expect("check_n have been done during init");
if blocking {
limiter
.until_n_ready_with_jitter(n, jitter)
.await
.expect("check_n have been done during init");
} else {
match limiter.check_n(n).expect("check_n have been done during init") {
Ok(_) => {}
Err(_) => return MethodResponse::error(req.id, errors::reached_rate_limit()),
}
}
}
service.call(req).await
}
Expand Down Expand Up @@ -110,7 +134,8 @@ mod tests {
Duration::from_millis(100),
Jitter::up_to(Duration::from_millis(10)),
Default::default(),
);
)
.blocking(true);

let batch = |service: ConnectionRateLimit<MockService>, count: usize, delay| async move {
tokio::time::sleep(Duration::from_millis(delay)).await;
Expand Down
35 changes: 30 additions & 5 deletions src/extensions/rate_limit/global.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::extensions::rate_limit::MethodWeights;
use crate::utils::errors;
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultDirectRateLimiter, Jitter};
use jsonrpsee::{
Expand All @@ -12,6 +13,7 @@ pub struct GlobalRateLimitLayer {
limiter: Arc<DefaultDirectRateLimiter>,
jitter: Jitter,
method_weights: MethodWeights,
blocking: bool,
}

impl GlobalRateLimitLayer {
Expand All @@ -20,15 +22,22 @@ impl GlobalRateLimitLayer {
limiter,
jitter,
method_weights,
blocking: false,
}
}

pub fn blocking(mut self, blocking: bool) -> Self {
self.blocking = blocking;
self
}
}

impl<S> tower::Layer<S> for GlobalRateLimitLayer {
type Service = GlobalRateLimit<S>;

fn layer(&self, service: S) -> Self::Service {
GlobalRateLimit::new(service, self.limiter.clone(), self.jitter, self.method_weights.clone())
.blocking(self.blocking)
}
}

Expand All @@ -38,6 +47,7 @@ pub struct GlobalRateLimit<S> {
limiter: Arc<DefaultDirectRateLimiter>,
jitter: Jitter,
method_weights: MethodWeights,
blocking: bool,
}

impl<S> GlobalRateLimit<S> {
Expand All @@ -52,8 +62,14 @@ impl<S> GlobalRateLimit<S> {
limiter,
jitter,
method_weights,
blocking: false,
}
}

pub fn blocking(mut self, blocking: bool) -> Self {
self.blocking = blocking;
self
}
}

impl<'a, S> RpcServiceT<'a> for GlobalRateLimit<S>
Expand All @@ -67,13 +83,21 @@ where
let service = self.service.clone();
let limiter = self.limiter.clone();
let weight = self.method_weights.get(req.method_name());
let blocking = self.blocking;

async move {
if let Some(n) = NonZeroU32::new(weight) {
limiter
.until_n_ready_with_jitter(n, jitter)
.await
.expect("check_n have been done during init");
if blocking {
limiter
.until_n_ready_with_jitter(n, jitter)
.await
.expect("check_n have been done during init");
} else {
match limiter.check_n(n).expect("check_n have been done during init") {
Ok(_) => {}
Err(_) => return MethodResponse::error(req.id, errors::reached_rate_limit()),
}
}
}
service.call(req).await
}
Expand Down Expand Up @@ -112,7 +136,8 @@ mod tests {
limiter,
Jitter::up_to(Duration::from_millis(10)),
Default::default(),
);
)
.blocking(true);

let batch = |service: GlobalRateLimit<MockService>, count: usize, delay| async move {
tokio::time::sleep(Duration::from_millis(delay)).await;
Expand Down
36 changes: 32 additions & 4 deletions src/extensions/rate_limit/ip.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::extensions::rate_limit::MethodWeights;
use crate::utils::errors;
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultKeyedRateLimiter, Jitter};
use jsonrpsee::{
Expand All @@ -13,6 +14,7 @@ pub struct IpRateLimitLayer {
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
method_weights: MethodWeights,
blocking: bool,
}

impl IpRateLimitLayer {
Expand All @@ -27,8 +29,14 @@ impl IpRateLimitLayer {
limiter,
jitter,
method_weights,
blocking: false,
}
}

pub fn blocking(mut self, blocking: bool) -> Self {
self.blocking = blocking;
self
}
}

impl<S> tower::Layer<S> for IpRateLimitLayer {
Expand All @@ -42,6 +50,7 @@ impl<S> tower::Layer<S> for IpRateLimitLayer {
self.jitter,
self.method_weights.clone(),
)
.blocking(self.blocking)
}
}

Expand All @@ -52,6 +61,7 @@ pub struct IpRateLimit<S> {
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
method_weights: MethodWeights,
blocking: bool,
}

impl<S> IpRateLimit<S> {
Expand All @@ -68,8 +78,14 @@ impl<S> IpRateLimit<S> {
limiter,
jitter,
method_weights,
blocking: false,
}
}

pub fn blocking(mut self, blocking: bool) -> Self {
self.blocking = blocking;
self
}
}

impl<'a, S> RpcServiceT<'a> for IpRateLimit<S>
Expand All @@ -84,12 +100,24 @@ where
let service = self.service.clone();
let limiter = self.limiter.clone();
let weight = self.method_weights.get(req.method_name());
let blocking = self.blocking;

async move {
if let Some(n) = NonZeroU32::new(weight) {
limiter
.until_key_n_ready_with_jitter(&ip_addr, n, jitter)
.await
.expect("check_n have been done during init");
if blocking {
limiter
.until_key_n_ready_with_jitter(&ip_addr, n, jitter)
.await
.expect("check_n have been done during init");
} else {
match limiter
.check_key_n(&ip_addr, n)
.expect("check_n have been done during init")
{
Ok(_) => {}
Err(_) => return MethodResponse::error(req.id, errors::reached_rate_limit()),
}
}
}
service.call(req).await
}
Expand Down
Loading

0 comments on commit 4093c6c

Please sign in to comment.