From 80a94c34783a7661cac7dbc751031ab56aeb4925 Mon Sep 17 00:00:00 2001 From: seanses Date: Fri, 27 Sep 2024 13:14:09 -0700 Subject: [PATCH 01/19] checkpoint --- Cargo.lock | 2 +- cas_client/Cargo.toml | 2 +- cas_client/src/caching_client.rs | 324 +--------- cas_client/src/cas_connection_pool.rs | 454 -------------- cas_client/src/client_adapter.rs | 33 - cas_client/src/data_transport.rs | 606 ------------------- cas_client/src/error.rs | 38 +- cas_client/src/interface.rs | 67 +- cas_client/src/lib.rs | 14 +- cas_client/src/local_client.rs | 193 +++--- cas_client/src/passthrough_staging_client.rs | 147 ----- cas_client/src/remote_client.rs | 207 ++----- cas_client/src/staging_client.rs | 584 ------------------ cas_client/src/staging_trait.rs | 58 -- cas_client/src/util.rs | 235 ------- mdb_shard/src/file_structs.rs | 38 +- mdb_shard/src/shard_file_manager.rs | 20 +- mdb_shard/src/shard_format.rs | 40 +- mdb_shard/src/shard_in_memory.rs | 17 +- 19 files changed, 277 insertions(+), 2802 deletions(-) delete mode 100644 cas_client/src/cas_connection_pool.rs delete mode 100644 cas_client/src/client_adapter.rs delete mode 100644 cas_client/src/data_transport.rs delete mode 100644 cas_client/src/passthrough_staging_client.rs delete mode 100644 cas_client/src/staging_client.rs delete mode 100644 cas_client/src/staging_trait.rs delete mode 100644 cas_client/src/util.rs diff --git a/Cargo.lock b/Cargo.lock index 9b06f5a2..a0a451f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -434,6 +434,7 @@ dependencies = [ "clap 2.34.0", "deadpool", "error_printer", + "file_utils", "futures", "http 0.2.12", "http-body-util", @@ -449,7 +450,6 @@ dependencies = [ "opentelemetry-http", "opentelemetry-jaeger", "parutils", - "progress_reporting", "prost", "rand 0.8.5", "rcgen", diff --git a/cas_client/Cargo.toml b/cas_client/Cargo.toml index a570006f..220cde1c 100644 --- a/cas_client/Cargo.toml +++ b/cas_client/Cargo.toml @@ -11,12 +11,12 @@ strict = [] [dependencies] cas_object = {path = "../cas_object"} error_printer = {path = "../error_printer"} +file_utils = {path = "../file_utils"} utils = {path = "../utils"} merkledb = {path = "../merkledb"} merklehash = { path = "../merklehash" } parutils = {path = "../parutils"} retry_strategy = {path = "../retry_strategy"} -progress_reporting = {path = "../progress_reporting"} tonic = {version = "0.10.2", features = ["tls", "tls-roots", "transport"] } prost = "0.12.3" tokio = { version = "1.36", features = ["full"] } diff --git a/cas_client/src/caching_client.rs b/cas_client/src/caching_client.rs index 6aa7eec1..62085d89 100644 --- a/cas_client/src/caching_client.rs +++ b/cas_client/src/caching_client.rs @@ -1,74 +1,14 @@ use crate::error::Result; -use crate::interface::Client; -use crate::{client_adapter::ClientRemoteAdapter, error::CasClientError}; +use crate::interface::*; use async_trait::async_trait; -use cache::{Remote, XorbCache}; -use cas::key::Key; -use error_printer::ErrorPrinter; use merklehash::MerkleHash; -use std::collections::HashMap; -use std::fmt::Debug; -use std::ops::Range; -use std::path::Path; -use std::sync::Arc; -use tokio::sync::Mutex; -use tracing::{debug, info}; - -pub const DEFAULT_BLOCK_SIZE: u64 = 16 * 1024 * 1024; +use std::io::Write; #[derive(Debug)] -pub struct CachingClient { - client: Arc, - cache: Arc, - xorb_lengths: Arc>>, -} - -impl CachingClient { - /// Create a new caching client. - /// client: This is the client object used to satisfy requests - pub fn new( - client: T, - cache_path: &Path, - capacity_bytes: u64, - block_size: u64, - ) -> Result> { - // convert Path to String - let canonical_path = cache_path.canonicalize().map_err(|e| { - CasClientError::ConfigurationError(format!("Error specifying cache path: {e}")) - })?; - - let canonical_string_path = canonical_path.to_str().ok_or_else(|| { - CasClientError::ConfigurationError("Error parsing cache path to UTF-8 path.".to_owned()) - })?; - - let arcclient = Arc::new(client); - let client_remote_arc: Arc = - Arc::new(ClientRemoteAdapter::new(arcclient.clone())); - - let cache = cache::from_config::( - cache::CacheConfig { - cache_dir: canonical_string_path.to_string(), - capacity: capacity_bytes, - block_size, - }, - client_remote_arc, - )?; - - info!( - "Creating CachingClient, path={:?}, byte capacity={}, blocksize={:?}", - cache_path, capacity_bytes, block_size - ); - - Ok(CachingClient { - client: arcclient, - cache, - xorb_lengths: Arc::new(Mutex::new(HashMap::new())), - }) - } -} +pub struct CachingClient {} #[async_trait] -impl Client for CachingClient { +impl UploadClient for CachingClient { async fn put( &self, prefix: &str, @@ -76,253 +16,31 @@ impl Client for CachingClient { data: Vec, chunk_boundaries: Vec, ) -> Result<()> { - // puts write through - debug!( - "CachingClient put to {}/{} of length {} bytes", - prefix, - hash, - data.len() - ); - Ok(self - .client - .put(prefix, hash, data, chunk_boundaries) - .await?) + todo!() } - async fn flush(&self) -> Result<()> { - // forward flush to the underlying client - Ok(self.client.flush().await?) + async fn exists(&self, prefix: &str, hash: &MerkleHash) -> Result { + todo!() } - async fn get(&self, prefix: &str, hash: &MerkleHash) -> Result> { - // get the length, reduce to range read of the entire length. - debug!("CachingClient Get of {}/{}", prefix, hash); - let xorb_size = self - .get_length(prefix, hash) - .await - .debug_error("CachingClient Get: get_length reported error")?; - - debug!("CachingClient Get: get_length call succeeded with value {xorb_size}."); - - self.get_object_range(prefix, hash, vec![(0, xorb_size)]) - .await - .map(|mut v| v.swap_remove(0)) - } - - async fn get_object_range( - &self, - prefix: &str, - hash: &MerkleHash, - ranges: Vec<(u64, u64)>, - ) -> Result>> { - debug!( - "CachingClient GetObjectRange of {}/{}: {:?}", - prefix, hash, ranges - ); - let mut ret: Vec> = Vec::new(); - for (start, end) in ranges { - let prefix_str = prefix.to_string(); - ret.push( - self.cache - .fetch_xorb_range( - &Key { - prefix: prefix_str, - hash: *hash, - }, - Range { start, end }, - None, - ) - .await - .warn_error(format!( - "CachingClient Error on GetObjectRange of {}/{}", - prefix, hash - ))?, - ) - } - Ok(ret) - } - - async fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result { - debug!("CachingClient GetLength of {}/{}", prefix, hash); - { - // check the xorb length cache - let xorb_lengths = self.xorb_lengths.lock().await; - if let Some(l) = xorb_lengths.get(hash) { - return Ok(*l); - } - // release lock here since get_length may take a while - } - let ret = self.client.get_length(prefix, hash).await; - - if let Ok(l) = ret { - // insert it into the xorb length cache - let mut xorb_lengths = self.xorb_lengths.lock().await; - xorb_lengths.insert(*hash, l); - } - ret + async fn flush(&self) -> Result<()> { + todo!() } } -#[cfg(test)] -mod tests { - use super::DEFAULT_BLOCK_SIZE; - use crate::*; - use std::fs; - use std::path::Path; - use std::sync::Arc; - use tempfile::TempDir; - - fn path_has_files(path: &Path) -> bool { - fs::read_dir(path).unwrap().count() > 0 - } - - #[tokio::test] - async fn test_basic_read_write() { - let client = Arc::new(LocalClient::default()); - let cachedir = TempDir::new().unwrap(); - assert!(!path_has_files(cachedir.path())); - - let client = CachingClient::new(client, cachedir.path(), 100, DEFAULT_BLOCK_SIZE).unwrap(); - - // the root hash of a single chunk is just the hash of the data - let hello = "hello world".as_bytes().to_vec(); - let hello_hash = merklehash::compute_data_hash(&hello[..]); - // write "hello world" - client - .put("key", &hello_hash, hello.clone(), vec![hello.len() as u64]) - .await - .unwrap(); - - // get length "hello world" - assert_eq!(11, client.get_length("key", &hello_hash).await.unwrap()); - - // read "hello world" - assert_eq!(hello, client.get("key", &hello_hash).await.unwrap()); - - // read range "hello" and "world" - let ranges_to_read: Vec<(u64, u64)> = vec![(0, 5), (6, 11)]; - let expected: Vec> = vec!["hello".as_bytes().to_vec(), "world".as_bytes().to_vec()]; - assert_eq!( - expected, - client - .get_object_range("key", &hello_hash, ranges_to_read) - .await - .unwrap() - ); - // read range "hello" and "world", with truncation for larger offsets - let ranges_to_read: Vec<(u64, u64)> = vec![(0, 5), (6, 20)]; - let expected: Vec> = vec!["hello".as_bytes().to_vec(), "world".as_bytes().to_vec()]; - assert_eq!( - expected, - client - .get_object_range("key", &hello_hash, ranges_to_read) - .await - .unwrap() - ); - // empty read - let ranges_to_read: Vec<(u64, u64)> = vec![(0, 5), (6, 6)]; - let expected: Vec> = vec!["hello".as_bytes().to_vec(), "".as_bytes().to_vec()]; - assert_eq!( - expected, - client - .get_object_range("key", &hello_hash, ranges_to_read) - .await - .unwrap() - ); - assert!(path_has_files(cachedir.path())); +#[async_trait] +impl ReconstructionClient for CachingClient { + async fn get_file(&self, hash: &MerkleHash, writer: &mut Box) -> Result<()> { + todo!() } - #[tokio::test] - async fn test_failures() { - let client = Arc::new(LocalClient::default()); - let cachedir = TempDir::new().unwrap(); - assert!(!path_has_files(cachedir.path())); - - let client = CachingClient::new(client, cachedir.path(), 100, DEFAULT_BLOCK_SIZE).unwrap(); - - let hello = "hello world".as_bytes().to_vec(); - let hello_hash = merklehash::compute_data_hash(&hello[..]); - // write "hello world" - client - .put("key", &hello_hash, hello.clone(), vec![hello.len() as u64]) - .await - .unwrap(); - // put the same value a second time. This should be ok. - client - .put("key", &hello_hash, hello.clone(), vec![hello.len() as u64]) - .await - .unwrap(); - - // put the different value with the same hash - // this should fail - assert_eq!( - CasClientError::CasObjectError(cas_object::error::CasObjectError::HashMismatch), - client - .put( - "hellp", - &hello_hash, - "hellp world".as_bytes().to_vec(), - vec![hello.len() as u64], - ) - .await - .unwrap_err() - ); - // content shorter than the chunk boundaries should fail - assert_eq!( - CasClientError::InvalidArguments, - client - .put( - "key", - &hello_hash, - "hellp wod".as_bytes().to_vec(), - vec![hello.len() as u64], - ) - .await - .unwrap_err() - ); - - // content longer than the chunk boundaries should fail - assert_eq!( - CasClientError::InvalidArguments, - client - .put( - "key", - &hello_hash, - "hello world again".as_bytes().to_vec(), - vec![hello.len() as u64], - ) - .await - .unwrap_err() - ); - - // empty writes should fail - assert_eq!( - CasClientError::InvalidArguments, - client - .put("key", &hello_hash, vec![], vec![],) - .await - .unwrap_err() - ); - - // compute a hash of something we do not have in the store - let world = "world".as_bytes().to_vec(); - let world_hash = merklehash::compute_data_hash(&world[..]); - - // get length of non-existant object should fail with XORBNotFound - assert_eq!( - CasClientError::XORBNotFound(world_hash), - client.get_length("key", &world_hash).await.unwrap_err() - ); - - // read of non-existant object should fail with XORBNotFound - assert_eq!( - CasClientError::XORBNotFound(world_hash), - client.get("key", &world_hash).await.unwrap_err() - ); - // read range of non-existant object should fail with XORBNotFound - assert!(client - .get_object_range("key", &world_hash, vec![(0, 5)]) - .await - .is_err()); + async fn get_file_byte_range( + &self, + hash: &MerkleHash, + offset: u64, + length: u64, + writer: &mut Box, + ) -> Result<()> { + todo!() } } diff --git a/cas_client/src/cas_connection_pool.rs b/cas_client/src/cas_connection_pool.rs deleted file mode 100644 index 0d080fd2..00000000 --- a/cas_client/src/cas_connection_pool.rs +++ /dev/null @@ -1,454 +0,0 @@ -use crate::{error::Result, CasClientError}; -use async_trait::async_trait; -use deadpool::{ - managed::{self, Object, PoolConfig, PoolError, Timeouts}, - Runtime, -}; -use std::sync::Arc; -use std::time::Duration; -use std::time::Instant; -use std::{collections::HashMap, marker::PhantomData}; -use tokio::sync::RwLock; -use tokio_retry::strategy::{jitter, ExponentialBackoff}; -use tokio_retry::RetryIf; -use tracing::{debug, error, info}; -use xet_error::Error; - -#[non_exhaustive] -#[derive(Debug, Error)] -pub enum CasConnectionPoolError { - #[error("Invalid Range Read")] - InvalidRange, - #[error("Locking primitives error")] - LockCannotBeAcquired, - #[error("Connection pool acquisition error: {0}")] - ConnectionPoolAcquisition(String), - #[error("Connection pool creation error: {0}")] - ConnectionPoolCreation(String), -} - -// Magic number for the number of concurrent connections we -// want to support. -const POOL_SIZE: usize = 16; - -const CONNECTION_RETRIES_ON_TIMEOUT: usize = 3; -const CONNECT_TIMEOUT_MS: u64 = 20000; -const CONNECTION_RETRY_BACKOFF_MS: u64 = 10; -const ASYNC_RUNTIME: Runtime = Runtime::Tokio1; - -/// Container for information required to set up and handle -/// CAS connections (both gRPC and H2) -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CasConnectionConfig { - // this endpoint contains the scheme info (http/https) - // ideally we'd have the scheme separately. - pub endpoint: String, - pub user_id: String, - pub auth: String, - pub repo_paths: String, - pub git_xet_version: String, - pub root_ca: Option>, -} - -impl CasConnectionConfig { - /// creates a new CasConnectionConfig with given endpoint and user_id - pub fn new( - endpoint: String, - user_id: String, - auth: String, - repo_paths: Vec, - git_xet_version: String, - ) -> CasConnectionConfig { - CasConnectionConfig { - endpoint, - user_id, - auth, - repo_paths: serde_json::to_string(&repo_paths).unwrap_or_else(|_| "[]".to_string()), - git_xet_version, - root_ca: None, - } - } - - pub fn with_root_ca>(mut self, root_ca: T) -> Self { - self.root_ca = Some(Arc::new(root_ca.into())); - self - } -} - -/// to be impl'ed by Connection types (DataTransport, GrpcClient)so that -/// connection pool managers could instantiate them using CasConnectionConfig -#[async_trait] -pub trait FromConnectionConfig: Sized { - async fn new_from_connection_config(config: CasConnectionConfig) -> Result; -} - -#[derive(Debug)] -pub struct PoolManager -where - T: ?Sized, -{ - connection_type: PhantomData, - cas_connection_config: CasConnectionConfig, -} - -#[async_trait] -impl managed::Manager for PoolManager -where - T: FromConnectionConfig + Sync + Send, -{ - type Type = T; - type Error = CasClientError; - - async fn create(&self) -> std::result::Result { - // Currently recreating the GrpcClient itself. In my limited testing, - // this gets slightly better overall performance than cloning the prototype. - Ok(T::new_from_connection_config(self.cas_connection_config.clone()).await?) - } - - async fn recycle(&self, _conn: &mut Self::Type) -> managed::RecycleResult { - Ok(()) - } -} - -// A mapping between IP address and connection pool. Each IP maps to a -// CAS instance, and we keep a fixed pool with the data plane connections. -#[derive(Debug)] -pub struct ConnectionPoolMap -where - T: FromConnectionConfig + Send + Sync, -{ - pool_map: RwLock>>>>, - max_pool_size: usize, -} - -impl ConnectionPoolMap -where - T: FromConnectionConfig + Send + Sync, -{ - pub fn new() -> Self { - ConnectionPoolMap { - pool_map: RwLock::new(HashMap::default()), - max_pool_size: POOL_SIZE, - } - } - - pub fn new_with_pool_size(max_pool_size: usize) -> Self { - ConnectionPoolMap { - pool_map: RwLock::new(HashMap::default()), - max_pool_size, - } - } - - // Creates a connection pool for the given IP address. - async fn create_pool_for_endpoint_impl( - cas_connection_config: CasConnectionConfig, - max_pool_size: usize, - ) -> std::result::Result>, CasConnectionPoolError> { - let endpoint = cas_connection_config.endpoint.clone(); - - let mgr = PoolManager { - connection_type: PhantomData, - cas_connection_config, - }; - - info!("Creating pool for {endpoint}"); - - let pool = managed::Pool::>::builder(mgr) - .config(PoolConfig { - max_size: max_pool_size, - timeouts: Timeouts { - create: Some(Duration::from_millis(CONNECT_TIMEOUT_MS)), - wait: None, - recycle: Some(Duration::from_millis(0)), - }, - }) - .runtime(ASYNC_RUNTIME) - .build() - .map_err(|e| { - error!( - "Error creating connection pool: {:?} server: {}", - e, endpoint - ); - CasConnectionPoolError::ConnectionPoolCreation(format!("{e:?}")) - })?; - - Ok(pool) - } - - // // Gets a connection object for the given endpoint. This will - // // create a connection pool for the endpoint if none exists already. - pub async fn get_connection_for_config( - &self, - cas_connection_config: CasConnectionConfig, - ) -> std::result::Result>, CasConnectionPoolError> { - let strategy = ExponentialBackoff::from_millis(CONNECTION_RETRY_BACKOFF_MS) - .map(jitter) - .take(CONNECTION_RETRIES_ON_TIMEOUT); - - let endpoint = cas_connection_config.endpoint.clone(); - let pool = self.get_pool_for_config(cas_connection_config).await?; - let result = RetryIf::spawn( - strategy, - || async { - debug!("Trying to get connection for endpoint: {}", endpoint); - pool.get().await - }, - is_pool_connection_error_retriable, - ) - .await - .map_err(|e| { - error!( - "Error acquiring connection for {:?} from pool: {:?} after {} retries", - endpoint, e, CONNECTION_RETRIES_ON_TIMEOUT - ); - CasConnectionPoolError::ConnectionPoolCreation(format!("{e:?}")) - })?; - - Ok(result) - } - - // Utility function to get a connection pool for the given endpoint. If none already - // exists, it will create it and insert it into the map. - async fn get_pool_for_config( - &self, - cas_connection_config: CasConnectionConfig, - ) -> std::result::Result>>, CasConnectionPoolError> { - debug!("Using connection pool"); - - // handle the typical case up front where we are connecting to an - // endpoint that already has its pool initialized. Scopes are meant - // to keep the - { - let now = Instant::now(); - debug!("Acquiring connection map read lock"); - let map = self.pool_map.read().await; - debug!( - "Connection map read lock acquired in {} ms", - now.elapsed().as_millis() - ); - - if let Some(pool) = map.get(cas_connection_config.endpoint.as_str()) { - return Ok(pool.clone()); - }; - } - - let endpoint = cas_connection_config.endpoint.clone(); - // If the connection isn't in the map, create it and insert. - // At worst, we'll briefly have multiple pools overwriting the hashmap, but this - // is needed so as not to carry the lock across an await - let new_pool = Arc::new( - Self::create_pool_for_endpoint_impl(cas_connection_config, self.max_pool_size).await?, - ); - { - let now = Instant::now(); - debug!("Acquiring connection map write lock"); - let mut map = self.pool_map.write().await; - - debug!( - "Connection map write lock acquired in {} ms", - now.elapsed().as_millis() - ); - - map.insert(endpoint, new_pool.clone()); - } - Ok(new_pool) - } - - // Utility to see how many connections are avialable for the IP address. - // This currently requires a lock, so it's not a good idea to use this for - // testing if you should get the lock. - pub async fn get_pool_status_for_endpoint( - &self, - ip_address: String, - ) -> Option { - let map = self.pool_map.read().await; - - map.get(&ip_address).map(|s| s.status()) - } -} - -fn is_pool_connection_error_retriable(err: &PoolError) -> bool { - matches!(err, PoolError::Timeout(_)) -} - -impl Default for ConnectionPoolMap -where - T: FromConnectionConfig + Send + Sync, -{ - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use crate::cas_connection_pool::{ConnectionPoolMap, FromConnectionConfig}; - use crate::error::Result; - use async_trait::async_trait; - - use super::CasConnectionConfig; - - const USER_ID: &str = "XET USER"; - const AUTH: &str = "XET SECRET"; - static REPO_PATHS: &str = "/XET/REPO"; - const GIT_XET_VERSION: &str = "0.1.0"; - - #[derive(Debug, Clone)] - struct PoolTestData { - cas_connection_config: CasConnectionConfig, - } - - #[async_trait] - impl FromConnectionConfig for PoolTestData { - async fn new_from_connection_config( - cas_connection_config: CasConnectionConfig, - ) -> Result { - Ok(PoolTestData { - cas_connection_config, - }) - } - } - - #[tokio::test] - async fn test_get_creates_pool() { - let cas_connection_pool = ConnectionPoolMap::::new(); - - let server1 = "foo".to_string(); - let server2 = "bar".to_string(); - - let config1 = CasConnectionConfig::new( - server1.clone(), - USER_ID.to_string(), - AUTH.to_string(), - vec![REPO_PATHS.to_string()], - GIT_XET_VERSION.to_string(), - ); - let config2 = CasConnectionConfig::new( - server2.clone(), - USER_ID.to_string(), - AUTH.to_string(), - vec![REPO_PATHS.to_string()], - GIT_XET_VERSION.to_string(), - ); - - let conn0 = cas_connection_pool - .get_connection_for_config(config1.clone()) - .await - .unwrap(); - assert_eq!(conn0.cas_connection_config, config1.clone()); - let stat = cas_connection_pool - .get_pool_status_for_endpoint(server1.clone()) - .await - .unwrap(); - assert_eq!(stat.size, 1); - assert_eq!(stat.available, 0); - - let conn1 = cas_connection_pool - .get_connection_for_config(config1.clone()) - .await - .unwrap(); - assert_eq!(conn1.cas_connection_config, config1.clone()); - - let stat = cas_connection_pool - .get_pool_status_for_endpoint(server1.clone()) - .await - .unwrap(); - assert_eq!(stat.size, 2); - assert_eq!(stat.available, 0); - - let conn2 = cas_connection_pool - .get_connection_for_config(config1.clone()) - .await - .unwrap(); - assert_eq!(conn2.cas_connection_config, config1.clone()); - - let stat = cas_connection_pool - .get_pool_status_for_endpoint(server1.clone()) - .await - .unwrap(); - assert_eq!(stat.size, 3); - assert_eq!(stat.available, 0); - - let conn3 = cas_connection_pool - .get_connection_for_config(config1.clone()) - .await - .unwrap(); - assert_eq!(conn3.cas_connection_config, config1.clone()); - - let stat = cas_connection_pool - .get_pool_status_for_endpoint(server1.clone()) - .await - .unwrap(); - assert_eq!(stat.size, 4); - assert_eq!(stat.available, 0); - - drop(conn0); - - let stat = cas_connection_pool - .get_pool_status_for_endpoint(server1.clone()) - .await - .unwrap(); - - println!("{}, {}", stat.size, stat.available); - assert_eq!(stat.size, 4); - assert_eq!(stat.available, 1); - - drop(conn1); - - let stat = cas_connection_pool - .get_pool_status_for_endpoint(server1.clone()) - .await - .unwrap(); - - println!("{}, {}", stat.size, stat.available); - assert_eq!(stat.size, 4); - assert_eq!(stat.available, 2); - - // ensure there's no cross pollination between different server strings - let conn0 = cas_connection_pool - .get_connection_for_config(config2.clone()) - .await - .unwrap(); - assert_eq!(conn0.cas_connection_config, config2.clone()); - let stat = cas_connection_pool - .get_pool_status_for_endpoint(server2.clone()) - .await - .unwrap(); - assert_eq!(stat.size, 1); - assert_eq!(stat.available, 0); - - // TODO: Add more tests here - } - - #[tokio::test] - async fn test_repo_name_encoding() { - let data: Vec> = vec![ - vec!["user1/repo-😀".to_string(), "user1/répô_123".to_string()], - vec![ - "user2/👾_repo".to_string(), - "user2/üникод".to_string(), - "user2/foobar!@#".to_string(), - ], - vec!["user3/sømè_repo".to_string(), "user3/你好-世界".to_string()], - vec!["user4/✨🌈repo".to_string()], - vec!["user5/Ω≈ç√repo".to_string()], - vec!["user6/42°_repo".to_string()], - vec![ - "user7/äëïöü_repo".to_string(), - "user7/ĀāĒēĪīŌōŪū".to_string(), - ], - ]; - for inner_vec in data { - let config = CasConnectionConfig::new( - "".to_string(), - "".to_string(), - "".to_string(), - inner_vec.clone(), - "".to_string(), - ); - let vec_of_strings: Vec = serde_json::from_str(config.repo_paths.as_ref()) - .expect("Failed to deserialize JSON"); - assert_eq!(vec_of_strings.clone(), inner_vec.clone()); - } - } -} diff --git a/cas_client/src/client_adapter.rs b/cas_client/src/client_adapter.rs deleted file mode 100644 index e1731e50..00000000 --- a/cas_client/src/client_adapter.rs +++ /dev/null @@ -1,33 +0,0 @@ -use crate::interface::Client; -use async_trait::async_trait; -use cache::Remote; -use cas::key::Key; -use std::fmt::Debug; -use std::ops::Range; - -#[derive(Debug)] -pub struct ClientRemoteAdapter { - client: T, -} -impl ClientRemoteAdapter { - pub fn new(client: T) -> ClientRemoteAdapter { - ClientRemoteAdapter { client } - } -} - -#[async_trait] -impl Remote for ClientRemoteAdapter { - /// Fetches the provided range from the backing storage, returning the contents - /// if they are present. - async fn fetch( - &self, - key: &Key, - range: Range, - ) -> std::result::Result, anyhow::Error> { - Ok(self - .client - .get_object_range(&key.prefix, &key.hash, vec![(range.start, range.end)]) - .await - .map(|mut v| v.swap_remove(0))?) - } -} diff --git a/cas_client/src/data_transport.rs b/cas_client/src/data_transport.rs deleted file mode 100644 index 32977878..00000000 --- a/cas_client/src/data_transport.rs +++ /dev/null @@ -1,606 +0,0 @@ -use cas::constants::*; -use std::str::FromStr; -use std::time::Duration; - -use crate::cas_connection_pool::CasConnectionConfig; -use anyhow::{anyhow, Result}; -use cas_object::CompressionScheme; - -use error_printer::ErrorPrinter; -use http_body_util::{BodyExt, Full}; -use hyper::body::Bytes; -use hyper::{ - header::RANGE, - header::{HeaderMap, HeaderName, HeaderValue}, - Method, Request, Response, Version, -}; -use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; -use hyper_util::client::legacy::connect::HttpConnector; -use hyper_util::client::legacy::Client; -use hyper_util::rt::{TokioExecutor, TokioTimer}; -use lazy_static::lazy_static; -use lz4::block::CompressionMode; -use opentelemetry::propagation::Injector; -use retry_strategy::RetryStrategy; -use rustls_pemfile::Item; -use tokio_rustls::rustls; -use tokio_rustls::rustls::pki_types::CertificateDer; -use tracing::{debug, error, info_span, warn, Instrument}; -use xet_error::Error; - -use merklehash::MerkleHash; - -const CAS_CONTENT_ENCODING_HEADER: &str = "xet-cas-content-encoding"; -const CAS_ACCEPT_ENCODING_HEADER: &str = "xet-cas-content-encoding"; -const CAS_INFLATED_SIZE_HEADER: &str = "xet-cas-inflated-size"; - -const HTTP2_POOL_IDLE_TIMEOUT_SECS: u64 = 30; -const HTTP2_KEEPALIVE_MILLIS: u64 = 500; -const HTTP2_WINDOW_SIZE: u32 = 2147418112; -const NUM_RETRIES: usize = 5; -const BASE_RETRY_DELAY_MS: u64 = 3000; - -// in the header value, we will consider -fn multiple_accepted_encoding_header_value(list: Vec) -> String { - let as_strs: Vec<&str> = list.iter().map(Into::into).collect(); - as_strs.join(";").to_string() -} - -lazy_static! { - static ref ACCEPTED_ENCODINGS_HEADER_VALUE: HeaderValue = HeaderValue::from_str( - multiple_accepted_encoding_header_value(vec![ - CompressionScheme::LZ4, - CompressionScheme::None - ]) - .as_str() - ) - .unwrap_or_else(|_| HeaderValue::from_static("")); -} - -pub struct DataTransport { - http2_client: Client, Full>, - retry_strategy: RetryStrategy, - cas_connection_config: CasConnectionConfig, -} - -impl std::fmt::Debug for DataTransport { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("DataTransport") - .field("authority", &self.authority()) - .finish() - } -} - -/// This struct is used to wrap the error types which we may -/// retry on. Request Errors (which are triggered if there was a -/// header error building the request) are not retryable. -/// Right now this retries every h2 error. Reading these: -/// - https://docs.rs/h2/latest/h2/struct.Error.html, -/// - https://docs.rs/h2/latest/h2/struct.Reason.html -/// unclear if there is any reason not to retry. -#[derive(Error, Debug)] -enum RetryError { - #[error("{0}")] - Hyper(#[from] hyper::Error), - - #[error("{0}")] - HyperLegacy(#[from] hyper_util::client::legacy::Error), - - #[error("Request Error: {0}")] - Request(#[from] anyhow::Error), - - /// Should only be used for non-success errors - #[error("Status Error: {0}")] - Status(hyper::StatusCode), -} -fn is_status_retriable(err: &RetryError) -> bool { - match err { - RetryError::Hyper(_) => true, - RetryError::HyperLegacy(_) => true, - RetryError::Request(_) => false, - RetryError::Status(n) => retry_http_status_code(n), - } -} -fn retry_http_status_code(stat: &hyper::StatusCode) -> bool { - stat.is_server_error() || *stat == hyper::StatusCode::TOO_MANY_REQUESTS -} - -fn is_status_retriable_and_print(err: &RetryError) -> bool { - let ret = is_status_retriable(err); - if ret { - debug!("{}. Retrying...", err); - } - ret -} -fn print_final_retry_error(err: RetryError) -> RetryError { - if is_status_retriable(&err) { - warn!("Many failures {}", err); - } - err -} - -impl DataTransport { - pub fn new( - http2_client: Client, Full>, - retry_strategy: RetryStrategy, - cas_connection_config: CasConnectionConfig, - ) -> Self { - Self { - http2_client, - retry_strategy, - cas_connection_config, - } - } - - /// creates the DataTransport instance for the H2 connection using - /// CasConnectionConfig info, and additional port - pub async fn from_config(cas_connection_config: CasConnectionConfig) -> Result { - debug!( - "Attempting to make HTTP connection with {}", - cas_connection_config.endpoint - ); - let mut builder = Client::builder(TokioExecutor::new()); - builder - .timer(TokioTimer::new()) - .pool_idle_timeout(Duration::from_secs(HTTP2_POOL_IDLE_TIMEOUT_SECS)) - .http2_keep_alive_interval(Duration::from_millis(HTTP2_KEEPALIVE_MILLIS)) - .http2_initial_connection_window_size(HTTP2_WINDOW_SIZE) - .http2_initial_stream_window_size(HTTP2_WINDOW_SIZE) - .http2_only(true); - let root_ca = cas_connection_config - .root_ca - .clone() - .ok_or_else(|| anyhow!("missing server certificate"))?; - let cert = try_from_pem(root_ca.as_bytes())?; - let mut root_store = rustls::RootCertStore::empty(); - root_store.add(cert)?; - let config = rustls::ClientConfig::builder() - // add the CAS certificate to the client's root store - // client does not need to assume identity for authentication - .with_root_certificates(root_store) - .with_no_client_auth(); - - let connector = HttpsConnectorBuilder::new() - .with_tls_config(config) - .https_only() - .enable_http2() - .build(); - let h2_client = builder.build(connector); - let retry_strategy = RetryStrategy::new(NUM_RETRIES, BASE_RETRY_DELAY_MS); - Ok(Self::new(h2_client, retry_strategy, cas_connection_config)) - } - - fn authority(&self) -> &str { - self.cas_connection_config.endpoint.as_str() - } - - fn get_uri(&self, prefix: &str, hash: &MerkleHash) -> String { - let cas_key_string = cas::key::Key { - prefix: prefix.to_string(), - hash: *hash, - } - .to_string(); - if cas_key_string.starts_with('/') { - format!("{}{}", self.authority(), cas_key_string) - } else { - format!("{}/{}", self.authority(), cas_key_string) - } - } - - fn setup_request( - &self, - method: Method, - prefix: &str, - hash: &MerkleHash, - body: Option>, - ) -> Result>> { - let dest = self.get_uri(prefix, hash); - debug!("Calling {} with address: {}", method, dest); - let user_id = self.cas_connection_config.user_id.clone(); - let auth = self.cas_connection_config.auth.clone(); - // let request_id = get_request_id(); - let repo_paths = self.cas_connection_config.repo_paths.clone(); - let git_xet_version = self.cas_connection_config.git_xet_version.clone(); - // let cas_protocol_version = CAS_PROTOCOL_VERSION.clone(); - - let mut req = Request::builder() - .method(method.clone()) - .header(USER_ID_HEADER, user_id) - .header(AUTH_HEADER, auth) - //.header(REQUEST_ID_HEADER, request_id) - .header(REPO_PATHS_HEADER, repo_paths) - .header(GIT_XET_VERSION_HEADER, git_xet_version) - //.header(CAS_PROTOCOL_VERSION_HEADER, cas_protocol_version) - .uri(&dest) - .version(Version::HTTP_2); - - if method == Method::GET { - req = req.header( - CAS_ACCEPT_ENCODING_HEADER, - ACCEPTED_ENCODINGS_HEADER_VALUE.clone(), - ); - } - - /* - if trace_forwarding() { - if let Some(headers) = req.headers_mut() { - let mut injector = HeaderInjector(headers); - let propagator = opentelemetry_jaeger::Propagator::new(); - let cur_span = Span::current(); - let ctx = cur_span.context(); - propagator.inject_context(&ctx, &mut injector); - } - } - */ - - let bytes = match body { - None => Bytes::new(), - Some(data) => Bytes::from(data), - }; - req.body(Full::new(bytes)).map_err(|e| anyhow!(e)) - } - - // Single get to the H2 server - pub async fn get(&self, prefix: &str, hash: &MerkleHash) -> Result> { - let resp = self - .retry_strategy - .retry( - || async { - let req = self - .setup_request(Method::GET, prefix, hash, None) - .map_err(RetryError::from)?; - - let resp = self - .http2_client - .request(req) - .instrument(info_span!("transport.h2_get")) - .await - .map_err(|e| { - error!("{e}"); - RetryError::from(e) - })?; - - if retry_http_status_code(&resp.status()) { - return Err(RetryError::Status(resp.status())); - } - Ok(resp) - }, - is_status_retriable_and_print, - ) - .await - .map_err(print_final_retry_error)?; - let status = resp.status(); - if status != hyper::StatusCode::OK { - return Err(anyhow!( - "data get status {} received for URL {}", - status, - self.get_uri(prefix, hash) - )); - } - debug!("Received Response from HTTP2 GET: {}", status); - let (encoding, uncompressed_size) = - get_encoding_info(&resp).unwrap_or((CompressionScheme::None, None)); - // Get the body - let bytes = resp - .collect() - .instrument(info_span!("transport.read_body")) - .await? - .to_bytes() - .to_vec(); - let payload_size = bytes.len(); - let bytes = maybe_decode(bytes.as_slice(), encoding, uncompressed_size)?; - debug!( - "GET; encoding: ({}), uncompressed size: ({}), payload ({}) prefix: ({}), hash: ({})", - encoding, - uncompressed_size.unwrap_or_default(), - payload_size, - prefix, - hash - ); - Ok(bytes) - } - - // Single get range to the H2 server - pub async fn get_range( - &self, - prefix: &str, - hash: &MerkleHash, - range: (u64, u64), - ) -> Result> { - let res = self - .retry_strategy - .retry( - || async { - let mut req = self - .setup_request(Method::GET, prefix, hash, None) - .map_err(RetryError::from)?; - let header_value = - HeaderValue::from_str(&format!("bytes={}-{}", range.0, range.1 - 1)) - .map_err(anyhow::Error::from) - .map_err(RetryError::from)?; - req.headers_mut().insert(RANGE, header_value); - - let resp = self - .http2_client - .request(req) - .instrument(info_span!("transport.h2_get_range")) - .await - .map_err(RetryError::from)?; - - if retry_http_status_code(&resp.status()) { - return Err(RetryError::Status(resp.status())); - } - - let status = resp.status(); - if status != hyper::StatusCode::OK { - return Err(RetryError::Request(anyhow!( - "data get_range status {} received for URL {} with range {:?}", - status, - self.get_uri(prefix, hash), - range - ))); - } - debug!("Received Response from HTTP2 GET range: {}", status); - let (encoding, uncompressed_size) = get_encoding_info(&resp).unwrap_or((CompressionScheme::None, None)); - // Get the body - let bytes: Vec = resp - .collect() - .instrument(info_span!("transport.read_body")) - .await? - .to_bytes() - .to_vec(); - let payload_size = bytes.len(); - let bytes = maybe_decode(bytes.as_slice(), encoding, uncompressed_size)?; - debug!("GET RANGE; encoding: ({}), uncompressed size: ({}), payload ({}) prefix: ({}), hash: ({})", encoding, uncompressed_size.unwrap_or_default(), payload_size, prefix, hash); - Ok(bytes.to_vec()) - }, - is_status_retriable_and_print, - ) - .await; - - res.map_err(print_final_retry_error) - .map_err(anyhow::Error::from) - } - - // Single put to the H2 server - pub async fn put( - &self, - prefix: &str, - hash: &MerkleHash, - data: &[u8], - encoding: CompressionScheme, - ) -> Result<()> { - let full_size = data.len(); - let data = maybe_encode(data, encoding)?; - debug!( - "PUT; encoding: ({}), uncompressed size: ({}), payload: ({}), prefix: ({}), hash: ({})", - encoding, - full_size, - data.len(), - prefix, - hash - ); - let resp = self - .retry_strategy - .retry( - || async { - // compression of data to be done here, for now none. - let mut req = self - .setup_request(Method::POST, prefix, hash, Some(data.clone())) - .map_err(RetryError::from)?; - let headers = req.headers_mut(); - headers.insert(CAS_INFLATED_SIZE_HEADER, HeaderValue::from(full_size)); - headers.insert( - CAS_CONTENT_ENCODING_HEADER, - HeaderValue::from_static(encoding.into()), - ); - - let resp = self - .http2_client - .request(req) - .instrument(info_span!("transport.h2_put")) - .await - .map_err(RetryError::from)?; - - if retry_http_status_code(&resp.status()) { - return Err(RetryError::Status(resp.status())); - } - Ok(resp) - }, - is_status_retriable_and_print, - ) - .await - .map_err(print_final_retry_error)?; - let status = resp.status(); - if status != hyper::StatusCode::OK { - return Err(anyhow!( - "data put status {} received for URL {}", - status, - self.get_uri(prefix, hash), - )); - } - debug!("Received Response from HTTP2 POST: {}", status); - - Ok(()) - } -} - -fn maybe_decode<'a, T: Into<&'a [u8]>>( - bytes: T, - encoding: CompressionScheme, - uncompressed_size: Option, -) -> Result> { - if let CompressionScheme::LZ4 = encoding { - if uncompressed_size.is_none() { - return Err(anyhow!( - "Missing uncompressed size when attempting to decompress LZ4" - )); - } - return lz4::block::decompress(bytes.into(), uncompressed_size).map_err(|e| anyhow!(e)); - } - Ok(bytes.into().to_vec()) -} - -fn get_encoding_info(response: &Response) -> Option<(CompressionScheme, Option)> { - let headers = response.headers(); - let value = headers.get(CAS_CONTENT_ENCODING_HEADER)?; - let as_str = value.to_str().ok()?; - let compression_scheme = CompressionScheme::from_str(as_str).ok()?; - - let value = headers.get(CAS_INFLATED_SIZE_HEADER)?; - let as_str = value.to_str().ok()?; - let uncompressed_size: Option = as_str.parse().ok(); - Some((compression_scheme, uncompressed_size)) -} - -fn maybe_encode<'a, T: Into<&'a [u8]>>(data: T, encoding: CompressionScheme) -> Result> { - if let CompressionScheme::LZ4 = encoding { - lz4::block::compress(data.into(), Some(CompressionMode::DEFAULT), false) - .log_error("LZ4 compression error") - .map_err(|e| anyhow!(e)) - } else { - // None - Ok(data.into().to_vec()) - } -} - -fn try_from_pem(pem: &[u8]) -> Result { - let (item, _) = rustls_pemfile::read_one_from_slice(pem) - .map_err(|e| { - error!("pem error: {e:?}"); - // rustls_pemfile::Error does not impl std::error::Error - anyhow!("rustls_pemfile error {e:?}") - })? - .ok_or_else(|| anyhow!("failed to parse pem"))?; - match item { - Item::X509Certificate(cert) => Ok(cert), - _ => Err(anyhow!("invalid cert format")), - } -} - -pub struct HeaderInjector<'a>(pub &'a mut HeaderMap); - -impl<'a> Injector for HeaderInjector<'a> { - /// Set a key and value in the HeaderMap. Does nothing if the key or value are not valid inputs. - fn set(&mut self, key: &str, value: String) { - if let Ok(key_header) = HeaderName::try_from(key) { - if let Ok(header_value) = HeaderValue::from_str(&value) { - self.0.insert(key_header, header_value); - } - } - } -} - -#[cfg(test)] -mod tests { - use lazy_static::lazy_static; - use std::vec; - - use super::*; - - // cert to use for testing - lazy_static! { - static ref CERT: rcgen::Certificate = rcgen::generate_simple_self_signed(vec![]).unwrap(); - } - - #[tokio::test] - async fn test_from_config() { - let endpoint = "http://localhost:443"; - let config = CasConnectionConfig { - endpoint: endpoint.to_string(), - user_id: "user".to_string(), - auth: "auth".to_string(), - repo_paths: "repo".to_string(), - git_xet_version: "0.1.0".to_string(), - root_ca: None, - } - .with_root_ca(CERT.serialize_pem().unwrap()); - let dt = DataTransport::from_config(config).await.unwrap(); - assert_eq!(dt.authority(), endpoint); - } - - #[tokio::test] - async fn repo_path_header_test() { - let data: Vec> = vec![ - vec!["user1/repo-😀".to_string(), "user1/répô_123".to_string()], - vec![ - "user2/👾_repo".to_string(), - "user2/üникод".to_string(), - "user2/foobar!@#".to_string(), - ], - vec!["user3/sømè_repo".to_string(), "user3/你好-世界".to_string()], - vec!["user4/✨🌈repo".to_string()], - vec!["user5/Ω≈ç√repo".to_string()], - vec!["user6/42°_repo".to_string()], - vec![ - "user7/äëïöü_repo".to_string(), - "user7/ĀāĒēĪīŌōŪū".to_string(), - ], - ]; - for inner_vec in data { - let config = CasConnectionConfig::new( - "".to_string(), - "".to_string(), - "".to_string(), - inner_vec.clone(), - "".to_string(), - ) - .with_root_ca(CERT.serialize_pem().unwrap()); - let client = DataTransport::from_config(config).await.unwrap(); - let hello = "hello".as_bytes().to_vec(); - let hash = merklehash::compute_data_hash(&hello[..]); - let req = client.setup_request(Method::GET, "", &hash, None).unwrap(); - let repo_paths = req.headers().get(REPO_PATHS_HEADER).unwrap(); - let repo_path_str = String::from_utf8(repo_paths.as_bytes().to_vec()).unwrap(); - let vec_of_strings: Vec = - serde_json::from_str(repo_path_str.as_str()).expect("Failed to deserialize JSON"); - assert_eq!(vec_of_strings, inner_vec); - } - } - - #[tokio::test] - async fn string_headers_test() { - let user_id = "XET USER"; - let auth = "XET AUTH"; - let git_xet_version = "0.1.0"; - - let cas_connection_config = CasConnectionConfig::new( - "".to_string(), - user_id.to_string(), - auth.to_string(), - vec![], - git_xet_version.to_string(), - ) - .with_root_ca(CERT.serialize_pem().unwrap()); - let client = DataTransport::from_config(cas_connection_config) - .await - .unwrap(); - let hash = merklehash::compute_data_hash("test".as_bytes()); - let req = client - .setup_request(Method::POST, "default", &hash, None) - .unwrap(); - let headers = req.headers(); - // gets header value assuming all well, panic if not - let get_header_value = |header: &str| headers.get(header).unwrap().to_str().unwrap(); - - // check against values in config - assert_eq!(get_header_value(GIT_XET_VERSION_HEADER), git_xet_version); - assert_eq!(get_header_value(USER_ID_HEADER), user_id); - } - - #[test] - fn test_multiple_accepted_encoding_header_value() { - let multi = vec![CompressionScheme::LZ4, CompressionScheme::None]; - assert_eq!( - multiple_accepted_encoding_header_value(multi), - "lz4;none".to_string() - ); - - let singular = vec![CompressionScheme::LZ4]; - assert_eq!( - multiple_accepted_encoding_header_value(singular), - "lz4".to_string() - ); - } - -} diff --git a/cas_client/src/error.rs b/cas_client/src/error.rs index 1d84847d..2904790a 100644 --- a/cas_client/src/error.rs +++ b/cas_client/src/error.rs @@ -1,35 +1,16 @@ use cache::CacheError; -use http::uri::InvalidUri; use merklehash::MerkleHash; -use tonic::metadata::errors::InvalidMetadataValue; use xet_error::Error; -use crate::cas_connection_pool::CasConnectionPoolError; - #[non_exhaustive] #[derive(Error, Debug)] pub enum CasClientError { - #[error("Tonic RPC error.")] - TonicError, - #[error("CAS Cache Error: {0}")] CacheError(#[from] CacheError), #[error("Configuration Error: {0} ")] ConfigurationError(String), - #[error("URL Parsing Error.")] - URLError(#[from] InvalidUri), - - #[error("Tonic Trasport Error")] - TonicTransportError(#[from] tonic::transport::Error), - - #[error("Metadata error: {0}")] - MetadataParsingError(#[from] InvalidMetadataValue), - - #[error("CAS Connection Pool Error")] - CasConnectionPoolError(#[from] CasConnectionPoolError), - #[error("Invalid Range Read")] InvalidRange, @@ -39,8 +20,8 @@ pub enum CasClientError { #[error("Hash Mismatch")] HashMismatch, - #[error("Internal IO Error: {0}")] - InternalIOError(#[from] std::io::Error), + #[error("IO Error: {0}")] + IOError(#[from] std::io::Error), #[error("Other Internal Error: {0}")] InternalError(anyhow::Error), @@ -48,21 +29,6 @@ pub enum CasClientError { #[error("CAS Hash not found")] XORBNotFound(MerkleHash), - #[error("Data transfer timeout")] - DataTransferTimeout, - - #[error("Client connection error {0}")] - Grpc(#[from] anyhow::Error), - - #[error("Batch Error: {0}")] - BatchError(String), - - #[error("Serialization Error: {0}")] - SerializationError(#[from] bincode::Error), - - #[error("Runtime Error (Temp files): {0}")] - RuntimeErrorTempFileError(#[from] tempfile::PersistError), - #[error("Cas Object Error: {0}")] CasObjectError(#[from] cas_object::error::CasObjectError), diff --git a/cas_client/src/interface.rs b/cas_client/src/interface.rs index b6a328d0..d1cf1312 100644 --- a/cas_client/src/interface.rs +++ b/cas_client/src/interface.rs @@ -1,7 +1,7 @@ use crate::error::Result; use async_trait::async_trait; use merklehash::MerkleHash; -use std::sync::Arc; +use std::{io::Write, sync::Arc}; /// A Client to the CAS (Content Addressed Storage) service to allow storage and /// management of XORBs (Xet Object Remote Block). A XORB represents a collection @@ -9,7 +9,7 @@ use std::sync::Arc; /// producing a Merkle Tree. XORBs in the CAS are identified by a combination of /// a prefix namespacing the XORB and the hash at the root of the Merkle Tree. #[async_trait] -pub trait Client: core::fmt::Debug { +pub trait UploadClient { /// Insert the provided data into the CAS as a XORB indicated by the prefix and hash. /// The hash will be verified on the server-side according to the chunk boundaries. /// Chunk Boundaries must be complete; i.e. the last entry in chunk boundary @@ -27,35 +27,37 @@ pub trait Client: core::fmt::Debug { chunk_boundaries: Vec, ) -> Result<()>; + /// Check if a XORB already exists. + async fn exists(&self, prefix: &str, hash: &MerkleHash) -> Result; + /// Clients may do puts in the background. A flush is necessary /// to enforce completion of all puts. If an error occured during any /// background put it will be returned here. async fn flush(&self) -> Result<()>; +} - /// Reads all of the contents for the indicated XORB, returning the data or an error - /// if an issue occurred. - async fn get(&self, prefix: &str, hash: &MerkleHash) -> Result>; +/// A Client to the CAS (Content Addressed Storage) service to allow reconstructing a +/// pointer file based on FileID (MerkleHash). +#[async_trait] +pub trait ReconstructionClient { + /// Get a entire file by file hash. + async fn get_file(&self, hash: &MerkleHash, writer: &mut Box) -> Result<()>; - /// Reads the requested ranges for the indicated object. Each range is a tuple of - /// start byte (inclusive) to end byte (exclusive). Will return the contents for - /// the ranges if they exist in the order specified. If there are issues fetching - /// any of the ranges, then an Error will be returned. - async fn get_object_range( + /// Get a entire file by file hash at a specific bytes range. + async fn get_file_byte_range( &self, - prefix: &str, hash: &MerkleHash, - ranges: Vec<(u64, u64)>, - ) -> Result>>; - - /// Gets the length of the XORB or an error if an issue occurred. - async fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result; + offset: u64, + length: u64, + writer: &mut Box, + ) -> Result<()>; } /* * If T implements Client, Arc also implements Client */ #[async_trait] -impl Client for Arc { +impl UploadClient for Arc { async fn put( &self, prefix: &str, @@ -66,8 +68,8 @@ impl Client for Arc { (**self).put(prefix, hash, data, chunk_boundaries).await } - async fn get(&self, prefix: &str, hash: &MerkleHash) -> Result> { - (**self).get(prefix, hash).await + async fn exists(&self, prefix: &str, hash: &MerkleHash) -> Result { + (**self).exists(prefix, hash).await } /// Clients may do puts in the background. A flush is necessary @@ -76,17 +78,26 @@ impl Client for Arc { async fn flush(&self) -> Result<()> { (**self).flush().await } +} - async fn get_object_range( - &self, - prefix: &str, - hash: &MerkleHash, - ranges: Vec<(u64, u64)>, - ) -> Result>> { - (**self).get_object_range(prefix, hash, ranges).await +#[async_trait] +impl ReconstructionClient for Arc { + /// Get a entire file by file hash. + async fn get_file(&self, hash: &MerkleHash, writer: &mut Box) -> Result<()> { + (**self).get_file(hash, writer).await } - async fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result { - (**self).get_length(prefix, hash).await + async fn get_file_byte_range( + &self, + hash: &MerkleHash, + offset: u64, + length: u64, + writer: &mut Box, + ) -> Result<()> { + (**self) + .get_file_byte_range(hash, offset, length, writer) + .await } } + +pub trait Client: UploadClient + ReconstructionClient {} diff --git a/cas_client/src/lib.rs b/cas_client/src/lib.rs index 9a14107c..179c9812 100644 --- a/cas_client/src/lib.rs +++ b/cas_client/src/lib.rs @@ -2,25 +2,13 @@ #![allow(dead_code)] pub use crate::error::CasClientError; -pub use caching_client::{CachingClient, DEFAULT_BLOCK_SIZE}; +pub use caching_client::CachingClient; pub use interface::Client; pub use local_client::LocalClient; -pub use merklehash::MerkleHash; // re-export since this is required for the client API. -pub use passthrough_staging_client::PassthroughStagingClient; -pub use remote_client::CASAPIClient; pub use remote_client::RemoteClient; -pub use staging_client::{new_staging_client, new_staging_client_with_progressbar, StagingClient}; -pub use staging_trait::{Staging, StagingBypassable}; mod caching_client; -mod cas_connection_pool; -mod client_adapter; -mod data_transport; mod error; mod interface; mod local_client; -mod passthrough_staging_client; mod remote_client; -mod staging_client; -mod staging_trait; -mod util; diff --git a/cas_client/src/local_client.rs b/cas_client/src/local_client.rs index 98f10361..9bb97e80 100644 --- a/cas_client/src/local_client.rs +++ b/cas_client/src/local_client.rs @@ -1,5 +1,7 @@ use crate::error::{CasClientError, Result}; -use crate::interface::Client; +use crate::interface::UploadClient; +use anyhow::anyhow; +use async_trait::async_trait; use cas::key::Key; use cas_object::CasObject; use merklehash::MerkleHash; @@ -7,9 +9,6 @@ use std::fs::{metadata, File}; use std::io::{BufReader, BufWriter, Write}; use std::path::{Path, PathBuf}; use tempfile::TempDir; - -use anyhow::anyhow; -use async_trait::async_trait; use tracing::{debug, error, info}; #[derive(Debug)] @@ -125,12 +124,11 @@ impl LocalClient { let _ = std::fs::remove_file(file_path); } - } /// LocalClient is responsible for writing/reading Xorbs on local disk. #[async_trait] -impl Client for LocalClient { +impl UploadClient for LocalClient { async fn put( &self, prefix: &str, @@ -152,31 +150,14 @@ impl Client for LocalClient { // moved hash validation into [CasObject::serialize], so removed from here. - if let Ok(xorb_size) = self.get_length(prefix, hash).await { - if xorb_size > 0 { - info!("{prefix:?}/{hash:?} already exists in Local CAS; returning."); - return Ok(()); - } + if self.exists(prefix, hash).await? { + info!("{prefix:?}/{hash:?} already exists in Local CAS; returning."); + return Ok(()); } let file_path = self.get_path_for_entry(prefix, hash); info!("Writing XORB {prefix}/{hash:?} to local path {file_path:?}"); - if let Ok(metadata) = metadata(&file_path) { - return if metadata.is_file() { - info!("{file_path:?} already exists; returning."); - // if its a file, its ok. we do not overwrite - Ok(()) - } else { - // if its not file we have a problem. - Err(CasClientError::InternalError(anyhow!( - "Attempting to write to {:?}, but {:?} is not a file", - file_path, - file_path - ))) - }; - } - // we prefix with "[PID]." for now. We should be able to do a cleanup // in the future. let tempfile = tempfile::Builder::new() @@ -197,14 +178,14 @@ impl Client for LocalClient { hash, &data, &chunk_boundaries.into_iter().map(|x| x as u32).collect(), - cas_object::CompressionScheme::None + cas_object::CompressionScheme::None, )?; // flush before persisting writer.flush()?; total_bytes_written = bytes_written; } - tempfile.persist(&file_path)?; + tempfile.persist(&file_path).map_err(|e| e.error)?; // attempt to set to readonly // its ok to fail. @@ -219,72 +200,117 @@ impl Client for LocalClient { Ok(()) } + async fn exists(&self, prefix: &str, hash: &MerkleHash) -> Result { + let file_path = self.get_path_for_entry(prefix, hash); + + let res = metadata(&file_path); + + if res.is_err() || !res.unwrap().is_file() { + return Err(CasClientError::InternalError(anyhow!( + "Attempting to write to {:?}, but it is not a file", + file_path + ))); + }; + + match File::open(file_path) { + Ok(file) => { + let mut reader = BufReader::new(file); + CasObject::deserialize(&mut reader)?; + Ok(true) + } + Err(_) => Err(CasClientError::XORBNotFound(*hash)), + } + } + async fn flush(&self) -> Result<()> { Ok(()) } +} - async fn get(&self, prefix: &str, hash: &MerkleHash) -> Result> { - let file_path = self.get_path_for_entry(prefix, hash); - let file = File::open(&file_path).map_err(|_| { - if !self.silence_errors { - error!("Unable to find file in local CAS {:?}", file_path); - } - CasClientError::XORBNotFound(*hash) - })?; - - let mut reader = BufReader::new(file); - let cas = CasObject::deserialize(&mut reader)?; - let result = cas.get_all_bytes(&mut reader)?; - Ok(result) +#[cfg(test)] +mod tests_utils { + use super::LocalClient; + use crate::{error::Result, CasClientError}; + use cas_object::CasObject; + use merklehash::MerkleHash; + use std::{fs::File, io::BufReader}; + use tracing::error; + + pub trait TestUtils { + fn get(&self, prefix: &str, hash: &MerkleHash) -> Result>; + fn get_object_range( + &self, + prefix: &str, + hash: &MerkleHash, + ranges: Vec<(u64, u64)>, + ) -> Result>>; + fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result; } - async fn get_object_range( - &self, - prefix: &str, - hash: &MerkleHash, - ranges: Vec<(u64, u64)>, - ) -> Result>> { - // Handle the case where we aren't asked for any real data. - if ranges.len() == 1 && ranges[0].0 == ranges[0].1 { - return Ok(vec![Vec::::new()]); + impl TestUtils for LocalClient { + fn get(&self, prefix: &str, hash: &MerkleHash) -> Result> { + let file_path = self.get_path_for_entry(prefix, hash); + let file = File::open(&file_path).map_err(|_| { + if !self.silence_errors { + error!("Unable to find file in local CAS {:?}", file_path); + } + CasClientError::XORBNotFound(*hash) + })?; + + let mut reader = BufReader::new(file); + let cas = CasObject::deserialize(&mut reader)?; + let result = cas.get_all_bytes(&mut reader)?; + Ok(result) } - let file_path = self.get_path_for_entry(prefix, hash); - let file = File::open(&file_path).map_err(|_| { - if !self.silence_errors { - error!("Unable to find file in local CAS {:?}", file_path); + fn get_object_range( + &self, + prefix: &str, + hash: &MerkleHash, + ranges: Vec<(u64, u64)>, + ) -> Result>> { + // Handle the case where we aren't asked for any real data. + if ranges.len() == 1 && ranges[0].0 == ranges[0].1 { + return Ok(vec![Vec::::new()]); } - CasClientError::XORBNotFound(*hash) - })?; - let mut reader = BufReader::new(file); - let cas = CasObject::deserialize(&mut reader)?; + let file_path = self.get_path_for_entry(prefix, hash); + let file = File::open(&file_path).map_err(|_| { + if !self.silence_errors { + error!("Unable to find file in local CAS {:?}", file_path); + } + CasClientError::XORBNotFound(*hash) + })?; + + let mut reader = BufReader::new(file); + let cas = CasObject::deserialize(&mut reader)?; - let mut ret: Vec> = Vec::new(); - for r in ranges { - let data = cas.get_range(&mut reader, r.0 as u32, r.1 as u32)?; - ret.push(data); + let mut ret: Vec> = Vec::new(); + for r in ranges { + let data = cas.get_range(&mut reader, r.0 as u32, r.1 as u32)?; + ret.push(data); + } + Ok(ret) } - Ok(ret) - } - async fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result { - let file_path = self.get_path_for_entry(prefix, hash); - match File::open(file_path) { - Ok(file) => { - let mut reader = BufReader::new(file); - let cas = CasObject::deserialize(&mut reader)?; - let length = cas.get_contents_length()?; - Ok(length as u64) + fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result { + let file_path = self.get_path_for_entry(prefix, hash); + match File::open(file_path) { + Ok(file) => { + let mut reader = BufReader::new(file); + let cas = CasObject::deserialize(&mut reader)?; + let length = cas.get_contents_length()?; + Ok(length as u64) + } + Err(_) => Err(CasClientError::XORBNotFound(*hash)), } - Err(_) => Err(CasClientError::XORBNotFound(*hash)), } } } #[cfg(test)] mod tests { - + use super::tests_utils::TestUtils; use super::*; use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB}; use merklehash::{compute_data_hash, DataHash}; @@ -306,7 +332,7 @@ mod tests { .await .is_ok()); - let returned_data = client.get("key", &hash).await.unwrap(); + let returned_data = client.get("key", &hash).unwrap(); assert_eq!(data_again, returned_data); } @@ -320,7 +346,7 @@ mod tests { // Act & Assert assert!(client.put("", &hash, data, chunk_boundaries).await.is_ok()); - let returned_data = client.get("", &hash).await.unwrap(); + let returned_data = client.get("", &hash).unwrap(); assert_eq!(data_again, returned_data); } @@ -336,7 +362,7 @@ mod tests { let ranges: Vec<(u64, u64)> = vec![(0, 100), (100, 1500)]; let ranges_again = ranges.clone(); - let returned_ranges = client.get_object_range("", &hash, ranges).await.unwrap(); + let returned_ranges = client.get_object_range("", &hash, ranges).unwrap(); for idx in 0..returned_ranges.len() { assert_eq!( @@ -355,7 +381,7 @@ mod tests { // Act client.put("", &hash, data, chunk_boundaries).await.unwrap(); - let len = client.get_length("", &hash).await.unwrap(); + let len = client.get_length("", &hash).unwrap(); // Assert assert_eq!(len as usize, gen_length); @@ -368,7 +394,7 @@ mod tests { let (hash, _, _) = gen_dummy_xorb(16, 2048, true); // Act & Assert - let result = client.get("", &hash).await; + let result = client.get("", &hash); assert!(matches!(result, Err(CasClientError::XORBNotFound(_)))); } @@ -459,15 +485,14 @@ mod tests { // get length of non-existant object should fail with XORBNotFound assert_eq!( CasClientError::XORBNotFound(world_hash), - client.get_length("key", &world_hash).await.unwrap_err() + client.get_length("key", &world_hash).unwrap_err() ); // read of non-existant object should fail with XORBNotFound - assert!(client.get("key", &world_hash).await.is_err()); + assert!(client.get("key", &world_hash).is_err()); // read range of non-existant object should fail with XORBNotFound assert!(client .get_object_range("key", &world_hash, vec![(0, 5)]) - .await .is_err()); // we can delete non-existant things @@ -481,11 +506,11 @@ mod tests { // now every read of that key should fail assert_eq!( CasClientError::XORBNotFound(hello_hash), - client.get_length("key", &hello_hash).await.unwrap_err() + client.get_length("key", &hello_hash).unwrap_err() ); assert_eq!( CasClientError::XORBNotFound(hello_hash), - client.get("key", &hello_hash).await.unwrap_err() + client.get("key", &hello_hash).unwrap_err() ); } diff --git a/cas_client/src/passthrough_staging_client.rs b/cas_client/src/passthrough_staging_client.rs deleted file mode 100644 index f87a0bdd..00000000 --- a/cas_client/src/passthrough_staging_client.rs +++ /dev/null @@ -1,147 +0,0 @@ -use futures::stream::FuturesUnordered; -use futures::StreamExt; -use std::fmt::Debug; -use std::future::Future; -use std::path::PathBuf; -use std::pin::Pin; -use std::sync::Arc; -use tokio::sync::Mutex; -use tracing::info; - -use async_trait::async_trait; - -use merklehash::MerkleHash; - -use crate::error::{CasClientError, Result}; -use crate::interface::Client; -use crate::staging_trait::*; - -const PASSTHROUGH_STAGING_MAX_CONCURRENT_UPLOADS: usize = 16; - -type FutureCollectionType = FuturesUnordered> + Send>>>; - -/// The PassthroughStagingClient is a simple wrapper around -/// a Client that provides the trait implementations required for StagingClient -/// All staging operations are no-op. -#[derive(Debug)] -pub struct PassthroughStagingClient { - client: Arc, - put_futures: Mutex, -} - -impl PassthroughStagingClient { - /// Create a new passthrough staging client which wraps any other client. - /// All operations are simply passthrough to the internal client. - /// All staging operations are no-op. - pub fn new(client: Arc) -> PassthroughStagingClient { - PassthroughStagingClient { - client, - put_futures: Mutex::new(FutureCollectionType::new()), - } - } -} - -impl Staging for PassthroughStagingClient {} - -#[async_trait] -impl StagingUpload for PassthroughStagingClient { - /// Upload all staged will upload everything to the remote client. - /// TODO : Caller may need to be wary of a HashMismatch error which will - /// indicate that the local staging environment has been corrupted somehow. - async fn upload_all_staged(&self, _max_concurrent: usize, _retain: bool) -> Result<()> { - Ok(()) - } -} - -#[async_trait] -impl StagingBypassable for PassthroughStagingClient { - async fn put_bypass_stage( - &self, - prefix: &str, - hash: &MerkleHash, - data: Vec, - chunk_boundaries: Vec, - ) -> Result<()> { - self.client.put(prefix, hash, data, chunk_boundaries).await - } -} - -#[async_trait] -impl StagingInspect for PassthroughStagingClient { - async fn list_all_staged(&self) -> Result> { - Ok(vec![]) - } - - async fn get_length_staged(&self, _prefix: &str, hash: &MerkleHash) -> Result { - Ok(Err(CasClientError::XORBNotFound(*hash))?) - } - - async fn get_length_remote(&self, prefix: &str, hash: &MerkleHash) -> Result { - let item = self.client.get_length(prefix, hash).await?; - - Ok(item as usize) - } - - fn get_staging_path(&self) -> PathBuf { - PathBuf::default() - } - - fn get_staging_size(&self) -> Result { - Ok(0) - } -} - -#[async_trait] -impl Client for PassthroughStagingClient { - async fn put( - &self, - prefix: &str, - hash: &MerkleHash, - data: Vec, - chunk_boundaries: Vec, - ) -> Result<()> { - let prefix = prefix.to_string(); - let hash = *hash; - let client = self.client.clone(); - let mut put_futures = self.put_futures.lock().await; - while put_futures.len() >= PASSTHROUGH_STAGING_MAX_CONCURRENT_UPLOADS { - if let Some(Err(e)) = put_futures.next().await { - info!("Error occurred with a background CAS upload."); - // a background upload failed. we returning that error here. - return Err(e); - } - } - put_futures.push(Box::pin(async move { - client.put(&prefix, &hash, data, chunk_boundaries).await - })); - Ok(()) - } - async fn flush(&self) -> Result<()> { - let mut put_futures = self.put_futures.lock().await; - while put_futures.len() > 0 { - if let Some(Err(e)) = put_futures.next().await { - info!("Error occurred with a background CAS upload."); - // a background upload failed. we returning that error here. - return Err(e); - } - } - Ok(()) - } - - async fn get(&self, prefix: &str, hash: &MerkleHash) -> Result> { - self.client.get(prefix, hash).await - } - - async fn get_object_range( - &self, - prefix: &str, - hash: &MerkleHash, - ranges: Vec<(u64, u64)>, - ) -> Result>> { - self.client.get_object_range(prefix, hash, ranges).await - } - - async fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result { - self.client.get_length(prefix, hash).await - } -} diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 870c46d1..9ee9a15a 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -1,36 +1,31 @@ -use std::io::{Cursor, Write}; - +use crate::interface::*; +use crate::{error::Result, CasClientError}; use anyhow::anyhow; +use async_trait::async_trait; use bytes::Buf; -use cas::key::Key; -use cas_types::{QueryChunkResponse, QueryReconstructionResponse, UploadXorbResponse}; -use reqwest::{ - header::{HeaderMap, HeaderValue}, - StatusCode, Url, -}; -use serde::{de::DeserializeOwned, Serialize}; - use bytes::Bytes; use cas_object::CasObject; -use cas_types::CASReconstructionTerm; -use tracing::{debug, warn}; - -use crate::{error::Result, CasClientError}; - +use cas_types::{CASReconstructionTerm, Key, QueryReconstructionResponse, UploadXorbResponse}; use merklehash::MerkleHash; +use reqwest::header::{HeaderMap, HeaderValue}; +use reqwest::StatusCode; +use reqwest::Url; +use std::io::{Cursor, Write}; +use tracing::{debug, warn}; -use crate::Client; pub const CAS_ENDPOINT: &str = "http://localhost:8080"; pub const PREFIX_DEFAULT: &str = "default"; #[derive(Debug)] pub struct RemoteClient { - client: CASAPIClient, + client: reqwest::Client, + endpoint: String, + token: Option, } // TODO: add retries -#[async_trait::async_trait] -impl Client for RemoteClient { +#[async_trait] +impl UploadClient for RemoteClient { async fn put( &self, prefix: &str, @@ -43,7 +38,7 @@ impl Client for RemoteClient { hash: *hash, }; - let was_uploaded = self.client.upload(&key, &data, chunk_boundaries).await?; + let was_uploaded = self.upload(&key, &data, chunk_boundaries).await?; if !was_uploaded { debug!("{key:?} not inserted into CAS."); @@ -54,67 +49,12 @@ impl Client for RemoteClient { Ok(()) } - async fn flush(&self) -> Result<()> { - Ok(()) - } - - async fn get(&self, _prefix: &str, _hash: &merklehash::MerkleHash) -> Result> { - Err(CasClientError::InvalidArguments) - } - - async fn get_object_range( - &self, - _prefix: &str, - _hash: &merklehash::MerkleHash, - _ranges: Vec<(u64, u64)>, - ) -> Result>> { - Err(CasClientError::InvalidArguments) - } - - async fn get_length(&self, prefix: &str, hash: &merklehash::MerkleHash) -> Result { + async fn exists(&self, prefix: &str, hash: &MerkleHash) -> Result { let key = Key { prefix: prefix.to_string(), hash: *hash, }; - match self.client.get_length(&key).await? { - Some(length) => Ok(length), - None => Err(CasClientError::XORBNotFound(*hash)), - } - } -} -impl RemoteClient { - pub async fn from_config(endpoint: String, token: Option) -> Self { - Self { - client: CASAPIClient::new(&endpoint, token), - } - } -} - -#[derive(Debug)] -pub struct CASAPIClient { - client: reqwest::Client, - endpoint: String, - token: Option, -} - -impl Default for CASAPIClient { - fn default() -> Self { - Self::new(CAS_ENDPOINT, None) - } -} - -impl CASAPIClient { - pub fn new(endpoint: &str, token: Option) -> Self { - let client = reqwest::Client::builder().build().unwrap(); - Self { - client, - endpoint: endpoint.to_string(), - token, - } - } - - pub async fn exists(&self, key: &Key) -> Result { let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?; let response = self .client @@ -131,40 +71,43 @@ impl CASAPIClient { } } - pub async fn get_length(&self, key: &Key) -> Result> { - let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?; - let response = self - .client - .head(url) - .headers(self.request_headers()) - .send() - .await?; - let status = response.status(); - if status == StatusCode::NOT_FOUND { - return Ok(None); - } - if status != StatusCode::OK { - return Err(CasClientError::InternalError(anyhow!( - "unrecognized status code {status}" - ))); + async fn flush(&self) -> Result<()> { + Ok(()) + } +} + +#[async_trait] +impl ReconstructionClient for RemoteClient { + async fn get_file(&self, hash: &MerkleHash, writer: &mut Box) -> Result<()> { + // get manifest of xorbs to download + let manifest = self.reconstruct_file(hash, None).await?; + + self.reconstruct(manifest, None, writer).await?; + + Ok(()) + } + + async fn get_file_byte_range( + &self, + hash: &MerkleHash, + offset: u64, + length: u64, + writer: &mut Box, + ) -> Result<()> { + todo!() + } +} + +impl Client for RemoteClient {} + +impl RemoteClient { + pub fn new(endpoint: &str, token: Option) -> Self { + let client = reqwest::Client::builder().build().unwrap(); + Self { + client, + endpoint: endpoint.to_string(), + token, } - let hv = match response.headers().get("Content-Length") { - Some(hv) => hv, - None => { - return Err(CasClientError::InternalError(anyhow!( - "HEAD missing content length header" - ))) - } - }; - let length: u64 = hv - .to_str() - .map_err(|_| { - CasClientError::InternalError(anyhow!("HEAD missing content length header")) - })? - .parse() - .map_err(|_| CasClientError::InternalError(anyhow!("failed to parse length")))?; - - Ok(Some(length)) } pub async fn upload( @@ -202,21 +145,10 @@ impl CASAPIClient { Ok(response_parsed.was_inserted) } - /// Reconstruct a file and write to writer. - pub async fn write_file( - &self, - file_id: &MerkleHash, - writer: &mut W, - ) -> Result { - // get manifest of xorbs to download - let manifest = self.reconstruct_file(file_id).await?; - - self.reconstruct(manifest, writer).await - } - async fn reconstruct( &self, reconstruction_response: QueryReconstructionResponse, + _byte_range: Option<(u64, u64)>, writer: &mut W, ) -> Result { let info = reconstruction_response.reconstruction; @@ -240,7 +172,11 @@ impl CASAPIClient { } /// Reconstruct the file - async fn reconstruct_file(&self, file_id: &MerkleHash) -> Result { + async fn reconstruct_file( + &self, + file_id: &MerkleHash, + _bytes_range: Option<(u64, u64)>, + ) -> Result { let url = Url::parse(&format!( "{}/reconstruction/{}", self.endpoint, @@ -260,20 +196,6 @@ impl CASAPIClient { Ok(response_parsed) } - pub async fn shard_query_chunk(&self, key: &Key) -> Result { - let url = Url::parse(&format!("{}/chunk/{key}", self.endpoint))?; - let response = self - .client - .get(url) - .headers(self.request_headers()) - .send() - .await?; - let response_body = response.bytes().await?; - let response_parsed: QueryChunkResponse = serde_json::from_reader(response_body.reader())?; - - Ok(response_parsed) - } - fn request_headers(&self) -> HeaderMap { let mut headers = HeaderMap::new(); if let Some(tok) = &self.token { @@ -284,17 +206,6 @@ impl CASAPIClient { } headers } - - async fn post_json(&self, url: Url, request_body: &ReqT) -> Result - where - ReqT: Serialize, - RespT: DeserializeOwned, - { - let body = serde_json::to_vec(request_body)?; - let response = self.client.post(url).body(body).send().await?; - let response_bytes = response.bytes().await?; - serde_json::from_reader(response_bytes.reader()).map_err(CasClientError::SerdeError) - } } async fn get_one(term: &CASReconstructionTerm) -> Result { @@ -345,7 +256,7 @@ mod tests { #[tokio::test] async fn test_basic_put() { // Arrange - let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string(), None).await; + let rc = RemoteClient::new(CAS_ENDPOINT, None); let prefix = PREFIX_DEFAULT; let (hash, data, chunk_boundaries) = gen_dummy_xorb(3, 10248, true); diff --git a/cas_client/src/staging_client.rs b/cas_client/src/staging_client.rs deleted file mode 100644 index 30b63713..00000000 --- a/cas_client/src/staging_client.rs +++ /dev/null @@ -1,584 +0,0 @@ -use anyhow::anyhow; -use async_trait::async_trait; -use merklehash::MerkleHash; -use parutils::{tokio_par_for_each, ParallelError}; -use progress_reporting::DataProgressReporter; -use std::fmt::Debug; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use tokio::sync::Mutex; -use tracing::{info, info_span, Instrument}; -use tracing_opentelemetry::OpenTelemetrySpanExt; - -use crate::error::CasClientError; -use crate::interface::Client; -use crate::local_client::LocalClient; -use crate::staging_trait::*; -use crate::PassthroughStagingClient; - -#[derive(Debug)] -pub struct StagingClient { - client: Arc, - staging_client: LocalClient, - progressbar: bool, -} - -impl StagingClient { - /// Create a new staging client which wraps a remote client. - /// - /// stage_path is the staging directory. - /// - /// Reads will check both the staging environment as well as the - /// the remote client. Puts will write to only staging environment - /// until upload `upload_all_staged()` is called. - /// - /// Staging environment is fully persistent and resilient to restarts. - pub fn new(client: Arc, stage_path: &Path) -> StagingClient { - StagingClient { - client, - staging_client: LocalClient::new(stage_path, true), // silence warnings=true - progressbar: false, - } - } - - /// Create a new staging client which wraps a remote client. - /// - /// stage_path is the staging directory. - /// - /// Reads will check both the staging environment as well as the - /// the remote client. Puts will write to only staging environment - /// until upload `upload_all_staged()` is called. - /// - /// Staging environment is fully persistent and resilient to restarts. - /// This version of the constructor will display a progressbar to stderr - /// when `upload_all_staged()` is called - pub fn new_with_progressbar( - client: Arc, - stage_path: &Path, - ) -> StagingClient { - StagingClient { - client, - staging_client: LocalClient::new(stage_path, true), // silence warnings=true - progressbar: true, - } - } -} - -fn cas_staging_bypass_is_set() -> bool { - // Returns true if XET_CAS_BYPASS_STAGING is set to something besides "0" - std::env::var_os("XET_CAS_BYPASS_STAGING") - .filter(|v| v != "0") - .is_some() -} - -/// Creates a new staging client wraping a staging directory. -/// If a staging directory is provided, it will be used for staging. -/// Otherwise all queries are passed through to the remote directly -/// using the PassthroughStagingClient. -pub fn new_staging_client( - client: T, - stage_path: Option<&Path>, -) -> Arc { - if let (false, Some(path)) = (cas_staging_bypass_is_set(), stage_path) { - Arc::new(StagingClient::new(Arc::new(client), path)) - } else { - Arc::new(PassthroughStagingClient::new(Arc::new(client))) - } -} - -/// Creates a new staging client wraping a staging directory. -/// If a staging directory is provided, it will be used for staging. -/// Otherwise all queries are passed through to the remote directly -/// using the PassthroughStagingClient. -pub fn new_staging_client_with_progressbar( - client: T, - stage_path: Option<&Path>, -) -> Arc { - if let (false, Some(path)) = (cas_staging_bypass_is_set(), stage_path) { - Arc::new(StagingClient::new_with_progressbar(Arc::new(client), path)) - } else { - Arc::new(PassthroughStagingClient::new(Arc::new(client))) - } -} - -impl Staging for StagingClient {} - -#[async_trait] -impl StagingUpload for StagingClient { - /// Upload all staged will upload everything to the remote client. - /// TODO : Caller may need to be wary of a HashMismatch error which will - /// indicate that the local staging environment has been corrupted somehow. - async fn upload_all_staged( - &self, - max_concurrent: usize, - retain: bool, - ) -> Result<(), CasClientError> { - let client = &self.client; - let stage = &self.staging_client; - let entries = stage.get_all_entries()?; - info!( - "XET StagingClient: {} entries to upload to remote.", - entries.len() - ); - - let pb = if self.progressbar && !entries.is_empty() { - let pb = - DataProgressReporter::new("Xet: Uploading data blocks", Some(entries.len()), None); - - pb.register_progress(Some(0), Some(0)); // draw the bar immediately - - Some(Arc::new(Mutex::new(pb))) - } else { - None - }; - let cur_span = info_span!("staging_client.upload_all_staged"); - let ctx = cur_span.context(); - // TODO: This can probably be re-written cleaner with futures::stream - // ex: https://patshaughnessy.net/2020/1/20/downloading-100000-files-using-async-rust - tokio_par_for_each(entries, max_concurrent, |entry, _| { - let pb = pb.clone(); - let span = info_span!("upload_staged_xorb"); - span.set_parent(ctx.clone()); - async move { - // if remote does not have the object - // read the object from staging - // and write the object out to remote - let (cb, val) = stage - .get_detailed(&entry.prefix, &entry.hash) - .instrument(info_span!("read_staged")) - .await?; - let xorb_length = val.len(); - info!( - "Uploading XORB {}/{} of length {}.", - &entry.prefix, - &entry.hash, - val.len() - ); - client.put(&entry.prefix, &entry.hash, val, cb).await?; - - if !retain { - info!( - "Clearing XORB {}/{} from staging area.", - &entry.prefix, &entry.hash, - ); - // Delete it from staging - stage.delete(&entry.prefix, &entry.hash); - } - if let Some(bar) = &pb { - bar.lock() - .await - .register_progress(Some(1), Some(xorb_length)); - } - Ok(()) - } - .instrument(span) - }) - .instrument(cur_span) - .await - .map_err(|e| match e { - ParallelError::JoinError => CasClientError::InternalError(anyhow!("Join Error")), - ParallelError::TaskError(e) => e, - })?; - self.client.flush().await?; - - if let Some(bar) = &pb { - bar.lock().await.finalize(); - } - - Ok(()) - } -} - -#[async_trait] -impl StagingBypassable for StagingClient { - async fn put_bypass_stage( - &self, - prefix: &str, - hash: &MerkleHash, - data: Vec, - chunk_boundaries: Vec, - ) -> Result<(), CasClientError> { - self.client.put(prefix, hash, data, chunk_boundaries).await - } -} - -#[async_trait] -impl StagingInspect for StagingClient { - async fn list_all_staged(&self) -> Result, CasClientError> { - let stage = &self.staging_client; - let items = stage - .get_all_entries()? - .iter() - .map(|item: &cas::key::Key| item.to_string()) - .collect(); - - Ok(items) - } - - async fn get_length_staged( - &self, - prefix: &str, - hash: &MerkleHash, - ) -> Result { - let stage = &self.staging_client; - let item = stage.get_detailed(prefix, hash).await?; - - Ok(item.1.len()) - } - - async fn get_length_remote( - &self, - prefix: &str, - hash: &MerkleHash, - ) -> Result { - let item = self.client.get_length(prefix, hash).await?; - - Ok(item as usize) - } - - fn get_staging_path(&self) -> PathBuf { - self.staging_client.path.clone() - } - - fn get_staging_size(&self) -> Result { - self.staging_client - .path - .read_dir() - .map_err(|x| CasClientError::InternalError(x.into()))? - // take only entries which are ok - .filter_map(|x| x.ok()) - // take only entries whose filenames convert into strings - .filter(|x| { - let mut is_ok = false; - if let Ok(name) = x.file_name().into_string() { - if let Some(pos) = name.rfind('.') { - is_ok = MerkleHash::from_hex(&name[(pos + 1)..]).is_ok() - } - } - - is_ok - }) - .try_fold(0, |acc, file| { - let file = file; - let size = match file.metadata() { - Ok(data) => data.len() as usize, - _ => 0, - }; - Ok(acc + size) - }) - } -} - -#[async_trait] -impl Client for StagingClient { - async fn put( - &self, - prefix: &str, - hash: &MerkleHash, - data: Vec, - chunk_boundaries: Vec, - ) -> Result<(), CasClientError> { - self.staging_client - .put(prefix, hash, data, chunk_boundaries) - .instrument(info_span!("staging_client.put")) - .await - } - - async fn flush(&self) -> Result<(), CasClientError> { - // forward flush to the underlying clients - self.staging_client.flush().await?; - self.client.flush().await - } - - async fn get(&self, prefix: &str, hash: &MerkleHash) -> Result, CasClientError> { - match self - .staging_client - .get(prefix, hash) - .instrument(info_span!("staging_client.get")) - .await - { - Err(CasClientError::XORBNotFound(_)) => self.client.get(prefix, hash).await, - x => x, - } - } - - async fn get_object_range( - &self, - prefix: &str, - hash: &MerkleHash, - ranges: Vec<(u64, u64)>, - ) -> Result>, CasClientError> { - match self - .staging_client - .get_object_range(prefix, hash, ranges.clone()) - .instrument(info_span!("staging_client.get_range")) - .await - { - Err(CasClientError::XORBNotFound(_)) => { - self.client.get_object_range(prefix, hash, ranges).await - } - x => x, - } - } - - async fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result { - match self - .staging_client - .get_length(prefix, hash) - .instrument(info_span!("staging_client.get_length")) - .await - { - Err(CasClientError::XORBNotFound(_)) => self.client.get_length(prefix, hash).await, - x => x, - } - } -} - -#[cfg(test)] -mod tests { - use std::path::Path; - use std::sync::Arc; - - use tempfile::TempDir; - - use crate::staging_client::{StagingClient, StagingUpload}; - use crate::*; - - fn make_staging_client(_client_path: &Path, stage_path: &Path) -> StagingClient { - let client = LocalClient::default(); - StagingClient::new(Arc::new(client), stage_path) - } - - #[tokio::test] - async fn test_general_basic_read_write() { - let localdir = TempDir::new().unwrap(); - let stagedir = TempDir::new().unwrap(); - let client = make_staging_client(localdir.path(), stagedir.path()); - - // the root hash of a single chunk is just the hash of the data - let hello = "hello world".as_bytes().to_vec(); - let hello_hash = merklehash::compute_data_hash(&hello[..]); - // write "hello world" - client - .put("key", &hello_hash, hello.clone(), vec![hello.len() as u64]) - .await - .unwrap(); - - // get length "hello world" - assert_eq!(11, client.get_length("key", &hello_hash).await.unwrap()); - - // read "hello world" - assert_eq!(hello, client.get("key", &hello_hash).await.unwrap()); - - // read range "hello" and "world" - let ranges_to_read: Vec<(u64, u64)> = vec![(0, 5), (6, 11)]; - let expected: Vec> = vec!["hello".as_bytes().to_vec(), "world".as_bytes().to_vec()]; - assert_eq!( - expected, - client - .get_object_range("key", &hello_hash, ranges_to_read) - .await - .unwrap() - ); - // read range "hello" and "world", with truncation for larger offsets - let ranges_to_read: Vec<(u64, u64)> = vec![(0, 5), (6, 20)]; - let expected: Vec> = vec!["hello".as_bytes().to_vec(), "world".as_bytes().to_vec()]; - assert_eq!( - expected, - client - .get_object_range("key", &hello_hash, ranges_to_read) - .await - .unwrap() - ); - // empty read - let ranges_to_read: Vec<(u64, u64)> = vec![(0, 5), (6, 6)]; - let expected: Vec> = vec!["hello".as_bytes().to_vec(), "".as_bytes().to_vec()]; - assert_eq!( - expected, - client - .get_object_range("key", &hello_hash, ranges_to_read) - .await - .unwrap() - ); - } - - #[tokio::test] - async fn test_general_failures() { - let localdir = TempDir::new().unwrap(); - let stagedir = TempDir::new().unwrap(); - let client = make_staging_client(localdir.path(), stagedir.path()); - - let hello = "hello world".as_bytes().to_vec(); - let hello_hash = merklehash::compute_data_hash(&hello[..]); - // write "hello world" - client - .put("key", &hello_hash, hello.clone(), vec![hello.len() as u64]) - .await - .unwrap(); - // put the same value a second time. This should be ok. - client - .put("key", &hello_hash, hello.clone(), vec![hello.len() as u64]) - .await - .unwrap(); - - // put the different value with the same hash - // this should fail - assert_eq!( - CasClientError::CasObjectError(cas_object::error::CasObjectError::HashMismatch), - client - .put( - "hellp", - &hello_hash, - "hellp world".as_bytes().to_vec(), - vec![hello.len() as u64], - ) - .await - .unwrap_err() - ); - // content shorter than the chunk boundaries should fail - assert_eq!( - CasClientError::InvalidArguments, - client - .put( - "key", - &hello_hash, - "hellp wod".as_bytes().to_vec(), - vec![hello.len() as u64], - ) - .await - .unwrap_err() - ); - - // content longer than the chunk boundaries should fail - assert_eq!( - CasClientError::InvalidArguments, - client - .put( - "key", - &hello_hash, - "hello world again".as_bytes().to_vec(), - vec![hello.len() as u64], - ) - .await - .unwrap_err() - ); - - // empty writes should fail - assert_eq!( - CasClientError::InvalidArguments, - client - .put("key", &hello_hash, vec![], vec![],) - .await - .unwrap_err() - ); - - // compute a hash of something we do not have in the store - let world = "world".as_bytes().to_vec(); - let world_hash = merklehash::compute_data_hash(&world[..]); - - // get length of non-existant object should fail with XORBNotFound - assert!(client.get_length("key", &world_hash).await.is_err()); - - // read of non-existant object should fail with XORBNotFound - assert!(client.get("key", &world_hash).await.is_err()); - // read range of non-existant object should fail with XORBNotFound - assert_eq!( - CasClientError::XORBNotFound(world_hash), - client - .get_object_range("key", &world_hash, vec![(0, 5)]) - .await - .unwrap_err() - ); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_staged_read_write() { - let localdir = TempDir::new().unwrap(); - let stagedir = TempDir::new().unwrap(); - let client = make_staging_client(localdir.path(), stagedir.path()); - - // put an object in and make sure it is there - - // the root hash of a single chunk is just the hash of the data - let hello = "hello world".as_bytes().to_vec(); - let hello_hash = merklehash::compute_data_hash(&hello[..]); - // write "hello world" - client - .put("key", &hello_hash, hello.clone(), vec![hello.len() as u64]) - .await - .unwrap(); - - // get length "hello world" - assert_eq!(11, client.get_length("key", &hello_hash).await.unwrap()); - // read "hello world" - assert_eq!(hello, client.get("key", &hello_hash).await.unwrap()); - - // check that the underlying client does not actually have it - assert_eq!( - CasClientError::XORBNotFound(hello_hash), - client.client.get("key", &hello_hash).await.unwrap_err() - ); - - // upload staged - client.upload_all_staged(1, false).await.unwrap(); - - // we can still read it - // get length "hello world" - assert_eq!(11, client.get_length("key", &hello_hash).await.unwrap()); - // read "hello world" - assert_eq!(hello, client.get("key", &hello_hash).await.unwrap()); - - // underlying client has it now - assert_eq!(hello, client.client.get("key", &hello_hash).await.unwrap()); - - // staging client does not - assert_eq!( - CasClientError::XORBNotFound(hello_hash), - client - .staging_client - .get("key", &hello_hash) - .await - .unwrap_err() - ); - } - - #[tokio::test] - async fn test_passthrough() { - let localdir = TempDir::new().unwrap(); - let local = LocalClient::new(localdir.path(), true); - // no staging directory - let client = new_staging_client(local, None); - - // put an object in and make sure it is there - - // the root hash of a single chunk is just the hash of the data - let hello = "hello world".as_bytes().to_vec(); - let hello_hash = merklehash::compute_data_hash(&hello[..]); - // write "hello world" - client - .put("key", &hello_hash, hello.clone(), vec![hello.len() as u64]) - .await - .unwrap(); - client.flush().await.unwrap(); - // get length "hello world" - assert_eq!(11, client.get_length("key", &hello_hash).await.unwrap()); - // read "hello world" - assert_eq!(hello, client.get("key", &hello_hash).await.unwrap()); - - // since there is no stage. get_length_staged should fail. - assert_eq!( - CasClientError::XORBNotFound(hello_hash), - client - .get_length_staged("key", &hello_hash) - .await - .unwrap_err() - ); - - // check that the underlying client has it! - // (this is a passthrough!) - // but we can't get it from the stage object (it is now a Box) - // so we make a new local client at the same directory - let local2 = LocalClient::new(localdir.path(), true); - assert_eq!(hello, local2.get("key", &hello_hash).await.unwrap()); - } -} diff --git a/cas_client/src/staging_trait.rs b/cas_client/src/staging_trait.rs deleted file mode 100644 index 728450b9..00000000 --- a/cas_client/src/staging_trait.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::error::CasClientError; -use crate::interface::Client; -use async_trait::async_trait; -use merklehash::MerkleHash; -use std::path::PathBuf; - -#[async_trait] -pub trait StagingUpload { - async fn upload_all_staged( - &self, - max_concurrent: usize, - retain: bool, - ) -> Result<(), CasClientError>; -} - -#[async_trait] -pub trait StagingBypassable { - async fn put_bypass_stage( - &self, - prefix: &str, - hash: &MerkleHash, - data: Vec, - chunk_boundaries: Vec, - ) -> Result<(), CasClientError>; -} - -#[async_trait] -pub trait StagingInspect { - /// Returns a vector of the XORBs in staging - async fn list_all_staged(&self) -> Result, CasClientError>; - - /// Gets the length of the XORB. This is the same as the - /// get_length method on the Client trait, but it forces the check to - /// come from staging only. - async fn get_length_staged( - &self, - prefix: &str, - hash: &MerkleHash, - ) -> Result; - - /// Gets the length of the XORB. This is the same as the - /// get_length method on the Client trait, but this forces the check to - /// come from the remote CAS server. - async fn get_length_remote( - &self, - prefix: &str, - hash: &MerkleHash, - ) -> Result; - - /// Gets the path to the staging directory. - fn get_staging_path(&self) -> PathBuf; - - /// Gets the sum of the file sizes of the valid XORBS in staging. - fn get_staging_size(&self) -> Result; -} - -#[async_trait] -pub trait Staging: StagingUpload + StagingInspect + Client + StagingBypassable {} diff --git a/cas_client/src/util.rs b/cas_client/src/util.rs deleted file mode 100644 index 94db4f46..00000000 --- a/cas_client/src/util.rs +++ /dev/null @@ -1,235 +0,0 @@ -// #[cfg(test)] -// pub(crate) mod grpc_mock { -// use std::sync::atomic::{AtomicU16, Ordering}; -// use std::sync::Arc; -// use std::time::Duration; - -// use cas::infra::infra_utils_server::InfraUtils; -// use oneshot::{channel, Receiver}; -// use tokio::sync::oneshot; -// use tokio::sync::oneshot::Sender; -// use tokio::task::JoinHandle; -// use tokio::time::sleep; -// use tonic::transport::{Error, Server}; -// use tonic::{Request, Response, Status}; - -// use crate::cas_connection_pool::CasConnectionConfig; -// use cas::cas::cas_server::{Cas, CasServer}; -// use cas::cas::{ -// GetRangeRequest, GetRangeResponse, GetRequest, GetResponse, HeadRequest, HeadResponse, -// PutCompleteRequest, PutCompleteResponse, PutRequest, PutResponse, -// }; -// use cas::common::{Empty, InitiateRequest, InitiateResponse}; -// use cas::infra::EndpointLoadResponse; -// use retry_strategy::RetryStrategy; - -// const TEST_PORT_START: u16 = 64400; - -// lazy_static::lazy_static! { -// static ref CURRENT_PORT: AtomicU16 = AtomicU16::new(TEST_PORT_START); -// } - -// trait_set::trait_set! { -// pub trait PutFn = Fn(Request) -> Result, Status> + 'static; -// pub trait InitiateFn = Fn(Request) -> Result, Status> + 'static; -// pub trait PutCompleteFn = Fn(Request) -> Result, Status> + 'static; -// pub trait GetFn = Fn(Request) -> Result, Status> + 'static; -// pub trait GetRangeFn = Fn(Request) -> Result, Status> + 'static; -// pub trait HeadFn = Fn(Request) -> Result, Status> + 'static; -// } - -// /// "Mocks" the grpc service for CAS. This is implemented by allowing the test writer -// /// to define the functionality needed for the server and then calling `#start()` to -// /// run the server on some port. A GrpcClient will be returned to test with as well -// /// as a shutdown hook that can be called to shutdown the mock service. -// #[derive(Default)] -// pub struct MockService { -// put_fn: Option>, -// initiate_fn: Option>, -// put_complete_fn: Option>, -// get_fn: Option>, -// get_range_fn: Option>, -// head_fn: Option>, -// } - -// impl MockService { -// #[allow(dead_code)] -// pub fn with_initiate(self, f: F) -> Self { -// Self { -// initiate_fn: Some(Arc::new(f)), -// ..self -// } -// } -// #[allow(dead_code)] -// pub fn with_put_complete(self, f: F) -> Self { -// Self { -// put_complete_fn: Some(Arc::new(f)), -// ..self -// } -// } - -// pub fn with_put(self, f: F) -> Self { -// Self { -// put_fn: Some(Arc::new(f)), -// ..self -// } -// } - -// #[allow(dead_code)] -// pub fn with_get(self, f: F) -> Self { -// Self { -// get_fn: Some(Arc::new(f)), -// ..self -// } -// } - -// #[allow(dead_code)] -// pub fn with_get_range(self, f: F) -> Self { -// Self { -// get_range_fn: Some(Arc::new(f)), -// ..self -// } -// } - -// #[allow(dead_code)] -// pub fn with_head(self, f: F) -> Self { -// Self { -// head_fn: Some(Arc::new(f)), -// ..self -// } -// } - -// /* -// pub async fn start(self) -> (ShutdownHook, GrpcClient) { -// self.start_with_retry_strategy(RetryStrategy::new(2, 1)) -// .await -// } - -// pub async fn start_with_retry_strategy( -// self, -// strategy: RetryStrategy, -// ) -> (ShutdownHook, GrpcClient) { -// // Get next port -// let port = CURRENT_PORT.fetch_add(1, Ordering::SeqCst); -// let addr = format!("127.0.0.1:{}", port).parse().unwrap(); - -// // Start up the server -// let (tx, rx) = channel::<()>(); -// let handle = tokio::spawn( -// Server::builder() -// .add_service(CasServer::new(self)) -// .serve_with_shutdown(addr, shutdown(rx)), -// ); -// let shutdown_hook = ShutdownHook::new(tx, handle); - -// // Wait for server to start up -// sleep(Duration::from_millis(10)).await; - -// // Create dedicated client for server -// let endpoint = format!("127.0.0.1:{}", port); -// let user_id = "xet_user".to_string(); -// let auth = "xet_auth".to_string(); -// let repo_paths = vec!["example".to_string()]; -// let version = "0.1.0".to_string(); -// let cas_client = get_client(CasConnectionConfig::new( -// endpoint, user_id, auth, repo_paths, version, -// )) -// .await -// .unwrap(); -// let client = GrpcClient::new("127.0.0.1".to_string(), cas_client, strategy); -// (shutdown_hook, client) -// } - -// */ -// } - -// // Unsafe hacks so that we can dynamically add in overrides to the mock functionality -// // (Fn isn't sync/send). There's probably a better way to do this that isn't so blunt/fragile. -// unsafe impl Send for MockService {} -// unsafe impl Sync for MockService {} - -// #[async_trait::async_trait] -// impl InfraUtils for MockService { -// async fn endpoint_load( -// &self, -// _request: Request, -// ) -> Result, Status> { -// unimplemented!() -// } -// async fn initiate( -// &self, -// request: Request, -// ) -> Result, Status> { -// self.initiate_fn.as_ref().unwrap()(request) -// } -// } -// #[async_trait::async_trait] -// impl Cas for MockService { -// async fn initiate( -// &self, -// request: Request, -// ) -> Result, Status> { -// self.initiate_fn.as_ref().unwrap()(request) -// } - -// async fn put(&self, request: Request) -> Result, Status> { -// self.put_fn.as_ref().unwrap()(request) -// } - -// async fn put_complete( -// &self, -// request: Request, -// ) -> Result, Status> { -// self.put_complete_fn.as_ref().unwrap()(request) -// } - -// async fn get(&self, request: Request) -> Result, Status> { -// self.get_fn.as_ref().unwrap()(request) -// } - -// async fn get_range( -// &self, -// request: Request, -// ) -> Result, Status> { -// self.get_range_fn.as_ref().unwrap()(request) -// } - -// async fn head( -// &self, -// request: Request, -// ) -> Result, Status> { -// self.head_fn.as_ref().unwrap()(request) -// } -// } - -// async fn shutdown(rx: Receiver<()>) { -// let _ = rx.await; -// } - -// /// Encapsulates logic to shutdown a running tonic Server. This is done through -// /// sending a message on a channel that the server is listening on for shutdown. -// /// Once the message has been sent, the spawned task is awaited using its JoinHandle. -// /// -// /// TODO: implementing `Drop` with async is difficult and the naïve implementation -// /// ends up blocking the test completion. There is likely some deadlock somewhere. -// pub struct ShutdownHook { -// tx: Option>, -// join_handle: Option>>, -// } - -// impl ShutdownHook { -// pub fn new(tx: Sender<()>, join_handle: JoinHandle>) -> Self { -// Self { -// tx: Some(tx), -// join_handle: Some(join_handle), -// } -// } - -// pub async fn async_drop(&mut self) { -// let tx = self.tx.take(); -// let handle = self.join_handle.take(); -// let _ = tx.unwrap().send(()); -// let _ = handle.unwrap().await; -// } -// } -// } diff --git a/mdb_shard/src/file_structs.rs b/mdb_shard/src/file_structs.rs index fd1e0ad9..091f466e 100644 --- a/mdb_shard/src/file_structs.rs +++ b/mdb_shard/src/file_structs.rs @@ -72,35 +72,39 @@ pub struct FileDataSequenceEntry { pub cas_hash: MerkleHash, pub cas_flags: u32, pub unpacked_segment_bytes: u32, - pub chunk_byte_range_start: u32, - pub chunk_byte_range_end: u32, + pub chunk_index_start: u32, + pub chunk_index_end: u32, } impl FileDataSequenceEntry { - pub fn new, I2: TryInto>( + pub fn new, I2: TryInto>( cas_hash: MerkleHash, unpacked_segment_bytes: I1, - chunk_byte_range_start: I2, - chunk_byte_range_end: I2, + chunk_index_start: I2, + chunk_index_end: I2, ) -> Self where >::Error: std::fmt::Debug, - >::Error: std::fmt::Debug, + >::Error: std::fmt::Debug, { Self { cas_hash, cas_flags: MDB_DEFAULT_FILE_FLAG, unpacked_segment_bytes: unpacked_segment_bytes.try_into().unwrap(), - chunk_byte_range_start: chunk_byte_range_start.try_into().unwrap(), - chunk_byte_range_end: chunk_byte_range_end.try_into().unwrap(), + chunk_index_start: chunk_index_start.try_into().unwrap() as u32, + chunk_index_end: chunk_index_end.try_into().unwrap() as u32, } } - pub fn from_cas_entries( + pub fn from_cas_entries>( metadata: &CASChunkSequenceHeader, chunks: &[CASChunkSequenceEntry], - chunk_byte_range_end: u32, - ) -> Self { + chunk_index_start: I1, + chunk_index_end: I1, + ) -> Self + where + >::Error: std::fmt::Debug, + { if chunks.is_empty() { return Self::default(); } @@ -109,8 +113,8 @@ impl FileDataSequenceEntry { cas_hash: metadata.cas_hash, cas_flags: metadata.cas_flags, unpacked_segment_bytes: chunks.iter().map(|sb| sb.unpacked_segment_bytes).sum(), - chunk_byte_range_start: chunks[0].chunk_byte_range_start, - chunk_byte_range_end, + chunk_index_start: chunk_index_start.try_into().unwrap(), + chunk_index_end: chunk_index_end.try_into().unwrap(), } } @@ -123,8 +127,8 @@ impl FileDataSequenceEntry { write_hash(writer, &self.cas_hash)?; write_u32(writer, self.cas_flags)?; write_u32(writer, self.unpacked_segment_bytes)?; - write_u32(writer, self.chunk_byte_range_start)?; - write_u32(writer, self.chunk_byte_range_end)?; + write_u32(writer, self.chunk_index_start)?; + write_u32(writer, self.chunk_index_end)?; } writer.write_all(&buf[..])?; @@ -143,8 +147,8 @@ impl FileDataSequenceEntry { cas_hash: read_hash(reader)?, cas_flags: read_u32(reader)?, unpacked_segment_bytes: read_u32(reader)?, - chunk_byte_range_start: read_u32(reader)?, - chunk_byte_range_end: read_u32(reader)?, + chunk_index_start: read_u32(reader)?, + chunk_index_end: read_u32(reader)?, }) } } diff --git a/mdb_shard/src/shard_file_manager.rs b/mdb_shard/src/shard_file_manager.rs index c23c44e4..d7f9e089 100644 --- a/mdb_shard/src/shard_file_manager.rs +++ b/mdb_shard/src/shard_file_manager.rs @@ -459,6 +459,8 @@ impl ShardFileManager { #[cfg(test)] mod tests { + use std::cmp::min; + use crate::{ cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader}, file_structs::FileDataSequenceHeader, @@ -596,12 +598,8 @@ mod tests { let mut query_hashes_2 = query_hashes_1.clone(); query_hashes_2.push(rng_hash(1000000 + i as u64)); - let lb = cas_block.chunks[i].chunk_byte_range_start; - let ub = if i + 3 >= cas_block.chunks.len() { - cas_block.metadata.num_bytes_in_cas - } else { - cas_block.chunks[i + 3].chunk_byte_range_start - }; + let lb = i as u32; + let ub = min(i + 3, cas_block.chunks.len()) as u32; for query_hashes in [&query_hashes_1, &query_hashes_2] { let result_m = mem_shard.chunk_hash_dedup_query(query_hashes).unwrap(); @@ -621,17 +619,11 @@ mod tests { // Make sure the bounds are correct assert_eq!( - ( - result_m.1.chunk_byte_range_start, - result_m.1.chunk_byte_range_end - ), + (result_m.1.chunk_index_start, result_m.1.chunk_index_end), (lb, ub) ); assert_eq!( - ( - result_f.1.chunk_byte_range_start, - result_f.1.chunk_byte_range_end - ), + (result_f.1.chunk_index_start, result_f.1.chunk_index_end), (lb, ub) ); diff --git a/mdb_shard/src/shard_format.rs b/mdb_shard/src/shard_format.rs index 4da188b9..a45b28de 100644 --- a/mdb_shard/src/shard_format.rs +++ b/mdb_shard/src/shard_format.rs @@ -659,15 +659,12 @@ impl MDBShardInfo { } let mut n_bytes = first_chunk.unpacked_segment_bytes; - let chunk_byte_range_start = first_chunk.chunk_byte_range_start; // Read everything else until the CAS block end. let mut end_idx = 0; - let mut chunk_byte_range_end = chunk_byte_range_start; for i in 1.. { if cas_chunk_offset as usize + i == cas_header.num_entries as usize { end_idx = i; - chunk_byte_range_end = cas_header.num_bytes_in_cas; break; } @@ -675,7 +672,6 @@ impl MDBShardInfo { if i == query_hashes.len() || chunk.chunk_hash != query_hashes[i] { end_idx = i; - chunk_byte_range_end = chunk.chunk_byte_range_start; break; } @@ -688,8 +684,8 @@ impl MDBShardInfo { cas_hash: cas_header.cas_hash, cas_flags: cas_header.cas_flags, unpacked_segment_bytes: n_bytes, - chunk_byte_range_start, - chunk_byte_range_end, + chunk_index_start: cas_chunk_offset, + chunk_index_end: cas_chunk_offset + end_idx as u32, }, ))) } @@ -861,17 +857,8 @@ impl MDBShardInfo { }; // Scan the cas entries to get the proper index - let Some(first_chunk_hash) = ('a: { - for e in cas_chunks[*cas_block_index].chunks.iter() { - if e.chunk_byte_range_start == entry.chunk_byte_range_start { - break 'a Some(e.chunk_hash); - } - } - error!("Error: Shard file start in CAS is not on chunk boundary."); - break 'a None; - }) else { - continue; - }; + let first_chunk_hash = + cas_chunks[*cas_block_index].chunks[entry.chunk_index_start as usize].chunk_hash; ret.push(first_chunk_hash); } @@ -881,6 +868,7 @@ impl MDBShardInfo { } pub mod test_routines { + use std::cmp::min; use std::io::{Cursor, Read, Seek}; use std::mem::size_of; @@ -1036,12 +1024,8 @@ pub mod test_routines { let mut query_hashes_2 = query_hashes_1.clone(); query_hashes_2.push(rng_hash(1000000 + i as u64)); - let lb = cas_block.chunks[i].chunk_byte_range_start; - let ub = if i + 3 >= cas_block.chunks.len() { - cas_block.metadata.num_bytes_in_cas - } else { - cas_block.chunks[i + 3].chunk_byte_range_start - }; + let lb = i as u32; + let ub = min(i + 3, cas_block.chunks.len()) as u32; for query_hashes in [&query_hashes_1, &query_hashes_2] { let result_m = mem_shard.chunk_hash_dedup_query(query_hashes).unwrap(); @@ -1060,17 +1044,11 @@ pub mod test_routines { // Make sure the bounds are correct assert_eq!( - ( - result_m.1.chunk_byte_range_start, - result_m.1.chunk_byte_range_end - ), + (result_m.1.chunk_index_start, result_m.1.chunk_index_end), (lb, ub) ); assert_eq!( - ( - result_f.1.chunk_byte_range_start, - result_f.1.chunk_byte_range_end - ), + (result_f.1.chunk_index_start, result_f.1.chunk_index_end), (lb, ub) ); diff --git a/mdb_shard/src/shard_in_memory.rs b/mdb_shard/src/shard_in_memory.rs index 85a0e15b..5277a476 100644 --- a/mdb_shard/src/shard_in_memory.rs +++ b/mdb_shard/src/shard_in_memory.rs @@ -146,25 +146,23 @@ impl MDBInMemoryShard { return None; } - let (chunk_ref, offset) = match self.chunk_hash_lookup.get(&query_hashes[0]) { + let (chunk_ref, chunk_index_start) = match self.chunk_hash_lookup.get(&query_hashes[0]) { Some(s) => s, None => return None, }; - let offset = *offset as usize; + let chunk_index_start = *chunk_index_start as usize; - let end_byte_offset; let mut query_idx = 0; loop { - if offset + query_idx >= chunk_ref.chunks.len() { - end_byte_offset = chunk_ref.metadata.num_bytes_in_cas; + if chunk_index_start + query_idx >= chunk_ref.chunks.len() { break; } if query_idx >= query_hashes.len() - || chunk_ref.chunks[offset + query_idx].chunk_hash != query_hashes[query_idx] + || chunk_ref.chunks[chunk_index_start + query_idx].chunk_hash + != query_hashes[query_idx] { - end_byte_offset = chunk_ref.chunks[offset + query_idx].chunk_byte_range_start; break; } query_idx += 1; @@ -174,8 +172,9 @@ impl MDBInMemoryShard { query_idx, FileDataSequenceEntry::from_cas_entries( &chunk_ref.metadata, - &chunk_ref.chunks[offset..(offset + query_idx)], - end_byte_offset, + &chunk_ref.chunks[chunk_index_start..(chunk_index_start + query_idx)], + chunk_index_start, + chunk_index_start + query_idx, ), )) } From 0eb6c314c78bcaa3b3d757429012295d2030dd69 Mon Sep 17 00:00:00 2001 From: seanses Date: Fri, 27 Sep 2024 16:21:01 -0700 Subject: [PATCH 02/19] update --- data/src/cas_interface.rs | 181 ++++------------------------- data/src/clean.rs | 14 +-- data/src/data_processing.rs | 89 +------------- data/src/lib.rs | 1 - data/src/remote_shard_interface.rs | 10 +- hf_xet/src/config.rs | 13 ++- hf_xet/src/data_client.rs | 4 +- 7 files changed, 49 insertions(+), 263 deletions(-) diff --git a/data/src/cas_interface.rs b/data/src/cas_interface.rs index 83709326..bb66aa18 100644 --- a/data/src/cas_interface.rs +++ b/data/src/cas_interface.rs @@ -1,170 +1,35 @@ -use super::configurations::{Endpoint::*, RepoInfo, StorageConfig}; -use super::errors::Result; -use crate::constants::MAX_CONCURRENT_DOWNLOADS; -use crate::metrics::FILTER_BYTES_SMUDGED; -use cas_client::{new_staging_client, CachingClient, LocalClient, RemoteClient, Staging}; -use futures::prelude::stream::*; -use merkledb::ObjectRange; -use merklehash::MerkleHash; +use crate::configurations::*; +use crate::errors::Result; +use cas_client::RemoteClient; use std::env::current_dir; +use std::path::Path; use std::sync::Arc; -use tracing::{error, info, info_span}; +use tracing::info; -// Re-export for external configuration suggestion. -pub use cas_client::DEFAULT_BLOCK_SIZE; +pub use cas_client::Client; -pub(crate) async fn create_cas_client( +pub(crate) fn create_cas_client( cas_storage_config: &StorageConfig, - maybe_repo_info: &Option, -) -> Result> { - // Local file system based CAS storage. - if let FileSystem(ref path) = cas_storage_config.endpoint { - info!("Using local CAS with path: {:?}.", path); - let path = match path.is_absolute() { - true => path, - false => ¤t_dir()?.join(path), - }; - let client = LocalClient::new(path, false); - return Ok(new_staging_client( - client, - cas_storage_config.staging_directory.as_deref(), - )); - } - - // Now we are using remote server CAS storage. - let Server(ref endpoint) = cas_storage_config.endpoint else { - unreachable!(); - }; - - - // Usage tracking. - let _repo_paths = maybe_repo_info - .as_ref() - .map(|repo_info| &repo_info.repo_paths) - .cloned() - .unwrap_or_default(); - - // Raw remote client. - let remote_client = Arc::new( - RemoteClient::from_config(endpoint.to_string(), cas_storage_config.auth.token.clone()).await, - ); - - // Try add in caching capability. - let maybe_caching_client = cas_storage_config.cache_config.as_ref().and_then(|cache| { - CachingClient::new( - remote_client.clone(), - &cache.cache_directory, - cache.cache_size, - cache.cache_blocksize, - ) - .map_err(|e| error!("Unable to use caching CAS due to: {:?}", &e)) - .ok() - }); - - // If initiating caching was unsuccessful, fall back to only remote client. - match maybe_caching_client { - Some(caching_client) => { - info!( - "Using caching CAS with endpoint {:?}, caching at {:?}.", - &endpoint, - cas_storage_config - .cache_config - .as_ref() - .unwrap() - .cache_directory - ); - - Ok(new_staging_client( - caching_client, - cas_storage_config.staging_directory.as_deref(), - )) - } - None => { - info!("Using non-caching CAS with endpoint: {:?}.", &endpoint); - Ok(new_staging_client( - remote_client, - cas_storage_config.staging_directory.as_deref(), - )) - } + _maybe_repo_info: &Option, +) -> Result> { + match cas_storage_config.endpoint { + Endpoint::Server(ref endpoint) => remote_client(endpoint, &cas_storage_config.auth), + Endpoint::FileSystem(ref path) => local_test_cas_client(path), } } -/** Wrapper to consolidate the logic for retrieving from CAS. - */ -async fn get_from_cas( - cas: &Arc, - prefix: String, - hash: MerkleHash, - ranges: (u64, u64), -) -> Result> { - if ranges.0 == ranges.1 { - return Ok(Vec::new()); - } - let mut query_result = cas.get_object_range(&prefix, &hash, vec![ranges]).await?; - Ok(std::mem::take(&mut query_result[0])) -} +pub(crate) fn remote_client(endpoint: &str, auth: &Auth) -> Result> { + // Raw remote client. + let remote_client = Arc::new(RemoteClient::new(endpoint, auth.token.clone())); -/// Given an Vec describing a series of range of bytes, -/// slice a subrange. This does not check limits and may return shorter -/// results if the slice goes past the end of the range. -pub(crate) fn slice_object_range( - v: &[ObjectRange], - mut start: usize, - mut len: usize, -) -> Vec { - let mut ret: Vec = Vec::new(); - for i in v.iter() { - let ilen = i.end - i.start; - // we have not gotten to the start of the range - if start > 0 && start >= ilen { - // start is still after this range - start -= ilen; - } else { - // either start == 0, or start < packet len. - // Either way, we need some or all of this packet - // and after this packet start must be = 0 - let packet_start = i.start + start; - // the maximum length allowed is how far to end of the packet - // OR the actual slice length requested which ever is shorter. - let max_length_allowed = std::cmp::min(i.end - packet_start, len); - ret.push(ObjectRange { - hash: i.hash, - start: packet_start, - end: packet_start + max_length_allowed, - }); - start = 0; - len -= max_length_allowed; - } - if len == 0 { - break; - } - } - ret + Ok(remote_client) } -/// Writes a collection of chunks from a Vec to a writer. -pub(crate) async fn data_from_chunks_to_writer( - cas: &Arc, - prefix: String, - chunks: Vec, - writer: &mut impl std::io::Write, -) -> Result<()> { - let mut bytes_smudged: u64 = 0; - let mut strm = iter(chunks.into_iter().map(|objr| { - let prefix = prefix.clone(); - get_from_cas(cas, prefix, objr.hash, (objr.start as u64, objr.end as u64)) - })) - .buffered(*MAX_CONCURRENT_DOWNLOADS); - - while let Some(buf) = strm.next().await { - let buf = buf?; - bytes_smudged += buf.len() as u64; - let s = info_span!("write_chunk"); - let _ = s.enter(); - writer.write_all(&buf)?; - } - - FILTER_BYTES_SMUDGED.inc_by(bytes_smudged); - - Ok(()) +fn local_test_cas_client(path: &Path) -> Result> { + info!("Using local CAS with path: {:?}.", path); + let _path = match path.is_absolute() { + true => path, + false => ¤t_dir()?.join(path), + }; + unimplemented!() } diff --git a/data/src/clean.rs b/data/src/clean.rs index 90ca681e..52350b32 100644 --- a/data/src/clean.rs +++ b/data/src/clean.rs @@ -1,18 +1,15 @@ +use crate::cas_interface::Client; use crate::chunking::{chunk_target_default, ChunkYieldType}; use crate::configurations::FileQueryPolicy; use crate::constants::MIN_SPACING_BETWEEN_GLOBAL_DEDUP_QUERIES; use crate::data_processing::{register_new_cas_block, CASDataAggregator}; -use crate::errors::{ - DataProcessingError::{self, *}, - Result, -}; +use crate::errors::{DataProcessingError::*, Result}; use crate::metrics::FILTER_BYTES_CLEANED; use crate::remote_shard_interface::RemoteShardInterface; use crate::repo_salt::RepoSalt; use crate::small_file_determination::{is_file_passthrough, is_possible_start_to_text_file}; use crate::PointerFile; -use cas_client::Staging; use lazy_static::lazy_static; use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDBFileInfo}; use mdb_shard::shard_file_reconstructor::FileReconstructor; @@ -64,7 +61,7 @@ pub struct Cleaner { // Utils shard_manager: Arc, remote_shards: Arc, - cas: Arc, + cas: Arc, // External Data global_cas_data: Arc>, @@ -91,7 +88,7 @@ impl Cleaner { repo_salt: Option, shard_manager: Arc, remote_shards: Arc, - cas: Arc, + cas: Arc, cas_data: Arc>, buffer_size: usize, file_name: Option<&Path>, @@ -139,7 +136,8 @@ impl Cleaner { let mut small_file_buffer = self.small_file_buffer.lock().await; if let Some(buffer) = small_file_buffer.take() { - return String::from_utf8(buffer).map_err(DataProcessingError::from); + let small_file = String::from_utf8(buffer)?; + return Ok(small_file); } self.to_pointer_file().await diff --git a/data/src/data_processing.rs b/data/src/data_processing.rs index 4c46091c..0e2363ca 100644 --- a/data/src/data_processing.rs +++ b/data/src/data_processing.rs @@ -1,4 +1,4 @@ -use crate::cas_interface::{create_cas_client, data_from_chunks_to_writer}; +use crate::cas_interface::create_cas_client; use crate::clean::Cleaner; use crate::configurations::*; use crate::constants::MAX_CONCURRENT_UPLOADS; @@ -8,19 +8,17 @@ use crate::remote_shard_interface::RemoteShardInterface; use crate::shard_interface::create_shard_manager; use crate::PointerFile; -use cas_client::{CASAPIClient, Staging}; +use cas_client::Client; use mdb_shard::cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader, MDBCASInfo}; use mdb_shard::file_structs::MDBFileInfo; use mdb_shard::ShardFileManager; use merkledb::aggregate_hashes::cas_node_hash; -use merkledb::ObjectRange; use merklehash::MerkleHash; use std::mem::take; use std::ops::DerefMut; use std::path::Path; use std::sync::Arc; use tokio::sync::Mutex; -use tracing::error; #[derive(Default, Debug)] pub struct CASDataAggregator { @@ -55,7 +53,7 @@ pub struct PointerFileTranslator { /* ----- Utils ----- */ shard_manager: Arc, remote_shards: Arc, - cas: Arc, + cas: Arc, /* ----- Deduped data shared across files ----- */ global_cas_data: Arc>, @@ -205,7 +203,7 @@ impl PointerFileTranslator { pub(crate) async fn register_new_cas_block( cas_data: &mut CASDataAggregator, shard_manager: &Arc, - cas: &Arc, + cas: &Arc, cas_prefix: &str, ) -> Result { let cas_hash = cas_node_hash(&cas_data.chunks[..]); @@ -287,36 +285,13 @@ pub(crate) async fn register_new_cas_block( /// Smudge operations impl PointerFileTranslator { - pub async fn derive_blocks(&self, hash: &MerkleHash) -> Result> { - if let Some((file_info, _shard_hash)) = self - .remote_shards - .get_file_reconstruction_info(hash) - .await? - { - Ok(file_info - .segments - .into_iter() - .map(|s| ObjectRange { - hash: s.cas_hash, - start: s.chunk_byte_range_start as usize, - end: s.chunk_byte_range_end as usize, - }) - .collect()) - } else { - error!("File Reconstruction info for hash {hash:?} not found."); - Err(DataProcessingError::HashNotFound) - } - } - pub async fn smudge_file_from_pointer( &self, pointer: &PointerFile, writer: &mut impl std::io::Write, range: Option<(usize, usize)>, - endpoint: Option, - token: Option, ) -> Result<()> { - self.smudge_file_from_hash(&pointer.hash()?, writer, range, endpoint, token) + self.smudge_file_from_hash(&pointer.hash()?, writer, range) .await } @@ -325,61 +300,9 @@ impl PointerFileTranslator { file_id: &MerkleHash, writer: &mut impl std::io::Write, _range: Option<(usize, usize)>, - endpoint: Option, - token: Option, ) -> Result<()> { - let endpoint = match &self.config.cas_storage_config.endpoint { - Endpoint::Server(config_endpoint) => { - if let Some(endpoint) = endpoint { - endpoint - } else { - config_endpoint.clone() - } - }, - Endpoint::FileSystem(_) => panic!("aaaaaaaa no server"), - }; - - let rc = CASAPIClient::new(&endpoint, token); - - rc.write_file(file_id, writer).await?; - - // let blocks = self - // .derive_blocks(file_id) - // .instrument(info_span!("derive_blocks")) - // .await?; - - // let ranged_blocks = match range { - // Some((start, end)) => { - // // we expect callers to validate the range, but just in case, check it anyway. - // if end < start { - // let msg = format!( - // "End range value requested ({end}) is less than start range value ({start})" - // ); - // error!(msg); - // return Err(DataProcessingError::ParameterError(msg)); - // } - // slice_object_range(&blocks, start, end - start) - // } - // None => blocks, - // }; - - // self.data_from_chunks_to_writer(ranged_blocks, writer) - // .await?; + self.cas.get_file(file_id, writer).await?; Ok(()) } - - async fn data_from_chunks_to_writer( - &self, - chunks: Vec, - writer: &mut impl std::io::Write, - ) -> Result<()> { - data_from_chunks_to_writer( - &self.cas, - self.config.cas_storage_config.prefix.clone(), - chunks, - writer, - ) - .await - } } diff --git a/data/src/lib.rs b/data/src/lib.rs index 7c2b7f4d..cdf1b417 100644 --- a/data/src/lib.rs +++ b/data/src/lib.rs @@ -14,7 +14,6 @@ mod repo_salt; mod shard_interface; mod small_file_determination; -pub use cas_interface::DEFAULT_BLOCK_SIZE; pub use constants::SMALL_FILE_THRESHOLD; pub use data_processing::PointerFileTranslator; pub use pointer_file::PointerFile; diff --git a/data/src/remote_shard_interface.rs b/data/src/remote_shard_interface.rs index a67f33aa..1369fd86 100644 --- a/data/src/remote_shard_interface.rs +++ b/data/src/remote_shard_interface.rs @@ -1,10 +1,10 @@ use super::configurations::{FileQueryPolicy, StorageConfig}; use super::errors::{DataProcessingError, Result}; use super::shard_interface::{create_shard_client, create_shard_manager}; +use crate::cas_interface::Client; use crate::constants::{FILE_RECONSTRUCTION_CACHE_SIZE, MAX_CONCURRENT_UPLOADS}; use crate::repo_salt::RepoSalt; use cas::singleflight; -use cas_client::Staging; use file_utils::write_all_safe; use lru::LruCache; use mdb_shard::constants::MDB_SHARD_MIN_TARGET_SIZE; @@ -31,7 +31,7 @@ pub struct RemoteShardInterface { pub repo_salt: Option, - pub cas: Option>, + pub cas: Option>, pub shard_manager: Option>, pub shard_client: Option>, pub reconstruction_cache: @@ -55,7 +55,7 @@ impl RemoteShardInterface { file_query_policy: FileQueryPolicy, shard_storage_config: &StorageConfig, shard_manager: Option>, - cas: Option>, + cas: Option>, repo_salt: Option, ) -> Result> { let shard_client = { @@ -93,7 +93,7 @@ impl RemoteShardInterface { })) } - fn cas(&self) -> Result> { + fn cas(&self) -> Result> { let Some(cas) = self.cas.clone() else { // Trigger error and backtrace return Err(DataProcessingError::CASConfigError( @@ -408,7 +408,7 @@ fn is_shard_file(path: &Path) -> bool { // Returns the path to the downloaded file and the number of bytes transferred. // Returns the path to the existing file and 0 (transferred byte) if exists. async fn download_shard( - cas: &Arc, + cas: &Arc, prefix: &str, shard_hash: &MerkleHash, dest_dir: &Path, diff --git a/hf_xet/src/config.rs b/hf_xet/src/config.rs index b9cec5b5..2614d683 100644 --- a/hf_xet/src/config.rs +++ b/hf_xet/src/config.rs @@ -1,7 +1,10 @@ +use data::configurations::{ + Auth, CacheConfig, DedupConfig, Endpoint, FileQueryPolicy, RepoInfo, StorageConfig, + TranslatorConfig, +}; +use data::{errors, DEFAULT_BLOCK_SIZE}; use std::env::current_dir; use std::fs; -use data::configurations::{Auth, CacheConfig, DedupConfig, Endpoint, FileQueryPolicy, RepoInfo, StorageConfig, TranslatorConfig}; -use data::{DEFAULT_BLOCK_SIZE, errors}; pub const SMALL_FILE_THRESHOLD: usize = 1; @@ -20,15 +23,13 @@ pub fn default_config(endpoint: String, token: Option) -> errors::Result cache_config: Some(CacheConfig { cache_directory: path.join("cache"), cache_size: 10 * 1024 * 1024 * 1024, // 10 GiB - cache_blocksize: DEFAULT_BLOCK_SIZE, + cache_blocksize: 0, // ignored }), staging_directory: None, }, shard_storage_config: StorageConfig { endpoint: Endpoint::Server(endpoint), - auth: Auth { - token: token, - }, + auth: Auth { token: token }, prefix: "default-merkledb".into(), cache_config: Some(CacheConfig { cache_directory: path.join("shard-cache"), diff --git a/hf_xet/src/data_client.rs b/hf_xet/src/data_client.rs index e10fe857..e5664d6d 100644 --- a/hf_xet/src/data_client.rs +++ b/hf_xet/src/data_client.rs @@ -92,13 +92,13 @@ async fn clean_file(processor: &PointerFileTranslator, f: String) -> errors::Res Ok(pf) } -async fn smudge_file(proc: &PointerFileTranslator, pointer_file: &PointerFile, endpoint: Option, token: Option) -> errors::Result { +async fn smudge_file(proc: &PointerFileTranslator, pointer_file: &PointerFile) -> errors::Result { let path = PathBuf::from(pointer_file.path()); if let Some(parent_dir) = path.parent() { fs::create_dir_all(parent_dir)?; } let mut f = File::create(&path)?; - proc.smudge_file_from_pointer(&pointer_file, &mut f, None, endpoint, token).await?; + proc.smudge_file_from_pointer(&pointer_file, &mut f, None).await?; Ok(pointer_file.path().to_string()) } From 45d0a3e2d8137dcefcb9e6fdb5a060d652d1bf60 Mon Sep 17 00:00:00 2001 From: Rajat Arya Date: Thu, 26 Sep 2024 11:36:50 -0700 Subject: [PATCH 03/19] Making old validate hash function private --- cas_object/src/cas_object_format.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index b91774f3..781ebdbb 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -514,7 +514,7 @@ impl CasObject { Ok((cas, total_written_bytes)) } - pub fn validate_root_hash(data: &[u8], chunk_boundaries: &[u32], hash: &MerkleHash) -> bool { + fn validate_root_hash(data: &[u8], chunk_boundaries: &[u32], hash: &MerkleHash) -> bool { // at least 1 chunk, and last entry in chunk boundary must match the length if chunk_boundaries.is_empty() || chunk_boundaries[chunk_boundaries.len() - 1] as usize != data.len() From 8e6e84ab1dab0f2172377bcf5e167b6320578bac Mon Sep 17 00:00:00 2001 From: Rajat Arya Date: Fri, 27 Sep 2024 13:27:54 -0700 Subject: [PATCH 04/19] Partial CasObject changes --- cas_object/src/cas_object_format.rs | 233 ++++++++-------------------- 1 file changed, 65 insertions(+), 168 deletions(-) diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index 781ebdbb..35463a8c 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -2,6 +2,7 @@ use bytes::Buf; use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB}; use merklehash::{DataHash, MerkleHash}; use tracing::warn; +use core::num; use std::{ cmp::min, io::{Cursor, Error, Read, Seek, Write}, @@ -29,42 +30,28 @@ pub struct CasObjectInfo { pub version: u8, /// 256-bits, 16-bytes, The CAS Hash of this Xorb. - pub cashash: DataHash, + pub cashash: MerkleHash, - /// Total number of chunks in the file. Length of chunk_size_info. + /// Total number of chunks in the Xorb. Length of chunk_byte_offset & chunk_hashes vectors. pub num_chunks: u32, - /// Chunk metadata (start of chunk, length of chunk), length of vector matches num_chunks. - /// This vector is expected to be in order (ex. `chunk[0].start_byte_index == 0`). - /// If uncompressed chunk, then: `chunk[n].start_byte_index == chunk[n-1].uncompressed_cumulative_len`. - /// And the final entry in this vector is a dummy entry to know the final chunk ending byte range. - /// + /// Byte offset marking the beginning of each chunk. Length of vector is num_chunks. + /// + /// To find the end of a chunk chunk[n] last byte is chunk[n+1].chunk_byte_index-1. + /// The final entry in this vector is a dummy entry to know the final chunk ending byte. /// ``` - /// // ex. chunks: [ 0 - 99 | 100 - 199 | 200 - 299 ] - /// // chunk_size_info : < (0,100), (100, 200), (200, 300), (300, 300) > <-- notice extra entry. + /// // ex. chunks: [ 0, 1, 2, 3] + /// // chunk_byte_offset: [ 0, 100, 200, 300, 400] <-- notice extra entry /// ``` - pub chunk_size_info: Vec, + pub chunk_byte_offsets: Vec, + + /// Merklehash for each chunk stored in the Xorb. Length of vector is num_chunks. + pub chunk_hashes: Vec, /// Unused 16-byte buffer to allow for future extensibility. _buffer: [u8; 16], } -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct CasChunkInfo { - /// Starting index of chunk. - /// - /// Ex. `chunk[5]` would start at start_byte_index - /// from the beginning of the XORB. - /// - /// This does include chunk header, to allow for fast range lookups. - pub start_byte_index: u32, - - /// Cumulative length of chunk. - /// - /// Does not include chunk header length, only uncompressed contents. - pub cumulative_uncompressed_len: u32, -} - impl Default for CasObjectInfo { fn default() -> Self { CasObjectInfo { @@ -72,16 +59,17 @@ impl Default for CasObjectInfo { version: CAS_OBJECT_FORMAT_VERSION, cashash: DataHash::default(), num_chunks: 0, - chunk_size_info: Vec::new(), + chunk_byte_offsets: Vec::new(), + chunk_hashes: Vec::new(), _buffer: Default::default(), } } } impl CasObjectInfo { - /// Serialize CasObjectMetadata to provided Writer. + /// Serialize CasObjectInfo to provided Writer. /// - /// Assumes caller has set position of Writer to appropriate location for metadata serialization. + /// Assumes caller has set position of Writer to appropriate location for serialization. pub fn serialize(&self, writer: &mut W) -> Result { let mut total_bytes_written = 0; @@ -98,10 +86,12 @@ impl CasObjectInfo { write_bytes(self.cashash.as_bytes())?; write_bytes(&self.num_chunks.to_le_bytes())?; - // write variable field: chunk_size_metadata - for chunk in &self.chunk_size_info { - let chunk_bytes = chunk.as_bytes(); - write_bytes(&chunk_bytes)?; + // write variable field: chunk offsets & hashes + for offset in &self.chunk_byte_offsets { + write_bytes(&offset.to_le_bytes())?; + } + for hash in &self.chunk_hashes { + write_bytes(hash.as_bytes())?; } // write closing metadata @@ -160,11 +150,17 @@ impl CasObjectInfo { read_bytes(&mut num_chunks)?; let num_chunks = u32::from_le_bytes(num_chunks); - let mut chunk_size_info = Vec::with_capacity(num_chunks as usize); + let mut chunk_byte_offsets = Vec::with_capacity(num_chunks as usize); + for _ in 0..num_chunks { + let mut offset = [0u8; 4]; + read_bytes(&mut offset)?; + chunk_byte_offsets.push(u32::from_le_bytes(offset)); + } + let mut chunk_hashes = Vec::with_capacity(num_chunks as usize); for _ in 0..num_chunks { - let mut buf = [0u8; size_of::()]; - read_bytes(&mut buf)?; - chunk_size_info.push(CasChunkInfo::from_bytes(buf)?); + let mut hash = [0u8; 32]; + read_bytes(&mut hash)?; + chunk_hashes.push(DataHash::from(&hash)); } let mut _buffer = [0u8; 16]; @@ -183,7 +179,8 @@ impl CasObjectInfo { version: version[0], cashash, num_chunks, - chunk_size_info, + chunk_byte_offsets, + chunk_hashes, _buffer, }, info_length, @@ -191,22 +188,6 @@ impl CasObjectInfo { } } -impl CasChunkInfo { - pub fn as_bytes(&self) -> [u8; size_of::()] { - let mut serialized_bytes = [0u8; size_of::()]; // 8 bytes, 2 u32 - serialized_bytes[..4].copy_from_slice(&self.start_byte_index.to_le_bytes()); - serialized_bytes[4..].copy_from_slice(&self.cumulative_uncompressed_len.to_le_bytes()); - serialized_bytes - } - - pub fn from_bytes(buf: [u8; 8]) -> Result { - Ok(Self { - start_byte_index: u32::from_le_bytes(buf[..4].try_into().unwrap()), - cumulative_uncompressed_len: u32::from_le_bytes(buf[4..].try_into().unwrap()), - }) - } -} - #[derive(Clone, PartialEq, Eq, Debug)] /// XORB: 16MB data block for storing chunks. /// @@ -265,104 +246,27 @@ impl CasObject { Ok(info_length) } - /// Deserialize the header only. + /// Deserialize the CasObjectInfo struct, the metadata for this Xorb. /// /// This allows the CasObject to be partially constructed, allowing for range reads inside the CasObject. pub fn deserialize(reader: &mut R) -> Result { let (info, info_length) = CasObjectInfo::deserialize(reader)?; Ok(Self { info, info_length }) } - - /// Translate desired range into actual byte range from within Xorb. - /// - /// This function will return a [RangeBoundaryHelper] struct to be able to read - /// a range from the Xorb. This function translates uncompressed ranges into their corresponding - /// Xorb chunk start byte index and Xorb chunk end byte index, along with an offset into that chunk. - /// See example below. - /// - /// Ex. If user requests range bytes 150-250 from a Xorb, and assume the following layout: - /// ``` - /// // chunk: [ 0 | 1 | 2 | 3 ] - /// // uncompressed chunks: [ 0-99 | 100-199 | 200-299 | 300-399 ] - /// // compressed chunks: [ 0-49 | 50-99 | 100-149 | 150-199 ] - /// ``` - /// This function needs to return starting index for chunk 1, with an offset of 50 bytes, and the end - /// index of chunk 2 in order to satisfy the range 150-250. - /// ``` - /// // let ranges = cas.get_range_boundaries(150, 250)?; - /// // ranges.compressed_range_start = 50 - /// // ranges.compressed_range_end = 150 - /// // ranges.uncompressed_offset = 50 - /// ``` - /// See [CasObject::get_range] for how these ranges are used. - pub fn get_range_boundaries( - &self, - start: u32, - end: u32, - ) -> Result { - if end < start { - return Err(CasObjectError::InvalidArguments); - } - - if end > self.get_contents_length()? { - return Err(CasObjectError::InvalidArguments); - } - - let chunk_size_info = &self.info.chunk_size_info; - - let mut compressed_range_start = u32::MAX; - let mut compressed_range_end = u32::MAX; - let mut uncompressed_offset = u32::MAX; - - // Enumerate all the chunks in order in the Xorb, but ignore the final one since that is a dummy chunk used to - // get the final byte index of the final content chunk. This allows the (idx + 1) to always be correct. - for (idx, c) in chunk_size_info[..chunk_size_info.len() - 1] - .iter() - .enumerate() - { - // Starting chunk is identified, store the start_byte_index of this chunk. - // compute the offset into the chunk if necessary by subtracting start range from end of - // previous chunk len (idx - 1). - if c.cumulative_uncompressed_len >= start && compressed_range_start == u32::MAX { - compressed_range_start = c.start_byte_index; - uncompressed_offset = if idx == 0 { - start - } else { - start - - chunk_size_info - .get(idx - 1) - .unwrap() - .cumulative_uncompressed_len - } - } - - // Once we find the 1st chunk (in-order) that meets the range query, we find the start_byte_index - // of the next chunk and capture that as compressed_range_end. This uses the dummy chunk entry - // to get the end of the final content chunk. - if c.cumulative_uncompressed_len >= end && compressed_range_end == u32::MAX { - compressed_range_end = chunk_size_info.get(idx + 1).unwrap().start_byte_index; - break; - } - } - - Ok(RangeBoundaryHelper { - compressed_range_start, - compressed_range_end, - uncompressed_offset, - }) - } - + /// Return end value of all chunk contents (byte index prior to header) pub fn get_contents_length(&self) -> Result { - match self.info.chunk_size_info.last() { - Some(c) => Ok(c.cumulative_uncompressed_len), + match self.info.chunk_byte_offsets.last() { + Some(c) => Ok(*c), None => Err(CasObjectError::FormatError(anyhow!( "Cannot retrieve content length" ))), } } - /// Get range of content bytes from Xorb + /// Get range of content bytes from Xorb. + /// + /// The start and end parameters are required to align with chunk boundaries. pub fn get_range( &self, reader: &mut R, @@ -380,10 +284,9 @@ impl CasObject { // let mut data = vec![0u8; (end - start) as usize]; // translate range into chunk bytes to read from xorb directly - let boundary = self.get_range_boundaries(start, end)?; - let chunk_start = boundary.compressed_range_start; - let chunk_end = boundary.compressed_range_end; - let offset = boundary.uncompressed_offset as usize; + let chunk_start = start; + let chunk_end = end; + let offset = 0; // read chunk bytes let mut chunk_data = vec![0u8; (chunk_end - chunk_start) as usize]; @@ -421,14 +324,11 @@ impl CasObject { self.get_range(reader, 0, self.get_contents_length()?) } - /// Helper function to translate CasObjectInfo.chunk_size_info to just return chunk_boundaries. + /// Helper function to translate CasObjectInfo.chunk_byte_offsets to just return chunk boundaries. /// - /// This isolates the weirdness about iterating through chunk_size_info and ignoring the final dummy entry. + /// This simplifies getting chunk boundaries by ignoring the dummy chunk in chunk_byte_offsets. fn get_chunk_boundaries(&self) -> Vec { - self.info.chunk_size_info.clone()[..self.info.chunk_size_info.len() - 1] - .iter() - .map(|c| c.cumulative_uncompressed_len) - .collect() + self.info.chunk_byte_offsets.clone()[..self.info.num_chunks as usize].to_vec() } /// Get all the content bytes from a Xorb, and return the chunk boundaries @@ -465,43 +365,37 @@ impl CasObject { let mut cas = CasObject::default(); cas.info.cashash.copy_from_slice(hash.as_slice()); cas.info.num_chunks = chunk_boundaries.len() as u32 + 1; // extra entry for dummy, see [chunk_size_info] for details. - cas.info.chunk_size_info = Vec::with_capacity(cas.info.num_chunks as usize); + cas.info.chunk_byte_offsets = Vec::with_capacity(cas.info.num_chunks as usize); + cas.info.chunk_hashes = Vec::with_capacity(cas.info.num_chunks as usize); let mut total_written_bytes: usize = 0; let mut raw_start_idx = 0; let mut start_idx: u32 = 0; - let mut cumulative_chunk_length: u32 = 0; for boundary in chunk_boundaries { let chunk_boundary: u32 = *boundary; let mut chunk_raw_bytes = Vec::::new(); chunk_raw_bytes .extend_from_slice(&data[raw_start_idx as usize..chunk_boundary as usize]); - let chunk_size = chunk_boundary - raw_start_idx; + + // generate chunk hash and store it + cas.info.chunk_byte_offsets.push(start_idx); + let chunk_hash = merklehash::compute_data_hash(&chunk_raw_bytes); + cas.info.chunk_hashes.push(chunk_hash); // now serialize chunk directly to writer (since chunks come first!) let chunk_written_bytes = serialize_chunk(&chunk_raw_bytes, writer, compression_scheme)?; total_written_bytes += chunk_written_bytes; - let chunk_meta = CasChunkInfo { - start_byte_index: start_idx, - cumulative_uncompressed_len: cumulative_chunk_length + chunk_size, - }; - cas.info.chunk_size_info.push(chunk_meta); - + // update indexes and onto next chunk start_idx += chunk_written_bytes as u32; raw_start_idx = chunk_boundary; - cumulative_chunk_length += chunk_size; } // dummy chunk_info to help with range reads. See [chunk_size_info] for details. - let chunk_meta = CasChunkInfo { - start_byte_index: start_idx, - cumulative_uncompressed_len: cumulative_chunk_length, - }; - cas.info.chunk_size_info.push(chunk_meta); + cas.info.chunk_byte_offsets.push(start_idx); // now that header is ready, write out to writer. let info_length = cas.info.serialize(writer)?; @@ -554,9 +448,9 @@ impl CasObject { let mut cumulative_uncompressed_length: u32 = 0; let mut cumulative_compressed_length: u32 = 0; - if let Some(c) = cas.info.chunk_size_info.first() { - if c.start_byte_index != 0 { - // for 1st chunk verify that its start_byte_index is 0 + if let Some(c) = cas.info.chunk_byte_offsets.first() { + if *c != 0 { + // for 1st chunk verify that its offset is 0 warn!("XORB Validation: Byte 0 does not contain 1st chunk."); return Ok(false); } @@ -564,6 +458,8 @@ impl CasObject { return Err(CasObjectError::FormatError(anyhow!("Invalid Xorb, no chunks"))); } + // TODO: iterate chunks, deseralize chunk for each one, compare stored hash with computed hash + for (idx, c) in cas.info.chunk_size_info[..cas.info.chunk_size_info.len() - 1].iter().enumerate() { // 3. verify on each chunk: @@ -637,7 +533,8 @@ mod tests { version: CAS_OBJECT_FORMAT_VERSION, cashash: DataHash::default(), num_chunks: 0, - chunk_size_info: Vec::new(), + chunk_byte_offsets: Vec::new(), + chunk_hashes: Vec::new(), _buffer: [0; 16], }; @@ -664,7 +561,7 @@ mod tests { assert_eq!(c.get_chunk_boundaries().len(), 3); assert_eq!(c.get_chunk_boundaries(), [100, 200, 300]); assert_eq!(c.info.num_chunks, 4); - assert_eq!(c.info.chunk_size_info.len(), c.info.num_chunks as usize); + assert_eq!(c.info.chunk_byte_offsets.len(), c.info.num_chunks as usize); let last_chunk_info = c.info.chunk_size_info[2].clone(); let dummy_chunk_info = c.info.chunk_size_info[3].clone(); From 682ff7a1dd6fe83d6d187626d01d1f0925a2ad6b Mon Sep 17 00:00:00 2001 From: Rajat Arya Date: Mon, 30 Sep 2024 13:52:14 -0700 Subject: [PATCH 05/19] Update CasObject format v3 - Removed CasChunkInfo - Now tracking chunk boundary offsets & chunk hashes --- cas_object/src/cas_object_format.rs | 289 +++++++++++----------------- 1 file changed, 113 insertions(+), 176 deletions(-) diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index 35463a8c..fcf414fc 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -2,7 +2,6 @@ use bytes::Buf; use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB}; use merklehash::{DataHash, MerkleHash}; use tracing::warn; -use core::num; use std::{ cmp::min, io::{Cursor, Error, Read, Seek, Write}, @@ -35,15 +34,17 @@ pub struct CasObjectInfo { /// Total number of chunks in the Xorb. Length of chunk_byte_offset & chunk_hashes vectors. pub num_chunks: u32, - /// Byte offset marking the beginning of each chunk. Length of vector is num_chunks. + /// Byte offset marking the boundary of each chunk. Length of vector is num_chunks. /// - /// To find the end of a chunk chunk[n] last byte is chunk[n+1].chunk_byte_index-1. - /// The final entry in this vector is a dummy entry to know the final chunk ending byte. + /// This vector only contains boundaries, so assumes the first chunk starts at offset 0. + /// The final entry in vector is the total length of the chunks. + /// See example below. + /// chunk[n] offset = chunk_boundary_offsets[n-1] /// ``` - /// // ex. chunks: [ 0, 1, 2, 3] - /// // chunk_byte_offset: [ 0, 100, 200, 300, 400] <-- notice extra entry + /// // ex. chunks: [ 0, 1, 2, 3 ] + /// // chunk_boundary_offsets: [ 100, 200, 300, 400] /// ``` - pub chunk_byte_offsets: Vec, + pub chunk_boundary_offsets: Vec, /// Merklehash for each chunk stored in the Xorb. Length of vector is num_chunks. pub chunk_hashes: Vec, @@ -59,7 +60,7 @@ impl Default for CasObjectInfo { version: CAS_OBJECT_FORMAT_VERSION, cashash: DataHash::default(), num_chunks: 0, - chunk_byte_offsets: Vec::new(), + chunk_boundary_offsets: Vec::new(), chunk_hashes: Vec::new(), _buffer: Default::default(), } @@ -87,7 +88,7 @@ impl CasObjectInfo { write_bytes(&self.num_chunks.to_le_bytes())?; // write variable field: chunk offsets & hashes - for offset in &self.chunk_byte_offsets { + for offset in &self.chunk_boundary_offsets { write_bytes(&offset.to_le_bytes())?; } for hash in &self.chunk_hashes { @@ -150,13 +151,13 @@ impl CasObjectInfo { read_bytes(&mut num_chunks)?; let num_chunks = u32::from_le_bytes(num_chunks); - let mut chunk_byte_offsets = Vec::with_capacity(num_chunks as usize); + let mut chunk_boundary_offsets = Vec::with_capacity(num_chunks as usize); for _ in 0..num_chunks { let mut offset = [0u8; 4]; read_bytes(&mut offset)?; - chunk_byte_offsets.push(u32::from_le_bytes(offset)); + chunk_boundary_offsets.push(u32::from_le_bytes(offset)); } - let mut chunk_hashes = Vec::with_capacity(num_chunks as usize); + let mut chunk_hashes = Vec::with_capacity(num_chunks as usize); // dummy chunk for _ in 0..num_chunks { let mut hash = [0u8; 32]; read_bytes(&mut hash)?; @@ -179,7 +180,7 @@ impl CasObjectInfo { version: version[0], cashash, num_chunks, - chunk_byte_offsets, + chunk_boundary_offsets, chunk_hashes, _buffer, }, @@ -212,24 +213,6 @@ impl Default for CasObject { } } -/// Helper struct to capture 3-part tuple needed to -/// correctly support range reads across compressed chunks in a Xorb. -/// -/// See docs for [CasObject::get_range_boundaries] for example usage. -pub struct RangeBoundaryHelper { - /// Index for range start in compressed chunks. - /// Guaranteed to be start of a [CASChunkHeader]. - pub compressed_range_start: u32, - - /// Index for range end in compressed chunk. - /// Guaranteed to be end of chunk. - pub compressed_range_end: u32, - - /// Offset into uncompressed chunk. This is necessary for - /// range requests that do not align with chunk boundary. - pub uncompressed_offset: u32, -} - impl CasObject { /// Deserializes only the info length field of the footer to tell the user how many bytes /// make up the info portion of the xorb. @@ -256,7 +239,7 @@ impl CasObject { /// Return end value of all chunk contents (byte index prior to header) pub fn get_contents_length(&self) -> Result { - match self.info.chunk_byte_offsets.last() { + match self.info.chunk_boundary_offsets.last() { Some(c) => Ok(*c), None => Err(CasObjectError::FormatError(anyhow!( "Cannot retrieve content length" @@ -280,24 +263,14 @@ impl CasObject { // make sure the end of the range is within the bounds of the xorb let end = min(end, self.get_contents_length()?); - // create return data bytes - // let mut data = vec![0u8; (end - start) as usize]; - - // translate range into chunk bytes to read from xorb directly - let chunk_start = start; - let chunk_end = end; - let offset = 0; - // read chunk bytes - let mut chunk_data = vec![0u8; (chunk_end - chunk_start) as usize]; - reader.seek(std::io::SeekFrom::Start(chunk_start as u64))?; + let mut chunk_data = vec![0u8; (end - start) as usize]; + reader.seek(std::io::SeekFrom::Start(start as u64))?; reader.read_exact(&mut chunk_data)?; // build up result vector by processing these chunks let chunk_contents = self.get_chunk_contents(&chunk_data)?; - let len = (end - start) as usize; - - Ok(chunk_contents[offset..offset + len].to_vec()) + Ok(chunk_contents) } /// Assumes chunk_data is 1+ complete chunks. Processes them sequentially and returns them as Vec. @@ -317,7 +290,7 @@ impl CasObject { pub fn get_all_bytes(&self, reader: &mut R) -> Result, CasObjectError> { if self.info == Default::default() { return Err(CasObjectError::InternalError(anyhow!( - "Incomplete CasObject, no header" + "Incomplete CasObject, no CasObjectInfo footer." ))); } @@ -326,9 +299,9 @@ impl CasObject { /// Helper function to translate CasObjectInfo.chunk_byte_offsets to just return chunk boundaries. /// - /// This simplifies getting chunk boundaries by ignoring the dummy chunk in chunk_byte_offsets. + /// The final chunk boundary returned is required to be the length of the contents, which is recorded in the dummy chunk. fn get_chunk_boundaries(&self) -> Vec { - self.info.chunk_byte_offsets.clone()[..self.info.num_chunks as usize].to_vec() + self.info.chunk_boundary_offsets.to_vec() } /// Get all the content bytes from a Xorb, and return the chunk boundaries @@ -364,14 +337,13 @@ impl CasObject { let mut cas = CasObject::default(); cas.info.cashash.copy_from_slice(hash.as_slice()); - cas.info.num_chunks = chunk_boundaries.len() as u32 + 1; // extra entry for dummy, see [chunk_size_info] for details. - cas.info.chunk_byte_offsets = Vec::with_capacity(cas.info.num_chunks as usize); + cas.info.num_chunks = chunk_boundaries.len() as u32; + cas.info.chunk_boundary_offsets = Vec::with_capacity(cas.info.num_chunks as usize); cas.info.chunk_hashes = Vec::with_capacity(cas.info.num_chunks as usize); let mut total_written_bytes: usize = 0; let mut raw_start_idx = 0; - let mut start_idx: u32 = 0; for boundary in chunk_boundaries { let chunk_boundary: u32 = *boundary; @@ -380,7 +352,6 @@ impl CasObject { .extend_from_slice(&data[raw_start_idx as usize..chunk_boundary as usize]); // generate chunk hash and store it - cas.info.chunk_byte_offsets.push(start_idx); let chunk_hash = merklehash::compute_data_hash(&chunk_raw_bytes); cas.info.chunk_hashes.push(chunk_hash); @@ -388,15 +359,12 @@ impl CasObject { let chunk_written_bytes = serialize_chunk(&chunk_raw_bytes, writer, compression_scheme)?; total_written_bytes += chunk_written_bytes; + cas.info.chunk_boundary_offsets.push(total_written_bytes as u32); // update indexes and onto next chunk - start_idx += chunk_written_bytes as u32; raw_start_idx = chunk_boundary; } - // dummy chunk_info to help with range reads. See [chunk_size_info] for details. - cas.info.chunk_byte_offsets.push(start_idx); - // now that header is ready, write out to writer. let info_length = cas.info.serialize(writer)?; cas.info_length = info_length as u32; @@ -443,49 +411,43 @@ impl CasObject { // 1. deserialize to get Info let cas = CasObject::deserialize(reader)?; - // 2. walk chunks from Info (skip the final dummy chunk) + // 2. walk chunks from Info let mut hash_chunks: Vec = Vec::new(); - let mut cumulative_uncompressed_length: u32 = 0; let mut cumulative_compressed_length: u32 = 0; - if let Some(c) = cas.info.chunk_byte_offsets.first() { - if *c != 0 { - // for 1st chunk verify that its offset is 0 - warn!("XORB Validation: Byte 0 does not contain 1st chunk."); - return Ok(false); - } - } else { - return Err(CasObjectError::FormatError(anyhow!("Invalid Xorb, no chunks"))); - } + let mut start_offset = 0; + // Validate each chunk: iterate chunks, deserialize chunk, compare stored hash with + // computed hash, store chunk hashes for cashash validation + for idx in 0..cas.info.num_chunks { - // TODO: iterate chunks, deseralize chunk for each one, compare stored hash with computed hash - - for (idx, c) in cas.info.chunk_size_info[..cas.info.chunk_size_info.len() - 1].iter().enumerate() { - - // 3. verify on each chunk: - reader.seek(std::io::SeekFrom::Start(c.start_byte_index as u64))?; + // deserialize each chunk + reader.seek(std::io::SeekFrom::Start(start_offset as u64))?; let (data, compressed_chunk_length) = deserialize_chunk(reader)?; let chunk_uncompressed_length = data.len(); - - // 3a. compute hash - hash_chunks.push(Chunk {hash: merklehash::compute_data_hash(&data), length: chunk_uncompressed_length}); - cumulative_uncompressed_length += data.len() as u32; + let chunk_hash = merklehash::compute_data_hash(&data); + hash_chunks.push(Chunk {hash: chunk_hash, length: chunk_uncompressed_length}); + cumulative_compressed_length += compressed_chunk_length as u32; - - // 3b. verify deserialized chunk is expected size from Info object - if cumulative_uncompressed_length != c.cumulative_uncompressed_len { - warn!("XORB Validation: Chunk length does not match Info object."); + + // verify chunk hash + if *cas.info.chunk_hashes.get(idx as usize).unwrap() != chunk_hash { + warn!("XORB Validation: Chunk hash does not match Info object."); return Ok(false); } - // 3c. verify start byte index of next chunk matches current byte index + compressed length - if cas.info.chunk_size_info[idx+1].start_byte_index != (c.start_byte_index + compressed_chunk_length as u32) { - warn!("XORB Validation: Chunk start byte index does not match Info object."); + let boundary = *cas.info.chunk_boundary_offsets.get(idx as usize).unwrap(); + + // verify that cas.chunk[n].len + 1 == cas.chunk_boundary_offsets[n] + if (start_offset + compressed_chunk_length as u32) != boundary { + warn!("XORB Validation: Chunk boundary byte index does not match Info object."); return Ok(false); } - } + // set start offset of next chunk as the boundary of the current chunk + start_offset = boundary; + } + // validate that Info/footer begins immediately after final content xorb. // end of for loop completes the content chunks, now should be able to deserialize an Info directly let cur_position = reader.stream_position()? as u32; @@ -533,7 +495,7 @@ mod tests { version: CAS_OBJECT_FORMAT_VERSION, cashash: DataHash::default(), num_chunks: 0, - chunk_byte_offsets: Vec::new(), + chunk_boundary_offsets: Vec::new(), chunk_hashes: Vec::new(), _buffer: [0; 16], }; @@ -556,18 +518,17 @@ mod tests { #[test] fn test_chunk_boundaries_chunk_size_info() { // Arrange - let (c, _cas_data, _raw_data) = build_cas_object(3, 100, false, false); + let (c, _cas_data, _raw_data, _raw_chunk_boundaries) = build_cas_object(3, 100, false, CompressionScheme::None); // Act & Assert assert_eq!(c.get_chunk_boundaries().len(), 3); - assert_eq!(c.get_chunk_boundaries(), [100, 200, 300]); - assert_eq!(c.info.num_chunks, 4); - assert_eq!(c.info.chunk_byte_offsets.len(), c.info.num_chunks as usize); - - let last_chunk_info = c.info.chunk_size_info[2].clone(); - let dummy_chunk_info = c.info.chunk_size_info[3].clone(); - assert_eq!(dummy_chunk_info.cumulative_uncompressed_len, 300); - assert_eq!(dummy_chunk_info.start_byte_index, 324); // 8-byte header, 3 chunks, so 4th chunk should start at byte 324 - assert_eq!(last_chunk_info.cumulative_uncompressed_len, 300); + assert_eq!(c.get_chunk_boundaries(), [108, 216, 324]); + assert_eq!(c.info.num_chunks, 3); + assert_eq!(c.info.chunk_boundary_offsets.len(), c.info.num_chunks as usize); + + let second_chunk_boundary = *c.info.chunk_boundary_offsets.get(1).unwrap(); + let third_chunk_boundary = *c.info.chunk_boundary_offsets.get(2).unwrap(); + assert_eq!(second_chunk_boundary, 216); // 8-byte header, 3 chunks, so 2nd chunk boundary is at byte 216 + assert_eq!(third_chunk_boundary, 324); // 8-byte header, 3 chunks, so 3rd chunk boundary is at byte 324 } fn gen_random_bytes(uncompressed_chunk_size: u32) -> Vec { @@ -577,22 +538,24 @@ mod tests { data } + /// Utility test method for creating a cas object + /// Returns (CasObject, CasObjectInfo serialized, raw data, raw data chunk boundaries) fn build_cas_object( num_chunks: u32, uncompressed_chunk_size: u32, use_random_chunk_size: bool, - use_lz4_compression: bool - ) -> (CasObject, Vec, Vec) { + compression_scheme: CompressionScheme, + ) -> (CasObject, Vec, Vec, Vec) { let mut c = CasObject::default(); - let mut chunk_size_info = Vec::::new(); + let mut chunk_boundary_offsets = Vec::::new(); + let mut chunk_hashes = Vec::::new(); let mut writer = Cursor::new(Vec::::new()); let mut total_bytes = 0; - let mut uncompressed_bytes: u32 = 0; - - let mut data_contents_raw = - Vec::::with_capacity(num_chunks as usize * uncompressed_chunk_size as usize); + let mut chunks: Vec = Vec::new(); + let mut data_contents_raw = Vec::::new(); + let mut raw_chunk_boundaries = Vec::::new(); for _idx in 0..num_chunks { let chunk_size: u32 = if use_random_chunk_size { @@ -603,17 +566,14 @@ mod tests { }; let bytes = gen_random_bytes(chunk_size); - let len: u32 = bytes.len() as u32; + + let chunk_hash = merklehash::compute_data_hash(&bytes); + chunks.push(Chunk { hash: chunk_hash, length: bytes.len() }); data_contents_raw.extend_from_slice(&bytes); // build chunk, create ChunkInfo and keep going - let compression_scheme = match use_lz4_compression { - true => CompressionScheme::LZ4, - false => CompressionScheme::None - }; - let bytes_written = serialize_chunk( &bytes, &mut writer, @@ -621,59 +581,36 @@ mod tests { ) .unwrap(); - let chunk_info = CasChunkInfo { - start_byte_index: total_bytes, - cumulative_uncompressed_len: uncompressed_bytes + len, - }; - - chunk_size_info.push(chunk_info); total_bytes += bytes_written as u32; - uncompressed_bytes += len; + + raw_chunk_boundaries.push(data_contents_raw.len() as u32); + chunk_boundary_offsets.push(total_bytes); + chunk_hashes.push(chunk_hash); } - let chunk_info = CasChunkInfo { - start_byte_index: total_bytes, - cumulative_uncompressed_len: uncompressed_bytes, - }; - chunk_size_info.push(chunk_info); + c.info.num_chunks = chunk_boundary_offsets.len() as u32; + c.info.chunk_boundary_offsets = chunk_boundary_offsets; + c.info.chunk_hashes = chunk_hashes; - c.info.num_chunks = chunk_size_info.len() as u32; - c.info.chunk_size_info = chunk_size_info; + let mut db = MerkleMemDB::default(); + let mut staging = db.start_insertion_staging(); + db.add_file(&mut staging, &chunks); + let ret = db.finalize(staging); - c.info.cashash = gen_hash(&data_contents_raw, &c.get_chunk_boundaries()); + c.info.cashash = *ret.hash(); // now serialize info to end Xorb length - let len = c.info.serialize(&mut writer).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let len = c.info.serialize(&mut buf).unwrap(); c.info_length = len as u32; - writer.write_all(&c.info_length.to_le_bytes()).unwrap(); - - (c, writer.get_ref().to_vec(), data_contents_raw) - } - - fn gen_hash(data: &[u8], chunk_boundaries: &[u32]) -> DataHash { - let mut chunks: Vec = Vec::new(); - let mut left_edge: usize = 0; - for i in chunk_boundaries { - let right_edge = *i as usize; - let hash = merklehash::compute_data_hash(&data[left_edge..right_edge]); - let length = right_edge - left_edge; - chunks.push(Chunk { hash, length }); - left_edge = right_edge; - } - - let mut db = MerkleMemDB::default(); - let mut staging = db.start_insertion_staging(); - db.add_file(&mut staging, &chunks); - let ret = db.finalize(staging); - *ret.hash() + (c, writer.get_ref().to_vec(), data_contents_raw, raw_chunk_boundaries) } #[test] fn test_compress_decompress() { // Arrange - let (c, _cas_data, raw_data) = build_cas_object(55, 53212, false, true); - let hash = gen_hash(&&raw_data, &c.get_chunk_boundaries()); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(55, 53212, false, CompressionScheme::LZ4); // Act & Assert let mut writer: Cursor> = Cursor::new(Vec::new()); @@ -681,7 +618,7 @@ mod tests { &mut writer, &c.info.cashash, &raw_data, - &c.get_chunk_boundaries(), + &raw_chunk_boundaries, CompressionScheme::LZ4 ) .is_ok()); @@ -693,15 +630,13 @@ mod tests { let c = res.unwrap(); let c_bytes = c.get_all_bytes(&mut reader).unwrap(); - let c_boundaries = c.get_chunk_boundaries(); - let c_hash = gen_hash(&c_bytes, &c_boundaries); let mut writer: Cursor> = Cursor::new(Vec::new()); assert!(CasObject::serialize( &mut writer, - &c_hash, + &c.info.cashash, &c_bytes, - &c_boundaries, + &raw_chunk_boundaries, CompressionScheme::None ) .is_ok()); @@ -712,42 +647,44 @@ mod tests { assert!(res.is_ok()); let c2 = res.unwrap(); - assert_eq!(hash, c_hash); - assert_eq!(c.info.cashash, hash); assert_eq!(c2.info.cashash, c.info.cashash); + assert_eq!(c.get_all_bytes(&mut writer), c.get_all_bytes(&mut reader)); + assert!(CasObject::validate_cas_object(&mut reader, &c2.info.cashash).is_ok()); + assert!(CasObject::validate_cas_object(&mut writer, &c.info.cashash).is_ok()); } #[test] fn test_hash_generation_compression() { // Arrange - let (c, cas_data, raw_data) = build_cas_object(55, 53212, false, true); + let (c, cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(55, 53212, false, CompressionScheme::LZ4); // Act & Assert let mut buf: Cursor> = Cursor::new(Vec::new()); assert!(CasObject::serialize( &mut buf, &c.info.cashash, &raw_data, - &c.get_chunk_boundaries(), + &raw_chunk_boundaries, CompressionScheme::LZ4 ) .is_ok()); - assert_eq!(c.info.cashash, gen_hash(&raw_data, &c.get_chunk_boundaries())); - assert_eq!(raw_data, c.get_all_bytes(&mut buf).unwrap()); - assert_eq!(&cas_data, buf.get_ref()); + let serialized_all_bytes = c.get_all_bytes(&mut buf).unwrap(); + + assert_eq!(raw_data, serialized_all_bytes); + assert_eq!(cas_data.len() as u32, c.get_contents_length().unwrap()); } #[test] fn test_basic_serialization_mem() { // Arrange - let (c, _cas_data, raw_data) = build_cas_object(3, 100, false, false); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(3, 100, false, CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( &mut buf, &c.info.cashash, &raw_data, - &c.get_chunk_boundaries(), + &raw_chunk_boundaries, CompressionScheme::None ) .is_ok()); @@ -758,14 +695,14 @@ mod tests { #[test] fn test_serialization_deserialization_mem_medium() { // Arrange - let (c, _cas_data, raw_data) = build_cas_object(32, 16384, false, false); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(32, 16384, false, CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( &mut buf, &c.info.cashash, &raw_data, - &c.get_chunk_boundaries(), + &raw_chunk_boundaries, CompressionScheme::None ) .is_ok()); @@ -788,14 +725,14 @@ mod tests { #[test] fn test_serialization_deserialization_mem_large_random() { // Arrange - let (c, _cas_data, raw_data) = build_cas_object(32, 65536, true, false); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(32, 65536, true, CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( &mut buf, &c.info.cashash, &raw_data, - &c.get_chunk_boundaries(), + &raw_chunk_boundaries, CompressionScheme::None ) .is_ok()); @@ -817,14 +754,14 @@ mod tests { #[test] fn test_serialization_deserialization_file_large_random() { // Arrange - let (c, _cas_data, raw_data) = build_cas_object(256, 65536, true, false); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(256, 65536, true, CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( &mut buf, &c.info.cashash, &raw_data, - &c.get_chunk_boundaries(), + &raw_chunk_boundaries, CompressionScheme::None ) .is_ok()); @@ -846,14 +783,14 @@ mod tests { #[test] fn test_basic_mem_lz4() { // Arrange - let (c, _cas_data, raw_data) = build_cas_object(1, 8, false, true); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(1, 8, false, CompressionScheme::LZ4); let mut writer: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( &mut writer, &c.info.cashash, &raw_data, - &c.get_chunk_boundaries(), + &raw_chunk_boundaries, CompressionScheme::LZ4 ) .is_ok()); @@ -874,14 +811,14 @@ mod tests { #[test] fn test_serialization_deserialization_mem_medium_lz4() { // Arrange - let (c, _cas_data, raw_data) = build_cas_object(32, 16384, false, true); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(32, 16384, false, CompressionScheme::LZ4); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( &mut buf, &c.info.cashash, &raw_data, - &c.get_chunk_boundaries(), + &raw_chunk_boundaries, CompressionScheme::LZ4 ) .is_ok()); @@ -904,14 +841,14 @@ mod tests { #[test] fn test_serialization_deserialization_mem_large_random_lz4() { // Arrange - let (c, _cas_data, raw_data) = build_cas_object(32, 65536, true, true); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(32, 65536, true, CompressionScheme::LZ4); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( &mut buf, &c.info.cashash, &raw_data, - &c.get_chunk_boundaries(), + &raw_chunk_boundaries, CompressionScheme::LZ4 ) .is_ok()); @@ -933,14 +870,14 @@ mod tests { #[test] fn test_serialization_deserialization_file_large_random_lz4() { // Arrange - let (c, _cas_data, raw_data) = build_cas_object(256, 65536, true, true); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(256, 65536, true, CompressionScheme::LZ4); let mut writer: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( &mut writer, &c.info.cashash, &raw_data, - &c.get_chunk_boundaries(), + &raw_chunk_boundaries, CompressionScheme::LZ4 ) .is_ok()); From e48352810611fc6f54932e57745b7db8e72064e5 Mon Sep 17 00:00:00 2001 From: Rajat Arya Date: Mon, 30 Sep 2024 14:03:13 -0700 Subject: [PATCH 06/19] Updating comments --- cas_object/src/cas_object_format.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index fcf414fc..dd8cf1c7 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -18,9 +18,7 @@ const CAS_OBJECT_FORMAT_VERSION: u8 = 0; const CAS_OBJECT_INFO_DEFAULT_LENGTH: u32 = 60; #[derive(Clone, PartialEq, Eq, Debug)] -/// Info struct for [CasObject]. This is stored at the end of the XORB -/// -/// See details here: https://www.notion.so/huggingface2/Introduction-To-XetHub-Storage-Architecture-And-The-Integration-Path-54c3d14c682c4e41beab2364f273fc35?pvs=4#4ffa9b930a6942bd87f054714865375d +/// Info struct for [CasObject]. This is stored at the end of the XORB. pub struct CasObjectInfo { /// CAS identifier: "XETBLOB" pub ident: [u8; 7], @@ -31,7 +29,7 @@ pub struct CasObjectInfo { /// 256-bits, 16-bytes, The CAS Hash of this Xorb. pub cashash: MerkleHash, - /// Total number of chunks in the Xorb. Length of chunk_byte_offset & chunk_hashes vectors. + /// Total number of chunks in the Xorb. Length of chunk_boundary_offsets & chunk_hashes vectors. pub num_chunks: u32, /// Byte offset marking the boundary of each chunk. Length of vector is num_chunks. @@ -87,7 +85,7 @@ impl CasObjectInfo { write_bytes(self.cashash.as_bytes())?; write_bytes(&self.num_chunks.to_le_bytes())?; - // write variable field: chunk offsets & hashes + // write variable field: chunk boundaries & hashes for offset in &self.chunk_boundary_offsets { write_bytes(&offset.to_le_bytes())?; } @@ -157,7 +155,7 @@ impl CasObjectInfo { read_bytes(&mut offset)?; chunk_boundary_offsets.push(u32::from_le_bytes(offset)); } - let mut chunk_hashes = Vec::with_capacity(num_chunks as usize); // dummy chunk + let mut chunk_hashes = Vec::with_capacity(num_chunks as usize); for _ in 0..num_chunks { let mut hash = [0u8; 32]; read_bytes(&mut hash)?; @@ -192,7 +190,17 @@ impl CasObjectInfo { #[derive(Clone, PartialEq, Eq, Debug)] /// XORB: 16MB data block for storing chunks. /// -/// Has header, and a set of functions that interact directly with XORB. +/// Has Info footer, and a set of functions that interact directly with XORB. +/// +/// Physical layout of this object is as follows: +/// [START OF XORB] +/// +/// +/// <..> +/// +/// +/// CasObjectinfo length: u32 +/// [END OF XORB] pub struct CasObject { /// CasObjectInfo block see [CasObjectInfo] for details. pub info: CasObjectInfo, @@ -299,7 +307,7 @@ impl CasObject { /// Helper function to translate CasObjectInfo.chunk_byte_offsets to just return chunk boundaries. /// - /// The final chunk boundary returned is required to be the length of the contents, which is recorded in the dummy chunk. + /// The final chunk boundary returned is required to be the length of the contents. fn get_chunk_boundaries(&self) -> Vec { self.info.chunk_boundary_offsets.to_vec() } From 0d4f2e3d4c03b0b9ee4f1836012c5830e578b685 Mon Sep 17 00:00:00 2001 From: Rajat Arya Date: Mon, 30 Sep 2024 15:05:29 -0700 Subject: [PATCH 07/19] Reformatted code, add validation - Moved static functions before methods. - Moved pub functions & methods before non-public functions. - Added CasObjectInfo basic validation on object methods. --- cas_object/src/cas_object_format.rs | 234 +++++++++++++++------------- 1 file changed, 122 insertions(+), 112 deletions(-) diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index dd8cf1c7..99cbec9f 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -222,6 +222,7 @@ impl Default for CasObject { } impl CasObject { + /// Deserializes only the info length field of the footer to tell the user how many bytes /// make up the info portion of the xorb. /// @@ -244,90 +245,6 @@ impl CasObject { let (info, info_length) = CasObjectInfo::deserialize(reader)?; Ok(Self { info, info_length }) } - - /// Return end value of all chunk contents (byte index prior to header) - pub fn get_contents_length(&self) -> Result { - match self.info.chunk_boundary_offsets.last() { - Some(c) => Ok(*c), - None => Err(CasObjectError::FormatError(anyhow!( - "Cannot retrieve content length" - ))), - } - } - - /// Get range of content bytes from Xorb. - /// - /// The start and end parameters are required to align with chunk boundaries. - pub fn get_range( - &self, - reader: &mut R, - start: u32, - end: u32, - ) -> Result, CasObjectError> { - if end < start { - return Err(CasObjectError::InvalidRange); - } - - // make sure the end of the range is within the bounds of the xorb - let end = min(end, self.get_contents_length()?); - - // read chunk bytes - let mut chunk_data = vec![0u8; (end - start) as usize]; - reader.seek(std::io::SeekFrom::Start(start as u64))?; - reader.read_exact(&mut chunk_data)?; - - // build up result vector by processing these chunks - let chunk_contents = self.get_chunk_contents(&chunk_data)?; - Ok(chunk_contents) - } - - /// Assumes chunk_data is 1+ complete chunks. Processes them sequentially and returns them as Vec. - fn get_chunk_contents(&self, chunk_data: &[u8]) -> Result, CasObjectError> { - // walk chunk_data, deserialize into Chunks, and then get_bytes() from each of them. - let mut reader = Cursor::new(chunk_data); - let mut res = Vec::::new(); - - while reader.has_remaining() { - let (data, _) = deserialize_chunk(&mut reader)?; - res.extend_from_slice(&data); - } - Ok(res) - } - - /// Get all the content bytes from a Xorb - pub fn get_all_bytes(&self, reader: &mut R) -> Result, CasObjectError> { - if self.info == Default::default() { - return Err(CasObjectError::InternalError(anyhow!( - "Incomplete CasObject, no CasObjectInfo footer." - ))); - } - - self.get_range(reader, 0, self.get_contents_length()?) - } - - /// Helper function to translate CasObjectInfo.chunk_byte_offsets to just return chunk boundaries. - /// - /// The final chunk boundary returned is required to be the length of the contents. - fn get_chunk_boundaries(&self) -> Vec { - self.info.chunk_boundary_offsets.to_vec() - } - - /// Get all the content bytes from a Xorb, and return the chunk boundaries - pub fn get_detailed_bytes( - &self, - reader: &mut R, - ) -> Result<(Vec, Vec), CasObjectError> { - if self.info == Default::default() { - return Err(CasObjectError::InternalError(anyhow!( - "Incomplete CasObject, no header" - ))); - } - - let data = self.get_all_bytes(reader)?; - let chunk_boundaries = self.get_chunk_boundaries(); - - Ok((chunk_boundaries, data)) - } /// Used by LocalClient for generating Cas Object from chunk_boundaries while uploading or downloading blocks. pub fn serialize( @@ -382,32 +299,7 @@ impl CasObject { total_written_bytes += size_of::(); Ok((cas, total_written_bytes)) - } - - fn validate_root_hash(data: &[u8], chunk_boundaries: &[u32], hash: &MerkleHash) -> bool { - // at least 1 chunk, and last entry in chunk boundary must match the length - if chunk_boundaries.is_empty() - || chunk_boundaries[chunk_boundaries.len() - 1] as usize != data.len() - { - return false; - } - - let mut chunks: Vec = Vec::new(); - let mut left_edge: usize = 0; - for i in chunk_boundaries { - let right_edge = *i as usize; - let hash = merklehash::compute_data_hash(&data[left_edge..right_edge]); - let length = right_edge - left_edge; - chunks.push(Chunk { hash, length }); - left_edge = right_edge; - } - - let mut db = MerkleMemDB::default(); - let mut staging = db.start_insertion_staging(); - db.add_file(&mut staging, &chunks); - let ret = db.finalize(staging); - *ret.hash() == *hash - } + } /// Validate CasObject. /// Verifies each chunk is valid and correctly represented in CasObjectInfo, along with @@ -481,6 +373,124 @@ impl CasObject { } + /// Return end value of all chunk contents (byte index prior to header) + pub fn get_contents_length(&self) -> Result { + self.validate_cas_object_info()?; + match self.info.chunk_boundary_offsets.last() { + Some(c) => Ok(*c), + None => Err(CasObjectError::FormatError(anyhow!( + "Cannot retrieve content length" + ))), + } + } + + /// Get range of content bytes from Xorb. + /// + /// The start and end parameters are required to align with chunk boundaries. + pub fn get_range( + &self, + reader: &mut R, + start: u32, + end: u32, + ) -> Result, CasObjectError> { + + if end < start { + return Err(CasObjectError::InvalidRange); + } + + self.validate_cas_object_info()?; + + // make sure the end of the range is within the bounds of the xorb + let end = min(end, self.get_contents_length()?); + + // read chunk bytes + let mut chunk_data = vec![0u8; (end - start) as usize]; + reader.seek(std::io::SeekFrom::Start(start as u64))?; + reader.read_exact(&mut chunk_data)?; + + // build up result vector by processing these chunks + let chunk_contents = self.get_chunk_contents(&chunk_data)?; + Ok(chunk_contents) + } + + /// Get all the content bytes from a Xorb + pub fn get_all_bytes(&self, reader: &mut R) -> Result, CasObjectError> { + self.validate_cas_object_info()?; + self.get_range(reader, 0, self.get_contents_length()?) + } + + /// Get all the content bytes from a Xorb, and return the chunk boundaries + pub fn get_detailed_bytes( + &self, + reader: &mut R, + ) -> Result<(Vec, Vec), CasObjectError> { + self.validate_cas_object_info()?; + + let data = self.get_all_bytes(reader)?; + let chunk_boundaries = self.get_chunk_boundaries()?; + + Ok((chunk_boundaries, data)) + } + + /// Assumes chunk_data is 1+ complete chunks. Processes them sequentially and returns them as Vec. + fn get_chunk_contents(&self, chunk_data: &[u8]) -> Result, CasObjectError> { + self.validate_cas_object_info()?; + + // walk chunk_data, deserialize into Chunks, and then get_bytes() from each of them. + let mut reader = Cursor::new(chunk_data); + let mut res = Vec::::new(); + + while reader.has_remaining() { + let (data, _) = deserialize_chunk(&mut reader)?; + res.extend_from_slice(&data); + } + Ok(res) + } + + /// Helper function to translate CasObjectInfo.chunk_boundary_offsets to just return chunk boundaries. + /// + /// The final chunk boundary returned is required to be the length of the contents. + fn get_chunk_boundaries(&self) -> Result, CasObjectError> { + self.validate_cas_object_info()?; + Ok(self.info.chunk_boundary_offsets.to_vec()) + } + + /// Helper method to verify that info object is complete + fn validate_cas_object_info(&self) -> Result<(), CasObjectError> { + if self.info == Default::default() { + return Err(CasObjectError::InternalError(anyhow!( + "Incomplete CasObject, no CasObjectInfo footer." + ))); + } + Ok(()) + } + + /// Helper method to validate root hash for data block. + fn validate_root_hash(data: &[u8], chunk_boundaries: &[u32], hash: &MerkleHash) -> bool { + // at least 1 chunk, and last entry in chunk boundary must match the length + if chunk_boundaries.is_empty() + || chunk_boundaries[chunk_boundaries.len() - 1] as usize != data.len() + { + return false; + } + + let mut chunks: Vec = Vec::new(); + let mut left_edge: usize = 0; + for i in chunk_boundaries { + let right_edge = *i as usize; + let hash = merklehash::compute_data_hash(&data[left_edge..right_edge]); + let length = right_edge - left_edge; + chunks.push(Chunk { hash, length }); + left_edge = right_edge; + } + + let mut db = MerkleMemDB::default(); + let mut staging = db.start_insertion_staging(); + db.add_file(&mut staging, &chunks); + let ret = db.finalize(staging); + *ret.hash() == *hash + } + } #[cfg(test)] @@ -528,8 +538,8 @@ mod tests { // Arrange let (c, _cas_data, _raw_data, _raw_chunk_boundaries) = build_cas_object(3, 100, false, CompressionScheme::None); // Act & Assert - assert_eq!(c.get_chunk_boundaries().len(), 3); - assert_eq!(c.get_chunk_boundaries(), [108, 216, 324]); + assert_eq!(c.get_chunk_boundaries().unwrap().len(), 3); + assert_eq!(c.get_chunk_boundaries().unwrap(), [108, 216, 324]); assert_eq!(c.info.num_chunks, 3); assert_eq!(c.info.chunk_boundary_offsets.len(), c.info.num_chunks as usize); From 951a464bd41aa5075b04414c8abada3828e2be71 Mon Sep 17 00:00:00 2001 From: Rajat Arya Date: Mon, 30 Sep 2024 17:45:54 -0700 Subject: [PATCH 08/19] Basic generate_chunk_range_hash - includes basic unit-test - created validate_cas_object_info method - todos for additional tests to write --- cas_object/src/cas_object_format.rs | 223 ++++++++++++++++++++++------ 1 file changed, 174 insertions(+), 49 deletions(-) diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index 99cbec9f..cef0f8a2 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -1,15 +1,17 @@ use bytes::Buf; use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB}; use merklehash::{DataHash, MerkleHash}; -use tracing::warn; use std::{ - cmp::min, + cmp::{max, min}, io::{Cursor, Error, Read, Seek, Write}, mem::size_of, }; +use tracing::warn; use crate::{ - cas_chunk_format::{deserialize_chunk, serialize_chunk}, error::CasObjectError, CompressionScheme + cas_chunk_format::{deserialize_chunk, serialize_chunk}, + error::CasObjectError, + CompressionScheme, }; use anyhow::anyhow; @@ -26,14 +28,14 @@ pub struct CasObjectInfo { /// Format version, expected to be 0 right now. pub version: u8, - /// 256-bits, 16-bytes, The CAS Hash of this Xorb. + /// 256-bits, 32-bytes, The CAS Hash of this Xorb. pub cashash: MerkleHash, /// Total number of chunks in the Xorb. Length of chunk_boundary_offsets & chunk_hashes vectors. pub num_chunks: u32, /// Byte offset marking the boundary of each chunk. Length of vector is num_chunks. - /// + /// /// This vector only contains boundaries, so assumes the first chunk starts at offset 0. /// The final entry in vector is the total length of the chunks. /// See example below. @@ -191,7 +193,7 @@ impl CasObjectInfo { /// XORB: 16MB data block for storing chunks. /// /// Has Info footer, and a set of functions that interact directly with XORB. -/// +/// /// Physical layout of this object is as follows: /// [START OF XORB] /// @@ -222,7 +224,6 @@ impl Default for CasObject { } impl CasObject { - /// Deserializes only the info length field of the footer to tell the user how many bytes /// make up the info portion of the xorb. /// @@ -254,7 +255,6 @@ impl CasObject { chunk_boundaries: &Vec, compression_scheme: CompressionScheme, ) -> Result<(Self, usize), CasObjectError> { - // validate hash against contents if !Self::validate_root_hash(data, chunk_boundaries, hash) { return Err(CasObjectError::HashMismatch); @@ -284,7 +284,9 @@ impl CasObject { let chunk_written_bytes = serialize_chunk(&chunk_raw_bytes, writer, compression_scheme)?; total_written_bytes += chunk_written_bytes; - cas.info.chunk_boundary_offsets.push(total_written_bytes as u32); + cas.info + .chunk_boundary_offsets + .push(total_written_bytes as u32); // update indexes and onto next chunk raw_start_idx = chunk_boundary; @@ -299,15 +301,17 @@ impl CasObject { total_written_bytes += size_of::(); Ok((cas, total_written_bytes)) - } + } /// Validate CasObject. - /// Verifies each chunk is valid and correctly represented in CasObjectInfo, along with + /// Verifies each chunk is valid and correctly represented in CasObjectInfo, along with /// recomputing the hash and validating it matches CasObjectInfo. - /// + /// /// Returns Ok(true) if recomputed hash matches what is passed in. - pub fn validate_cas_object(reader: &mut R, hash: &MerkleHash) -> Result { - + pub fn validate_cas_object( + reader: &mut R, + hash: &MerkleHash, + ) -> Result { // 1. deserialize to get Info let cas = CasObject::deserialize(reader)?; @@ -316,20 +320,22 @@ impl CasObject { let mut cumulative_compressed_length: u32 = 0; let mut start_offset = 0; - // Validate each chunk: iterate chunks, deserialize chunk, compare stored hash with + // Validate each chunk: iterate chunks, deserialize chunk, compare stored hash with // computed hash, store chunk hashes for cashash validation for idx in 0..cas.info.num_chunks { - - // deserialize each chunk + // deserialize each chunk reader.seek(std::io::SeekFrom::Start(start_offset as u64))?; let (data, compressed_chunk_length) = deserialize_chunk(reader)?; let chunk_uncompressed_length = data.len(); let chunk_hash = merklehash::compute_data_hash(&data); - hash_chunks.push(Chunk {hash: chunk_hash, length: chunk_uncompressed_length}); - + hash_chunks.push(Chunk { + hash: chunk_hash, + length: chunk_uncompressed_length, + }); + cumulative_compressed_length += compressed_chunk_length as u32; - + // verify chunk hash if *cas.info.chunk_hashes.get(idx as usize).unwrap() != chunk_hash { warn!("XORB Validation: Chunk hash does not match Info object."); @@ -347,12 +353,14 @@ impl CasObject { // set start offset of next chunk as the boundary of the current chunk start_offset = boundary; } - + // validate that Info/footer begins immediately after final content xorb. // end of for loop completes the content chunks, now should be able to deserialize an Info directly let cur_position = reader.stream_position()? as u32; let expected_position = cumulative_compressed_length; - let expected_from_end_position = reader.seek(std::io::SeekFrom::End(0))? as u32 - cas.info_length - size_of::() as u32; + let expected_from_end_position = reader.seek(std::io::SeekFrom::End(0))? as u32 + - cas.info_length + - size_of::() as u32; if cur_position != expected_position || cur_position != expected_from_end_position { warn!("XORB Validation: Content bytes after known chunks in Info object."); return Ok(false); @@ -370,7 +378,60 @@ impl CasObject { } Ok(true) + } + + /// Generate a hash for securing a chunk range. + /// + /// chunk_start_index, chunk_end_index: byte indices for chunks in CasObject. + /// key: additional key incorporated into generating hash. + /// + /// This hash ensures validity of the knowledge of chunks, since ranges are public, + /// this ensures that only users that actually have access to chunks can request them. + pub fn generate_chunk_range_hash( + &self, + chunk_start_index: u32, + chunk_end_index: u32, + key: &[u8], + ) -> Result { + self.validate_cas_object_info()?; + + if chunk_end_index < chunk_start_index + || self.get_contents_length()? > max(chunk_end_index, chunk_start_index) + { + return Err(CasObjectError::InvalidArguments); + } + + // walk chunk boundaries and collect relevant hashes + let mut range_hashes = Vec::::new(); + let mut found_start = chunk_start_index == 0; + let mut found_end = false; + for (idx, boundary) in self.info.chunk_boundary_offsets.iter().enumerate() { + let boundary = *boundary; + if found_start || chunk_start_index == boundary { + found_start = true; + let chunk_hash = self.info.chunk_hashes.get(idx).unwrap(); + range_hashes.push(*chunk_hash); + } + + // if found end then exit loop early + if chunk_end_index == boundary { + found_end = true; + break; + } + } + + if !found_start || !found_end { + return Err(CasObjectError::InternalError(anyhow!("Chunk Range Invalid"))) + } + + // TODO: Make this more robust, currently appends range hashes together, adds key to end + let mut combined : Vec = range_hashes.iter().flat_map(|hash| hash.as_bytes().to_vec()).collect(); + combined.extend_from_slice(key); + + // now hash the hashes + key and return + let range_hash = merklehash::compute_data_hash(&combined); + Ok(range_hash) } /// Return end value of all chunk contents (byte index prior to header) @@ -385,7 +446,7 @@ impl CasObject { } /// Get range of content bytes from Xorb. - /// + /// /// The start and end parameters are required to align with chunk boundaries. pub fn get_range( &self, @@ -393,11 +454,10 @@ impl CasObject { start: u32, end: u32, ) -> Result, CasObjectError> { - if end < start { return Err(CasObjectError::InvalidRange); } - + self.validate_cas_object_info()?; // make sure the end of the range is within the bounds of the xorb @@ -462,10 +522,31 @@ impl CasObject { "Incomplete CasObject, no CasObjectInfo footer." ))); } + + if self.info.num_chunks == 0 { + return Err(CasObjectError::FormatError(anyhow!( + "Invalid CasObjectInfo, no chunks in CasObject." + ))); + } + + if self.info.num_chunks != self.info.chunk_boundary_offsets.len() as u32 + || self.info.num_chunks != self.info.chunk_hashes.len() as u32 + { + return Err(CasObjectError::FormatError(anyhow!( + "Invalid CasObjectInfo, num chunks not matching boundaries or hashes." + ))); + } + + if self.info.cashash == DataHash::default() { + return Err(CasObjectError::FormatError(anyhow!( + "Invalid CasObjectInfo, Missing cashash." + ))); + } + Ok(()) } - /// Helper method to validate root hash for data block. + /// Helper method to validate root hash for data block. fn validate_root_hash(data: &[u8], chunk_boundaries: &[u32], hash: &MerkleHash) -> bool { // at least 1 chunk, and last entry in chunk boundary must match the length if chunk_boundaries.is_empty() @@ -490,7 +571,6 @@ impl CasObject { let ret = db.finalize(staging); *ret.hash() == *hash } - } #[cfg(test)] @@ -536,16 +616,20 @@ mod tests { #[test] fn test_chunk_boundaries_chunk_size_info() { // Arrange - let (c, _cas_data, _raw_data, _raw_chunk_boundaries) = build_cas_object(3, 100, false, CompressionScheme::None); + let (c, _cas_data, _raw_data, _raw_chunk_boundaries) = + build_cas_object(3, 100, false, CompressionScheme::None); // Act & Assert assert_eq!(c.get_chunk_boundaries().unwrap().len(), 3); assert_eq!(c.get_chunk_boundaries().unwrap(), [108, 216, 324]); assert_eq!(c.info.num_chunks, 3); - assert_eq!(c.info.chunk_boundary_offsets.len(), c.info.num_chunks as usize); + assert_eq!( + c.info.chunk_boundary_offsets.len(), + c.info.num_chunks as usize + ); let second_chunk_boundary = *c.info.chunk_boundary_offsets.get(1).unwrap(); let third_chunk_boundary = *c.info.chunk_boundary_offsets.get(2).unwrap(); - assert_eq!(second_chunk_boundary, 216); // 8-byte header, 3 chunks, so 2nd chunk boundary is at byte 216 + assert_eq!(second_chunk_boundary, 216); // 8-byte header, 3 chunks, so 2nd chunk boundary is at byte 216 assert_eq!(third_chunk_boundary, 324); // 8-byte header, 3 chunks, so 3rd chunk boundary is at byte 324 } @@ -556,6 +640,34 @@ mod tests { data } + #[test] + fn test_generate_range_hash_full_range() { + // Arrange + let (c, _cas_data, _raw_data, _raw_chunk_boundaries) = + build_cas_object(3, 100, false, CompressionScheme::None); + let key = [b'K', b'E', b'Y']; + + let mut hashes : Vec = c.info.chunk_hashes.iter().flat_map(|hash| hash.as_bytes().to_vec()).collect(); + hashes.extend_from_slice(&key); + let expected_hash = merklehash::compute_data_hash(&hashes); + + // Act & Assert + let range_hash = c.generate_chunk_range_hash(0, 324, &key).unwrap(); + assert_eq!(range_hash, expected_hash); + } + + #[ignore = "not written yet"] + #[test] + fn test_generate_range_hash_partial() { + todo!() + } + + #[ignore = "Not written yet"] + #[test] + fn test_validate_cas_object_info() { + todo!() + } + /// Utility test method for creating a cas object /// Returns (CasObject, CasObjectInfo serialized, raw data, raw data chunk boundaries) fn build_cas_object( @@ -586,18 +698,16 @@ mod tests { let bytes = gen_random_bytes(chunk_size); let chunk_hash = merklehash::compute_data_hash(&bytes); - chunks.push(Chunk { hash: chunk_hash, length: bytes.len() }); + chunks.push(Chunk { + hash: chunk_hash, + length: bytes.len(), + }); data_contents_raw.extend_from_slice(&bytes); // build chunk, create ChunkInfo and keep going - let bytes_written = serialize_chunk( - &bytes, - &mut writer, - compression_scheme, - ) - .unwrap(); + let bytes_written = serialize_chunk(&bytes, &mut writer, compression_scheme).unwrap(); total_bytes += bytes_written as u32; @@ -622,13 +732,19 @@ mod tests { let len = c.info.serialize(&mut buf).unwrap(); c.info_length = len as u32; - (c, writer.get_ref().to_vec(), data_contents_raw, raw_chunk_boundaries) + ( + c, + writer.get_ref().to_vec(), + data_contents_raw, + raw_chunk_boundaries, + ) } #[test] fn test_compress_decompress() { // Arrange - let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(55, 53212, false, CompressionScheme::LZ4); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = + build_cas_object(55, 53212, false, CompressionScheme::LZ4); // Act & Assert let mut writer: Cursor> = Cursor::new(Vec::new()); @@ -674,7 +790,8 @@ mod tests { #[test] fn test_hash_generation_compression() { // Arrange - let (c, cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(55, 53212, false, CompressionScheme::LZ4); + let (c, cas_data, raw_data, raw_chunk_boundaries) = + build_cas_object(55, 53212, false, CompressionScheme::LZ4); // Act & Assert let mut buf: Cursor> = Cursor::new(Vec::new()); assert!(CasObject::serialize( @@ -695,7 +812,8 @@ mod tests { #[test] fn test_basic_serialization_mem() { // Arrange - let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(3, 100, false, CompressionScheme::None); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = + build_cas_object(3, 100, false, CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -713,7 +831,8 @@ mod tests { #[test] fn test_serialization_deserialization_mem_medium() { // Arrange - let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(32, 16384, false, CompressionScheme::None); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = + build_cas_object(32, 16384, false, CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -743,7 +862,8 @@ mod tests { #[test] fn test_serialization_deserialization_mem_large_random() { // Arrange - let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(32, 65536, true, CompressionScheme::None); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = + build_cas_object(32, 65536, true, CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -772,7 +892,8 @@ mod tests { #[test] fn test_serialization_deserialization_file_large_random() { // Arrange - let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(256, 65536, true, CompressionScheme::None); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = + build_cas_object(256, 65536, true, CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -801,7 +922,8 @@ mod tests { #[test] fn test_basic_mem_lz4() { // Arrange - let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(1, 8, false, CompressionScheme::LZ4); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = + build_cas_object(1, 8, false, CompressionScheme::LZ4); let mut writer: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -825,11 +947,12 @@ mod tests { assert_eq!(c.info.num_chunks, c2.info.num_chunks); assert_eq!(raw_data, bytes_read); } - + #[test] fn test_serialization_deserialization_mem_medium_lz4() { // Arrange - let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(32, 16384, false, CompressionScheme::LZ4); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = + build_cas_object(32, 16384, false, CompressionScheme::LZ4); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -859,7 +982,8 @@ mod tests { #[test] fn test_serialization_deserialization_mem_large_random_lz4() { // Arrange - let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(32, 65536, true, CompressionScheme::LZ4); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = + build_cas_object(32, 65536, true, CompressionScheme::LZ4); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -888,7 +1012,8 @@ mod tests { #[test] fn test_serialization_deserialization_file_large_random_lz4() { // Arrange - let (c, _cas_data, raw_data, raw_chunk_boundaries) = build_cas_object(256, 65536, true, CompressionScheme::LZ4); + let (c, _cas_data, raw_data, raw_chunk_boundaries) = + build_cas_object(256, 65536, true, CompressionScheme::LZ4); let mut writer: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( From 8b5f344e7c6a1da1088b18b84583195f7cf16c4f Mon Sep 17 00:00:00 2001 From: seanses Date: Tue, 1 Oct 2024 10:07:33 -0700 Subject: [PATCH 09/19] more updates and api changes --- Cargo.lock | 3 +- cas_client/src/caching_client.rs | 4 + cas_client/src/remote_client.rs | 5 +- cas_types/Cargo.toml | 3 +- cas_types/src/error.rs | 8 ++ cas_types/src/key.rs | 26 +++++- cas_types/src/lib.rs | 6 +- data/src/bin/example.rs | 33 +++----- data/src/clean.rs | 50 +++++------ data/src/data_processing.rs | 44 +++++----- data/src/remote_shard_interface.rs | 94 +++------------------ data/src/shard_interface.rs | 9 +- hf_xet/Cargo.lock | 77 +---------------- hf_xet/src/config.rs | 2 +- hf_xet/src/data_client.rs | 99 ++++++++++++---------- mdb_shard/src/file_structs.rs | 7 +- mdb_shard/src/shard_format.rs | 14 ++-- merkledb/src/aggregate_hashes.rs | 4 +- shard_client/Cargo.toml | 2 +- shard_client/src/error.rs | 6 ++ shard_client/src/http_shard_client.rs | 115 ++++++++++++++++++++++---- 21 files changed, 302 insertions(+), 309 deletions(-) create mode 100644 cas_types/src/error.rs diff --git a/Cargo.lock b/Cargo.lock index a0a451f9..48106afc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -502,6 +502,7 @@ dependencies = [ "merklehash", "serde", "serde_repr", + "xet_error", ] [[package]] @@ -3680,6 +3681,7 @@ dependencies = [ "cas_client", "cas_types", "clap 2.34.0", + "file_utils", "heed", "http 0.2.12", "itertools 0.10.5", @@ -3701,7 +3703,6 @@ dependencies = [ "tracing", "tracing-opentelemetry", "url", - "utils", "uuid", "xet_error", ] diff --git a/cas_client/src/caching_client.rs b/cas_client/src/caching_client.rs index 62085d89..d61eba01 100644 --- a/cas_client/src/caching_client.rs +++ b/cas_client/src/caching_client.rs @@ -1,3 +1,5 @@ +#![allow(unused_variables)] + use crate::error::Result; use crate::interface::*; use async_trait::async_trait; @@ -44,3 +46,5 @@ impl ReconstructionClient for CachingClient { todo!() } } + +impl Client for CachingClient {} diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 9ee9a15a..ad56054c 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -87,6 +87,7 @@ impl ReconstructionClient for RemoteClient { Ok(()) } + #[allow(unused_variables)] async fn get_file_byte_range( &self, hash: &MerkleHash, @@ -145,11 +146,11 @@ impl RemoteClient { Ok(response_parsed.was_inserted) } - async fn reconstruct( + async fn reconstruct( &self, reconstruction_response: QueryReconstructionResponse, _byte_range: Option<(u64, u64)>, - writer: &mut W, + writer: &mut Box, ) -> Result { let info = reconstruction_response.reconstruction; let total_len = info.iter().fold(0, |acc, x| acc + x.unpacked_length); diff --git a/cas_types/Cargo.toml b/cas_types/Cargo.toml index f74b8fda..b055fe94 100644 --- a/cas_types/Cargo.toml +++ b/cas_types/Cargo.toml @@ -4,7 +4,8 @@ version = "0.1.0" edition = "2021" [dependencies] +merklehash = { path = "../merklehash" } +xet_error = { path = "../xet_error"} anyhow = "1.0.86" serde = { version = "1.0.208", features = ["derive"] } -merklehash = { path = "../merklehash" } serde_repr = "0.1.19" diff --git a/cas_types/src/error.rs b/cas_types/src/error.rs new file mode 100644 index 00000000..9c6ac19c --- /dev/null +++ b/cas_types/src/error.rs @@ -0,0 +1,8 @@ +use xet_error::Error; + +#[non_exhaustive] +#[derive(Error, Debug)] +pub enum CasTypesError { + #[error("Invalid key: {0}")] + InvalidKey(String), +} diff --git a/cas_types/src/key.rs b/cas_types/src/key.rs index 34d7298a..daae2501 100644 --- a/cas_types/src/key.rs +++ b/cas_types/src/key.rs @@ -1,7 +1,10 @@ -use std::fmt::{Display, Formatter}; - +use crate::error::CasTypesError; use merklehash::MerkleHash; use serde::{Deserialize, Serialize}; +use std::{ + fmt::{Display, Formatter}, + str::FromStr, +}; /// A Key indicates a prefixed merkle hash for some data stored in the CAS DB. #[derive(Debug, PartialEq, Default, Serialize, Deserialize, Ord, PartialOrd, Eq, Hash, Clone)] @@ -16,6 +19,25 @@ impl Display for Key { } } +impl FromStr for Key { + type Err = CasTypesError; + + fn from_str(s: &str) -> Result { + let parts = s.rsplit_once('/'); + let Some((prefix, hash)) = parts else { + return Err(CasTypesError::InvalidKey(s.to_owned())); + }; + + let hash = + MerkleHash::from_hex(hash).map_err(|_| CasTypesError::InvalidKey(s.to_owned()))?; + + Ok(Key { + prefix: prefix.to_owned(), + hash, + }) + } +} + mod hex { pub mod serde { use merklehash::MerkleHash; diff --git a/cas_types/src/lib.rs b/cas_types/src/lib.rs index 22d14f65..0b88ccf0 100644 --- a/cas_types/src/lib.rs +++ b/cas_types/src/lib.rs @@ -1,8 +1,8 @@ -use serde_repr::{Deserialize_repr, Serialize_repr}; - use merklehash::MerkleHash; use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; +mod error; mod key; pub use key::*; @@ -21,9 +21,11 @@ pub struct Range { pub struct CASReconstructionTerm { pub hash: HexMerkleHash, pub unpacked_length: u32, + // chunk index start and end in a xorb pub range: Range, pub range_start_offset: u32, pub url: String, + // byte index start and end in a xorb pub url_range: Range, } diff --git a/data/src/bin/example.rs b/data/src/bin/example.rs index 43e9b195..160e0c1d 100644 --- a/data/src/bin/example.rs +++ b/data/src/bin/example.rs @@ -1,6 +1,5 @@ use anyhow::Result; use clap::{Args, Parser, Subcommand}; -use data::DEFAULT_BLOCK_SIZE; use data::{configurations::*, SMALL_FILE_THRESHOLD}; use data::{PointerFile, PointerFileTranslator}; use std::env::current_dir; @@ -73,22 +72,18 @@ fn default_clean_config() -> Result { file_query_policy: Default::default(), cas_storage_config: StorageConfig { endpoint: Endpoint::FileSystem(path.join("xorbs")), - auth: Auth { - token: None, - }, + auth: Auth { token: None }, prefix: "default".into(), cache_config: Some(CacheConfig { cache_directory: path.join("cache"), cache_size: 10 * 1024 * 1024 * 1024, // 10 GiB - cache_blocksize: DEFAULT_BLOCK_SIZE, + cache_blocksize: 0, // ignored }), staging_directory: None, }, shard_storage_config: StorageConfig { endpoint: Endpoint::FileSystem(path.join("xorbs")), - auth: Auth { - token: None, - }, + auth: Auth { token: None }, prefix: "default-merkledb".into(), cache_config: Some(CacheConfig { cache_directory: path.join("shard-cache"), @@ -120,22 +115,18 @@ fn default_smudge_config() -> Result { file_query_policy: Default::default(), cas_storage_config: StorageConfig { endpoint: Endpoint::FileSystem(path.join("xorbs")), - auth: Auth { - token: None, - }, + auth: Auth { token: None }, prefix: "default".into(), cache_config: Some(CacheConfig { cache_directory: path.join("cache"), cache_size: 10 * 1024 * 1024 * 1024, // 10 GiB - cache_blocksize: DEFAULT_BLOCK_SIZE, + cache_blocksize: 0, // ignored }), staging_directory: None, }, shard_storage_config: StorageConfig { endpoint: Endpoint::FileSystem(path.join("xorbs")), - auth: Auth { - token: None, - }, + auth: Auth { token: None }, prefix: "default-merkledb".into(), cache_config: Some(CacheConfig { cache_directory: path.join("shard-cache"), @@ -203,20 +194,22 @@ async fn smudge_file(arg: &SmudgeArg) -> Result<()> { Some(path) => Box::new(File::open(path)?), None => Box::new(std::io::stdin()), }; - let writer = BufWriter::new( + let mut writer: Box = Box::new(BufWriter::new( File::options() .create(true) .write(true) .truncate(true) .open(&arg.dest)?, - ); + )); - smudge(reader, writer).await?; + smudge(reader, &mut writer).await?; + + writer.flush()?; Ok(()) } -async fn smudge(mut reader: impl Read, mut writer: impl Write) -> Result<()> { +async fn smudge(mut reader: impl Read, writer: &mut Box) -> Result<()> { let mut input = String::new(); reader.read_to_string(&mut input)?; @@ -230,7 +223,7 @@ async fn smudge(mut reader: impl Read, mut writer: impl Write) -> Result<()> { let translator = PointerFileTranslator::new(default_smudge_config()?).await?; translator - .smudge_file_from_pointer(&pointer_file, &mut writer, None, None, None) + .smudge_file_from_pointer(&pointer_file, writer, None) .await?; Ok(()) diff --git a/data/src/clean.rs b/data/src/clean.rs index 52350b32..364add8d 100644 --- a/data/src/clean.rs +++ b/data/src/clean.rs @@ -9,7 +9,6 @@ use crate::remote_shard_interface::RemoteShardInterface; use crate::repo_salt::RepoSalt; use crate::small_file_determination::{is_file_passthrough, is_possible_start_to_text_file}; use crate::PointerFile; - use lazy_static::lazy_static; use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDBFileInfo}; use mdb_shard::shard_file_reconstructor::FileReconstructor; @@ -323,8 +322,8 @@ impl Cleaner { }; // Okay, we have something, so go ahead and download it in the background. - debug!("global dedup: {file_name:?} deduplicated by shard {shard_hash}; downloading."); - let Ok(_) = remote_shards.download_and_register_shard(&shard_hash).await.map_err(|e| { + debug!("global dedup: {file_name:?} deduplicated by shard {shard_hash}; registering."); + let Ok(_) = remote_shards.register_local_shard(&shard_hash).await.map_err(|e| { warn!("Error encountered attempting to download and register shard {shard_hash} for deduplication : {e:?}; ignoring."); e }) @@ -386,13 +385,13 @@ impl Cleaner { // start a new entry? if !tracking_info.file_info.is_empty() && tracking_info.file_info.last().unwrap().cas_hash == fse.cas_hash - && tracking_info.file_info.last().unwrap().chunk_byte_range_end - == fse.chunk_byte_range_start + && tracking_info.file_info.last().unwrap().chunk_index_end + == fse.chunk_index_start { // This block is the contiguous continuation of the last entry let last_entry = tracking_info.file_info.last_mut().unwrap(); last_entry.unpacked_segment_bytes += n_bytes as u32; - last_entry.chunk_byte_range_end += n_bytes as u32; + last_entry.chunk_index_end = fse.chunk_index_end; } else { // This block is new tracking_info.file_info.push(fse); @@ -409,8 +408,7 @@ impl Cleaner { let add_new_data; if let Some(idx) = tracking_info.current_cas_block_hashes.get(&chunk.hash) { - let (_, (data_lb, data_ub)) = tracking_info.cas_data.chunks[*idx]; - + let idx = *idx; // This chunk will get the CAS hash updated when the local CAS block // is full and registered. let file_info_len = tracking_info.file_info.len(); @@ -421,20 +419,20 @@ impl Cleaner { tracking_info.file_info.push(FileDataSequenceEntry::new( MerkleHash::default(), n_bytes, - data_lb, - data_ub, + idx, + idx + 1, )); add_new_data = false; } else if !tracking_info.file_info.is_empty() && tracking_info.file_info.last().unwrap().cas_hash == MerkleHash::default() - && tracking_info.file_info.last().unwrap().chunk_byte_range_end as usize - == tracking_info.cas_data.data.len() + && tracking_info.file_info.last().unwrap().chunk_index_end as usize + == tracking_info.cas_data.chunks.len() { - // This is the next chunk in the CAS block - // we're building, in which case we can just modify the previous entry. + // This is the next chunk in the CAS block we're building, + // in which case we can just modify the previous entry. let last_entry = tracking_info.file_info.last_mut().unwrap(); last_entry.unpacked_segment_bytes += n_bytes as u32; - last_entry.chunk_byte_range_end += n_bytes as u32; + last_entry.chunk_index_end += 1; add_new_data = true; } else { // This block is unrelated to the previous one. @@ -445,12 +443,12 @@ impl Cleaner { .current_cas_file_info_indices .push(file_info_len); - let cas_data_len = tracking_info.cas_data.data.len(); + let chunk_len = tracking_info.cas_data.chunks.len(); tracking_info.file_info.push(FileDataSequenceEntry::new( MerkleHash::default(), n_bytes, - cas_data_len, - cas_data_len + n_bytes, + chunk_len, + chunk_len + 1, )); add_new_data = true; } @@ -461,12 +459,7 @@ impl Cleaner { tracking_info .current_cas_block_hashes .insert(chunk.hash, cas_data_chunks_len); - - let cas_data_len = tracking_info.cas_data.data.len(); - tracking_info - .cas_data - .chunks - .push((chunk.hash, (cas_data_len, cas_data_len + n_bytes))); + tracking_info.cas_data.chunks.push((chunk.hash, n_bytes)); tracking_info.cas_data.data.extend(bytes); if tracking_info.cas_data.data.len() > TARGET_CAS_BLOCK_SIZE { @@ -547,7 +540,7 @@ impl Cleaner { // Put an accumulated data into the struct-wide cas block for building a future chunk. let mut cas_data_accumulator = self.global_cas_data.lock().await; - let shift = cas_data_accumulator.data.len() as u32; + let shift = cas_data_accumulator.chunks.len() as u32; cas_data_accumulator .data .append(&mut tracking_info.cas_data.data); @@ -560,7 +553,8 @@ impl Cleaner { .file_info .iter() .map(|fi| { - // If it's in this new cas chunk, shift everything. + // Transfering cas chunks from tracking_info.cas_data to cas_data_accumulator, + // shift chunk indices. let s = if fi.cas_hash == MerkleHash::default() { shift } else { @@ -568,8 +562,8 @@ impl Cleaner { }; let mut new_fi = fi.clone(); - new_fi.chunk_byte_range_start += s; - new_fi.chunk_byte_range_end += s; + new_fi.chunk_index_start += s; + new_fi.chunk_index_end += s; new_fi }) diff --git a/data/src/data_processing.rs b/data/src/data_processing.rs index 0e2363ca..f3881faf 100644 --- a/data/src/data_processing.rs +++ b/data/src/data_processing.rs @@ -1,7 +1,6 @@ use crate::cas_interface::create_cas_client; use crate::clean::Cleaner; use crate::configurations::*; -use crate::constants::MAX_CONCURRENT_UPLOADS; use crate::errors::*; use crate::metrics::FILTER_CAS_BYTES_PRODUCED; use crate::remote_shard_interface::RemoteShardInterface; @@ -14,6 +13,7 @@ use mdb_shard::file_structs::MDBFileInfo; use mdb_shard::ShardFileManager; use merkledb::aggregate_hashes::cas_node_hash; use merklehash::MerkleHash; +use std::io::Write; use std::mem::take; use std::ops::DerefMut; use std::path::Path; @@ -22,8 +22,11 @@ use tokio::sync::Mutex; #[derive(Default, Debug)] pub struct CASDataAggregator { + /// Bytes of all chunks accumulated in one CAS block concatenated together. pub data: Vec, - pub chunks: Vec<(MerkleHash, (usize, usize))>, + /// Metadata of all chunks accumulated in one CAS block. Each entry is + /// (chunk hash, chunk size). + pub chunks: Vec<(MerkleHash, usize)>, // The file info of files that are still being processed. // As we're building this up, we assume that all files that do not have a size in the header are // not finished yet and thus cannot be uploaded. @@ -59,10 +62,10 @@ pub struct PointerFileTranslator { global_cas_data: Arc>, } -// Constructors +// Constructorscas_data_accumulator impl PointerFileTranslator { pub async fn new(config: TranslatorConfig) -> Result { - let cas_client = create_cas_client(&config.cas_storage_config, &config.repo_info).await?; + let cas_client = create_cas_client(&config.cas_storage_config, &config.repo_info)?; let shard_manager = Arc::new(create_shard_manager(&config.shard_storage_config).await?); @@ -191,10 +194,7 @@ impl PointerFileTranslator { } async fn upload_cas(&self) -> Result<()> { - self.cas - .upload_all_staged(*MAX_CONCURRENT_UPLOADS, false) - .await?; - + // We don't have staging client support yet. Ok(()) } } @@ -232,23 +232,23 @@ pub(crate) async fn register_new_cas_block( let chunks: Vec<_> = cas_data .chunks .iter() - .map(|(h, (bytes_lb, bytes_ub))| { - let size = bytes_ub - bytes_lb; - let result = CASChunkSequenceEntry::new(*h, size, pos); - pos += size; + .map(|(h, len)| { + let result = CASChunkSequenceEntry::new(*h, *len, pos); + pos += *len; result }) .collect(); - let cas_info = MDBCASInfo { metadata, chunks }; - let mut chunk_boundaries: Vec = Vec::with_capacity(cas_data.chunks.len()); - let mut running_sum = 0; - - for (_, s) in cas_data.chunks.iter() { - running_sum += s.1 - s.0; - chunk_boundaries.push(running_sum as u64); - } + pos = 0; + let chunk_boundaries = cas_data + .chunks + .iter() + .map(|(_, len)| { + pos += *len; + pos as u64 + }) + .collect(); if !cas_info.chunks.is_empty() { shard_manager.add_cas_block(cas_info).await?; @@ -288,7 +288,7 @@ impl PointerFileTranslator { pub async fn smudge_file_from_pointer( &self, pointer: &PointerFile, - writer: &mut impl std::io::Write, + writer: &mut Box, range: Option<(usize, usize)>, ) -> Result<()> { self.smudge_file_from_hash(&pointer.hash()?, writer, range) @@ -298,7 +298,7 @@ impl PointerFileTranslator { pub async fn smudge_file_from_hash( &self, file_id: &MerkleHash, - writer: &mut impl std::io::Write, + writer: &mut Box, _range: Option<(usize, usize)>, ) -> Result<()> { self.cas.get_file(file_id, writer).await?; diff --git a/data/src/remote_shard_interface.rs b/data/src/remote_shard_interface.rs index 1369fd86..bfedcb63 100644 --- a/data/src/remote_shard_interface.rs +++ b/data/src/remote_shard_interface.rs @@ -5,7 +5,6 @@ use crate::cas_interface::Client; use crate::constants::{FILE_RECONSTRUCTION_CACHE_SIZE, MAX_CONCURRENT_UPLOADS}; use crate::repo_salt::RepoSalt; use cas::singleflight; -use file_utils::write_all_safe; use lru::LruCache; use mdb_shard::constants::MDB_SHARD_MIN_TARGET_SIZE; use mdb_shard::session_directory::consolidate_shards_in_directory; @@ -21,7 +20,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::Mutex; use tokio::task::JoinHandle; -use tracing::{debug, error, info}; +use tracing::{debug, info}; pub struct RemoteShardInterface { pub file_query_policy: FileQueryPolicy, @@ -271,50 +270,23 @@ impl RemoteShardInterface { Ok(self.get_dedup_shards(&[*chunk_hash], salt).await?.pop()) } - fn download_and_register_shard_background( - &self, - shard_hash: &MerkleHash, - ) -> Result>> { - let hex_key = shard_hash.hex(); - - let prefix = self.shard_prefix.to_owned(); - - let shard_hash = shard_hash.to_owned(); - let shard_downloads_sf = self.shard_downloads.clone(); + pub async fn register_local_shard(&self, shard_hash: &MerkleHash) -> Result<()> { let shard_manager = self.shard_manager()?; - let cas = self.cas()?; - let cache_dir = self.shard_cache_directory()?; - Ok(tokio::spawn(async move { - if shard_manager.shard_is_registered(&shard_hash).await { - info!("download_and_register_shard: Shard {shard_hash:?} is already registered."); - return Ok(()); - } - - shard_downloads_sf - .work(&hex_key, async move { - // Download the shard in question. - let (shard_file, _) = download_shard(&cas, &prefix, &shard_hash, &cache_dir) - .await - .map_err(|e| DataProcessingError::InternalError(format!("{e:?}")))?; - - shard_manager - .register_shards_by_path(&[shard_file], true) - .await?; + // Shard is expired, we need to evit the previous registration. + if shard_manager.shard_is_registered(&shard_hash).await { + info!("register_local_shard: re-register {shard_hash:?}."); + todo!() + } - Ok(()) - }) - .await - .0?; + let shard_file = cache_dir.join(local_shard_name(&shard_hash)); - Ok(()) - })) - } + shard_manager + .register_shards_by_path(&[shard_file], true) + .await?; - pub async fn download_and_register_shard(&self, shard_hash: &MerkleHash) -> Result<()> { - self.download_and_register_shard_background(shard_hash)? - .await? + Ok(()) } pub fn merge_shards( @@ -403,45 +375,3 @@ fn local_shard_name(hash: &MerkleHash) -> PathBuf { fn is_shard_file(path: &Path) -> bool { path.extension().and_then(OsStr::to_str) == Some("mdb") } - -// Download a shard to local cache if not exists. -// Returns the path to the downloaded file and the number of bytes transferred. -// Returns the path to the existing file and 0 (transferred byte) if exists. -async fn download_shard( - cas: &Arc, - prefix: &str, - shard_hash: &MerkleHash, - dest_dir: &Path, -) -> Result<(PathBuf, usize)> { - let shard_name = local_shard_name(shard_hash); - let dest_file = dest_dir.join(&shard_name); - - if dest_file.exists() { - #[cfg(debug_assertions)] - { - MDBShardFile::load_from_file(&dest_file)?.verify_shard_integrity_debug_only(); - } - debug!( - "download_shard: shard file {shard_name:?} already present in local cache, skipping download." - ); - return Ok((dest_file, 0)); - } else { - debug!( - "download_shard: shard file {shard_name:?} does not exist in local cache, downloading from cas." - ); - } - - let bytes: Vec = match cas.get(prefix, shard_hash).await { - Err(e) => { - error!("Error attempting to download shard {prefix}/{shard_hash:?}: {e:?}"); - Err(e)? - } - Ok(data) => data, - }; - - info!("Downloaded shard {prefix}/{shard_hash:?}."); - - write_all_safe(&dest_file, &bytes)?; - - Ok((dest_file, bytes.len())) -} diff --git a/data/src/shard_interface.rs b/data/src/shard_interface.rs index 56d86089..3c48fcf4 100644 --- a/data/src/shard_interface.rs +++ b/data/src/shard_interface.rs @@ -39,7 +39,14 @@ pub async fn create_shard_client( ) -> Result> { info!("Shard endpoint = {:?}", shard_storage_config.endpoint); let client: Arc = match &shard_storage_config.endpoint { - Server(endpoint) => Arc::new(HttpShardClient::new(endpoint, shard_storage_config.auth.token.clone())), + Server(endpoint) => Arc::new(HttpShardClient::new( + endpoint, + shard_storage_config.auth.token.clone(), + shard_storage_config + .cache_config + .as_ref() + .map(|cache| cache.cache_directory.clone()), + )), FileSystem(path) => Arc::new(LocalShardClient::new(path).await?), }; diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index 38563790..ed997220 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -397,6 +397,7 @@ dependencies = [ "clap 2.34.0", "deadpool", "error_printer", + "file_utils", "futures", "http 0.2.12", "http-body-util", @@ -412,7 +413,6 @@ dependencies = [ "opentelemetry-http", "opentelemetry-jaeger", "parutils", - "progress_reporting", "prost", "reqwest 0.12.7", "retry_strategy", @@ -459,6 +459,7 @@ dependencies = [ "merklehash", "serde", "serde_repr", + "xet_error", ] [[package]] @@ -669,31 +670,6 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" -[[package]] -name = "crossterm" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67" -dependencies = [ - "bitflags 1.3.2", - "crossterm_winapi", - "libc", - "mio 0.8.11", - "parking_lot 0.12.3", - "signal-hook", - "signal-hook-mio", - "winapi", -] - -[[package]] -name = "crossterm_winapi" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" -dependencies = [ - "winapi", -] - [[package]] name = "csv-core" version = "0.1.11" @@ -2039,18 +2015,6 @@ dependencies = [ "adler2", ] -[[package]] -name = "mio" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" -dependencies = [ - "libc", - "log", - "wasi", - "windows-sys 0.48.0", -] - [[package]] name = "mio" version = "1.0.2" @@ -2640,18 +2604,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "progress_reporting" -version = "0.14.5" -dependencies = [ - "atty", - "crossterm", - "more-asserts", - "tokio", - "tracing", - "utils", -] - [[package]] name = "prometheus" version = "0.13.4" @@ -3415,6 +3367,7 @@ dependencies = [ "cas_client", "cas_types", "clap 2.34.0", + "file_utils", "heed", "http 0.2.12", "itertools 0.10.5", @@ -3435,7 +3388,6 @@ dependencies = [ "tracing", "tracing-opentelemetry", "url", - "utils", "uuid", "xet_error", ] @@ -3470,27 +3422,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "signal-hook" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" -dependencies = [ - "libc", - "signal-hook-registry", -] - -[[package]] -name = "signal-hook-mio" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" -dependencies = [ - "libc", - "mio 0.8.11", - "signal-hook", -] - [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -3948,7 +3879,7 @@ dependencies = [ "backtrace", "bytes", "libc", - "mio 1.0.2", + "mio", "parking_lot 0.12.3", "pin-project-lite", "signal-hook-registry", diff --git a/hf_xet/src/config.rs b/hf_xet/src/config.rs index 2614d683..b8140a3c 100644 --- a/hf_xet/src/config.rs +++ b/hf_xet/src/config.rs @@ -2,7 +2,7 @@ use data::configurations::{ Auth, CacheConfig, DedupConfig, Endpoint, FileQueryPolicy, RepoInfo, StorageConfig, TranslatorConfig, }; -use data::{errors, DEFAULT_BLOCK_SIZE}; +use data::errors; use std::env::current_dir; use std::fs; diff --git a/hf_xet/src/data_client.rs b/hf_xet/src/data_client.rs index e5664d6d..3e35b199 100644 --- a/hf_xet/src/data_client.rs +++ b/hf_xet/src/data_client.rs @@ -1,12 +1,12 @@ +use crate::config::default_config; +use data::errors::DataProcessingError; +use data::{errors, PointerFile, PointerFileTranslator}; +use parutils::{tokio_par_for_each, ParallelError}; use std::fs; use std::fs::File; -use std::io::{BufReader, Read}; +use std::io::{BufReader, Read, Write}; use std::path::PathBuf; use std::sync::Arc; -use data::{errors, PointerFile, PointerFileTranslator}; -use data::errors::DataProcessingError; -use parutils::{ParallelError, tokio_par_for_each}; -use crate::config::default_config; /// The maximum git filter protocol packet size pub const MAX_CONCURRENT_UPLOADS: usize = 8; // TODO @@ -15,28 +15,29 @@ pub const MAX_CONCURRENT_DOWNLOADS: usize = 8; // TODO const DEFAULT_CAS_ENDPOINT: &str = "http://localhost:8080"; const READ_BLOCK_SIZE: usize = 1024 * 1024; -pub async fn upload_async(file_paths: Vec, endpoint: Option, token: Option) -> errors::Result> { +pub async fn upload_async( + file_paths: Vec, + endpoint: Option, + token: Option, +) -> errors::Result> { // chunk files // produce Xorbs + Shards // upload shards and xorbs // for each file, return the filehash - let endpoint = endpoint.unwrap_or(DEFAULT_CAS_ENDPOINT.to_string()); - - let config = default_config(endpoint, token)?; + let config = default_config( + endpoint.clone().unwrap_or(DEFAULT_CAS_ENDPOINT.to_string()), + token.clone(), + )?; let processor = Arc::new(PointerFileTranslator::new(config).await?); let processor = &processor; // for all files, clean them, producing pointer files. - let pointers = tokio_par_for_each( - file_paths, - MAX_CONCURRENT_UPLOADS, - |f, _| async { - let proc = processor.clone(); - clean_file(&proc, f).await - }, - ).await.map_err(|e| match e { - ParallelError::JoinError => { - DataProcessingError::InternalError("Join error".to_string()) - } + let pointers = tokio_par_for_each(file_paths, MAX_CONCURRENT_UPLOADS, |f, _| async { + let proc = processor.clone(); + clean_file(&proc, f).await + }) + .await + .map_err(|e| match e { + ParallelError::JoinError => DataProcessingError::InternalError("Join error".to_string()), ParallelError::TaskError(e) => e, })?; @@ -46,25 +47,30 @@ pub async fn upload_async(file_paths: Vec, endpoint: Option, tok Ok(pointers) } -pub async fn download_async(pointer_files: Vec, endpoint: Option, token: Option) -> errors::Result> { - let config = default_config(endpoint.clone().unwrap_or(DEFAULT_CAS_ENDPOINT.to_string()), token.clone())?; +pub async fn download_async( + pointer_files: Vec, + endpoint: Option, + token: Option, +) -> errors::Result> { + let config = default_config( + endpoint.clone().unwrap_or(DEFAULT_CAS_ENDPOINT.to_string()), + token.clone(), + )?; let processor = Arc::new(PointerFileTranslator::new(config).await?); let processor = &processor; let paths = tokio_par_for_each( pointer_files, MAX_CONCURRENT_DOWNLOADS, |pointer_file, _| { - let tok = token.clone(); - let end = endpoint.clone(); async move { let proc = processor.clone(); - smudge_file(&proc, &pointer_file, end.clone(), tok.clone()).await + smudge_file(&proc, &pointer_file).await } }, - ).await.map_err(|e| match e { - ParallelError::JoinError => { - DataProcessingError::InternalError("Join error".to_string()) - } + ) + .await + .map_err(|e| match e { + ParallelError::JoinError => DataProcessingError::InternalError("Join error".to_string()), ParallelError::TaskError(e) => e, })?; @@ -88,25 +94,29 @@ async fn clean_file(processor: &PointerFileTranslator, f: String) -> errors::Res } let pf_str = handle.result().await?; - let pf = PointerFile::init_from_string(&pf_str, path.to_str().unwrap()); + let pf = PointerFile::init_from_string(&pf_str, path.to_str().unwrap()); Ok(pf) } -async fn smudge_file(proc: &PointerFileTranslator, pointer_file: &PointerFile) -> errors::Result { +async fn smudge_file( + proc: &PointerFileTranslator, + pointer_file: &PointerFile, +) -> errors::Result { let path = PathBuf::from(pointer_file.path()); if let Some(parent_dir) = path.parent() { fs::create_dir_all(parent_dir)?; } - let mut f = File::create(&path)?; - proc.smudge_file_from_pointer(&pointer_file, &mut f, None).await?; + let mut f: Box = Box::new(File::create(&path)?); + proc.smudge_file_from_pointer(&pointer_file, &mut f, None) + .await?; Ok(pointer_file.path().to_string()) } #[cfg(test)] mod tests { + use super::*; use std::env::current_dir; use std::fs::canonicalize; - use super::*; #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn upload_files() { @@ -115,20 +125,23 @@ mod tests { let abs_path = canonicalize(path).unwrap(); let s = abs_path.to_string_lossy(); - let files = vec![ - s.to_string(), - ]; - let pointers = upload_async(files).await.unwrap(); + let files = vec![s.to_string()]; + let pointers = upload_async(files, "http://localhost:8080", "12345") + .await + .unwrap(); println!("files: {pointers:?}"); } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn download_files() { - let pointers = vec![ - PointerFile::init_from_info("/tmp/foo.rs", "6999733a46030e67f6f020651c91442ace735572458573df599106e54646867c", 4203), - ]; - let paths = download_async(pointers, "http://localhost:8080", "12345").await.unwrap(); + let pointers = vec![PointerFile::init_from_info( + "/tmp/foo.rs", + "6999733a46030e67f6f020651c91442ace735572458573df599106e54646867c", + 4203, + )]; + let paths = download_async(pointers, "http://localhost:8080", "12345") + .await + .unwrap(); println!("paths: {paths:?}"); } } - diff --git a/mdb_shard/src/file_structs.rs b/mdb_shard/src/file_structs.rs index 091f466e..63818773 100644 --- a/mdb_shard/src/file_structs.rs +++ b/mdb_shard/src/file_structs.rs @@ -77,15 +77,14 @@ pub struct FileDataSequenceEntry { } impl FileDataSequenceEntry { - pub fn new, I2: TryInto>( + pub fn new>( cas_hash: MerkleHash, unpacked_segment_bytes: I1, - chunk_index_start: I2, - chunk_index_end: I2, + chunk_index_start: I1, + chunk_index_end: I1, ) -> Self where >::Error: std::fmt::Debug, - >::Error: std::fmt::Debug, { Self { cas_hash, diff --git a/mdb_shard/src/shard_format.rs b/mdb_shard/src/shard_format.rs index a45b28de..24c27032 100644 --- a/mdb_shard/src/shard_format.rs +++ b/mdb_shard/src/shard_format.rs @@ -1,20 +1,18 @@ +use crate::cas_structs::*; use crate::constants::*; use crate::error::{MDBShardError, Result}; +use crate::file_structs::*; use crate::serialization_utils::*; +use crate::shard_in_memory::MDBInMemoryShard; +use crate::shard_version; +use crate::utils::truncate_hash; use merklehash::MerkleHash; - use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; use std::io::{Read, Seek, SeekFrom, Write}; use std::mem::size_of; use std::sync::Arc; -use tracing::{debug, error}; - -use crate::cas_structs::*; -use crate::file_structs::*; -use crate::shard_in_memory::MDBInMemoryShard; -use crate::shard_version; -use crate::utils::truncate_hash; +use tracing::debug; // Same size for FileDataSequenceHeader and FileDataSequenceEntry const MDB_FILE_INFO_ENTRY_SIZE: u64 = (size_of::<[u64; 4]>() + 4 * size_of::()) as u64; diff --git a/merkledb/src/aggregate_hashes.rs b/merkledb/src/aggregate_hashes.rs index 38bfb5c8..f037eb47 100644 --- a/merkledb/src/aggregate_hashes.rs +++ b/merkledb/src/aggregate_hashes.rs @@ -8,7 +8,7 @@ use crate::MerkleNode; use crate::{merkledbbase::MerkleDBBase, MerkleMemDB}; // Given a list of hashes and sizes, compute the aggregate hash for a cas node. -pub fn cas_node_hash(chunks: &[(MerkleHash, (usize, usize))]) -> MerkleHash { +pub fn cas_node_hash(chunks: &[(MerkleHash, usize)]) -> MerkleHash { // Create an ephemeral MDB. if chunks.is_empty() { return MerkleHash::default(); @@ -18,7 +18,7 @@ pub fn cas_node_hash(chunks: &[(MerkleHash, (usize, usize))]) -> MerkleHash { let nodes: Vec = chunks .iter() - .map(|(h, (lb, ub))| mdb.maybe_add_node(h, ub - lb, Vec::default()).0) + .map(|(h, len)| mdb.maybe_add_node(h, *len, Vec::default()).0) .collect(); let m = mdb.merge_to_cas(&nodes[..]); diff --git a/shard_client/Cargo.toml b/shard_client/Cargo.toml index f258dcb6..af48d419 100644 --- a/shard_client/Cargo.toml +++ b/shard_client/Cargo.toml @@ -9,7 +9,7 @@ strict = [] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -utils = {path = "../utils"} +file_utils = {path = "../file_utils" } merklehash = {path = "../merklehash"} retry_strategy = {path = "../retry_strategy"} cas_client = {path = "../cas_client"} diff --git a/shard_client/src/error.rs b/shard_client/src/error.rs index b2ca7f24..4cdc9ab8 100644 --- a/shard_client/src/error.rs +++ b/shard_client/src/error.rs @@ -3,6 +3,9 @@ use xet_error::Error; #[non_exhaustive] #[derive(Error, Debug)] pub enum ShardClientError { + #[error("Invalid config: {0}")] + InvalidConfig(String), + #[error("File I/O error")] IOError(#[from] std::io::Error), @@ -23,6 +26,9 @@ pub enum ShardClientError { #[error("Bad endpoint: {0}")] UrlError(#[from] url::ParseError), + + #[error("Invalid Shard Key: {0}")] + InvalidShardKey(String), } // Define our own result type here (this seems to be the standard). diff --git a/shard_client/src/http_shard_client.rs b/shard_client/src/http_shard_client.rs index f1bfe5bb..b4af6acb 100644 --- a/shard_client/src/http_shard_client.rs +++ b/shard_client/src/http_shard_client.rs @@ -1,18 +1,24 @@ use crate::error::{Result, ShardClientError}; use crate::{RegistrationClient, ShardClientInterface}; - use async_trait::async_trait; use bytes::Buf; use cas_types::Key; -use cas_types::{ - QueryChunkResponse, QueryReconstructionResponse, UploadShardResponse, UploadShardResponseType, -}; +use cas_types::{QueryReconstructionResponse, UploadShardResponse, UploadShardResponseType}; +use file_utils::write_all_safe; use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDBFileInfo}; +use mdb_shard::serialization_utils::read_u32; use mdb_shard::shard_dedup_probe::ShardDedupProber; use mdb_shard::shard_file_reconstructor::FileReconstructor; use merklehash::MerkleHash; -use reqwest::{Url, header::{HeaderMap, HeaderValue}}; +use reqwest::{ + header::{HeaderMap, HeaderValue}, + Url, +}; use retry_strategy::RetryStrategy; +use std::io::Read; +use std::path::PathBuf; +use std::str::FromStr; +use tokio::task::JoinSet; use tracing::warn; const NUM_RETRIES: usize = 5; @@ -21,20 +27,26 @@ const BASE_RETRY_DELAY_MS: u64 = 3000; /// Shard Client that uses HTTP for communication. #[derive(Debug)] pub struct HttpShardClient { - pub endpoint: String, - pub token: Option, + endpoint: String, + token: Option, client: reqwest::Client, retry_strategy: RetryStrategy, + cache_directory: Option, } impl HttpShardClient { - pub fn new(endpoint: &str, token: Option) -> Self { + pub fn new( + endpoint: &str, + token: Option, + shard_cache_directory: Option, + ) -> Self { HttpShardClient { endpoint: endpoint.into(), token, client: reqwest::Client::builder().build().unwrap(), // Retry policy: Exponential backoff starting at BASE_RETRY_DELAY_MS and retrying NUM_RETRIES times retry_strategy: RetryStrategy::new(NUM_RETRIES, BASE_RETRY_DELAY_MS), + cache_directory: shard_cache_directory.clone(), } } } @@ -71,10 +83,13 @@ impl RegistrationClient for HttpShardClient { }; let url = Url::parse(&format!("{}/shard/{key}", self.endpoint))?; - + let mut headers = HeaderMap::new(); if let Some(tok) = &self.token { - headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap()); + headers.insert( + "Authorization", + HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap(), + ); } let response = self @@ -85,8 +100,22 @@ impl RegistrationClient for HttpShardClient { async { let url = url.clone(); match force_sync { - true => self.client.put(url).headers(headers).body(shard_data.to_vec()).send().await, - false => self.client.post(url).headers(headers).body(shard_data.to_vec()).send().await, + true => { + self.client + .put(url) + .headers(headers) + .body(shard_data.to_vec()) + .send() + .await + } + false => { + self.client + .post(url) + .headers(headers) + .body(shard_data.to_vec()) + .send() + .await + } } } }, @@ -165,6 +194,11 @@ impl ShardDedupProber for HttpShardClient { _salt: &[u8; 32], ) -> Result> { debug_assert!(chunk_hash.len() == 1); + let Some(shard_cache_dir) = &self.cache_directory else { + return Err(ShardClientError::InvalidConfig( + "cache directory not configured for shard storage".into(), + )); + }; // The API endpoint now only supports non-batched dedup request and // ignores salt. @@ -187,17 +221,66 @@ impl ShardDedupProber for HttpShardClient { .map_err(|e| ShardClientError::Other(format!("request failed with code {e}")))?; let response_body = response.bytes().await?; - let response_info: QueryChunkResponse = serde_json::from_reader(response_body.reader())?; + let mut reader = response_body.reader(); + + let mut downloaded_shards = vec![]; + + let mut write_join_set = JoinSet::>::new(); + + // Return in the format of + // [ + // { + // key_length: usize, // 4 bytes little endian + // key: [u8; key_length] + // }, + // { + // shard_content_length: usize, // 4 bytes little endian + // shard_content: [u8; shard_content_length] + // } + // ] // Repeat for each shard + loop { + let Ok(key_length) = read_u32(&mut reader) else { + break; + }; + let mut shard_key = vec![0u8; key_length as usize]; + reader.read_exact(&mut shard_key)?; + + let shard_key = String::from_utf8(shard_key) + .map_err(|e| ShardClientError::InvalidShardKey(format!("{e:?}")))?; + let shard_key = Key::from_str(&shard_key) + .map_err(|e| ShardClientError::InvalidShardKey(format!("{e:?}")))?; + downloaded_shards.push(shard_key.hash); + + let shard_content_length = read_u32(&mut reader)?; + let mut shard_content = vec![0u8; shard_content_length as usize]; + reader.read_exact(&mut shard_content)?; + + let file_name = local_shard_name(&shard_key.hash); + let file_path = shard_cache_dir.join(file_name); + write_join_set.spawn(async move { + write_all_safe(&file_path, &shard_content)?; + Ok(()) + }); + } + + while let Some(res) = write_join_set.join_next().await { + res.map_err(|e| ShardClientError::Other(format!("Internal task error: {e:?}")))??; + } - Ok(vec![response_info.shard]) + Ok(downloaded_shards) } } +/// Construct a file name for a MDBShard stored under cache and session dir. +fn local_shard_name(hash: &MerkleHash) -> PathBuf { + PathBuf::from(hash.to_string()).with_extension("mdb") +} + impl ShardClientInterface for HttpShardClient {} #[cfg(test)] mod test { - use std::path::PathBuf; + use std::{env, path::PathBuf}; use super::HttpShardClient; use crate::RegistrationClient; @@ -210,7 +293,7 @@ mod test { #[tokio::test] #[ignore = "need a local cas_server running"] async fn test_local() -> anyhow::Result<()> { - let client = HttpShardClient::new("http://localhost:8080", None); + let client = HttpShardClient::new("http://localhost:8080", None, Some(env::current_dir()?)); let path = PathBuf::from("./a7de567477348b23d23b667dba4d63d533c2ba7337cdc4297970bb494ba4699e.mdb"); From 71348d2489c43662253e8cc0984a4356166bb7c9 Mon Sep 17 00:00:00 2001 From: seanses Date: Wed, 2 Oct 2024 14:23:36 -0700 Subject: [PATCH 10/19] fix bugs and clean up --- Cargo.lock | 2 + cache/Cargo.toml | 1 + cache/src/error.rs | 2 +- cache/src/lib.rs | 6 +- cache/src/xorb_cache.rs | 4 +- cas_client/src/auth.rs | 2 +- cas_client/src/caching_client.rs | 2 +- cas_client/src/interface.rs | 8 +- cas_client/src/local_client.rs | 182 +++----- cas_client/src/remote_client.rs | 70 +-- cas_object/Cargo.toml | 10 +- cas_object/src/cas_object_format.rs | 439 ++++++++---------- cas_object/src/lib.rs | 1 + data/src/cas_interface.rs | 2 +- data/src/configurations.rs | 2 +- data/src/data_processing.rs | 4 +- data/src/errors.rs | 2 +- data/src/remote_shard_interface.rs | 2 +- hf_xet/Cargo.lock | 3 + hf_xet/src/config.rs | 2 +- hf_xet/src/data_client.rs | 4 +- hf_xet/src/lib.rs | 2 +- hf_xet/src/token_refresh.rs | 4 +- mdb_shard/Cargo.toml | 5 +- mdb_shard/src/cas_structs.rs | 2 +- mdb_shard/src/file_structs.rs | 2 +- ...ation_utils.rs => interpolation_search.rs} | 63 +-- mdb_shard/src/lib.rs | 2 +- mdb_shard/src/set_operations.rs | 2 +- mdb_shard/src/shard_format.rs | 3 +- progress_reporting/src/data_progress.rs | 2 +- shard_client/src/http_shard_client.rs | 4 +- utils/Cargo.toml | 2 +- utils/build.rs | 2 - utils/examples/infra.rs | 6 +- utils/proto/cas.proto | 81 ---- utils/proto/shard.proto | 69 --- utils/src/key.rs | 122 ----- utils/src/lib.rs | 68 +-- utils/src/safeio.rs | 59 --- utils/src/serialization_utils.rs | 60 +++ utils/src/singleflight.rs | 2 +- 42 files changed, 388 insertions(+), 924 deletions(-) rename mdb_shard/src/{serialization_utils.rs => interpolation_search.rs} (86%) delete mode 100644 utils/proto/cas.proto delete mode 100644 utils/proto/shard.proto delete mode 100644 utils/src/key.rs delete mode 100644 utils/src/safeio.rs create mode 100644 utils/src/serialization_utils.rs diff --git a/Cargo.lock b/Cargo.lock index 42ae5f39..3e2ad790 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -400,6 +400,7 @@ dependencies = [ "async-trait", "base64 0.13.1", "byteorder", + "cas_types", "chrono", "lazy_static", "lru", @@ -2104,6 +2105,7 @@ dependencies = [ "tempfile", "tokio", "tracing", + "utils", "uuid", "xet_error", ] diff --git a/cache/Cargo.toml b/cache/Cargo.toml index 9bed6179..8f14a649 100644 --- a/cache/Cargo.toml +++ b/cache/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +cas_types = { path = "../cas_types" } utils = { path = "../utils" } tokio = { version = "1.36", features = ["full"] } async-trait = "0.1.9" diff --git a/cache/src/error.rs b/cache/src/error.rs index 012a0808..8108ff4e 100644 --- a/cache/src/error.rs +++ b/cache/src/error.rs @@ -1,6 +1,6 @@ use crate::CacheError::OtherTaskError; -use cas::errors::SingleflightError; use tracing::error; +use utils::errors::SingleflightError; use xet_error::Error; #[non_exhaustive] diff --git a/cache/src/lib.rs b/cache/src/lib.rs index aec5a4bf..efad5ccb 100644 --- a/cache/src/lib.rs +++ b/cache/src/lib.rs @@ -5,12 +5,12 @@ use std::{fmt::Debug, sync::Arc}; use crate::error::Result; pub use block::BlockConverter; -use cas::key::Key; -use cas::singleflight; +use cas_types::Key; pub use disk::DiskCache; pub use error::CacheError; pub use interface::{BlockReadRequest, BlockReader, FileMetadata}; pub use metrics::set_metrics_service_name; +use utils::singleflight; pub use xorb_cache::XorbCacheImpl; mod block; @@ -29,7 +29,7 @@ mod xorb_cache; pub trait Remote: Debug + Sync + Send { async fn fetch( &self, - key: &cas::key::Key, + key: &Key, range: Range, ) -> std::result::Result, anyhow::Error>; } diff --git a/cache/src/xorb_cache.rs b/cache/src/xorb_cache.rs index 24777883..78e18aa3 100644 --- a/cache/src/xorb_cache.rs +++ b/cache/src/xorb_cache.rs @@ -6,8 +6,8 @@ use std::time::SystemTime; use tracing::{debug, info, info_span, warn}; use tracing_futures::Instrument; -use cas::key::Key; -use cas::singleflight; +use cas_types::Key; +use utils::singleflight; use crate::metrics::{ BLOCKS_READ, DATA_READ, READ_ERROR_COUNT, REQUEST_LATENCY_MS, REQUEST_THROUGHPUT, diff --git a/cas_client/src/auth.rs b/cas_client/src/auth.rs index 4aabf161..78e81130 100644 --- a/cas_client/src/auth.rs +++ b/cas_client/src/auth.rs @@ -1,10 +1,10 @@ use anyhow::anyhow; -use cas::auth::{AuthConfig, TokenProvider}; use reqwest::header::HeaderValue; use reqwest::header::AUTHORIZATION; use reqwest::{Request, Response}; use reqwest_middleware::{Middleware, Next}; use std::sync::{Arc, Mutex}; +use utils::auth::{AuthConfig, TokenProvider}; /// AuthMiddleware is a thread-safe middleware that adds a CAS auth token to outbound requests. /// If the token it holds is expired, it will automatically be refreshed. diff --git a/cas_client/src/caching_client.rs b/cas_client/src/caching_client.rs index d61eba01..d4519ff7 100644 --- a/cas_client/src/caching_client.rs +++ b/cas_client/src/caching_client.rs @@ -16,7 +16,7 @@ impl UploadClient for CachingClient { prefix: &str, hash: &MerkleHash, data: Vec, - chunk_boundaries: Vec, + chunk_boundaries: Vec<(MerkleHash, u32)>, ) -> Result<()> { todo!() } diff --git a/cas_client/src/interface.rs b/cas_client/src/interface.rs index d1cf1312..8756dfbf 100644 --- a/cas_client/src/interface.rs +++ b/cas_client/src/interface.rs @@ -11,7 +11,7 @@ use std::{io::Write, sync::Arc}; #[async_trait] pub trait UploadClient { /// Insert the provided data into the CAS as a XORB indicated by the prefix and hash. - /// The hash will be verified on the server-side according to the chunk boundaries. + /// The hash will be verified on the SERVER-side according to the chunk boundaries. /// Chunk Boundaries must be complete; i.e. the last entry in chunk boundary /// must be the length of data. For instance, if data="helloworld" with 2 chunks /// ["hello" "world"], chunk_boundaries should be [5, 10]. @@ -24,7 +24,7 @@ pub trait UploadClient { prefix: &str, hash: &MerkleHash, data: Vec, - chunk_boundaries: Vec, + chunk_and_boundaries: Vec<(MerkleHash, u32)>, ) -> Result<()>; /// Check if a XORB already exists. @@ -63,9 +63,9 @@ impl UploadClient for Arc { prefix: &str, hash: &MerkleHash, data: Vec, - chunk_boundaries: Vec, + chunk_and_boundaries: Vec<(MerkleHash, u32)>, ) -> Result<()> { - (**self).put(prefix, hash, data, chunk_boundaries).await + (**self).put(prefix, hash, data, chunk_and_boundaries).await } async fn exists(&self, prefix: &str, hash: &MerkleHash) -> Result { diff --git a/cas_client/src/local_client.rs b/cas_client/src/local_client.rs index 9bb97e80..f52f3b32 100644 --- a/cas_client/src/local_client.rs +++ b/cas_client/src/local_client.rs @@ -2,14 +2,14 @@ use crate::error::{CasClientError, Result}; use crate::interface::UploadClient; use anyhow::anyhow; use async_trait::async_trait; -use cas::key::Key; use cas_object::CasObject; +use cas_types::Key; use merklehash::MerkleHash; use std::fs::{metadata, File}; use std::io::{BufReader, BufWriter, Write}; use std::path::{Path, PathBuf}; use tempfile::TempDir; -use tracing::{debug, error, info}; +use tracing::{debug, info}; #[derive(Debug)] pub struct LocalClient { @@ -85,29 +85,6 @@ impl LocalClient { Ok(ret) } - /// A more complete get() which returns both the chunk boundaries as well - /// as the raw data - pub async fn get_detailed( - &self, - prefix: &str, - hash: &MerkleHash, - ) -> Result<(Vec, Vec)> { - let file_path = self.get_path_for_entry(prefix, hash); - - let file = File::open(&file_path).map_err(|_| { - if !self.silence_errors { - error!("Unable to find file in local CAS {:?}", file_path); - } - CasClientError::XORBNotFound(*hash) - })?; - - let mut reader = BufReader::new(file); - let cas = CasObject::deserialize(&mut reader)?; - let (boundaries, data) = cas.get_detailed_bytes(&mut reader)?; - - Ok((boundaries.into_iter().map(|x| x as u64).collect(), data)) - } - /// Deletes an entry pub fn delete(&self, prefix: &str, hash: &MerkleHash) { let file_path = self.get_path_for_entry(prefix, hash); @@ -134,17 +111,15 @@ impl UploadClient for LocalClient { prefix: &str, hash: &MerkleHash, data: Vec, - chunk_boundaries: Vec, + chunk_and_boundaries: Vec<(MerkleHash, u32)>, ) -> Result<()> { // no empty writes - if chunk_boundaries.is_empty() || data.is_empty() { + if chunk_and_boundaries.is_empty() || data.is_empty() { return Err(CasClientError::InvalidArguments); } // last boundary must be end of data - if !chunk_boundaries.is_empty() - && chunk_boundaries[chunk_boundaries.len() - 1] as usize != data.len() - { + if chunk_and_boundaries.last().unwrap().1 as usize != data.len() { return Err(CasClientError::InvalidArguments); } @@ -177,7 +152,7 @@ impl UploadClient for LocalClient { &mut writer, hash, &data, - &chunk_boundaries.into_iter().map(|x| x as u32).collect(), + &chunk_and_boundaries, cas_object::CompressionScheme::None, )?; // flush before persisting @@ -205,7 +180,11 @@ impl UploadClient for LocalClient { let res = metadata(&file_path); - if res.is_err() || !res.unwrap().is_file() { + if res.is_err() { + return Ok(false); + } + + if !res.unwrap().is_file() { return Err(CasClientError::InternalError(anyhow!( "Attempting to write to {:?}, but it is not a file", file_path @@ -242,9 +221,9 @@ mod tests_utils { &self, prefix: &str, hash: &MerkleHash, - ranges: Vec<(u64, u64)>, + ranges: Vec<(u32, u32)>, ) -> Result>>; - fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result; + fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result; } impl TestUtils for LocalClient { @@ -267,7 +246,7 @@ mod tests_utils { &self, prefix: &str, hash: &MerkleHash, - ranges: Vec<(u64, u64)>, + ranges: Vec<(u32, u32)>, ) -> Result>> { // Handle the case where we aren't asked for any real data. if ranges.len() == 1 && ranges[0].0 == ranges[0].1 { @@ -286,21 +265,22 @@ mod tests_utils { let cas = CasObject::deserialize(&mut reader)?; let mut ret: Vec> = Vec::new(); + let all_uncompressed_bytes = cas.get_all_bytes(&mut reader)?; for r in ranges { - let data = cas.get_range(&mut reader, r.0 as u32, r.1 as u32)?; + let data = all_uncompressed_bytes[r.0 as usize..r.1 as usize].to_vec(); ret.push(data); } Ok(ret) } - fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result { + fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result { let file_path = self.get_path_for_entry(prefix, hash); match File::open(file_path) { Ok(file) => { let mut reader = BufReader::new(file); let cas = CasObject::deserialize(&mut reader)?; - let length = cas.get_contents_length()?; - Ok(length as u64) + let length = cas.get_all_bytes(&mut reader)?.len(); + Ok(length as u32) } Err(_) => Err(CasClientError::XORBNotFound(*hash)), } @@ -312,9 +292,9 @@ mod tests_utils { mod tests { use super::tests_utils::TestUtils; use super::*; - use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB}; - use merklehash::{compute_data_hash, DataHash}; - use rand::Rng; + use cas_object::test_utils::*; + use cas_object::CompressionScheme::LZ4; + use merklehash::compute_data_hash; #[tokio::test] async fn test_basic_put_get() { @@ -322,13 +302,13 @@ mod tests { let client = LocalClient::default(); let data = gen_random_bytes(2048); let hash = compute_data_hash(&data[..]); - let chunk_boundaries = vec![data.len() as u64]; + let chunk_boundaries = data.len() as u32; let data_again = data.clone(); // Act & Assert assert!(client - .put("key", &hash, data, chunk_boundaries) + .put("key", &hash, data, vec![(hash, chunk_boundaries)]) .await .is_ok()); @@ -340,13 +320,17 @@ mod tests { async fn test_basic_put_get_random_medium() { // Arrange let client = LocalClient::default(); - let (hash, data, chunk_boundaries) = gen_dummy_xorb(44, 15633, true); + let (c, _, data, chunk_boundaries) = + build_cas_object(44, ChunkSize::Random(512, 15633), LZ4); let data_again = data.clone(); // Act & Assert - assert!(client.put("", &hash, data, chunk_boundaries).await.is_ok()); + assert!(client + .put("", &c.info.cashash, data, chunk_boundaries) + .await + .is_ok()); - let returned_data = client.get("", &hash).unwrap(); + let returned_data = client.get("", &c.info.cashash).unwrap(); assert_eq!(data_again, returned_data); } @@ -354,15 +338,20 @@ mod tests { async fn test_basic_put_get_range_random_small() { // Arrange let client = LocalClient::default(); - let (hash, data, chunk_boundaries) = gen_dummy_xorb(3, 2048, true); + let (c, _, data, chunk_boundaries) = build_cas_object(3, ChunkSize::Random(512, 2048), LZ4); let data_again = data.clone(); // Act & Assert - assert!(client.put("", &hash, data, chunk_boundaries).await.is_ok()); + assert!(client + .put("", &c.info.cashash, data, chunk_boundaries) + .await + .is_ok()); - let ranges: Vec<(u64, u64)> = vec![(0, 100), (100, 1500)]; + let ranges: Vec<(u32, u32)> = vec![(0, 100), (100, 1500)]; let ranges_again = ranges.clone(); - let returned_ranges = client.get_object_range("", &hash, ranges).unwrap(); + let returned_ranges = client + .get_object_range("", &c.info.cashash, ranges) + .unwrap(); for idx in 0..returned_ranges.len() { assert_eq!( @@ -376,12 +365,15 @@ mod tests { async fn test_basic_length() { // Arrange let client = LocalClient::default(); - let (hash, data, chunk_boundaries) = gen_dummy_xorb(1, 2048, false); + let (c, _, data, chunk_boundaries) = build_cas_object(1, ChunkSize::Fixed(2048), LZ4); let gen_length = data.len(); // Act - client.put("", &hash, data, chunk_boundaries).await.unwrap(); - let len = client.get_length("", &hash).unwrap(); + client + .put("", &c.info.cashash, data, chunk_boundaries) + .await + .unwrap(); + let len = client.get_length("", &c.info.cashash).unwrap(); // Assert assert_eq!(len as usize, gen_length); @@ -391,7 +383,10 @@ mod tests { async fn test_missing_xorb() { // Arrange let client = LocalClient::default(); - let (hash, _, _) = gen_dummy_xorb(16, 2048, true); + let hash = MerkleHash::from_hex( + "d760aaf4beb07581956e24c847c47f1abd2e419166aa68259035bc412232e9da", + ) + .unwrap(); // Act & Assert let result = client.get("", &hash); @@ -406,13 +401,23 @@ mod tests { let hello_hash = merklehash::compute_data_hash(&hello[..]); // write "hello world" client - .put("key", &hello_hash, hello.clone(), vec![hello.len() as u64]) + .put( + "key", + &hello_hash, + hello.clone(), + vec![(hello_hash, hello.len() as u32)], + ) .await .unwrap(); // put the same value a second time. This should be ok. client - .put("key", &hello_hash, hello.clone(), vec![hello.len() as u64]) + .put( + "key", + &hello_hash, + hello.clone(), + vec![(hello_hash, hello.len() as u32)], + ) .await .unwrap(); @@ -427,20 +432,6 @@ mod tests { }] ); - // put the different value with the same hash - // this should fail - assert_eq!( - CasClientError::CasObjectError(cas_object::error::CasObjectError::HashMismatch), - client - .put( - "hellp", - &hello_hash, - "hellp world".as_bytes().to_vec(), - vec![hello.len() as u64], - ) - .await - .unwrap_err() - ); // content shorter than the chunk boundaries should fail assert_eq!( CasClientError::InvalidArguments, @@ -449,7 +440,7 @@ mod tests { "hellp2", &hello_hash, "hellp wod".as_bytes().to_vec(), - vec![hello.len() as u64], + vec![(hello_hash, hello.len() as u32)], ) .await .unwrap_err() @@ -463,7 +454,7 @@ mod tests { "again", &hello_hash, "hello world again".as_bytes().to_vec(), - vec![hello.len() as u64], + vec![(hello_hash, hello.len() as u32)], ) .await .unwrap_err() @@ -534,52 +525,9 @@ mod tests { "key", &final_hash, "helloworld".as_bytes().to_vec(), - vec![5, 10], + vec![(hello_hash, 5), (world_hash, 10)], ) .await .unwrap(); } - - fn gen_dummy_xorb( - num_chunks: u32, - uncompressed_chunk_size: u32, - randomize_chunk_sizes: bool, - ) -> (DataHash, Vec, Vec) { - let mut contents = Vec::new(); - let mut chunks: Vec = Vec::new(); - let mut chunk_boundaries = Vec::with_capacity(num_chunks as usize); - for _idx in 0..num_chunks { - let chunk_size: u32 = if randomize_chunk_sizes { - let mut rng = rand::thread_rng(); - rng.gen_range(1024..=uncompressed_chunk_size) - } else { - uncompressed_chunk_size - }; - - let bytes = gen_random_bytes(chunk_size); - - chunks.push(Chunk { - hash: merklehash::compute_data_hash(&bytes), - length: bytes.len(), - }); - - contents.extend(bytes); - chunk_boundaries.push(contents.len() as u64); - } - - let mut db = MerkleMemDB::default(); - let mut staging = db.start_insertion_staging(); - db.add_file(&mut staging, &chunks); - let ret = db.finalize(staging); - let hash = *ret.hash(); - - (hash, contents, chunk_boundaries) - } - - fn gen_random_bytes(uncompressed_chunk_size: u32) -> Vec { - let mut rng = rand::thread_rng(); - let mut data = vec![0u8; uncompressed_chunk_size as usize]; - rng.fill(&mut data[..]); - data - } } diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 24be8ce7..3026b6c5 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -5,7 +5,6 @@ use anyhow::anyhow; use async_trait::async_trait; use bytes::Buf; use bytes::Bytes; -use cas::auth::AuthConfig; use cas_object::CasObject; use cas_types::{CASReconstructionTerm, Key, QueryReconstructionResponse, UploadXorbResponse}; use error_printer::OptionPrinter; @@ -14,6 +13,7 @@ use reqwest::{StatusCode, Url}; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware}; use std::io::{Cursor, Write}; use tracing::{debug, warn}; +use utils::auth::AuthConfig; pub const CAS_ENDPOINT: &str = "http://localhost:8080"; pub const PREFIX_DEFAULT: &str = "default"; @@ -32,14 +32,14 @@ impl UploadClient for RemoteClient { prefix: &str, hash: &MerkleHash, data: Vec, - chunk_boundaries: Vec, + chunk_and_boundaries: Vec<(MerkleHash, u32)>, ) -> Result<()> { let key = Key { prefix: prefix.to_string(), hash: *hash, }; - let was_uploaded = self.upload(&key, &data, chunk_boundaries).await?; + let was_uploaded = self.upload(&key, &data, chunk_and_boundaries).await?; if !was_uploaded { debug!("{key:?} not inserted into CAS."); @@ -110,7 +110,7 @@ impl RemoteClient { &self, key: &Key, contents: &[u8], - chunk_boundaries: Vec, + chunk_and_boundaries: Vec<(MerkleHash, u32)>, ) -> Result { let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?; @@ -120,7 +120,7 @@ impl RemoteClient { &mut writer, &key.hash, contents, - &chunk_boundaries.into_iter().map(|x| x as u32).collect(), + &chunk_and_boundaries, cas_object::CompressionScheme::LZ4, )?; @@ -245,12 +245,9 @@ impl OptionalMiddleware for ClientBuilder { #[cfg(test)] mod tests { - use rand::Rng; - use tracing_test::traced_test; - use super::*; - use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB}; - use merklehash::DataHash; + use cas_object::test_utils::*; + use tracing_test::traced_test; #[ignore] #[traced_test] @@ -259,55 +256,18 @@ mod tests { // Arrange let rc = RemoteClient::new(CAS_ENDPOINT, &None); let prefix = PREFIX_DEFAULT; - let (hash, data, chunk_boundaries) = gen_dummy_xorb(3, 10248, true); + let (c, _, data, chunk_boundaries) = build_cas_object( + 3, + ChunkSize::Random(512, 10248), + cas_object::CompressionScheme::LZ4, + ); // Act - let result = rc.put(prefix, &hash, data, chunk_boundaries).await; + let result = rc + .put(prefix, &c.info.cashash, data, chunk_boundaries) + .await; // Assert assert!(result.is_ok()); } - - fn gen_dummy_xorb( - num_chunks: u32, - uncompressed_chunk_size: u32, - randomize_chunk_sizes: bool, - ) -> (DataHash, Vec, Vec) { - let mut contents = Vec::new(); - let mut chunks: Vec = Vec::new(); - let mut chunk_boundaries = Vec::with_capacity(num_chunks as usize); - for _idx in 0..num_chunks { - let chunk_size: u32 = if randomize_chunk_sizes { - let mut rng = rand::thread_rng(); - rng.gen_range(1024..=uncompressed_chunk_size) - } else { - uncompressed_chunk_size - }; - - let bytes = gen_random_bytes(chunk_size); - - chunks.push(Chunk { - hash: merklehash::compute_data_hash(&bytes), - length: bytes.len(), - }); - - contents.extend(bytes); - chunk_boundaries.push(contents.len() as u64); - } - - let mut db = MerkleMemDB::default(); - let mut staging = db.start_insertion_staging(); - db.add_file(&mut staging, &chunks); - let ret = db.finalize(staging); - let hash = *ret.hash(); - - (hash, contents, chunk_boundaries) - } - - fn gen_random_bytes(uncompressed_chunk_size: u32) -> Vec { - let mut rng = rand::thread_rng(); - let mut data = vec![0u8; uncompressed_chunk_size as usize]; - rng.fill(&mut data[..]); - data - } } diff --git a/cas_object/Cargo.toml b/cas_object/Cargo.toml index 29fdaacb..be227693 100644 --- a/cas_object/Cargo.toml +++ b/cas_object/Cargo.toml @@ -4,18 +4,18 @@ version = "0.1.0" edition = "2021" [dependencies] +xet_error = { path = "../xet_error" } +cas_types = { path = "../cas_types" } +merkledb = { path = "../merkledb" } +merklehash = { path = "../merklehash" } anyhow = "1.0.88" bincode = "1.3.3" http = "1.1.0" -merkledb = { path = "../merkledb" } -merklehash = { path = "../merklehash" } tempfile = "3.12.0" tracing = "0.1.40" -xet_error = { path = "../xet_error" } -cas_types = { path = "../cas_types" } lz4_flex = "0.11.3" bytes = "1.7.2" +rand = "0.8.5" [dev-dependencies] -rand = "0.8.5" tempfile = "3.12.0" diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index cef0f8a2..3087967c 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -1,20 +1,19 @@ +use crate::{ + cas_chunk_format::{deserialize_chunk, serialize_chunk}, + error::CasObjectError, + CompressionScheme, +}; +use anyhow::anyhow; use bytes::Buf; use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB}; -use merklehash::{DataHash, MerkleHash}; +use merklehash::MerkleHash; use std::{ - cmp::{max, min}, + cmp::min, io::{Cursor, Error, Read, Seek, Write}, mem::size_of, }; use tracing::warn; -use crate::{ - cas_chunk_format::{deserialize_chunk, serialize_chunk}, - error::CasObjectError, - CompressionScheme, -}; -use anyhow::anyhow; - const CAS_OBJECT_FORMAT_IDENT: [u8; 7] = [b'X', b'E', b'T', b'B', b'L', b'O', b'B']; const CAS_OBJECT_FORMAT_VERSION: u8 = 0; const CAS_OBJECT_INFO_DEFAULT_LENGTH: u32 = 60; @@ -39,7 +38,7 @@ pub struct CasObjectInfo { /// This vector only contains boundaries, so assumes the first chunk starts at offset 0. /// The final entry in vector is the total length of the chunks. /// See example below. - /// chunk[n] offset = chunk_boundary_offsets[n-1] + /// chunk[n] are bytes in [chunk_boundary_offsets[n-1], chunk_boundary_offsets[n]) /// ``` /// // ex. chunks: [ 0, 1, 2, 3 ] /// // chunk_boundary_offsets: [ 100, 200, 300, 400] @@ -58,7 +57,7 @@ impl Default for CasObjectInfo { CasObjectInfo { ident: CAS_OBJECT_FORMAT_IDENT, version: CAS_OBJECT_FORMAT_VERSION, - cashash: DataHash::default(), + cashash: MerkleHash::default(), num_chunks: 0, chunk_boundary_offsets: Vec::new(), chunk_hashes: Vec::new(), @@ -107,19 +106,6 @@ impl CasObjectInfo { pub fn deserialize(reader: &mut R) -> Result<(Self, u32), CasObjectError> { let mut total_bytes_read: u32 = 0; - // Go to end of Reader and get length, then jump back to it, and read sequentially - // read last 4 bytes to get length - reader.seek(std::io::SeekFrom::End(-(size_of::() as i64)))?; - - let mut info_length = [0u8; 4]; - reader.read_exact(&mut info_length)?; - let info_length = u32::from_le_bytes(info_length); - - // now seek back that many bytes + size of length (u32) and read sequentially. - reader.seek(std::io::SeekFrom::End( - -(size_of::() as i64 + info_length as i64), - ))?; - // Helper function to read data and update the byte count let mut read_bytes = |data: &mut [u8]| -> Result<(), CasObjectError> { reader.read_exact(data)?; @@ -143,37 +129,30 @@ impl CasObjectInfo { ))); } - let mut buf = [0u8; 32]; + let mut buf = [0u8; size_of::()]; read_bytes(&mut buf)?; - let cashash = DataHash::from(&buf); + let cashash = MerkleHash::from(&buf); - let mut num_chunks = [0u8; 4]; + let mut num_chunks = [0u8; size_of::()]; read_bytes(&mut num_chunks)?; let num_chunks = u32::from_le_bytes(num_chunks); let mut chunk_boundary_offsets = Vec::with_capacity(num_chunks as usize); for _ in 0..num_chunks { - let mut offset = [0u8; 4]; + let mut offset = [0u8; size_of::()]; read_bytes(&mut offset)?; chunk_boundary_offsets.push(u32::from_le_bytes(offset)); } let mut chunk_hashes = Vec::with_capacity(num_chunks as usize); for _ in 0..num_chunks { - let mut hash = [0u8; 32]; + let mut hash = [0u8; size_of::()]; read_bytes(&mut hash)?; - chunk_hashes.push(DataHash::from(&hash)); + chunk_hashes.push(MerkleHash::from(&hash)); } let mut _buffer = [0u8; 16]; read_bytes(&mut _buffer)?; - // validate that info_length matches what we read off of header - if total_bytes_read != info_length { - return Err(CasObjectError::FormatError(anyhow!( - "Xorb Info Format Error" - ))); - } - Ok(( CasObjectInfo { ident, @@ -184,7 +163,7 @@ impl CasObjectInfo { chunk_hashes, _buffer, }, - info_length, + total_bytes_read, )) } } @@ -243,46 +222,50 @@ impl CasObject { /// /// This allows the CasObject to be partially constructed, allowing for range reads inside the CasObject. pub fn deserialize(reader: &mut R) -> Result { - let (info, info_length) = CasObjectInfo::deserialize(reader)?; + let info_length = Self::get_info_length(reader)?; + + // now seek back that many bytes + size of length (u32) and read sequentially. + reader.seek(std::io::SeekFrom::End( + -(size_of::() as i64 + info_length as i64), + ))?; + + let (info, total_bytes_read) = CasObjectInfo::deserialize(reader)?; + + // validate that info_length matches what we read off of header + if total_bytes_read != info_length { + return Err(CasObjectError::FormatError(anyhow!( + "Xorb Info Format Error" + ))); + } + Ok(Self { info, info_length }) } - /// Used by LocalClient for generating Cas Object from chunk_boundaries while uploading or downloading blocks. + /// Serialize into Cas Object from uncompressed data and chunk boundaries. + /// Assumes correctness from caller: it's the receiver's responsibility to validate a cas object. pub fn serialize( writer: &mut W, hash: &MerkleHash, data: &[u8], - chunk_boundaries: &Vec, + chunk_and_boundaries: &[(MerkleHash, u32)], compression_scheme: CompressionScheme, ) -> Result<(Self, usize), CasObjectError> { - // validate hash against contents - if !Self::validate_root_hash(data, chunk_boundaries, hash) { - return Err(CasObjectError::HashMismatch); - } - let mut cas = CasObject::default(); - cas.info.cashash.copy_from_slice(hash.as_slice()); - cas.info.num_chunks = chunk_boundaries.len() as u32; + cas.info.cashash = *hash; + cas.info.num_chunks = chunk_and_boundaries.len() as u32; cas.info.chunk_boundary_offsets = Vec::with_capacity(cas.info.num_chunks as usize); - cas.info.chunk_hashes = Vec::with_capacity(cas.info.num_chunks as usize); + cas.info.chunk_hashes = chunk_and_boundaries.iter().map(|(hash, _)| *hash).collect(); let mut total_written_bytes: usize = 0; let mut raw_start_idx = 0; - for boundary in chunk_boundaries { - let chunk_boundary: u32 = *boundary; - - let mut chunk_raw_bytes = Vec::::new(); - chunk_raw_bytes - .extend_from_slice(&data[raw_start_idx as usize..chunk_boundary as usize]); + for boundary in chunk_and_boundaries { + let chunk_boundary: u32 = boundary.1; - // generate chunk hash and store it - let chunk_hash = merklehash::compute_data_hash(&chunk_raw_bytes); - cas.info.chunk_hashes.push(chunk_hash); + let chunk_raw_bytes = &data[raw_start_idx as usize..chunk_boundary as usize]; // now serialize chunk directly to writer (since chunks come first!) - let chunk_written_bytes = - serialize_chunk(&chunk_raw_bytes, writer, compression_scheme)?; + let chunk_written_bytes = serialize_chunk(chunk_raw_bytes, writer, compression_scheme)?; total_written_bytes += chunk_written_bytes; cas.info .chunk_boundary_offsets @@ -382,50 +365,33 @@ impl CasObject { /// Generate a hash for securing a chunk range. /// - /// chunk_start_index, chunk_end_index: byte indices for chunks in CasObject. + /// chunk_start_index, chunk_end_index: indices for chunks in CasObject. /// key: additional key incorporated into generating hash. /// /// This hash ensures validity of the knowledge of chunks, since ranges are public, - /// this ensures that only users that actually have access to chunks can request them. + /// this ensures that only users that actually have access to chunks can claim them + /// in a file reconstruction entry. pub fn generate_chunk_range_hash( &self, chunk_start_index: u32, chunk_end_index: u32, key: &[u8], - ) -> Result { + ) -> Result { self.validate_cas_object_info()?; - if chunk_end_index < chunk_start_index - || self.get_contents_length()? > max(chunk_end_index, chunk_start_index) - { + if chunk_end_index < chunk_start_index || chunk_end_index > self.info.num_chunks { return Err(CasObjectError::InvalidArguments); } - // walk chunk boundaries and collect relevant hashes - let mut range_hashes = Vec::::new(); - let mut found_start = chunk_start_index == 0; - let mut found_end = false; - for (idx, boundary) in self.info.chunk_boundary_offsets.iter().enumerate() { - let boundary = *boundary; - if found_start || chunk_start_index == boundary { - found_start = true; - let chunk_hash = self.info.chunk_hashes.get(idx).unwrap(); - range_hashes.push(*chunk_hash); - } - - // if found end then exit loop early - if chunk_end_index == boundary { - found_end = true; - break; - } - } - - if !found_start || !found_end { - return Err(CasObjectError::InternalError(anyhow!("Chunk Range Invalid"))) - } + // Collect relevant hashes + let range_hashes = + self.info.chunk_hashes[chunk_start_index as usize..chunk_end_index as usize].as_ref(); // TODO: Make this more robust, currently appends range hashes together, adds key to end - let mut combined : Vec = range_hashes.iter().flat_map(|hash| hash.as_bytes().to_vec()).collect(); + let mut combined: Vec = range_hashes + .iter() + .flat_map(|hash| hash.as_bytes().to_vec()) + .collect(); combined.extend_from_slice(key); // now hash the hashes + key and return @@ -434,7 +400,7 @@ impl CasObject { Ok(range_hash) } - /// Return end value of all chunk contents (byte index prior to header) + /// Return end offset of all physical chunk contents (byte index at the beginning of footer) pub fn get_contents_length(&self) -> Result { self.validate_cas_object_info()?; match self.info.chunk_boundary_offsets.last() { @@ -445,10 +411,12 @@ impl CasObject { } } - /// Get range of content bytes from Xorb. + /// Get range of content bytes uncompressed from Xorb. + /// + /// start and end are byte indices into the physical layout of a xorb. /// /// The start and end parameters are required to align with chunk boundaries. - pub fn get_range( + fn get_range( &self, reader: &mut R, start: u32, @@ -479,23 +447,8 @@ impl CasObject { self.get_range(reader, 0, self.get_contents_length()?) } - /// Get all the content bytes from a Xorb, and return the chunk boundaries - pub fn get_detailed_bytes( - &self, - reader: &mut R, - ) -> Result<(Vec, Vec), CasObjectError> { - self.validate_cas_object_info()?; - - let data = self.get_all_bytes(reader)?; - let chunk_boundaries = self.get_chunk_boundaries()?; - - Ok((chunk_boundaries, data)) - } - /// Assumes chunk_data is 1+ complete chunks. Processes them sequentially and returns them as Vec. fn get_chunk_contents(&self, chunk_data: &[u8]) -> Result, CasObjectError> { - self.validate_cas_object_info()?; - // walk chunk_data, deserialize into Chunks, and then get_bytes() from each of them. let mut reader = Cursor::new(chunk_data); let mut res = Vec::::new(); @@ -507,22 +460,28 @@ impl CasObject { Ok(res) } - /// Helper function to translate CasObjectInfo.chunk_boundary_offsets to just return chunk boundaries. - /// - /// The final chunk boundary returned is required to be the length of the contents. - fn get_chunk_boundaries(&self) -> Result, CasObjectError> { + /// Helper function to translate a range of chunk indices to physical byte offset range. + fn get_byte_offset( + &self, + chunk_index_start: u32, + chunk_index_end: u32, + ) -> Result<(u32, u32), CasObjectError> { self.validate_cas_object_info()?; - Ok(self.info.chunk_boundary_offsets.to_vec()) + if chunk_index_end <= chunk_index_start || chunk_index_end > self.info.num_chunks { + return Err(CasObjectError::InvalidArguments); + } + + let byte_offset_start = match chunk_index_start { + 0 => 0, + _ => self.info.chunk_boundary_offsets[chunk_index_start as usize - 1], + }; + let byte_offset_end = self.info.chunk_boundary_offsets[chunk_index_end as usize - 1]; + + Ok((byte_offset_start, byte_offset_end)) } /// Helper method to verify that info object is complete fn validate_cas_object_info(&self) -> Result<(), CasObjectError> { - if self.info == Default::default() { - return Err(CasObjectError::InternalError(anyhow!( - "Incomplete CasObject, no CasObjectInfo footer." - ))); - } - if self.info.num_chunks == 0 { return Err(CasObjectError::FormatError(anyhow!( "Invalid CasObjectInfo, no chunks in CasObject." @@ -537,7 +496,7 @@ impl CasObject { ))); } - if self.info.cashash == DataHash::default() { + if self.info.cashash == MerkleHash::default() { return Err(CasObjectError::FormatError(anyhow!( "Invalid CasObjectInfo, Missing cashash." ))); @@ -545,41 +504,104 @@ impl CasObject { Ok(()) } +} - /// Helper method to validate root hash for data block. - fn validate_root_hash(data: &[u8], chunk_boundaries: &[u32], hash: &MerkleHash) -> bool { - // at least 1 chunk, and last entry in chunk boundary must match the length - if chunk_boundaries.is_empty() - || chunk_boundaries[chunk_boundaries.len() - 1] as usize != data.len() - { - return false; - } +pub mod test_utils { + use super::*; + use crate::cas_chunk_format::serialize_chunk; + use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB}; + use rand::Rng; + + pub fn gen_random_bytes(size: u32) -> Vec { + let mut rng = rand::thread_rng(); + let mut data = vec![0u8; size as usize]; + rng.fill(&mut data[..]); + data + } + + pub enum ChunkSize { + Random(u32, u32), + Fixed(u32), + } + + /// Utility test method for creating a cas object + /// Returns (CasObject, chunks serialized, raw data, raw data chunk boundaries) + #[allow(clippy::type_complexity)] + pub fn build_cas_object( + num_chunks: u32, + chunk_size: ChunkSize, + compression_scheme: CompressionScheme, + ) -> (CasObject, Vec, Vec, Vec<(MerkleHash, u32)>) { + let mut c = CasObject::default(); + + let mut chunk_boundary_offsets = vec![]; + let mut chunk_hashes = vec![]; + let mut writer = Cursor::new(vec![]); - let mut chunks: Vec = Vec::new(); - let mut left_edge: usize = 0; - for i in chunk_boundaries { - let right_edge = *i as usize; - let hash = merklehash::compute_data_hash(&data[left_edge..right_edge]); - let length = right_edge - left_edge; - chunks.push(Chunk { hash, length }); - left_edge = right_edge; + let mut total_bytes = 0; + let mut chunks = vec![]; + let mut data_contents_raw = vec![]; + let mut raw_chunk_boundaries = vec![]; + + for _idx in 0..num_chunks { + let chunk_size: u32 = match chunk_size { + ChunkSize::Random(a, b) => { + let mut rng = rand::thread_rng(); + rng.gen_range(a..=b) + } + ChunkSize::Fixed(size) => size, + }; + + let bytes = gen_random_bytes(chunk_size); + + let chunk_hash = merklehash::compute_data_hash(&bytes); + chunks.push(Chunk { + hash: chunk_hash, + length: bytes.len(), + }); + + data_contents_raw.extend_from_slice(&bytes); + + // build chunk, create ChunkInfo and keep going + + let bytes_written = serialize_chunk(&bytes, &mut writer, compression_scheme).unwrap(); + + total_bytes += bytes_written as u32; + + raw_chunk_boundaries.push((chunk_hash, data_contents_raw.len() as u32)); + chunk_boundary_offsets.push(total_bytes); + chunk_hashes.push(chunk_hash); } + c.info.num_chunks = chunk_boundary_offsets.len() as u32; + c.info.chunk_boundary_offsets = chunk_boundary_offsets; + c.info.chunk_hashes = chunk_hashes; + let mut db = MerkleMemDB::default(); let mut staging = db.start_insertion_staging(); db.add_file(&mut staging, &chunks); let ret = db.finalize(staging); - *ret.hash() == *hash + + c.info.cashash = *ret.hash(); + + // now serialize info to end Xorb length + let mut buf = Cursor::new(Vec::new()); + let len = c.info.serialize(&mut buf).unwrap(); + c.info_length = len as u32; + + ( + c, + writer.get_ref().to_vec(), + data_contents_raw, + raw_chunk_boundaries, + ) } } #[cfg(test)] mod tests { - + use super::test_utils::*; use super::*; - use crate::cas_chunk_format::serialize_chunk; - use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB}; - use rand::Rng; use std::io::Cursor; #[test] @@ -591,7 +613,7 @@ mod tests { let expected_default = CasObjectInfo { ident: CAS_OBJECT_FORMAT_IDENT, version: CAS_OBJECT_FORMAT_VERSION, - cashash: DataHash::default(), + cashash: MerkleHash::default(), num_chunks: 0, chunk_boundary_offsets: Vec::new(), chunk_hashes: Vec::new(), @@ -614,13 +636,11 @@ mod tests { } #[test] - fn test_chunk_boundaries_chunk_size_info() { + fn test_uncompressed_cas_object() { // Arrange let (c, _cas_data, _raw_data, _raw_chunk_boundaries) = - build_cas_object(3, 100, false, CompressionScheme::None); + build_cas_object(3, ChunkSize::Fixed(100), CompressionScheme::None); // Act & Assert - assert_eq!(c.get_chunk_boundaries().unwrap().len(), 3); - assert_eq!(c.get_chunk_boundaries().unwrap(), [108, 216, 324]); assert_eq!(c.info.num_chunks, 3); assert_eq!( c.info.chunk_boundary_offsets.len(), @@ -628,123 +648,44 @@ mod tests { ); let second_chunk_boundary = *c.info.chunk_boundary_offsets.get(1).unwrap(); - let third_chunk_boundary = *c.info.chunk_boundary_offsets.get(2).unwrap(); assert_eq!(second_chunk_boundary, 216); // 8-byte header, 3 chunks, so 2nd chunk boundary is at byte 216 + + let third_chunk_boundary = *c.info.chunk_boundary_offsets.get(2).unwrap(); assert_eq!(third_chunk_boundary, 324); // 8-byte header, 3 chunks, so 3rd chunk boundary is at byte 324 - } - fn gen_random_bytes(uncompressed_chunk_size: u32) -> Vec { - let mut rng = rand::thread_rng(); - let mut data = vec![0u8; uncompressed_chunk_size as usize]; - rng.fill(&mut data[..]); - data + let byte_offset_range = c.get_byte_offset(0, 1).unwrap(); + assert_eq!(byte_offset_range, (0, 108)); + + let byte_offset_range = c.get_byte_offset(1, 3).unwrap(); + assert_eq!(byte_offset_range, (108, 324)); } #[test] fn test_generate_range_hash_full_range() { // Arrange let (c, _cas_data, _raw_data, _raw_chunk_boundaries) = - build_cas_object(3, 100, false, CompressionScheme::None); + build_cas_object(3, ChunkSize::Fixed(100), CompressionScheme::None); let key = [b'K', b'E', b'Y']; - let mut hashes : Vec = c.info.chunk_hashes.iter().flat_map(|hash| hash.as_bytes().to_vec()).collect(); + let mut hashes: Vec = c + .info + .chunk_hashes + .iter() + .flat_map(|hash| hash.as_bytes().to_vec()) + .collect(); hashes.extend_from_slice(&key); let expected_hash = merklehash::compute_data_hash(&hashes); // Act & Assert - let range_hash = c.generate_chunk_range_hash(0, 324, &key).unwrap(); + let range_hash = c.generate_chunk_range_hash(0, 3, &key).unwrap(); assert_eq!(range_hash, expected_hash); } - #[ignore = "not written yet"] - #[test] - fn test_generate_range_hash_partial() { - todo!() - } - - #[ignore = "Not written yet"] - #[test] - fn test_validate_cas_object_info() { - todo!() - } - - /// Utility test method for creating a cas object - /// Returns (CasObject, CasObjectInfo serialized, raw data, raw data chunk boundaries) - fn build_cas_object( - num_chunks: u32, - uncompressed_chunk_size: u32, - use_random_chunk_size: bool, - compression_scheme: CompressionScheme, - ) -> (CasObject, Vec, Vec, Vec) { - let mut c = CasObject::default(); - - let mut chunk_boundary_offsets = Vec::::new(); - let mut chunk_hashes = Vec::::new(); - let mut writer = Cursor::new(Vec::::new()); - - let mut total_bytes = 0; - let mut chunks: Vec = Vec::new(); - let mut data_contents_raw = Vec::::new(); - let mut raw_chunk_boundaries = Vec::::new(); - - for _idx in 0..num_chunks { - let chunk_size: u32 = if use_random_chunk_size { - let mut rng = rand::thread_rng(); - rng.gen_range(512..=uncompressed_chunk_size) - } else { - uncompressed_chunk_size - }; - - let bytes = gen_random_bytes(chunk_size); - - let chunk_hash = merklehash::compute_data_hash(&bytes); - chunks.push(Chunk { - hash: chunk_hash, - length: bytes.len(), - }); - - data_contents_raw.extend_from_slice(&bytes); - - // build chunk, create ChunkInfo and keep going - - let bytes_written = serialize_chunk(&bytes, &mut writer, compression_scheme).unwrap(); - - total_bytes += bytes_written as u32; - - raw_chunk_boundaries.push(data_contents_raw.len() as u32); - chunk_boundary_offsets.push(total_bytes); - chunk_hashes.push(chunk_hash); - } - - c.info.num_chunks = chunk_boundary_offsets.len() as u32; - c.info.chunk_boundary_offsets = chunk_boundary_offsets; - c.info.chunk_hashes = chunk_hashes; - - let mut db = MerkleMemDB::default(); - let mut staging = db.start_insertion_staging(); - db.add_file(&mut staging, &chunks); - let ret = db.finalize(staging); - - c.info.cashash = *ret.hash(); - - // now serialize info to end Xorb length - let mut buf = Cursor::new(Vec::new()); - let len = c.info.serialize(&mut buf).unwrap(); - c.info_length = len as u32; - - ( - c, - writer.get_ref().to_vec(), - data_contents_raw, - raw_chunk_boundaries, - ) - } - #[test] fn test_compress_decompress() { // Arrange let (c, _cas_data, raw_data, raw_chunk_boundaries) = - build_cas_object(55, 53212, false, CompressionScheme::LZ4); + build_cas_object(55, ChunkSize::Fixed(53212), CompressionScheme::LZ4); // Act & Assert let mut writer: Cursor> = Cursor::new(Vec::new()); @@ -791,7 +732,7 @@ mod tests { fn test_hash_generation_compression() { // Arrange let (c, cas_data, raw_data, raw_chunk_boundaries) = - build_cas_object(55, 53212, false, CompressionScheme::LZ4); + build_cas_object(55, ChunkSize::Fixed(53212), CompressionScheme::LZ4); // Act & Assert let mut buf: Cursor> = Cursor::new(Vec::new()); assert!(CasObject::serialize( @@ -813,7 +754,7 @@ mod tests { fn test_basic_serialization_mem() { // Arrange let (c, _cas_data, raw_data, raw_chunk_boundaries) = - build_cas_object(3, 100, false, CompressionScheme::None); + build_cas_object(3, ChunkSize::Fixed(100), CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -832,7 +773,7 @@ mod tests { fn test_serialization_deserialization_mem_medium() { // Arrange let (c, _cas_data, raw_data, raw_chunk_boundaries) = - build_cas_object(32, 16384, false, CompressionScheme::None); + build_cas_object(32, ChunkSize::Fixed(16384), CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -863,7 +804,7 @@ mod tests { fn test_serialization_deserialization_mem_large_random() { // Arrange let (c, _cas_data, raw_data, raw_chunk_boundaries) = - build_cas_object(32, 65536, true, CompressionScheme::None); + build_cas_object(32, ChunkSize::Random(512, 65536), CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -893,7 +834,7 @@ mod tests { fn test_serialization_deserialization_file_large_random() { // Arrange let (c, _cas_data, raw_data, raw_chunk_boundaries) = - build_cas_object(256, 65536, true, CompressionScheme::None); + build_cas_object(256, ChunkSize::Random(512, 65536), CompressionScheme::None); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -923,7 +864,7 @@ mod tests { fn test_basic_mem_lz4() { // Arrange let (c, _cas_data, raw_data, raw_chunk_boundaries) = - build_cas_object(1, 8, false, CompressionScheme::LZ4); + build_cas_object(1, ChunkSize::Fixed(8), CompressionScheme::LZ4); let mut writer: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -952,7 +893,7 @@ mod tests { fn test_serialization_deserialization_mem_medium_lz4() { // Arrange let (c, _cas_data, raw_data, raw_chunk_boundaries) = - build_cas_object(32, 16384, false, CompressionScheme::LZ4); + build_cas_object(32, ChunkSize::Fixed(16384), CompressionScheme::LZ4); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -983,7 +924,7 @@ mod tests { fn test_serialization_deserialization_mem_large_random_lz4() { // Arrange let (c, _cas_data, raw_data, raw_chunk_boundaries) = - build_cas_object(32, 65536, true, CompressionScheme::LZ4); + build_cas_object(32, ChunkSize::Random(512, 65536), CompressionScheme::LZ4); let mut buf: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( @@ -1013,7 +954,7 @@ mod tests { fn test_serialization_deserialization_file_large_random_lz4() { // Arrange let (c, _cas_data, raw_data, raw_chunk_boundaries) = - build_cas_object(256, 65536, true, CompressionScheme::LZ4); + build_cas_object(256, ChunkSize::Random(512, 65536), CompressionScheme::LZ4); let mut writer: Cursor> = Cursor::new(Vec::new()); // Act & Assert assert!(CasObject::serialize( diff --git a/cas_object/src/lib.rs b/cas_object/src/lib.rs index 77beb0eb..15d1a6d3 100644 --- a/cas_object/src/lib.rs +++ b/cas_object/src/lib.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] mod cas_chunk_format; mod cas_object_format; mod compression_scheme; diff --git a/data/src/cas_interface.rs b/data/src/cas_interface.rs index 613a644b..7ca2f458 100644 --- a/data/src/cas_interface.rs +++ b/data/src/cas_interface.rs @@ -1,11 +1,11 @@ use crate::configurations::*; use crate::errors::Result; -use cas::auth::AuthConfig; use cas_client::RemoteClient; use std::env::current_dir; use std::path::Path; use std::sync::Arc; use tracing::info; +use utils::auth::AuthConfig; pub use cas_client::Client; diff --git a/data/src/configurations.rs b/data/src/configurations.rs index ccf6c031..8b659f8a 100644 --- a/data/src/configurations.rs +++ b/data/src/configurations.rs @@ -1,8 +1,8 @@ use crate::errors::Result; use crate::repo_salt::RepoSalt; -use cas::auth::AuthConfig; use std::path::PathBuf; use std::str::FromStr; +use utils::auth::AuthConfig; #[derive(Debug)] pub enum Endpoint { diff --git a/data/src/data_processing.rs b/data/src/data_processing.rs index 15a9f34f..da3dac5a 100644 --- a/data/src/data_processing.rs +++ b/data/src/data_processing.rs @@ -243,9 +243,9 @@ pub(crate) async fn register_new_cas_block( let chunk_boundaries = cas_data .chunks .iter() - .map(|(_, len)| { + .map(|(hash, len)| { pos += *len; - pos as u64 + (*hash, pos as u32) }) .collect(); diff --git a/data/src/errors.rs b/data/src/errors.rs index 5f1d52f8..251298e8 100644 --- a/data/src/errors.rs +++ b/data/src/errors.rs @@ -1,10 +1,10 @@ -use cas::errors::{AuthError, SingleflightError}; use cas_client::CasClientError; use mdb_shard::error::MDBShardError; use merkledb::error::MerkleDBError; use shard_client::error::ShardClientError; use std::string::FromUtf8Error; use std::sync::mpsc::RecvError; +use utils::errors::{AuthError, SingleflightError}; use xet_error::Error; #[derive(Error, Debug)] diff --git a/data/src/remote_shard_interface.rs b/data/src/remote_shard_interface.rs index 016c6400..14dfece0 100644 --- a/data/src/remote_shard_interface.rs +++ b/data/src/remote_shard_interface.rs @@ -4,7 +4,6 @@ use super::shard_interface::{create_shard_client, create_shard_manager}; use crate::cas_interface::Client; use crate::constants::{FILE_RECONSTRUCTION_CACHE_SIZE, MAX_CONCURRENT_UPLOADS}; use crate::repo_salt::RepoSalt; -use cas::singleflight; use lru::LruCache; use mdb_shard::constants::MDB_SHARD_MIN_TARGET_SIZE; use mdb_shard::session_directory::consolidate_shards_in_directory; @@ -21,6 +20,7 @@ use std::sync::Arc; use std::sync::Mutex; use tokio::task::JoinHandle; use tracing::{debug, info}; +use utils::singleflight; pub struct RemoteShardInterface { pub file_query_policy: FileQueryPolicy, diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index 7d90cadc..260ccaee 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -367,6 +367,7 @@ dependencies = [ "async-trait", "base64 0.13.1", "byteorder", + "cas_types", "chrono", "lazy_static", "lru", @@ -447,6 +448,7 @@ dependencies = [ "lz4_flex", "merkledb", "merklehash", + "rand 0.8.5", "tempfile", "tracing", "xet_error", @@ -1937,6 +1939,7 @@ dependencies = [ "tempfile", "tokio", "tracing", + "utils", "uuid", "xet_error", ] diff --git a/hf_xet/src/config.rs b/hf_xet/src/config.rs index a6c00235..d5a7bb8b 100644 --- a/hf_xet/src/config.rs +++ b/hf_xet/src/config.rs @@ -1,9 +1,9 @@ -use cas::auth::{AuthConfig, TokenRefresher}; use data::configurations::*; use data::errors; use std::env::current_dir; use std::fs; use std::sync::Arc; +use utils::auth::{AuthConfig, TokenRefresher}; pub const SMALL_FILE_THRESHOLD: usize = 1; diff --git a/hf_xet/src/data_client.rs b/hf_xet/src/data_client.rs index fc165767..77234c11 100644 --- a/hf_xet/src/data_client.rs +++ b/hf_xet/src/data_client.rs @@ -1,5 +1,5 @@ use crate::config::default_config; -use cas::auth::TokenRefresher; +use utils::auth::TokenRefresher; use data::errors::DataProcessingError; use data::{errors, PointerFile, PointerFileTranslator}; use parutils::{tokio_par_for_each, ParallelError}; @@ -111,7 +111,7 @@ async fn smudge_file( fs::create_dir_all(parent_dir)?; } let mut f: Box = Box::new(File::create(&path)?); - proc.smudge_file_from_pointer(&pointer_file, &mut f, None) + proc.smudge_file_from_pointer(pointer_file, &mut f, None) .await?; Ok(pointer_file.path().to_string()) } diff --git a/hf_xet/src/lib.rs b/hf_xet/src/lib.rs index 4a528a19..b6c6ef93 100644 --- a/hf_xet/src/lib.rs +++ b/hf_xet/src/lib.rs @@ -3,7 +3,7 @@ mod data_client; mod log; mod token_refresh; -use cas::auth::TokenRefresher; +use utils::auth::TokenRefresher; use data::PointerFile; use pyo3::exceptions::PyException; use pyo3::prelude::*; diff --git a/hf_xet/src/token_refresh.rs b/hf_xet/src/token_refresh.rs index 7d82f5be..e86dae66 100644 --- a/hf_xet/src/token_refresh.rs +++ b/hf_xet/src/token_refresh.rs @@ -1,10 +1,10 @@ -use cas::auth::{TokenInfo, TokenRefresher}; -use cas::errors::AuthError; use pyo3::exceptions::PyTypeError; use pyo3::prelude::PyAnyMethods; use pyo3::{Py, PyAny, PyErr, PyResult, Python}; use std::fmt::{Debug, Formatter}; use tracing::{error, info}; +use utils::auth::{TokenInfo, TokenRefresher}; +use utils::errors::AuthError; /// A wrapper struct of a python function to refresh the CAS auth token. /// Since tokens are generated by hub, we want to be able to refresh the diff --git a/mdb_shard/Cargo.toml b/mdb_shard/Cargo.toml index 70e51e8e..4d59057c 100644 --- a/mdb_shard/Cargo.toml +++ b/mdb_shard/Cargo.toml @@ -6,9 +6,11 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +merklehash = { path = "../merklehash" } +xet_error = {path = "../xet_error"} +utils = {path = "../utils"} more-asserts = "0.3.*" tempdir = "0.3.7" -merklehash = { path = "../merklehash" } serde = {version="1.0.129", features = ["derive"]} tokio = { version = "1.36", features = ["full"] } lazy_static = "1.4.0" @@ -21,7 +23,6 @@ tempfile = "3.2.0" clap = { version = "3.1.6", features = ["derive"] } anyhow = "1" rand = {version = "0.8.5", features = ["small_rng"]} -xet_error = {path = "../xet_error"} async-trait = "0.1.9" [[bin]] diff --git a/mdb_shard/src/cas_structs.rs b/mdb_shard/src/cas_structs.rs index 9b641402..33554995 100644 --- a/mdb_shard/src/cas_structs.rs +++ b/mdb_shard/src/cas_structs.rs @@ -1,8 +1,8 @@ -use crate::serialization_utils::*; use merklehash::MerkleHash; use std::fmt::Debug; use std::io::{Read, Write}; use std::mem::size_of; +use utils::serialization_utils::*; pub const MDB_DEFAULT_CAS_FLAG: u32 = 0; diff --git a/mdb_shard/src/file_structs.rs b/mdb_shard/src/file_structs.rs index c8a0633c..9555f6ac 100644 --- a/mdb_shard/src/file_structs.rs +++ b/mdb_shard/src/file_structs.rs @@ -1,9 +1,9 @@ use crate::cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader}; -use crate::serialization_utils::*; use merklehash::MerkleHash; use std::fmt::Debug; use std::io::{Cursor, Read, Write}; use std::mem::size_of; +use utils::serialization_utils::*; pub const MDB_DEFAULT_FILE_FLAG: u32 = 0; diff --git a/mdb_shard/src/serialization_utils.rs b/mdb_shard/src/interpolation_search.rs similarity index 86% rename from mdb_shard/src/serialization_utils.rs rename to mdb_shard/src/interpolation_search.rs index 32bf6abc..56289a50 100644 --- a/mdb_shard/src/serialization_utils.rs +++ b/mdb_shard/src/interpolation_search.rs @@ -1,64 +1,7 @@ -use merklehash::MerkleHash; use std::cmp::Ordering; -use std::io::{Read, Seek, SeekFrom, Write}; -use std::mem::{size_of, transmute}; - -pub fn write_hash(writer: &mut W, m: &MerkleHash) -> Result<(), std::io::Error> { - writer.write_all(m.as_bytes()) -} - -pub fn write_u32(writer: &mut W, v: u32) -> Result<(), std::io::Error> { - writer.write_all(&v.to_le_bytes()) -} - -pub fn write_u64(writer: &mut W, v: u64) -> Result<(), std::io::Error> { - writer.write_all(&v.to_le_bytes()) -} - -pub fn write_u32s(writer: &mut W, vs: &[u32]) -> Result<(), std::io::Error> { - for e in vs { - write_u32(writer, *e)?; - } - - Ok(()) -} - -pub fn write_u64s(writer: &mut W, vs: &[u64]) -> Result<(), std::io::Error> { - for e in vs { - write_u64(writer, *e)?; - } - - Ok(()) -} - -pub fn read_hash(reader: &mut R) -> Result { - let mut m = [0u8; 32]; - reader.read_exact(&mut m)?; // Not endian safe. - - Ok(MerkleHash::from(unsafe { - transmute::<[u8; 32], [u64; 4]>(m) - })) -} - -pub fn read_u32(reader: &mut R) -> Result { - let mut buf = [0u8; size_of::()]; - reader.read_exact(&mut buf[..])?; - Ok(u32::from_le_bytes(buf)) -} - -pub fn read_u64(reader: &mut R) -> Result { - let mut buf = [0u8; size_of::()]; - reader.read_exact(&mut buf[..])?; - Ok(u64::from_le_bytes(buf)) -} - -pub fn read_u64s(reader: &mut R, vs: &mut [u64]) -> Result<(), std::io::Error> { - for e in vs.iter_mut() { - *e = read_u64(reader)?; - } - - Ok(()) -} +use std::io::{Read, Seek, SeekFrom}; +use std::mem::size_of; +use utils::serialization_utils::*; /// Performs an interpolation search on a block of sorted, possibly multile /// u64 hash keys with a simple payload. diff --git a/mdb_shard/src/lib.rs b/mdb_shard/src/lib.rs index 3df56868..366a071a 100644 --- a/mdb_shard/src/lib.rs +++ b/mdb_shard/src/lib.rs @@ -2,7 +2,7 @@ pub mod cas_structs; pub mod constants; pub mod error; pub mod file_structs; -pub mod serialization_utils; +pub mod interpolation_search; pub mod session_directory; pub mod set_operations; pub mod shard_dedup_probe; diff --git a/mdb_shard/src/set_operations.rs b/mdb_shard/src/set_operations.rs index b0c5e9a1..51918463 100644 --- a/mdb_shard/src/set_operations.rs +++ b/mdb_shard/src/set_operations.rs @@ -2,7 +2,6 @@ use crate::error::Result; use crate::{ cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader}, file_structs::{FileDataSequenceEntry, FileDataSequenceHeader}, - serialization_utils::{write_u32, write_u64}, shard_format::{MDBShardFileFooter, MDBShardFileHeader, MDBShardInfo}, utils::truncate_hash, }; @@ -14,6 +13,7 @@ use std::{ mem::size_of, path::Path, }; +use utils::serialization_utils::*; use uuid::Uuid; #[derive(PartialEq, Debug, Copy, Clone)] diff --git a/mdb_shard/src/shard_format.rs b/mdb_shard/src/shard_format.rs index 24c27032..720835a9 100644 --- a/mdb_shard/src/shard_format.rs +++ b/mdb_shard/src/shard_format.rs @@ -2,7 +2,7 @@ use crate::cas_structs::*; use crate::constants::*; use crate::error::{MDBShardError, Result}; use crate::file_structs::*; -use crate::serialization_utils::*; +use crate::interpolation_search::search_on_sorted_u64s; use crate::shard_in_memory::MDBInMemoryShard; use crate::shard_version; use crate::utils::truncate_hash; @@ -13,6 +13,7 @@ use std::io::{Read, Seek, SeekFrom, Write}; use std::mem::size_of; use std::sync::Arc; use tracing::debug; +use utils::serialization_utils::*; // Same size for FileDataSequenceHeader and FileDataSequenceEntry const MDB_FILE_INFO_ENTRY_SIZE: u64 = (size_of::<[u64; 4]>() + 4 * size_of::()) as u64; diff --git a/progress_reporting/src/data_progress.rs b/progress_reporting/src/data_progress.rs index 2f8e9f72..6d2a6e01 100644 --- a/progress_reporting/src/data_progress.rs +++ b/progress_reporting/src/data_progress.rs @@ -1,9 +1,9 @@ -use cas::output_bytes; use crossterm::{cursor, QueueableCommand}; use std::io::{stderr, Write}; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Instant; +use utils::output_bytes; const MAX_PRINT_INTERVAL_MS: u64 = 250; diff --git a/shard_client/src/http_shard_client.rs b/shard_client/src/http_shard_client.rs index 8bf2fb33..2fb5e3ed 100644 --- a/shard_client/src/http_shard_client.rs +++ b/shard_client/src/http_shard_client.rs @@ -2,13 +2,11 @@ use crate::error::{Result, ShardClientError}; use crate::{RegistrationClient, ShardClientInterface}; use async_trait::async_trait; use bytes::Buf; -use cas::auth::AuthConfig; use cas_client::build_reqwest_client; use cas_types::Key; use cas_types::{QueryReconstructionResponse, UploadShardResponse, UploadShardResponseType}; use file_utils::write_all_safe; use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDBFileInfo}; -use mdb_shard::serialization_utils::read_u32; use mdb_shard::shard_dedup_probe::ShardDedupProber; use mdb_shard::shard_file_reconstructor::FileReconstructor; use merklehash::MerkleHash; @@ -20,6 +18,8 @@ use std::path::PathBuf; use std::str::FromStr; use tokio::task::JoinSet; use tracing::warn; +use utils::auth::AuthConfig; +use utils::serialization_utils::read_u32; const NUM_RETRIES: usize = 5; const BASE_RETRY_DELAY_MS: u64 = 3000; diff --git a/utils/Cargo.toml b/utils/Cargo.toml index 9542bd03..f34251a4 100644 --- a/utils/Cargo.toml +++ b/utils/Cargo.toml @@ -4,7 +4,7 @@ version = "0.14.5" edition = "2021" [lib] -name = "cas" +name = "utils" path = "src/lib.rs" [dependencies] diff --git a/utils/build.rs b/utils/build.rs index 2a073e56..5516d623 100644 --- a/utils/build.rs +++ b/utils/build.rs @@ -1,8 +1,6 @@ fn main() -> Result<(), Box> { tonic_build::configure().compile(&["proto/common.proto"], &["proto"])?; - tonic_build::configure().compile(&["proto/cas.proto"], &["proto"])?; tonic_build::configure().compile(&["proto/infra.proto"], &["proto"])?; tonic_build::configure().compile(&["proto/alb.proto"], &["proto"])?; - tonic_build::configure().compile(&["proto/shard.proto"], &["proto"])?; Ok(()) } diff --git a/utils/examples/infra.rs b/utils/examples/infra.rs index d44515f3..f463e5d4 100644 --- a/utils/examples/infra.rs +++ b/utils/examples/infra.rs @@ -1,9 +1,9 @@ -use cas::common::Empty; -use cas::consistenthash::ConsistentHash; -use cas::infra::infra_utils_client::InfraUtilsClient; use clap::Parser; use http::Uri; use tonic::transport::Channel; +use utils::common::Empty; +use utils::consistenthash::ConsistentHash; +use utils::infra::infra_utils_client::InfraUtilsClient; pub type InfraUtilsClientType = InfraUtilsClient; pub async fn get_infra_client(server_name: &str) -> anyhow::Result { diff --git a/utils/proto/cas.proto b/utils/proto/cas.proto deleted file mode 100644 index e6a7daee..00000000 --- a/utils/proto/cas.proto +++ /dev/null @@ -1,81 +0,0 @@ -/* - * https://www.notion.so/cantorsystems/GlodHub-Client-Architecture-0992cff0f95e4203bf7763e9951f1fe8 - */ -syntax = "proto3"; -package cas; -import public "common.proto"; - -// The CAS (Content Addressed Storage) service. -service Cas { - // Initiates uploads of an object. - rpc Initiate(common.InitiateRequest) returns (common.InitiateResponse); - - // Uploads the provided bytes to an object. - // This will also verify the hash is correct. - rpc Put(PutRequest) returns (PutResponse); - - // Completes uploads of an object. This verifies that hash is correct. - rpc PutComplete(PutCompleteRequest) returns (PutCompleteResponse); - - - // Downloads all bytes for the indicated object. - rpc Get(GetRequest) returns (GetResponse); - - - // Downloads a set of ranges within an object. - rpc GetRange(GetRangeRequest) returns (GetRangeResponse); - - // Retrieve metadata about a particular object. - rpc Head(HeadRequest) returns (HeadResponse); -} - - - -message PutRequest { - common.Key key = 1; - bytes data = 2; - repeated uint64 chunk_boundaries = 3; -} - -message PutResponse { - bool was_inserted = 1; -} - - -message PutCompleteRequest { - common.Key key = 1; - repeated uint64 chunk_boundaries = 2; -} - -message PutCompleteResponse { - bool was_inserted = 1; -} - -message GetRequest { - common.Key key = 1; -} - -message GetResponse { - bytes data = 1; -} -message GetRangeRequest { - common.Key key = 1; - repeated Range ranges = 2; -} - -message GetRangeResponse { - repeated bytes data = 1; -} - -message HeadRequest { - common.Key key = 1; -} - -message HeadResponse { - uint64 size = 1; -} - -message Range { - uint64 start = 1; - uint64 end = 2; -} diff --git a/utils/proto/shard.proto b/utils/proto/shard.proto deleted file mode 100644 index 644533fc..00000000 --- a/utils/proto/shard.proto +++ /dev/null @@ -1,69 +0,0 @@ -/* - * https://www.notion.so/xethub/MerkleDBv2-Xet-CLI-Architecture-62c3177c92834864883bd3fa442feadc - * https://www.notion.so/xethub/MerkleDBv2-The-Final-Stage-cc654b5266294d399503c3431131fafa - */ -syntax = "proto3"; -package shard; -import public "common.proto"; - -// The Shard service. -service Shard { - // Queries for file->shard information. - rpc QueryFile(QueryFileRequest) returns (QueryFileResponse); - - // Synchronizes a shard from CAS to the Shard Service for querying - rpc SyncShard(SyncShardRequest) returns (SyncShardResponse); - - // Queries for chunk->shard information. - rpc QueryChunk(QueryChunkRequest) returns (QueryChunkResponse); - - // SyncShard + synchronizes chunk->shard information to the Shard Service - rpc SyncShardWithSalt(SyncShardWithSaltRequest) returns (SyncShardResponse); -} -message QueryFileRequest { - bytes file_id = 1; -} - -message Range { - uint64 start = 1; - uint64 end = 2; -} - -message CASReconstructionTerm { - bytes cas_id = 1; - uint64 unpacked_length = 2; - Range range = 3; -} - -message QueryFileResponse { - repeated CASReconstructionTerm reconstruction = 1; - common.Key shard_id = 2; -} - -message SyncShardRequest { - common.Key key = 1; - bool force_sync = 2; -} - -enum SyncShardResponseType { - Exists = 0; - SyncPerformed = 1; -} - -message SyncShardResponse { - SyncShardResponseType response = 1; -} - -message QueryChunkRequest { - string prefix = 1; - repeated bytes chunk = 2; -} - -message QueryChunkResponse { - repeated bytes shard = 1; -} - -message SyncShardWithSaltRequest { - SyncShardRequest ssr = 1; - bytes salt = 2; -} diff --git a/utils/src/key.rs b/utils/src/key.rs deleted file mode 100644 index 95126167..00000000 --- a/utils/src/key.rs +++ /dev/null @@ -1,122 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::fmt::{Display, Formatter}; - -use merklehash::MerkleHash; - -use crate::errors::KeyError; - -/// A Key indicates a prefixed merkle hash for some data stored in the CAS DB. -#[derive(Debug, PartialEq, Default, Serialize, Deserialize, Ord, PartialOrd, Eq, Hash, Clone)] -pub struct Key { - pub prefix: String, - pub hash: MerkleHash, -} - -impl TryFrom<&crate::common::Key> for Key { - type Error = KeyError; - - fn try_from(proto_key: &crate::common::Key) -> Result { - let hash = MerkleHash::try_from(proto_key.hash.as_slice()) - .map_err(|e| KeyError::UnparsableKey(format!("{e:?}")))?; - Ok(Key { - prefix: proto_key.prefix.clone(), - hash, - }) - } -} - -impl Display for Key { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}/{:x}", self.prefix, self.hash) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_from_proto() { - let proto = crate::common::Key { - prefix: "abc".to_string(), - hash: [1u8; 32].to_vec(), - }; - let key = Key::try_from(&proto).unwrap(); - assert_eq!(proto.prefix, key.prefix); - assert_eq!(&proto.hash, key.hash.as_bytes()); - } - - #[test] - fn test_from_invalid_proto() { - let proto = crate::common::Key { - prefix: "".to_string(), - hash: vec![1, 2, 3], - }; - let res = Key::try_from(&proto); - assert!(res.is_err()); - assert!(matches!(res, Err(KeyError::UnparsableKey(..)))) - } - - #[test] - fn test_display() { - let proto = crate::common::Key { - prefix: "abc".to_string(), - hash: [1u8; 32].to_vec(), - }; - let key = Key::try_from(&proto).unwrap(); - assert_eq!( - "abc/0101010101010101010101010101010101010101010101010101010101010101", - key.to_string() - ); - } - - #[test] - fn test_equality() { - let proto1 = crate::common::Key { - prefix: "abc".to_string(), - hash: [1u8; 32].to_vec(), - }; - let key1 = Key::try_from(&proto1).unwrap(); - assert_eq!(key1, key1); - - let proto2 = crate::common::Key { - prefix: "abc".to_string(), - hash: [1u8; 32].to_vec(), - }; - let key2 = Key::try_from(&proto2).unwrap(); - assert_eq!(key2, key1); - assert_eq!(key1, key2); // symmetry - - let proto3 = crate::common::Key { - prefix: "abc".to_string(), - hash: [1u8; 32].to_vec(), - }; - let key3 = Key::try_from(&proto3).unwrap(); - assert_eq!(key1, key3); - assert_eq!(key2, key3); // transitive - } - - #[test] - fn test_inequality() { - let proto1 = crate::common::Key { - prefix: "abc".to_string(), - hash: [1u8; 32].to_vec(), - }; - let key1 = Key::try_from(&proto1).unwrap(); - - let proto2 = crate::common::Key { - prefix: "def".to_string(), - hash: [1u8; 32].to_vec(), - }; - let key2 = Key::try_from(&proto2).unwrap(); - assert_ne!(key2, key1); - - let proto3 = crate::common::Key { - prefix: "abc".to_string(), - hash: [2u8; 32].to_vec(), - }; - let key3 = Key::try_from(&proto3).unwrap(); - assert_ne!(key1, key3); - assert_ne!(key2, key3); - } -} diff --git a/utils/src/lib.rs b/utils/src/lib.rs index 06709a2e..30387cbc 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -1,9 +1,5 @@ #![cfg_attr(feature = "strict", deny(warnings))] -use std::ops::Range; - -use tonic::Status; - // The auto code generated by tonic has clippy warnings. Disabling until those are // resolved in tonic/prost. pub mod common { @@ -11,11 +7,6 @@ pub mod common { tonic::include_proto!("common"); } -pub mod cas { - #![allow(clippy::derive_partial_eq_without_eq)] - tonic::include_proto!("cas"); -} - pub mod infra { #![allow(clippy::derive_partial_eq_without_eq)] tonic::include_proto!("infra"); @@ -26,52 +17,20 @@ pub mod alb { tonic::include_proto!("aws"); } -pub mod shard { - #![allow(clippy::derive_partial_eq_without_eq)] - tonic::include_proto!("shard"); -} - +pub mod auth; pub mod consistenthash; pub mod constants; pub mod errors; pub mod gitbaretools; -pub mod key; -pub mod safeio; +pub mod serialization_utils; pub mod singleflight; pub mod version; -pub mod auth; mod output_bytes; use crate::common::{CompressionScheme, InitiateResponse}; pub use output_bytes::output_bytes; -impl TryFrom for Range { - type Error = Status; - - fn try_from(range_proto: cas::Range) -> Result { - if range_proto.start > range_proto.end { - return Err(Status::failed_precondition(format!( - "Range: {range_proto:?} has an end smaller than the start" - ))); - } - Ok(range_proto.start..range_proto.end) - } -} - -impl TryFrom<&cas::Range> for Range { - type Error = Status; - - fn try_from(range_proto: &cas::Range) -> Result { - if range_proto.start > range_proto.end { - return Err(Status::failed_precondition(format!( - "Range: {range_proto:?} has an end smaller than the start" - ))); - } - Ok(range_proto.start..range_proto.end) - } -} - impl std::fmt::Display for common::Scheme { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let as_str = match self { @@ -107,29 +66,6 @@ mod tests { use super::*; - #[test] - fn test_range_conversion() { - let r = cas::Range { start: 0, end: 10 }; - let range = Range::try_from(r).unwrap(); - assert_eq!(range.start, 0); - assert_eq!(range.end, 10); - } - - #[test] - fn test_range_conversion_zero_len() { - let r = cas::Range { start: 10, end: 10 }; - let range = Range::try_from(r).unwrap(); - assert_eq!(range.start, 10); - assert_eq!(range.end, 10); - } - - #[test] - fn test_range_conversion_failed() { - let r = cas::Range { start: 20, end: 10 }; - let res = Range::try_from(r); - assert!(res.is_err()); - } - #[test] fn test_endpoint_config_to_endpoint_string() { let host = "xetxet"; diff --git a/utils/src/safeio.rs b/utils/src/safeio.rs deleted file mode 100644 index 8e2163c9..00000000 --- a/utils/src/safeio.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::{ - io::{self, Write}, - path::Path, -}; -use tempfile::NamedTempFile; - -/// Write all bytes -pub fn write_all_file_safe(path: &Path, bytes: &[u8]) -> io::Result<()> { - if !path.as_os_str().is_empty() { - let dir = path.parent().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("Unable to find parent path from {path:?}"), - ) - })?; - - // Make sure dir exists. - if !dir.exists() { - std::fs::create_dir_all(dir)?; - } - - let mut tempfile = create_temp_file(dir, "")?; - tempfile.write_all(bytes)?; - tempfile.persist(path).map_err(|e| e.error)?; - } - - Ok(()) -} - -pub fn create_temp_file(dir: &Path, suffix: &str) -> io::Result { - let tempfile = tempfile::Builder::new() - .prefix(&format!("{}.", std::process::id())) - .suffix(suffix) - .tempfile_in(dir)?; - - Ok(tempfile) -} - -#[cfg(test)] -mod test { - use anyhow::Result; - use std::fs; - use tempfile::TempDir; - - use super::write_all_file_safe; - - #[test] - fn test_small_file_write() -> Result<()> { - let tmp_dir = TempDir::new()?; - let bytes = vec![1u8; 1000]; - let file_name = tmp_dir.path().join("data"); - - write_all_file_safe(&file_name, &bytes)?; - - assert_eq!(fs::read(file_name)?, bytes); - - Ok(()) - } -} diff --git a/utils/src/serialization_utils.rs b/utils/src/serialization_utils.rs new file mode 100644 index 00000000..4c64ba24 --- /dev/null +++ b/utils/src/serialization_utils.rs @@ -0,0 +1,60 @@ +use merklehash::MerkleHash; +use std::io::{Read, Write}; +use std::mem::{size_of, transmute}; + +pub fn write_hash(writer: &mut W, m: &MerkleHash) -> Result<(), std::io::Error> { + writer.write_all(m.as_bytes()) +} + +pub fn write_u32(writer: &mut W, v: u32) -> Result<(), std::io::Error> { + writer.write_all(&v.to_le_bytes()) +} + +pub fn write_u64(writer: &mut W, v: u64) -> Result<(), std::io::Error> { + writer.write_all(&v.to_le_bytes()) +} + +pub fn write_u32s(writer: &mut W, vs: &[u32]) -> Result<(), std::io::Error> { + for e in vs { + write_u32(writer, *e)?; + } + + Ok(()) +} + +pub fn write_u64s(writer: &mut W, vs: &[u64]) -> Result<(), std::io::Error> { + for e in vs { + write_u64(writer, *e)?; + } + + Ok(()) +} + +pub fn read_hash(reader: &mut R) -> Result { + let mut m = [0u8; 32]; + reader.read_exact(&mut m)?; // Not endian safe. + + Ok(MerkleHash::from(unsafe { + transmute::<[u8; 32], [u64; 4]>(m) + })) +} + +pub fn read_u32(reader: &mut R) -> Result { + let mut buf = [0u8; size_of::()]; + reader.read_exact(&mut buf[..])?; + Ok(u32::from_le_bytes(buf)) +} + +pub fn read_u64(reader: &mut R) -> Result { + let mut buf = [0u8; size_of::()]; + reader.read_exact(&mut buf[..])?; + Ok(u64::from_le_bytes(buf)) +} + +pub fn read_u64s(reader: &mut R, vs: &mut [u64]) -> Result<(), std::io::Error> { + for e in vs.iter_mut() { + *e = read_u64(reader)?; + } + + Ok(()) +} diff --git a/utils/src/singleflight.rs b/utils/src/singleflight.rs index 982c8cc9..4e7b96d3 100644 --- a/utils/src/singleflight.rs +++ b/utils/src/singleflight.rs @@ -9,7 +9,7 @@ //! use std::sync::Arc; //! use std::time::Duration; //! -//! use cas::singleflight::Group; +//! use utils::singleflight::Group; //! //! const RES: usize = 7; //! From adc982e438ed2ed4d2abab91360e9225320cb3a6 Mon Sep 17 00:00:00 2001 From: seanses Date: Wed, 2 Oct 2024 14:33:10 -0700 Subject: [PATCH 11/19] fix typo --- data/src/data_processing.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/src/data_processing.rs b/data/src/data_processing.rs index da3dac5a..f72ae9cc 100644 --- a/data/src/data_processing.rs +++ b/data/src/data_processing.rs @@ -61,7 +61,7 @@ pub struct PointerFileTranslator { global_cas_data: Arc>, } -// Constructorscas_data_accumulator +// Constructors impl PointerFileTranslator { pub async fn new(config: TranslatorConfig) -> Result { let cas_client = create_cas_client(&config.cas_storage_config, &config.repo_info)?; From 9718b993189ac67730fea5cc6511f99f6881dba2 Mon Sep 17 00:00:00 2001 From: seanses Date: Wed, 2 Oct 2024 14:43:08 -0700 Subject: [PATCH 12/19] make CasObject::get_byte_offset public --- cas_object/src/cas_object_format.rs | 2 +- cas_object/src/lib.rs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index 3087967c..e97509ce 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -461,7 +461,7 @@ impl CasObject { } /// Helper function to translate a range of chunk indices to physical byte offset range. - fn get_byte_offset( + pub fn get_byte_offset( &self, chunk_index_start: u32, chunk_index_end: u32, diff --git a/cas_object/src/lib.rs b/cas_object/src/lib.rs index 15d1a6d3..77beb0eb 100644 --- a/cas_object/src/lib.rs +++ b/cas_object/src/lib.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] mod cas_chunk_format; mod cas_object_format; mod compression_scheme; From 950817de42ccecb9af36696ed9ec39db32037111 Mon Sep 17 00:00:00 2001 From: seanses Date: Wed, 2 Oct 2024 14:51:26 -0700 Subject: [PATCH 13/19] move range start offset from each reconstruction term to global, because for a file range query only the first may be non-zero --- cas_types/src/lib.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cas_types/src/lib.rs b/cas_types/src/lib.rs index 0b88ccf0..57090c62 100644 --- a/cas_types/src/lib.rs +++ b/cas_types/src/lib.rs @@ -23,7 +23,6 @@ pub struct CASReconstructionTerm { pub unpacked_length: u32, // chunk index start and end in a xorb pub range: Range, - pub range_start_offset: u32, pub url: String, // byte index start and end in a xorb pub url_range: Range, @@ -31,6 +30,9 @@ pub struct CASReconstructionTerm { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct QueryReconstructionResponse { + // For range query [a, b) into a file content, the location + // of "a" into the first range. + pub offset_into_first_range: u32, pub reconstruction: Vec, } From 804b0cef79bb52783c01b944392cc6ecc8f93b5d Mon Sep 17 00:00:00 2001 From: seanses Date: Wed, 2 Oct 2024 15:01:43 -0700 Subject: [PATCH 14/19] make test util struct ChunkSize Debug, Clone and Copy, and impl Display for it --- cas_object/src/cas_object_format.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index e97509ce..9b38a671 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -519,11 +519,21 @@ pub mod test_utils { data } + #[derive(Debug, Clone, Copy)] pub enum ChunkSize { Random(u32, u32), Fixed(u32), } + impl std::fmt::Display for ChunkSize { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ChunkSize::Random(a, b) => write!(f, "[{a}, {b}]"), + ChunkSize::Fixed(a) => write!(f, "{a}"), + } + } + } + /// Utility test method for creating a cas object /// Returns (CasObject, chunks serialized, raw data, raw data chunk boundaries) #[allow(clippy::type_complexity)] From d51a5dcbe733d39c2e465e4bb1d108af23037fbb Mon Sep 17 00:00:00 2001 From: seanses Date: Wed, 2 Oct 2024 15:27:17 -0700 Subject: [PATCH 15/19] bug fix and chunk uncompress length check --- cas_client/src/remote_client.rs | 23 ++++++++--------------- cas_object/src/cas_chunk_format.rs | 19 ++++++++++++++----- cas_object/src/error.rs | 6 ------ 3 files changed, 22 insertions(+), 26 deletions(-) diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 3026b6c5..2ef7b00c 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -12,7 +12,7 @@ use merklehash::MerkleHash; use reqwest::{StatusCode, Url}; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware}; use std::io::{Cursor, Write}; -use tracing::{debug, warn}; +use tracing::{debug, error}; use utils::auth::AuthConfig; pub const CAS_ENDPOINT: &str = "http://localhost:8080"; @@ -143,15 +143,9 @@ impl RemoteClient { ) -> Result { let info = reconstruction_response.reconstruction; let total_len = info.iter().fold(0, |acc, x| acc + x.unpacked_length); - let futs = info.into_iter().map(|term| { - tokio::spawn(async move { - let piece = get_one(&term).await?; - if piece.len() != (term.range.end - term.range.start) as usize { - warn!("got back a smaller range than requested"); - } - Result::::Ok(piece) - }) - }); + let futs = info + .into_iter() + .map(|term| tokio::spawn(async move { Result::::Ok(get_one(&term).await?) })); for fut in futs { let piece = fut .await @@ -205,14 +199,13 @@ async fn get_one(term: &CASReconstructionTerm) -> Result { .bytes() .await .map_err(CasClientError::ReqwestError)?; + if xorb_bytes.len() as u32 != term.url_range.end - term.url_range.start { + error!("got back a smaller range than requested"); + } let mut readseek = Cursor::new(xorb_bytes.to_vec()); let data = cas_object::deserialize_chunks(&mut readseek)?; - let len = (term.range.end - term.range.start) as usize; - let offset = term.range_start_offset as usize; - - let sliced = data[offset..offset + len].to_vec(); - Ok(Bytes::from(sliced)) + Ok(Bytes::from(data)) } /// builds the client to talk to CAS. diff --git a/cas_object/src/cas_chunk_format.rs b/cas_object/src/cas_chunk_format.rs index 7ece83e7..350a6ccf 100644 --- a/cas_object/src/cas_chunk_format.rs +++ b/cas_object/src/cas_chunk_format.rs @@ -5,8 +5,8 @@ use std::{ }; use crate::error::CasObjectError; -use anyhow::anyhow; use crate::CompressionScheme; +use anyhow::anyhow; use lz4_flex::frame::{FrameDecoder, FrameEncoder}; pub const CAS_CHUNK_HEADER_LENGTH: usize = size_of::(); @@ -145,14 +145,23 @@ pub fn deserialize_chunk_to_writer( let mut compressed_buf = vec![0u8; header.get_compressed_length() as usize]; reader.read_exact(&mut compressed_buf)?; - match header.get_compression_scheme() { - CompressionScheme::None => writer.write_all(&compressed_buf)?, + let uncompressed_len = match header.get_compression_scheme() { + CompressionScheme::None => { + writer.write_all(&compressed_buf)?; + compressed_buf.len() as u32 + } CompressionScheme::LZ4 => { let mut dec = FrameDecoder::new(Cursor::new(compressed_buf)); - copy(&mut dec, writer)?; + copy(&mut dec, writer)? as u32 } }; + if uncompressed_len != header.get_uncompressed_length() { + return Err(CasObjectError::FormatError(anyhow!( + "chunk is corrupted, uncompressed bytes len doesn't agree with chunk header" + ))); + } + Ok(header.get_compressed_length() as usize + CAS_CHUNK_HEADER_LENGTH) } @@ -191,8 +200,8 @@ mod tests { use std::io::Cursor; use super::*; - use CompressionScheme; use rand::Rng; + use CompressionScheme; const COMP_LEN: u32 = 0x010203; const UNCOMP_LEN: u32 = 0x040506; diff --git a/cas_object/src/error.rs b/cas_object/src/error.rs index e63fb8e4..d334d420 100644 --- a/cas_object/src/error.rs +++ b/cas_object/src/error.rs @@ -1,6 +1,4 @@ use std::convert::Infallible; - -use merklehash::MerkleHash; use xet_error::Error; #[non_exhaustive] @@ -26,9 +24,6 @@ pub enum CasObjectError { #[error("Internal Hash Parsing Error")] HashParsingError(#[from] Infallible), - - #[error("CAS Hash not found")] - XORBNotFound(MerkleHash), } // Define our own result type here (this seems to be the standard). @@ -37,7 +32,6 @@ pub type Result = std::result::Result; impl PartialEq for CasObjectError { fn eq(&self, other: &CasObjectError) -> bool { match (self, other) { - (CasObjectError::XORBNotFound(a), CasObjectError::XORBNotFound(b)) => a == b, (e1, e2) => std::mem::discriminant(e1) == std::mem::discriminant(e2), } } From 0a4b584297fbdf6bc101d5da0dd98f7185bb1ae3 Mon Sep 17 00:00:00 2001 From: seanses Date: Wed, 2 Oct 2024 15:32:12 -0700 Subject: [PATCH 16/19] fix linting --- cas_client/src/remote_client.rs | 2 +- cas_object/src/error.rs | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 2ef7b00c..2b01c241 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -145,7 +145,7 @@ impl RemoteClient { let total_len = info.iter().fold(0, |acc, x| acc + x.unpacked_length); let futs = info .into_iter() - .map(|term| tokio::spawn(async move { Result::::Ok(get_one(&term).await?) })); + .map(|term| tokio::spawn(async move { get_one(&term).await })); for fut in futs { let piece = fut .await diff --git a/cas_object/src/error.rs b/cas_object/src/error.rs index d334d420..1e9f3310 100644 --- a/cas_object/src/error.rs +++ b/cas_object/src/error.rs @@ -31,8 +31,6 @@ pub type Result = std::result::Result; impl PartialEq for CasObjectError { fn eq(&self, other: &CasObjectError) -> bool { - match (self, other) { - (e1, e2) => std::mem::discriminant(e1) == std::mem::discriminant(e2), - } + std::mem::discriminant(self) == std::mem::discriminant(other) } } From 302b8c3aaebf1a85d98083b28172886c5402a285 Mon Sep 17 00:00:00 2001 From: Rajat Arya Date: Wed, 2 Oct 2024 18:55:16 -0700 Subject: [PATCH 17/19] Fix range logic and added range unit-tests --- cas_object/src/cas_object_format.rs | 73 +++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 4 deletions(-) diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index 9b38a671..c1f0a149 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -365,7 +365,10 @@ impl CasObject { /// Generate a hash for securing a chunk range. /// - /// chunk_start_index, chunk_end_index: indices for chunks in CasObject. + /// chunk_start_index, chunk_end_index: indices for chunks in CasObject. + /// The indices should be [start, end) - meaning start is inclusive and end is exclusive. + /// Ex. For specifying the 1st chunk: chunk_start_index: 0, chunk_end_index: 1 + /// /// key: additional key incorporated into generating hash. /// /// This hash ensures validity of the knowledge of chunks, since ranges are public, @@ -379,13 +382,12 @@ impl CasObject { ) -> Result { self.validate_cas_object_info()?; - if chunk_end_index < chunk_start_index || chunk_end_index > self.info.num_chunks { + if chunk_end_index <= chunk_start_index || chunk_end_index > self.info.num_chunks { return Err(CasObjectError::InvalidArguments); } // Collect relevant hashes - let range_hashes = - self.info.chunk_hashes[chunk_start_index as usize..chunk_end_index as usize].as_ref(); + let range_hashes = self.info.chunk_hashes[chunk_start_index as usize..chunk_end_index as usize].as_ref(); // TODO: Make this more robust, currently appends range hashes together, adds key to end let mut combined: Vec = range_hashes @@ -690,7 +692,70 @@ mod tests { let range_hash = c.generate_chunk_range_hash(0, 3, &key).unwrap(); assert_eq!(range_hash, expected_hash); } + + #[test] + fn test_generate_range_hash_partial() { + // Arrange + let (c, _cas_data, _raw_data, _raw_chunk_boundaries) = + build_cas_object(5, ChunkSize::Fixed(100), CompressionScheme::None); + let key = [b'K', b'E', b'Y', b'B', b'A', b'B', b'Y']; + + let mut hashes : Vec = c.info.chunk_hashes.as_slice()[1..=3].to_vec().iter().flat_map(|hash| hash.as_bytes().to_vec()).collect(); + hashes.extend_from_slice(&key); + let expected_hash = merklehash::compute_data_hash(&hashes); + + // Act & Assert + let range_hash = c.generate_chunk_range_hash(1, 4, &key).unwrap(); + assert_eq!(range_hash, expected_hash); + let mut hashes : Vec = c.info.chunk_hashes.as_slice()[0..1].to_vec().iter().flat_map(|hash| hash.as_bytes().to_vec()).collect(); + hashes.extend_from_slice(&key); + let expected_hash = merklehash::compute_data_hash(&hashes); + + let range_hash = c.generate_chunk_range_hash(0, 1, &key).unwrap(); + assert_eq!(range_hash, expected_hash); + } + + #[test] + fn test_generate_range_hash_invalid_range() { + // Arrange + let (c, _cas_data, _raw_data, _raw_chunk_boundaries) = + build_cas_object(5, ChunkSize::Fixed(100), CompressionScheme::None); + let key = [b'K', b'E', b'Y', b'B', b'A', b'B', b'Y']; + + // Act & Assert + assert_eq!(c.generate_chunk_range_hash(1, 6, &key), Err(CasObjectError::InvalidArguments)); + assert_eq!(c.generate_chunk_range_hash(100, 10, &key), Err(CasObjectError::InvalidArguments)); + assert_eq!(c.generate_chunk_range_hash(0, 0, &key), Err(CasObjectError::InvalidArguments)); + } + + #[test] + fn test_validate_cas_object_info() { + // Arrange & Act & Assert + let (c, _cas_data, _raw_data, _raw_chunk_boundaries) = + build_cas_object(5, ChunkSize::Fixed(100), CompressionScheme::None); + let result = c.validate_cas_object_info(); + assert!(result.is_ok()); + + // no chunks + let c = CasObject::default(); + let result = c.validate_cas_object_info(); + assert_eq!(result, Err(CasObjectError::FormatError(anyhow!("Invalid CasObjectInfo, no chunks in CasObject.")))); + + // num_chunks doesn't match chunk_boundaries.len() + let mut c = CasObject::default(); + c.info.num_chunks = 1; + let result = c.validate_cas_object_info(); + assert_eq!(result, Err(CasObjectError::FormatError(anyhow!("Invalid CasObjectInfo, num chunks not matching boundaries or hashes.")))); + + // no hash + let (mut c, _cas_data, _raw_data, _raw_chunk_boundaries) = + build_cas_object(1, ChunkSize::Fixed(100), CompressionScheme::None); + c.info.cashash = MerkleHash::default(); + let result = c.validate_cas_object_info(); + assert_eq!(result, Err(CasObjectError::FormatError(anyhow!("Invalid CasObjectInfo, Missing cashash.")))); + } + #[test] fn test_compress_decompress() { // Arrange From 715ee10eb7ce7b04566b89ea9803b5d69f580936 Mon Sep 17 00:00:00 2001 From: seanses Date: Fri, 4 Oct 2024 12:05:19 -0700 Subject: [PATCH 18/19] cas client interface update; integration tests --- Cargo.toml | 2 - cas_client/src/caching_client.rs | 46 ++++- cas_client/src/error.rs | 9 +- cas_client/src/interface.rs | 96 +++++----- cas_client/src/lib.rs | 3 +- cas_client/src/local_client.rs | 45 +++-- cas_client/src/remote_client.rs | 53 +++--- cas_object/src/cas_object_format.rs | 24 ++- data/Cargo.toml | 5 - data/src/cas_interface.rs | 41 ++++- data/src/data_processing.rs | 8 +- data/src/lib.rs | 1 + data/src/test_utils/local_test_client.rs | 97 ++++++++++ data/src/test_utils/mod.rs | 3 + data/tests/integration_tests.rs | 123 +++++++++++++ data/tests/integration_tests/initialize.sh | 168 ++++++++++++++++++ .../test_basic_clean_smudge.sh | 32 ++++ 17 files changed, 640 insertions(+), 116 deletions(-) create mode 100644 data/src/test_utils/local_test_client.rs create mode 100644 data/src/test_utils/mod.rs create mode 100644 data/tests/integration_tests.rs create mode 100644 data/tests/integration_tests/initialize.sh create mode 100644 data/tests/integration_tests/test_basic_clean_smudge.sh diff --git a/Cargo.toml b/Cargo.toml index 60031d16..41d09a43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,4 @@ [workspace] -rust-version = "1.79" - resolver = "2" members = [ diff --git a/cas_client/src/caching_client.rs b/cas_client/src/caching_client.rs index d4519ff7..1b84c667 100644 --- a/cas_client/src/caching_client.rs +++ b/cas_client/src/caching_client.rs @@ -1,16 +1,20 @@ #![allow(unused_variables)] - use crate::error::Result; use crate::interface::*; use async_trait::async_trait; +use cas_types::QueryReconstructionResponse; use merklehash::MerkleHash; use std::io::Write; +use std::path::Path; #[derive(Debug)] -pub struct CachingClient {} +#[allow(private_bounds)] +pub struct CachingClient { + client: T, +} #[async_trait] -impl UploadClient for CachingClient { +impl UploadClient for CachingClient { async fn put( &self, prefix: &str, @@ -31,8 +35,22 @@ impl UploadClient for CachingClient { } #[async_trait] -impl ReconstructionClient for CachingClient { +impl ReconstructionClient for CachingClient { async fn get_file(&self, hash: &MerkleHash, writer: &mut Box) -> Result<()> { + /* + let file_info = self.reconstruct(hash, None).await?; + + for entry in file_info.reconstruction { + if let Some(bytes) = self.cache.get(entry.hash, entry.range) { + // write out + } else { + let bytes = crate::get_one_range(&entry).await?; + // put into cache + // write out + } + } + */ + todo!() } @@ -47,4 +65,22 @@ impl ReconstructionClient for CachingClient { } } -impl Client for CachingClient {} +#[async_trait] +impl Reconstructable for CachingClient { + async fn reconstruct( + &self, + hash: &MerkleHash, + byte_range: Option<(u64, u64)>, + ) -> Result { + self.reconstruct(hash, byte_range).await + } +} + +impl Client for CachingClient {} + +#[allow(private_bounds)] +impl CachingClient { + pub fn new(client: T, cache_directory: &Path, cache_size: u64) -> Self { + Self { client } + } +} diff --git a/cas_client/src/error.rs b/cas_client/src/error.rs index 5dc65b8b..18195567 100644 --- a/cas_client/src/error.rs +++ b/cas_client/src/error.rs @@ -17,16 +17,16 @@ pub enum CasClientError { #[error("Invalid Arguments")] InvalidArguments, - #[error("Hash Mismatch")] - HashMismatch, + #[error("File not found for hash: {0}")] + FileNotFound(MerkleHash), #[error("IO Error: {0}")] IOError(#[from] std::io::Error), #[error("Other Internal Error: {0}")] - InternalError(anyhow::Error), + InternalError(#[from] anyhow::Error), - #[error("CAS Hash not found")] + #[error("CAS object not found for hash: {0}")] XORBNotFound(MerkleHash), #[error("Cas Object Error: {0}")] @@ -37,6 +37,7 @@ pub enum CasClientError { #[error("ReqwestMiddleware Error: {0}")] ReqwestMiddlewareError(#[from] reqwest_middleware::Error), + #[error("Reqwest Error: {0}")] ReqwestError(#[from] reqwest::Error), diff --git a/cas_client/src/interface.rs b/cas_client/src/interface.rs index 8756dfbf..690e9acf 100644 --- a/cas_client/src/interface.rs +++ b/cas_client/src/interface.rs @@ -1,7 +1,8 @@ use crate::error::Result; use async_trait::async_trait; +use cas_types::QueryReconstructionResponse; use merklehash::MerkleHash; -use std::{io::Write, sync::Arc}; +use std::io::Write; /// A Client to the CAS (Content Addressed Storage) service to allow storage and /// management of XORBs (Xet Object Remote Block). A XORB represents a collection @@ -53,51 +54,64 @@ pub trait ReconstructionClient { ) -> Result<()>; } -/* - * If T implements Client, Arc also implements Client - */ +pub trait Client: UploadClient + ReconstructionClient {} + +/// A Client to the CAS (Content Addressed Storage) service that is able to obtain +/// the reconstruction info of a file by FileID (MerkleHash). +/// This trait is meant for internal (caching): external users to this crate don't +/// access these trait functions. #[async_trait] -impl UploadClient for Arc { - async fn put( +pub(crate) trait Reconstructable { + async fn reconstruct( &self, - prefix: &str, hash: &MerkleHash, - data: Vec, - chunk_and_boundaries: Vec<(MerkleHash, u32)>, - ) -> Result<()> { - (**self).put(prefix, hash, data, chunk_and_boundaries).await - } + byte_range: Option<(u64, u64)>, + ) -> Result; +} - async fn exists(&self, prefix: &str, hash: &MerkleHash) -> Result { - (**self).exists(prefix, hash).await - } +/* + * If T implements Client, Arc also implements Client + */ +// #[async_trait] +// impl UploadClient for Arc { +// async fn put( +// &self, +// prefix: &str, +// hash: &MerkleHash, +// data: Vec, +// chunk_and_boundaries: Vec<(MerkleHash, u32)>, +// ) -> Result<()> { +// (**self).put(prefix, hash, data, chunk_and_boundaries).await +// } - /// Clients may do puts in the background. A flush is necessary - /// to enforce completion of all puts. If an error occured during any - /// background put it will be returned here.force completion of all puts. - async fn flush(&self) -> Result<()> { - (**self).flush().await - } -} +// async fn exists(&self, prefix: &str, hash: &MerkleHash) -> Result { +// (**self).exists(prefix, hash).await +// } -#[async_trait] -impl ReconstructionClient for Arc { - /// Get a entire file by file hash. - async fn get_file(&self, hash: &MerkleHash, writer: &mut Box) -> Result<()> { - (**self).get_file(hash, writer).await - } +// /// Clients may do puts in the background. A flush is necessary +// /// to enforce completion of all puts. If an error occured during any +// /// background put it will be returned here.force completion of all puts. +// async fn flush(&self) -> Result<()> { +// (**self).flush().await +// } +// } - async fn get_file_byte_range( - &self, - hash: &MerkleHash, - offset: u64, - length: u64, - writer: &mut Box, - ) -> Result<()> { - (**self) - .get_file_byte_range(hash, offset, length, writer) - .await - } -} +// #[async_trait] +// impl ReconstructionClient for Arc { +// /// Get a entire file by file hash. +// async fn get_file(&self, hash: &MerkleHash, writer: &mut Box) -> Result<()> { +// (**self).get_file(hash, writer).await +// } -pub trait Client: UploadClient + ReconstructionClient {} +// async fn get_file_byte_range( +// &self, +// hash: &MerkleHash, +// offset: u64, +// length: u64, +// writer: &mut Box, +// ) -> Result<()> { +// (**self) +// .get_file_byte_range(hash, offset, length, writer) +// .await +// } +// } diff --git a/cas_client/src/lib.rs b/cas_client/src/lib.rs index c97bbbf0..0aa58f44 100644 --- a/cas_client/src/lib.rs +++ b/cas_client/src/lib.rs @@ -4,7 +4,8 @@ pub use crate::error::CasClientError; pub use auth::AuthMiddleware; pub use caching_client::CachingClient; -pub use interface::Client; +pub use interface::{Client, ReconstructionClient, UploadClient}; +pub use local_client::tests_utils; pub use local_client::LocalClient; pub use remote_client::build_reqwest_client; pub use remote_client::RemoteClient; diff --git a/cas_client/src/local_client.rs b/cas_client/src/local_client.rs index f52f3b32..79be927a 100644 --- a/cas_client/src/local_client.rs +++ b/cas_client/src/local_client.rs @@ -206,8 +206,7 @@ impl UploadClient for LocalClient { } } -#[cfg(test)] -mod tests_utils { +pub mod tests_utils { use super::LocalClient; use crate::{error::Result, CasClientError}; use cas_object::CasObject; @@ -242,15 +241,17 @@ mod tests_utils { Ok(result) } + /// Get uncompressed bytes from a CAS object within chunk ranges. + /// Each tuple in chunk_ranges represents a chunk index range [a, b) fn get_object_range( &self, prefix: &str, hash: &MerkleHash, - ranges: Vec<(u32, u32)>, + chunk_ranges: Vec<(u32, u32)>, ) -> Result>> { // Handle the case where we aren't asked for any real data. - if ranges.len() == 1 && ranges[0].0 == ranges[0].1 { - return Ok(vec![Vec::::new()]); + if chunk_ranges.is_empty() { + return Ok(vec![vec![]]); } let file_path = self.get_path_for_entry(prefix, hash); @@ -265,9 +266,13 @@ mod tests_utils { let cas = CasObject::deserialize(&mut reader)?; let mut ret: Vec> = Vec::new(); - let all_uncompressed_bytes = cas.get_all_bytes(&mut reader)?; - for r in ranges { - let data = all_uncompressed_bytes[r.0 as usize..r.1 as usize].to_vec(); + for r in chunk_ranges { + if r.0 >= r.1 { + ret.push(vec![]); + continue; + } + + let data = cas.get_bytes_by_chunk_range(&mut reader, r.0, r.1)?; ret.push(data); } Ok(ret) @@ -338,26 +343,32 @@ mod tests { async fn test_basic_put_get_range_random_small() { // Arrange let client = LocalClient::default(); - let (c, _, data, chunk_boundaries) = build_cas_object(3, ChunkSize::Random(512, 2048), LZ4); - let data_again = data.clone(); + let (c, _, data, chunk_and_boundaries) = + build_cas_object(3, ChunkSize::Random(512, 2048), LZ4); // Act & Assert assert!(client - .put("", &c.info.cashash, data, chunk_boundaries) + .put( + "", + &c.info.cashash, + data.clone(), + chunk_and_boundaries.clone() + ) .await .is_ok()); - let ranges: Vec<(u32, u32)> = vec![(0, 100), (100, 1500)]; - let ranges_again = ranges.clone(); + let ranges: Vec<(u32, u32)> = vec![(0, 1), (2, 3)]; let returned_ranges = client .get_object_range("", &c.info.cashash, ranges) .unwrap(); + let expected = vec![ + data[0..chunk_and_boundaries[0].1 as usize].to_vec(), + data[chunk_and_boundaries[1].1 as usize..chunk_and_boundaries[2].1 as usize].to_vec(), + ]; + for idx in 0..returned_ranges.len() { - assert_eq!( - data_again[ranges_again[idx].0 as usize..ranges_again[idx].1 as usize], - returned_ranges[idx] - ); + assert_eq!(expected[idx], returned_ranges[idx]); } } diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 2b01c241..cfeb38b1 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -76,9 +76,9 @@ impl UploadClient for RemoteClient { impl ReconstructionClient for RemoteClient { async fn get_file(&self, hash: &MerkleHash, writer: &mut Box) -> Result<()> { // get manifest of xorbs to download - let manifest = self.reconstruct_file(hash, None).await?; + let manifest = self.reconstruct(hash, None).await?; - self.reconstruct(manifest, None, writer).await?; + self.get_ranges(manifest, None, writer).await?; Ok(()) } @@ -95,6 +95,29 @@ impl ReconstructionClient for RemoteClient { } } +#[async_trait] +impl Reconstructable for RemoteClient { + async fn reconstruct( + &self, + file_id: &MerkleHash, + _bytes_range: Option<(u64, u64)>, + ) -> Result { + let url = Url::parse(&format!( + "{}/reconstruction/{}", + self.endpoint, + file_id.hex() + ))?; + + let response = self.client.get(url).send().await?; + let response_body = response.bytes().await?; + let response_parsed: QueryReconstructionResponse = + serde_json::from_reader(response_body.reader()) + .map_err(|_| CasClientError::FileNotFound(*file_id))?; + + Ok(response_parsed) + } +} + impl Client for RemoteClient {} impl RemoteClient { @@ -135,7 +158,7 @@ impl RemoteClient { Ok(response_parsed.was_inserted) } - async fn reconstruct( + async fn get_ranges( &self, reconstruction_response: QueryReconstructionResponse, _byte_range: Option<(u64, u64)>, @@ -145,7 +168,7 @@ impl RemoteClient { let total_len = info.iter().fold(0, |acc, x| acc + x.unpacked_length); let futs = info .into_iter() - .map(|term| tokio::spawn(async move { get_one(&term).await })); + .map(|term| tokio::spawn(async move { get_one_range(&term).await })); for fut in futs { let piece = fut .await @@ -154,29 +177,9 @@ impl RemoteClient { } Ok(total_len as usize) } - - /// Reconstruct the file - async fn reconstruct_file( - &self, - file_id: &MerkleHash, - _bytes_range: Option<(u64, u64)>, - ) -> Result { - let url = Url::parse(&format!( - "{}/reconstruction/{}", - self.endpoint, - file_id.hex() - ))?; - - let response = self.client.get(url).send().await?; - let response_body = response.bytes().await?; - let response_parsed: QueryReconstructionResponse = - serde_json::from_reader(response_body.reader())?; - - Ok(response_parsed) - } } -async fn get_one(term: &CASReconstructionTerm) -> Result { +pub(crate) async fn get_one_range(term: &CASReconstructionTerm) -> Result { debug!("term: {term:?}"); if term.range.end < term.range.start || term.url_range.end < term.url_range.start { diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index c1f0a149..d1acfe75 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -421,21 +421,21 @@ impl CasObject { fn get_range( &self, reader: &mut R, - start: u32, - end: u32, + byte_start: u32, + byte_end: u32, ) -> Result, CasObjectError> { - if end < start { + if byte_end < byte_start { return Err(CasObjectError::InvalidRange); } self.validate_cas_object_info()?; // make sure the end of the range is within the bounds of the xorb - let end = min(end, self.get_contents_length()?); + let end = min(byte_end, self.get_contents_length()?); // read chunk bytes - let mut chunk_data = vec![0u8; (end - start) as usize]; - reader.seek(std::io::SeekFrom::Start(start as u64))?; + let mut chunk_data = vec![0u8; (end - byte_start) as usize]; + reader.seek(std::io::SeekFrom::Start(byte_start as u64))?; reader.read_exact(&mut chunk_data)?; // build up result vector by processing these chunks @@ -449,6 +449,18 @@ impl CasObject { self.get_range(reader, 0, self.get_contents_length()?) } + /// Convenient function to get content bytes by chunk range, mainly for internal testing + pub fn get_bytes_by_chunk_range( + &self, + reader: &mut R, + chunk_index_start: u32, + chunk_index_end: u32, + ) -> Result, CasObjectError> { + let (byte_start, byte_end) = self.get_byte_offset(chunk_index_start, chunk_index_end)?; + + self.get_range(reader, byte_start, byte_end) + } + /// Assumes chunk_data is 1+ complete chunks. Processes them sequentially and returns them as Vec. fn get_chunk_contents(&self, chunk_data: &[u8]) -> Result, CasObjectError> { // walk chunk_data, deserialize into Chunks, and then get_bytes() from each of them. diff --git a/data/Cargo.toml b/data/Cargo.toml index 19350464..dc35d5de 100644 --- a/data/Cargo.toml +++ b/data/Cargo.toml @@ -3,11 +3,6 @@ name = "data" version = "0.14.5" edition = "2021" -[profile.release] -opt-level = 3 -lto = true -debug = 1 - [lib] doctest = false diff --git a/data/src/cas_interface.rs b/data/src/cas_interface.rs index 7ca2f458..c3e46f2e 100644 --- a/data/src/cas_interface.rs +++ b/data/src/cas_interface.rs @@ -1,6 +1,8 @@ use crate::configurations::*; use crate::errors::Result; +use crate::test_utils::LocalTestClient; use cas_client::RemoteClient; +use mdb_shard::ShardFileManager; use std::env::current_dir; use std::path::Path; use std::sync::Arc; @@ -12,28 +14,51 @@ pub use cas_client::Client; pub(crate) fn create_cas_client( cas_storage_config: &StorageConfig, _maybe_repo_info: &Option, + shard_manager: Arc, ) -> Result> { match cas_storage_config.endpoint { - Endpoint::Server(ref endpoint) => remote_client(endpoint, &cas_storage_config.auth), - Endpoint::FileSystem(ref path) => local_test_cas_client(path), + Endpoint::Server(ref endpoint) => remote_client( + endpoint, + &cas_storage_config.cache_config, + &cas_storage_config.auth, + ), + Endpoint::FileSystem(ref path) => { + local_test_cas_client(&cas_storage_config.prefix, path, shard_manager) + } } } -pub(crate) fn remote_client( +fn remote_client( endpoint: &str, + _cache_config: &Option, auth: &Option, ) -> Result> { // Raw remote client. - let remote_client = Arc::new(RemoteClient::new(endpoint, auth)); + let remote_client = RemoteClient::new(endpoint, auth); - Ok(remote_client) + /* + if let Some(cache) = cache_config { + let caching_client = + CachingClient::new(remote_client, &cache.cache_directory, cache.cache_size); + return Ok(Arc::new(caching_client)); + } + */ + + Ok(Arc::new(remote_client)) } -fn local_test_cas_client(path: &Path) -> Result> { +fn local_test_cas_client( + prefix: &str, + path: &Path, + shard_manager: Arc, +) -> Result> { info!("Using local CAS with path: {:?}.", path); - let _path = match path.is_absolute() { + let path = match path.is_absolute() { true => path, false => ¤t_dir()?.join(path), }; - unimplemented!() + + let client = LocalTestClient::new(prefix, path, shard_manager); + + Ok(Arc::new(client)) } diff --git a/data/src/data_processing.rs b/data/src/data_processing.rs index f72ae9cc..69017db1 100644 --- a/data/src/data_processing.rs +++ b/data/src/data_processing.rs @@ -64,10 +64,14 @@ pub struct PointerFileTranslator { // Constructors impl PointerFileTranslator { pub async fn new(config: TranslatorConfig) -> Result { - let cas_client = create_cas_client(&config.cas_storage_config, &config.repo_info)?; - let shard_manager = Arc::new(create_shard_manager(&config.shard_storage_config).await?); + let cas_client = create_cas_client( + &config.cas_storage_config, + &config.repo_info, + shard_manager.clone(), + )?; + let remote_shards = { if let Some(dedup) = &config.dedup_config { RemoteShardInterface::new( diff --git a/data/src/lib.rs b/data/src/lib.rs index cdf1b417..0919c3c7 100644 --- a/data/src/lib.rs +++ b/data/src/lib.rs @@ -13,6 +13,7 @@ mod remote_shard_interface; mod repo_salt; mod shard_interface; mod small_file_determination; +mod test_utils; pub use constants::SMALL_FILE_THRESHOLD; pub use data_processing::PointerFileTranslator; diff --git a/data/src/test_utils/local_test_client.rs b/data/src/test_utils/local_test_client.rs new file mode 100644 index 00000000..a1991393 --- /dev/null +++ b/data/src/test_utils/local_test_client.rs @@ -0,0 +1,97 @@ +use anyhow::anyhow; +use async_trait::async_trait; +use cas_client::tests_utils::*; +use cas_client::{CasClientError, Client, LocalClient, ReconstructionClient, UploadClient}; +use mdb_shard::{shard_file_reconstructor::FileReconstructor, ShardFileManager}; +use merklehash::MerkleHash; +use std::path::Path; +use std::{io::Write, sync::Arc}; + +/// A CAS client only for the purpose of testing. It utilizes LocalClient to upload +/// and download xorbs and ShardFileManager to retrieve file reconstruction info. +pub struct LocalTestClient { + prefix: String, + cas: LocalClient, + shard_manager: Arc, +} + +impl LocalTestClient { + pub fn new(prefix: &str, path: &Path, shard_manager: Arc) -> Self { + let cas = LocalClient::new(path, false); + Self { + prefix: prefix.to_owned(), + cas, + shard_manager, + } + } +} + +#[async_trait] +impl UploadClient for LocalTestClient { + async fn put( + &self, + prefix: &str, + hash: &MerkleHash, + data: Vec, + chunk_and_boundaries: Vec<(MerkleHash, u32)>, + ) -> Result<(), CasClientError> { + self.cas.put(prefix, hash, data, chunk_and_boundaries).await + } + + async fn exists(&self, prefix: &str, hash: &MerkleHash) -> Result { + self.cas.exists(prefix, hash).await + } + + async fn flush(&self) -> Result<(), CasClientError> { + self.cas.flush().await + } +} + +#[async_trait] +impl ReconstructionClient for LocalTestClient { + async fn get_file( + &self, + hash: &MerkleHash, + writer: &mut Box, + ) -> Result<(), CasClientError> { + let Some((file_info, _)) = self + .shard_manager + .get_file_reconstruction_info(hash) + .await + .map_err(|e| anyhow!("{e}"))? + else { + return Err(CasClientError::FileNotFound(*hash)); + }; + + for entry in file_info.segments { + let Some(one_range) = self + .cas + .get_object_range( + &self.prefix, + &entry.cas_hash, + vec![(entry.chunk_index_start, entry.chunk_index_end)], + )? + .pop() + else { + return Err(CasClientError::InvalidRange); + }; + + writer.write_all(&one_range)?; + } + + Ok(()) + } + + #[allow(unused_variables)] + async fn get_file_byte_range( + &self, + hash: &MerkleHash, + offset: u64, + length: u64, + writer: &mut Box, + ) -> Result<(), CasClientError> { + todo!() + } +} + +impl Client for LocalTestClient {} diff --git a/data/src/test_utils/mod.rs b/data/src/test_utils/mod.rs new file mode 100644 index 00000000..66a03583 --- /dev/null +++ b/data/src/test_utils/mod.rs @@ -0,0 +1,3 @@ +mod local_test_client; + +pub use local_test_client::LocalTestClient; diff --git a/data/tests/integration_tests.rs b/data/tests/integration_tests.rs new file mode 100644 index 00000000..6b05fbf3 --- /dev/null +++ b/data/tests/integration_tests.rs @@ -0,0 +1,123 @@ +use anyhow::anyhow; +use std::{io::Write, path::Path, process::Command}; +use tempfile::TempDir; +use tracing::info; + +/// Set this to true to see the output of the tests on success. +const DEBUG: bool = false; + +struct IntegrationTest { + test_script: String, + arguments: Vec, + assets: Vec<(String, &'static [u8])>, +} + +impl IntegrationTest { + fn new(test_script: &str) -> Self { + Self { + test_script: test_script.to_owned(), + arguments: Vec::new(), + assets: Vec::new(), + } + } + + #[allow(unused)] + fn add_arguments(&mut self, args: &[&str]) { + self.arguments.extend(args.iter().map(|s| s.to_string())) + } + + #[allow(unused)] + fn add_asset(&mut self, name: &str, arg: &'static [u8]) { + self.assets.push((name.to_owned(), arg)); + } + + fn run(&self) -> anyhow::Result<()> { + // Create a temporary directory + let tmp_repo_dest = TempDir::new().unwrap(); + let tmp_path_path = tmp_repo_dest.path().to_path_buf(); + + std::fs::write(tmp_path_path.join("test_script.sh"), &self.test_script).unwrap(); + + std::fs::write( + tmp_path_path.join("initialize.sh"), + include_str!("integration_tests/initialize.sh"), + ) + .unwrap(); + + // Write the assets into the tmp path + for (name, data) in self.assets.iter() { + std::fs::write(tmp_path_path.join(name), data)?; + } + + let mut cmd = Command::new("bash"); + cmd.args(["-e", "-x", "test_script.sh"]); + cmd.args(&self.arguments[..]); + cmd.current_dir(tmp_path_path.clone()); + + // Add in the path of the test bin executable + + let test_bin_path = env!("CARGO_BIN_EXE_x"); + let buildpath = Path::new(&test_bin_path).parent().unwrap(); + info!("Adding {:?} to path.", &buildpath); + cmd.env( + "PATH", + format!( + "{}:{}", + &buildpath.to_str().unwrap(), + &std::env::var("PATH").unwrap() + ), + ); + + // Now, to prevent ~/.gitconfig to be read, we need to reset the home directory; otherwise + // these tests will not be run in an isolated environment. + // + // NOTE: this is not a problem with git version 2.32 or later. There, GIT_CONFIG_GLOBAL + // works and the scripts take advantage of it. However, outside of that, this is needed + // to avoid issues with a lesser git. + cmd.env("HOME", tmp_path_path.as_os_str()); + + // Now, run the script. + let out = cmd.output()?; + let status = out.status; + + if status.success() { + if DEBUG { + // Just dump things to the output + eprintln!("Test succeeded, STDOUT:"); + std::io::stdout().write_all(&out.stdout).unwrap(); + eprintln!("STDERR:"); + std::io::stderr().write_all(&out.stderr).unwrap(); + } + Ok(()) + } else { + eprintln!("Test failed, STDOUT:"); + std::io::stderr().write_all(&out.stderr).unwrap(); + // Parse output for error string: + let stderr_out = std::str::from_utf8(&out.stderr)?; + + eprintln!("STDERR:\n{}", &stderr_out); + + let error_re = regex::Regex::new("ERROR:>>>>>(.*)<<<<<").unwrap(); + + let captures = error_re.captures(stderr_out); + + if let Some(captured_text) = captures { + Err(anyhow!( + "Test failed: {}", + captured_text.get(1).unwrap().as_str() + )) + } else { + Err(anyhow!("Test failed: Unknown Error.")) + } + } + } +} + +#[cfg(all(test, unix))] +mod git_integration_tests { + use super::*; + #[test] + fn test_basic_read() -> anyhow::Result<()> { + IntegrationTest::new(include_str!("integration_tests/test_basic_clean_smudge.sh")).run() + } +} diff --git a/data/tests/integration_tests/initialize.sh b/data/tests/integration_tests/initialize.sh new file mode 100644 index 00000000..d7994b10 --- /dev/null +++ b/data/tests/integration_tests/initialize.sh @@ -0,0 +1,168 @@ +#!/usr/bin/env bash + +# With these, Log the filename, function name, and line number when showing where we're executing. +set -o xtrace +export PS4='+($(basename ${BASH_SOURCE}):${LINENO}): ${FUNCNAME[0]:+${FUNCNAME[0]}(): }' + +die() { + echo >&2 "ERROR:>>>>> $1 <<<<<" + return 1 +} +export -f die + +# support both Mac OS and Linux for these scripts +if hash md5 2>/dev/null; then + checksum() { + md5 -q $1 + } + checksum_string() { + echo $1 | md5 -q + } +else + checksum() { + md5sum $1 | head -c 32 + } + checksum_string() { + echo $1 | md5sum | head -c 32 + } +fi + +export -f checksum +export -f checksum_string + +create_data_file() { + f="$1" + len=$2 + + printf '\xff' >$f # Start with this to ensure utf-8 encoding fails quickly. + cat /dev/random | head -c $(($2 - 1)) >>$f + echo $(checksum $f) +} +export -f create_data_file + +append_data_file() { + f="$1" + len=$2 + + printf '\xff' >>$f # Start with this to ensure utf-8 encoding fails quickly. + cat /dev/random | head -c $(($2 - 1)) >>$f + echo $(checksum $f) +} +export -f append_data_file + +assert_files_equal() { + # Use fastest way to determine content equality. + cmp --silent $1 $2 || die "Assert Failed: Files $1 and $2 not equal." +} +export -f assert_files_equal + +assert_files_not_equal() { + # Use fastest way to determine content equality. + cmp --silent $1 $2 && die "Assert Failed: Files $1 and $2 should not be equal." || echo >&2 "Files $1 and $2 not equal." +} +export -f assert_files_not_equal + +assert_is_pointer_file() { + file=$1 + match=$(cat $file | head -n 1 | grep -F '# xet version' || echo "") + [[ ! -z "$match" ]] || die "File $file does not appear to be a pointer file." +} +export -f assert_is_pointer_file + +assert_pointer_file_size() { + file=$1 + size=$2 + + assert_is_pointer_file $file + + filesize=$(cat $file | grep -F filesize | sed -E 's|.*filesize = ([0-9]+).*|\1|' || echo "") + [[ $filesize == $size ]] || die "Pointer file $file gives incorrect size; $filesize, expected $size." +} +export -f assert_pointer_file_size + +pseudorandom_stream() { + key=$1 + + while true; do + key=$(checksum_string $key) + echo "$(echo $key | xxd -r -p)" 2>/dev/null || exit 0 + done +} +export -f pseudorandom_stream + +create_csv_file() { + local set_x_status=$([[ "$-" == *x* ]] && echo 1) + set +x + + csv_file="$1" + key="$2" + n_lines="$3" + n_repeats="${4:-1}" + n_lines_p_1=$((n_lines + 1)) + + pseudorandom_stream "$key" | hexdump -v -e '5/1 "%02x""\n"' | + awk -v OFS='\t' 'NR == 1 { print "foo", "bar", "baz" } + { print "S"substr($0, 1, 4), substr($0, 5, 2), substr($0, 7, 2)"."substr($0, 9, 1), 6, 3}' | + head -n $((n_lines + 1)) | tr 'abcdef' '123456' >$csv_file.part + + cat $csv_file.part >$csv_file + + for i in {0..n_repeats}; do + tail -n $n_lines $csv_file.part >>$csv_file + done + + rm $csv_file.part + [[ $set_x_status != "1" ]] || set -x +} +export -f create_csv_file + +create_random_csv_file() { + f="$1" + n_lines="$2" + n_repeats="${3:-1}" + n_lines_p_1=$((n_lines + 1)) + + cat /dev/random | hexdump -v -e '5/1 "%02x""\n"' | + awk -v OFS='\t' 'NR == 1 { print "foo", "bar", "baz" } + { print "S"substr($0, 1, 4), substr($0, 5, 2), substr($0, 7, 2)"."substr($0, 9, 1), 6, 3}' | + head -n $((n_lines + 1)) | tr 'abcdef' '123456' >$f.part + + cat $f.part >$f + + for i in {0..n_repeats}; do + tail -n $n_lines $f.part >>$f + done + + rm $f.part +} +export -f create_random_csv_file + +create_text_file() { + local set_x_status=$([[ "$-" == *x* ]] && echo 1) + set +x + + text_file="$1" + key="$2" + n_lines="$3" + n_repeats="${4:-1}" + + create_csv_file "$text_file.temp" "$key" "$n_lines" "$n_repeats" + + cat "$text_file.temp" | tr ',0123456789' 'ghijklmnopq' >$text_file + rm "$text_file.temp" + [[ $set_x_status != "1" ]] || set -x +} +export -f create_text_file + +random_tag() { + cat /dev/random | head -c 64 | checksum_string +} +export -f random_tag + +raw_file_size() { + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + stat --printf="%s" $1 + elif [[ "$OSTYPE" == "darwin"* ]]; then + stat -f%z $1 + fi +} diff --git a/data/tests/integration_tests/test_basic_clean_smudge.sh b/data/tests/integration_tests/test_basic_clean_smudge.sh new file mode 100644 index 00000000..93aad16f --- /dev/null +++ b/data/tests/integration_tests/test_basic_clean_smudge.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash +set -e +set -x + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]:-$0}")" &>/dev/null && pwd 2>/dev/null)" +. "$SCRIPT_DIR/initialize.sh" + +# Test small binary file clean & smudge +create_data_file small.dat 1452 + +x clean -d small.pft small.dat +assert_is_pointer_file small.pft +assert_pointer_file_size small.pft 1452 + +x smudge -f small.pft small.dat.2 +assert_files_equal small.dat small.dat.2 + +# Test big binary file clean & smudge +create_data_file large.dat 4621684 # 4.6 MB + +x clean -d large.pft large.dat +assert_is_pointer_file large.pft +assert_pointer_file_size large.pft 4621684 + +x smudge -f large.pft large.dat.2 +assert_files_equal large.dat large.dat.2 + +# Test small text file clean +create_text_file small.txt key1 100 1 + +x clean -d small.pft small.txt +assert_files_equal small.pft small.txt # not converted to pointer file From a7ffdfad45c3e544cc2952469248c8d3564cb093 Mon Sep 17 00:00:00 2001 From: seanses Date: Fri, 4 Oct 2024 12:41:23 -0700 Subject: [PATCH 19/19] remote client retry on error and error handling --- cas_client/src/remote_client.rs | 81 ++++++++++++++++++++++++------ data/src/remote_shard_interface.rs | 5 -- 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index cfeb38b1..c5698d27 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -11,16 +11,44 @@ use error_printer::OptionPrinter; use merklehash::MerkleHash; use reqwest::{StatusCode, Url}; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware}; +use retry_strategy::RetryStrategy; use std::io::{Cursor, Write}; -use tracing::{debug, error}; +use tracing::{debug, error, warn}; use utils::auth::AuthConfig; pub const CAS_ENDPOINT: &str = "http://localhost:8080"; pub const PREFIX_DEFAULT: &str = "default"; +const NUM_RETRIES: usize = 5; +const BASE_RETRY_DELAY_MS: u64 = 3000; + +fn retry_http_status_code(stat: &reqwest::StatusCode) -> bool { + stat.is_server_error() || *stat == reqwest::StatusCode::TOO_MANY_REQUESTS +} + +fn is_status_retriable_and_print(err: &reqwest::Error) -> bool { + let ret = err + .status() + .as_ref() + .map(retry_http_status_code) + .unwrap_or(true); // network issues should be retried + if ret { + warn!("{err:?}. Retrying..."); + } + ret +} + +fn is_middleware_status_retriable_and_print(err: &reqwest_middleware::Error) -> bool { + match err { + reqwest_middleware::Error::Reqwest(error) => is_status_retriable_and_print(error), + _ => false, + } +} + #[derive(Debug)] pub struct RemoteClient { client: ClientWithMiddleware, + retry_strategy: RetryStrategy, endpoint: String, } @@ -108,7 +136,18 @@ impl Reconstructable for RemoteClient { file_id.hex() ))?; - let response = self.client.get(url).send().await?; + let response = self + .retry_strategy + .retry( + || async { + let url = url.clone(); + self.client.get(url).send().await + }, + is_middleware_status_retriable_and_print, + ) + .await + .map_err(|e| CasClientError::InternalError(anyhow!("request failed with code {e}")))?; + let response_body = response.bytes().await?; let response_parsed: QueryReconstructionResponse = serde_json::from_reader(response_body.reader()) @@ -125,6 +164,7 @@ impl RemoteClient { let client = build_reqwest_client(auth_config).unwrap(); Self { client, + retry_strategy: RetryStrategy::new(NUM_RETRIES, BASE_RETRY_DELAY_MS), endpoint: endpoint.to_string(), } } @@ -183,27 +223,36 @@ pub(crate) async fn get_one_range(term: &CASReconstructionTerm) -> Result debug!("term: {term:?}"); if term.range.end < term.range.start || term.url_range.end < term.url_range.start { - return Err(CasClientError::InternalError(anyhow!( - "invalid range in reconstruction" - ))); + return Err(CasClientError::InvalidRange); } let url = Url::parse(term.url.as_str())?; - let response = reqwest::Client::new() - .request(hyper::Method::GET, url) - .header( - reqwest::header::RANGE, - format!("bytes={}-{}", term.url_range.start, term.url_range.end), + let client = reqwest::Client::new(); + let retry_strategy = RetryStrategy::new(NUM_RETRIES, BASE_RETRY_DELAY_MS); + + let response = retry_strategy + .retry( + || async { + let url = url.clone(); + + client + .get(url) + .header( + reqwest::header::RANGE, + format!("bytes={}-{}", term.url_range.start, term.url_range.end), + ) + .send() + .await + }, + is_status_retriable_and_print, ) - .send() - .await? - .error_for_status()?; - let xorb_bytes = response - .bytes() .await - .map_err(CasClientError::ReqwestError)?; + .map_err(|e| CasClientError::InternalError(anyhow!("request failed with code {e}")))?; + + let xorb_bytes = response.bytes().await?; if xorb_bytes.len() as u32 != term.url_range.end - term.url_range.start { error!("got back a smaller range than requested"); + return Err(CasClientError::InvalidRange); } let mut readseek = Cursor::new(xorb_bytes.to_vec()); let data = cas_object::deserialize_chunks(&mut readseek)?; diff --git a/data/src/remote_shard_interface.rs b/data/src/remote_shard_interface.rs index 14dfece0..76f76cca 100644 --- a/data/src/remote_shard_interface.rs +++ b/data/src/remote_shard_interface.rs @@ -20,7 +20,6 @@ use std::sync::Arc; use std::sync::Mutex; use tokio::task::JoinHandle; use tracing::{debug, info}; -use utils::singleflight; pub struct RemoteShardInterface { pub file_query_policy: FileQueryPolicy, @@ -35,9 +34,6 @@ pub struct RemoteShardInterface { pub shard_client: Option>, pub reconstruction_cache: Mutex)>>, - - // A gate on downloading and registering new shards. - pub shard_downloads: Arc>, } impl RemoteShardInterface { @@ -88,7 +84,6 @@ impl RemoteShardInterface { std::num::NonZero::new(FILE_RECONSTRUCTION_CACHE_SIZE).unwrap(), )), cas, - shard_downloads: Arc::new(singleflight::Group::new()), })) }