Skip to content

Commit

Permalink
feat: pre-check for ip/connection rate limit
Browse files Browse the repository at this point in the history
  • Loading branch information
yjhmelody committed Jun 24, 2024
1 parent c93f208 commit 122fb02
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 11 deletions.
9 changes: 5 additions & 4 deletions src/extensions/rate_limit/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{extensions::rate_limit::MethodWeights, utils::errors};
use crate::extensions::rate_limit::MethodWeights;
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultDirectRateLimiter, Jitter, RateLimiter};
use jsonrpsee::{
Expand Down Expand Up @@ -75,9 +75,10 @@ where

async move {
if let Some(n) = NonZeroU32::new(weight) {
if limiter.until_n_ready_with_jitter(n, jitter).await.is_err() {
return MethodResponse::error(req.id, errors::failed("rate limit exceeded"));
}
limiter
.until_n_ready_with_jitter(n, jitter)
.await
.expect("check_n have been done during init");
}
service.call(req).await
}
Expand Down
9 changes: 3 additions & 6 deletions src/extensions/rate_limit/ip.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{extensions::rate_limit::MethodWeights, utils::errors};
use crate::extensions::rate_limit::MethodWeights;
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultKeyedRateLimiter, Jitter};
use jsonrpsee::{
Expand Down Expand Up @@ -86,13 +86,10 @@ where
let weight = self.method_weights.get(req.method_name());
async move {
if let Some(n) = NonZeroU32::new(weight) {
if limiter
limiter
.until_key_n_ready_with_jitter(&ip_addr, n, jitter)
.await
.is_err()
{
return MethodResponse::error(req.id, errors::failed("rate limit exceeded"));
}
.expect("check_n have been done during init");
}
service.call(req).await
}
Expand Down
40 changes: 40 additions & 0 deletions src/extensions/rate_limit/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use anyhow::bail;
use governor::{DefaultKeyedRateLimiter, Jitter, Quota, RateLimiter};
use serde::Deserialize;
use std::num::NonZeroU32;
Expand Down Expand Up @@ -90,6 +91,45 @@ impl RateLimitBuilder {
}
}
}

pub fn pre_check_connection(&self, method_weights: &MethodWeights) -> anyhow::Result<()> {
if let Some(ref rule) = self.config.connection {
let burst = NonZeroU32::new(rule.burst).unwrap();
let period = Duration::from_secs(rule.period_secs);
let quota = build_quota(burst, period);
let limiter = RateLimiter::direct(quota);

for (method, weight) in &method_weights.0 {
if let Some(n) = NonZeroU32::new(*weight) {
if limiter.check_n(n).is_err() {
bail!("`{method}` weight config too big for connection rate limit: {}", n);
}
}
}
}

Ok(())
}

pub fn pre_check_ip(&self, method_weights: &MethodWeights) -> anyhow::Result<()> {
if let Some(ref rule) = self.config.ip {
let burst = NonZeroU32::new(rule.burst).unwrap();
let period = Duration::from_secs(rule.period_secs);
let quota = build_quota(burst, period);
let limiter = RateLimiter::direct(quota);

for (method, weight) in &method_weights.0 {
if let Some(n) = NonZeroU32::new(*weight) {
if limiter.check_n(n).is_err() {
bail!("`{method}` weight config too big for ip rate limit: {}", n);
}
}
}
}

Ok(())
}

pub fn connection_limit(&self, method_weights: MethodWeights) -> Option<ConnectionRateLimitLayer> {
if let Some(ref rule) = self.config.connection {
let burst = NonZeroU32::new(rule.burst).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/extensions/rate_limit/weight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::config::RpcMethod;
use std::collections::BTreeMap;

#[derive(Clone, Debug, Default)]
pub struct MethodWeights(BTreeMap<String, u32>);
pub struct MethodWeights(pub(crate) BTreeMap<String, u32>);

impl MethodWeights {
pub fn add(&mut self, method: &str, weight: u32) {
Expand Down
2 changes: 2 additions & 0 deletions src/extensions/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ impl SubwayServerBuilder {
methods: Methods,
stop_handle: StopHandle,
rpc_metrics: RpcMetrics,
// TODO: this is not cheap.
svc_builder: TowerServiceBuilder<RpcMiddleware, HttpMiddleware>,
rate_limit_builder: Option<Arc<RateLimitBuilder>>,
// TODO: this is not cheap.
rpc_method_weights: MethodWeights,
}

Expand Down
6 changes: 6 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ pub async fn build(config: Config) -> anyhow::Result<SubwayServerHandle> {

let rpc_method_weights = MethodWeights::from_config(&config.rpcs.methods);

// pre-check stage
if let Some(r) = &rate_limit_builder {
r.pre_check_ip(&rpc_method_weights)?;
r.pre_check_connection(&rpc_method_weights)?;
}

let request_timeout_seconds = server_builder.config.request_timeout_seconds;

let metrics = get_rpc_metrics(&extensions_registry).await;
Expand Down

0 comments on commit 122fb02

Please sign in to comment.