From f1d05f4e9cd360b4ded7ab45bede574774507f1c Mon Sep 17 00:00:00 2001 From: Lorenzo Delgado Date: Tue, 23 Jan 2024 23:41:22 +0100 Subject: [PATCH] refactor(graph-gateway): use origin header extractor (#556) --- Cargo.lock | 1 + graph-gateway/Cargo.toml | 2 +- graph-gateway/src/client_query.rs | 137 ++---------------- .../src/client_query/l2_forwarding.rs | 118 +++++++++++++++ 4 files changed, 130 insertions(+), 128 deletions(-) create mode 100644 graph-gateway/src/client_query/l2_forwarding.rs diff --git a/Cargo.lock b/Cargo.lock index f0d61e1f..f4349ac6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -464,6 +464,7 @@ dependencies = [ "bitflags 1.3.2", "bytes", "futures-util", + "headers", "http 0.2.11", "http-body", "hyper", diff --git a/graph-gateway/Cargo.toml b/graph-gateway/Cargo.toml index b7a57e69..432ddb69 100644 --- a/graph-gateway/Cargo.toml +++ b/graph-gateway/Cargo.toml @@ -7,7 +7,7 @@ version = "17.0.0" alloy-primitives.workspace = true alloy-sol-types = "0.6.0" anyhow.workspace = true -axum.workspace = true +axum = { workspace = true, features = ["headers"] } chrono = { version = "0.4", default-features = false, features = ["clock"] } cost-model = { git = "https://github.com/graphprotocol/agora", rev = "3367d76" } eventuals = "0.6.7" diff --git a/graph-gateway/src/client_query.rs b/graph-gateway/src/client_query.rs index 03e67769..55974999 100644 --- a/graph-gateway/src/client_query.rs +++ b/graph-gateway/src/client_query.rs @@ -6,17 +6,16 @@ use alloy_primitives::{Address, BlockNumber, U256}; use alloy_sol_types::Eip712Domain; use anyhow::anyhow; use axum::extract::OriginalUri; -use axum::http::Uri; use axum::{ body::Bytes, extract::{Path, State}, - http::{header, HeaderMap, Response, StatusCode}, - Extension, + http::{HeaderMap, Response, StatusCode}, + Extension, TypedHeader, }; use cost_model::{Context as AgoraContext, CostModel}; use eventuals::Ptr; use futures::future::join_all; -use headers::ContentType; +use headers::{ContentType, Origin}; use prost::bytes::Buf; use rand::{rngs::SmallRng, SeedableRng as _}; use serde::Deserialize; @@ -24,7 +23,6 @@ use serde_json::value::RawValue; use thegraph::types::{attestation, BlockPointer, DeploymentId, SubgraphId}; use tokio::sync::mpsc; use toolshed::buffer_queue::QueueWriter; -use toolshed::url::Url; use tracing::Instrument; use gateway_common::utils::http_ext::HttpBuilderExt; @@ -54,11 +52,13 @@ use crate::unattestable_errors::{miscategorized_attestable, miscategorized_unatt use self::attestation_header::GraphAttestation; use self::auth::AuthToken; use self::context::Context; +use self::l2_forwarding::forward_request_to_l2; mod attestation_header; pub mod auth; pub mod context; mod graphql; +mod l2_forwarding; pub mod legacy_auth_adapter; pub mod query_id; pub mod query_tracing; @@ -75,6 +75,7 @@ pub async fn handle_query( Extension(auth): Extension, OriginalUri(original_uri): OriginalUri, Path(params): Path>, + origin: Option>, headers: HeaderMap, payload: Bytes, ) -> Response { @@ -85,7 +86,7 @@ pub async fn handle_query( let resolved_deployments = resolve_subgraph_deployments(&ctx.network, ¶ms).await; - // This is very useful for investigating gateway logs in production. + // This is very useful for investigating gateway logs in production let selector = match &resolved_deployments { Ok((_, Some(subgraph))) => subgraph.id.to_string(), Ok((deployments, None)) => deployments @@ -123,13 +124,9 @@ pub async fn handle_query( } } - // TODO(LNSD): Move origin header parsing to an extractor. - let domain = headers - .get(header::ORIGIN) - .and_then(|v| v.to_str().ok()) - .and_then(|v| Some(v.parse::().ok()?.host_str()?.to_string())) - .unwrap_or("".to_string()); - + let domain = origin + .map(|origin| origin.hostname().to_string()) + .unwrap_or_default(); let result = match resolved_deployments { Ok((deployments, _)) => { handle_client_query_inner(&ctx, deployments, payload, auth, domain) @@ -177,67 +174,6 @@ pub async fn handle_query( } } -async fn forward_request_to_l2( - client: &reqwest::Client, - l2_url: &Url, - original_path: &Uri, - headers: HeaderMap, - payload: Bytes, - l2_subgraph_id: Option, -) -> Response { - // We originally attempted to proxy the user's request, but that resulted in a lot of strange - // behavior from Cloudflare that was too difficult to debug. - let l2_path = l2_request_path(original_path, l2_subgraph_id); - let l2_url = l2_url.join(&l2_path).unwrap(); - tracing::info!(%l2_url, %original_path); - let headers = headers - .into_iter() - .filter_map(|(k, v)| Some((k?, v))) - .filter(|(k, _)| [header::CONTENT_TYPE, header::AUTHORIZATION, header::ORIGIN].contains(k)) - .collect(); - let response = match client - .post(l2_url) - .headers(headers) - .body(payload) - .send() - .await - .and_then(|response| response.error_for_status()) - { - Ok(response) => response, - Err(err) => { - return graphql::error_response(Error::Internal(anyhow!("L2 gateway error: {err}"))) - } - }; - let status = response.status(); - if !status.is_success() { - return graphql::error_response(Error::Internal(anyhow!("L2 gateway error: {status}"))); - } - let body = match response.text().await { - Ok(body) => body, - Err(err) => { - return graphql::error_response(Error::Internal(anyhow!("L2 gateway error: {err}"))) - } - }; - Response::builder() - .status(status) - .header(header::CONTENT_TYPE, "application/json") - .body(body) - .unwrap() -} - -fn l2_request_path(original_path: &Uri, l2_subgraph_id: Option) -> String { - let mut path = original_path.path().to_string(); - let subgraph_prefix = "subgraphs/id/"; - let subgraph_start = path.find(subgraph_prefix); - // rewrite path of subgraph queries to the L2 subgraph ID, conserving version constraint - if let (Some(l2_subgraph_id), Some(replace_start)) = (l2_subgraph_id, subgraph_start) { - let replace_start = replace_start + subgraph_prefix.len(); - let replace_end = path.find('^').unwrap_or(path.len()); - path.replace_range(replace_start..replace_end, &l2_subgraph_id.to_string()); - } - path -} - async fn resolve_subgraph_deployments( network: &GraphNetwork, params: &BTreeMap, @@ -859,59 +795,6 @@ async fn optimistic_query( mod tests { use super::*; - mod l2_gateway_forwarding { - use super::*; - - #[tokio::test] - async fn l2_request_path() { - let deployment: DeploymentId = "QmdveVMs7nAvdBPxNoaMMAYgNcuSroneMctZDnZUgbPPP3" - .parse() - .unwrap(); - let l1_subgraph: SubgraphId = "EMRitnR1t3drKrDQSmJMSmHBPB2sGotgZE12DzWNezDn" - .parse() - .unwrap(); - let l2_subgraph: SubgraphId = "CVHoVSrdiiYvLcH4wocDCazJ1YuixHZ1SKt34UWmnQcC" - .parse() - .unwrap(); - - // test deployment route - let mut original = format!("/api/deployments/id/{deployment}").parse().unwrap(); - let mut expected = format!("/api/deployments/id/{deployment}"); - assert_eq!( - expected, - super::l2_request_path(&original, Some(l2_subgraph)) - ); - - // test subgraph route - original = format!("/api/subgraphs/id/{l1_subgraph}").parse().unwrap(); - expected = format!("/api/subgraphs/id/{l2_subgraph}"); - assert_eq!( - expected, - super::l2_request_path(&original, Some(l2_subgraph)) - ); - - // test subgraph route with API key prefix - original = format!("/api/deadbeefdeadbeefdeadbeefdeadbeef/subgraphs/id/{l1_subgraph}") - .parse() - .unwrap(); - expected = format!("/api/deadbeefdeadbeefdeadbeefdeadbeef/subgraphs/id/{l2_subgraph}"); - assert_eq!( - expected, - super::l2_request_path(&original, Some(l2_subgraph)) - ); - - // test subgraph route with version constraint - original = format!("/api/subgraphs/id/{l1_subgraph}^0.0.1") - .parse() - .unwrap(); - expected = format!("/api/subgraphs/id/{l2_subgraph}^0.0.1"); - assert_eq!( - expected, - super::l2_request_path(&original, Some(l2_subgraph)) - ); - } - } - mod require_req_auth { use std::collections::HashMap; use std::sync::Arc; diff --git a/graph-gateway/src/client_query/l2_forwarding.rs b/graph-gateway/src/client_query/l2_forwarding.rs new file mode 100644 index 00000000..d8d5752f --- /dev/null +++ b/graph-gateway/src/client_query/l2_forwarding.rs @@ -0,0 +1,118 @@ +use alloy_primitives::bytes::Bytes; +use anyhow::anyhow; +use axum::http::{header, HeaderMap, Response, Uri}; +use thegraph::types::SubgraphId; +use toolshed::url::Url; + +use gateway_framework::errors::Error; + +pub async fn forward_request_to_l2( + client: &reqwest::Client, + l2_url: &Url, + original_path: &Uri, + headers: HeaderMap, + payload: Bytes, + l2_subgraph_id: Option, +) -> Response { + // We originally attempted to proxy the user's request, but that resulted in a lot of strange + // behavior from Cloudflare that was too difficult to debug. + let l2_path = l2_request_path(original_path, l2_subgraph_id); + let l2_url = l2_url.join(&l2_path).unwrap(); + tracing::info!(%l2_url, %original_path); + let headers = headers + .into_iter() + .filter_map(|(k, v)| Some((k?, v))) + .filter(|(k, _)| [header::CONTENT_TYPE, header::AUTHORIZATION, header::ORIGIN].contains(k)) + .collect(); + let response = match client + .post(l2_url) + .headers(headers) + .body(payload) + .send() + .await + .and_then(|response| response.error_for_status()) + { + Ok(response) => response, + Err(err) => { + return crate::client_query::graphql::error_response(Error::Internal(anyhow!( + "L2 gateway error: {err}" + ))) + } + }; + let status = response.status(); + if !status.is_success() { + return crate::client_query::graphql::error_response(Error::Internal(anyhow!( + "L2 gateway error: {status}" + ))); + } + let body = match response.text().await { + Ok(body) => body, + Err(err) => { + return crate::client_query::graphql::error_response(Error::Internal(anyhow!( + "L2 gateway error: {err}" + ))) + } + }; + Response::builder() + .status(status) + .header(header::CONTENT_TYPE, "application/json") + .body(body) + .unwrap() +} + +fn l2_request_path(original_path: &Uri, l2_subgraph_id: Option) -> String { + let mut path = original_path.path().to_string(); + let subgraph_prefix = "subgraphs/id/"; + let subgraph_start = path.find(subgraph_prefix); + // rewrite path of subgraph queries to the L2 subgraph ID, conserving version constraint + if let (Some(l2_subgraph_id), Some(replace_start)) = (l2_subgraph_id, subgraph_start) { + let replace_start = replace_start + subgraph_prefix.len(); + let replace_end = path.find('^').unwrap_or(path.len()); + path.replace_range(replace_start..replace_end, &l2_subgraph_id.to_string()); + } + path +} + +#[cfg(test)] +mod tests { + use thegraph::types::{DeploymentId, SubgraphId}; + + use super::l2_request_path; + + #[test] + fn test_l2_request_path() { + let deployment: DeploymentId = "QmdveVMs7nAvdBPxNoaMMAYgNcuSroneMctZDnZUgbPPP3" + .parse() + .unwrap(); + let l1_subgraph: SubgraphId = "EMRitnR1t3drKrDQSmJMSmHBPB2sGotgZE12DzWNezDn" + .parse() + .unwrap(); + let l2_subgraph: SubgraphId = "CVHoVSrdiiYvLcH4wocDCazJ1YuixHZ1SKt34UWmnQcC" + .parse() + .unwrap(); + + // test deployment route + let mut original = format!("/api/deployments/id/{deployment}").parse().unwrap(); + let mut expected = format!("/api/deployments/id/{deployment}"); + assert_eq!(expected, l2_request_path(&original, Some(l2_subgraph))); + + // test subgraph route + original = format!("/api/subgraphs/id/{l1_subgraph}").parse().unwrap(); + expected = format!("/api/subgraphs/id/{l2_subgraph}"); + assert_eq!(expected, l2_request_path(&original, Some(l2_subgraph))); + + // test subgraph route with API key prefix + original = format!("/api/deadbeefdeadbeefdeadbeefdeadbeef/subgraphs/id/{l1_subgraph}") + .parse() + .unwrap(); + expected = format!("/api/deadbeefdeadbeefdeadbeefdeadbeef/subgraphs/id/{l2_subgraph}"); + assert_eq!(expected, l2_request_path(&original, Some(l2_subgraph))); + + // test subgraph route with version constraint + original = format!("/api/subgraphs/id/{l1_subgraph}^0.0.1") + .parse() + .unwrap(); + expected = format!("/api/subgraphs/id/{l2_subgraph}^0.0.1"); + assert_eq!(expected, l2_request_path(&original, Some(l2_subgraph))); + } +}