diff --git a/src/main.rs b/src/main.rs index f6679748..b48bb327 100644 --- a/src/main.rs +++ b/src/main.rs @@ -42,9 +42,7 @@ use client_query::context::Context; use config::{ApiKeys, BlocklistEntry, ExchangeRateProvider}; use indexer_client::IndexerClient; use indexing_performance::IndexingPerformance; -use middleware::{ - legacy_auth_adapter, RequestTracingLayer, RequireAuthorizationLayer, SetRequestIdLayer, -}; +use middleware::{legacy_auth_adapter, RequestTracingLayer, RequireAuthorizationLayer}; use network::{indexer_blocklist, subgraph_client::Client as SubgraphClient}; use prometheus::{self, Encoder as _}; use receipts::ReceiptSigner; @@ -209,13 +207,8 @@ async fn main() { .allow_headers(cors::Any) .allow_methods([http::Method::OPTIONS, http::Method::POST]), ) - // Set up the query tracing span - .layer(RequestTracingLayer) - // Set the query ID on the request - .layer(SetRequestIdLayer::new(format!("{:?}", signer_address))) - // Handle legacy in-path auth, and convert it into a header + .layer(RequestTracingLayer::new(format!("{:?}", signer_address))) .layer(axum::middleware::from_fn(legacy_auth_adapter)) - // Require the query to be authorized .layer(RequireAuthorizationLayer::new(auth_service)), ); diff --git a/src/middleware.rs b/src/middleware.rs index a3199ec9..39de9f7b 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,9 +1,7 @@ mod legacy_auth; -mod request_id; mod request_tracing; mod require_auth; pub use legacy_auth::legacy_auth_adapter; -pub use request_id::{RequestId, SetRequestIdLayer}; -pub use request_tracing::RequestTracingLayer; +pub use request_tracing::{RequestId, RequestTracingLayer}; pub use require_auth::RequireAuthorizationLayer; diff --git a/src/middleware/request_id.rs b/src/middleware/request_id.rs deleted file mode 100644 index b347bafa..00000000 --- a/src/middleware/request_id.rs +++ /dev/null @@ -1,259 +0,0 @@ -use std::{ - sync::{atomic, atomic::AtomicU64, Arc}, - task::{Context, Poll}, -}; - -use axum::http::{HeaderName, HeaderValue, Request, Response}; -use tower::Service; - -/// Cloudflare Ray ID header name. -static CLOUDFLARE_RAY_ID: HeaderName = HeaderName::from_static("cf-ray"); - -/// An identifier for a request. -#[derive(Clone)] -pub struct RequestId(pub String); - -impl RequestId { - #[cfg(test)] - /// Create a new [`RequestId`] from a string. - pub fn new(id: impl Into) -> Self { - Self(id.into()) - } - - /// Create a new [`RequestId`] from a HeaderValue. - /// - /// If the header value is invalid, an empty string will be used. - pub fn from_header_value(value: &HeaderValue) -> Self { - Self(value.to_str().unwrap_or_default().to_string()) - } - - /// Create a new [`RequestId`] from the Gateway ID and a counter. - pub fn new_from_gateway_id_and_count(gateway_id: &str, counter: u64) -> Self { - Self(format!("{}-{:x}", gateway_id, counter)) - } -} - -impl AsRef for RequestId { - fn as_ref(&self) -> &str { - &self.0 - } -} - -impl std::fmt::Display for RequestId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } -} - -impl std::fmt::Debug for RequestId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } -} - -/// Set request IDs on ingoing requests. -/// -/// If the request has a `cf-ray` header, it will be used as the request ID. Otherwise, a new request ID -/// derived from the gateway ID and a counter will be used. -/// -/// The middleware inserts the request ID into the request extensions. -#[derive(Clone, Debug)] -pub struct SetRequestId { - inner: S, - gateway_id: String, - counter: Arc, -} - -impl SetRequestId { - /// Create a new [`SetRequestId] middleware. - pub fn new(inner: S, gateway_id: String, counter: Arc) -> Self { - Self { - inner, - gateway_id, - counter, - } - } -} - -impl Service> for SetRequestId -where - S: Service, Response = Response>, -{ - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - #[inline] - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - if req.extensions().get::().is_none() { - let request_id = if let Some(ray_id) = req.headers().get(&CLOUDFLARE_RAY_ID) { - RequestId::from_header_value(ray_id) - } else { - let request_count = self.counter.fetch_add(1, atomic::Ordering::Relaxed); - RequestId::new_from_gateway_id_and_count(&self.gateway_id, request_count) - }; - - // Set the request ID on the current span. The request tracing middleware sets the span's request_id - // field to field::Empty. We set it here to the actual request ID. - tracing::span::Span::current() - .record("request_id", tracing::field::display(&request_id)); - - // Set the request ID on the request extensions - req.extensions_mut().insert(request_id); - } - - self.inner.call(req) - } -} - -/// Set request id extensions. -/// -/// This layer applies the [`SetRequestId`] middleware. -#[derive(Clone, Debug)] -pub struct SetRequestIdLayer { - gateway_id: String, - counter: Arc, -} - -impl SetRequestIdLayer { - /// Create a new [`SetRequestIdLayer`]. - pub fn new(gateway_id: impl Into) -> Self { - Self { - gateway_id: gateway_id.into(), - counter: Arc::new(AtomicU64::new(0)), - } - } -} - -impl tower::layer::Layer for SetRequestIdLayer { - type Service = SetRequestId; - - fn layer(&self, inner: S) -> Self::Service { - SetRequestId::new(inner, self.gateway_id.clone(), self.counter.clone()) - } -} - -#[cfg(test)] -mod tests { - use hyper::http; - use tower::{Service, ServiceBuilder, ServiceExt}; - - use super::{RequestId, SetRequestIdLayer}; - - #[tokio::test] - async fn cf_ray_header_is_present() { - //* Given - let gateway_id = "test-gateway"; - - let (mock_svc, mut handle) = - tower_test::mock::pair::, http::Response<&str>>(); - let mut svc = ServiceBuilder::new() - .layer(SetRequestIdLayer::new(gateway_id)) - .service(mock_svc); - - let req = http::Request::builder() - .header("cf-ray", "test-cf-ray") - .body("test") - .unwrap(); - - //* When - // The service must be ready before calling it - svc.ready().await.expect("service is ready"); - tokio::spawn(svc.call(req)); - - let (r, _) = handle - .next_request() - .await - .expect("service received a request"); - - //* Then - assert_eq!(r.headers().get("cf-ray").unwrap(), "test-cf-ray"); - assert_eq!( - r.extensions().get::().unwrap().as_ref(), - "test-cf-ray" - ); - } - - #[tokio::test] - async fn auto_incrementing_id() { - //* Given - let gateway_id = "fe3c0304-7383-48f4-9f3a-fc0cb37f55ba"; - - let (mock_svc, mut handle) = - tower_test::mock::pair::, http::Response<&str>>(); - let mut svc = ServiceBuilder::new() - .layer(SetRequestIdLayer::new(gateway_id)) - .service(mock_svc); - - let req1 = http::Request::builder().body("test").unwrap(); - let req2 = http::Request::builder().body("test").unwrap(); - - //* When - // The service must be ready before calling it - svc.ready().await.expect("service is ready"); - tokio::spawn(svc.call(req1)); - - // Wait for the service to be ready again before calling it - svc.ready().await.expect("service is ready"); - tokio::spawn(svc.call(req2)); - - let (r1, _) = handle - .next_request() - .await - .expect("service received a request"); - let (r2, _) = handle - .next_request() - .await - .expect("service received a request"); - - //* Then - assert_eq!( - r1.extensions().get::().unwrap().as_ref(), - "fe3c0304-7383-48f4-9f3a-fc0cb37f55ba-0" - ); - assert_eq!( - r2.extensions().get::().unwrap().as_ref(), - "fe3c0304-7383-48f4-9f3a-fc0cb37f55ba-1" - ); - } - - #[tokio::test] - async fn request_id_extension_is_already_present() { - //* Given - let gateway_id = "unique-gateway-id"; - - let (mock_svc, mut handle) = - tower_test::mock::pair::, http::Response<&str>>(); - let mut svc = ServiceBuilder::new() - .layer(SetRequestIdLayer::new(gateway_id)) - .service(mock_svc); - - let expected_request_id = "fe3c0304-7383-48f4-9f3a-fc0cb37f55ba-0"; - let req = { - let mut req = http::Request::builder().body("test").unwrap(); - req.extensions_mut() - .insert(RequestId::new(expected_request_id)); - req - }; - - //* When - // The service must be ready before calling it - svc.ready().await.expect("service is ready"); - tokio::spawn(svc.call(req)); - - let (r, _) = handle - .next_request() - .await - .expect("service received a request"); - - //* Then - assert_eq!( - r.extensions().get::().unwrap().as_ref(), - expected_request_id - ); - } -} diff --git a/src/middleware/request_tracing.rs b/src/middleware/request_tracing.rs index 35c6c3d7..88c9e771 100644 --- a/src/middleware/request_tracing.rs +++ b/src/middleware/request_tracing.rs @@ -1,4 +1,10 @@ -use std::task::{Context, Poll}; +use std::{ + sync::{ + atomic::{self, AtomicU64}, + Arc, + }, + task::{Context, Poll}, +}; use axum::http::{Request, Response}; use tower::Service; @@ -7,6 +13,9 @@ use tracing::{ instrument::{Instrument, Instrumented}, }; +#[derive(Clone)] +pub struct RequestId(pub String); + /// Middleware that instruments client query request with a tracing span. /// /// This middleware instruments the request future with a span: @@ -20,6 +29,8 @@ use tracing::{ #[derive(Debug, Clone)] pub struct RequestTracing { inner: S, + gateway_id: String, + counter: Arc, } impl Service> for RequestTracing @@ -34,25 +45,47 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { + const CLOUDFLARE_RAY_ID: http::HeaderName = http::HeaderName::from_static("cf-ray"); + let request_id = if let Some(ray_id) = req.headers().get(&CLOUDFLARE_RAY_ID) { + ray_id.to_str().unwrap_or_default().to_string() + } else { + let request_count = self.counter.fetch_add(1, atomic::Ordering::Relaxed); + format!("{}-{:x}", self.gateway_id, request_count) + }; + + req.extensions_mut().insert(RequestId(request_id.clone())); self.inner.call(req).instrument(tracing::info_span!( "client_request", - request_id = field::Empty, + request_id, selector = field::Empty, )) } } -/// A layer that applies the [`RequestTracing`] middleware. -/// -/// See [`RequestTracing`] for more details. #[derive(Debug, Clone)] -pub struct RequestTracingLayer; +pub struct RequestTracingLayer { + gateway_id: String, + counter: Arc, +} + +impl RequestTracingLayer { + pub fn new(gateway_id: String) -> Self { + Self { + gateway_id, + counter: Arc::new(AtomicU64::new(0)), + } + } +} impl tower::layer::Layer for RequestTracingLayer { type Service = RequestTracing; fn layer(&self, inner: S) -> Self::Service { - RequestTracing { inner } + RequestTracing { + inner, + gateway_id: self.gateway_id.clone(), + counter: self.counter.clone(), + } } }