From e9ab79b1e42c5af871548dd2adc52057361c8ebf Mon Sep 17 00:00:00 2001 From: Ermal Kaleci Date: Tue, 5 Dec 2023 09:12:32 +0100 Subject: [PATCH] ip rate limit (#146) * ip rate limit * rename * update config * add benches * update config * rename * xff * update config * fmt * update --- benches/bench/main.rs | 4 + benches/bench/rate_limit.rs | 52 +++++++ config.yml | 13 +- src/extensions/rate_limit/connection.rs | 115 +++++++++++++++ src/extensions/rate_limit/ip.rs | 71 +++++++++ src/extensions/rate_limit/mod.rs | 188 +++++++++--------------- src/extensions/rate_limit/xff.rs | 45 ++++++ src/extensions/server/mod.rs | 127 ++++++++++------ src/server.rs | 8 +- 9 files changed, 453 insertions(+), 170 deletions(-) create mode 100644 benches/bench/rate_limit.rs create mode 100644 src/extensions/rate_limit/connection.rs create mode 100644 src/extensions/rate_limit/ip.rs create mode 100644 src/extensions/rate_limit/xff.rs diff --git a/benches/bench/main.rs b/benches/bench/main.rs index 2be9aee..887246d 100644 --- a/benches/bench/main.rs +++ b/benches/bench/main.rs @@ -6,6 +6,8 @@ use pprof::criterion::{Output, PProfProfiler}; use std::{sync::Arc, time::Duration}; use tokio::runtime::Runtime as TokioRuntime; +mod rate_limit; + use helpers::{ client::{rpc_params, ws_client, ws_handshake, ClientT, HeaderMap, SubscriptionClientT}, ASYNC_INJECT_CALL, KIB, SUB_METHOD_NAME, UNSUB_METHOD_NAME, @@ -70,6 +72,7 @@ criterion_group!( config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); targets = AsyncBencher::websocket_benches_inject ); + criterion_main!( sync_benches, sync_benches_mid, @@ -79,6 +82,7 @@ criterion_main!( async_benches_slow, subscriptions, async_benches_inject, + rate_limit::rate_limit_benches, ); const SERVER_ONE_ENDPOINT: &str = "127.0.0.1:9955"; diff --git a/benches/bench/rate_limit.rs b/benches/bench/rate_limit.rs new file mode 100644 index 0000000..b17bda3 --- /dev/null +++ b/benches/bench/rate_limit.rs @@ -0,0 +1,52 @@ +use criterion::{criterion_group, Criterion}; +use futures_util::future::BoxFuture; +use futures_util::FutureExt; +use governor::Jitter; +use governor::RateLimiter; +use jsonrpsee::server::middleware::rpc::RpcServiceT; +use jsonrpsee::types::{Request, ResponsePayload}; +use jsonrpsee::MethodResponse; +use std::num::NonZeroU32; +use std::time::Duration; +use subway::extensions::rate_limit::{build_quota, ConnectionRateLimit, IpRateLimit}; + +#[derive(Clone)] +struct MockService; +impl RpcServiceT<'static> for MockService { + type Future = BoxFuture<'static, MethodResponse>; + + fn call(&self, req: Request<'static>) -> Self::Future { + async move { MethodResponse::response(req.id, ResponsePayload::result("ok"), 1024) }.boxed() + } +} + +pub fn connection_rate_limit(c: &mut Criterion) { + let rate_limit = ConnectionRateLimit::new( + MockService, + NonZeroU32::new(1000).unwrap(), + Duration::from_millis(1000), + Jitter::up_to(Duration::from_millis(10)), + ); + + c.bench_function("rate_limit/connection_rate_limit", |b| { + b.iter(|| rate_limit.call(Request::new("test".into(), None, jsonrpsee::types::Id::Number(1)))) + }); +} + +pub fn ip_rate_limit(c: &mut Criterion) { + let burst = NonZeroU32::new(1000).unwrap(); + let quota = build_quota(burst, Duration::from_millis(1000)); + let limiter = RateLimiter::keyed(quota); + let rate_limit = IpRateLimit::new( + MockService, + "::1".to_string(), + std::sync::Arc::new(limiter), + Jitter::up_to(Duration::from_millis(10)), + ); + + c.bench_function("rate_limit/ip_rate_limit", |b| { + b.iter(|| rate_limit.call(Request::new("test".into(), None, jsonrpsee::types::Id::Number(1)))) + }); +} + +criterion_group!(rate_limit_benches, connection_rate_limit, ip_rate_limit); diff --git a/config.yml b/config.yml index d0ce7fb..c284bf2 100644 --- a/config.yml +++ b/config.yml @@ -23,9 +23,16 @@ extensions: - path: /liveness method: chain_getBlockHash cors: all - rate_limit: # 20 requests per second per connection - burst: 20 - period_secs: 1 + rate_limit: # these are for demo purpose only, please adjust to your needs + connection: # 20 RPC requests per second per connection + burst: 20 + period_secs: 1 + ip: # 500 RPC requests per 10 seconds per ip + burst: 500 + period_secs: 10 + # use X-Forwarded-For header to get real ip, if available (e.g. behind a load balancer). + # WARNING: Use with caution, as this xff header can be forged. + use_xff: true # default is false middlewares: methods: diff --git a/src/extensions/rate_limit/connection.rs b/src/extensions/rate_limit/connection.rs new file mode 100644 index 0000000..06c155d --- /dev/null +++ b/src/extensions/rate_limit/connection.rs @@ -0,0 +1,115 @@ +use futures::{future::BoxFuture, FutureExt}; +use governor::{DefaultDirectRateLimiter, Jitter, RateLimiter}; +use jsonrpsee::{ + server::{middleware::rpc::RpcServiceT, types::Request}, + MethodResponse, +}; +use std::num::NonZeroU32; +use std::{sync::Arc, time::Duration}; + +#[derive(Clone)] +pub struct ConnectionRateLimitLayer { + burst: NonZeroU32, + period: Duration, + jitter: Jitter, +} + +impl ConnectionRateLimitLayer { + pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self { + Self { burst, period, jitter } + } +} + +impl tower::Layer for ConnectionRateLimitLayer { + type Service = ConnectionRateLimit; + + fn layer(&self, service: S) -> Self::Service { + ConnectionRateLimit::new(service, self.burst, self.period, self.jitter) + } +} + +#[derive(Clone)] +pub struct ConnectionRateLimit { + service: S, + limiter: Arc, + jitter: Jitter, +} + +impl ConnectionRateLimit { + pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self { + let quota = super::build_quota(burst, period); + let limiter = Arc::new(RateLimiter::direct(quota)); + Self { + service, + limiter, + jitter, + } + } +} + +impl<'a, S> RpcServiceT<'a> for ConnectionRateLimit +where + S: RpcServiceT<'a> + Send + Sync + Clone + 'static, +{ + type Future = BoxFuture<'a, MethodResponse>; + + fn call(&self, req: Request<'a>) -> Self::Future { + let jitter = self.jitter; + let service = self.service.clone(); + let limiter = self.limiter.clone(); + + async move { + limiter.until_ready_with_jitter(jitter).await; + service.call(req).await + } + .boxed() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use jsonrpsee::types::{Id, ResponsePayload}; + + #[derive(Clone)] + struct MockService; + impl RpcServiceT<'static> for MockService { + type Future = BoxFuture<'static, MethodResponse>; + + fn call(&self, req: Request<'static>) -> Self::Future { + async move { MethodResponse::response(req.id, ResponsePayload::result("ok"), 1024) }.boxed() + } + } + + #[tokio::test] + async fn rate_limit_works() { + let service = ConnectionRateLimit::new( + MockService, + NonZeroU32::new(10).unwrap(), + Duration::from_millis(100), + Jitter::up_to(Duration::from_millis(10)), + ); + + let batch = |service: ConnectionRateLimit, count: usize, delay| async move { + tokio::time::sleep(Duration::from_millis(delay)).await; + let calls = (1..=count) + .map(|id| service.call(Request::new("test".into(), None, Id::Number(id as u64)))) + .collect::>(); + let results = futures::future::join_all(calls).await; + assert_eq!(results.iter().filter(|r| r.is_success()).count(), count); + }; + + let start = tokio::time::Instant::now(); + // background task to make calls + let batch1 = tokio::spawn(batch(service.clone(), 30, 0)); + let batch2 = tokio::spawn(batch(service.clone(), 40, 200)); + let batch3 = tokio::spawn(batch(service.clone(), 20, 300)); + batch1.await.unwrap(); + batch2.await.unwrap(); + batch3.await.unwrap(); + let duration = start.elapsed().as_millis(); + println!("duration: {} ms", duration); + // should take between 800..900 millis. each 100ms period handles 10 calls + assert!(duration > 800 && duration < 900); + } +} diff --git a/src/extensions/rate_limit/ip.rs b/src/extensions/rate_limit/ip.rs new file mode 100644 index 0000000..365739a --- /dev/null +++ b/src/extensions/rate_limit/ip.rs @@ -0,0 +1,71 @@ +use futures::{future::BoxFuture, FutureExt}; +use governor::{DefaultKeyedRateLimiter, Jitter}; +use jsonrpsee::{ + server::{middleware::rpc::RpcServiceT, types::Request}, + MethodResponse, +}; +use std::sync::Arc; + +#[derive(Clone)] +pub struct IpRateLimitLayer { + ip_addr: String, + limiter: Arc>, + jitter: Jitter, +} + +impl IpRateLimitLayer { + pub fn new(ip_addr: String, limiter: Arc>, jitter: Jitter) -> Self { + Self { + ip_addr, + limiter, + jitter, + } + } +} + +impl tower::Layer for IpRateLimitLayer { + type Service = IpRateLimit; + + fn layer(&self, service: S) -> Self::Service { + IpRateLimit::new(service, self.ip_addr.clone(), self.limiter.clone(), self.jitter) + } +} + +#[derive(Clone)] +pub struct IpRateLimit { + service: S, + ip_addr: String, + limiter: Arc>, + jitter: Jitter, +} + +impl IpRateLimit { + pub fn new(service: S, ip_addr: String, limiter: Arc>, jitter: Jitter) -> Self { + Self { + service, + ip_addr, + limiter, + jitter, + } + } +} + +impl<'a, S> RpcServiceT<'a> for IpRateLimit +where + S: RpcServiceT<'a> + Send + Sync + Clone + 'static, +{ + type Future = BoxFuture<'a, MethodResponse>; + + fn call(&self, req: Request<'a>) -> Self::Future { + let ip_addr = self.ip_addr.clone(); + let jitter = self.jitter; + let service = self.service.clone(); + let limiter = self.limiter.clone(); + + async move { + limiter.until_key_ready_with_jitter(&ip_addr, jitter).await; + service.call(req).await + } + .boxed() + } +} diff --git a/src/extensions/rate_limit/mod.rs b/src/extensions/rate_limit/mod.rs index 0e57fd9..3d7f7bf 100644 --- a/src/extensions/rate_limit/mod.rs +++ b/src/extensions/rate_limit/mod.rs @@ -1,17 +1,28 @@ -use futures::{future::BoxFuture, FutureExt}; -use governor::{DefaultDirectRateLimiter, Jitter, Quota, RateLimiter}; -use jsonrpsee::{ - server::{middleware::rpc::RpcServiceT, types::Request}, - MethodResponse, -}; +use governor::{DefaultKeyedRateLimiter, Jitter, Quota, RateLimiter}; use serde::Deserialize; use std::num::NonZeroU32; use std::{sync::Arc, time::Duration}; use super::{Extension, ExtensionRegistry}; -#[derive(Deserialize, Debug, Copy, Clone, Default)] +mod connection; +mod ip; +mod xff; + +pub use connection::{ConnectionRateLimit, ConnectionRateLimitLayer}; +pub use ip::{IpRateLimit, IpRateLimitLayer}; +pub use xff::XFF; + +#[derive(Deserialize, Debug, Clone, Default)] pub struct RateLimitConfig { + pub ip: Option, + pub connection: Option, + #[serde(default)] + pub use_xff: bool, +} + +#[derive(Deserialize, Debug, Clone, Default)] +pub struct Rule { // burst is the maximum number of requests that can be made in a period pub burst: u32, // period is the period of time in which the burst is allowed @@ -34,6 +45,8 @@ fn default_jitter_up_to_millis() -> u64 { pub struct RateLimitBuilder { config: RateLimitConfig, + ip_jitter: Option, + ip_limiter: Option>>, } #[async_trait::async_trait] @@ -41,130 +54,65 @@ impl Extension for RateLimitBuilder { type Config = RateLimitConfig; async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { - Ok(Self::new(*config)) + Ok(Self::new(config.clone())) } } impl RateLimitBuilder { pub fn new(config: RateLimitConfig) -> Self { - assert!(config.burst > 0, "burst must be greater than 0"); - assert!(config.period_secs > 0, "period_secs must be greater than 0"); - Self { config } - } - pub fn build(&self) -> RateLimit { - let burst = NonZeroU32::new(self.config.burst).unwrap(); - let period = Duration::from_secs(self.config.period_secs); - let jitter = Jitter::up_to(Duration::from_millis(self.config.jitter_up_to_millis)); - RateLimit::new(burst, period, jitter) - } -} - -#[derive(Clone)] -pub struct RateLimit { - burst: NonZeroU32, - period: Duration, - jitter: Jitter, -} - -impl RateLimit { - pub fn new(burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self { - Self { burst, period, jitter } - } -} - -impl tower::Layer for RateLimit { - type Service = ConnectionRateLimit; - - fn layer(&self, service: S) -> Self::Service { - ConnectionRateLimit::new(service, self.burst, self.period, self.jitter) - } -} - -#[derive(Clone)] -pub struct ConnectionRateLimit { - service: S, - limiter: Arc, - jitter: Jitter, -} - -impl ConnectionRateLimit { - pub fn new(service: S, burst: NonZeroU32, period: Duration, jitter: Jitter) -> Self { - let replenish_interval_ns = period.as_nanos() / (burst.get() as u128); - let quota = Quota::with_period(Duration::from_nanos(replenish_interval_ns as u64)) - .unwrap() - .allow_burst(burst); - let limiter = Arc::new(RateLimiter::direct(quota)); - Self { - service, - limiter, - jitter, + // make sure all rules are valid + if let Some(ref rule) = config.ip { + assert!(rule.burst > 0, "burst must be greater than 0"); + assert!(rule.period_secs > 0, "period_secs must be greater than 0"); + } + if let Some(ref rule) = config.connection { + assert!(rule.burst > 0, "burst must be greater than 0"); + assert!(rule.period_secs > 0, "period_secs must be greater than 0"); } - } -} - -impl<'a, S> RpcServiceT<'a> for ConnectionRateLimit -where - S: RpcServiceT<'a> + Send + Sync + Clone + 'static, -{ - type Future = BoxFuture<'a, MethodResponse>; - - fn call(&self, req: Request<'a>) -> Self::Future { - let jitter = self.jitter; - let service = self.service.clone(); - let limiter = self.limiter.clone(); - async move { - limiter.until_ready_with_jitter(jitter).await; - service.call(req).await + if let Some(ref rule) = config.ip { + let burst = NonZeroU32::new(rule.burst).unwrap(); + let quota = build_quota(burst, Duration::from_secs(rule.period_secs)); + let ip_limiter = Some(Arc::new(RateLimiter::keyed(quota))); + let ip_jitter = Some(Jitter::up_to(Duration::from_millis(rule.jitter_up_to_millis))); + Self { + config, + ip_jitter, + ip_limiter, + } + } else { + Self { + config, + ip_jitter: None, + ip_limiter: None, + } } - .boxed() } -} - -#[cfg(test)] -mod tests { - use super::*; - use jsonrpsee::types::{Id, ResponsePayload}; - - #[derive(Clone)] - struct MockService; - impl RpcServiceT<'static> for MockService { - type Future = BoxFuture<'static, MethodResponse>; - - fn call(&self, req: Request<'static>) -> Self::Future { - async move { MethodResponse::response(req.id, ResponsePayload::result("ok"), 1024) }.boxed() + pub fn connection_limit(&self) -> Option { + if let Some(ref rule) = self.config.connection { + let burst = NonZeroU32::new(rule.burst).unwrap(); + let period = Duration::from_secs(rule.period_secs); + let jitter = Jitter::up_to(Duration::from_millis(rule.jitter_up_to_millis)); + Some(ConnectionRateLimitLayer::new(burst, period, jitter)) + } else { + None } } + pub fn ip_limit(&self, remote_ip: String) -> Option { + self.ip_limiter + .as_ref() + .map(|ip_limiter| IpRateLimitLayer::new(remote_ip, ip_limiter.clone(), self.ip_jitter.unwrap_or_default())) + } - #[tokio::test] - async fn rate_limit_works() { - let service = ConnectionRateLimit::new( - MockService, - NonZeroU32::new(10).unwrap(), - Duration::from_millis(100), - Jitter::up_to(Duration::from_millis(10)), - ); - - let batch = |service: ConnectionRateLimit, count: usize, delay| async move { - tokio::time::sleep(Duration::from_millis(delay)).await; - let calls = (1..=count) - .map(|id| service.call(Request::new("test".into(), None, Id::Number(id as u64)))) - .collect::>(); - let results = futures::future::join_all(calls).await; - assert_eq!(results.iter().filter(|r| r.is_success()).count(), count); - }; - - let start = tokio::time::Instant::now(); - // background task to make calls - let batch1 = tokio::spawn(batch(service.clone(), 30, 0)); - let batch2 = tokio::spawn(batch(service.clone(), 40, 200)); - let batch3 = tokio::spawn(batch(service.clone(), 20, 300)); - batch1.await.unwrap(); - batch2.await.unwrap(); - batch3.await.unwrap(); - let duration = start.elapsed().as_millis(); - println!("duration: {} ms", duration); - // should take between 800..900 millis. each 100ms period handles 10 calls - assert!(duration > 800 && duration < 900); + // whether to use the X-Forwarded-For header to get the remote ip + pub fn use_xff(&self) -> bool { + self.config.use_xff } } + +pub fn build_quota(burst: NonZeroU32, period: Duration) -> Quota { + let replenish_interval_ns = period.as_nanos() / (burst.get() as u128); + Quota::with_period(Duration::from_nanos(replenish_interval_ns as u64)) + .unwrap() + .allow_burst(burst) +} diff --git a/src/extensions/rate_limit/xff.rs b/src/extensions/rate_limit/xff.rs new file mode 100644 index 0000000..6b44d75 --- /dev/null +++ b/src/extensions/rate_limit/xff.rs @@ -0,0 +1,45 @@ +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; + +pub trait XFF { + fn xxf_ip(&self) -> Option; +} +impl XFF for http::Request { + fn xxf_ip(&self) -> Option { + let xff = self.headers().get("x-forwarded-for")?; + let xff = xff.to_str().ok()?; + let xff = xff.split(',').next()?; + let addr = IpAddr::from_str(xff) + .ok() + .or(SocketAddr::from_str(xff).map(|x| x.ip()).ok())?; + Some(addr.to_string()) + } +} + +#[test] +fn test_xff() { + let cases = vec![ + ("", None), + ("foo,bar", None), + ("1.2.3.4:1234,foo,bar", Some("1.2.3.4")), + ("203.0.113.195, 70.41.3.18, 150.172.238.178", Some("203.0.113.195")), + ("203.0.113.195", Some("203.0.113.195")), + ("[::1]:1234,foo,bar", Some("::1")), + ( + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + Some("2001:db8:85a3:8d3:1319:8a2e:370:7348"), + ), + ( + "[2001:db8::1a2b:3c4d]:41237, 198.51.100.100:26321", + Some("2001:db8::1a2b:3c4d"), + ), + ]; + + for (xff, ip) in cases { + let req = http::Request::builder() + .header("X-Forwarded-For", xff) + .body(()) + .unwrap(); + assert_eq!(req.xxf_ip().as_deref(), ip); + } +} diff --git a/src/extensions/server/mod.rs b/src/extensions/server/mod.rs index 1ce5caa..823e04a 100644 --- a/src/extensions/server/mod.rs +++ b/src/extensions/server/mod.rs @@ -1,19 +1,25 @@ -use std::{future::Future, net::SocketAddr}; - use async_trait::async_trait; use http::header::HeaderValue; +use hyper::server::conn::AddrStream; +use hyper::service::Service; +use hyper::service::{make_service_fn, service_fn}; use jsonrpsee::server::{ - middleware::rpc::RpcServiceBuilder, RandomStringIdProvider, RpcModule, ServerBuilder, ServerHandle, + middleware::rpc::RpcServiceBuilder, stop_channel, RandomStringIdProvider, RpcModule, ServerBuilder, ServerHandle, }; +use jsonrpsee::Methods; +use serde::ser::StdError; use serde::Deserialize; +use std::str::FromStr; +use std::sync::Arc; +use std::{future::Future, net::SocketAddr}; +use tower::ServiceBuilder; use tower_http::cors::{AllowOrigin, CorsLayer}; -use super::{rate_limit::RateLimit, Extension, ExtensionRegistry}; -use proxy_get_request::ProxyGetRequestLayer; - -use self::proxy_get_request::ProxyGetRequestMethod; +use super::{Extension, ExtensionRegistry}; +use crate::extensions::rate_limit::{RateLimitBuilder, XFF}; mod proxy_get_request; +use proxy_get_request::{ProxyGetRequestLayer, ProxyGetRequestMethod}; pub struct SubwayServerBuilder { pub config: ServerConfig, @@ -90,40 +96,79 @@ impl SubwayServerBuilder { pub async fn build>>>( &self, - rate_limit: Option, - builder: impl FnOnce() -> Fut, + rate_limit_builder: Option>, + rpc_module_builder: impl FnOnce() -> Fut, ) -> anyhow::Result<(SocketAddr, ServerHandle)> { - let rpc_middleware = RpcServiceBuilder::new().option_layer(rate_limit); - - let service_builder = tower::ServiceBuilder::new() - .layer(cors_layer(self.config.cors.clone()).expect("Invalid CORS config")) - .layer( - ProxyGetRequestLayer::new( - self.config - .http_methods - .iter() - .map(|m| ProxyGetRequestMethod { - path: m.path.clone(), - method: m.method.clone(), - }) - .collect(), - ) - .expect("Invalid health config"), - ); - - let server = ServerBuilder::default() - .set_rpc_middleware(rpc_middleware) - .set_http_middleware(service_builder) - .max_connections(self.config.max_connections) - .set_id_provider(RandomStringIdProvider::new(16)) - .build((self.config.listen_address.as_str(), self.config.port)) - .await?; - - let module = builder().await?; - - let addr = server.local_addr()?; - let server = server.start(module); - - Ok((addr, server)) + let config = self.config.clone(); + + let (stop_handle, server_handle) = stop_channel(); + let handle = stop_handle.clone(); + let rpc_module = rpc_module_builder().await?; + + // make_service handle each connection + let make_service = make_service_fn(move |socket: &AddrStream| { + let socket_ip = socket.remote_addr().ip().to_string(); + + let http_middleware: ServiceBuilder<_> = tower::ServiceBuilder::new() + .layer(cors_layer(config.cors.clone()).expect("Invalid CORS config")) + .layer( + ProxyGetRequestLayer::new( + config + .http_methods + .iter() + .map(|m| ProxyGetRequestMethod { + path: m.path.clone(), + method: m.method.clone(), + }) + .collect(), + ) + .expect("Invalid health config"), + ); + + let rpc_module = rpc_module.clone(); + let stop_handle = stop_handle.clone(); + let rate_limit_builder = rate_limit_builder.clone(); + + async move { + // service_fn handle each request + Ok::<_, Box>(service_fn(move |req| { + let mut socket_ip = socket_ip.clone(); + let methods: Methods = rpc_module.clone().into(); + let stop_handle = stop_handle.clone(); + let http_middleware = http_middleware.clone(); + + if let Some(true) = rate_limit_builder.as_ref().map(|r| r.use_xff()) { + socket_ip = req.xxf_ip().unwrap_or(socket_ip); + } + + let rpc_middleware = RpcServiceBuilder::new() + .option_layer(rate_limit_builder.as_ref().and_then(|r| r.ip_limit(socket_ip))) + .option_layer(rate_limit_builder.as_ref().and_then(|r| r.connection_limit())); + + let service_builder = ServerBuilder::default() + .set_rpc_middleware(rpc_middleware) + .set_http_middleware(http_middleware) + .max_connections(config.max_connections) + .set_id_provider(RandomStringIdProvider::new(16)) + .to_service_builder(); + + let mut service = service_builder.build(methods, stop_handle); + service.call(req) + })) + } + }); + + let ip_addr = std::net::IpAddr::from_str(&self.config.listen_address)?; + let addr = SocketAddr::new(ip_addr, self.config.port); + + let server = hyper::Server::bind(&addr).serve(make_service); + let addr = server.local_addr(); + + tokio::spawn(async move { + let graceful = server.with_graceful_shutdown(async move { handle.shutdown().await }); + graceful.await.unwrap() + }); + + Ok((addr, server_handle)) } } diff --git a/src/server.rs b/src/server.rs index 0bf7a30..4745d70 100644 --- a/src/server.rs +++ b/src/server.rs @@ -43,17 +43,13 @@ pub async fn build(config: Config) -> anyhow::Result { .get::() .expect("Server extension not found"); - let rate_limit = extensions_registry - .read() - .await - .get::() - .map(|b| b.build()); + let rate_limit_builder = extensions_registry.read().await.get::(); let request_timeout_seconds = server_builder.config.request_timeout_seconds; let registry = extensions_registry.clone(); let (addr, handle) = server_builder - .build(rate_limit, move || async move { + .build(rate_limit_builder, move || async move { let mut module = RpcModule::new(()); let tracer = telemetry::Tracer::new("server");