Skip to content

Commit

Permalink
Add custom certificate extenstion (#221)
Browse files Browse the repository at this point in the history
* Add custom certificate extenstion

- Validate certificate extension and premissions on quic messages

* Add new cert extenstion to sms control

* Fix formatting errors

* Fix cargo deny issues

* Fix formatting errors
  • Loading branch information
kosticmarin authored Sep 23, 2023
1 parent 81dd26f commit 4cb0ed4
Show file tree
Hide file tree
Showing 9 changed files with 432 additions and 306 deletions.
527 changes: 258 additions & 269 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions crates/lunatic-control-axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ getrandom = "0.2.8"
http = "0.2.8"
log = { workspace = true }
rcgen = "0.10"
asn1-rs = "0.5.2"
serde = { workspace = true }
serde_json = "1.0.89"
tokio = { workspace = true, features = ["io-util", "rt", "sync", "time", "fs"] }
Expand Down
30 changes: 26 additions & 4 deletions crates/lunatic-control-axum/src/routes.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use std::{collections::HashMap, sync::Arc};

use asn1_rs::ToDer;
use axum::{
body::Bytes,
extract::{DefaultBodyLimit, Query},
routing::{get, post},
Extension, Json, Router,
};
use lunatic_control::{api::*, NodeInfo};
use lunatic_distributed::control::cert::TEST_ROOT_CERT;
use rcgen::CertificateSigningRequest;
use lunatic_distributed::{control::cert::TEST_ROOT_CERT, CertAttrs, SUBJECT_DIR_ATTRS};
use rcgen::{CertificateSigningRequest, CustomExtension};
use tower_http::limit::RequestBodyLimitLayer;

use crate::{
Expand All @@ -24,8 +25,29 @@ pub async fn register(
log::info!("Registration for node name {}", reg.node_name);

let control = control.as_ref();
let cert_pem = CertificateSigningRequest::from_pem(&reg.csr_pem)
.and_then(|sign_request| sign_request.serialize_pem_with_signer(&control.ca_cert))

let mut sign_request = CertificateSigningRequest::from_pem(&reg.csr_pem).map_err(|e| {
ApiError::custom(
"sign_error",
format!("Certificate Signing Request invalid pem format: {}", e),
)
})?;
// Add json to custom certificate extension
sign_request
.params
.custom_extensions
.push(CustomExtension::from_oid_content(
&SUBJECT_DIR_ATTRS,
serde_json::to_string(&CertAttrs {
allowed_envs: vec![],
is_privileged: true,
})
.unwrap()
.to_der_vec()
.map_err(|e| ApiError::log_internal("Error serializing allowed envs to der", e))?,
));
let cert_pem = sign_request
.serialize_pem_with_signer(&control.ca_cert)
.map_err(|e| ApiError::custom("sign_error", e.to_string()))?;

let mut authentication_token = [0u8; 32];
Expand Down
2 changes: 2 additions & 0 deletions crates/lunatic-distributed-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ lunatic-error-api = { workspace = true }
lunatic-process = { workspace = true }
lunatic-process-api = { workspace = true }

asn1-rs = "0.5.2"
anyhow = { workspace = true }
bincode = { workspace = true }
rcgen = { version = "0.10", features = ["pem", "x509-parser"] }
rmp-serde = "1.1.1"
log = { workspace = true }
serde_json = "1.0.89"
tokio = { workspace = true, features = ["time"] }
wasmtime = { workspace = true }
29 changes: 21 additions & 8 deletions crates/lunatic-distributed-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
use std::{future::Future, sync::Arc, time::Duration};

use anyhow::{anyhow, Result};
use asn1_rs::ToDer;
use lunatic_common_api::{get_memory, write_to_guest_vec, IntoTrap};
use lunatic_distributed::{
distributed::{
self,
client::{EnvironmentId, NodeId, ProcessId, SendParams, SpawnParams},
message::{ClientError, Spawn, Val},
},
DistributedCtx,
CertAttrs, DistributedCtx, SUBJECT_DIR_ATTRS,
};
use lunatic_error_api::ErrorCtx;
use lunatic_process::{
env::Environment,
message::{DataMessage, Message},
};
use lunatic_process_api::ProcessCtx;
use rcgen::{Certificate, CertificateParams, CertificateSigningRequest, KeyPair};
use rcgen::{Certificate, CertificateParams, CertificateSigningRequest, CustomExtension, KeyPair};
use tokio::time::timeout;
use wasmtime::{Caller, Linker, ResourceLimiter};

Expand Down Expand Up @@ -319,12 +320,24 @@ where
let ca_cert =
Certificate::from_params(cert_params).or_trap("lunatic::distributed::sign_node")?;

let Ok(cert_pem) = CertificateSigningRequest::from_pem(csr_pem)
.and_then(|sign_request| sign_request.serialize_pem_with_signer(&ca_cert))
else {
return Ok(0);
};

let mut csr = CertificateSigningRequest::from_pem(csr_pem)
.or_trap("lunatic::distributed::sign_node")?;
// Add json to custom certificate extension
csr.params
.custom_extensions
.push(CustomExtension::from_oid_content(
&SUBJECT_DIR_ATTRS,
serde_json::to_string(&CertAttrs {
allowed_envs: vec![],
is_privileged: true,
})
.or_trap("lunatic::distributed::sign_node")?
.to_der_vec()
.or_trap("lunatic::distributed::sign_node")?,
));
let cert_pem = csr
.serialize_pem_with_signer(&ca_cert)
.or_trap("lunatic::distributed::sign_node")?;
let data = bincode::serialize(&cert_pem).or_trap("lunatic::distributed::sign_node")?;
let ptr = write_to_guest_vec(&mut caller, &memory, &data, len_ptr)
.await
Expand Down
2 changes: 2 additions & 0 deletions crates/lunatic-distributed/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ reqwest = { workspace = true, features = ["json"] }
rustls = { version = "0.21.6" }
rustls-pemfile = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = "1.0.89"
tokio = { workspace = true, features = ["io-util", "rt", "sync", "time"] }
uuid = { version = "1.0", features = ["serde", "v4"] }
wasmtime = { workspace = true }
x509-parser = "0.14.0"
79 changes: 59 additions & 20 deletions crates/lunatic-distributed/src/distributed/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use wasmtime::ResourceLimiter;

use crate::{
distributed::message::{Request, Response},
quic::{self},
quic::{self, NodeEnvPermission},
DistributedCtx, DistributedProcessState,
};

Expand Down Expand Up @@ -82,21 +82,75 @@ where
Ok(())
}

pub async fn handle_message<T, E>(ctx: ServerCtx<T, E>, msg_id: u64, msg: Request)
where
pub async fn handle_message<T, E>(
ctx: ServerCtx<T, E>,
msg_id: u64,
msg: Request,
node_permissions: Arc<NodeEnvPermission>,
) where
T: ProcessState + DistributedCtx<E> + ResourceLimiter + Send + Sync + 'static,
E: Environment + 'static,
{
if let Err(e) = handle_message_err(ctx, msg_id, msg).await {
if let Err(e) = handle_message_err(ctx, msg_id, msg, node_permissions).await {
log::error!("Error handling message: {e}");
}
}

async fn handle_message_err<T, E>(ctx: ServerCtx<T, E>, msg_id: u64, msg: Request) -> Result<()>
async fn handle_message_err<T, E>(
ctx: ServerCtx<T, E>,
msg_id: u64,
msg: Request,
node_permissions: Arc<NodeEnvPermission>,
) -> Result<()>
where
T: ProcessState + DistributedCtx<E> + ResourceLimiter + Send + Sync + 'static,
E: Environment + 'static,
{
let env_id = match &msg {
Request::Spawn(spawn) => Some((spawn.response_node_id, spawn.environment_id)),
Request::Message {
node_id,
environment_id,
process_id: _,
tag: _,
data: _,
} => Some((*node_id, *environment_id)),
Request::Response(_) => None,
};
if let Some((node_id, env_id)) = env_id {
if let Some(ref allowed_envs) = node_permissions.0 {
if !allowed_envs.contains(&env_id) {
ctx.node_client
.send_response(ResponseParams {
node_id: NodeId(node_id),
response: Response {
message_id: msg_id,
content: ResponseContent::Error(ClientError::Unexpected(format!(
"The node sending the request does not have access to the environment {env_id}"
))),
},
})
.await?;
return Ok(());
}
}
if let Some(ref allowed_envs) = ctx.allowed_envs {
if !allowed_envs.contains(&env_id) {
ctx.node_client
.send_response(ResponseParams {
node_id: NodeId(node_id),
response: Response {
message_id: msg_id,
content: ResponseContent::Error(ClientError::Unexpected(format!(
"This node does not have access to environment {env_id}"
))),
},
})
.await?;
return Ok(());
}
}
}
match msg {
Request::Spawn(spawn) => {
log::trace!("lunatic::distributed::server process Spawn");
Expand Down Expand Up @@ -196,14 +250,6 @@ where
config,
..
} = spawn;

if let Some(ref allowed_envs) = ctx.allowed_envs {
if !allowed_envs.contains(&environment_id) {
return Ok(Err(ClientError::Unexpected(format!(
"This node does not have access to environment {environment_id}"
))));
}
}
let config: T::Config = rmp_serde::from_slice(&config[..])?;
let config = Arc::new(config);

Expand Down Expand Up @@ -261,13 +307,6 @@ where
T: ProcessState + DistributedCtx<E> + ResourceLimiter + Send + 'static,
E: Environment,
{
if let Some(ref allowed_envs) = ctx.allowed_envs {
if !allowed_envs.contains(&environment_id) {
return Err(ClientError::Unexpected(format!(
"This node does not have access to environment {environment_id}"
)));
}
}
let env = ctx.envs.get(environment_id).await;
if let Some(env) = env {
if let Some(proc) = env.get_process(process_id) {
Expand Down
9 changes: 9 additions & 0 deletions crates/lunatic-distributed/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use lunatic_process::{
runtimes::wasmtime::{WasmtimeCompiledModule, WasmtimeRuntime},
state::ProcessState,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;

pub trait DistributedCtx<E: Environment>: ProcessState + Sized {
Expand Down Expand Up @@ -50,3 +51,11 @@ impl DistributedProcessState {
self.node_id
}
}

pub const SUBJECT_DIR_ATTRS: [u64; 4] = [2, 5, 29, 9];

#[derive(Debug, Serialize, Deserialize)]
pub struct CertAttrs {
pub allowed_envs: Vec<u64>,
pub is_privileged: bool,
}
59 changes: 54 additions & 5 deletions crates/lunatic-distributed/src/quic/quin.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use std::{net::SocketAddr, sync::Arc, time::Duration};
use std::{collections::HashSet, net::SocketAddr, sync::Arc, time::Duration};

use anyhow::{anyhow, Result};
use bytes::Bytes;
use dashmap::DashMap;
use lunatic_process::{env::Environment, state::ProcessState};
use quinn::{ClientConfig, Connecting, ConnectionError, Endpoint, ServerConfig};
use quinn::{ClientConfig, Connecting, Connection, ConnectionError, Endpoint, ServerConfig};
use rustls::server::AllowAnyAuthenticatedClient;
use rustls_pemfile::Item;
use wasmtime::ResourceLimiter;
use x509_parser::{der_parser::oid, oid_registry::asn1_rs::Utf8String, prelude::FromDer};

use crate::{
distributed::{self},
DistributedCtx,
CertAttrs, DistributedCtx,
};

#[derive(Clone)]
Expand Down Expand Up @@ -43,6 +44,28 @@ impl Client {
}
}

fn get_cert_attrs(conn: &Connection) -> Result<CertAttrs> {
let peer_identity = match conn
.peer_identity()
.ok_or(anyhow!("Peer must provide an identity."))?
.downcast::<Vec<rustls::Certificate>>()
{
Ok(certs) => Ok(certs),
Err(_) => Err(anyhow!("Failed to downcast peer identity.")),
}?;
if peer_identity.len() != 1 {
return Err(anyhow!("More than one identity certificate detected."));
}
let cert = peer_identity.get(0).unwrap();
let (_rem, x509) = x509_parser::certificate::X509Certificate::from_der(&cert.0)?;
let oid = oid!(2.5.29 .9);
let ext = x509
.get_extension_unique(&oid)?
.ok_or_else(|| anyhow!("Missing critical Lunatic certificate extension."))?;
let (_rem, value) = Utf8String::from_der(ext.value)?;
Ok(serde_json::from_str(&value.string())?)
}

pub fn new_quic_client(ca_cert: &str, cert: &str, key: &str) -> Result<Client> {
let mut ca_cert = ca_cert.as_bytes();
let ca_cert = rustls_pemfile::read_one(&mut ca_cert)?.unwrap();
Expand Down Expand Up @@ -137,6 +160,19 @@ where
Err(anyhow!("Node server exited"))
}

pub struct NodeEnvPermission(pub Option<HashSet<u64>>);

impl NodeEnvPermission {
fn new(cert_attrs: CertAttrs) -> Self {
let some_set: Option<HashSet<u64>> = if cert_attrs.is_privileged {
None
} else {
Some(cert_attrs.allowed_envs.into_iter().collect())
};
Self(some_set)
}
}

async fn handle_quic_connection_node<T, E>(
ctx: distributed::server::ServerCtx<T, E>,
conn: Connecting,
Expand All @@ -147,6 +183,8 @@ where
{
log::info!("New node connection");
let conn = conn.await?;
let node_cert_attrs = get_cert_attrs(&conn)?;
let node_permissions = Arc::new(NodeEnvPermission::new(node_cert_attrs));
log::info!("Remote {} connected", conn.remote_address());
loop {
if let Some(reason) = conn.close_reason() {
Expand All @@ -157,7 +195,11 @@ where
log::info!("Stream from remote {} accepted", conn.remote_address());
match stream {
Ok(recv) => {
tokio::spawn(handle_quic_stream_node(ctx.clone(), recv));
tokio::spawn(handle_quic_stream_node(
ctx.clone(),
recv,
node_permissions.clone(),
));
}
Err(ConnectionError::LocallyClosed) => {
log::trace!("distributed::server::stream locally closed");
Expand All @@ -173,6 +215,7 @@ where
async fn handle_quic_stream_node<T, E>(
ctx: distributed::server::ServerCtx<T, E>,
recv: quinn::RecvStream,
node_permissions: Arc<NodeEnvPermission>,
) where
T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
E: Environment + 'static,
Expand All @@ -184,7 +227,13 @@ async fn handle_quic_stream_node<T, E>(
log::trace!("distributed::server::handle_quic_stream started");
while let Ok((msg_id, bytes)) = read_next_stream_message(&mut recv_ctx).await {
if let Ok(request) = rmp_serde::from_slice::<distributed::message::Request>(&bytes) {
distributed::server::handle_message(ctx.clone(), msg_id, request).await;
distributed::server::handle_message(
ctx.clone(),
msg_id,
request,
node_permissions.clone(),
)
.await;
} else {
log::debug!("Error deserializing request");
}
Expand Down

0 comments on commit 4cb0ed4

Please sign in to comment.