Skip to content

Commit

Permalink
feat: remove support for API key max_budget (#983)
Browse files Browse the repository at this point in the history
  • Loading branch information
Theodus authored Dec 2, 2024
1 parent 79d0843 commit 9284635
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 40 deletions.
13 changes: 2 additions & 11 deletions src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use std::{
};

use anyhow::{anyhow, bail, ensure};
use ordered_float::NotNan;
use serde::Deserialize;
use serde_with::serde_as;
use thegraph_core::SubgraphId;
use tokio::sync::watch;

Expand All @@ -15,7 +13,6 @@ pub struct AuthSettings {
pub key: String,
pub user: String,
pub authorized_subgraphs: Vec<SubgraphId>,
pub budget_usd: Option<NotNan<f64>>,
}

impl AuthSettings {
Expand All @@ -32,15 +29,11 @@ impl AuthSettings {
}
}

#[serde_as]
#[derive(Clone, Debug, Default, Deserialize)]
pub struct APIKey {
pub struct ApiKey {
pub key: String,
pub user_address: String,
pub query_status: QueryStatus,
#[serde_as(as = "Option<serde_with::TryFromInto<f64>>")]
#[serde(rename = "max_budget")]
pub max_budget_usd: Option<NotNan<f64>>,
#[serde(default)]
pub subgraphs: Vec<SubgraphId>,
#[serde(default)]
Expand All @@ -61,7 +54,7 @@ pub struct AuthContext {
/// This is used to disable the payment requirement on testnets. If `true`, all queries will be
/// checked for the `query_status` of their API key.
pub payment_required: bool,
pub api_keys: watch::Receiver<HashMap<String, APIKey>>,
pub api_keys: watch::Receiver<HashMap<String, ApiKey>>,
pub special_api_keys: Arc<HashSet<String>>,
}

Expand All @@ -77,7 +70,6 @@ impl AuthContext {
key: token.to_string(),
user: String::new(),
authorized_subgraphs: vec![],
budget_usd: None,
});
}

Expand Down Expand Up @@ -109,7 +101,6 @@ impl AuthContext {
key: api_key.key.clone(),
user: api_key.user_address.clone(),
authorized_subgraphs: api_key.subgraphs.clone(),
budget_usd: api_key.max_budget_usd,
})
}
}
Expand Down
19 changes: 3 additions & 16 deletions src/client_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,9 @@ pub async fn handle_query(
let client_request: QueryBody =
serde_json::from_reader(payload.reader()).map_err(|err| Error::BadQuery(err.into()))?;

// Calculate the budget for the query
let grt_per_usd = *ctx.grt_per_usd.borrow();
let one_grt = NotNan::new(1e18).unwrap();
let budget = {
let mut budget = *(ctx.budgeter.query_fees_target.0 * grt_per_usd * one_grt) as u128;
if let Some(user_budget_usd) = auth.budget_usd {
// Security: Consumers can and will set their budget to unreasonably high values.
// This `.min` prevents the budget from being set far beyond what it would be
// automatically. The reason this is important is that sometimes queries are
// subsidized, and we would be at-risk to allow arbitrarily high values.
let max_budget = budget * 10;

budget = (*(user_budget_usd * grt_per_usd * one_grt) as u128).min(max_budget);
}
budget
};
let budget = *(ctx.budgeter.query_fees_target.0 * grt_per_usd * one_grt) as u128;

let (tx, mut rx) = mpsc::channel(1);
tokio::spawn(
Expand Down Expand Up @@ -838,7 +825,7 @@ mod tests {
use tower::ServiceExt;

use crate::{
auth::{APIKey, AuthContext, AuthSettings},
auth::{ApiKey, AuthContext, AuthSettings},
middleware::{legacy_auth_adapter, RequireAuthorizationLayer},
};

Expand All @@ -852,7 +839,7 @@ mod tests {
if let Some(key) = key {
ctx.api_keys = watch::channel(HashMap::from([(
key.into(),
APIKey {
ApiKey {
key: key.into(),
..Default::default()
},
Expand Down
4 changes: 2 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use thegraph_core::{
};
use url::Url;

use crate::{auth::APIKey, network::subgraph_client::TrustedIndexer};
use crate::{auth::ApiKey, network::subgraph_client::TrustedIndexer};

/// The Graph Gateway configuration.
#[serde_as]
Expand Down Expand Up @@ -93,7 +93,7 @@ pub enum ApiKeys {
special: Vec<String>,
},
/// Fixed set of API keys
Fixed(Vec<APIKey>),
Fixed(Vec<ApiKey>),
}

#[derive(Deserialize)]
Expand Down
6 changes: 2 additions & 4 deletions src/middleware/require_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,11 @@ mod tests {
use headers::{Authorization, ContentType, HeaderMapExt};
use http_body_util::BodyExt;
use hyper::http;
use ordered_float::NotNan;
use tokio::sync::watch;
use tokio_test::assert_ready_ok;

use super::{AuthContext, AuthSettings, RequireAuthorizationLayer};
use crate::auth::APIKey;
use crate::auth::ApiKey;

fn test_auth_ctx(key: Option<&str>) -> AuthContext {
let mut ctx = AuthContext {
Expand All @@ -190,9 +189,8 @@ mod tests {
if let Some(key) = key {
ctx.api_keys = watch::channel(HashMap::from([(
key.into(),
APIKey {
ApiKey {
key: key.into(),
max_budget_usd: Some(NotNan::new(1e3).unwrap()),
..Default::default()
},
)]))
Expand Down
12 changes: 5 additions & 7 deletions src/subgraph_studio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ use tokio::{
};
use url::Url;

use crate::auth::{APIKey, QueryStatus};
use crate::auth::{ApiKey, QueryStatus};

pub async fn api_keys(
client: reqwest::Client,
url: Url,
auth: String,
) -> watch::Receiver<HashMap<String, APIKey>> {
) -> watch::Receiver<HashMap<String, ApiKey>> {
let (tx, mut rx) = watch::channel(Default::default());
let mut client = Client { client, url, auth };
tokio::spawn(async move {
Expand Down Expand Up @@ -44,7 +44,7 @@ struct Client {
}

impl Client {
async fn fetch_api_keys(&mut self) -> anyhow::Result<HashMap<String, APIKey>> {
async fn fetch_api_keys(&mut self) -> anyhow::Result<HashMap<String, ApiKey>> {
#[derive(Deserialize, Debug)]
struct _ApiKeys {
api_keys: Vec<_ApiKey>,
Expand All @@ -54,7 +54,6 @@ impl Client {
key: String,
user_address: String,
query_status: QueryStatus,
max_budget: Option<f64>,
#[serde(default)]
subgraphs: Vec<String>,
#[serde(default)]
Expand All @@ -73,12 +72,11 @@ impl Client {
.api_keys
.into_iter()
.map(|api_key| {
let api_key = APIKey {
let api_key = ApiKey {
key: api_key.key,
user_address: api_key.user_address,
query_status: api_key.query_status,
domains: api_key.domains,
max_budget_usd: api_key.max_budget.and_then(|b| b.try_into().ok()),
subgraphs: api_key
.subgraphs
.into_iter()
Expand All @@ -87,7 +85,7 @@ impl Client {
};
(api_key.key.clone(), api_key)
})
.collect::<HashMap<String, APIKey>>();
.collect::<HashMap<String, ApiKey>>();

tracing::info!(api_keys = api_keys.len());
Ok(api_keys)
Expand Down

0 comments on commit 9284635

Please sign in to comment.