Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the auth to the default headers instead of to each call #28

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 17 additions & 43 deletions cas_client/src/remote_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,18 @@ impl Default for CASAPIClient {

impl CASAPIClient {
pub fn new(endpoint: &str, token: Option<String>) -> Self {
let client = reqwest::Client::builder().build().unwrap();
let mut headers = HeaderMap::new();
if let Some(tok) = &token {
headers.insert(
reqwest::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap(),
);
}

let client = reqwest::Client::builder()
.default_headers(headers)
.build()
.unwrap();
Self {
client,
endpoint: endpoint.to_string(),
Expand All @@ -116,12 +127,7 @@ impl CASAPIClient {

pub async fn exists(&self, key: &Key) -> Result<bool> {
let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?;
let response = self
.client
.head(url)
.headers(self.request_headers())
.send()
.await?;
let response = self.client.head(url).send().await?;
match response.status() {
StatusCode::OK => Ok(true),
StatusCode::NOT_FOUND => Ok(false),
Expand All @@ -133,12 +139,7 @@ impl CASAPIClient {

pub async fn get_length(&self, key: &Key) -> Result<Option<u64>> {
let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?;
let response = self
.client
.head(url)
.headers(self.request_headers())
.send()
.await?;
let response = self.client.head(url).send().await?;
let status = response.status();
if status == StatusCode::NOT_FOUND {
return Ok(None);
Expand Down Expand Up @@ -189,13 +190,7 @@ impl CASAPIClient {
writer.set_position(0);
let data = writer.into_inner();

let response = self
.client
.post(url)
.headers(self.request_headers())
.body(data)
.send()
.await?;
let response = self.client.post(url).body(data).send().await?;
let response_body = response.bytes().await?;
let response_parsed: UploadXorbResponse = serde_json::from_reader(response_body.reader())?;

Expand Down Expand Up @@ -247,12 +242,7 @@ impl CASAPIClient {
file_id.hex()
))?;

let response = self
.client
.get(url)
.headers(self.request_headers())
.send()
.await?;
let response = self.client.get(url).send().await?;
let response_body = response.bytes().await?;
let response_parsed: QueryReconstructionResponse =
serde_json::from_reader(response_body.reader())?;
Expand All @@ -262,29 +252,13 @@ impl CASAPIClient {

pub async fn shard_query_chunk(&self, key: &Key) -> Result<QueryChunkResponse> {
let url = Url::parse(&format!("{}/chunk/{key}", self.endpoint))?;
let response = self
.client
.get(url)
.headers(self.request_headers())
.send()
.await?;
let response = self.client.get(url).send().await?;
let response_body = response.bytes().await?;
let response_parsed: QueryChunkResponse = serde_json::from_reader(response_body.reader())?;

Ok(response_parsed)
}

fn request_headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
if let Some(tok) = &self.token {
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap(),
);
}
headers
}

async fn post_json<ReqT, RespT>(&self, url: Url, request_body: &ReqT) -> Result<RespT>
where
ReqT: Serialize,
Expand Down
35 changes: 20 additions & 15 deletions shard_client/src/http_shard_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDB
use mdb_shard::shard_dedup_probe::ShardDedupProber;
use mdb_shard::shard_file_reconstructor::FileReconstructor;
use merklehash::MerkleHash;
use reqwest::{Url, header::{HeaderMap, HeaderValue}};
use reqwest::{
header::{HeaderMap, HeaderValue},
Url,
};
use retry_strategy::RetryStrategy;
use tracing::warn;

Expand All @@ -29,10 +32,20 @@ pub struct HttpShardClient {

impl HttpShardClient {
pub fn new(endpoint: &str, token: Option<String>) -> Self {
let mut headers = HeaderMap::new();
if let Some(tok) = &token {
headers.insert(
reqwest::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap(),
);
}
HttpShardClient {
endpoint: endpoint.into(),
token,
client: reqwest::Client::builder().build().unwrap(),
client: reqwest::Client::builder()
.default_headers(headers)
.build()
.unwrap(),
// Retry policy: Exponential backoff starting at BASE_RETRY_DELAY_MS and retrying NUM_RETRIES times
retry_strategy: RetryStrategy::new(NUM_RETRIES, BASE_RETRY_DELAY_MS),
}
Expand Down Expand Up @@ -71,23 +84,15 @@ impl RegistrationClient for HttpShardClient {
};

let url = Url::parse(&format!("{}/shard/{key}", self.endpoint))?;

let mut headers = HeaderMap::new();
if let Some(tok) = &self.token {
headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap());
}

let response = self
.retry_strategy
.retry(
|| {
let headers = headers.clone();
async {
let url = url.clone();
match force_sync {
true => self.client.put(url).headers(headers).body(shard_data.to_vec()).send().await,
false => self.client.post(url).headers(headers).body(shard_data.to_vec()).send().await,
}
|| async {
let url = url.clone();
match force_sync {
true => self.client.put(url).body(shard_data.to_vec()).send().await,
false => self.client.post(url).body(shard_data.to_vec()).send().await,
}
},
is_status_retriable_and_print,
Expand Down