Skip to content

Commit

Permalink
refactor(graph-gateway): use origin header extractor (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
LNSD authored Jan 23, 2024
1 parent 6d5f001 commit f1d05f4
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 128 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion graph-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ version = "17.0.0"
alloy-primitives.workspace = true
alloy-sol-types = "0.6.0"
anyhow.workspace = true
axum.workspace = true
axum = { workspace = true, features = ["headers"] }
chrono = { version = "0.4", default-features = false, features = ["clock"] }
cost-model = { git = "https://github.com/graphprotocol/agora", rev = "3367d76" }
eventuals = "0.6.7"
Expand Down
137 changes: 10 additions & 127 deletions graph-gateway/src/client_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,23 @@ use alloy_primitives::{Address, BlockNumber, U256};
use alloy_sol_types::Eip712Domain;
use anyhow::anyhow;
use axum::extract::OriginalUri;
use axum::http::Uri;
use axum::{
body::Bytes,
extract::{Path, State},
http::{header, HeaderMap, Response, StatusCode},
Extension,
http::{HeaderMap, Response, StatusCode},
Extension, TypedHeader,
};
use cost_model::{Context as AgoraContext, CostModel};
use eventuals::Ptr;
use futures::future::join_all;
use headers::ContentType;
use headers::{ContentType, Origin};
use prost::bytes::Buf;
use rand::{rngs::SmallRng, SeedableRng as _};
use serde::Deserialize;
use serde_json::value::RawValue;
use thegraph::types::{attestation, BlockPointer, DeploymentId, SubgraphId};
use tokio::sync::mpsc;
use toolshed::buffer_queue::QueueWriter;
use toolshed::url::Url;
use tracing::Instrument;

use gateway_common::utils::http_ext::HttpBuilderExt;
Expand Down Expand Up @@ -54,11 +52,13 @@ use crate::unattestable_errors::{miscategorized_attestable, miscategorized_unatt
use self::attestation_header::GraphAttestation;
use self::auth::AuthToken;
use self::context::Context;
use self::l2_forwarding::forward_request_to_l2;

mod attestation_header;
pub mod auth;
pub mod context;
mod graphql;
mod l2_forwarding;
pub mod legacy_auth_adapter;
pub mod query_id;
pub mod query_tracing;
Expand All @@ -75,6 +75,7 @@ pub async fn handle_query(
Extension(auth): Extension<AuthToken>,
OriginalUri(original_uri): OriginalUri,
Path(params): Path<BTreeMap<String, String>>,
origin: Option<TypedHeader<Origin>>,
headers: HeaderMap,
payload: Bytes,
) -> Response<String> {
Expand All @@ -85,7 +86,7 @@ pub async fn handle_query(

let resolved_deployments = resolve_subgraph_deployments(&ctx.network, &params).await;

// This is very useful for investigating gateway logs in production.
// This is very useful for investigating gateway logs in production
let selector = match &resolved_deployments {
Ok((_, Some(subgraph))) => subgraph.id.to_string(),
Ok((deployments, None)) => deployments
Expand Down Expand Up @@ -123,13 +124,9 @@ pub async fn handle_query(
}
}

// TODO(LNSD): Move origin header parsing to an extractor.
let domain = headers
.get(header::ORIGIN)
.and_then(|v| v.to_str().ok())
.and_then(|v| Some(v.parse::<Url>().ok()?.host_str()?.to_string()))
.unwrap_or("".to_string());

let domain = origin
.map(|origin| origin.hostname().to_string())
.unwrap_or_default();
let result = match resolved_deployments {
Ok((deployments, _)) => {
handle_client_query_inner(&ctx, deployments, payload, auth, domain)
Expand Down Expand Up @@ -177,67 +174,6 @@ pub async fn handle_query(
}
}

async fn forward_request_to_l2(
client: &reqwest::Client,
l2_url: &Url,
original_path: &Uri,
headers: HeaderMap,
payload: Bytes,
l2_subgraph_id: Option<SubgraphId>,
) -> Response<String> {
// We originally attempted to proxy the user's request, but that resulted in a lot of strange
// behavior from Cloudflare that was too difficult to debug.
let l2_path = l2_request_path(original_path, l2_subgraph_id);
let l2_url = l2_url.join(&l2_path).unwrap();
tracing::info!(%l2_url, %original_path);
let headers = headers
.into_iter()
.filter_map(|(k, v)| Some((k?, v)))
.filter(|(k, _)| [header::CONTENT_TYPE, header::AUTHORIZATION, header::ORIGIN].contains(k))
.collect();
let response = match client
.post(l2_url)
.headers(headers)
.body(payload)
.send()
.await
.and_then(|response| response.error_for_status())
{
Ok(response) => response,
Err(err) => {
return graphql::error_response(Error::Internal(anyhow!("L2 gateway error: {err}")))
}
};
let status = response.status();
if !status.is_success() {
return graphql::error_response(Error::Internal(anyhow!("L2 gateway error: {status}")));
}
let body = match response.text().await {
Ok(body) => body,
Err(err) => {
return graphql::error_response(Error::Internal(anyhow!("L2 gateway error: {err}")))
}
};
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(body)
.unwrap()
}

fn l2_request_path(original_path: &Uri, l2_subgraph_id: Option<SubgraphId>) -> String {
let mut path = original_path.path().to_string();
let subgraph_prefix = "subgraphs/id/";
let subgraph_start = path.find(subgraph_prefix);
// rewrite path of subgraph queries to the L2 subgraph ID, conserving version constraint
if let (Some(l2_subgraph_id), Some(replace_start)) = (l2_subgraph_id, subgraph_start) {
let replace_start = replace_start + subgraph_prefix.len();
let replace_end = path.find('^').unwrap_or(path.len());
path.replace_range(replace_start..replace_end, &l2_subgraph_id.to_string());
}
path
}

async fn resolve_subgraph_deployments(
network: &GraphNetwork,
params: &BTreeMap<String, String>,
Expand Down Expand Up @@ -859,59 +795,6 @@ async fn optimistic_query(
mod tests {
use super::*;

mod l2_gateway_forwarding {
use super::*;

#[tokio::test]
async fn l2_request_path() {
let deployment: DeploymentId = "QmdveVMs7nAvdBPxNoaMMAYgNcuSroneMctZDnZUgbPPP3"
.parse()
.unwrap();
let l1_subgraph: SubgraphId = "EMRitnR1t3drKrDQSmJMSmHBPB2sGotgZE12DzWNezDn"
.parse()
.unwrap();
let l2_subgraph: SubgraphId = "CVHoVSrdiiYvLcH4wocDCazJ1YuixHZ1SKt34UWmnQcC"
.parse()
.unwrap();

// test deployment route
let mut original = format!("/api/deployments/id/{deployment}").parse().unwrap();
let mut expected = format!("/api/deployments/id/{deployment}");
assert_eq!(
expected,
super::l2_request_path(&original, Some(l2_subgraph))
);

// test subgraph route
original = format!("/api/subgraphs/id/{l1_subgraph}").parse().unwrap();
expected = format!("/api/subgraphs/id/{l2_subgraph}");
assert_eq!(
expected,
super::l2_request_path(&original, Some(l2_subgraph))
);

// test subgraph route with API key prefix
original = format!("/api/deadbeefdeadbeefdeadbeefdeadbeef/subgraphs/id/{l1_subgraph}")
.parse()
.unwrap();
expected = format!("/api/deadbeefdeadbeefdeadbeefdeadbeef/subgraphs/id/{l2_subgraph}");
assert_eq!(
expected,
super::l2_request_path(&original, Some(l2_subgraph))
);

// test subgraph route with version constraint
original = format!("/api/subgraphs/id/{l1_subgraph}^0.0.1")
.parse()
.unwrap();
expected = format!("/api/subgraphs/id/{l2_subgraph}^0.0.1");
assert_eq!(
expected,
super::l2_request_path(&original, Some(l2_subgraph))
);
}
}

mod require_req_auth {
use std::collections::HashMap;
use std::sync::Arc;
Expand Down
118 changes: 118 additions & 0 deletions graph-gateway/src/client_query/l2_forwarding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use alloy_primitives::bytes::Bytes;
use anyhow::anyhow;
use axum::http::{header, HeaderMap, Response, Uri};
use thegraph::types::SubgraphId;
use toolshed::url::Url;

use gateway_framework::errors::Error;

pub async fn forward_request_to_l2(
client: &reqwest::Client,
l2_url: &Url,
original_path: &Uri,
headers: HeaderMap,
payload: Bytes,
l2_subgraph_id: Option<SubgraphId>,
) -> Response<String> {
// We originally attempted to proxy the user's request, but that resulted in a lot of strange
// behavior from Cloudflare that was too difficult to debug.
let l2_path = l2_request_path(original_path, l2_subgraph_id);
let l2_url = l2_url.join(&l2_path).unwrap();
tracing::info!(%l2_url, %original_path);
let headers = headers
.into_iter()
.filter_map(|(k, v)| Some((k?, v)))
.filter(|(k, _)| [header::CONTENT_TYPE, header::AUTHORIZATION, header::ORIGIN].contains(k))
.collect();
let response = match client
.post(l2_url)
.headers(headers)
.body(payload)
.send()
.await
.and_then(|response| response.error_for_status())
{
Ok(response) => response,
Err(err) => {
return crate::client_query::graphql::error_response(Error::Internal(anyhow!(
"L2 gateway error: {err}"
)))
}
};
let status = response.status();
if !status.is_success() {
return crate::client_query::graphql::error_response(Error::Internal(anyhow!(
"L2 gateway error: {status}"
)));
}
let body = match response.text().await {
Ok(body) => body,
Err(err) => {
return crate::client_query::graphql::error_response(Error::Internal(anyhow!(
"L2 gateway error: {err}"
)))
}
};
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(body)
.unwrap()
}

fn l2_request_path(original_path: &Uri, l2_subgraph_id: Option<SubgraphId>) -> String {
let mut path = original_path.path().to_string();
let subgraph_prefix = "subgraphs/id/";
let subgraph_start = path.find(subgraph_prefix);
// rewrite path of subgraph queries to the L2 subgraph ID, conserving version constraint
if let (Some(l2_subgraph_id), Some(replace_start)) = (l2_subgraph_id, subgraph_start) {
let replace_start = replace_start + subgraph_prefix.len();
let replace_end = path.find('^').unwrap_or(path.len());
path.replace_range(replace_start..replace_end, &l2_subgraph_id.to_string());
}
path
}

#[cfg(test)]
mod tests {
use thegraph::types::{DeploymentId, SubgraphId};

use super::l2_request_path;

#[test]
fn test_l2_request_path() {
let deployment: DeploymentId = "QmdveVMs7nAvdBPxNoaMMAYgNcuSroneMctZDnZUgbPPP3"
.parse()
.unwrap();
let l1_subgraph: SubgraphId = "EMRitnR1t3drKrDQSmJMSmHBPB2sGotgZE12DzWNezDn"
.parse()
.unwrap();
let l2_subgraph: SubgraphId = "CVHoVSrdiiYvLcH4wocDCazJ1YuixHZ1SKt34UWmnQcC"
.parse()
.unwrap();

// test deployment route
let mut original = format!("/api/deployments/id/{deployment}").parse().unwrap();
let mut expected = format!("/api/deployments/id/{deployment}");
assert_eq!(expected, l2_request_path(&original, Some(l2_subgraph)));

// test subgraph route
original = format!("/api/subgraphs/id/{l1_subgraph}").parse().unwrap();
expected = format!("/api/subgraphs/id/{l2_subgraph}");
assert_eq!(expected, l2_request_path(&original, Some(l2_subgraph)));

// test subgraph route with API key prefix
original = format!("/api/deadbeefdeadbeefdeadbeefdeadbeef/subgraphs/id/{l1_subgraph}")
.parse()
.unwrap();
expected = format!("/api/deadbeefdeadbeefdeadbeefdeadbeef/subgraphs/id/{l2_subgraph}");
assert_eq!(expected, l2_request_path(&original, Some(l2_subgraph)));

// test subgraph route with version constraint
original = format!("/api/subgraphs/id/{l1_subgraph}^0.0.1")
.parse()
.unwrap();
expected = format!("/api/subgraphs/id/{l2_subgraph}^0.0.1");
assert_eq!(expected, l2_request_path(&original, Some(l2_subgraph)));
}
}

0 comments on commit f1d05f4

Please sign in to comment.