Skip to content

Commit

Permalink
Server for HSM demo (#962)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygolay authored Dec 18, 2024
1 parent 642227e commit 873d1ff
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 25 deletions.
19 changes: 19 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions demo/hsm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,23 @@ publish = { workspace = true }
rust-version = { workspace = true }

[dependencies]
tokio = { workspace = true }
tokio = { workspace = true, features = ["full"] }
async-trait = { workspace = true }
anyhow = { workspace = true }
vaultrs = { workspace = true }
anyhow = { workspace = true }
aws-sdk-kms = { workspace = true }
aws-config = { workspace = true }
rand = { workspace = true }
base64 = { workspace = true }
dotenv = "0.15"
ed25519 = { workspace = true }
ring-compat = { workspace = true }
k256 = { workspace = true, features = ["ecdsa", "pkcs8"] }
google-cloud-kms = { workspace = true }
reqwest = { version = "0.12", features = ["json"] }
axum = "0.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

[lints]
workspace = true
5 changes: 4 additions & 1 deletion demo/hsm/src/hsm/aws_kms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use anyhow::Context;
use aws_sdk_kms::primitives::Blob;
use aws_sdk_kms::types::{KeySpec, KeyUsageType, SigningAlgorithmSpec};
use aws_sdk_kms::Client;
use dotenv::dotenv;
use k256::ecdsa::{self, VerifyingKey};
use k256::pkcs8::DecodePublicKey;
use ring_compat::signature::Verifier;
Expand All @@ -11,7 +12,7 @@ use ring_compat::signature::Verifier;
pub struct AwsKms {
client: Client,
key_id: String,
public_key: PublicKey,
pub public_key: PublicKey,
}

impl AwsKms {
Expand All @@ -22,6 +23,7 @@ impl AwsKms {

/// Tries to create a new AWS KMS HSM from the environment
pub async fn try_from_env() -> Result<Self, anyhow::Error> {
dotenv().ok();
let key_id = std::env::var("AWS_KMS_KEY_ID").context("AWS_KMS_KEY_ID not set")?;
let public_key = std::env::var("AWS_KMS_PUBLIC_KEY").unwrap_or_default();

Expand Down Expand Up @@ -49,6 +51,7 @@ impl AwsKms {
/// Fills the public key from the key id
pub async fn fill_with_public_key(mut self) -> Result<Self, anyhow::Error> {
let res = self.client.get_public_key().key_id(&self.key_id).send().await?;
println!("AWS KMS Response: {:?}", res);
let public_key = PublicKey(Bytes(
res.public_key().context("No public key available")?.as_ref().to_vec(),
));
Expand Down
61 changes: 61 additions & 0 deletions demo/hsm/src/hsm/cli.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
mod cli;
mod hsm;

use anyhow::Result;
use clap::Parser;
use cli::{Cli, Service};
use hsm::{aws::AwsKms, google::GoogleKms, vault::HashiCorpVault};
use dotenv::dotenv;
use hsm_demo::{action_stream, Application};

#[tokio::main]
async fn main() -> Result<()> {
dotenv().ok(); // Load environment variables from .env file
let cli = Cli::parse();

// Select the HSM implementation based on CLI input
let hsm = match cli.service {
Service::Aws(args) => {
println!("Using AWS KMS with {:?} key", args.key_type);
AwsKms::try_from_env()
.await?
.create_key()
.await?
.fill_with_public_key()
.await?
}
Service::Gcp(args) => {
println!("Using Google Cloud KMS with {:?} key", args.key_type);
GoogleKms::try_from_env()
.await?
.create_key_ring()
.await?
.create_key()
.await?
.fill_with_public_key()
.await?
}
Service::Vault(args) => {
println!("Using HashiCorp Vault with {:?} key", args.key_type);
HashiCorpVault::try_from_env()
.and_then(|vault| vault.create_key())
.await?
.fill_with_public_key()
.await?
}
};

// Initialize the streams
let random_stream = action_stream::random::Random;
let notify_verify_stream = action_stream::notify_verify::NotifyVerify::new();
let join_stream = action_stream::join::Join::new(vec![
Box::new(random_stream),
Box::new(notify_verify_stream),
]);

// Run the application
let mut app = Application::new(Box::new(hsm), Box::new(join_stream));
app.run().await?;

Ok(())
}
2 changes: 1 addition & 1 deletion demo/hsm/src/hsm/hashi_corp_vault.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct HashiCorpVault {
client: VaultClient,
key_name: String,
mount_name: String,
public_key: PublicKey,
pub public_key: PublicKey,
}

impl HashiCorpVault {
Expand Down
1 change: 1 addition & 0 deletions demo/hsm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod action_stream;
pub mod hsm;
pub mod server;

/// A collection of bytes.
#[derive(Debug, Clone)]
Expand Down
156 changes: 135 additions & 21 deletions demo/hsm/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,138 @@
use hsm_demo::{action_stream, hsm, Application};
use axum::Server;
use hsm_demo::{hsm, Bytes, Hsm, PublicKey, Signature};
use reqwest::Client;
use serde::Serialize;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task;
use dotenv::dotenv;

use hsm_demo::{action_stream, Application};
use hsm_demo::server::create_server;

#[tokio::main]
pub async fn main() -> Result<(), anyhow::Error> {
let random_stream = action_stream::random::Random;
let notify_verify_stream = action_stream::notify_verify::NotifyVerify::new();
let join_stream = action_stream::join::Join::new(vec![
Box::new(random_stream),
Box::new(notify_verify_stream),
]);

let hsm = hsm::aws_kms::AwsKms::try_from_env()
.await?
.create_key()
.await?
.fill_with_public_key()
.await?;

let mut app = Application::new(Box::new(hsm), Box::new(join_stream));

app.run().await?;

Ok(())
async fn main() -> Result<(), anyhow::Error> {
dotenv().ok(); // Load environment variables from .env file

// Initialize HSM based on PROVIDER
let provider = std::env::var("PROVIDER").unwrap_or_else(|_| "AWS".to_string());
let (hsm, public_key) = match provider.as_str() {
"AWS" => {
let aws_kms_hsm = hsm::aws_kms::AwsKms::try_from_env()
.await?
.create_key()
.await?
.fill_with_public_key()
.await?;
let public_key = aws_kms_hsm.public_key.clone();
(Arc::new(Mutex::new(aws_kms_hsm)) as Arc<Mutex<dyn hsm_demo::Hsm + Send + Sync>>, public_key)
}
"VAULT" => {
let vault_hsm = hsm::hashi_corp_vault::HashiCorpVault::try_from_env()?
.create_key()
.await?
.fill_with_public_key()
.await?;
let public_key = vault_hsm.public_key.clone();
(Arc::new(Mutex::new(vault_hsm)) as Arc<Mutex<dyn hsm_demo::Hsm + Send + Sync>>, public_key)
}
_ => {
return Err(anyhow::anyhow!("Unsupported provider: {}", provider));
}
};

// Start the server task
let server_hsm = hsm.clone();
let server_task = task::spawn(async move {
let app = create_server(server_hsm);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("Server listening on {}", addr);

Server::bind(&addr)
.serve(app.into_make_service())
.await
.expect("Server failed");
});

tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;

// Start the Application
let client = Client::new();
let random_stream = action_stream::random::Random;
let notify_verify_stream = action_stream::notify_verify::NotifyVerify::new();
let join_stream = action_stream::join::Join::new(vec![
Box::new(random_stream),
Box::new(notify_verify_stream),
]);

// Replace HSM with the HTTP client that sends requests to the server
let hsm_proxy = HttpHsmProxy::new(client, "http://127.0.0.1:3000/sign".to_string(), public_key);
let mut app = Application::new(Box::new(hsm_proxy), Box::new(join_stream));

app.run().await?;

server_task.await?;
Ok(())
}

#[derive(Serialize)]
struct SignRequest {
message: Vec<u8>,
}

#[derive(serde::Deserialize)]
struct SignedResponse {
signature: Vec<u8>,
}

pub struct HttpHsmProxy {
client: Client,
server_url: String,
public_key: PublicKey,
}

impl HttpHsmProxy {
pub fn new(client: Client, server_url: String, public_key: PublicKey) -> Self {
Self { client, server_url, public_key }
}

pub fn get_public_key(&self) -> PublicKey {
self.public_key.clone()
}
}

#[async_trait::async_trait]
impl Hsm for HttpHsmProxy {
async fn sign(
&self,
message: Bytes,
) -> Result<(Bytes, PublicKey, Signature), anyhow::Error> {
let payload = SignRequest { message: message.0.clone() };

let response = self
.client
.post(&self.server_url)
.json(&payload)
.send()
.await?
.json::<SignedResponse>()
.await?;

let signature = Signature(Bytes(response.signature));

// Return the stored public key along with the signature
Ok((message, self.public_key.clone(), signature))
}

async fn verify(
&self,
_message: Bytes,
_public_key: PublicKey,
_signature: Signature,
) -> Result<bool, anyhow::Error> {
// Verification would need another endpoint or can be skipped because Application already verifies
Ok(true)
}
}

44 changes: 44 additions & 0 deletions demo/hsm/src/server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use axum::{
routing::post,
extract::State,
Json, Router,
http::StatusCode,
};
use std::sync::Arc;
use tokio::sync::Mutex;

use crate::{Bytes, Hsm};

pub fn create_server(hsm: Arc<Mutex<dyn Hsm + Send + Sync>>) -> Router {
Router::new()
.route("/sign", post(sign_handler))
.with_state(hsm)
}

async fn sign_handler(
State(hsm): State<Arc<Mutex<dyn Hsm + Send + Sync>>>,
Json(payload): Json<SignRequest>,
) -> Result<Json<SignedResponse>, StatusCode> {
let message_bytes = Bytes(payload.message);

let (_message, _public_key, signature) = hsm
.lock()
.await
.sign(message_bytes)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

Ok(Json(SignedResponse {
signature: signature.0 .0,
}))
}

#[derive(serde::Deserialize)]
pub struct SignRequest {
pub message: Vec<u8>,
}

#[derive(serde::Serialize)]
pub struct SignedResponse {
pub signature: Vec<u8>,
}

0 comments on commit 873d1ff

Please sign in to comment.