From ec98cef9b8c39832f3c3c0940c0b947d5f1b897c Mon Sep 17 00:00:00 2001 From: Rajat Arya Date: Fri, 13 Sep 2024 18:02:49 -0700 Subject: [PATCH] CAS HTTP remote client (#6) * Partial Remote Client - not building yet. * Everything builds, but unlikely anything works * comments * Rust file reconstruction Co-authored by Assaf * Fixing clippy & build --- Cargo.lock | 181 +++++- Cargo.toml | 2 + cas_client/Cargo.toml | 5 + cas_client/src/data_transport.rs | 28 +- cas_client/src/error.rs | 9 + cas_client/src/grpc.rs | 822 ---------------------------- cas_client/src/lib.rs | 4 - cas_client/src/remote_client.rs | 773 +++++++++----------------- cas_client/src/util.rs | 5 +- cas_object/src/cas_object_format.rs | 2 +- data/src/cas_interface.rs | 10 +- 11 files changed, 470 insertions(+), 1371 deletions(-) delete mode 100644 cas_client/src/grpc.rs diff --git a/Cargo.lock b/Cargo.lock index bc48fd48..cb4c42ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -227,7 +227,7 @@ dependencies = [ "pin-project-lite", "rustversion", "serde", - "sync_wrapper", + "sync_wrapper 0.1.2", "tower", "tower-layer", "tower-service", @@ -430,6 +430,7 @@ dependencies = [ "bytes", "cache", "cas_object", + "cas_types", "clap 2.34.0", "deadpool", "error_printer", @@ -437,7 +438,7 @@ dependencies = [ "http 0.2.12", "http-body-util", "hyper 1.4.1", - "hyper-rustls", + "hyper-rustls 0.26.0", "hyper-util", "itertools 0.10.5", "lazy_static", @@ -452,8 +453,10 @@ dependencies = [ "prost", "rand 0.8.5", "rcgen", + "reqwest 0.12.7", "retry_strategy", "rustls-pemfile 2.1.3", + "serde", "serde_json", "tempfile", "tokio", @@ -465,7 +468,9 @@ dependencies = [ "tower", "tracing", "tracing-opentelemetry", + "tracing-test", "trait-set", + "url", "utils", "uuid", "xet_error", @@ -922,7 +927,7 @@ dependencies = [ "rand 0.8.5", "rand_chacha", "regex", - "reqwest", + "reqwest 0.11.27", "retry_strategy", "ring 0.16.20", "rstest", @@ -1714,6 +1719,23 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-rustls" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +dependencies = [ + "futures-util", + "http 1.1.0", + "hyper 1.4.1", + "hyper-util", + "rustls 0.23.13", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.0", + "tower-service", +] + [[package]] name = "hyper-timeout" version = "0.4.1" @@ -1739,6 +1761,22 @@ dependencies = [ "tokio-native-tls", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.8" @@ -3113,7 +3151,7 @@ dependencies = [ "http 0.2.12", "http-body 0.4.6", "hyper 0.14.30", - "hyper-tls", + "hyper-tls 0.5.0", "ipnet", "js-sys", "log", @@ -3126,8 +3164,8 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", - "system-configuration", + "sync_wrapper 0.1.2", + "system-configuration 0.5.1", "tokio", "tokio-native-tls", "tower-service", @@ -3139,6 +3177,49 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest" +version = "0.12.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" +dependencies = [ + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2 0.4.6", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.4.1", + "hyper-rustls 0.27.3", + "hyper-tls 0.6.0", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls-pemfile 2.1.3", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "system-configuration 0.6.1", + "tokio", + "tokio-native-tls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "windows-registry", +] + [[package]] name = "retain_mut" version = "0.1.9" @@ -3277,6 +3358,19 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls" +version = "0.23.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki 0.102.8", + "subtle", + "zeroize", +] + [[package]] name = "rustls-native-certs" version = "0.6.3" @@ -3584,7 +3678,7 @@ dependencies = [ "opentelemetry-http", "opentelemetry-jaeger", "rand 0.8.5", - "reqwest", + "reqwest 0.11.27", "retry_strategy", "serde_json", "tempfile", @@ -3836,6 +3930,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "sync_wrapper" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +dependencies = [ + "futures-core", +] + [[package]] name = "synchronoise" version = "1.0.1" @@ -3868,7 +3971,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", "core-foundation", - "system-configuration-sys", + "system-configuration-sys 0.5.0", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "system-configuration-sys 0.6.0", ] [[package]] @@ -3881,6 +3995,16 @@ dependencies = [ "libc", ] +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tabled" version = "0.12.2" @@ -4174,6 +4298,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls 0.23.13", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.16" @@ -4815,6 +4950,36 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/Cargo.toml b/Cargo.toml index 22670a76..60031d16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,8 @@ members = [ "shard_client", "utils", "xet_error", + "cas_object", + "cas_types", ] exclude = [ diff --git a/cas_client/Cargo.toml b/cas_client/Cargo.toml index 0a89e5c1..a570006f 100644 --- a/cas_client/Cargo.toml +++ b/cas_client/Cargo.toml @@ -51,6 +51,10 @@ tokio-rustls = "0.25.0" rustls-pemfile = "2.0.0" hyper-rustls = { version = "0.26.0", features = ["http2"] } lz4 = "1.24.0" +reqwest = "0.12.7" +serde = { version = "1.0.210", features = ["derive"] } +cas_types = { version = "0.1.0", path = "../cas_types" } +url = "2.5.2" [dev-dependencies] trait-set = "0.3.0" @@ -58,3 +62,4 @@ lazy_static = "1.4.0" tokio-stream = { version = "0.1", features = ["net"] } rcgen = "0.12.0" rand = "0.8.5" +tracing-test = "0.2.5" diff --git a/cas_client/src/data_transport.rs b/cas_client/src/data_transport.rs index edf26b59..d8e04613 100644 --- a/cas_client/src/data_transport.rs +++ b/cas_client/src/data_transport.rs @@ -2,11 +2,7 @@ use cas::constants::*; use std::str::FromStr; use std::time::Duration; -use crate::{ - cas_connection_pool::CasConnectionConfig, - grpc::{get_request_id, trace_forwarding}, - remote_client::CAS_PROTOCOL_VERSION, -}; +use crate::cas_connection_pool::CasConnectionConfig; use anyhow::{anyhow, Result}; use cas::common::CompressionScheme; use cas::compression::{ @@ -27,13 +23,12 @@ use hyper_util::client::legacy::Client; use hyper_util::rt::{TokioExecutor, TokioTimer}; use lazy_static::lazy_static; use lz4::block::CompressionMode; -use opentelemetry::propagation::{Injector, TextMapPropagator}; +use opentelemetry::propagation::Injector; use retry_strategy::RetryStrategy; use rustls_pemfile::Item; use tokio_rustls::rustls; use tokio_rustls::rustls::pki_types::CertificateDer; -use tracing::{debug, error, info_span, warn, Instrument, Span}; -use tracing_opentelemetry::OpenTelemetrySpanExt; +use tracing::{debug, error, info_span, warn, Instrument}; use xet_error::Error; use merklehash::MerkleHash; @@ -196,19 +191,19 @@ impl DataTransport { debug!("Calling {} with address: {}", method, dest); let user_id = self.cas_connection_config.user_id.clone(); let auth = self.cas_connection_config.auth.clone(); - let request_id = get_request_id(); + // let request_id = get_request_id(); let repo_paths = self.cas_connection_config.repo_paths.clone(); let git_xet_version = self.cas_connection_config.git_xet_version.clone(); - let cas_protocol_version = CAS_PROTOCOL_VERSION.clone(); + // let cas_protocol_version = CAS_PROTOCOL_VERSION.clone(); let mut req = Request::builder() .method(method.clone()) .header(USER_ID_HEADER, user_id) .header(AUTH_HEADER, auth) - .header(REQUEST_ID_HEADER, request_id) + //.header(REQUEST_ID_HEADER, request_id) .header(REPO_PATHS_HEADER, repo_paths) .header(GIT_XET_VERSION_HEADER, git_xet_version) - .header(CAS_PROTOCOL_VERSION_HEADER, cas_protocol_version) + //.header(CAS_PROTOCOL_VERSION_HEADER, cas_protocol_version) .uri(&dest) .version(Version::HTTP_2); @@ -219,6 +214,7 @@ impl DataTransport { ); } + /* if trace_forwarding() { if let Some(headers) = req.headers_mut() { let mut injector = HeaderInjector(headers); @@ -228,6 +224,8 @@ impl DataTransport { propagator.inject_context(&ctx, &mut injector); } } + */ + let bytes = match body { None => Bytes::new(), Some(data) => Bytes::from(data), @@ -581,11 +579,5 @@ mod tests { // check against values in config assert_eq!(get_header_value(GIT_XET_VERSION_HEADER), git_xet_version); assert_eq!(get_header_value(USER_ID_HEADER), user_id); - - // check against global static - assert_eq!( - get_header_value(CAS_PROTOCOL_VERSION_HEADER), - CAS_PROTOCOL_VERSION.as_str() - ); } } diff --git a/cas_client/src/error.rs b/cas_client/src/error.rs index 218d2e3e..1d84847d 100644 --- a/cas_client/src/error.rs +++ b/cas_client/src/error.rs @@ -65,6 +65,15 @@ pub enum CasClientError { #[error("Cas Object Error: {0}")] CasObjectError(#[from] cas_object::error::CasObjectError), + + #[error("Parse Error: {0}")] + ParseError(#[from] url::ParseError), + + #[error("Reqwest Error: {0}")] + ReqwestError(#[from] reqwest::Error), + + #[error("Serde Error: {0}")] + SerdeError(#[from] serde_json::Error), } // Define our own result type here (this seems to be the standard). diff --git a/cas_client/src/grpc.rs b/cas_client/src/grpc.rs deleted file mode 100644 index 5815bda7..00000000 --- a/cas_client/src/grpc.rs +++ /dev/null @@ -1,822 +0,0 @@ -use crate::error::Result; -use std::env::VarError; -use std::str::FromStr; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::Duration; - -use crate::cas_connection_pool::CasConnectionConfig; -use crate::remote_client::CAS_PROTOCOL_VERSION; -use http::Uri; -use opentelemetry::propagation::{Injector, TextMapPropagator}; -use retry_strategy::RetryStrategy; -use tonic::codegen::InterceptedService; -use tonic::metadata::{Ascii, Binary, MetadataKey, MetadataMap, MetadataValue}; -use tonic::service::Interceptor; -use tonic::transport::{Certificate, ClientTlsConfig}; -use tonic::{transport::Channel, Code, Request, Status}; -use tracing::{debug, info, warn, Span}; -use tracing_opentelemetry::OpenTelemetrySpanExt; -use uuid::Uuid; - -use cas::common::CompressionScheme; -use cas::{ - cas::{ - cas_client::CasClient, GetRangeRequest, GetRequest, HeadRequest, PutCompleteRequest, - PutRequest, Range, - }, - common::{EndpointConfig, InitiateRequest, InitiateResponse, Key, Scheme}, - constants::*, -}; -use merklehash::MerkleHash; - -use crate::CasClientError; -pub type CasClientType = CasClient>; - -const DEFAULT_H2_PORT: u16 = 443; -const DEFAULT_PUT_COMPLETE_PORT: u16 = 5000; - -const HTTP2_KEEPALIVE_TIMEOUT_SEC: u64 = 20; -const HTTP2_KEEPALIVE_INTERVAL_SEC: u64 = 1; -const NUM_RETRIES: usize = 5; -const BASE_RETRY_DELAY_MS: u64 = 3000; - -// production ready settings -const INITIATE_CAS_SCHEME: &str = "https"; -const HTTP_CAS_SCHEME: &str = "http"; - -lazy_static::lazy_static! { - static ref DEFAULT_UUID: Uuid = Uuid::new_v4(); - static ref REQUEST_COUNTER: AtomicUsize = AtomicUsize::new(0); - static ref TRACE_FORWARDING: AtomicBool = AtomicBool::new(false); -} - -async fn get_channel(endpoint: &str, root_ca: &Option>) -> Result { - debug!("server name: {}", endpoint); - let mut server_uri: Uri = endpoint - .parse() - .map_err(|e| CasClientError::ConfigurationError(format!("Error parsing endpoint: {e}.")))?; - - // supports an absolute URI (above) or just the host:port (below) - // only used on first endpoint, all other endpoints should come from CAS - // with scheme info already included - // in local/witt modes overridden CAS initial URI should include scheme e.g. - // http://localhost:40000 - if server_uri.scheme().is_none() { - let scheme = if cfg!(test) { - HTTP_CAS_SCHEME - } else { - INITIATE_CAS_SCHEME - }; - server_uri = format!("{scheme}://{endpoint}").parse().unwrap(); - } - - debug!("Connecting to URI: {}", server_uri); - - let mut builder = Channel::builder(server_uri); - if let Some(root_ca) = root_ca { - let tls_config = - ClientTlsConfig::new().ca_certificate(Certificate::from_pem(root_ca.as_str())); - builder = builder.tls_config(tls_config)?; - } - let channel = builder - .keep_alive_timeout(Duration::new(HTTP2_KEEPALIVE_TIMEOUT_SEC, 0)) - .http2_keep_alive_interval(Duration::new(HTTP2_KEEPALIVE_INTERVAL_SEC, 0)) - .timeout(Duration::new(GRPC_TIMEOUT_SEC, 0)) - .connect_timeout(Duration::new(GRPC_TIMEOUT_SEC, 0)) - .connect() - .await?; - Ok(channel) -} - -pub async fn get_client(cas_connection_config: CasConnectionConfig) -> Result { - let timeout_channel = get_channel( - cas_connection_config.endpoint.as_str(), - &cas_connection_config.root_ca, - ) - .await?; - - let client: CasClientType = CasClient::with_interceptor( - timeout_channel, - MetadataHeaderInterceptor::new(cas_connection_config), - ); - Ok(client) -} - -/// Adds common metadata headers to all requests. Currently, this includes -/// authorization and xet-user-id. -/// TODO: at some point, we should re-evaluate how we authenticate/authorize requests to CAS. -#[derive(Debug, Clone)] -pub struct MetadataHeaderInterceptor { - config: CasConnectionConfig, -} - -impl MetadataHeaderInterceptor { - fn new(config: CasConnectionConfig) -> MetadataHeaderInterceptor { - MetadataHeaderInterceptor { config } - } -} - -impl Interceptor for MetadataHeaderInterceptor { - // note original Interceptor trait accepts non-mut request - // but may accept mut request like in this case - fn call(&mut self, mut request: Request<()>) -> std::result::Result, Status> { - request.set_timeout(Duration::new(GRPC_TIMEOUT_SEC, 0)); - let metadata = request.metadata_mut(); - - // pass user_id and repo_paths received from xetconfig - let user_id = get_metadata_ascii_from_str_with_default(&self.config.user_id, DEFAULT_USER); - metadata.insert(USER_ID_HEADER, user_id); - let auth = get_metadata_ascii_from_str_with_default(&self.config.auth, DEFAULT_AUTH); - metadata.insert(AUTH_HEADER, auth); - - let repo_paths = get_repo_paths_metadata_value(&self.config.repo_paths); - metadata.insert_bin(REPO_PATHS_HEADER, repo_paths); - - let git_xet_version = - get_metadata_ascii_from_str_with_default(&self.config.git_xet_version, DEFAULT_VERSION); - metadata.insert(GIT_XET_VERSION_HEADER, git_xet_version); - - let cas_protocol_version: MetadataValue = - MetadataValue::from_static(CAS_PROTOCOL_VERSION.as_str()); - metadata.insert(CAS_PROTOCOL_VERSION_HEADER, cas_protocol_version); - - // propagate tracing context (e.g. trace_id, span_id) to service - if trace_forwarding() { - let mut injector = HeaderInjector(metadata); - let propagator = opentelemetry_jaeger::Propagator::new(); - let cur_span = Span::current(); - let ctx = cur_span.context(); - propagator.inject_context(&ctx, &mut injector); - } - - let request_id = get_request_id(); - metadata.insert( - REQUEST_ID_HEADER, - MetadataValue::from_str(&request_id) - .map_err(|e| Status::internal(format!("Metadata error: {e:?}")))?, - ); - - Ok(request) - } -} - -pub fn set_trace_forwarding(should_enable: bool) { - TRACE_FORWARDING.store(should_enable, Ordering::Relaxed); -} - -pub fn trace_forwarding() -> bool { - TRACE_FORWARDING.load(Ordering::Relaxed) -} - -pub struct HeaderInjector<'a>(pub &'a mut MetadataMap); - -impl<'a> Injector for HeaderInjector<'a> { - /// Set a key and value in the HeaderMap. Does nothing if the key or value are not valid inputs. - fn set(&mut self, key: &str, value: String) { - if let Ok(name) = MetadataKey::from_str(key) { - if let Ok(val) = MetadataValue::from_str(&value) { - self.0.insert(name, val); - } - } - } -} - -fn get_metadata_ascii_from_str_with_default( - value: &str, - default: &'static str, -) -> MetadataValue { - MetadataValue::from_str(value) - .map_err(|_| VarError::NotPresent) - .unwrap_or_else(|_| MetadataValue::from_static(default)) -} - -fn get_repo_paths_metadata_value(repo_paths: &str) -> MetadataValue { - MetadataValue::from_bytes(repo_paths.as_bytes()) -} - -pub fn get_request_id() -> String { - format!( - "{}.{}", - *DEFAULT_UUID, - REQUEST_COUNTER.load(Ordering::Relaxed) - ) -} - -fn inc_request_id() { - REQUEST_COUNTER.fetch_add(1, Ordering::Relaxed); -} - -/// CAS Client that uses GRPC for communication. -/// -/// ## Implementation note -/// The GrpcClient is thread-safe and allows multiplexing requests on the -/// underlying gRPC connection. This is done by cheaply cloning the client: -/// https://docs.rs/tonic/0.1.0/tonic/transport/struct.Channel.html#multiplexing-requests -#[derive(Debug)] -pub struct GrpcClient { - pub endpoint: String, - client: CasClientType, - retry_strategy: RetryStrategy, -} - -impl Clone for GrpcClient { - fn clone(&self) -> Self { - GrpcClient { - endpoint: self.endpoint.clone(), - client: self.client.clone(), - retry_strategy: self.retry_strategy, - } - } -} - -impl GrpcClient { - pub fn new(endpoint: String, client: CasClientType, retry_strategy: RetryStrategy) -> Self { - Self { - endpoint, - client, - retry_strategy, - } - } - - pub async fn from_config(cas_connection_config: CasConnectionConfig) -> Result { - let endpoint = cas_connection_config.endpoint.clone(); - let client: CasClientType = get_client(cas_connection_config).await?; - // Retry policy: Exponential backoff starting at BASE_RETRY_DELAY_MS and retrying NUM_RETRIES times - let retry_strategy = RetryStrategy::new(NUM_RETRIES, BASE_RETRY_DELAY_MS); - Ok(GrpcClient::new(endpoint, client, retry_strategy)) - } -} - -pub fn is_status_retriable(err: &Status) -> bool { - match err.code() { - Code::Ok - | Code::Cancelled - | Code::InvalidArgument - | Code::NotFound - | Code::AlreadyExists - | Code::PermissionDenied - | Code::FailedPrecondition - | Code::OutOfRange - | Code::Unimplemented - | Code::Unauthenticated => false, - Code::Unknown - | Code::DeadlineExceeded - | Code::ResourceExhausted - | Code::Aborted - | Code::Internal - | Code::Unavailable - | Code::DataLoss => true, - } -} - -pub fn is_status_retriable_and_print(err: &Status) -> bool { - let ret = is_status_retriable(err); - if ret { - info!("GRPC Error {}. Retrying...", err); - } - ret -} - -pub fn print_final_retry_error(err: Status) -> Status { - if is_status_retriable(&err) { - warn!("Many failures {}", err); - } - err -} - -impl Drop for GrpcClient { - fn drop(&mut self) { - debug!("GrpcClient: Dropping GRPC Client."); - } -} - -// DTO for initiate rpc response info -pub struct EndpointsInfo { - pub data_plane_endpoint: EndpointConfig, - pub put_complete_endpoint: EndpointConfig, - pub accepted_encodings: Vec, -} - -impl GrpcClient { - #[tracing::instrument(skip_all, name = "cas.client", err, fields(prefix = prefix, hash = hash.hex().as_str(), api = "put", request_id = tracing::field::Empty))] - pub async fn put( - &self, - prefix: &str, - hash: &MerkleHash, - data: Vec, - chunk_boundaries: Vec, - ) -> Result<()> { - inc_request_id(); - Span::current().record("request_id", &get_request_id()); - debug!( - "GrpcClient Req {}: put to {}/{} of length {} bytes", - get_request_id(), - prefix, - hash, - data.len(), - ); - let request = PutRequest { - key: Some(get_key_for_request(prefix, hash)), - data, - chunk_boundaries, - }; - - let response = self - .retry_strategy - .retry( - || async { - let req = Request::new(request.clone()); - self.client.clone().put(req).await - }, - is_status_retriable_and_print, - ) - .await - .map_err(print_final_retry_error) - .map_err(|e| { - info!( - "GrpcClient Req {}: Error on Put {}/{} : {:?}", - get_request_id(), - prefix, - hash, - e - ); - CasClientError::Grpc(anyhow::Error::from(e)) - })?; - - debug!( - "GrpcClient Req {}: put to {}/{} complete.", - get_request_id(), - prefix, - hash, - ); - - if !response.into_inner().was_inserted { - debug!( - "GrpcClient Req {}: XORB {}/{} not inserted; already present.", - get_request_id(), - prefix, - hash - ); - } - - Ok(()) - } - - /// on success returns a type of 2 EndpointConfig's - /// first the EndpointConfig for the h2 dataplane endpoint - /// second the EndpointConfig for the grpc endpoint used for put complete rpc - #[tracing::instrument(skip_all, name = "cas.client", err, fields(prefix = prefix, hash = hash.hex().as_str(), api = "initiate", request_id = tracing::field::Empty))] - pub async fn initiate( - &self, - prefix: &str, - hash: &MerkleHash, - payload_size: usize, - ) -> Result { - debug!( - "GrpcClient Req {}. initiate {}/{}, size={payload_size}", - get_request_id(), - prefix, - hash - ); - inc_request_id(); - Span::current().record("request_id", &get_request_id()); - let request = InitiateRequest { - key: Some(get_key_for_request(prefix, hash)), - payload_size: payload_size as u64, - }; - - let response = self - .retry_strategy - .retry( - || async { - let req = Request::new(request.clone()); - self.client.clone().initiate(req).await - }, - is_status_retriable_and_print, - ) - .await - .map_err(print_final_retry_error) - .map_err(|e| CasClientError::Grpc(anyhow::Error::from(e)))?; - - let InitiateResponse { - data_plane_endpoint, - put_complete_endpoint, - cas_hostname, - accepted_encodings, - } = response.into_inner(); - - let accepted_encodings = accepted_encodings - .into_iter() - .filter_map(|i| CompressionScheme::try_from(i).ok()) - .collect(); - - if data_plane_endpoint.is_none() || put_complete_endpoint.is_none() { - info!("CAS initiate response indicates cas protocol version < v0.2.0, defaulting to v0.1.0 config"); - // default case, relevant for using CAS until prod is synced with v0.2.0 - return Ok(EndpointsInfo { - data_plane_endpoint: EndpointConfig { - host: cas_hostname.clone(), - port: DEFAULT_H2_PORT.into(), - scheme: Scheme::Http.into(), - ..Default::default() - }, - put_complete_endpoint: EndpointConfig { - host: cas_hostname, - port: DEFAULT_PUT_COMPLETE_PORT.into(), - scheme: Scheme::Http.into(), - ..Default::default() - }, - accepted_encodings, - }); - } - debug!( - "GrpcClient Req {}. initiate {}/{}, size={payload_size} complete", - get_request_id(), - prefix, - hash - ); - - Ok(EndpointsInfo { - data_plane_endpoint: data_plane_endpoint.unwrap(), - put_complete_endpoint: put_complete_endpoint.unwrap(), - accepted_encodings, - }) - } - - #[tracing::instrument(skip_all, name = "cas.client", err, fields(prefix = prefix, hash = hash.hex().as_str(), api = "put_complete", request_id = tracing::field::Empty))] - pub async fn put_complete( - &self, - prefix: &str, - hash: &MerkleHash, - chunk_boundaries: &[u64], - ) -> Result<()> { - debug!( - "GrpcClient Req {}. put_complete of {}/{}", - get_request_id(), - prefix, - hash - ); - Span::current().record("request_id", &get_request_id()); - let request = PutCompleteRequest { - key: Some(get_key_for_request(prefix, hash)), - chunk_boundaries: chunk_boundaries.to_owned(), - }; - - let _ = self - .retry_strategy - .retry( - || async { - let req = Request::new(request.clone()); - self.client.clone().put_complete(req).await - }, - is_status_retriable_and_print, - ) - .await - .map_err(print_final_retry_error) - .map_err(|e| CasClientError::Grpc(anyhow::Error::from(e)))?; - - debug!( - "GrpcClient Req {}. put_complete of {}/{} complete.", - get_request_id(), - prefix, - hash - ); - Ok(()) - } - - #[tracing::instrument(skip_all, name = "cas.client", err, fields(prefix = prefix, hash = hash.hex().as_str(), api = "get", request_id = tracing::field::Empty))] - pub async fn get(&self, prefix: &str, hash: &MerkleHash) -> Result> { - inc_request_id(); - Span::current().record("request_id", &get_request_id()); - debug!( - "GrpcClient Req {}. Get of {}/{}", - get_request_id(), - prefix, - hash - ); - let request = GetRequest { - key: Some(get_key_for_request(prefix, hash)), - }; - let response = self - .retry_strategy - .retry( - || async { - let req = Request::new(request.clone()); - self.client.clone().get(req).await - }, - is_status_retriable_and_print, - ) - .await - .map_err(print_final_retry_error) - .map_err(|e| { - info!( - "GrpcClient Req {}. Error on Get {}/{} : {:?}", - get_request_id(), - prefix, - hash, - e - ); - CasClientError::Grpc(anyhow::Error::from(e)) - })?; - - debug!( - "GrpcClient Req {}. Get of {}/{} complete.", - get_request_id(), - prefix, - hash - ); - - Ok(response.into_inner().data) - } - - #[tracing::instrument(skip_all, name = "cas.client", err, fields(prefix = prefix, hash = hash.hex().as_str(), api = "get_range", request_id = tracing::field::Empty))] - pub async fn get_object_range( - &self, - prefix: &str, - hash: &MerkleHash, - ranges: Vec<(u64, u64)>, - ) -> Result>> { - inc_request_id(); - Span::current().record("request_id", &get_request_id()); - debug!( - "GrpcClient Req {}. GetObjectRange of {}/{}", - get_request_id(), - prefix, - hash - ); - // Handle the case where we aren't asked for any real data. - if ranges.len() == 1 && ranges[0].0 == ranges[0].1 { - return Ok(vec![Vec::::new()]); - } - - let range_vec: Vec = ranges - .into_iter() - .map(|(start, end)| Range { start, end }) - .collect(); - let request = GetRangeRequest { - key: Some(get_key_for_request(prefix, hash)), - ranges: range_vec, - }; - let response = self - .retry_strategy - .retry( - || async { - let req = Request::new(request.clone()); - self.client.clone().get_range(req).await - }, - is_status_retriable_and_print, - ) - .await - .map_err(print_final_retry_error) - .map_err(|e| { - info!( - "GrpcClient Req {}. Error on GetObjectRange of {}/{} : {:?}", - get_request_id(), - prefix, - hash, - e - ); - CasClientError::Grpc(anyhow::Error::from(e)) - })?; - - debug!( - "GrpcClient Req {}. GetObjectRange of {}/{} complete.", - get_request_id(), - prefix, - hash - ); - - Ok(response.into_inner().data) - } - - #[tracing::instrument(skip_all, name = "cas.client", fields(prefix = prefix, hash = hash.hex().as_str(), api = "get_length", request_id = tracing::field::Empty))] - pub async fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result { - inc_request_id(); - Span::current().record("request_id", &get_request_id()); - debug!( - "GrpcClient Req {}. GetLength of {}/{}", - get_request_id(), - prefix, - hash - ); - let request = HeadRequest { - key: Some(get_key_for_request(prefix, hash)), - }; - let response = self - .retry_strategy - .retry( - || async { - let req = Request::new(request.clone()); - - self.client.clone().head(req).await - }, - is_status_retriable_and_print, - ) - .await - .map_err(print_final_retry_error) - .map_err(|e| { - debug!( - "GrpcClient Req {}. Error on GetLength of {}/{} : {:?}", - get_request_id(), - prefix, - hash, - e - ); - CasClientError::Grpc(anyhow::Error::from(e)) - })?; - debug!( - "GrpcClient Req {}. GetLength of {}/{} complete.", - get_request_id(), - prefix, - hash - ); - Ok(response.into_inner().size) - } -} - -pub fn get_key_for_request(prefix: &str, hash: &MerkleHash) -> Key { - Key { - prefix: prefix.to_string(), - hash: hash.as_bytes().to_vec(), - } -} - -#[cfg(test)] -mod tests { - use std::sync::atomic::{AtomicU32, Ordering}; - use std::sync::Arc; - - use tonic::Response; - - use cas::cas::PutResponse; - - use crate::util::grpc_mock::{MockService, ShutdownHook}; - - use super::*; - - #[tokio::test] - async fn test_put_with_retry() { - let count = Arc::new(AtomicU32::new(0)); - let put_count = count.clone(); - let put_api = move |req: Request| { - assert_eq!(req.into_inner().chunk_boundaries, vec![32, 54, 63]); - if 0 == put_count.fetch_add(1, Ordering::SeqCst) { - return Err(Status::internal("Failed")); - } - Ok(Response::new(PutResponse { was_inserted: true })) - }; - - let (mut hook, client): (ShutdownHook, GrpcClient) = - MockService::default().with_put(put_api).start().await; - - let resp = client - .put("pre1", &MerkleHash::default(), vec![0], vec![32, 54, 63]) - .await; - assert_eq!(2, count.load(Ordering::SeqCst)); - assert!(resp.is_ok()); - hook.async_drop().await; - } - - #[tokio::test] - async fn test_put_exhausted_retries() { - let count = Arc::new(AtomicU32::new(0)); - let put_count = count.clone(); - let put_api = move |req: Request| { - assert_eq!(req.into_inner().chunk_boundaries, vec![31, 54, 63]); - put_count.fetch_add(1, Ordering::SeqCst); - Err(Status::internal("Failed")) - }; - - let (mut hook, client) = MockService::default().with_put(put_api).start().await; - - let resp = client - .put("pre1", &MerkleHash::default(), vec![0], vec![31, 54, 63]) - .await; - assert_eq!(3, count.load(Ordering::SeqCst)); - assert!(resp.is_err()); - hook.async_drop().await - } - - #[tokio::test] - async fn test_put_no_retries() { - let count = Arc::new(AtomicU32::new(0)); - let put_count = count.clone(); - let put_api = move |req: Request| { - assert_eq!(req.into_inner().chunk_boundaries, vec![32, 95, 63]); - put_count.fetch_add(1, Ordering::SeqCst); - Err(Status::internal("Failed")) - }; - let (mut hook, client) = MockService::default() - .with_put(put_api) - .start_with_retry_strategy(RetryStrategy::new(0, 1)) - .await; - - let resp = client - .put("pre1", &MerkleHash::default(), vec![0], vec![32, 95, 63]) - .await; - assert_eq!(1, count.load(Ordering::SeqCst)); - assert!(resp.is_err()); - hook.async_drop().await - } - - #[tokio::test] - async fn test_put_application_error() { - let count = Arc::new(AtomicU32::new(0)); - let put_count = count.clone(); - let put_api = move |req: Request| { - assert_eq!(req.into_inner().chunk_boundaries, vec![32, 56, 63]); - put_count.fetch_add(1, Ordering::SeqCst); - Err(Status::failed_precondition("Failed precondition")) - }; - let (mut hook, client) = MockService::default().with_put(put_api).start().await; - - let resp = client - .put("pre1", &MerkleHash::default(), vec![0], vec![32, 56, 63]) - .await; - assert_eq!(1, count.load(Ordering::SeqCst)); - assert!(resp.is_err()); - hook.async_drop().await - } - - #[test] - fn metadata_header_interceptor_test() { - const XET_VERSION: &str = "0.1.0"; - let cas_connection_cofig: CasConnectionConfig = CasConnectionConfig::new( - "".to_string(), - "xet_user".to_string(), - "xet_auth".to_string(), - vec!["example".to_string()], - XET_VERSION.to_string(), - ); - let mut mh_interceptor = MetadataHeaderInterceptor::new(cas_connection_cofig); - let request = Request::new(()); - - { - // scoped so md reference to request is dropped - let md = request.metadata(); - assert!(md.get(USER_ID_HEADER).is_none()); - assert!(md.get(REQUEST_ID_HEADER).is_none()); - assert!(md.get(REPO_PATHS_HEADER).is_none()); - assert!(md.get(GIT_XET_VERSION_HEADER).is_none()); - assert!(md.get(CAS_PROTOCOL_VERSION_HEADER).is_none()); - } - let request = mh_interceptor.call(request).unwrap(); - - let md = request.metadata(); - let user_id_val = md.get(USER_ID_HEADER).unwrap(); - assert_eq!(user_id_val.to_str().unwrap(), "xet_user"); - let repo_path_val = md.get_bin(REPO_PATHS_HEADER).unwrap(); - assert_eq!(repo_path_val.to_bytes().unwrap().as_ref(), b"[\"example\"]"); - assert!(md.get(REQUEST_ID_HEADER).is_some()); - - assert!(md.get(GIT_XET_VERSION_HEADER).is_some()); - let xet_version = md.get(GIT_XET_VERSION_HEADER).unwrap().to_str().unwrap(); - assert_eq!(xet_version, XET_VERSION); - - // check that global static CAS_PROTOCOL_VERSION is what's set in the header - assert!(md.get(CAS_PROTOCOL_VERSION_HEADER).is_some()); - let cas_protocol_version = md - .get(CAS_PROTOCOL_VERSION_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(cas_protocol_version, CAS_PROTOCOL_VERSION.as_str()); - - let data: Vec> = vec![ - vec!["user1/repo-😀".to_string(), "user1/répô_123".to_string()], - vec![ - "user2/👾_repo".to_string(), - "user2/üникод".to_string(), - "user2/foobar!@#".to_string(), - ], - vec!["user3/sømè_repo".to_string(), "user3/你好-世界".to_string()], - vec!["user4/✨🌈repo".to_string()], - vec!["user5/Ω≈ç√repo".to_string()], - vec!["user6/42°_repo".to_string()], - vec![ - "user7/äëïöü_repo".to_string(), - "user7/ĀāĒēĪīŌōŪū".to_string(), - ], - ]; - for inner_vec in data { - let config = CasConnectionConfig::new( - "".to_string(), - "".to_string(), - "".to_string(), - inner_vec.clone(), - "".to_string(), - ); - let mut mh_interceptor = MetadataHeaderInterceptor::new(config); - let request = Request::new(()); - let request = mh_interceptor.call(request).unwrap(); - let md = request.metadata(); - let repo_path_val = md.get_bin(REPO_PATHS_HEADER).unwrap(); - let repo_path_str = - String::from_utf8(repo_path_val.to_bytes().unwrap().to_vec()).unwrap(); - let vec_of_strings: Vec = - serde_json::from_str(repo_path_str.as_str()).expect("Failed to deserialize JSON"); - assert_eq!(vec_of_strings, inner_vec); - } - } -} diff --git a/cas_client/src/lib.rs b/cas_client/src/lib.rs index cb0579a6..c0f752ec 100644 --- a/cas_client/src/lib.rs +++ b/cas_client/src/lib.rs @@ -3,14 +3,11 @@ pub use crate::error::CasClientError; pub use caching_client::{CachingClient, DEFAULT_BLOCK_SIZE}; -pub use grpc::set_trace_forwarding; -pub use grpc::GrpcClient; pub use interface::Client; pub use local_client::LocalClient; pub use merklehash::MerkleHash; // re-export since this is required for the client API. pub use passthrough_staging_client::PassthroughStagingClient; pub use remote_client::RemoteClient; -pub use remote_client::CAS_PROTOCOL_VERSION; pub use staging_client::{new_staging_client, new_staging_client_with_progressbar, StagingClient}; pub use staging_trait::{Staging, StagingBypassable}; @@ -19,7 +16,6 @@ mod cas_connection_pool; mod client_adapter; mod data_transport; mod error; -pub mod grpc; mod interface; mod local_client; mod passthrough_staging_client; diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 26f76400..70c3176a 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -1,577 +1,328 @@ -use async_trait::async_trait; -use cas::singleflight; -use itertools::Itertools; -use lazy_static::lazy_static; -use tracing::{debug, debug_span, error, info, info_span, Instrument}; +use std::io::{Cursor, Write}; -use merklehash::MerkleHash; - -use cas::common::CompressionScheme; -use std::collections::HashMap; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use tokio::sync::Mutex; - -use crate::cas_connection_pool::{self, CasConnectionConfig, FromConnectionConfig}; -use crate::data_transport::DataTransport; -use crate::error::{CasClientError, Result}; -use crate::grpc::{EndpointsInfo, GrpcClient}; -use crate::Client; -use retry_strategy::RetryStrategy; - -/// cas protocol version as seen from the client -/// cas protocol determines the parameters and protocols used for -/// cas operations -/// -/// 0.1.0: -/// grpc initiate; port 443; scheme https -/// h2 get + h2 put; port 443; scheme http; host from initiate rpc response -/// grpc put_complete; port 5000; scheme http; host from initiate rpc response -/// 0.2.0: -/// grpc initiate; port 443; scheme https -/// h2 get + h2 put; port, host, and scheme from initiate rpc response -/// grpc put_complete; port, host, and scheme from initiate rpc response -/// defaults to 0.1.0 if initiate does not respond with required info -/// 0.3.0: -/// grpc initiate; port 443; scheme https; widely trusted certificate -/// h2 get + h2 put; port, host and scheme from initiate rpc response; includes custom root certificate authority -/// grpc put_complete; port, host and scheme from initiate rpc response; includes custom root certificate authority -/// defaults to 0.2.0 if initiate does not respond with correct info -const _CAS_PROTOCOL_VERSION: &str = "0.3.0"; - -lazy_static! { - pub static ref CAS_PROTOCOL_VERSION: String = - std::env::var("XET_CAS_PROTOCOL_VERSION").unwrap_or(_CAS_PROTOCOL_VERSION.to_string()); -} - -// Completely arbitrary CAS size for using a single-hit put call. -// This should be tuned after performance testing. -const _MINIMUM_DATA_TRANSPORT_UPLOAD_SIZE: usize = 500; +use anyhow::anyhow; +use bytes::Buf; +use cas::key::Key; +use cas_types::{ + QueryChunkResponse, QueryReconstructionResponse, UploadXorbResponse +}; +use reqwest::{StatusCode, Url}; +use serde::{de::DeserializeOwned, Serialize}; -const PUT_MAX_RETRIES: usize = 3; -const PUT_RETRY_DELAY_MS: u64 = 1000; +use bytes::Bytes; +use cas_object::cas_object_format::CasObject; +use cas_types::CASReconstructionTerm; +use tracing::warn; -// We have different pool sizes since the GRPC connections are shorter-lived and -// thus, not as many of them are needed. This helps reduce the impact that connection -// creation can have (which, on MacOS, can be significant (hundreds of ms)). -const H2_TRANSPORT_POOL_SIZE: usize = 16; +use crate::{error::Result, CasClientError}; -type DataTransportPoolMap = cas_connection_pool::ConnectionPoolMap; - -// Apply an id for instrumentation when new connections are created to help with -// debugging / investigating performance issues related to connection creation. -lazy_static::lazy_static! { - static ref GRPC_CLIENT_ID: AtomicUsize = AtomicUsize::new(0); - static ref H2_CLIENT_ID: AtomicUsize = AtomicUsize::new(0); -} +use merklehash::MerkleHash; +use tracing::debug; -#[async_trait] -impl FromConnectionConfig for DataTransport { - async fn new_from_connection_config(config: CasConnectionConfig) -> Result { - let id = H2_CLIENT_ID.fetch_add(1, Ordering::SeqCst); - Ok(DataTransport::from_config(config) - .instrument(info_span!("transport.connect", id)) - .await?) - } -} +use crate::Client; -#[async_trait] -impl FromConnectionConfig for GrpcClient { - async fn new_from_connection_config(config: CasConnectionConfig) -> Result { - let id = GRPC_CLIENT_ID.fetch_add(1, Ordering::SeqCst); - Ok(GrpcClient::from_config(config) - .instrument(info_span!("grpc.connect", id)) - .await?) - } -} +pub const CAS_ENDPOINT: &str = "localhost:8080"; +pub const SCHEME: &str = "http:/"; +pub const PREFIX_DEFAULT: &str = "default"; -/// CAS Remote client. This negotiates between the control plane (gRPC) -/// and data plane (HTTP) to optimize the uploads and fetches according to -/// the network, file size, and other dynamic qualities. #[derive(Debug)] pub struct RemoteClient { - lb_endpoint: String, - user_id: String, - auth: String, - repo_paths: Vec, - grpc_connection_map: Arc>>, - dt_connection_map: DataTransportPoolMap, - length_singleflight: singleflight::Group, - length_cache: Arc>>, - git_xet_version: String, -} - -// DTO's for organization moving around endpoint info -#[derive(Clone)] -struct InitiateResponseEndpointInfo { endpoint: String, - root_ca: String, + client: CASAPIClient, } -#[derive(Clone)] -struct InitiateResponseEndpoints { - h2: InitiateResponseEndpointInfo, - put_complete: InitiateResponseEndpointInfo, - accepted_encodings: Vec, -} +// TODO: add retries +#[async_trait::async_trait] +impl Client for RemoteClient { + async fn put( + &self, + prefix: &str, + hash: &MerkleHash, + data: Vec, + chunk_boundaries: Vec, + ) -> Result<()> { + let key = Key { + prefix: prefix.to_string(), + hash: *hash, + }; -impl RemoteClient { - pub fn new( - lb_endpoint: String, - user_id: String, - auth: String, - repo_paths: Vec, - grpc_connection_map: Mutex>, - dt_connection_map: DataTransportPoolMap, - git_xet_version: String, - ) -> Self { - Self { - lb_endpoint, - user_id, - auth, - repo_paths, - grpc_connection_map: Arc::new(grpc_connection_map), - dt_connection_map, - length_singleflight: singleflight::Group::new(), - length_cache: Arc::new(Mutex::new(HashMap::new())), - git_xet_version, + let was_uploaded = self.client.upload(&key, data, chunk_boundaries).await?; + + if !was_uploaded { + debug!("{key:?} not inserted into CAS."); + } else { + debug!("{key:?} inserted into CAS."); } - } - pub async fn from_config( - endpoint: &str, - user_id: &str, - auth: &str, - repo_paths: Vec, - git_xet_version: String, - ) -> Self { - // optionally switch between a CAS and a local server running on CAS_GRPC_PORT and - // CAS_HTTP_PORT - Self::new( - endpoint.to_string(), - String::from(user_id), - String::from(auth), - repo_paths, - Mutex::new(HashMap::new()), - cas_connection_pool::ConnectionPoolMap::new_with_pool_size(H2_TRANSPORT_POOL_SIZE), - git_xet_version, - ) + Ok(()) } - /// utility to generate connection config for an endpoint and other owned information - /// currently only other owned info is `user_id` - fn get_cas_connection_config_for_endpoint(&self, endpoint: String) -> CasConnectionConfig { - CasConnectionConfig::new( - endpoint, - self.user_id.clone(), - self.auth.clone(), - self.repo_paths.clone(), - self.git_xet_version.clone(), - ) + async fn flush(&self) -> Result<()> { + Ok(()) } - async fn get_grpc_connection_for_config( - &self, - cas_connection_config: CasConnectionConfig, - ) -> Result { - Self::get_grpc_connection_for_config_from_map( - self.grpc_connection_map.clone(), - cas_connection_config, - ) - .await + async fn get(&self, _prefix: &str, _hash: &merklehash::MerkleHash) -> Result> { + Err(CasClientError::InvalidArguments) } - /// makes an initiate call to the ALB endpoint and returns - /// a tuple of 2 strings, the first being the http direct endpoint - /// and the second is the grpc direct endpoint - async fn initiate_cas_server_query( + async fn get_object_range( &self, - prefix: &str, - hash: &MerkleHash, - len: usize, - ) -> Result { - let cas_connection_config = - self.get_cas_connection_config_for_endpoint(self.lb_endpoint.clone()); - let lb_grpc_client = self - .get_grpc_connection_for_config(cas_connection_config) - .await?; - - let EndpointsInfo { - data_plane_endpoint, - put_complete_endpoint, - accepted_encodings, - } = lb_grpc_client.initiate(prefix, hash, len).await?; - drop(lb_grpc_client); - - debug!("cas initiate response; data plane endpoint: {data_plane_endpoint}; put complete endpoint: {put_complete_endpoint}"); - - Ok(InitiateResponseEndpoints { - h2: InitiateResponseEndpointInfo { - endpoint: data_plane_endpoint.to_string(), - root_ca: data_plane_endpoint.root_ca_certificate, - }, - put_complete: InitiateResponseEndpointInfo { - endpoint: put_complete_endpoint.to_string(), - root_ca: put_complete_endpoint.root_ca_certificate, - }, - accepted_encodings, - }) + _prefix: &str, + _hash: &merklehash::MerkleHash, + _ranges: Vec<(u64, u64)>, + ) -> Result>> { + Err(CasClientError::InvalidArguments) } - async fn put_impl_h2( - &self, - prefix: &str, - hash: &MerkleHash, - data: &[u8], - chunk_boundaries: &[u64], - ) -> Result<()> { - debug!("H2 Put executed with {} {}", prefix, hash); - let InitiateResponseEndpoints { - h2, - put_complete, - accepted_encodings, - } = self - .initiate_cas_server_query(prefix, hash, data.len()) - .instrument(debug_span!("remote_client.initiate")) - .await?; - - let encoding = choose_encoding(accepted_encodings); - - debug!("H2 Put initiate response h2 endpoint: {}, put complete endpoint {}\nh2 cert: {}, put complete cert {}", h2.endpoint, put_complete.endpoint, h2.root_ca, put_complete.root_ca); - - { - // separate scoped to drop transport so that the connection can be reclaimed by the pool - let transport = self - .dt_connection_map - .get_connection_for_config( - self.get_cas_connection_config_for_endpoint(h2.endpoint) - .with_root_ca(h2.root_ca), - ) - .await?; - transport - .put(prefix, hash, data, encoding) - .instrument(debug_span!("remote_client.put_h2")) - .await?; + async fn get_length(&self, prefix: &str, hash: &merklehash::MerkleHash) -> Result { + let key = Key { + prefix: prefix.to_string(), + hash: *hash, + }; + match self.client.get_length(&key).await? { + Some(length) => Ok(length), + None => Err(CasClientError::XORBNotFound(*hash)), } - - debug!("Data transport completed"); - - let cas_connection_config = self - .get_cas_connection_config_for_endpoint(put_complete.endpoint) - .with_root_ca(put_complete.root_ca); - let grpc_client = self - .get_grpc_connection_for_config(cas_connection_config) - .await?; - - debug!( - "Received grpc connection from pool: {}", - grpc_client.endpoint - ); - - grpc_client - .put_complete(prefix, hash, chunk_boundaries) - .await } +} - // default implementation, parallel unary - #[allow(dead_code)] - async fn put_impl_unary( - &self, - prefix: &str, - hash: &MerkleHash, - data: Vec, - chunk_boundaries: Vec, - ) -> Result<()> { - debug!("Unary Put executed with {} {}", prefix, hash); +impl RemoteClient { + pub async fn from_config(endpoint: String) -> Self { + Self { endpoint, client: CASAPIClient::default() } + } +} - let cas_connection_config = - self.get_cas_connection_config_for_endpoint(self.lb_endpoint.clone()); - let grpc_client = self - .get_grpc_connection_for_config(cas_connection_config) - .await?; +#[derive(Debug)] +pub struct CASAPIClient { + client: reqwest::Client, + scheme: String, + endpoint: String, +} - grpc_client.put(prefix, hash, data, chunk_boundaries).await +impl Default for CASAPIClient { + fn default() -> Self { + Self::new(SCHEME, CAS_ENDPOINT) } +} - // Default implementation, parallel unary - #[allow(dead_code)] - async fn get_impl_unary(&self, prefix: &str, hash: &MerkleHash) -> Result> { - let cas_connection_config = - self.get_cas_connection_config_for_endpoint(self.lb_endpoint.clone()); - let grpc_client = self - .get_grpc_connection_for_config(cas_connection_config) - .await?; +impl CASAPIClient { + pub fn new(scheme: &str, endpoint: &str) -> Self { + let client = reqwest::Client::builder() + .http2_prior_knowledge() + .build() + .unwrap(); + Self { client, scheme: scheme.to_string(), endpoint: endpoint.to_string() } + } - grpc_client.get(prefix, hash).await + pub async fn exists(&self, key: &Key) -> Result { + let url = Url::parse(&format!("{0}/{1}/xorb/{key}", self.scheme, self.endpoint))?; + let response = self.client.head(url).send().await?; + match response.status() { + StatusCode::OK => Ok(true), + StatusCode::NOT_FOUND => Ok(false), + e => Err(CasClientError::InternalError(anyhow!( + "unrecognized status code {e}" + ))), + } } - async fn get_impl_h2(&self, prefix: &str, hash: &MerkleHash) -> Result> { - debug!("H2 Get executed with {} {}", prefix, hash); - - let InitiateResponseEndpoints { h2, .. } = self - .initiate_cas_server_query(prefix, hash, 0) - .instrument(debug_span!("remote_client.initiate")) - .await?; - - let cas_connection_config = self - .get_cas_connection_config_for_endpoint(h2.endpoint) - .with_root_ca(h2.root_ca); - let transport = self - .dt_connection_map - .get_connection_for_config(cas_connection_config) - .instrument(debug_span!("remote_client.get_transport_connection")) - .await?; - let data = transport - .get(prefix, hash) - .instrument(debug_span!("remote_client.h2_get")) - .await?; - drop(transport); - - debug!("Data transport completed"); - Ok(data) + pub async fn get_length(&self, key: &Key) -> Result> { + let url = Url::parse(&format!("{0}/{1}/xorb/{key}", self.scheme, self.endpoint))?; + let response = self.client.head(url).send().await?; + let status = response.status(); + if status == StatusCode::NOT_FOUND { + return Ok(None); + } + if status != StatusCode::OK { + return Err(CasClientError::InternalError(anyhow!( + "unrecognized status code {status}" + ))); + } + let hv = match response.headers().get("Content-Length") { + Some(hv) => hv, + None => { + return Err(CasClientError::InternalError(anyhow!( + "HEAD missing content length header" + ))) + } + }; + let length: u64 = hv + .to_str() + .map_err(|_| { + CasClientError::InternalError(anyhow!("HEAD missing content length header")) + })? + .parse() + .map_err(|_| CasClientError::InternalError(anyhow!("failed to parse length")))?; + + Ok(Some(length)) } - // Default implementation, parallel unary - #[allow(dead_code)] - async fn get_object_range_impl_unary( + pub async fn upload>( &self, - prefix: &str, - hash: &MerkleHash, - ranges: Vec<(u64, u64)>, - ) -> Result>> { - debug!("Unary GetRange executed with {} {}", prefix, hash); + key: &Key, + contents: T, + chunk_boundaries: Vec, + ) -> Result { + let chunk_boundaries_query = chunk_boundaries + .iter() + .map(|num| num.to_string()) + .collect::>() + .join(","); + let url = Url::parse(&format!("{0}/{1}/xorb/{key}?{chunk_boundaries_query}", self.scheme, self.endpoint))?; - let cas_connection_config = - self.get_cas_connection_config_for_endpoint(self.lb_endpoint.clone()); - let grpc_client = self - .get_grpc_connection_for_config(cas_connection_config) - .await?; + debug!("Upload: POST to {url:?} for {key:?}"); - grpc_client.get_object_range(prefix, hash, ranges).await - } + let response = self.client.post(url).body(contents.into()).send().await?; + let response_body = response.bytes().await?; + let response_parsed: UploadXorbResponse = serde_json::from_reader(response_body.reader())?; - async fn get_object_range_impl_h2( - &self, - prefix: &str, - hash: &MerkleHash, - ranges: Vec<(u64, u64)>, - ) -> Result>> { - debug!("H2 GetRange executed with {} {}", prefix, hash); - - let InitiateResponseEndpoints { h2, .. } = self - .initiate_cas_server_query(prefix, hash, 0) - .instrument(debug_span!("remote_client.initiate")) - .await?; - - let cas_connection_config = self - .get_cas_connection_config_for_endpoint(h2.endpoint) - .with_root_ca(h2.root_ca); - let transport = self - .dt_connection_map - .get_connection_for_config(cas_connection_config) - .await?; - - let mut handlers = Vec::new(); - for range in ranges { - handlers.push(transport.get_range(prefix, hash, range)); - } - let results = futures::future::join_all(handlers).await; - let errors: Vec = results - .iter() - .filter_map(|r| r.as_deref().err().map(|s| s.to_string())) - .collect(); - if !errors.is_empty() { - let error_description: String = errors.join("-"); - Err(CasClientError::BatchError(error_description))?; - } - let data = results - .into_iter() - // unwrap is safe since we verified in the above if that no elements have an error - .map(|r| r.unwrap()) - .collect_vec(); - Ok(data) - } -} + Ok(response_parsed.was_inserted) + } -fn choose_encoding(accepted_encodings: Vec) -> CompressionScheme { - if accepted_encodings.is_empty() { - return CompressionScheme::None; - } - if accepted_encodings.contains(&CompressionScheme::Lz4) { - return CompressionScheme::Lz4; - } - CompressionScheme::None -} + /// Reconstruct a file and write to writer. + pub async fn write_file(&self, file_id: &MerkleHash, writer: &mut W) -> Result { -fn cas_client_error_retriable(err: &CasClientError) -> bool { - // we do not retry the logical errors - !matches!( - err, - CasClientError::InvalidRange - | CasClientError::InvalidArguments - | CasClientError::HashMismatch - ) -} + // get manifest of xorbs to download + let manifest = self.reconstruct_file(file_id).await?; -#[async_trait] -impl Client for RemoteClient { - async fn put( - &self, - prefix: &str, - hash: &MerkleHash, - data: Vec, - chunk_boundaries: Vec, - ) -> Result<()> { - // We first check if the block already exists, to avoid an unnecessary upload - if let Ok(xorb_size) = self.get_length(prefix, hash).await { - if xorb_size > 0 { - return Ok(()); - } - } - // We could potentially narrow down the error conditions - // further, but that gets complicated. - // So we just do something pretty coarse-grained - let strategy = RetryStrategy::new(PUT_MAX_RETRIES, PUT_RETRY_DELAY_MS); - let res = strategy - .retry( - || async { - self.put_impl_h2(prefix, hash, &data, &chunk_boundaries) - .await - }, - |e| { - let retry = cas_client_error_retriable(e); - if retry { - info!("Put error {:?}. Retrying...", e); - } - retry - }, - ) - .await; - - if let Err(ref e) = res { - if cas_client_error_retriable(e) { - error!("Too many failures writing {:?}: {:?}.", hash, e); - } - } - res + self.reconstruct(manifest, writer).await } - async fn flush(&self) -> Result<()> { - // this client does not background so no flush is needed - Ok(()) + async fn reconstruct(&self, reconstruction_response: QueryReconstructionResponse, writer: &mut W) -> Result { + let info = reconstruction_response.reconstruction; + let total_len = info.iter().fold(0, |acc, x| acc + x.unpacked_length); + let futs = info.into_iter().map(|term| { + tokio::spawn(async move { + let piece = get_one(&term).await?; + if piece.len() != (term.range.end - term.range.start) as usize { + warn!("got back a smaller range than requested"); + } + Result::::Ok(piece) + }) + }); + for fut in futs { + let piece = fut + .await + .map_err(|e| CasClientError::InternalError(anyhow!("join error {e}")))??; + writer.write_all(&piece)?; + } + Ok(total_len as usize) } - async fn get(&self, prefix: &str, hash: &MerkleHash) -> Result> { - self.get_impl_h2(prefix, hash).await + /// Reconstruct the file + async fn reconstruct_file(&self, file_id: &MerkleHash) -> Result { + let url = Url::parse(&format!("{0}/{1}/reconstruction/{2}", self.scheme, self.endpoint, file_id.hex()))?; + + let response = self.client.get(url).send().await?; + let response_body = response.bytes().await?; + let response_parsed: QueryReconstructionResponse = serde_json::from_reader(response_body.reader())?; + + Ok(response_parsed) } - async fn get_object_range( + pub async fn shard_query_chunk( &self, - prefix: &str, - hash: &MerkleHash, - ranges: Vec<(u64, u64)>, - ) -> Result>> { - self.get_object_range_impl_h2(prefix, hash, ranges).await + key: &Key, + ) -> Result { + let url = Url::parse(&format!("{0}/{1}/chunk/{key}", self.scheme, self.endpoint))?; + let response = self.client.get(url).send().await?; + let response_body = response.bytes().await?; + let response_parsed: QueryChunkResponse = serde_json::from_reader(response_body.reader())?; + + Ok(response_parsed) } - async fn get_length(&self, prefix: &str, hash: &MerkleHash) -> Result { - let key = format!("{}:{}", prefix, hash.hex()); + async fn post_json(&self, url: Url, request_body: &ReqT) -> Result + where + ReqT: Serialize, + RespT: DeserializeOwned, + { + let body = serde_json::to_vec(request_body)?; + let response = self.client.post(url).body(body).send().await?; + let response_bytes = response.bytes().await?; + serde_json::from_reader(response_bytes.reader()).map_err(CasClientError::SerdeError) + } +} - let cache = self.length_cache.clone(); +async fn get_one(term: &CASReconstructionTerm) -> Result { + let url = Url::parse(term.url.as_str())?; + let response = reqwest::Client::new() + .request(hyper::Method::GET, url) + .send() + .await? + .error_for_status()?; + let xorb_bytes = response + .bytes() + .await + .map_err(CasClientError::ReqwestError)?; + let mut readseek = Cursor::new(xorb_bytes.to_vec()); - // See if it's in the cache first before we try to launch it; this is cheap. - { - let cache = cache.lock().await; - if let Some(v) = cache.get(&key) { - return Ok(*v); - } - } - let cas_connection_config = - self.get_cas_connection_config_for_endpoint(self.lb_endpoint.clone()); - let connection_map = self.grpc_connection_map.clone(); - - let (res, _dedup) = self - .length_singleflight - .work( - &key, - Self::get_length_from_remote( - connection_map, - cas_connection_config, - cache, - prefix.to_string(), - *hash, - ), - ) - .await; - - return match res { - Ok(v) => Ok(v), - Err(singleflight::SingleflightError::InternalError(e)) => Err(e), - Err(e) => Err(CasClientError::InternalError(anyhow::Error::from(e))), - }; - } + let cas_object = CasObject::deserialize(&mut readseek)?; + let data = cas_object.get_range(&mut readseek, term.range.start as u32, term.range.end as u32)?; + + Ok(Bytes::from(data)) } -// static functions that can be used in spawned tasks -impl RemoteClient { - async fn get_length_from_remote( - connection_map: Arc>>, - cas_connection_config: CasConnectionConfig, - cache: Arc>>, - prefix: String, - hash: MerkleHash, - ) -> Result { - let key = format!("{}:{}", prefix, hash.hex()); - { - let cache = cache.lock().await; - if let Some(v) = cache.get(&key) { - return Ok(*v); - } - } +#[cfg(test)] +mod tests { - let grpc_client = - Self::get_grpc_connection_for_config_from_map(connection_map, cas_connection_config) - .await?; + use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB}; + use merklehash::DataHash; + use rand::Rng; + use tracing_test::traced_test; - debug!("RemoteClient: GetLength of {}/{}", prefix, hash); + use super::*; - let res = grpc_client.get_length(&prefix, &hash).await?; + #[ignore] + #[traced_test] + #[tokio::test] + async fn test_basic_put() { + // Arrange + let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string()).await; + let prefix = PREFIX_DEFAULT; + let (hash, data, chunk_boundaries) = gen_dummy_xorb(3, 10248, true); - debug!( - "RemoteClient: GetLength of {}/{} request complete", - prefix, hash - ); + // Act + let result = rc.put(prefix, &hash, data, chunk_boundaries).await; - // See if it's in the cache - { - let mut cache = cache.lock().await; - let _ = cache.insert(key.clone(), res); + // Assert + assert!(result.is_ok()); + } + + fn gen_dummy_xorb(num_chunks: u32, uncompressed_chunk_size: u32, randomize_chunk_sizes: bool) -> (DataHash, Vec, Vec) { + let mut contents = Vec::new(); + let mut chunks: Vec = Vec::new(); + let mut chunk_boundaries = Vec::with_capacity(num_chunks as usize); + for _idx in 0..num_chunks { + + let chunk_size: u32 = if randomize_chunk_sizes { + let mut rng = rand::thread_rng(); + rng.gen_range(1024..=uncompressed_chunk_size) + } else { + uncompressed_chunk_size + }; + + let bytes = gen_random_bytes(chunk_size); + + chunks.push(Chunk { hash: merklehash::compute_data_hash(&bytes), length: bytes.len() }); + + contents.extend(bytes); + chunk_boundaries.push(contents.len() as u64); } - Ok(res) + let mut db = MerkleMemDB::default(); + let mut staging = db.start_insertion_staging(); + db.add_file(&mut staging, &chunks); + let ret = db.finalize(staging); + let hash = *ret.hash(); + + (hash, contents, chunk_boundaries) } - async fn get_grpc_connection_for_config_from_map( - grpc_connection_map: Arc>>, - cas_connection_config: CasConnectionConfig, - ) -> Result { - let mut map = grpc_connection_map.lock().await; - if let Some(client) = map.get(&cas_connection_config.endpoint) { - return Ok(client.clone()); - } - // yes the lock is held through to endpoint creation. - // While strictly by locking patterns we should release the - // lock here, create the client, then re-acquire the lock to insert - // into the map, in practice *thousands* of threads could call this - // method simultaneously leading to a "race" where we create - // thousands of connections. - // - // Really we need to "single-flight" connection creation per endpoint. - // Since each RemoteClient really connects to only 1 endpoint, - // just locking the whole method here pretty much does what we need. - let endpoint = cas_connection_config.endpoint.clone(); - let new_client = GrpcClient::new_from_connection_config(cas_connection_config).await?; - map.insert(endpoint, new_client.clone()); - Ok(new_client) + fn gen_random_bytes(uncompressed_chunk_size: u32) -> Vec { + let mut rng = rand::thread_rng(); + let mut data = vec![0u8; uncompressed_chunk_size as usize]; + rng.fill(&mut data[..]); + data } -} +} \ No newline at end of file diff --git a/cas_client/src/util.rs b/cas_client/src/util.rs index 3a25bc38..d70047a4 100644 --- a/cas_client/src/util.rs +++ b/cas_client/src/util.rs @@ -14,8 +14,6 @@ pub(crate) mod grpc_mock { use tonic::{Request, Response, Status}; use crate::cas_connection_pool::CasConnectionConfig; - use crate::grpc::get_client; - use crate::grpc::GrpcClient; use cas::cas::cas_server::{Cas, CasServer}; use cas::cas::{ GetRangeRequest, GetRangeResponse, GetRequest, GetResponse, HeadRequest, HeadResponse, @@ -101,6 +99,7 @@ pub(crate) mod grpc_mock { } } + /* pub async fn start(self) -> (ShutdownHook, GrpcClient) { self.start_with_retry_strategy(RetryStrategy::new(2, 1)) .await @@ -140,6 +139,8 @@ pub(crate) mod grpc_mock { let client = GrpcClient::new("127.0.0.1".to_string(), cas_client, strategy); (shutdown_hook, client) } + + */ } // Unsafe hacks so that we can dynamically add in overrides to the mock functionality diff --git a/cas_object/src/cas_object_format.rs b/cas_object/src/cas_object_format.rs index 5dac2e71..71f3b82a 100644 --- a/cas_object/src/cas_object_format.rs +++ b/cas_object/src/cas_object_format.rs @@ -318,7 +318,7 @@ mod tests { use super::*; use merklehash::compute_data_hash; use rand::Rng; - use std::io::{Cursor}; + use std::io::Cursor; #[test] fn test_default_header_initialization() { diff --git a/data/src/cas_interface.rs b/data/src/cas_interface.rs index ed9938c5..e1158e2f 100644 --- a/data/src/cas_interface.rs +++ b/data/src/cas_interface.rs @@ -1,6 +1,6 @@ use super::configurations::{Endpoint::*, RepoInfo, StorageConfig}; use super::errors::Result; -use crate::constants::{MAX_CONCURRENT_DOWNLOADS, XET_VERSION}; +use crate::constants::MAX_CONCURRENT_DOWNLOADS; use crate::metrics::FILTER_BYTES_SMUDGED; use cas_client::{new_staging_client, CachingClient, LocalClient, RemoteClient, Staging}; use futures::prelude::stream::*; @@ -37,11 +37,11 @@ pub(crate) async fn create_cas_client( }; // Auth info. - let user_id = &cas_storage_config.auth.user_id; - let auth = &cas_storage_config.auth.login_id; + let _user_id = &cas_storage_config.auth.user_id; + let _auth = &cas_storage_config.auth.login_id; // Usage tracking. - let repo_paths = maybe_repo_info + let _repo_paths = maybe_repo_info .as_ref() .map(|repo_info| &repo_info.repo_paths) .cloned() @@ -49,7 +49,7 @@ pub(crate) async fn create_cas_client( // Raw remote client. let remote_client = Arc::new( - RemoteClient::from_config(endpoint, user_id, auth, repo_paths, XET_VERSION.clone()).await, + RemoteClient::from_config(endpoint.to_string()).await, ); // Try add in caching capability.