diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 870c46d1..17a9c47d 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -106,7 +106,18 @@ impl Default for CASAPIClient { impl CASAPIClient { pub fn new(endpoint: &str, token: Option) -> Self { - let client = reqwest::Client::builder().build().unwrap(); + let mut headers = HeaderMap::new(); + if let Some(tok) = &token { + headers.insert( + reqwest::header::AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap(), + ); + } + + let client = reqwest::Client::builder() + .default_headers(headers) + .build() + .unwrap(); Self { client, endpoint: endpoint.to_string(), @@ -116,12 +127,7 @@ impl CASAPIClient { pub async fn exists(&self, key: &Key) -> Result { let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?; - let response = self - .client - .head(url) - .headers(self.request_headers()) - .send() - .await?; + let response = self.client.head(url).send().await?; match response.status() { StatusCode::OK => Ok(true), StatusCode::NOT_FOUND => Ok(false), @@ -133,12 +139,7 @@ impl CASAPIClient { pub async fn get_length(&self, key: &Key) -> Result> { let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?; - let response = self - .client - .head(url) - .headers(self.request_headers()) - .send() - .await?; + let response = self.client.head(url).send().await?; let status = response.status(); if status == StatusCode::NOT_FOUND { return Ok(None); @@ -189,13 +190,7 @@ impl CASAPIClient { writer.set_position(0); let data = writer.into_inner(); - let response = self - .client - .post(url) - .headers(self.request_headers()) - .body(data) - .send() - .await?; + let response = self.client.post(url).body(data).send().await?; let response_body = response.bytes().await?; let response_parsed: UploadXorbResponse = serde_json::from_reader(response_body.reader())?; @@ -247,12 +242,7 @@ impl CASAPIClient { file_id.hex() ))?; - let response = self - .client - .get(url) - .headers(self.request_headers()) - .send() - .await?; + let response = self.client.get(url).send().await?; let response_body = response.bytes().await?; let response_parsed: QueryReconstructionResponse = serde_json::from_reader(response_body.reader())?; @@ -262,29 +252,13 @@ impl CASAPIClient { pub async fn shard_query_chunk(&self, key: &Key) -> Result { let url = Url::parse(&format!("{}/chunk/{key}", self.endpoint))?; - let response = self - .client - .get(url) - .headers(self.request_headers()) - .send() - .await?; + let response = self.client.get(url).send().await?; let response_body = response.bytes().await?; let response_parsed: QueryChunkResponse = serde_json::from_reader(response_body.reader())?; Ok(response_parsed) } - fn request_headers(&self) -> HeaderMap { - let mut headers = HeaderMap::new(); - if let Some(tok) = &self.token { - headers.insert( - "Authorization", - HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap(), - ); - } - headers - } - async fn post_json(&self, url: Url, request_body: &ReqT) -> Result where ReqT: Serialize, diff --git a/shard_client/src/http_shard_client.rs b/shard_client/src/http_shard_client.rs index f1bfe5bb..60006918 100644 --- a/shard_client/src/http_shard_client.rs +++ b/shard_client/src/http_shard_client.rs @@ -11,7 +11,10 @@ use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDB use mdb_shard::shard_dedup_probe::ShardDedupProber; use mdb_shard::shard_file_reconstructor::FileReconstructor; use merklehash::MerkleHash; -use reqwest::{Url, header::{HeaderMap, HeaderValue}}; +use reqwest::{ + header::{HeaderMap, HeaderValue}, + Url, +}; use retry_strategy::RetryStrategy; use tracing::warn; @@ -29,10 +32,20 @@ pub struct HttpShardClient { impl HttpShardClient { pub fn new(endpoint: &str, token: Option) -> Self { + let mut headers = HeaderMap::new(); + if let Some(tok) = &token { + headers.insert( + reqwest::header::AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap(), + ); + } HttpShardClient { endpoint: endpoint.into(), token, - client: reqwest::Client::builder().build().unwrap(), + client: reqwest::Client::builder() + .default_headers(headers) + .build() + .unwrap(), // Retry policy: Exponential backoff starting at BASE_RETRY_DELAY_MS and retrying NUM_RETRIES times retry_strategy: RetryStrategy::new(NUM_RETRIES, BASE_RETRY_DELAY_MS), } @@ -71,23 +84,15 @@ impl RegistrationClient for HttpShardClient { }; let url = Url::parse(&format!("{}/shard/{key}", self.endpoint))?; - - let mut headers = HeaderMap::new(); - if let Some(tok) = &self.token { - headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap()); - } let response = self .retry_strategy .retry( - || { - let headers = headers.clone(); - async { - let url = url.clone(); - match force_sync { - true => self.client.put(url).headers(headers).body(shard_data.to_vec()).send().await, - false => self.client.post(url).headers(headers).body(shard_data.to_vec()).send().await, - } + || async { + let url = url.clone(); + match force_sync { + true => self.client.put(url).body(shard_data.to_vec()).send().await, + false => self.client.post(url).body(shard_data.to_vec()).send().await, } }, is_status_retriable_and_print,