Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to customize method rate limit weight #149

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ tracing-serde = "0.1.3"
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] }

jsonrpsee = { path = "./vendor/jsonrpsee/jsonrpsee", features = ["full"] }
governor = "0.6.0"
governor = { path = "./vendor/governor/governor" }

[dev-dependencies]
criterion = { version = "0.5.1", features = ["async_tokio", "html_reports"] }
Expand Down
7 changes: 7 additions & 0 deletions benches/bench/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,41 +243,47 @@ fn config() -> Config {
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::ASYNC_FAST_CALL.to_string(),
params: vec![],
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::SYNC_MEM_CALL.to_string(),
params: vec![],
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::ASYNC_MEM_CALL.to_string(),
params: vec![],
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::SYNC_SLOW_CALL.to_string(),
params: vec![],
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::ASYNC_SLOW_CALL.to_string(),
params: vec![],
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
RpcMethod {
method: helpers::ASYNC_INJECT_CALL.to_string(),
Expand All @@ -298,6 +304,7 @@ fn config() -> Config {
response: None,
cache: None,
delay_ms: None,
rate_limit_weight: 1,
},
],
subscriptions: vec![RpcSubscription {
Expand Down
2 changes: 2 additions & 0 deletions benches/bench/rate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub fn connection_rate_limit(c: &mut Criterion) {
NonZeroU32::new(1000).unwrap(),
Duration::from_millis(1000),
Jitter::up_to(Duration::from_millis(10)),
Default::default(),
);

c.bench_function("rate_limit/connection_rate_limit", |b| {
Expand All @@ -42,6 +43,7 @@ pub fn ip_rate_limit(c: &mut Criterion) {
"::1".to_string(),
std::sync::Arc::new(limiter),
Jitter::up_to(Duration::from_millis(10)),
Default::default(),
);

c.bench_function("rate_limit/ip_rate_limit", |b| {
Expand Down
15 changes: 15 additions & 0 deletions src/config/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ pub struct RpcMethod {

#[serde(default)]
pub delay_ms: Option<u64>,

/// This should not exceed max cell capacity. If it does,
/// method will return error. Burst size is the max cell capacity.
/// If rate limit is not configured, this will be ignored.
/// e.g. if rate limit is configured as 10r per 2s and rate_limit_weight is 10,
/// then only 1 call is allowed per 2s. If rate_limit_weight is 5, then 2 calls
/// are allowed per 2s. If rate_limit_weight is greater than 10, then method will
/// return error "rate limit exceeded".
/// Add this if you want to modify the default value of 1.
#[serde(default = "default_rate_limit_weight")]
pub rate_limit_weight: u32,
}

fn default_rate_limit_weight() -> u32 {
1
}

#[derive(Copy, Clone, Deserialize, Debug)]
Expand Down
34 changes: 27 additions & 7 deletions src/extensions/rate_limit/connection.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
use crate::{extensions::rate_limit::MethodWeights, utils::errors};
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultDirectRateLimiter, Jitter, RateLimiter};
use jsonrpsee::{
server::{middleware::rpc::RpcServiceT, types::Request},
MethodResponse,
};
use std::num::NonZeroU32;
use std::{sync::Arc, time::Duration};
use std::{num::NonZeroU32, sync::Arc, time::Duration};

#[derive(Clone)]
pub struct ConnectionRateLimitLayer {
burst: NonZeroU32,
period: Duration,
jitter: Jitter,
method_weights: MethodWeights,
}

impl ConnectionRateLimitLayer {
pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self {
Self { burst, period, jitter }
pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter, method_weights: MethodWeights) -> Self {
Self {
burst,
period,
jitter,
method_weights,
}
}
}

impl<S> tower::Layer<S> for ConnectionRateLimitLayer {
type Service = ConnectionRateLimit<S>;

fn layer(&self, service: S) -> Self::Service {
ConnectionRateLimit::new(service, self.burst, self.period, self.jitter)
ConnectionRateLimit::new(
service,
self.burst,
self.period,
self.jitter,
self.method_weights.clone(),
)
}
}

Expand All @@ -33,16 +45,18 @@ pub struct ConnectionRateLimit<S> {
service: S,
limiter: Arc<DefaultDirectRateLimiter>,
jitter: Jitter,
method_weights: MethodWeights,
}

impl<S> ConnectionRateLimit<S> {
pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self {
pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter, method_weights: MethodWeights) -> Self {
let quota = super::build_quota(burst, period);
let limiter = Arc::new(RateLimiter::direct(quota));
Self {
service,
limiter,
jitter,
method_weights,
}
}
}
Expand All @@ -57,9 +71,14 @@ where
let jitter = self.jitter;
let service = self.service.clone();
let limiter = self.limiter.clone();
let weight = self.method_weights.get(req.method_name());

async move {
limiter.until_ready_with_jitter(jitter).await;
if let Some(n) = NonZeroU32::new(weight) {
if limiter.until_n_ready_with_jitter(n, jitter).await.is_err() {
return MethodResponse::error(req.id, errors::failed("rate limit exceeded"));
}
}
service.call(req).await
}
.boxed()
Expand Down Expand Up @@ -88,6 +107,7 @@ mod tests {
NonZeroU32::new(10).unwrap(),
Duration::from_millis(100),
Jitter::up_to(Duration::from_millis(10)),
Default::default(),
);

let batch = |service: ConnectionRateLimit<MockService>, count: usize, delay| async move {
Expand Down
42 changes: 36 additions & 6 deletions src/extensions/rate_limit/ip.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
use crate::{extensions::rate_limit::MethodWeights, utils::errors};
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultKeyedRateLimiter, Jitter};
use jsonrpsee::{
server::{middleware::rpc::RpcServiceT, types::Request},
MethodResponse,
};
use std::sync::Arc;
use std::{num::NonZeroU32, sync::Arc};

#[derive(Clone)]
pub struct IpRateLimitLayer {
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
method_weights: MethodWeights,
}

impl IpRateLimitLayer {
pub fn new(ip_addr: String, limiter: Arc<DefaultKeyedRateLimiter<String>>, jitter: Jitter) -> Self {
pub fn new(
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
method_weights: MethodWeights,
) -> Self {
Self {
ip_addr,
limiter,
jitter,
method_weights,
}
}
}
Expand All @@ -27,7 +35,13 @@ impl<S> tower::Layer<S> for IpRateLimitLayer {
type Service = IpRateLimit<S>;

fn layer(&self, service: S) -> Self::Service {
IpRateLimit::new(service, self.ip_addr.clone(), self.limiter.clone(), self.jitter)
IpRateLimit::new(
service,
self.ip_addr.clone(),
self.limiter.clone(),
self.jitter,
self.method_weights.clone(),
)
}
}

Expand All @@ -37,15 +51,23 @@ pub struct IpRateLimit<S> {
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
method_weights: MethodWeights,
}

impl<S> IpRateLimit<S> {
pub fn new(service: S, ip_addr: String, limiter: Arc<DefaultKeyedRateLimiter<String>>, jitter: Jitter) -> Self {
pub fn new(
service: S,
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
method_weights: MethodWeights,
) -> Self {
Self {
service,
ip_addr,
limiter,
jitter,
method_weights,
}
}
}
Expand All @@ -61,9 +83,17 @@ where
let jitter = self.jitter;
let service = self.service.clone();
let limiter = self.limiter.clone();

let weight = self.method_weights.get(req.method_name());
async move {
limiter.until_key_ready_with_jitter(&ip_addr, jitter).await;
if let Some(n) = NonZeroU32::new(weight) {
if limiter
.until_key_n_ready_with_jitter(&ip_addr, n, jitter)
.await
.is_err()
{
return MethodResponse::error(req.id, errors::failed("rate limit exceeded"));
}
}
service.call(req).await
}
.boxed()
Expand Down
19 changes: 13 additions & 6 deletions src/extensions/rate_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use super::{Extension, ExtensionRegistry};

mod connection;
mod ip;
mod weight;
mod xff;

pub use connection::{ConnectionRateLimit, ConnectionRateLimitLayer};
pub use ip::{IpRateLimit, IpRateLimitLayer};
pub use weight::MethodWeights;
pub use xff::XFF;

#[derive(Deserialize, Debug, Clone, Default)]
Expand Down Expand Up @@ -88,20 +90,25 @@ impl RateLimitBuilder {
}
}
}
pub fn connection_limit(&self) -> Option<ConnectionRateLimitLayer> {
pub fn connection_limit(&self, method_weights: MethodWeights) -> Option<ConnectionRateLimitLayer> {
if let Some(ref rule) = self.config.connection {
let burst = NonZeroU32::new(rule.burst).unwrap();
let period = Duration::from_secs(rule.period_secs);
let jitter = Jitter::up_to(Duration::from_millis(rule.jitter_up_to_millis));
Some(ConnectionRateLimitLayer::new(burst, period, jitter))
Some(ConnectionRateLimitLayer::new(burst, period, jitter, method_weights))
} else {
None
}
}
pub fn ip_limit(&self, remote_ip: String) -> Option<IpRateLimitLayer> {
self.ip_limiter
.as_ref()
.map(|ip_limiter| IpRateLimitLayer::new(remote_ip, ip_limiter.clone(), self.ip_jitter.unwrap_or_default()))
pub fn ip_limit(&self, remote_ip: String, method_weights: MethodWeights) -> Option<IpRateLimitLayer> {
self.ip_limiter.as_ref().map(|ip_limiter| {
IpRateLimitLayer::new(
remote_ip,
ip_limiter.clone(),
self.ip_jitter.unwrap_or_default(),
method_weights,
)
})
}

// whether to use the X-Forwarded-For header to get the remote ip
Expand Down
25 changes: 25 additions & 0 deletions src/extensions/rate_limit/weight.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::config::RpcMethod;
use std::collections::BTreeMap;

#[derive(Clone, Debug, Default)]
pub struct MethodWeights(BTreeMap<String, u32>);

impl MethodWeights {
pub fn add(&mut self, method: &str, weight: u32) {
self.0.insert(method.to_owned(), weight);
}

pub fn get(&self, method: &str) -> u32 {
self.0.get(method).cloned().unwrap_or(1)
}
}

impl MethodWeights {
pub fn from_config(methods: &[RpcMethod]) -> Self {
let mut weights = MethodWeights::default();
for method in methods {
weights.add(&method.method, method.rate_limit_weight);
}
weights
}
}
Loading