Skip to content

Commit

Permalink
Add token refresh to CAS/Shard Client (#30)
Browse files Browse the repository at this point in the history
Co-authored-by: Brian Ronan <[email protected]>
  • Loading branch information
jgodlew and bpronan authored Oct 1, 2024
1 parent d087b27 commit 71c6b62
Show file tree
Hide file tree
Showing 28 changed files with 769 additions and 245 deletions.
19 changes: 19 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions cas_client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ rustls-pemfile = "2.0.0"
hyper-rustls = { version = "0.26.0", features = ["http2"] }
lz4 = "1.24.0"
reqwest = "0.12.7"
reqwest-middleware = "0.3.3"
serde = { version = "1.0.210", features = ["derive"] }
cas_types = { version = "0.1.0", path = "../cas_types" }
url = "2.5.2"
Expand Down
63 changes: 63 additions & 0 deletions cas_client/src/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use anyhow::anyhow;
use cas::auth::{AuthConfig, TokenProvider};
use reqwest::header::HeaderValue;
use reqwest::header::AUTHORIZATION;
use reqwest::{Request, Response};
use reqwest_middleware::{Middleware, Next};
use std::sync::{Arc, Mutex};

/// AuthMiddleware is a thread-safe middleware that adds a CAS auth token to outbound requests.
/// If the token it holds is expired, it will automatically be refreshed.
pub struct AuthMiddleware {
token_provider: Arc<Mutex<TokenProvider>>,
}

impl AuthMiddleware {
/// Fetches a token from our TokenProvider. This locks the TokenProvider as we might need
/// to refresh the token if it has expired.
///
/// In the common case, this lock is held only to read the underlying token stored
/// in memory. However, in the event of an expired token (e.g. once every 15 min),
/// we will need to hold the lock while making a call to refresh the token
/// (e.g. to a remote service). During this time, no other CAS requests can proceed
/// from this client until the token has been fetched. This is expected/ok since we
/// don't have a valid token and thus any calls would fail.
fn get_token(&self) -> Result<String, anyhow::Error> {
let mut provider = self
.token_provider
.lock()
.map_err(|e| anyhow!("lock error: {e:?}"))?;
provider
.get_valid_token()
.map_err(|e| anyhow!("couldn't get token: {e:?}"))
}
}

impl From<&AuthConfig> for AuthMiddleware {
fn from(cfg: &AuthConfig) -> Self {
Self {
token_provider: Arc::new(Mutex::new(TokenProvider::new(cfg))),
}
}
}

#[async_trait::async_trait]
impl Middleware for AuthMiddleware {
async fn handle(
&self,
mut req: Request,
extensions: &mut hyper::http::Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
let token = self
.get_token()
.map_err(reqwest_middleware::Error::Middleware)?;

let headers = req.headers_mut();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", token)).unwrap(),
);
next.run(req, extensions).await
}
}
2 changes: 2 additions & 0 deletions cas_client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ pub enum CasClientError {
#[error("Parse Error: {0}")]
ParseError(#[from] url::ParseError),

#[error("ReqwestMiddleware Error: {0}")]
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
#[error("Reqwest Error: {0}")]
ReqwestError(#[from] reqwest::Error),

Expand Down
3 changes: 3 additions & 0 deletions cas_client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
#![allow(dead_code)]

pub use crate::error::CasClientError;
pub use auth::AuthMiddleware;
pub use caching_client::{CachingClient, DEFAULT_BLOCK_SIZE};
pub use interface::Client;
pub use local_client::LocalClient;
pub use merklehash::MerkleHash; // re-export since this is required for the client API.
pub use passthrough_staging_client::PassthroughStagingClient;
pub use remote_client::build_reqwest_client;
pub use remote_client::CASAPIClient;
pub use remote_client::RemoteClient;
pub use staging_client::{new_staging_client, new_staging_client_with_progressbar, StagingClient};
pub use staging_trait::{Staging, StagingBypassable};

mod auth;
mod caching_client;
mod cas_connection_pool;
mod client_adapter;
Expand Down
117 changes: 52 additions & 65 deletions cas_client/src/remote_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,23 @@ use std::io::{Cursor, Write};

use anyhow::anyhow;
use bytes::Buf;
use cas::key::Key;
use cas_types::{QueryChunkResponse, QueryReconstructionResponse, UploadXorbResponse};
use reqwest::{
header::{HeaderMap, HeaderValue},
StatusCode, Url,
};
use bytes::Bytes;
use reqwest::{StatusCode, Url};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware};
use serde::{de::DeserializeOwned, Serialize};
use tracing::{debug, warn};

use bytes::Bytes;
use cas::auth::AuthConfig;
use cas::key::Key;
use cas_object::CasObject;
use cas_types::CASReconstructionTerm;
use tracing::{debug, warn};

use crate::{error::Result, CasClientError};

use cas_types::{QueryChunkResponse, QueryReconstructionResponse, UploadXorbResponse};
use error_printer::OptionPrinter;
use merklehash::MerkleHash;

use crate::Client;
use crate::{error::Result, AuthMiddleware, CasClientError};

pub const CAS_ENDPOINT: &str = "http://localhost:8080";
pub const PREFIX_DEFAULT: &str = "default";

Expand Down Expand Up @@ -84,44 +83,37 @@ impl Client for RemoteClient {
}

impl RemoteClient {
pub async fn from_config(endpoint: String, token: Option<String>) -> Self {
pub async fn from_config(endpoint: String, auth_config: &Option<AuthConfig>) -> Self {
Self {
client: CASAPIClient::new(&endpoint, token),
client: CASAPIClient::new(&endpoint, auth_config),
}
}
}

#[derive(Debug)]
pub struct CASAPIClient {
client: reqwest::Client,
client: ClientWithMiddleware,
endpoint: String,
token: Option<String>,
}

impl Default for CASAPIClient {
fn default() -> Self {
Self::new(CAS_ENDPOINT, None)
Self::new(CAS_ENDPOINT, &None)
}
}

impl CASAPIClient {
pub fn new(endpoint: &str, token: Option<String>) -> Self {
let client = reqwest::Client::builder().build().unwrap();
pub fn new(endpoint: &str, auth_config: &Option<AuthConfig>) -> Self {
let client = build_reqwest_client(auth_config).unwrap();
Self {
client,
endpoint: endpoint.to_string(),
token,
}
}

pub async fn exists(&self, key: &Key) -> Result<bool> {
let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?;
let response = self
.client
.head(url)
.headers(self.request_headers())
.send()
.await?;
let response = self.client.head(url).send().await?;
match response.status() {
StatusCode::OK => Ok(true),
StatusCode::NOT_FOUND => Ok(false),
Expand All @@ -133,12 +125,7 @@ impl CASAPIClient {

pub async fn get_length(&self, key: &Key) -> Result<Option<u64>> {
let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?;
let response = self
.client
.head(url)
.headers(self.request_headers())
.send()
.await?;
let response = self.client.head(url).send().await?;
let status = response.status();
if status == StatusCode::NOT_FOUND {
return Ok(None);
Expand Down Expand Up @@ -189,13 +176,7 @@ impl CASAPIClient {
writer.set_position(0);
let data = writer.into_inner();

let response = self
.client
.post(url)
.headers(self.request_headers())
.body(data)
.send()
.await?;
let response = self.client.post(url).body(data).send().await?;
let response_body = response.bytes().await?;
let response_parsed: UploadXorbResponse = serde_json::from_reader(response_body.reader())?;

Expand Down Expand Up @@ -247,12 +228,7 @@ impl CASAPIClient {
file_id.hex()
))?;

let response = self
.client
.get(url)
.headers(self.request_headers())
.send()
.await?;
let response = self.client.get(url).send().await?;
let response_body = response.bytes().await?;
let response_parsed: QueryReconstructionResponse =
serde_json::from_reader(response_body.reader())?;
Expand All @@ -262,29 +238,13 @@ impl CASAPIClient {

pub async fn shard_query_chunk(&self, key: &Key) -> Result<QueryChunkResponse> {
let url = Url::parse(&format!("{}/chunk/{key}", self.endpoint))?;
let response = self
.client
.get(url)
.headers(self.request_headers())
.send()
.await?;
let response = self.client.get(url).send().await?;
let response_body = response.bytes().await?;
let response_parsed: QueryChunkResponse = serde_json::from_reader(response_body.reader())?;

Ok(response_parsed)
}

fn request_headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
if let Some(tok) = &self.token {
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap(),
);
}
headers
}

async fn post_json<ReqT, RespT>(&self, url: Url, request_body: &ReqT) -> Result<RespT>
where
ReqT: Serialize,
Expand Down Expand Up @@ -330,22 +290,49 @@ async fn get_one(term: &CASReconstructionTerm) -> Result<Bytes> {
Ok(Bytes::from(sliced))
}

/// builds the client to talk to CAS.
pub fn build_reqwest_client(
auth_config: &Option<AuthConfig>,
) -> std::result::Result<ClientWithMiddleware, reqwest::Error> {
let auth_middleware = auth_config
.as_ref()
.map(AuthMiddleware::from)
.info_none("CAS auth disabled");
let reqwest_client = reqwest::Client::builder().build()?;
Ok(ClientBuilder::new(reqwest_client)
.maybe_with(auth_middleware)
.build())
}

/// Helper trait to allow the reqwest_middleware client to optionally add a middleware.
trait OptionalMiddleware {
fn maybe_with<M: Middleware>(self, middleware: Option<M>) -> Self;
}

impl OptionalMiddleware for ClientBuilder {
fn maybe_with<M: Middleware>(self, middleware: Option<M>) -> Self {
match middleware {
Some(m) => self.with(m),
None => self,
}
}
}

#[cfg(test)]
mod tests {

use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB};
use merklehash::DataHash;
use rand::Rng;
use tracing_test::traced_test;

use super::*;
use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB};
use merklehash::DataHash;

#[ignore]
#[traced_test]
#[tokio::test]
async fn test_basic_put() {
// Arrange
let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string(), None).await;
let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string(), &None).await;
let prefix = PREFIX_DEFAULT;
let (hash, data, chunk_boundaries) = gen_dummy_xorb(3, 10248, true);

Expand Down
Loading

0 comments on commit 71c6b62

Please sign in to comment.