From 97461667ced8bc119a5056a7c9f31c8f2f064685 Mon Sep 17 00:00:00 2001 From: asr2003 <162500856+asr2003@users.noreply.github.com> Date: Fri, 20 Dec 2024 15:41:57 +0530 Subject: [PATCH] feat: add auth with environment variable --- nativelink-store/src/gcs_store.rs | 111 ++++++++++++++++++++++++++++-- 1 file changed, 104 insertions(+), 7 deletions(-) diff --git a/nativelink-store/src/gcs_store.rs b/nativelink-store/src/gcs_store.rs index 08230cbc5..fa0d7a9bc 100644 --- a/nativelink-store/src/gcs_store.rs +++ b/nativelink-store/src/gcs_store.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::borrow::Cow; +use std::env; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -36,8 +37,10 @@ use nativelink_util::retry::{Retrier, RetryResult}; use nativelink_util::store_trait::{StoreDriver, StoreKey, UploadSizeInfo}; use rand::rngs::OsRng; use rand::Rng; -use tokio::time::sleep; +use tokio::time::{sleep, Instant}; +use tonic::metadata::MetadataValue; use tonic::transport::Channel; +use tonic::{Request, Status}; // use tracing::{event, Level}; use crate::cas_utils::is_zero_digest; @@ -66,6 +69,66 @@ These differences emphasize the need for tailored approaches to storage backends // Note: If you change this, adjust the docs in the config. const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; +/// `CredentialProvider` manages the authentication token required for GCS operations. +/// It fetches tokens dynamically and ensures they are refreshed periodically. +pub struct CredentialProvider { + token: Arc>, +} + +impl CredentialProvider { + async fn new() -> Result { + let token = Self::fetch_gcs_token().await?; + let expiry = Instant::now() + Duration::from_secs(3600); // Default expiry duration + Ok(Self { + token: Arc::new(tokio::sync::Mutex::new((token, expiry))), + }) + } + + /// Fetch a new token if the current token is expired. + pub async fn get_token(&self) -> Result { + let mut lock = self.token.lock().await; + + // Refresh the token if it has expired + if Instant::now() >= lock.1 { + lock.0 = Self::fetch_gcs_token().await?; + lock.1 = Instant::now() + Duration::from_secs(3600); + } + Ok(lock.0.clone()) + } + + // /// Starts a background task to refresh the token periodically. + // /// This ensures the token remains valid for ongoing operations. + // pub async fn start_token_refresh(self: Arc) { + // nativelink_util::background_spawn(async move { + // loop { + // { + // let mut lock = self.token.lock().await; + // lock.0 = Self::fetch_gcs_token() + // .await + // .unwrap_or_else(|_| String::new()); + // lock.1 = Instant::now() + Duration::from_secs(3600); + // } + // sleep(Duration::from_secs(3500)).await; // Refresh before expiry + // } + // }); + // } + + pub fn get_token_sync(&self) -> String { + // Returns the current token synchronously + let lock = self.token.blocking_lock(); + lock.0.clone() + } + + /// Fetches a GCS token using either an environment variable or the `gcloud` CLI. + async fn fetch_gcs_token() -> Result { + if let Ok(token) = env::var("GCS_AUTH_TOKEN") { + Ok(token.trim().to_string()) + } else { + Err(make_err!(Code::Unavailable, "GCS_AUTH_TOKEN not found")) + } + } +} + #[derive(MetricsComponent)] pub struct GCSStore { // The gRPC client for GCS @@ -84,6 +147,7 @@ pub struct GCSStore { resumable_chunk_size: usize, #[metric(help = "The number of concurrent uploads allowed for resumable uploads")] resumable_max_concurrent_uploads: usize, + credential_provider: Arc, } impl GCSStore @@ -91,6 +155,7 @@ where I: InstantWrapper, NowFn: Fn() -> I + Send + Sync + Unpin + 'static, { + /// Create a GCS client with an interceptor for dynamic token injection. pub async fn new(spec: &GCSSpec, now_fn: NowFn) -> Result, Error> { let jitter_amt = spec.retry.jitter; let jitter_fn = Arc::new(move |delay: Duration| { @@ -102,22 +167,28 @@ where delay.mul_f32(OsRng.gen_range(min..max)) }); - let channel = tonic::transport::Channel::from_static("https://storage.googleapis.com") + let endpoint = env::var("GCS_ENDPOINT") + .unwrap_or_else(|_| "https://storage.googleapis.com".to_string()); + let channel = tonic::transport::Channel::from_shared(endpoint) + .map_err(|e| make_err!(Code::InvalidArgument, "Invalid GCS endpoint: {e:?}"))? .connect() .await .map_err(|e| make_err!(Code::Unavailable, "Failed to connect to GCS: {e:?}"))?; - let gcs_client = StorageClient::new(channel); + let credential_provider = Arc::new(CredentialProvider::new().await?); - Self::new_with_client_and_jitter(spec, gcs_client, jitter_fn, now_fn) + Self::new_with_client_and_jitter(spec, channel, credential_provider, jitter_fn, now_fn) } pub fn new_with_client_and_jitter( spec: &GCSSpec, - gcs_client: StorageClient, + channel: Channel, + credential_provider: Arc, jitter_fn: Arc Duration + Send + Sync>, now_fn: NowFn, ) -> Result, Error> { + let gcs_client = StorageClient::new(channel); + Ok(Arc::new(Self { gcs_client: Arc::new(gcs_client), now_fn, @@ -136,6 +207,7 @@ where .resumable_chunk_size .unwrap_or(DEFAULT_CHUNK_SIZE as usize), resumable_max_concurrent_uploads: 0, + credential_provider, })) } @@ -143,6 +215,21 @@ where format!("{}{}", self.key_prefix, key.as_str()) } + async fn inject_auth(&self, mut request: Request) -> Result, Status> { + let token = self + .credential_provider + .get_token() + .await + .map_err(|_| Status::unauthenticated("Failed to retrieve auth token"))?; + + let auth_header = format!("Bearer {token}"); + request.metadata_mut().insert( + "authorization", + MetadataValue::try_from(auth_header).unwrap(), + ); + Ok(request) + } + /// Check if the object exists and is not expired pub async fn has(self: Pin<&Self>, digest: &StoreKey<'_>) -> Result, Error> { let client = Arc::clone(&self.gcs_client); @@ -150,15 +237,25 @@ where self.retrier .retry(unfold((), move |state| { let mut client = (*client).clone(); + async move { let object_path = self.make_gcs_path(digest); - let request = ReadObjectRequest { + let raw_request = ReadObjectRequest { bucket: self.bucket.clone(), object: object_path.clone(), ..Default::default() }; - let result = client.read_object(request).await; + let Ok(authenticated_request) = + self.inject_auth(Request::new(raw_request)).await + else { + return Some(( + RetryResult::Err(make_err!(Code::Unauthenticated, "Auth failed")), + state, + )); + }; + + let result = client.read_object(authenticated_request).await; match result { Ok(response) => {