diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 1391ae55..15823563 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -4,7 +4,7 @@ use anyhow::anyhow; use bytes::Buf; use cas::key::Key; use cas_types::{QueryChunkResponse, QueryReconstructionResponse, UploadXorbResponse}; -use reqwest::{StatusCode, Url}; +use reqwest::{header::{HeaderMap, HeaderValue}, StatusCode, Url}; use serde::{de::DeserializeOwned, Serialize}; use bytes::Bytes; @@ -82,9 +82,9 @@ impl Client for RemoteClient { } impl RemoteClient { - pub async fn from_config(endpoint: String) -> Self { - Self { - client: CASAPIClient::new(&endpoint), + pub async fn from_config(endpoint: String, token: Option) -> Self { + Self { + client: CASAPIClient::new(&endpoint, token) } } } @@ -93,20 +93,22 @@ impl RemoteClient { pub struct CASAPIClient { client: reqwest::Client, endpoint: String, + token: Option, } impl Default for CASAPIClient { fn default() -> Self { - Self::new(CAS_ENDPOINT) + Self::new(CAS_ENDPOINT, None) } } impl CASAPIClient { - pub fn new(endpoint: &str) -> Self { + pub fn new(endpoint: &str, token: Option) -> Self { let client = reqwest::Client::builder().build().unwrap(); Self { client, endpoint: endpoint.to_string(), + token } } @@ -222,12 +224,17 @@ impl CASAPIClient { /// Reconstruct the file async fn reconstruct_file(&self, file_id: &MerkleHash) -> Result { let url = Url::parse(&format!( - "{}/reconstruction/{}", - self.endpoint, + "{}/reconstruction/{}", + self.endpoint, file_id.hex() ))?; - let response = self.client.get(url).send().await?; + let mut headers = HeaderMap::new(); + if let Some(tok) = &self.token { + headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap()); + } + + let response = self.client.get(url).headers(headers).send().await?; let response_body = response.bytes().await?; let response_parsed: QueryReconstructionResponse = serde_json::from_reader(response_body.reader())?; @@ -290,7 +297,7 @@ mod tests { #[tokio::test] async fn test_basic_put() { // Arrange - let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string()).await; + let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string(), None).await; let prefix = PREFIX_DEFAULT; let (hash, data, chunk_boundaries) = gen_dummy_xorb(3, 10248, true); diff --git a/cas_types/src/lib.rs b/cas_types/src/lib.rs index fc21d52d..b8f5b9a3 100644 --- a/cas_types/src/lib.rs +++ b/cas_types/src/lib.rs @@ -22,7 +22,9 @@ pub struct CASReconstructionTerm { pub hash: HexMerkleHash, pub unpacked_length: u32, pub range: Range, - pub range_start_offset: u32, + // TODO: disabling until https://github.com/huggingface-internal/xetcas/pull/31/files + // is merged. + // pub range_start_offset: u32, pub url: String, } diff --git a/data/src/bin/example.rs b/data/src/bin/example.rs index 2c44af81..43e9b195 100644 --- a/data/src/bin/example.rs +++ b/data/src/bin/example.rs @@ -74,8 +74,7 @@ fn default_clean_config() -> Result { cas_storage_config: StorageConfig { endpoint: Endpoint::FileSystem(path.join("xorbs")), auth: Auth { - user_id: "".into(), - login_id: "".into(), + token: None, }, prefix: "default".into(), cache_config: Some(CacheConfig { @@ -88,8 +87,7 @@ fn default_clean_config() -> Result { shard_storage_config: StorageConfig { endpoint: Endpoint::FileSystem(path.join("xorbs")), auth: Auth { - user_id: "".into(), - login_id: "".into(), + token: None, }, prefix: "default-merkledb".into(), cache_config: Some(CacheConfig { @@ -123,8 +121,7 @@ fn default_smudge_config() -> Result { cas_storage_config: StorageConfig { endpoint: Endpoint::FileSystem(path.join("xorbs")), auth: Auth { - user_id: "".into(), - login_id: "".into(), + token: None, }, prefix: "default".into(), cache_config: Some(CacheConfig { @@ -137,8 +134,7 @@ fn default_smudge_config() -> Result { shard_storage_config: StorageConfig { endpoint: Endpoint::FileSystem(path.join("xorbs")), auth: Auth { - user_id: "".into(), - login_id: "".into(), + token: None, }, prefix: "default-merkledb".into(), cache_config: Some(CacheConfig { @@ -234,7 +230,7 @@ async fn smudge(mut reader: impl Read, mut writer: impl Write) -> Result<()> { let translator = PointerFileTranslator::new(default_smudge_config()?).await?; translator - .smudge_file_from_pointer(&pointer_file, &mut writer, None) + .smudge_file_from_pointer(&pointer_file, &mut writer, None, None, None) .await?; Ok(()) diff --git a/data/src/cas_interface.rs b/data/src/cas_interface.rs index e1158e2f..83709326 100644 --- a/data/src/cas_interface.rs +++ b/data/src/cas_interface.rs @@ -36,9 +36,6 @@ pub(crate) async fn create_cas_client( unreachable!(); }; - // Auth info. - let _user_id = &cas_storage_config.auth.user_id; - let _auth = &cas_storage_config.auth.login_id; // Usage tracking. let _repo_paths = maybe_repo_info @@ -49,7 +46,7 @@ pub(crate) async fn create_cas_client( // Raw remote client. let remote_client = Arc::new( - RemoteClient::from_config(endpoint.to_string()).await, + RemoteClient::from_config(endpoint.to_string(), cas_storage_config.auth.token.clone()).await, ); // Try add in caching capability. diff --git a/data/src/configurations.rs b/data/src/configurations.rs index 51217e8d..a9e76bb8 100644 --- a/data/src/configurations.rs +++ b/data/src/configurations.rs @@ -11,8 +11,7 @@ pub enum Endpoint { #[derive(Debug)] pub struct Auth { - pub user_id: String, - pub login_id: String, + pub token: Option, } #[derive(Debug)] diff --git a/data/src/data_processing.rs b/data/src/data_processing.rs index 5ffee70f..4c46091c 100644 --- a/data/src/data_processing.rs +++ b/data/src/data_processing.rs @@ -313,8 +313,10 @@ impl PointerFileTranslator { pointer: &PointerFile, writer: &mut impl std::io::Write, range: Option<(usize, usize)>, + endpoint: Option, + token: Option, ) -> Result<()> { - self.smudge_file_from_hash(&pointer.hash()?, writer, range) + self.smudge_file_from_hash(&pointer.hash()?, writer, range, endpoint, token) .await } @@ -323,13 +325,21 @@ impl PointerFileTranslator { file_id: &MerkleHash, writer: &mut impl std::io::Write, _range: Option<(usize, usize)>, + endpoint: Option, + token: Option, ) -> Result<()> { let endpoint = match &self.config.cas_storage_config.endpoint { - Endpoint::Server(endpoint) => endpoint.clone(), + Endpoint::Server(config_endpoint) => { + if let Some(endpoint) = endpoint { + endpoint + } else { + config_endpoint.clone() + } + }, Endpoint::FileSystem(_) => panic!("aaaaaaaa no server"), }; - let rc = CASAPIClient::new(&endpoint); + let rc = CASAPIClient::new(&endpoint, token); rc.write_file(file_id, writer).await?; diff --git a/data/src/shard_interface.rs b/data/src/shard_interface.rs index b49a994c..56d86089 100644 --- a/data/src/shard_interface.rs +++ b/data/src/shard_interface.rs @@ -39,7 +39,7 @@ pub async fn create_shard_client( ) -> Result> { info!("Shard endpoint = {:?}", shard_storage_config.endpoint); let client: Arc = match &shard_storage_config.endpoint { - Server(endpoint) => Arc::new(HttpShardClient::new(endpoint)), + Server(endpoint) => Arc::new(HttpShardClient::new(endpoint, shard_storage_config.auth.token.clone())), FileSystem(path) => Arc::new(LocalShardClient::new(path).await?), }; diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index b5cb214f..1cefd30e 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -346,9 +346,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "bytestream" @@ -440,7 +440,10 @@ version = "0.1.0" dependencies = [ "anyhow", "bincode", + "bytes", + "cas_types", "http 1.1.0", + "lz4_flex", "merklehash", "tempfile", "tracing", @@ -1910,6 +1913,15 @@ dependencies = [ "libc", ] +[[package]] +name = "lz4_flex" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75761162ae2b0e580d7e7c390558127e5f01b4194debd6221fd8c207fc80e3f5" +dependencies = [ + "twox-hash", +] + [[package]] name = "matchers" version = "0.1.0" @@ -4262,6 +4274,16 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if 1.0.0", + "static_assertions", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/hf_xet/src/config.rs b/hf_xet/src/config.rs index e8ac6c42..b9cec5b5 100644 --- a/hf_xet/src/config.rs +++ b/hf_xet/src/config.rs @@ -5,7 +5,7 @@ use data::{DEFAULT_BLOCK_SIZE, errors}; pub const SMALL_FILE_THRESHOLD: usize = 1; -pub fn default_config(endpoint: String) -> errors::Result { +pub fn default_config(endpoint: String, token: Option) -> errors::Result { let path = current_dir()?.join(".xet"); fs::create_dir_all(&path)?; @@ -14,8 +14,7 @@ pub fn default_config(endpoint: String) -> errors::Result { cas_storage_config: StorageConfig { endpoint: Endpoint::Server(endpoint.clone()), auth: Auth { - user_id: "".into(), - login_id: "".into(), + token: token.clone(), }, prefix: "default".into(), cache_config: Some(CacheConfig { @@ -28,8 +27,7 @@ pub fn default_config(endpoint: String) -> errors::Result { shard_storage_config: StorageConfig { endpoint: Endpoint::Server(endpoint), auth: Auth { - user_id: "".into(), - login_id: "".into(), + token: token, }, prefix: "default-merkledb".into(), cache_config: Some(CacheConfig { diff --git a/hf_xet/src/data_client.rs b/hf_xet/src/data_client.rs index b6cc1957..e10fe857 100644 --- a/hf_xet/src/data_client.rs +++ b/hf_xet/src/data_client.rs @@ -12,16 +12,17 @@ use crate::config::default_config; pub const MAX_CONCURRENT_UPLOADS: usize = 8; // TODO pub const MAX_CONCURRENT_DOWNLOADS: usize = 8; // TODO -const DEFAULT_CAS_ENDPOINT: &str = "https://cas-server.us.dev.moon.huggingface.tech"; +const DEFAULT_CAS_ENDPOINT: &str = "http://localhost:8080"; const READ_BLOCK_SIZE: usize = 1024 * 1024; -pub async fn upload_async(file_paths: Vec) -> errors::Result> { +pub async fn upload_async(file_paths: Vec, endpoint: Option, token: Option) -> errors::Result> { // chunk files // produce Xorbs + Shards // upload shards and xorbs // for each file, return the filehash + let endpoint = endpoint.unwrap_or(DEFAULT_CAS_ENDPOINT.to_string()); - let config = default_config(DEFAULT_CAS_ENDPOINT.to_string())?; + let config = default_config(endpoint, token)?; let processor = Arc::new(PointerFileTranslator::new(config).await?); let processor = &processor; // for all files, clean them, producing pointer files. @@ -45,16 +46,20 @@ pub async fn upload_async(file_paths: Vec) -> errors::Result) -> errors::Result> { - let config = default_config(DEFAULT_CAS_ENDPOINT.to_string())?; +pub async fn download_async(pointer_files: Vec, endpoint: Option, token: Option) -> errors::Result> { + let config = default_config(endpoint.clone().unwrap_or(DEFAULT_CAS_ENDPOINT.to_string()), token.clone())?; let processor = Arc::new(PointerFileTranslator::new(config).await?); let processor = &processor; let paths = tokio_par_for_each( pointer_files, MAX_CONCURRENT_DOWNLOADS, - |pointer_file, _| async move { - let proc = processor.clone(); - smudge_file(&proc, &pointer_file).await + |pointer_file, _| { + let tok = token.clone(); + let end = endpoint.clone(); + async move { + let proc = processor.clone(); + smudge_file(&proc, &pointer_file, end.clone(), tok.clone()).await + } }, ).await.map_err(|e| match e { ParallelError::JoinError => { @@ -87,13 +92,13 @@ async fn clean_file(processor: &PointerFileTranslator, f: String) -> errors::Res Ok(pf) } -async fn smudge_file(proc: &PointerFileTranslator, pointer_file: &PointerFile) -> errors::Result { +async fn smudge_file(proc: &PointerFileTranslator, pointer_file: &PointerFile, endpoint: Option, token: Option) -> errors::Result { let path = PathBuf::from(pointer_file.path()); if let Some(parent_dir) = path.parent() { fs::create_dir_all(parent_dir)?; } let mut f = File::create(&path)?; - proc.smudge_file_from_pointer(&pointer_file, &mut f, None).await?; + proc.smudge_file_from_pointer(&pointer_file, &mut f, None, endpoint, token).await?; Ok(pointer_file.path().to_string()) } @@ -122,7 +127,7 @@ mod tests { let pointers = vec![ PointerFile::init_from_info("/tmp/foo.rs", "6999733a46030e67f6f020651c91442ace735572458573df599106e54646867c", 4203), ]; - let paths = download_async(pointers).await.unwrap(); + let paths = download_async(pointers, "http://localhost:8080", "12345").await.unwrap(); println!("paths: {paths:?}"); } } diff --git a/hf_xet/src/lib.rs b/hf_xet/src/lib.rs index c28d09cb..8e3dc2b4 100644 --- a/hf_xet/src/lib.rs +++ b/hf_xet/src/lib.rs @@ -7,13 +7,13 @@ use pyo3::prelude::*; use data::PointerFile; #[pyfunction] -#[pyo3(signature = (file_paths), text_signature = "(file_paths: List[str]) -> List[PyPointerFile]")] -pub fn upload_files(file_paths: Vec) -> PyResult> { +#[pyo3(signature = (file_paths, endpoint, token), text_signature = "(file_paths: List[str], endpoint: Optional[str], token: Optional[str]) -> List[PyPointerFile]")] +pub fn upload_files(file_paths: Vec, endpoint: Option, token: Option) -> PyResult> { Ok(tokio::runtime::Builder::new_multi_thread() .enable_all() .build()? .block_on(async { - data_client::upload_async(file_paths).await + data_client::upload_async(file_paths, endpoint, token).await }).map_err(|e| PyException::new_err(format!("{e:?}")))? .into_iter() .map(PyPointerFile::from) @@ -21,15 +21,15 @@ pub fn upload_files(file_paths: Vec) -> PyResult> { } #[pyfunction] -#[pyo3(signature = (files), text_signature = "(files: List[PyPointerFile]) -> List[str]")] -pub fn download_files(files: Vec) -> PyResult> { +#[pyo3(signature = (files, endpoint, token), text_signature = "(files: List[PyPointerFile], endpoint: Optional[str], token: Optional[str]) -> List[str]")] +pub fn download_files(files: Vec, endpoint: Option, token: Option) -> PyResult> { let pfs = files.into_iter().map(PointerFile::from) .collect(); tokio::runtime::Builder::new_multi_thread() .enable_all() .build()? .block_on(async move { - data_client::download_async(pfs).await + data_client::download_async(pfs, endpoint, token).await }).map_err(|e| PyException::new_err(format!("{e:?}"))) } diff --git a/shard_client/src/http_shard_client.rs b/shard_client/src/http_shard_client.rs index ff4fbf0b..f1bfe5bb 100644 --- a/shard_client/src/http_shard_client.rs +++ b/shard_client/src/http_shard_client.rs @@ -11,7 +11,7 @@ use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDB use mdb_shard::shard_dedup_probe::ShardDedupProber; use mdb_shard::shard_file_reconstructor::FileReconstructor; use merklehash::MerkleHash; -use reqwest::Url; +use reqwest::{Url, header::{HeaderMap, HeaderValue}}; use retry_strategy::RetryStrategy; use tracing::warn; @@ -22,14 +22,16 @@ const BASE_RETRY_DELAY_MS: u64 = 3000; #[derive(Debug)] pub struct HttpShardClient { pub endpoint: String, + pub token: Option, client: reqwest::Client, retry_strategy: RetryStrategy, } impl HttpShardClient { - pub fn new(endpoint: &str) -> Self { + pub fn new(endpoint: &str, token: Option) -> Self { HttpShardClient { endpoint: endpoint.into(), + token, client: reqwest::Client::builder().build().unwrap(), // Retry policy: Exponential backoff starting at BASE_RETRY_DELAY_MS and retrying NUM_RETRIES times retry_strategy: RetryStrategy::new(NUM_RETRIES, BASE_RETRY_DELAY_MS), @@ -69,14 +71,23 @@ impl RegistrationClient for HttpShardClient { }; let url = Url::parse(&format!("{}/shard/{key}", self.endpoint))?; + + let mut headers = HeaderMap::new(); + if let Some(tok) = &self.token { + headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap()); + } + let response = self .retry_strategy .retry( - || async { - let url = url.clone(); - match force_sync { - true => self.client.put(url).body(shard_data.to_vec()).send().await, - false => self.client.post(url).body(shard_data.to_vec()).send().await, + || { + let headers = headers.clone(); + async { + let url = url.clone(); + match force_sync { + true => self.client.put(url).headers(headers).body(shard_data.to_vec()).send().await, + false => self.client.post(url).headers(headers).body(shard_data.to_vec()).send().await, + } } }, is_status_retriable_and_print, @@ -199,7 +210,7 @@ mod test { #[tokio::test] #[ignore = "need a local cas_server running"] async fn test_local() -> anyhow::Result<()> { - let client = HttpShardClient::new("http://localhost:8080"); + let client = HttpShardClient::new("http://localhost:8080", None); let path = PathBuf::from("./a7de567477348b23d23b667dba4d63d533c2ba7337cdc4297970bb494ba4699e.mdb");