From 873d1ffe5581c2b24556ae6c7757b5ef895b2033 Mon Sep 17 00:00:00 2001 From: Andy Golay Date: Wed, 18 Dec 2024 03:56:19 -0500 Subject: [PATCH] Server for HSM demo (#962) --- Cargo.lock | 19 ++++ demo/hsm/Cargo.toml | 9 +- demo/hsm/src/hsm/aws_kms.rs | 5 +- demo/hsm/src/hsm/cli.rs | 61 +++++++++++ demo/hsm/src/hsm/hashi_corp_vault.rs | 2 +- demo/hsm/src/lib.rs | 1 + demo/hsm/src/main.rs | 156 +++++++++++++++++++++++---- demo/hsm/src/server.rs | 44 ++++++++ 8 files changed, 272 insertions(+), 25 deletions(-) create mode 100644 demo/hsm/src/hsm/cli.rs create mode 100644 demo/hsm/src/server.rs diff --git a/Cargo.lock b/Cargo.lock index b3b1c12e8..8318b0ca2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3606,7 +3606,11 @@ dependencies = [ "pin-project-lite", "rustversion", "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", "sync_wrapper 0.1.2", + "tokio", "tower 0.4.13", "tower-layer", "tower-service", @@ -7522,12 +7526,17 @@ dependencies = [ "async-trait", "aws-config", "aws-sdk-kms", + "axum 0.6.20", "base64 0.13.1", + "dotenv", "ed25519 2.2.3", "google-cloud-kms", "k256", "rand 0.7.3", + "reqwest 0.12.9", "ring-compat", + "serde", + "serde_json", "tokio", "vaultrs", ] @@ -13624,6 +13633,16 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_repr" version = "0.1.19" diff --git a/demo/hsm/Cargo.toml b/demo/hsm/Cargo.toml index 047e4bfaf..63d08ec1f 100644 --- a/demo/hsm/Cargo.toml +++ b/demo/hsm/Cargo.toml @@ -10,18 +10,23 @@ publish = { workspace = true } rust-version = { workspace = true } [dependencies] -tokio = { workspace = true } +tokio = { workspace = true, features = ["full"] } async-trait = { workspace = true } -anyhow = { workspace = true } vaultrs = { workspace = true } +anyhow = { workspace = true } aws-sdk-kms = { workspace = true } aws-config = { workspace = true } rand = { workspace = true } base64 = { workspace = true } +dotenv = "0.15" ed25519 = { workspace = true } ring-compat = { workspace = true } k256 = { workspace = true, features = ["ecdsa", "pkcs8"] } google-cloud-kms = { workspace = true } +reqwest = { version = "0.12", features = ["json"] } +axum = "0.6" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" [lints] workspace = true diff --git a/demo/hsm/src/hsm/aws_kms.rs b/demo/hsm/src/hsm/aws_kms.rs index eb952432e..23ccab0e5 100644 --- a/demo/hsm/src/hsm/aws_kms.rs +++ b/demo/hsm/src/hsm/aws_kms.rs @@ -3,6 +3,7 @@ use anyhow::Context; use aws_sdk_kms::primitives::Blob; use aws_sdk_kms::types::{KeySpec, KeyUsageType, SigningAlgorithmSpec}; use aws_sdk_kms::Client; +use dotenv::dotenv; use k256::ecdsa::{self, VerifyingKey}; use k256::pkcs8::DecodePublicKey; use ring_compat::signature::Verifier; @@ -11,7 +12,7 @@ use ring_compat::signature::Verifier; pub struct AwsKms { client: Client, key_id: String, - public_key: PublicKey, + pub public_key: PublicKey, } impl AwsKms { @@ -22,6 +23,7 @@ impl AwsKms { /// Tries to create a new AWS KMS HSM from the environment pub async fn try_from_env() -> Result { + dotenv().ok(); let key_id = std::env::var("AWS_KMS_KEY_ID").context("AWS_KMS_KEY_ID not set")?; let public_key = std::env::var("AWS_KMS_PUBLIC_KEY").unwrap_or_default(); @@ -49,6 +51,7 @@ impl AwsKms { /// Fills the public key from the key id pub async fn fill_with_public_key(mut self) -> Result { let res = self.client.get_public_key().key_id(&self.key_id).send().await?; + println!("AWS KMS Response: {:?}", res); let public_key = PublicKey(Bytes( res.public_key().context("No public key available")?.as_ref().to_vec(), )); diff --git a/demo/hsm/src/hsm/cli.rs b/demo/hsm/src/hsm/cli.rs new file mode 100644 index 000000000..fc909558c --- /dev/null +++ b/demo/hsm/src/hsm/cli.rs @@ -0,0 +1,61 @@ +mod cli; +mod hsm; + +use anyhow::Result; +use clap::Parser; +use cli::{Cli, Service}; +use hsm::{aws::AwsKms, google::GoogleKms, vault::HashiCorpVault}; +use dotenv::dotenv; +use hsm_demo::{action_stream, Application}; + +#[tokio::main] +async fn main() -> Result<()> { + dotenv().ok(); // Load environment variables from .env file + let cli = Cli::parse(); + + // Select the HSM implementation based on CLI input + let hsm = match cli.service { + Service::Aws(args) => { + println!("Using AWS KMS with {:?} key", args.key_type); + AwsKms::try_from_env() + .await? + .create_key() + .await? + .fill_with_public_key() + .await? + } + Service::Gcp(args) => { + println!("Using Google Cloud KMS with {:?} key", args.key_type); + GoogleKms::try_from_env() + .await? + .create_key_ring() + .await? + .create_key() + .await? + .fill_with_public_key() + .await? + } + Service::Vault(args) => { + println!("Using HashiCorp Vault with {:?} key", args.key_type); + HashiCorpVault::try_from_env() + .and_then(|vault| vault.create_key()) + .await? + .fill_with_public_key() + .await? + } + }; + + // Initialize the streams + let random_stream = action_stream::random::Random; + let notify_verify_stream = action_stream::notify_verify::NotifyVerify::new(); + let join_stream = action_stream::join::Join::new(vec![ + Box::new(random_stream), + Box::new(notify_verify_stream), + ]); + + // Run the application + let mut app = Application::new(Box::new(hsm), Box::new(join_stream)); + app.run().await?; + + Ok(()) +} diff --git a/demo/hsm/src/hsm/hashi_corp_vault.rs b/demo/hsm/src/hsm/hashi_corp_vault.rs index 9b5457132..8e27d82e5 100644 --- a/demo/hsm/src/hsm/hashi_corp_vault.rs +++ b/demo/hsm/src/hsm/hashi_corp_vault.rs @@ -15,7 +15,7 @@ pub struct HashiCorpVault { client: VaultClient, key_name: String, mount_name: String, - public_key: PublicKey, + pub public_key: PublicKey, } impl HashiCorpVault { diff --git a/demo/hsm/src/lib.rs b/demo/hsm/src/lib.rs index f5c863fb5..7aace54e2 100644 --- a/demo/hsm/src/lib.rs +++ b/demo/hsm/src/lib.rs @@ -1,5 +1,6 @@ pub mod action_stream; pub mod hsm; +pub mod server; /// A collection of bytes. #[derive(Debug, Clone)] diff --git a/demo/hsm/src/main.rs b/demo/hsm/src/main.rs index 11a173cbc..9a23eb0e0 100644 --- a/demo/hsm/src/main.rs +++ b/demo/hsm/src/main.rs @@ -1,24 +1,138 @@ -use hsm_demo::{action_stream, hsm, Application}; +use axum::Server; +use hsm_demo::{hsm, Bytes, Hsm, PublicKey, Signature}; +use reqwest::Client; +use serde::Serialize; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::Mutex; +use tokio::task; +use dotenv::dotenv; + +use hsm_demo::{action_stream, Application}; +use hsm_demo::server::create_server; #[tokio::main] -pub async fn main() -> Result<(), anyhow::Error> { - let random_stream = action_stream::random::Random; - let notify_verify_stream = action_stream::notify_verify::NotifyVerify::new(); - let join_stream = action_stream::join::Join::new(vec![ - Box::new(random_stream), - Box::new(notify_verify_stream), - ]); - - let hsm = hsm::aws_kms::AwsKms::try_from_env() - .await? - .create_key() - .await? - .fill_with_public_key() - .await?; - - let mut app = Application::new(Box::new(hsm), Box::new(join_stream)); - - app.run().await?; - - Ok(()) +async fn main() -> Result<(), anyhow::Error> { + dotenv().ok(); // Load environment variables from .env file + + // Initialize HSM based on PROVIDER + let provider = std::env::var("PROVIDER").unwrap_or_else(|_| "AWS".to_string()); + let (hsm, public_key) = match provider.as_str() { + "AWS" => { + let aws_kms_hsm = hsm::aws_kms::AwsKms::try_from_env() + .await? + .create_key() + .await? + .fill_with_public_key() + .await?; + let public_key = aws_kms_hsm.public_key.clone(); + (Arc::new(Mutex::new(aws_kms_hsm)) as Arc>, public_key) + } + "VAULT" => { + let vault_hsm = hsm::hashi_corp_vault::HashiCorpVault::try_from_env()? + .create_key() + .await? + .fill_with_public_key() + .await?; + let public_key = vault_hsm.public_key.clone(); + (Arc::new(Mutex::new(vault_hsm)) as Arc>, public_key) + } + _ => { + return Err(anyhow::anyhow!("Unsupported provider: {}", provider)); + } + }; + + // Start the server task + let server_hsm = hsm.clone(); + let server_task = task::spawn(async move { + let app = create_server(server_hsm); + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + println!("Server listening on {}", addr); + + Server::bind(&addr) + .serve(app.into_make_service()) + .await + .expect("Server failed"); + }); + + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + // Start the Application + let client = Client::new(); + let random_stream = action_stream::random::Random; + let notify_verify_stream = action_stream::notify_verify::NotifyVerify::new(); + let join_stream = action_stream::join::Join::new(vec![ + Box::new(random_stream), + Box::new(notify_verify_stream), + ]); + + // Replace HSM with the HTTP client that sends requests to the server + let hsm_proxy = HttpHsmProxy::new(client, "http://127.0.0.1:3000/sign".to_string(), public_key); + let mut app = Application::new(Box::new(hsm_proxy), Box::new(join_stream)); + + app.run().await?; + + server_task.await?; + Ok(()) +} + +#[derive(Serialize)] +struct SignRequest { + message: Vec, +} + +#[derive(serde::Deserialize)] +struct SignedResponse { + signature: Vec, } + +pub struct HttpHsmProxy { + client: Client, + server_url: String, + public_key: PublicKey, +} + +impl HttpHsmProxy { + pub fn new(client: Client, server_url: String, public_key: PublicKey) -> Self { + Self { client, server_url, public_key } + } + + pub fn get_public_key(&self) -> PublicKey { + self.public_key.clone() + } +} + +#[async_trait::async_trait] +impl Hsm for HttpHsmProxy { + async fn sign( + &self, + message: Bytes, + ) -> Result<(Bytes, PublicKey, Signature), anyhow::Error> { + let payload = SignRequest { message: message.0.clone() }; + + let response = self + .client + .post(&self.server_url) + .json(&payload) + .send() + .await? + .json::() + .await?; + + let signature = Signature(Bytes(response.signature)); + + // Return the stored public key along with the signature + Ok((message, self.public_key.clone(), signature)) + } + + async fn verify( + &self, + _message: Bytes, + _public_key: PublicKey, + _signature: Signature, + ) -> Result { + // Verification would need another endpoint or can be skipped because Application already verifies + Ok(true) + } +} + diff --git a/demo/hsm/src/server.rs b/demo/hsm/src/server.rs new file mode 100644 index 000000000..9f0bf5667 --- /dev/null +++ b/demo/hsm/src/server.rs @@ -0,0 +1,44 @@ +use axum::{ + routing::post, + extract::State, + Json, Router, + http::StatusCode, +}; +use std::sync::Arc; +use tokio::sync::Mutex; + +use crate::{Bytes, Hsm}; + +pub fn create_server(hsm: Arc>) -> Router { + Router::new() + .route("/sign", post(sign_handler)) + .with_state(hsm) +} + +async fn sign_handler( + State(hsm): State>>, + Json(payload): Json, +) -> Result, StatusCode> { + let message_bytes = Bytes(payload.message); + + let (_message, _public_key, signature) = hsm + .lock() + .await + .sign(message_bytes) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(Json(SignedResponse { + signature: signature.0 .0, + })) +} + +#[derive(serde::Deserialize)] +pub struct SignRequest { + pub message: Vec, +} + +#[derive(serde::Serialize)] +pub struct SignedResponse { + pub signature: Vec, +}