From dea33825433f6a9fee8e5d04ab54a3948d5d1a17 Mon Sep 17 00:00:00 2001 From: Jason LeBrun Date: Fri, 21 Feb 2025 20:47:53 +0000 Subject: [PATCH] Create a function to create IdentityKey from bytes Updates the C++ bindings tests to generate keys from C++-provided bytes. BUG=396377174 Change-Id: Idd429602922ab3abbff95eb636af20f1d974e68c --- cc/oak_session/client_server_session_test.cc | 36 ++++++++++----- cc/oak_session/oak_session_bindings.h | 7 ++- cc/oak_session/oak_session_bindings_test.cc | 32 +++++++++---- oak_crypto/src/identity_key.rs | 6 ++- .../src/noise_handshake/crypto_wrapper.rs | 6 ++- oak_session/ffi/config.rs | 45 ++++++++++++++++++- 6 files changed, 108 insertions(+), 24 deletions(-) diff --git a/cc/oak_session/client_server_session_test.cc b/cc/oak_session/client_server_session_test.cc index 7fd4fbf42b8..d6bd8dab803 100644 --- a/cc/oak_session/client_server_session_test.cc +++ b/cc/oak_session/client_server_session_test.cc @@ -43,6 +43,11 @@ constexpr absl::string_view kFakeAttesterId = "fake_attester"; constexpr absl::string_view kFakeEvent = "fake event"; constexpr absl::string_view kFakePlatform = "fake platform"; +constexpr absl::string_view kClientKeyBytes = + "clientkeybytes12clientkeybytes34"; +constexpr absl::string_view kServerKeyBytes = + "serverkeybytes12serverkeybytes34"; + SessionConfig* TestConfigUnattestedNN() { return SessionConfigBuilder(AttestationType::kUnattested, HandshakeType::kNoiseNN) @@ -159,15 +164,18 @@ TEST(ClientServerSessionTest, AttestedNNHandshakeSucceeds) { } TEST(ClientServerSessionTest, UnattestedNKHandshakeSucceeds) { - bindings::IdentityKey* identity_key = bindings::new_identity_key(); + bindings::ErrorOrIdentityKey identity_key = + bindings::new_identity_key_from_bytes( + ffi::bindings::BytesView(kServerKeyBytes)); + ASSERT_THAT(identity_key, IsResult()); ffi::bindings::ErrorOrRustBytes public_key = - bindings::identity_key_get_public_key(identity_key); + bindings::identity_key_get_public_key(identity_key.result); ASSERT_THAT(public_key, IsResult()); auto client_session = ClientSession::Create(TestConfigUnattestedNKClient(*public_key.result)); auto server_session = - ServerSession::Create(TestConfigUnattestedNKServer(identity_key)); + ServerSession::Create(TestConfigUnattestedNKServer(identity_key.result)); DoHandshake(**client_session, **server_session); @@ -175,20 +183,26 @@ TEST(ClientServerSessionTest, UnattestedNKHandshakeSucceeds) { } TEST(ClientServerSessionTest, UnattestedKKHandshakeSucceeds) { - bindings::IdentityKey* client_identity_key = bindings::new_identity_key(); + bindings::ErrorOrIdentityKey client_identity_key = + bindings::new_identity_key_from_bytes( + ffi::bindings::BytesView(kClientKeyBytes)); + ASSERT_THAT(client_identity_key, IsResult()); ffi::bindings::ErrorOrRustBytes client_public_key = - bindings::identity_key_get_public_key(client_identity_key); + bindings::identity_key_get_public_key(client_identity_key.result); ASSERT_THAT(client_public_key, IsResult()); - bindings::IdentityKey* server_identity_key = bindings::new_identity_key(); + bindings::ErrorOrIdentityKey server_identity_key = + bindings::new_identity_key_from_bytes( + ffi::bindings::BytesView(kServerKeyBytes)); + ASSERT_THAT(server_identity_key, IsResult()); ffi::bindings::ErrorOrRustBytes server_public_key = - bindings::identity_key_get_public_key(server_identity_key); + bindings::identity_key_get_public_key(server_identity_key.result); ASSERT_THAT(client_public_key, IsResult()); - auto client_session = ClientSession::Create( - TestConfigUnattestedKK(*server_public_key.result, client_identity_key)); - auto server_session = ServerSession::Create( - TestConfigUnattestedKK(*client_public_key.result, server_identity_key)); + auto client_session = ClientSession::Create(TestConfigUnattestedKK( + *server_public_key.result, client_identity_key.result)); + auto server_session = ServerSession::Create(TestConfigUnattestedKK( + *client_public_key.result, server_identity_key.result)); DoHandshake(**client_session, **server_session); diff --git a/cc/oak_session/oak_session_bindings.h b/cc/oak_session/oak_session_bindings.h index 8c054b626b5..e7e215169e5 100644 --- a/cc/oak_session/oak_session_bindings.h +++ b/cc/oak_session/oak_session_bindings.h @@ -118,8 +118,12 @@ struct ErrorOrFfiEndorser { struct SigningKey; struct IdentityKey; -extern "C" { +struct ErrorOrIdentityKey { + IdentityKey* result; + ffi::bindings::Error* error; +}; +extern "C" { // Corresponds to functions in oak_session/ffi/config.rs extern ErrorOrSessionConfigBuilder new_session_config_builder(uint32_t, uint32_t); @@ -147,6 +151,7 @@ extern ffi::bindings::RustBytes signing_key_verifying_key_bytes(SigningKey*); extern void free_signing_key(SigningKey*); extern IdentityKey* new_identity_key(); +extern ErrorOrIdentityKey new_identity_key_from_bytes(ffi::bindings::BytesView); extern ffi::bindings::ErrorOrRustBytes identity_key_get_public_key( IdentityKey*); diff --git a/cc/oak_session/oak_session_bindings_test.cc b/cc/oak_session/oak_session_bindings_test.cc index 2a451ec0e57..b8e7e44255a 100644 --- a/cc/oak_session/oak_session_bindings_test.cc +++ b/cc/oak_session/oak_session_bindings_test.cc @@ -38,6 +38,11 @@ using ::oak::session::v1::SessionRequest; using ::oak::session::v1::SessionResponse; using ::testing::Eq; +constexpr absl::string_view kClientKeyBytes = + "clientkeybytes12clientkeybytes34"; +constexpr absl::string_view kServerKeyBytes = + "serverkeybytes12serverkeybytes34"; + void DoHandshake(ServerSession* server_session, ClientSession* client_session) { while (!client_is_open(client_session) && !server_is_open(server_session)) { if (!client_is_open(client_session)) { @@ -218,12 +223,15 @@ TEST(OakSessionBindingsTest, TestNNHandshake) { } TEST(OakSessionBindingsTest, TestNKHandshake) { - IdentityKey* identity_key = new_identity_key(); - ErrorOrRustBytes public_key = identity_key_get_public_key(identity_key); + ErrorOrIdentityKey identity_key = + new_identity_key_from_bytes(BytesView(kServerKeyBytes)); + ASSERT_THAT(identity_key, IsResult()); + ErrorOrRustBytes public_key = + identity_key_get_public_key(identity_key.result); ASSERT_THAT(public_key, IsResult()); ErrorOrServerSession server_session_result = - new_server_session(TestConfigUnattestedNKServer(identity_key)); + new_server_session(TestConfigUnattestedNKServer(identity_key.result)); ASSERT_THAT(server_session_result, IsResult()); ServerSession* server_session = server_session_result.result; ErrorOrClientSession client_session_result = new_client_session( @@ -239,23 +247,27 @@ TEST(OakSessionBindingsTest, TestNKHandshake) { } TEST(OakSessionBindingsTest, TestKKHandshake) { - IdentityKey* client_identity_key = new_identity_key(); - IdentityKey* server_identity_key = new_identity_key(); + ErrorOrIdentityKey client_identity_key = + new_identity_key_from_bytes(BytesView(kClientKeyBytes)); + ASSERT_THAT(client_identity_key, IsResult()); + ErrorOrIdentityKey server_identity_key = + new_identity_key_from_bytes(BytesView(kServerKeyBytes)); + ASSERT_THAT(server_identity_key, IsResult()); ErrorOrRustBytes client_public_key = - identity_key_get_public_key(client_identity_key); + identity_key_get_public_key(client_identity_key.result); ASSERT_THAT(client_public_key, IsResult()); ErrorOrRustBytes server_public_key = - identity_key_get_public_key(server_identity_key); + identity_key_get_public_key(server_identity_key.result); ASSERT_THAT(server_public_key, IsResult()); ErrorOrServerSession server_session_result = new_server_session(TestConfigUnattestedKK( - BytesView(*(client_public_key.result)), server_identity_key)); + BytesView(*(client_public_key.result)), server_identity_key.result)); ASSERT_THAT(server_session_result, IsResult()); ServerSession* server_session = server_session_result.result; ErrorOrClientSession client_session_result = new_client_session(TestConfigUnattestedKK( - BytesView(*(server_public_key.result)), client_identity_key)); + BytesView(*(server_public_key.result)), client_identity_key.result)); ASSERT_THAT(client_session_result, IsResult()); ClientSession* client_session = client_session_result.result; @@ -457,5 +469,7 @@ TEST(OakSessionBindingsTest, ErrorsAreReturned) { free_client_session(client_session); } +TEST(OakSessionBindingsTest, IncorrectKeyLenReturnsError) {} + } // namespace } // namespace oak::session::bindings diff --git a/oak_crypto/src/identity_key.rs b/oak_crypto/src/identity_key.rs index 196915f4e18..042d3bc570c 100644 --- a/oak_crypto/src/identity_key.rs +++ b/oak_crypto/src/identity_key.rs @@ -21,7 +21,7 @@ use alloc::vec::Vec; use anyhow::anyhow; -use crate::noise_handshake::{p256_scalar_mult, P256Scalar}; +use crate::noise_handshake::{client::P256_SCALAR_LEN, p256_scalar_mult, P256Scalar}; pub trait IdentityKeyHandle: Send { fn get_public_key(&self) -> anyhow::Result>; @@ -38,6 +38,10 @@ impl IdentityKey { pub fn generate() -> Self { Self { private_key: P256Scalar::generate() } } + + pub fn from_bytes(bytes: [u8; P256_SCALAR_LEN]) -> Self { + Self { private_key: P256Scalar::from_bytes(bytes) } + } } impl IdentityKeyHandle for IdentityKey { diff --git a/oak_crypto/src/noise_handshake/crypto_wrapper.rs b/oak_crypto/src/noise_handshake/crypto_wrapper.rs index e76e17920ee..440e538c74f 100644 --- a/oak_crypto/src/noise_handshake/crypto_wrapper.rs +++ b/oak_crypto/src/noise_handshake/crypto_wrapper.rs @@ -103,7 +103,11 @@ impl P256Scalar { pub fn generate() -> P256Scalar { let mut ret = [0u8; P256_SCALAR_LEN]; rand_bytes(&mut ret); - P256Scalar { v: p256::Scalar::from_repr(ret.into()).unwrap() } + P256Scalar::from_bytes(ret) + } + + pub fn from_bytes(bytes: [u8; P256_SCALAR_LEN]) -> P256Scalar { + P256Scalar { v: p256::Scalar::from_repr(bytes.into()).unwrap() } } pub fn compute_public_key(&self) -> [u8; P256_X962_LEN] { diff --git a/oak_session/ffi/config.rs b/oak_session/ffi/config.rs index f1c7220caa0..28406e1e998 100644 --- a/oak_session/ffi/config.rs +++ b/oak_session/ffi/config.rs @@ -36,7 +36,10 @@ use std::ffi::c_void; use oak_attestation_types::{attester::Attester, endorser::Endorser}; use oak_attestation_verification_types::verifier::AttestationVerifier; -use oak_crypto::identity_key::{IdentityKey, IdentityKeyHandle}; +use oak_crypto::{ + identity_key::{IdentityKey, IdentityKeyHandle}, + noise_handshake::client::P256_SCALAR_LEN, +}; use oak_ffi_bytes::BytesView; use oak_ffi_error::{Error, ErrorOrRustBytes}; use oak_session::{ @@ -294,6 +297,30 @@ pub extern "C" fn new_identity_key() -> *mut IdentityKey { Box::into_raw(Box::new(IdentityKey::generate())) } +/// Create a new IdentityKey instance from the provided bytes. +/// +/// If the functions succeeds, +/// `ErrorOrIdentity::result` will contain a pointer to the +/// [`IdentityKey`]. It should be freed by returning it to Rust via a function +/// call that reclaims ownership. +/// +/// In case of an error, `ErrorOrIdentityKey::error` will contain a poiner to +/// an error, containing a string description of the Rust error encountered. +/// The error should be freed with `oak_session_ffi_types::free_error`. +/// +/// # Safety +/// +/// * bytes is a valid, properly aligned pointer to a SessionConfigBuilder. +#[no_mangle] +pub unsafe extern "C" fn new_identity_key_from_bytes(bytes: BytesView) -> ErrorOrIdentityKey { + match <[u8; P256_SCALAR_LEN]>::try_from(bytes.as_slice()) { + Ok(bytes) => { + ErrorOrIdentityKey::ok(Box::into_raw(Box::new(IdentityKey::from_bytes(bytes)))) + } + Err(e) => ErrorOrIdentityKey::err(e.to_string()), + } +} + /// Call get_public_key on the provided IdentityKey. /// /// # Safety @@ -324,3 +351,19 @@ impl ErrorOrSessionConfigBuilder { Self { result: std::ptr::null_mut(), error: Error::new_raw(message) } } } + +#[repr(C)] +pub struct ErrorOrIdentityKey { + pub result: *mut IdentityKey, + pub error: *const Error, +} + +impl ErrorOrIdentityKey { + fn ok(result: *mut IdentityKey) -> Self { + Self { result, error: std::ptr::null() } + } + + fn err(message: impl AsRef) -> Self { + Self { result: std::ptr::null_mut(), error: Error::new_raw(message) } + } +}