From 43cbd171ff21464795cfdd24cf1027b6deb4b627 Mon Sep 17 00:00:00 2001 From: Ermal Kaleci Date: Tue, 28 Nov 2023 00:17:44 +0100 Subject: [PATCH] rate limit (#145) * rate limit * clippy * clippy * update * cleanup * make rate_limit an extension * change config * fix * update test * reduce default jitter * Update config.yml --------- Co-authored-by: Xiliang Chen --- Cargo.lock | 83 +++++++++++---- Cargo.toml | 4 +- config.yml | 3 + src/extensions/mod.rs | 2 + src/extensions/rate_limit/mod.rs | 170 +++++++++++++++++++++++++++++++ src/extensions/server/mod.rs | 12 ++- src/server.rs | 10 +- vendor/jsonrpsee | 2 +- 8 files changed, 261 insertions(+), 25 deletions(-) create mode 100644 src/extensions/rate_limit/mod.rs diff --git a/Cargo.lock b/Cargo.lock index d6ba6de..75ab028 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -157,7 +157,18 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "287272293e9d8c41773cec55e365490fe034813a2f172f502d6ddcf75b2f582b" dependencies = [ - "event-listener", + "event-listener 2.5.3", +] + +[[package]] +name = "async-lock" +version = "3.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea8b3453dd7cc96711834b75400d671b73e3656975fa68d9f277163b7f7e316" +dependencies = [ + "event-listener 4.0.0", + "event-listener-strategy", + "pin-project-lite", ] [[package]] @@ -563,6 +574,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "concurrent-queue" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console-api" version = "0.6.0" @@ -851,6 +871,27 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +[[package]] +name = "event-listener" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" +dependencies = [ + "event-listener 4.0.0", + "pin-project-lite", +] + [[package]] name = "fake-simd" version = "0.1.2" @@ -1131,6 +1172,8 @@ dependencies = [ [[package]] name = "governor" version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "821239e5672ff23e2a7060901fa622950bbd80b649cdaadd78d1c1767ed14eb4" dependencies = [ "cfg-if 1.0.0", "dashmap", @@ -1549,7 +1592,7 @@ dependencies = [ [[package]] name = "jsonrpsee" -version = "0.20.3" +version = "0.20.0" dependencies = [ "jsonrpsee-client-transport", "jsonrpsee-core", @@ -1565,7 +1608,7 @@ dependencies = [ [[package]] name = "jsonrpsee-client-transport" -version = "0.20.3" +version = "0.20.0" dependencies = [ "futures-channel", "futures-util", @@ -1586,10 +1629,10 @@ dependencies = [ [[package]] name = "jsonrpsee-core" -version = "0.20.3" +version = "0.20.0" dependencies = [ "anyhow", - "async-lock", + "async-lock 3.1.2", "async-trait", "beef", "futures-timer", @@ -1610,7 +1653,7 @@ dependencies = [ [[package]] name = "jsonrpsee-http-client" -version = "0.20.3" +version = "0.20.0" dependencies = [ "async-trait", "hyper", @@ -1628,7 +1671,7 @@ dependencies = [ [[package]] name = "jsonrpsee-proc-macros" -version = "0.20.3" +version = "0.20.0" dependencies = [ "heck", "proc-macro-crate", @@ -1639,13 +1682,14 @@ dependencies = [ [[package]] name = "jsonrpsee-server" -version = "0.20.3" +version = "0.20.0" dependencies = [ "futures-util", "http", "hyper", "jsonrpsee-core", "jsonrpsee-types", + "pin-project", "route-recognizer", "serde", "serde_json", @@ -1660,7 +1704,7 @@ dependencies = [ [[package]] name = "jsonrpsee-types" -version = "0.20.3" +version = "0.20.0" dependencies = [ "anyhow", "beef", @@ -1672,7 +1716,7 @@ dependencies = [ [[package]] name = "jsonrpsee-wasm-client" -version = "0.20.3" +version = "0.20.0" dependencies = [ "jsonrpsee-client-transport", "jsonrpsee-core", @@ -1681,7 +1725,7 @@ dependencies = [ [[package]] name = "jsonrpsee-ws-client" -version = "0.20.3" +version = "0.20.0" dependencies = [ "http", "jsonrpsee-client-transport", @@ -1879,7 +1923,7 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8017ec3548ffe7d4cef7ac0e12b044c01164a74c0f3119420faeaf13490ad8b" dependencies = [ - "async-lock", + "async-lock 2.8.0", "async-trait", "crossbeam-channel", "crossbeam-epoch", @@ -2157,6 +2201,12 @@ dependencies = [ "url", ] +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + [[package]] name = "parking_lot" version = "0.11.2" @@ -2313,11 +2363,10 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro-crate" -version = "1.3.1" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" +checksum = "7e8366a6159044a37876a2b9817124296703c586a5c92e2c53751fa06d8d43e8" dependencies = [ - "once_cell", "toml_edit", ] @@ -3307,9 +3356,9 @@ checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" [[package]] name = "toml_edit" -version = "0.19.15" +version = "0.20.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" +checksum = "70f427fce4d84c72b5b732388bf4a9f4531b53f74e2887e3ecb2481f68f66d81" dependencies = [ "indexmap 2.1.0", "toml_datetime", diff --git a/Cargo.toml b/Cargo.toml index 622d74a..026ffa9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,8 +43,8 @@ tracing = "0.1.37" tracing-serde = "0.1.3" tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] } -jsonrpsee = { version = "0.20.0", path = "./vendor/jsonrpsee/jsonrpsee", features = ["full"] } -governor = { version = "0.6.0", path = "./vendor/governor/governor" } +jsonrpsee = { path = "./vendor/jsonrpsee/jsonrpsee", features = ["full"] } +governor = "0.6.0" [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio", "html_reports"] } diff --git a/config.yml b/config.yml index 3021803..d0ce7fb 100644 --- a/config.yml +++ b/config.yml @@ -23,6 +23,9 @@ extensions: - path: /liveness method: chain_getBlockHash cors: all + rate_limit: # 20 requests per second per connection + burst: 20 + period_secs: 1 middlewares: methods: diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 2f15fd4..2cf1880 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -13,6 +13,7 @@ pub mod cache; pub mod client; pub mod event_bus; pub mod merge_subscription; +pub mod rate_limit; pub mod server; pub mod telemetry; @@ -136,4 +137,5 @@ define_all_extensions! { eth_api: api::EthApi, server: server::SubwayServerBuilder, event_bus: event_bus::EventBus, + rate_limit: rate_limit::RateLimitBuilder, } diff --git a/src/extensions/rate_limit/mod.rs b/src/extensions/rate_limit/mod.rs new file mode 100644 index 0000000..0e57fd9 --- /dev/null +++ b/src/extensions/rate_limit/mod.rs @@ -0,0 +1,170 @@ +use futures::{future::BoxFuture, FutureExt}; +use governor::{DefaultDirectRateLimiter, Jitter, Quota, RateLimiter}; +use jsonrpsee::{ + server::{middleware::rpc::RpcServiceT, types::Request}, + MethodResponse, +}; +use serde::Deserialize; +use std::num::NonZeroU32; +use std::{sync::Arc, time::Duration}; + +use super::{Extension, ExtensionRegistry}; + +#[derive(Deserialize, Debug, Copy, Clone, Default)] +pub struct RateLimitConfig { + // burst is the maximum number of requests that can be made in a period + pub burst: u32, + // period is the period of time in which the burst is allowed + #[serde(default = "default_period_secs")] + pub period_secs: u64, + // jitter_millis is the maximum amount of jitter to add to the rate limit + // this is to prevent a thundering herd problem https://en.wikipedia.org/wiki/Thundering_herd_problem + // e.g. if jitter_up_to_millis is 1000, then additional delay of random(0, 1000) milliseconds will be added + #[serde(default = "default_jitter_up_to_millis")] + pub jitter_up_to_millis: u64, +} + +fn default_period_secs() -> u64 { + 1 +} + +fn default_jitter_up_to_millis() -> u64 { + 100 +} + +pub struct RateLimitBuilder { + config: RateLimitConfig, +} + +#[async_trait::async_trait] +impl Extension for RateLimitBuilder { + type Config = RateLimitConfig; + + async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { + Ok(Self::new(*config)) + } +} + +impl RateLimitBuilder { + pub fn new(config: RateLimitConfig) -> Self { + assert!(config.burst > 0, "burst must be greater than 0"); + assert!(config.period_secs > 0, "period_secs must be greater than 0"); + Self { config } + } + pub fn build(&self) -> RateLimit { + let burst = NonZeroU32::new(self.config.burst).unwrap(); + let period = Duration::from_secs(self.config.period_secs); + let jitter = Jitter::up_to(Duration::from_millis(self.config.jitter_up_to_millis)); + RateLimit::new(burst, period, jitter) + } +} + +#[derive(Clone)] +pub struct RateLimit { + burst: NonZeroU32, + period: Duration, + jitter: Jitter, +} + +impl RateLimit { + pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self { + Self { burst, period, jitter } + } +} + +impl tower::Layer for RateLimit { + type Service = ConnectionRateLimit; + + fn layer(&self, service: S) -> Self::Service { + ConnectionRateLimit::new(service, self.burst, self.period, self.jitter) + } +} + +#[derive(Clone)] +pub struct ConnectionRateLimit { + service: S, + limiter: Arc, + jitter: Jitter, +} + +impl ConnectionRateLimit { + pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self { + let replenish_interval_ns = period.as_nanos() / (burst.get() as u128); + let quota = Quota::with_period(Duration::from_nanos(replenish_interval_ns as u64)) + .unwrap() + .allow_burst(burst); + let limiter = Arc::new(RateLimiter::direct(quota)); + Self { + service, + limiter, + jitter, + } + } +} + +impl<'a, S> RpcServiceT<'a> for ConnectionRateLimit +where + S: RpcServiceT<'a> + Send + Sync + Clone + 'static, +{ + type Future = BoxFuture<'a, MethodResponse>; + + fn call(&self, req: Request<'a>) -> Self::Future { + let jitter = self.jitter; + let service = self.service.clone(); + let limiter = self.limiter.clone(); + + async move { + limiter.until_ready_with_jitter(jitter).await; + service.call(req).await + } + .boxed() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use jsonrpsee::types::{Id, ResponsePayload}; + + #[derive(Clone)] + struct MockService; + impl RpcServiceT<'static> for MockService { + type Future = BoxFuture<'static, MethodResponse>; + + fn call(&self, req: Request<'static>) -> Self::Future { + async move { MethodResponse::response(req.id, ResponsePayload::result("ok"), 1024) }.boxed() + } + } + + #[tokio::test] + async fn rate_limit_works() { + let service = ConnectionRateLimit::new( + MockService, + NonZeroU32::new(10).unwrap(), + Duration::from_millis(100), + Jitter::up_to(Duration::from_millis(10)), + ); + + let batch = |service: ConnectionRateLimit, count: usize, delay| async move { + tokio::time::sleep(Duration::from_millis(delay)).await; + let calls = (1..=count) + .map(|id| service.call(Request::new("test".into(), None, Id::Number(id as u64)))) + .collect::>(); + let results = futures::future::join_all(calls).await; + assert_eq!(results.iter().filter(|r| r.is_success()).count(), count); + }; + + let start = tokio::time::Instant::now(); + // background task to make calls + let batch1 = tokio::spawn(batch(service.clone(), 30, 0)); + let batch2 = tokio::spawn(batch(service.clone(), 40, 200)); + let batch3 = tokio::spawn(batch(service.clone(), 20, 300)); + batch1.await.unwrap(); + batch2.await.unwrap(); + batch3.await.unwrap(); + let duration = start.elapsed().as_millis(); + println!("duration: {} ms", duration); + // should take between 800..900 millis. each 100ms period handles 10 calls + assert!(duration > 800 && duration < 900); + } +} diff --git a/src/extensions/server/mod.rs b/src/extensions/server/mod.rs index 59ac30f..1ce5caa 100644 --- a/src/extensions/server/mod.rs +++ b/src/extensions/server/mod.rs @@ -2,11 +2,13 @@ use std::{future::Future, net::SocketAddr}; use async_trait::async_trait; use http::header::HeaderValue; -use jsonrpsee::server::{RandomStringIdProvider, RpcModule, ServerBuilder, ServerHandle}; +use jsonrpsee::server::{ + middleware::rpc::RpcServiceBuilder, RandomStringIdProvider, RpcModule, ServerBuilder, ServerHandle, +}; use serde::Deserialize; use tower_http::cors::{AllowOrigin, CorsLayer}; -use super::{Extension, ExtensionRegistry}; +use super::{rate_limit::RateLimit, Extension, ExtensionRegistry}; use proxy_get_request::ProxyGetRequestLayer; use self::proxy_get_request::ProxyGetRequestMethod; @@ -88,8 +90,11 @@ impl SubwayServerBuilder { pub async fn build>>>( &self, + rate_limit: Option, builder: impl FnOnce() -> Fut, ) -> anyhow::Result<(SocketAddr, ServerHandle)> { + let rpc_middleware = RpcServiceBuilder::new().option_layer(rate_limit); + let service_builder = tower::ServiceBuilder::new() .layer(cors_layer(self.config.cors.clone()).expect("Invalid CORS config")) .layer( @@ -107,7 +112,8 @@ impl SubwayServerBuilder { ); let server = ServerBuilder::default() - .set_middleware(service_builder) + .set_rpc_middleware(rpc_middleware) + .set_http_middleware(service_builder) .max_connections(self.config.max_connections) .set_id_provider(RandomStringIdProvider::new(16)) .build((self.config.listen_address.as_str(), self.config.port)) diff --git a/src/server.rs b/src/server.rs index 8660917..0bf7a30 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,7 +12,7 @@ use serde_json::json; use crate::utils::TypeRegistryRef; use crate::{ config::Config, - extensions::server::SubwayServerBuilder, + extensions::{rate_limit::RateLimitBuilder, server::SubwayServerBuilder}, middlewares::{factory, CallRequest, Middlewares, SubscriptionRequest}, utils::{errors, telemetry}, }; @@ -43,11 +43,17 @@ pub async fn build(config: Config) -> anyhow::Result { .get::() .expect("Server extension not found"); + let rate_limit = extensions_registry + .read() + .await + .get::() + .map(|b| b.build()); + let request_timeout_seconds = server_builder.config.request_timeout_seconds; let registry = extensions_registry.clone(); let (addr, handle) = server_builder - .build(move || async move { + .build(rate_limit, move || async move { let mut module = RpcModule::new(()); let tracer = telemetry::Tracer::new("server"); diff --git a/vendor/jsonrpsee b/vendor/jsonrpsee index c23f453..98675a0 160000 --- a/vendor/jsonrpsee +++ b/vendor/jsonrpsee @@ -1 +1 @@ -Subproject commit c23f453bd6b5eb646082c212dc20edfd747eaf37 +Subproject commit 98675a06ee0d386256954c4fa21b3fb02345d045