Skip to content

Commit

Permalink
add option to customize method rate limit weight (AcalaNetwork#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
ermalkaleci authored Dec 7, 2023
1 parent e9ab79b commit 9a1124f
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 30 deletions.
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ tracing-serde = "0.1.3"
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] }

jsonrpsee = { path = "./vendor/jsonrpsee/jsonrpsee", features = ["full"] }
governor = "0.6.0"
governor = { path = "./vendor/governor/governor" }

[dev-dependencies]
criterion = { version = "0.5.1", features = ["async_tokio", "html_reports"] }
Expand Down
7 changes: 7 additions & 0 deletions benches/bench/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,41 +243,47 @@ fn config() -> Config {
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::ASYNC_FAST_CALL.to_string(),
params: vec![],
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::SYNC_MEM_CALL.to_string(),
params: vec![],
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::ASYNC_MEM_CALL.to_string(),
params: vec![],
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::SYNC_SLOW_CALL.to_string(),
params: vec![],
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::ASYNC_SLOW_CALL.to_string(),
params: vec![],
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::ASYNC_INJECT_CALL.to_string(),
Expand All @@ -298,6 +304,7 @@ fn config() -> Config {
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
],
subscriptions: vec![RpcSubscription {
Expand Down
2 changes: 2 additions & 0 deletions benches/bench/rate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub fn connection_rate_limit(c: &mut Criterion) {
NonZeroU32::new(1000).unwrap(),
Duration::from_millis(1000),
Jitter::up_to(Duration::from_millis(10)),
Default::default(),
);

c.bench_function("rate_limit/connection_rate_limit", |b| {
Expand All @@ -42,6 +43,7 @@ pub fn ip_rate_limit(c: &mut Criterion) {
"::1".to_string(),
std::sync::Arc::new(limiter),
Jitter::up_to(Duration::from_millis(10)),
Default::default(),
);

c.bench_function("rate_limit/ip_rate_limit", |b| {
Expand Down
15 changes: 15 additions & 0 deletions src/config/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ pub struct RpcMethod {

#[serde(default)]
pub delay_ms: Option<u64>,

/// This should not exceed max cell capacity. If it does,
/// method will return error. Burst size is the max cell capacity.
/// If rate limit is not configured, this will be ignored.
/// e.g. if rate limit is configured as 10r per 2s and rate_limit_weight is 10,
/// then only 1 call is allowed per 2s. If rate_limit_weight is 5, then 2 calls
/// are allowed per 2s. If rate_limit_weight is greater than 10, then method will
/// return error "rate limit exceeded".
/// Add this if you want to modify the default value of 1.
#[serde(default = "default_rate_limit_weight")]
pub rate_limit_weight: u32,
}

fn default_rate_limit_weight() -> u32 {
1
}

#[derive(Copy, Clone, Deserialize, Debug)]
Expand Down
34 changes: 27 additions & 7 deletions src/extensions/rate_limit/connection.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
use crate::{extensions::rate_limit::MethodWeights, utils::errors};
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultDirectRateLimiter, Jitter, RateLimiter};
use jsonrpsee::{
server::{middleware::rpc::RpcServiceT, types::Request},
MethodResponse,
};
use std::num::NonZeroU32;
use std::{sync::Arc, time::Duration};
use std::{num::NonZeroU32, sync::Arc, time::Duration};

#[derive(Clone)]
pub struct ConnectionRateLimitLayer {
burst: NonZeroU32,
period: Duration,
jitter: Jitter,
method_weights: MethodWeights,
}

impl ConnectionRateLimitLayer {
pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self {
Self { burst, period, jitter }
pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter, method_weights: MethodWeights) -> Self {
Self {
burst,
period,
jitter,
method_weights,
}
}
}

impl<S> tower::Layer<S> for ConnectionRateLimitLayer {
type Service = ConnectionRateLimit<S>;

fn layer(&self, service: S) -> Self::Service {
ConnectionRateLimit::new(service, self.burst, self.period, self.jitter)
ConnectionRateLimit::new(
service,
self.burst,
self.period,
self.jitter,
self.method_weights.clone(),
)
}
}

Expand All @@ -33,16 +45,18 @@ pub struct ConnectionRateLimit<S> {
service: S,
limiter: Arc<DefaultDirectRateLimiter>,
jitter: Jitter,
method_weights: MethodWeights,
}

impl<S> ConnectionRateLimit<S> {
pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self {
pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter, method_weights: MethodWeights) -> Self {
let quota = super::build_quota(burst, period);
let limiter = Arc::new(RateLimiter::direct(quota));
Self {
service,
limiter,
jitter,
method_weights,
}
}
}
Expand All @@ -57,9 +71,14 @@ where
let jitter = self.jitter;
let service = self.service.clone();
let limiter = self.limiter.clone();
let weight = self.method_weights.get(req.method_name());

async move {
limiter.until_ready_with_jitter(jitter).await;
if let Some(n) = NonZeroU32::new(weight) {
if limiter.until_n_ready_with_jitter(n, jitter).await.is_err() {
return MethodResponse::error(req.id, errors::failed("rate limit exceeded"));
}
}
service.call(req).await
}
.boxed()
Expand Down Expand Up @@ -88,6 +107,7 @@ mod tests {
NonZeroU32::new(10).unwrap(),
Duration::from_millis(100),
Jitter::up_to(Duration::from_millis(10)),
Default::default(),
);

let batch = |service: ConnectionRateLimit<MockService>, count: usize, delay| async move {
Expand Down
42 changes: 36 additions & 6 deletions src/extensions/rate_limit/ip.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
use crate::{extensions::rate_limit::MethodWeights, utils::errors};
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultKeyedRateLimiter, Jitter};
use jsonrpsee::{
server::{middleware::rpc::RpcServiceT, types::Request},
MethodResponse,
};
use std::sync::Arc;
use std::{num::NonZeroU32, sync::Arc};

#[derive(Clone)]
pub struct IpRateLimitLayer {
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
method_weights: MethodWeights,
}

impl IpRateLimitLayer {
pub fn new(ip_addr: String, limiter: Arc<DefaultKeyedRateLimiter<String>>, jitter: Jitter) -> Self {
pub fn new(
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
method_weights: MethodWeights,
) -> Self {
Self {
ip_addr,
limiter,
jitter,
method_weights,
}
}
}
Expand All @@ -27,7 +35,13 @@ impl<S> tower::Layer<S> for IpRateLimitLayer {
type Service = IpRateLimit<S>;

fn layer(&self, service: S) -> Self::Service {
IpRateLimit::new(service, self.ip_addr.clone(), self.limiter.clone(), self.jitter)
IpRateLimit::new(
service,
self.ip_addr.clone(),
self.limiter.clone(),
self.jitter,
self.method_weights.clone(),
)
}
}

Expand All @@ -37,15 +51,23 @@ pub struct IpRateLimit<S> {
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
method_weights: MethodWeights,
}

impl<S> IpRateLimit<S> {
pub fn new(service: S, ip_addr: String, limiter: Arc<DefaultKeyedRateLimiter<String>>, jitter: Jitter) -> Self {
pub fn new(
service: S,
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
method_weights: MethodWeights,
) -> Self {
Self {
service,
ip_addr,
limiter,
jitter,
method_weights,
}
}
}
Expand All @@ -61,9 +83,17 @@ where
let jitter = self.jitter;
let service = self.service.clone();
let limiter = self.limiter.clone();

let weight = self.method_weights.get(req.method_name());
async move {
limiter.until_key_ready_with_jitter(&ip_addr, jitter).await;
if let Some(n) = NonZeroU32::new(weight) {
if limiter
.until_key_n_ready_with_jitter(&ip_addr, n, jitter)
.await
.is_err()
{
return MethodResponse::error(req.id, errors::failed("rate limit exceeded"));
}
}
service.call(req).await
}
.boxed()
Expand Down
19 changes: 13 additions & 6 deletions src/extensions/rate_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use super::{Extension, ExtensionRegistry};

mod connection;
mod ip;
mod weight;
mod xff;

pub use connection::{ConnectionRateLimit, ConnectionRateLimitLayer};
pub use ip::{IpRateLimit, IpRateLimitLayer};
pub use weight::MethodWeights;
pub use xff::XFF;

#[derive(Deserialize, Debug, Clone, Default)]
Expand Down Expand Up @@ -88,20 +90,25 @@ impl RateLimitBuilder {
}
}
}
pub fn connection_limit(&self) -> Option<ConnectionRateLimitLayer> {
pub fn connection_limit(&self, method_weights: MethodWeights) -> Option<ConnectionRateLimitLayer> {
if let Some(ref rule) = self.config.connection {
let burst = NonZeroU32::new(rule.burst).unwrap();
let period = Duration::from_secs(rule.period_secs);
let jitter = Jitter::up_to(Duration::from_millis(rule.jitter_up_to_millis));
Some(ConnectionRateLimitLayer::new(burst, period, jitter))
Some(ConnectionRateLimitLayer::new(burst, period, jitter, method_weights))
} else {
None
}
}
pub fn ip_limit(&self, remote_ip: String) -> Option<IpRateLimitLayer> {
self.ip_limiter
.as_ref()
.map(|ip_limiter| IpRateLimitLayer::new(remote_ip, ip_limiter.clone(), self.ip_jitter.unwrap_or_default()))
pub fn ip_limit(&self, remote_ip: String, method_weights: MethodWeights) -> Option<IpRateLimitLayer> {
self.ip_limiter.as_ref().map(|ip_limiter| {
IpRateLimitLayer::new(
remote_ip,
ip_limiter.clone(),
self.ip_jitter.unwrap_or_default(),
method_weights,
)
})
}

// whether to use the X-Forwarded-For header to get the remote ip
Expand Down
25 changes: 25 additions & 0 deletions src/extensions/rate_limit/weight.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::config::RpcMethod;
use std::collections::BTreeMap;

#[derive(Clone, Debug, Default)]
pub struct MethodWeights(BTreeMap<String, u32>);

impl MethodWeights {
pub fn add(&mut self, method: &str, weight: u32) {
self.0.insert(method.to_owned(), weight);
}

pub fn get(&self, method: &str) -> u32 {
self.0.get(method).cloned().unwrap_or(1)
}
}

impl MethodWeights {
pub fn from_config(methods: &[RpcMethod]) -> Self {
let mut weights = MethodWeights::default();
for method in methods {
weights.add(&method.method, method.rate_limit_weight);
}
weights
}
}
Loading

0 comments on commit 9a1124f

Please sign in to comment.