Skip to content

Commit

Permalink
[Experimental] Move identity & origin event middleware config (#1534)
Browse files Browse the repository at this point in the history
Restructure the config slightly for better management of http
specific configs allowing them to be configured per http service.
  • Loading branch information
allada authored Dec 13, 2024
1 parent 4896b5c commit 45520d9
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 106 deletions.
17 changes: 5 additions & 12 deletions nativelink-config/src/cas_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,6 @@ pub struct BepConfig {
/// The store name referenced in the `stores` map in the main config.
#[serde(deserialize_with = "convert_string_with_shellexpand")]
pub store: StoreRefName,

/// The config related to identifying the client.
/// The value of this header will be used to identify the caller and
/// will be added to the `BuildEvent::identity` field of the message.
/// Default: {see `IdentityHeaderSpec`}
#[serde(default)]
pub experimental_identity_header: IdentityHeaderSpec,
}

#[derive(Deserialize, Clone, Debug, Default)]
Expand Down Expand Up @@ -250,11 +243,6 @@ pub struct OriginEventsSpec {
/// Default: 65536 (zero defaults to this)
#[serde(default, deserialize_with = "convert_numeric_with_shellexpand")]
pub max_event_queue_size: usize,

/// The config related to identifying the client.
/// Default: {see `IdentityHeaderSpec`}
#[serde(default)]
pub identity_header: IdentityHeaderSpec,
}

#[derive(Deserialize, Debug)]
Expand Down Expand Up @@ -452,6 +440,11 @@ pub struct ServerConfig {

/// Services to attach to server.
pub services: Option<ServicesConfig>,

/// The config related to identifying the client.
/// Default: {see `IdentityHeaderSpec`}
#[serde(default)]
pub experimental_identity_header: IdentityHeaderSpec,
}

#[allow(non_camel_case_types)]
Expand Down
51 changes: 10 additions & 41 deletions nativelink-service/src/bep_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::pin::Pin;
use bytes::BytesMut;
use futures::stream::unfold;
use futures::Stream;
use nativelink_config::cas_server::IdentityHeaderSpec;
use nativelink_error::{Error, ResultExt};
use nativelink_proto::com::github::trace_machina::nativelink::events::{bep_event, BepEvent};
use nativelink_proto::google::devtools::build::v1::publish_build_event_server::{
Expand All @@ -29,25 +28,25 @@ use nativelink_proto::google::devtools::build::v1::{
PublishLifecycleEventRequest,
};
use nativelink_store::store_manager::StoreManager;
use nativelink_util::origin_context::{ActiveOriginContext, ORIGIN_IDENTITY};
use nativelink_util::store_trait::{Store, StoreDriver, StoreKey, StoreLike};
use prost::Message;
use tonic::metadata::MetadataMap;
use tonic::{Request, Response, Result, Status, Streaming};
use tracing::{instrument, Level};

/// Current version of the BEP event. This might be used in the future if
/// there is a breaking change in the BEP event format.
const BEP_EVENT_VERSION: u32 = 0;

/// Default identity header name.
/// Note: If this is changed, the default value in the [`IdentityHeaderSpec`]
// TODO(allada) This has a mirror in origin_event_middleware.rs.
// We should consolidate these.
const DEFAULT_IDENTITY_HEADER: &str = "x-identity";
fn get_identity() -> Result<Option<String>, Status> {
ActiveOriginContext::get()
.map_or(Ok(None), |ctx| ctx.get_value(&ORIGIN_IDENTITY))
.err_tip(|| "In BepServer")
.map_or_else(|e| Err(e.into()), |v| Ok(v.map(|v| v.as_ref().clone())))
}

pub struct BepServer {
store: Store,
identity_header: IdentityHeaderSpec,
}

impl BepServer {
Expand All @@ -59,41 +58,13 @@ impl BepServer {
.get_store(&config.store)
.err_tip(|| format!("Expected store {} to exist in store manager", &config.store))?;

let mut identity_header = config.experimental_identity_header.clone();
if identity_header.header_name.is_none() {
identity_header.header_name = Some(DEFAULT_IDENTITY_HEADER.to_string());
}

Ok(Self {
store,
identity_header,
})
Ok(Self { store })
}

pub fn into_service(self) -> PublishBuildEventServer<BepServer> {
PublishBuildEventServer::new(self)
}

fn get_identity(&self, request_metadata: &MetadataMap) -> Result<Option<String>, Status> {
let header_name = self
.identity_header
.header_name
.as_deref()
.unwrap_or(DEFAULT_IDENTITY_HEADER);
if header_name.is_empty() {
return Ok(None);
}
let identity = request_metadata
.get(header_name)
.and_then(|header| header.to_str().ok().map(str::to_string));
if identity.is_none() && self.identity_header.required {
return Err(Status::unauthenticated(format!(
"'{header_name}' header is required"
)));
}
Ok(identity)
}

async fn inner_publish_lifecycle_event(
&self,
request: PublishLifecycleEventRequest,
Expand Down Expand Up @@ -243,8 +214,7 @@ impl PublishBuildEvent for BepServer {
&self,
grpc_request: Request<PublishLifecycleEventRequest>,
) -> Result<Response<()>, Status> {
let identity = self.get_identity(grpc_request.metadata())?;
self.inner_publish_lifecycle_event(grpc_request.into_inner(), identity)
self.inner_publish_lifecycle_event(grpc_request.into_inner(), get_identity()?)
.await
.map_err(Error::into)
}
Expand All @@ -260,8 +230,7 @@ impl PublishBuildEvent for BepServer {
&self,
grpc_request: Request<Streaming<PublishBuildToolEventStreamRequest>>,
) -> Result<Response<Self::PublishBuildToolEventStreamStream>, Status> {
let identity = self.get_identity(grpc_request.metadata())?;
self.inner_publish_build_tool_event_stream(grpc_request.into_inner(), identity)
self.inner_publish_build_tool_event_stream(grpc_request.into_inner(), get_identity()?)
.await
.map_err(Error::into)
}
Expand Down
3 changes: 1 addition & 2 deletions nativelink-service/tests/bep_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;

use futures::StreamExt;
use hyper::body::Frame;
use nativelink_config::cas_server::{BepConfig, IdentityHeaderSpec};
use nativelink_config::cas_server::BepConfig;
use nativelink_config::stores::{MemorySpec, StoreSpec};
use nativelink_error::{Error, ResultExt};
use nativelink_macro::nativelink_test;
Expand Down Expand Up @@ -69,7 +69,6 @@ fn make_bep_server(store_manager: &StoreManager) -> Result<BepServer, Error> {
BepServer::new(
&BepConfig {
store: BEP_STORE_NAME.to_string(),
experimental_identity_header: IdentityHeaderSpec::default(),
},
store_manager,
)
Expand Down
4 changes: 4 additions & 0 deletions nativelink-util/src/origin_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ macro_rules! make_symbol {
};
}

// Symbol that represents the identity of the origin of a request.
// See: IdentityHeaderSpec for details.
make_symbol!(ORIGIN_IDENTITY, String);

pub struct NLSymbol<T: Send + Sync + 'static> {
pub name: &'static str,
pub _phantom: std::marker::PhantomData<T>,
Expand Down
41 changes: 21 additions & 20 deletions nativelink-util/src/origin_event_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use tower::layer::Layer;
use tower::Service;
use tracing::trace_span;

use crate::origin_context::OriginContext;
use crate::origin_context::{ActiveOriginContext, ORIGIN_IDENTITY};
use crate::origin_event::{OriginEventCollector, ORIGIN_EVENT_COLLECTOR};

/// Default identity header name.
Expand All @@ -45,17 +45,17 @@ pub struct OriginRequestMetadata {

#[derive(Clone)]
pub struct OriginEventMiddlewareLayer {
origin_event_tx: mpsc::Sender<OriginEvent>,
maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>,
idenity_header_config: Arc<IdentityHeaderSpec>,
}

impl OriginEventMiddlewareLayer {
pub fn new(
origin_event_tx: mpsc::Sender<OriginEvent>,
maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>,
idenity_header_config: IdentityHeaderSpec,
) -> Self {
Self {
origin_event_tx,
maybe_origin_event_tx,
idenity_header_config: Arc::new(idenity_header_config),
}
}
Expand All @@ -67,7 +67,7 @@ impl<S> Layer<S> for OriginEventMiddlewareLayer {
fn layer(&self, service: S) -> Self::Service {
OriginEventMiddleware {
inner: service,
origin_event_tx: self.origin_event_tx.clone(),
maybe_origin_event_tx: self.maybe_origin_event_tx.clone(),
idenity_header_config: self.idenity_header_config.clone(),
}
}
Expand All @@ -76,7 +76,7 @@ impl<S> Layer<S> for OriginEventMiddlewareLayer {
#[derive(Clone)]
pub struct OriginEventMiddleware<S> {
inner: S,
origin_event_tx: mpsc::Sender<OriginEvent>,
maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>,
idenity_header_config: Arc<IdentityHeaderSpec>,
}

Expand All @@ -101,7 +101,8 @@ where
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);

let (identity, bazel_metadata) = {
let mut context = ActiveOriginContext::fork().unwrap_or_default();
let identity = {
let identity_header = self
.idenity_header_config
.header_name
Expand All @@ -124,24 +125,24 @@ where
.unwrap())
});
}

context.set_value(&ORIGIN_IDENTITY, Arc::new(identity.clone()));
identity
};
if let Some(origin_event_tx) = &self.maybe_origin_event_tx {
let bazel_metadata = req
.headers()
.get("build.bazel.remote.execution.v2.requestmetadata-bin")
.and_then(|header| BASE64_STANDARD_NO_PAD.decode(header.as_bytes()).ok())
.and_then(|data| RequestMetadata::decode(data.as_slice()).ok());
(identity, bazel_metadata)
};

let mut context = OriginContext::new();
context.set_value(
&ORIGIN_EVENT_COLLECTOR,
Arc::new(OriginEventCollector::new(
self.origin_event_tx.clone(),
identity,
bazel_metadata,
)),
);
context.set_value(
&ORIGIN_EVENT_COLLECTOR,
Arc::new(OriginEventCollector::new(
origin_event_tx.clone(),
identity,
bazel_metadata,
)),
);
}

Box::pin(async move {
Arc::new(context)
Expand Down
61 changes: 30 additions & 31 deletions src/bin/nativelink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,30 @@ async fn inner_main(
schedulers: action_schedulers.clone(),
}));

let maybe_origin_event_tx = cfg
.experimental_origin_events
.as_ref()
.map(|origin_events_cfg| {
let mut max_queued_events = origin_events_cfg.max_event_queue_size;
if max_queued_events == 0 {
max_queued_events = DEFAULT_MAX_QUEUE_EVENTS;
}
let (tx, rx) = mpsc::channel(max_queued_events);
let store_name = origin_events_cfg.publisher.store.as_str();
let store = store_manager.get_store(store_name).err_tip(|| {
format!("Could not get store {store_name} for origin event publisher")
})?;

root_futures.push(Box::pin(
OriginEventPublisher::new(store, rx, shutdown_tx.clone())
.run()
.map(Ok),
));

Ok::<_, Error>(tx)
})
.transpose()?;

for (server_cfg, connected_clients_mux) in servers_and_clients {
let services = server_cfg
.services
Expand Down Expand Up @@ -457,37 +481,12 @@ async fn inner_main(

let health_registry = health_registry_builder.lock().await.build();

let mut svc = Router::new();
let maybe_middleware_layer = cfg
.experimental_origin_events
.as_ref()
.map(|origin_events_cfg| {
let mut max_queued_events = origin_events_cfg.max_event_queue_size;
if max_queued_events == 0 {
max_queued_events = DEFAULT_MAX_QUEUE_EVENTS;
}
let (tx, rx) = mpsc::channel(max_queued_events);
let store_name = origin_events_cfg.publisher.store.as_str();
let store = store_manager.get_store(store_name).err_tip(|| {
format!("Could not get store {store_name} for origin event publisher")
})?;

root_futures.push(Box::pin(
OriginEventPublisher::new(store, rx, shutdown_tx.clone())
.run()
.map(Ok),
));

Ok::<_, Error>((tx, origin_events_cfg))
})
.transpose()?
.map(|(tx, cfg)| OriginEventMiddlewareLayer::new(tx, cfg.identity_header.clone()));
let tonic_axum_rounter = tonic_services.into_service().into_axum_router();
svc = if let Some(middleware) = maybe_middleware_layer {
svc.merge(tonic_axum_rounter.layer(middleware))
} else {
svc.merge(tonic_axum_rounter)
};
let mut svc = Router::new().merge(tonic_services.into_service().into_axum_router().layer(
OriginEventMiddlewareLayer::new(
maybe_origin_event_tx.clone(),
server_cfg.experimental_identity_header.clone(),
),
));

if let Some(health_cfg) = services.health {
let path = if health_cfg.path.is_empty() {
Expand Down

0 comments on commit 45520d9

Please sign in to comment.