Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plumb endpoint and access token #19

Merged
merged 5 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions cas_client/src/remote_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use cas::key::Key;
use cas_types::{
QueryChunkResponse, QueryReconstructionResponse, UploadXorbResponse
};
use reqwest::{StatusCode, Url};
use reqwest::{header::{HeaderMap, HeaderValue}, StatusCode, Url};
use serde::{de::DeserializeOwned, Serialize};

use bytes::Bytes;
Expand Down Expand Up @@ -84,29 +84,30 @@ impl Client for RemoteClient {
}

impl RemoteClient {
pub async fn from_config(endpoint: String) -> Self {
Self { client: CASAPIClient::new(&endpoint) }
pub async fn from_config(endpoint: String, token: Option<String>) -> Self {
Self { client: CASAPIClient::new(&endpoint, token) }
}
}

#[derive(Debug)]
pub struct CASAPIClient {
client: reqwest::Client,
endpoint: String,
token: Option<String>,
}

impl Default for CASAPIClient {
fn default() -> Self {
Self::new(CAS_ENDPOINT)
Self::new(CAS_ENDPOINT, None)
}
}

impl CASAPIClient {
pub fn new(endpoint: &str) -> Self {
pub fn new(endpoint: &str, token: Option<String>) -> Self {
let client = reqwest::Client::builder()
.build()
.unwrap();
Self { client, endpoint: endpoint.to_string() }
Self { client, endpoint: endpoint.to_string(), token }
}

pub async fn exists(&self, key: &Key) -> Result<bool> {
Expand Down Expand Up @@ -182,7 +183,6 @@ impl CASAPIClient {

/// Reconstruct a file and write to writer.
pub async fn write_file<W: Write>(&self, file_id: &MerkleHash, writer: &mut W) -> Result<usize> {

// get manifest of xorbs to download
let manifest = self.reconstruct_file(file_id).await?;

Expand Down Expand Up @@ -213,8 +213,13 @@ impl CASAPIClient {
/// Reconstruct the file
async fn reconstruct_file(&self, file_id: &MerkleHash) -> Result<QueryReconstructionResponse> {
let url = Url::parse(&format!("{}/reconstruction/{}", self.endpoint, file_id.hex()))?;

let response = self.client.get(url).send().await?;

let mut headers = HeaderMap::new();
if let Some(tok) = &self.token {
headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap());
}

let response = self.client.get(url).headers(headers).send().await?;
let response_body = response.bytes().await?;
let response_parsed: QueryReconstructionResponse = serde_json::from_reader(response_body.reader())?;

Expand Down Expand Up @@ -279,7 +284,7 @@ mod tests {
#[tokio::test]
async fn test_basic_put() {
// Arrange
let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string()).await;
let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string(), None).await;
let prefix = PREFIX_DEFAULT;
let (hash, data, chunk_boundaries) = gen_dummy_xorb(3, 10248, true);

Expand Down
2 changes: 1 addition & 1 deletion data/src/bin/example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ async fn smudge(mut reader: impl Read, mut writer: impl Write) -> Result<()> {
let translator = PointerFileTranslator::new(default_smudge_config()?).await?;

translator
.smudge_file_from_pointer(&pointer_file, &mut writer, None)
.smudge_file_from_pointer(&pointer_file, &mut writer, None, None, None)
.await?;

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion data/src/cas_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub(crate) async fn create_cas_client(

// Raw remote client.
let remote_client = Arc::new(
RemoteClient::from_config(endpoint.to_string()).await,
RemoteClient::from_config(endpoint.to_string(), None).await,
bpronan marked this conversation as resolved.
Show resolved Hide resolved
);

// Try add in caching capability.
Expand Down
1 change: 1 addition & 0 deletions data/src/configurations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct CacheConfig {
#[derive(Debug)]
pub struct StorageConfig {
pub endpoint: Endpoint,
pub token: Option<String>,
pub auth: Auth,
bpronan marked this conversation as resolved.
Show resolved Hide resolved
pub prefix: String,
pub cache_config: Option<CacheConfig>,
Expand Down
16 changes: 13 additions & 3 deletions data/src/data_processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,10 @@ impl PointerFileTranslator {
pointer: &PointerFile,
writer: &mut impl std::io::Write,
range: Option<(usize, usize)>,
endpoint: Option<String>,
token: Option<String>,
) -> Result<()> {
self.smudge_file_from_hash(&pointer.hash()?, writer, range)
self.smudge_file_from_hash(&pointer.hash()?, writer, range, endpoint, token)
.await
}

Expand All @@ -323,13 +325,21 @@ impl PointerFileTranslator {
file_id: &MerkleHash,
writer: &mut impl std::io::Write,
_range: Option<(usize, usize)>,
endpoint: Option<String>,
token: Option<String>,
) -> Result<()> {
let endpoint = match &self.config.cas_storage_config.endpoint {
Endpoint::Server(endpoint) => endpoint.clone(),
Endpoint::Server(config_endpoint) => {
if let Some(endpoint) = endpoint {
endpoint
} else {
config_endpoint.clone()
}
},
Endpoint::FileSystem(_) => panic!("aaaaaaaa no server"),
};

let rc = CASAPIClient::new(&endpoint);
let rc = CASAPIClient::new(&endpoint, token);

rc.write_file(file_id, writer).await?;

Expand Down
2 changes: 1 addition & 1 deletion data/src/shard_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub async fn create_shard_client(
) -> Result<Arc<dyn ShardClientInterface>> {
info!("Shard endpoint = {:?}", shard_storage_config.endpoint);
let client: Arc<dyn ShardClientInterface> = match &shard_storage_config.endpoint {
Server(endpoint) => Arc::new(HttpShardClient::new(endpoint)),
Server(endpoint) => Arc::new(HttpShardClient::new(endpoint, shard_storage_config.token.clone())),
FileSystem(path) => Arc::new(LocalShardClient::new(path).await?),
};

Expand Down
4 changes: 3 additions & 1 deletion hf_xet/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ use data::{DEFAULT_BLOCK_SIZE, errors};

pub const SMALL_FILE_THRESHOLD: usize = 1;

pub fn default_config(endpoint: String) -> errors::Result<TranslatorConfig> {
pub fn default_config(endpoint: String, token: Option<String>) -> errors::Result<TranslatorConfig> {
let path = current_dir()?.join(".xet");
fs::create_dir_all(&path)?;

let translator_config = TranslatorConfig {
file_query_policy: FileQueryPolicy::ServerOnly,
cas_storage_config: StorageConfig {
endpoint: Endpoint::Server(endpoint.clone()),
token: token.clone(),
auth: Auth {
user_id: "".into(),
login_id: "".into(),
Expand All @@ -27,6 +28,7 @@ pub fn default_config(endpoint: String) -> errors::Result<TranslatorConfig> {
},
shard_storage_config: StorageConfig {
endpoint: Endpoint::Server(endpoint),
token: token,
auth: Auth {
user_id: "".into(),
login_id: "".into(),
Expand Down
27 changes: 16 additions & 11 deletions hf_xet/src/data_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@ use crate::config::default_config;
pub const MAX_CONCURRENT_UPLOADS: usize = 8; // TODO
pub const MAX_CONCURRENT_DOWNLOADS: usize = 8; // TODO

const DEFAULT_CAS_ENDPOINT: &str = "https://cas-server.us.dev.moon.huggingface.tech";
const DEFAULT_CAS_ENDPOINT: &str = "http://localhost:8080";
const READ_BLOCK_SIZE: usize = 1024 * 1024;

pub async fn upload_async(file_paths: Vec<String>) -> errors::Result<Vec<PointerFile>> {
pub async fn upload_async(file_paths: Vec<String>, endpoint: Option<String>, token: Option<String>) -> errors::Result<Vec<PointerFile>> {
// chunk files
// produce Xorbs + Shards
// upload shards and xorbs
// for each file, return the filehash
let endpoint = endpoint.unwrap_or(DEFAULT_CAS_ENDPOINT.to_string());

let config = default_config(DEFAULT_CAS_ENDPOINT.to_string())?;
let config = default_config(endpoint, token)?;
let processor = Arc::new(PointerFileTranslator::new(config).await?);
let processor = &processor;
// for all files, clean them, producing pointer files.
Expand All @@ -45,16 +46,20 @@ pub async fn upload_async(file_paths: Vec<String>) -> errors::Result<Vec<Pointer
Ok(pointers)
}

pub async fn download_async(pointer_files: Vec<PointerFile>) -> errors::Result<Vec<String>> {
let config = default_config(DEFAULT_CAS_ENDPOINT.to_string())?;
pub async fn download_async(pointer_files: Vec<PointerFile>, endpoint: Option<String>, token: Option<String>) -> errors::Result<Vec<String>> {
let config = default_config(endpoint.clone().unwrap_or(DEFAULT_CAS_ENDPOINT.to_string()), token.clone())?;
let processor = Arc::new(PointerFileTranslator::new(config).await?);
let processor = &processor;
let paths = tokio_par_for_each(
pointer_files,
MAX_CONCURRENT_DOWNLOADS,
|pointer_file, _| async move {
let proc = processor.clone();
smudge_file(&proc, &pointer_file).await
|pointer_file, _| {
let tok = token.clone();
let end = endpoint.clone();
async move {
let proc = processor.clone();
smudge_file(&proc, &pointer_file, end.clone(), tok.clone()).await
}
},
).await.map_err(|e| match e {
ParallelError::JoinError => {
Expand Down Expand Up @@ -87,13 +92,13 @@ async fn clean_file(processor: &PointerFileTranslator, f: String) -> errors::Res
Ok(pf)
}

async fn smudge_file(proc: &PointerFileTranslator, pointer_file: &PointerFile) -> errors::Result<String> {
async fn smudge_file(proc: &PointerFileTranslator, pointer_file: &PointerFile, endpoint: Option<String>, token: Option<String>) -> errors::Result<String> {
let path = PathBuf::from(pointer_file.path());
if let Some(parent_dir) = path.parent() {
fs::create_dir_all(parent_dir)?;
}
let mut f = File::create(&path)?;
proc.smudge_file_from_pointer(&pointer_file, &mut f, None).await?;
proc.smudge_file_from_pointer(&pointer_file, &mut f, None, endpoint, token).await?;
Ok(pointer_file.path().to_string())
}

Expand Down Expand Up @@ -122,7 +127,7 @@ mod tests {
let pointers = vec![
PointerFile::init_from_info("/tmp/foo.rs", "6999733a46030e67f6f020651c91442ace735572458573df599106e54646867c", 4203),
];
let paths = download_async(pointers).await.unwrap();
let paths = download_async(pointers, "http://localhost:8080", "12345").await.unwrap();
println!("paths: {paths:?}");
}
}
Expand Down
12 changes: 6 additions & 6 deletions hf_xet/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,29 @@ use pyo3::prelude::*;
use data::PointerFile;

#[pyfunction]
#[pyo3(signature = (file_paths), text_signature = "(file_paths: List[str]) -> List[PyPointerFile]")]
pub fn upload_files(file_paths: Vec<String>) -> PyResult<Vec<PyPointerFile>> {
#[pyo3(signature = (file_paths, endpoint, token), text_signature = "(file_paths: List[str], endpoint: Optional[str], token: Optional[str]) -> List[str]")]
bpronan marked this conversation as resolved.
Show resolved Hide resolved
pub fn upload_files(file_paths: Vec<String>, endpoint: Option<String>, token: Option<String>) -> PyResult<Vec<PyPointerFile>> {
Ok(tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async {
data_client::upload_async(file_paths).await
data_client::upload_async(file_paths, endpoint, token).await
}).map_err(|e| PyException::new_err(format!("{e:?}")))?
.into_iter()
.map(PyPointerFile::from)
.collect())
}

#[pyfunction]
#[pyo3(signature = (files), text_signature = "(files: List[PyPointerFile]) -> List[str]")]
pub fn download_files(files: Vec<PyPointerFile>) -> PyResult<Vec<String>> {
#[pyo3(signature = (files, endpoint, token), text_signature = "(files: List[PyPointerFile], endpoint: Optional[str], token: Optional[str]) -> List[str]")]
pub fn download_files(files: Vec<PyPointerFile>, endpoint: Option<String>, token: Option<String>) -> PyResult<Vec<String>> {
let pfs = files.into_iter().map(PointerFile::from)
.collect();
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async move {
data_client::download_async(pfs).await
data_client::download_async(pfs, endpoint, token).await
}).map_err(|e| PyException::new_err(format!("{e:?}")))
}

Expand Down
27 changes: 19 additions & 8 deletions shard_client/src/http_shard_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDB
use mdb_shard::shard_dedup_probe::ShardDedupProber;
use mdb_shard::shard_file_reconstructor::FileReconstructor;
use merklehash::MerkleHash;
use reqwest::Url;
use reqwest::{Url, header::{HeaderMap, HeaderValue}};
use retry_strategy::RetryStrategy;
use tracing::warn;

Expand All @@ -22,14 +22,16 @@ const BASE_RETRY_DELAY_MS: u64 = 3000;
#[derive(Debug)]
pub struct HttpShardClient {
pub endpoint: String,
pub token: Option<String>,
client: reqwest::Client,
retry_strategy: RetryStrategy,
}

impl HttpShardClient {
pub fn new(endpoint: &str) -> Self {
pub fn new(endpoint: &str, token: Option<String>) -> Self {
HttpShardClient {
endpoint: endpoint.into(),
token: token,
client: reqwest::Client::builder().build().unwrap(),
// Retry policy: Exponential backoff starting at BASE_RETRY_DELAY_MS and retrying NUM_RETRIES times
retry_strategy: RetryStrategy::new(NUM_RETRIES, BASE_RETRY_DELAY_MS),
Expand Down Expand Up @@ -69,14 +71,23 @@ impl RegistrationClient for HttpShardClient {
};

let url = Url::parse(&format!("{}/shard/{key}", self.endpoint))?;

let mut headers = HeaderMap::new();
if let Some(tok) = &self.token {
headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap());
}

let response = self
.retry_strategy
.retry(
|| async {
let url = url.clone();
match force_sync {
true => self.client.put(url).body(shard_data.to_vec()).send().await,
false => self.client.post(url).body(shard_data.to_vec()).send().await,
|| {
let headers = headers.clone();
async {
let url = url.clone();
match force_sync {
true => self.client.put(url).headers(headers).body(shard_data.to_vec()).send().await,
false => self.client.post(url).headers(headers).body(shard_data.to_vec()).send().await,
}
}
},
is_status_retriable_and_print,
Expand Down Expand Up @@ -199,7 +210,7 @@ mod test {
#[tokio::test]
#[ignore = "need a local cas_server running"]
async fn test_local() -> anyhow::Result<()> {
let client = HttpShardClient::new("http://localhost:8080");
let client = HttpShardClient::new("http://localhost:8080", None);

let path =
PathBuf::from("./a7de567477348b23d23b667dba4d63d533c2ba7337cdc4297970bb494ba4699e.mdb");
Expand Down