Skip to content

Commit

Permalink
feat: add auth with environment variable
Browse files Browse the repository at this point in the history
  • Loading branch information
asr2003 authored Dec 20, 2024
1 parent 069bbc4 commit 9746166
Showing 1 changed file with 104 additions and 7 deletions.
111 changes: 104 additions & 7 deletions nativelink-store/src/gcs_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::borrow::Cow;
use std::env;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -36,8 +37,10 @@ use nativelink_util::retry::{Retrier, RetryResult};
use nativelink_util::store_trait::{StoreDriver, StoreKey, UploadSizeInfo};
use rand::rngs::OsRng;
use rand::Rng;
use tokio::time::sleep;
use tokio::time::{sleep, Instant};
use tonic::metadata::MetadataValue;
use tonic::transport::Channel;
use tonic::{Request, Status};

// use tracing::{event, Level};
use crate::cas_utils::is_zero_digest;
Expand Down Expand Up @@ -66,6 +69,66 @@ These differences emphasize the need for tailored approaches to storage backends
// Note: If you change this, adjust the docs in the config.
const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024;

/// `CredentialProvider` manages the authentication token required for GCS operations.
/// It fetches tokens dynamically and ensures they are refreshed periodically.
pub struct CredentialProvider {
token: Arc<tokio::sync::Mutex<(String, Instant)>>,
}

impl CredentialProvider {
async fn new() -> Result<Self, Error> {
let token = Self::fetch_gcs_token().await?;
let expiry = Instant::now() + Duration::from_secs(3600); // Default expiry duration
Ok(Self {
token: Arc::new(tokio::sync::Mutex::new((token, expiry))),
})
}

/// Fetch a new token if the current token is expired.
pub async fn get_token(&self) -> Result<String, Error> {
let mut lock = self.token.lock().await;

// Refresh the token if it has expired
if Instant::now() >= lock.1 {
lock.0 = Self::fetch_gcs_token().await?;
lock.1 = Instant::now() + Duration::from_secs(3600);
}
Ok(lock.0.clone())
}

// /// Starts a background task to refresh the token periodically.
// /// This ensures the token remains valid for ongoing operations.
// pub async fn start_token_refresh(self: Arc<Self>) {
// nativelink_util::background_spawn(async move {
// loop {
// {
// let mut lock = self.token.lock().await;
// lock.0 = Self::fetch_gcs_token()
// .await
// .unwrap_or_else(|_| String::new());
// lock.1 = Instant::now() + Duration::from_secs(3600);
// }
// sleep(Duration::from_secs(3500)).await; // Refresh before expiry
// }
// });
// }

pub fn get_token_sync(&self) -> String {
// Returns the current token synchronously
let lock = self.token.blocking_lock();
lock.0.clone()
}

/// Fetches a GCS token using either an environment variable or the `gcloud` CLI.
async fn fetch_gcs_token() -> Result<String, Error> {
if let Ok(token) = env::var("GCS_AUTH_TOKEN") {
Ok(token.trim().to_string())
} else {
Err(make_err!(Code::Unavailable, "GCS_AUTH_TOKEN not found"))
}
}
}

#[derive(MetricsComponent)]
pub struct GCSStore<NowFn> {
// The gRPC client for GCS
Expand All @@ -84,13 +147,15 @@ pub struct GCSStore<NowFn> {
resumable_chunk_size: usize,
#[metric(help = "The number of concurrent uploads allowed for resumable uploads")]
resumable_max_concurrent_uploads: usize,
credential_provider: Arc<CredentialProvider>,
}

impl<I, NowFn> GCSStore<NowFn>
where
I: InstantWrapper,
NowFn: Fn() -> I + Send + Sync + Unpin + 'static,
{
/// Create a GCS client with an interceptor for dynamic token injection.
pub async fn new(spec: &GCSSpec, now_fn: NowFn) -> Result<Arc<Self>, Error> {
let jitter_amt = spec.retry.jitter;
let jitter_fn = Arc::new(move |delay: Duration| {
Expand All @@ -102,22 +167,28 @@ where
delay.mul_f32(OsRng.gen_range(min..max))
});

let channel = tonic::transport::Channel::from_static("https://storage.googleapis.com")
let endpoint = env::var("GCS_ENDPOINT")
.unwrap_or_else(|_| "https://storage.googleapis.com".to_string());
let channel = tonic::transport::Channel::from_shared(endpoint)
.map_err(|e| make_err!(Code::InvalidArgument, "Invalid GCS endpoint: {e:?}"))?
.connect()
.await
.map_err(|e| make_err!(Code::Unavailable, "Failed to connect to GCS: {e:?}"))?;

let gcs_client = StorageClient::new(channel);
let credential_provider = Arc::new(CredentialProvider::new().await?);

Self::new_with_client_and_jitter(spec, gcs_client, jitter_fn, now_fn)
Self::new_with_client_and_jitter(spec, channel, credential_provider, jitter_fn, now_fn)
}

pub fn new_with_client_and_jitter(
spec: &GCSSpec,
gcs_client: StorageClient<tonic::transport::Channel>,
channel: Channel,
credential_provider: Arc<CredentialProvider>,
jitter_fn: Arc<dyn Fn(Duration) -> Duration + Send + Sync>,
now_fn: NowFn,
) -> Result<Arc<Self>, Error> {
let gcs_client = StorageClient::new(channel);

Ok(Arc::new(Self {
gcs_client: Arc::new(gcs_client),
now_fn,
Expand All @@ -136,29 +207,55 @@ where
.resumable_chunk_size
.unwrap_or(DEFAULT_CHUNK_SIZE as usize),
resumable_max_concurrent_uploads: 0,
credential_provider,
}))
}

fn make_gcs_path(&self, key: &StoreKey<'_>) -> String {
format!("{}{}", self.key_prefix, key.as_str())
}

async fn inject_auth<F>(&self, mut request: Request<F>) -> Result<Request<F>, Status> {
let token = self
.credential_provider
.get_token()
.await
.map_err(|_| Status::unauthenticated("Failed to retrieve auth token"))?;

let auth_header = format!("Bearer {token}");
request.metadata_mut().insert(
"authorization",
MetadataValue::try_from(auth_header).unwrap(),
);
Ok(request)
}

/// Check if the object exists and is not expired
pub async fn has(self: Pin<&Self>, digest: &StoreKey<'_>) -> Result<Option<u64>, Error> {
let client = Arc::clone(&self.gcs_client);

self.retrier
.retry(unfold((), move |state| {
let mut client = (*client).clone();

async move {
let object_path = self.make_gcs_path(digest);
let request = ReadObjectRequest {
let raw_request = ReadObjectRequest {
bucket: self.bucket.clone(),
object: object_path.clone(),
..Default::default()
};

let result = client.read_object(request).await;
let Ok(authenticated_request) =
self.inject_auth(Request::new(raw_request)).await
else {
return Some((
RetryResult::Err(make_err!(Code::Unauthenticated, "Auth failed")),
state,
));
};

let result = client.read_object(authenticated_request).await;

match result {
Ok(response) => {
Expand Down

0 comments on commit 9746166

Please sign in to comment.