From f24e62e512cc621f4a1334616dcf6f892f5dfdb8 Mon Sep 17 00:00:00 2001 From: Jason LeBrun Date: Fri, 5 Apr 2024 01:43:02 +0000 Subject: [PATCH] Update ack branch to main so builds continue succeeding. Squashed commit of the following: commit 1b24ec113436b21d12be8a2f18af879bf8258fb1 Author: Tiziano Santoro Date: Thu Apr 4 00:50:26 2024 +0100 Update nix deps (#4979) Among other things, this update xz to v. 5.4.6 commit f68df2b2e39058ff67c6b06e6155ab15b2225d21 Author: Tiziano Santoro Date: Thu Apr 4 00:46:45 2024 +0100 Align with internal linter (#4978) b/332740854 commit 8bdd77345b0573c8361ce1fbf464f243fb7c2378 Author: jblebrun Date: Wed Apr 3 21:43:20 2024 +0000 Update h2 to resolve vulnerability discovered by deny (#4977) https://rustsec.org/advisories/RUSTSEC-2024-0332 commit 5bc91beb210f6bd9fbe3e8cb02a880bfc044b6e4 Author: jul-sh Date: Wed Apr 3 16:05:06 2024 -0400 Directly issue kernel provenance for attestation measurements (#4976) * Directly issue kernel provenance attestation measurements Previously the provenance created by the SLSA builder was just for the bzImage. Not the artifact that would be measured in the attestation. With this PR the provenance subjects should include binaries measured in the attestation. Change-Id: I16e3234d0d65e3790319294c416c378cd7611681 * fix typo Change-Id: I3d078256d085ef05171e5997743d7497fc530ad0 commit 2ae6255c8e4eade70114c3bb87773729f4328a07 Author: Andri Saar Date: Tue Apr 2 20:55:51 2024 +0000 Do a page state change operation before invoking `PVALIDATE` commit 845288530b8f7daa90bd51cfdd55a3eecd55df25 Author: conradgrobler <58467069+conradgrobler@users.noreply.github.com> Date: Wed Apr 3 17:54:04 2024 +0100 Ensure CPUID triggered the #VC exception (#4974) We want to make sure that the instruction pointer in a #VC exception really pointed to a CPUID instruction since it is the only #VC exception type we support. commit 4ad534f14f49865ab412b00eb511285961ad8af8 Author: thmsbinder <129782017+thmsbinder@users.noreply.github.com> Date: Wed Apr 3 18:09:46 2024 +0200 Add and verify endorsement field for text reference value (#4973) The kernel command line reference value now follows the pattern from other reference values: skip, TR endorsement, or direct verification. When using TR endorsements in conjunction with the kernel command line the regex feature needs to be enabled. commit fa50670c417e2751b5dcd7a4a5f0ed08d7e98a41 Author: Patrick McGrath <48302523+pmcgrath17@users.noreply.github.com> Date: Tue Apr 2 10:43:22 2024 -0700 Unary gRPC transport template class (#4970) Implement unary transport class template for future Oak clients that use the unary interface. commit 65f6b4685e31bc8d520bbbdddc84bc9febbfd1a0 Author: k-naliuka <126095038+k-naliuka@users.noreply.github.com> Date: Fri Mar 29 00:33:37 2024 +0100 Add go and java options to the TcbVersion proto (#4969) commit cefb3c34b68c7bd15cf086be02477cd009d30ce7 Author: Andri Saar Date: Thu Mar 28 15:46:31 2024 +0000 Collect, and print out, some `PVALIDATE` stats in stage0 commit 579e92c257d5d42994295a0256764840aa54bf8e Author: k-naliuka <126095038+k-naliuka@users.noreply.github.com> Date: Wed Mar 27 20:49:53 2024 +0100 Refactor text reference values matching (#4965) Allow literal string comparison and make regex optional commit 121a6b0db570ef2f78c4ca0d2248dd55926a346d Author: Ivan Petrov Date: Wed Mar 27 19:13:14 2024 +0000 Sign group keys as part of Key Provisioning (#4961) This PR adds the ability to sign group keys in the attestation evidence as part of Key Provisioning. Ref https://github.com/project-oak/oak/issues/4442 commit 2a57cd6e802c9d333ce54f38a0c4c8a499656de5 Author: jul-sh Date: Wed Mar 27 12:10:56 2024 -0400 Revert "Increase the size of the certificate in Stage0 DICE data (#4946)" (#4966) This reverts commit c86964406aade767a449afd51eb10f99a0c84cb4, as it introduced a breaking change that broke imports. commit 57a8f734719715be5f119a6990fd9deebea94196 Author: Ivan Petrov Date: Wed Mar 27 15:29:07 2024 +0000 Add GroupEncryptionKeyHandle to C++ Containers SDK (#4964) Ref https://github.com/project-oak/oak/issues/4442 commit 863ee006a3154a7f1f6d82b45230c4b4f36f8c85 Author: k-naliuka <126095038+k-naliuka@users.noreply.github.com> Date: Wed Mar 27 14:15:48 2024 +0100 Include regex in Bazel oak_crates_index (#4960) commit 83d881ddf4b2fe9ced46c85891687aa8a8ee041c Author: Tiziano Santoro Date: Wed Mar 27 09:53:32 2024 +0000 Fix username and host when building kernel (#4963) b/330744888 Change-Id: Iac4a71c2d14238ccaca13c3997f47aa265a789ba --- .clang-format | 7 + Cargo.lock | 43 ++-- ...cted_kernel_simple_io_init_rd_wrapper.toml | 2 +- .../insecure_attestation_verifier.cc | 7 +- .../insecure_attestation_verifier.h | 3 +- cc/attestation/verification/utils.cc | 9 +- cc/client/client.cc | 25 ++- cc/client/client.h | 3 +- cc/client/client_test.cc | 23 +- cc/client/grpc_client_cli.cc | 10 +- .../hello_world_trusted_app/app_service.cc | 20 +- .../hello_world_trusted_app/app_service.h | 11 +- cc/containers/hello_world_trusted_app/main.cc | 6 +- cc/containers/sdk/common.h | 3 +- cc/containers/sdk/encryption_key_handle.cc | 18 +- cc/containers/sdk/encryption_key_handle.h | 13 +- cc/containers/sdk/orchestrator_client.cc | 12 +- cc/containers/sdk/orchestrator_client.h | 4 +- .../sdk/orchestrator_crypto_client.cc | 9 +- .../sdk/orchestrator_crypto_client.h | 7 +- cc/crypto/client_encryptor.cc | 30 ++- cc/crypto/client_encryptor.h | 27 ++- cc/crypto/common.h | 3 +- cc/crypto/encryption_key.cc | 6 +- cc/crypto/encryption_key.h | 3 +- cc/crypto/encryptor_test.cc | 72 +++--- cc/crypto/hpke/jni/context_jni.cc | 46 ++-- cc/crypto/hpke/jni/hpke_jni.cc | 69 +++--- cc/crypto/hpke/jni/keypair_jni.cc | 15 +- cc/crypto/hpke/recipient_context.cc | 59 ++--- cc/crypto/hpke/recipient_context.h | 14 +- cc/crypto/hpke/recipient_context_test.cc | 122 ++++++---- cc/crypto/hpke/sender_context.cc | 51 +++-- cc/crypto/hpke/sender_context.h | 25 ++- cc/crypto/hpke/sender_context_test.cc | 94 +++++--- cc/crypto/hpke/utils.cc | 22 +- cc/crypto/hpke/utils.h | 36 +-- cc/crypto/server_encryptor.cc | 39 ++-- cc/crypto/server_encryptor.h | 32 +-- cc/oak_echo_raw_enclave_app/main.cc | 4 +- cc/oak_functions/native_sdk/native_sdk.cc | 26 ++- cc/oak_functions/native_sdk/native_sdk.h | 4 +- cc/oak_functions/native_sdk/native_sdk_ffi.h | 13 +- cc/transport/BUILD | 46 ++++ cc/transport/grpc_streaming_transport.cc | 41 ++-- cc/transport/grpc_streaming_transport.h | 12 +- cc/transport/grpc_streaming_transport_test.cc | 64 ++++-- cc/transport/grpc_unary_transport.h | 86 +++++++ cc/transport/grpc_unary_transport_test.cc | 106 +++++++++ cc/transport/transport.h | 3 +- cc/transport/util.cc | 27 +++ cc/transport/util.h | 31 +++ cc/utils/cose/cose.cc | 57 +++-- cc/utils/cose/cose.h | 42 ++-- cc/utils/cose/cose_test.cc | 12 +- cc/utils/cose/cwt.cc | 22 +- cc/utils/cose/cwt.h | 7 +- cc/utils/cose/cwt_test.cc | 26 ++- flake.lock | 18 +- flake.nix | 6 + java/proto/server/secure_proxy.proto | 3 +- justfile | 9 + micro_rpc_workspace_test/Cargo.lock | 4 +- micro_rpc_workspace_test/proto/stubs.proto | 9 +- oak_attestation/src/dice.rs | 46 +++- .../tests/verifier_tests.rs | 17 +- oak_attestation_verification/BUILD | 30 +++ oak_attestation_verification/Cargo.toml | 2 +- oak_attestation_verification/src/verifier.rs | 91 +++++--- .../tests/verifier_tests.rs | 79 +++++-- oak_containers/proto/interfaces.proto | 52 +++-- oak_containers_orchestrator/src/main.rs | 67 +++--- oak_dice/src/evidence.rs | 15 +- .../proto/unary_server.proto | 21 +- oak_kernel_measurement/src/main.rs | 12 + oak_ml_transparency/runner/Cargo.lock | 1 - oak_restricted_kernel/src/interrupts.rs | 11 + oak_sev_guest/src/ghcb.rs | 131 ++++++++++- oak_sev_guest/src/instructions.rs | 2 +- proto/attestation/dice.proto | 7 +- proto/attestation/evidence.proto | 14 ++ proto/attestation/reference_value.proto | 54 +++-- proto/attestation/tcb_version.proto | 4 + .../containers/hostlib_key_provisioning.proto | 18 +- proto/containers/orchestrator_crypto.proto | 16 +- proto/crypto/crypto.proto | 9 +- proto/key_provisioning/key_provisioning.proto | 11 +- proto/micro_rpc/messages.proto | 11 +- proto/oak_functions/abi.proto | 36 +-- proto/oak_functions/application_config.proto | 9 +- .../sdk/oak_functions_wasm.proto | 24 +- .../oak_functions/service/oak_functions.proto | 32 +-- proto/session/BUILD | 1 + proto/session/messages.proto | 4 +- proto/session/service_streaming.proto | 23 +- proto/session/service_unary.proto | 3 +- stage0/src/sev.rs | 209 +++++++++++++----- 97 files changed, 1936 insertions(+), 804 deletions(-) create mode 100644 cc/transport/grpc_unary_transport.h create mode 100644 cc/transport/grpc_unary_transport_test.cc create mode 100644 cc/transport/util.cc create mode 100644 cc/transport/util.h diff --git a/.clang-format b/.clang-format index f5e3c53f7e3..c2be08feed6 100644 --- a/.clang-format +++ b/.clang-format @@ -2,3 +2,10 @@ BasedOnStyle: Google ColumnLimit: 100 DerivePointerAlignment: false PointerAlignment: Left +--- +Language: Cpp +ColumnLimit: 80 +--- +Language: Proto +ColumnLimit: 80 +--- diff --git a/Cargo.lock b/Cargo.lock index 39d18cf3d01..746de50998f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,6 +62,17 @@ dependencies = [ "subtle", ] +[[package]] +name = "ahash" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + [[package]] name = "ahash" version = "0.8.7" @@ -90,12 +101,6 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" -[[package]] -name = "allocator-api2" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" - [[package]] name = "aml" version = "0.16.4" @@ -1508,9 +1513,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" +checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" dependencies = [ "bytes", "fnv", @@ -1536,6 +1541,9 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +dependencies = [ + "ahash 0.7.8", +] [[package]] name = "hashbrown" @@ -1543,7 +1551,7 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" dependencies = [ - "ahash", + "ahash 0.8.7", ] [[package]] @@ -1552,8 +1560,7 @@ version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ - "ahash", - "allocator-api2", + "ahash 0.8.7", ] [[package]] @@ -1783,7 +1790,7 @@ version = "0.11.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "321f0f839cd44a4686e9504b0a62b4d69a50b62072144c71c68f5873c167b8d9" dependencies = [ - "ahash", + "ahash 0.8.7", "indexmap 2.1.0", "is-terminal", "itoa", @@ -1885,7 +1892,7 @@ dependencies = [ name = "key_value_lookup" version = "0.1.0" dependencies = [ - "hashbrown 0.14.3", + "hashbrown 0.12.3", "http", "log", "maplit", @@ -2629,7 +2636,7 @@ name = "oak_echo_service" version = "0.1.0" dependencies = [ "async-trait", - "hashbrown 0.14.3", + "hashbrown 0.12.3", "log", "micro_rpc", "micro_rpc_build", @@ -2774,7 +2781,7 @@ dependencies = [ "command-fds", "env_logger", "futures", - "hashbrown 0.14.3", + "hashbrown 0.12.3", "log", "micro_rpc", "micro_rpc_build", @@ -2839,7 +2846,7 @@ dependencies = [ "byteorder", "criterion", "criterion-macro", - "hashbrown 0.14.3", + "hashbrown 0.12.3", "log", "micro_rpc", "micro_rpc_build", @@ -2934,7 +2941,7 @@ dependencies = [ "bmrng", "clap", "command-fds", - "hashbrown 0.14.3", + "hashbrown 0.12.3", "log", "micro_rpc", "micro_rpc_build", @@ -3079,7 +3086,7 @@ name = "oak_restricted_kernel_sdk_proc_macro" version = "0.1.0" dependencies = [ "quote", - "syn 2.0.48", + "syn 1.0.109", ] [[package]] diff --git a/buildconfigs/oak_restricted_kernel_simple_io_init_rd_wrapper.toml b/buildconfigs/oak_restricted_kernel_simple_io_init_rd_wrapper.toml index 9eb661cfaff..42bfb4735d7 100644 --- a/buildconfigs/oak_restricted_kernel_simple_io_init_rd_wrapper.toml +++ b/buildconfigs/oak_restricted_kernel_simple_io_init_rd_wrapper.toml @@ -6,4 +6,4 @@ command = [ "just", "oak_restricted_kernel_simple_io_init_rd_wrapper", ] -artifact_path = "./oak_restricted_kernel_wrapper/target/x86_64-unknown-none/release/oak_restricted_kernel_simple_io_init_rd_wrapper_bin" +artifact_path = "./oak_restricted_kernel_wrapper/target/released_bin_with_components_oak_restricted_kernel_simple_io_init_rd/*" diff --git a/cc/attestation/verification/insecure_attestation_verifier.cc b/cc/attestation/verification/insecure_attestation_verifier.cc index 31646f14199..bcc03c8fcf4 100644 --- a/cc/attestation/verification/insecure_attestation_verifier.cc +++ b/cc/attestation/verification/insecure_attestation_verifier.cc @@ -34,9 +34,10 @@ using ::oak::attestation::v1::Evidence; } // namespace absl::StatusOr InsecureAttestationVerifier::Verify( - std::chrono::time_point now, const Evidence& evidence, - const Endorsements& endorsements) const { - absl::StatusOr encryption_public_key = ExtractEncryptionPublicKey(evidence); + std::chrono::time_point now, + const Evidence& evidence, const Endorsements& endorsements) const { + absl::StatusOr encryption_public_key = + ExtractEncryptionPublicKey(evidence); if (!encryption_public_key.ok()) { return encryption_public_key.status(); } diff --git a/cc/attestation/verification/insecure_attestation_verifier.h b/cc/attestation/verification/insecure_attestation_verifier.h index c812e7c47f6..e4b14ded143 100644 --- a/cc/attestation/verification/insecure_attestation_verifier.h +++ b/cc/attestation/verification/insecure_attestation_verifier.h @@ -28,7 +28,8 @@ namespace oak::attestation::verification { -// Verifier implementation that doesn't verify attestation evidence and is used for testing. +// Verifier implementation that doesn't verify attestation evidence and is used +// for testing. class InsecureAttestationVerifier : public AttestationVerifier { public: // Doesn't perform attestation verification and just returns a success value. diff --git a/cc/attestation/verification/utils.cc b/cc/attestation/verification/utils.cc index a72266849fd..934c5990e0c 100644 --- a/cc/attestation/verification/utils.cc +++ b/cc/attestation/verification/utils.cc @@ -39,12 +39,15 @@ absl::StatusOr ExtractPublicKey(absl::string_view certificate) { return std::string(public_key.begin(), public_key.end()); } -absl::StatusOr ExtractEncryptionPublicKey(const Evidence& evidence) { - return ExtractPublicKey(evidence.application_keys().encryption_public_key_certificate()); +absl::StatusOr ExtractEncryptionPublicKey( + const Evidence& evidence) { + return ExtractPublicKey( + evidence.application_keys().encryption_public_key_certificate()); } absl::StatusOr ExtractSigningPublicKey(const Evidence& evidence) { - return ExtractPublicKey(evidence.application_keys().signing_public_key_certificate()); + return ExtractPublicKey( + evidence.application_keys().signing_public_key_certificate()); } } // namespace oak::attestation::verification diff --git a/cc/client/client.cc b/cc/client/client.cc index f335cd130c0..8277c6d48a6 100644 --- a/cc/client/client.cc +++ b/cc/client/client.cc @@ -51,26 +51,29 @@ using ::oak::transport::TransportWrapper; constexpr absl::string_view kEmptyAssociatedData = ""; absl::StatusOr> OakClient::Create( - std::unique_ptr transport, AttestationVerifier& verifier) { - absl::StatusOr endorsed_evidence = transport->GetEndorsedEvidence(); + std::unique_ptr transport, + AttestationVerifier& verifier) { + absl::StatusOr endorsed_evidence = + transport->GetEndorsedEvidence(); if (!endorsed_evidence.ok()) { return endorsed_evidence.status(); } - absl::StatusOr attestation_results = - verifier.Verify(std::chrono::system_clock::now(), endorsed_evidence->evidence(), - endorsed_evidence->endorsements()); + absl::StatusOr attestation_results = verifier.Verify( + std::chrono::system_clock::now(), endorsed_evidence->evidence(), + endorsed_evidence->endorsements()); if (!attestation_results.ok()) { return attestation_results.status(); } switch (attestation_results->status()) { case AttestationResults::STATUS_SUCCESS: - return absl::WrapUnique( - new OakClient(std::move(transport), attestation_results->encryption_public_key())); + return absl::WrapUnique(new OakClient( + std::move(transport), attestation_results->encryption_public_key())); case AttestationResults::STATUS_GENERIC_FAILURE: return absl::FailedPreconditionError( - absl::StrCat("couldn't verify endorsed evidence: ", attestation_results->reason())); + absl::StrCat("couldn't verify endorsed evidence: ", + attestation_results->reason())); case AttestationResults::STATUS_UNSPECIFIED: default: return absl::InternalError("illegal status code in attestation results"); @@ -93,13 +96,15 @@ absl::StatusOr OakClient::Invoke(absl::string_view request_body) { } // Send request. - absl::StatusOr encrypted_response = transport_->Invoke(*encrypted_request); + absl::StatusOr encrypted_response = + transport_->Invoke(*encrypted_request); if (!encrypted_response.ok()) { return encrypted_response.status(); } // Decrypt response. - absl::StatusOr response = (*client_encryptor)->Decrypt(*encrypted_response); + absl::StatusOr response = + (*client_encryptor)->Decrypt(*encrypted_response); if (!response.ok()) { return response.status(); } diff --git a/cc/client/client.h b/cc/client/client.h index 427a635de0f..8e92afe6c2a 100644 --- a/cc/client/client.h +++ b/cc/client/client.h @@ -47,7 +47,8 @@ class OakClient { private: std::unique_ptr transport_; - // TODO(#4157): Store client encryptor once crypto sessions are implemented on the server. + // TODO(#4157): Store client encryptor once crypto sessions are implemented on + // the server. std::string server_encryption_public_key_; OakClient(std::unique_ptr transport, diff --git a/cc/client/client_test.cc b/cc/client/client_test.cc index 75d2b7b0e58..d7bd7b5ff12 100644 --- a/cc/client/client_test.cc +++ b/cc/client/client_test.cc @@ -65,15 +65,19 @@ class OakClientTest : public testing::Test { std::shared_ptr encryption_key_; }; -// TODO(#3641): Send test remote attestation report to the client and add corresponding tests. +// TODO(#3641): Send test remote attestation report to the client and add +// corresponding tests. class TestTransport : public TransportWrapper { public: explicit TestTransport(std::shared_ptr encryption_key) : encryption_key_(encryption_key) {} - absl::StatusOr GetEndorsedEvidence() override { return EndorsedEvidence(); } + absl::StatusOr GetEndorsedEvidence() override { + return EndorsedEvidence(); + } - absl::StatusOr Invoke(const EncryptedRequest& encrypted_request) override { + absl::StatusOr Invoke( + const EncryptedRequest& encrypted_request) override { ServerEncryptor server_encryptor = ServerEncryptor(*encryption_key_); auto decrypted_request = server_encryptor.Decrypt(encrypted_request); if (!decrypted_request.ok()) { @@ -81,9 +85,10 @@ class TestTransport : public TransportWrapper { } if (decrypted_request->plaintext != kTestRequest) { - return absl::InvalidArgumentError(std::string("incorrect request, expected: ") + - std::string(kTestRequest) + - ", got : " + decrypted_request->plaintext); + return absl::InvalidArgumentError( + std::string("incorrect request, expected: ") + + std::string(kTestRequest) + + ", got : " + decrypted_request->plaintext); } return server_encryptor.Encrypt(kTestResponse, kTestAssociatedData); @@ -95,11 +100,13 @@ class TestTransport : public TransportWrapper { class TestAttestationVerifier : public AttestationVerifier { public: - explicit TestAttestationVerifier(std::shared_ptr encryption_key) + explicit TestAttestationVerifier( + std::shared_ptr encryption_key) : encryption_key_(encryption_key) {} absl::StatusOr<::oak::attestation::v1::AttestationResults> Verify( - std::chrono::time_point now, const Evidence& evidence, + std::chrono::time_point now, + const Evidence& evidence, const Endorsements& endorsements) const override { AttestationResults attestation_results; attestation_results.set_status(AttestationResults::STATUS_SUCCESS); diff --git a/cc/client/grpc_client_cli.cc b/cc/client/grpc_client_cli.cc index 61a50b165b8..74768074ef2 100644 --- a/cc/client/grpc_client_cli.cc +++ b/cc/client/grpc_client_cli.cc @@ -52,17 +52,19 @@ int main(int argc, char* argv[]) { // Create gRPC client stub. LOG(INFO) << "connecting to: " << address; - std::shared_ptr channel = CreateChannel(address, InsecureChannelCredentials()); + std::shared_ptr channel = + CreateChannel(address, InsecureChannelCredentials()); std::shared_ptr stub = StreamingSession::NewStub(channel); ClientContext context; - std::unique_ptr> channel_reader_writer = - stub->Stream(&context); + std::unique_ptr> + channel_reader_writer = stub->Stream(&context); // Create Oak Client. LOG(INFO) << "creating Oak Client"; std::unique_ptr transport = - std::make_unique(std::move(channel_reader_writer)); + std::make_unique( + std::move(channel_reader_writer)); InsecureAttestationVerifier verifier = InsecureAttestationVerifier(); absl::StatusOr> oak_client = OakClient::Create(std::move(transport), verifier); diff --git a/cc/containers/hello_world_trusted_app/app_service.cc b/cc/containers/hello_world_trusted_app/app_service.cc index 087f25b69a2..e8e6df4b524 100644 --- a/cc/containers/hello_world_trusted_app/app_service.cc +++ b/cc/containers/hello_world_trusted_app/app_service.cc @@ -38,23 +38,27 @@ using ::oak::crypto::v1::EncryptedResponse; constexpr absl::string_view kEmptyAssociatedData = ""; grpc::Status TrustedApplicationImpl::Hello(grpc::ServerContext* context, - const HelloRequest* request, HelloResponse* response) { + const HelloRequest* request, + HelloResponse* response) { ServerEncryptor server_encryptor(*encryption_key_handle_); absl::StatusOr decrypted_request = server_encryptor.Decrypt(request->encrypted_request()); if (!decrypted_request.ok()) { - return grpc::Status(static_cast(decrypted_request.status().code()), - std::string(decrypted_request.status().message())); + return grpc::Status( + static_cast(decrypted_request.status().code()), + std::string(decrypted_request.status().message())); } - std::string greeting = absl::StrCat("Hello from the trusted side, ", decrypted_request->plaintext, - "! Btw, the Trusted App has a config with a length of ", - application_config_.size(), " bytes."); + std::string greeting = absl::StrCat( + "Hello from the trusted side, ", decrypted_request->plaintext, + "! Btw, the Trusted App has a config with a length of ", + application_config_.size(), " bytes."); absl::StatusOr encrypted_response = server_encryptor.Encrypt(greeting, kEmptyAssociatedData); if (!encrypted_response.ok()) { - return grpc::Status(static_cast(encrypted_response.status().code()), - std::string(encrypted_response.status().message())); + return grpc::Status( + static_cast(encrypted_response.status().code()), + std::string(encrypted_response.status().message())); } *response->mutable_encrypted_response() = *std::move(encrypted_response); diff --git a/cc/containers/hello_world_trusted_app/app_service.h b/cc/containers/hello_world_trusted_app/app_service.h index b2abf1de4d2..5a66a3dc866 100644 --- a/cc/containers/hello_world_trusted_app/app_service.h +++ b/cc/containers/hello_world_trusted_app/app_service.h @@ -28,14 +28,17 @@ namespace oak::oak_containers_hello_world_trusted_app { -class TrustedApplicationImpl : public containers::example::TrustedApplication::Service { +class TrustedApplicationImpl + : public containers::example::TrustedApplication::Service { public: - TrustedApplicationImpl(std::unique_ptr<::oak::crypto::EncryptionKeyHandle> encryption_key_handle, - absl::string_view application_config) + TrustedApplicationImpl( + std::unique_ptr<::oak::crypto::EncryptionKeyHandle> encryption_key_handle, + absl::string_view application_config) : encryption_key_handle_(std::move(encryption_key_handle)), application_config_(application_config) {} - grpc::Status Hello(grpc::ServerContext* context, const containers::example::HelloRequest* request, + grpc::Status Hello(grpc::ServerContext* context, + const containers::example::HelloRequest* request, containers::example::HelloResponse* response) override; private: diff --git a/cc/containers/hello_world_trusted_app/main.cc b/cc/containers/hello_world_trusted_app/main.cc index db112b03961..eeccd701d75 100644 --- a/cc/containers/hello_world_trusted_app/main.cc +++ b/cc/containers/hello_world_trusted_app/main.cc @@ -33,10 +33,12 @@ int main(int argc, char* argv[]) { absl::InitializeLog(); OrchestratorClient client; - absl::StatusOr application_config = client.GetApplicationConfig(); + absl::StatusOr application_config = + client.GetApplicationConfig(); QCHECK_OK(application_config); TrustedApplicationImpl service( - std::make_unique<::oak::containers::sdk::InstanceEncryptionKeyHandle>(), *application_config); + std::make_unique<::oak::containers::sdk::InstanceEncryptionKeyHandle>(), + *application_config); grpc::ServerBuilder builder; builder.AddListeningPort("[::]:8080", grpc::InsecureServerCredentials()); diff --git a/cc/containers/sdk/common.h b/cc/containers/sdk/common.h index 9078df3c81d..9b664401699 100644 --- a/cc/containers/sdk/common.h +++ b/cc/containers/sdk/common.h @@ -20,7 +20,8 @@ namespace oak::containers::sdk { // Unix socket used to connect to the Orchestrator. -inline static const char kOrchestratorSocket[] = "unix:/oak_utils/orchestrator_ipc"; +inline static const char kOrchestratorSocket[] = + "unix:/oak_utils/orchestrator_ipc"; inline static const char kContextAuthority[] = "[::]:0"; diff --git a/cc/containers/sdk/encryption_key_handle.cc b/cc/containers/sdk/encryption_key_handle.cc index c28c5ffbeb9..e595a12a21d 100644 --- a/cc/containers/sdk/encryption_key_handle.cc +++ b/cc/containers/sdk/encryption_key_handle.cc @@ -35,8 +35,22 @@ using ::oak::crypto::v1::SessionKeys; absl::StatusOr> InstanceEncryptionKeyHandle::GenerateRecipientContext( absl::string_view serialized_encapsulated_public_key) { - absl::StatusOr session_keys = orchestrator_crypto_client_.DeriveSessionKeys( - KeyOrigin::INSTANCE, serialized_encapsulated_public_key); + absl::StatusOr session_keys = + orchestrator_crypto_client_.DeriveSessionKeys( + KeyOrigin::INSTANCE, serialized_encapsulated_public_key); + if (!session_keys.ok()) { + return absl::InternalError("couldn't derive session keys"); + } + + return RecipientContext::Deserialize(*session_keys); +} + +absl::StatusOr> +GroupEncryptionKeyHandle::GenerateRecipientContext( + absl::string_view serialized_encapsulated_public_key) { + absl::StatusOr session_keys = + orchestrator_crypto_client_.DeriveSessionKeys( + KeyOrigin::GROUP, serialized_encapsulated_public_key); if (!session_keys.ok()) { return absl::InternalError("couldn't derive session keys"); } diff --git a/cc/containers/sdk/encryption_key_handle.h b/cc/containers/sdk/encryption_key_handle.h index a2513324c66..16bed3efb61 100644 --- a/cc/containers/sdk/encryption_key_handle.h +++ b/cc/containers/sdk/encryption_key_handle.h @@ -30,7 +30,18 @@ namespace oak::containers::sdk { class InstanceEncryptionKeyHandle : public ::oak::crypto::EncryptionKeyHandle { public: - absl::StatusOr> GenerateRecipientContext( + absl::StatusOr> + GenerateRecipientContext( + absl::string_view serialized_encapsulated_public_key) override; + + private: + OrchestratorCryptoClient orchestrator_crypto_client_; +}; + +class GroupEncryptionKeyHandle : public ::oak::crypto::EncryptionKeyHandle { + public: + absl::StatusOr> + GenerateRecipientContext( absl::string_view serialized_encapsulated_public_key) override; private: diff --git a/cc/containers/sdk/orchestrator_client.cc b/cc/containers/sdk/orchestrator_client.cc index 5547220e506..9ba704d3f57 100644 --- a/cc/containers/sdk/orchestrator_client.cc +++ b/cc/containers/sdk/orchestrator_client.cc @@ -38,8 +38,10 @@ absl::StatusOr OrchestratorClient::GetApplicationConfig() const { grpc::ClientContext context; context.set_authority(kContextAuthority); GetApplicationConfigResponse response; - if (auto status = stub_->GetApplicationConfig(&context, {}, &response); !status.ok()) { - return absl::Status(static_cast(status.error_code()), status.error_message()); + if (auto status = stub_->GetApplicationConfig(&context, {}, &response); + !status.ok()) { + return absl::Status(static_cast(status.error_code()), + status.error_message()); } return std::move(*response.mutable_config()); } @@ -48,8 +50,10 @@ absl::Status OrchestratorClient::NotifyAppReady() const { grpc::ClientContext context; context.set_authority(kContextAuthority); google::protobuf::Empty response; - if (auto status = stub_->NotifyAppReady(&context, {}, &response); !status.ok()) { - return absl::Status(static_cast(status.error_code()), status.error_message()); + if (auto status = stub_->NotifyAppReady(&context, {}, &response); + !status.ok()) { + return absl::Status(static_cast(status.error_code()), + status.error_message()); } return absl::OkStatus(); } diff --git a/cc/containers/sdk/orchestrator_client.h b/cc/containers/sdk/orchestrator_client.h index 396f85865eb..645219c67b5 100644 --- a/cc/containers/sdk/orchestrator_client.h +++ b/cc/containers/sdk/orchestrator_client.h @@ -34,8 +34,8 @@ namespace oak::containers::sdk { class OrchestratorClient { public: OrchestratorClient() - : OrchestratorClient( - grpc::CreateChannel(kOrchestratorSocket, grpc::InsecureChannelCredentials())) {} + : OrchestratorClient(grpc::CreateChannel( + kOrchestratorSocket, grpc::InsecureChannelCredentials())) {} absl::StatusOr GetApplicationConfig() const; absl::Status NotifyAppReady() const; diff --git a/cc/containers/sdk/orchestrator_crypto_client.cc b/cc/containers/sdk/orchestrator_crypto_client.cc index f7dfbc8a850..04c848e3ba1 100644 --- a/cc/containers/sdk/orchestrator_crypto_client.cc +++ b/cc/containers/sdk/orchestrator_crypto_client.cc @@ -39,15 +39,18 @@ using ::oak::crypto::v1::SessionKeys; } // namespace absl::StatusOr OrchestratorCryptoClient::DeriveSessionKeys( - KeyOrigin key_origin, absl::string_view serialized_encapsulated_public_key) const { + KeyOrigin key_origin, + absl::string_view serialized_encapsulated_public_key) const { ClientContext context; context.set_authority(kContextAuthority); DeriveSessionKeysRequest request; request.set_key_origin(key_origin); - request.set_serialized_encapsulated_public_key(serialized_encapsulated_public_key); + request.set_serialized_encapsulated_public_key( + serialized_encapsulated_public_key); DeriveSessionKeysResponse response; - ::grpc::Status status = stub_->DeriveSessionKeys(&context, request, &response); + ::grpc::Status status = + stub_->DeriveSessionKeys(&context, request, &response); if (!status.ok()) { return absl::InternalError("couldn't derive session keys"); } diff --git a/cc/containers/sdk/orchestrator_crypto_client.h b/cc/containers/sdk/orchestrator_crypto_client.h index 22e0eff89b8..483ad8ca535 100644 --- a/cc/containers/sdk/orchestrator_crypto_client.h +++ b/cc/containers/sdk/orchestrator_crypto_client.h @@ -33,15 +33,16 @@ namespace oak::containers::sdk { class OrchestratorCryptoClient { public: OrchestratorCryptoClient() - : OrchestratorCryptoClient( - grpc::CreateChannel(kOrchestratorSocket, grpc::InsecureChannelCredentials())) {} + : OrchestratorCryptoClient(grpc::CreateChannel( + kOrchestratorSocket, grpc::InsecureChannelCredentials())) {} absl::StatusOr<::oak::crypto::v1::SessionKeys> DeriveSessionKeys( ::oak::containers::v1::KeyOrigin key_origin, absl::string_view serialized_encapsulated_public_key) const; private: - explicit OrchestratorCryptoClient(const std::shared_ptr& channel) + explicit OrchestratorCryptoClient( + const std::shared_ptr& channel) : stub_(::oak::containers::v1::OrchestratorCrypto::NewStub(channel)) {} std::unique_ptr<::oak::containers::v1::OrchestratorCrypto::Stub> stub_; diff --git a/cc/crypto/client_encryptor.cc b/cc/crypto/client_encryptor.cc index e19fdd7b6a1..ae888dea31e 100644 --- a/cc/crypto/client_encryptor.cc +++ b/cc/crypto/client_encryptor.cc @@ -45,8 +45,8 @@ absl::StatusOr> ClientEncryptor::Create( return std::make_unique(std::move(*sender_context)); } -absl::StatusOr ClientEncryptor::Encrypt(absl::string_view plaintext, - absl::string_view associated_data) { +absl::StatusOr ClientEncryptor::Encrypt( + absl::string_view plaintext, absl::string_view associated_data) { // Encrypt request. absl::StatusOr> nonce = GenerateRandomNonce(); if (!nonce.ok()) { @@ -62,10 +62,13 @@ absl::StatusOr ClientEncryptor::Encrypt(absl::string_view plai EncryptedRequest encrypted_request; *encrypted_request.mutable_encrypted_message()->mutable_nonce() = std::string(nonce->begin(), nonce->end()); - *encrypted_request.mutable_encrypted_message()->mutable_ciphertext() = *ciphertext; - *encrypted_request.mutable_encrypted_message()->mutable_associated_data() = associated_data; + *encrypted_request.mutable_encrypted_message()->mutable_ciphertext() = + *ciphertext; + *encrypted_request.mutable_encrypted_message()->mutable_associated_data() = + associated_data; - // Encapsulated public key is only sent in the initial request message of the session. + // Encapsulated public key is only sent in the initial request message of the + // session. if (!serialized_encapsulated_public_key_has_been_sent_) { *encrypted_request.mutable_serialized_encapsulated_public_key() = sender_context_->GetSerializedEncapsulatedPublicKey(); @@ -75,18 +78,21 @@ absl::StatusOr ClientEncryptor::Encrypt(absl::string_view plai return encrypted_request; } -absl::StatusOr ClientEncryptor::Decrypt(EncryptedResponse encrypted_response) { +absl::StatusOr ClientEncryptor::Decrypt( + EncryptedResponse encrypted_response) { // Decrypt response. - const std::vector nonce(encrypted_response.encrypted_message().nonce().begin(), - encrypted_response.encrypted_message().nonce().end()); - absl::StatusOr plaintext = - sender_context_->Open(nonce, encrypted_response.encrypted_message().ciphertext(), - encrypted_response.encrypted_message().associated_data()); + const std::vector nonce( + encrypted_response.encrypted_message().nonce().begin(), + encrypted_response.encrypted_message().nonce().end()); + absl::StatusOr plaintext = sender_context_->Open( + nonce, encrypted_response.encrypted_message().ciphertext(), + encrypted_response.encrypted_message().associated_data()); if (!plaintext.ok()) { return plaintext.status(); } - return DecryptionResult{*plaintext, encrypted_response.encrypted_message().associated_data()}; + return DecryptionResult{ + *plaintext, encrypted_response.encrypted_message().associated_data()}; } } // namespace oak::crypto diff --git a/cc/crypto/client_encryptor.h b/cc/crypto/client_encryptor.h index e0a2b61f7ba..a9902e17d08 100644 --- a/cc/crypto/client_encryptor.h +++ b/cc/crypto/client_encryptor.h @@ -29,21 +29,23 @@ namespace oak::crypto { -// Encryptor class for encrypting client requests that will be sent to the server and decrypting -// server responses that are received by the client. Each Encryptor corresponds to a single crypto -// session between the client and the server. +// Encryptor class for encrypting client requests that will be sent to the +// server and decrypting server responses that are received by the client. Each +// Encryptor corresponds to a single crypto session between the client and the +// server. // -// Sequence numbers for requests and responses are incremented separately, meaning that there could -// be multiple responses per request and multiple requests per response. +// Sequence numbers for requests and responses are incremented separately, +// meaning that there could be multiple responses per request and multiple +// requests per response. class ClientEncryptor { public: // Creates a new instance of [`ClientEncryptor`]. - // The corresponding encryption and decryption keys are generated using the server public key with - // Hybrid Public Key Encryption (HPKE). + // The corresponding encryption and decryption keys are generated using the + // server public key with Hybrid Public Key Encryption (HPKE). // // - // `serialized_server_public_key` must be a NIST P-256 SEC1 encoded point public key. - // + // `serialized_server_public_key` must be a NIST P-256 SEC1 encoded point + // public key. static absl::StatusOr> Create( absl::string_view serialized_server_public_key); @@ -56,14 +58,15 @@ class ClientEncryptor { // // // Returns an [`oak.crypto.EncryptedRequest`] proto message. - absl::StatusOr<::oak::crypto::v1::EncryptedRequest> Encrypt(absl::string_view plaintext, - absl::string_view associated_data); + absl::StatusOr<::oak::crypto::v1::EncryptedRequest> Encrypt( + absl::string_view plaintext, absl::string_view associated_data); // Decrypts a [`EncryptedResponse`] proto message using AEAD. // // // Returns a response message plaintext and associated data. - absl::StatusOr Decrypt(oak::crypto::v1::EncryptedResponse encrypted_response); + absl::StatusOr Decrypt( + oak::crypto::v1::EncryptedResponse encrypted_response); private: // Encapsulated public key needed to establish a symmetric session key. diff --git a/cc/crypto/common.h b/cc/crypto/common.h index a7fc75f6b45..226bc8c34d4 100644 --- a/cc/crypto/common.h +++ b/cc/crypto/common.h @@ -22,7 +22,8 @@ namespace oak::crypto { // Info string used by Hybrid Public Key Encryption. -inline constexpr absl::string_view kOakHPKEInfo = "Oak Hybrid Public Key Encryption v1"; +inline constexpr absl::string_view kOakHPKEInfo = + "Oak Hybrid Public Key Encryption v1"; struct DecryptionResult { std::string plaintext; diff --git a/cc/crypto/encryption_key.cc b/cc/crypto/encryption_key.cc index 3b3f78805c9..b5cd7a029db 100644 --- a/cc/crypto/encryption_key.cc +++ b/cc/crypto/encryption_key.cc @@ -33,9 +33,11 @@ absl::StatusOr EncryptionKeyProvider::Create() { return EncryptionKeyProvider(*key_pair); } -absl::StatusOr> EncryptionKeyProvider::GenerateRecipientContext( +absl::StatusOr> +EncryptionKeyProvider::GenerateRecipientContext( absl::string_view serialized_encapsulated_public_key) { - return SetupBaseRecipient(serialized_encapsulated_public_key, key_pair_, kOakHPKEInfo); + return SetupBaseRecipient(serialized_encapsulated_public_key, key_pair_, + kOakHPKEInfo); } } // namespace oak::crypto diff --git a/cc/crypto/encryption_key.h b/cc/crypto/encryption_key.h index 04516f6f7a8..ba71b331a5f 100644 --- a/cc/crypto/encryption_key.h +++ b/cc/crypto/encryption_key.h @@ -29,7 +29,8 @@ namespace oak::crypto { class EncryptionKeyHandle { public: - virtual absl::StatusOr> GenerateRecipientContext( + virtual absl::StatusOr> + GenerateRecipientContext( absl::string_view serialized_encapsulated_public_key) = 0; virtual ~EncryptionKeyHandle() = default; diff --git a/cc/crypto/encryptor_test.cc b/cc/crypto/encryptor_test.cc index 4e1b4eac87d..864c1582e97 100644 --- a/cc/crypto/encryptor_test.cc +++ b/cc/crypto/encryptor_test.cc @@ -29,7 +29,8 @@ namespace { using ::testing::StrEq; -constexpr absl::string_view kOakHPKEInfoTest = "Oak Hybrid Public Key Encryption Test"; +constexpr absl::string_view kOakHPKEInfoTest = + "Oak Hybrid Public Key Encryption Test"; // Client Encryptor and Server Encryptor can communicate. TEST(EncryptorTest, ClientEncryptorAndServerEncryptorCommunicateSuccess) { // Set up client and server encryptors. @@ -40,50 +41,66 @@ TEST(EncryptorTest, ClientEncryptorAndServerEncryptorCommunicateSuccess) { ASSERT_TRUE(client_encryptor.ok()); ServerEncryptor server_encryptor = ServerEncryptor(*encryption_key); - // Here we have the client send 2 encrypted messages to the server to ensure that nonce's align - // for multi-message communication. + // Here we have the client send 2 encrypted messages to the server to ensure + // that nonce's align for multi-message communication. std::string client_plaintext_message1 = "Hello server"; // Encrypt plaintext message and have server encryptor decrypt message. auto client_ciphertext1 = (*client_encryptor)->Encrypt(client_plaintext_message1, kOakHPKEInfoTest); ASSERT_TRUE(client_ciphertext1.ok()); - auto server_decryption_result1 = server_encryptor.Decrypt(*client_ciphertext1); + auto server_decryption_result1 = + server_encryptor.Decrypt(*client_ciphertext1); ASSERT_TRUE(server_decryption_result1.ok()); - EXPECT_THAT(client_plaintext_message1, StrEq(server_decryption_result1->plaintext)); - EXPECT_THAT(kOakHPKEInfoTest, StrEq(server_decryption_result1->associated_data)); + EXPECT_THAT(client_plaintext_message1, + StrEq(server_decryption_result1->plaintext)); + EXPECT_THAT(kOakHPKEInfoTest, + StrEq(server_decryption_result1->associated_data)); std::string client_plaintext_message2 = "Hello again, server!"; auto client_ciphertext2 = (*client_encryptor)->Encrypt(client_plaintext_message2, kOakHPKEInfoTest); ASSERT_TRUE(client_ciphertext2.ok()); - auto server_decryption_result2 = server_encryptor.Decrypt(*client_ciphertext2); + auto server_decryption_result2 = + server_encryptor.Decrypt(*client_ciphertext2); ASSERT_TRUE(server_decryption_result2.ok()); - EXPECT_THAT(client_plaintext_message2, StrEq(server_decryption_result2->plaintext)); - EXPECT_THAT(kOakHPKEInfoTest, StrEq(server_decryption_result2->associated_data)); + EXPECT_THAT(client_plaintext_message2, + StrEq(server_decryption_result2->plaintext)); + EXPECT_THAT(kOakHPKEInfoTest, + StrEq(server_decryption_result2->associated_data)); - // We have the server send 2 encrypted messages back to the client. Again this is to ensure the - // nonce's align for the multiple messages. + // We have the server send 2 encrypted messages back to the client. Again this + // is to ensure the nonce's align for the multiple messages. std::string server_plaintext_message1 = "Hello client"; - // Server responds with an encrypted message that the client successfully decrypts. - auto server_ciphertext1 = server_encryptor.Encrypt(server_plaintext_message1, kOakHPKEInfoTest); + // Server responds with an encrypted message that the client successfully + // decrypts. + auto server_ciphertext1 = + server_encryptor.Encrypt(server_plaintext_message1, kOakHPKEInfoTest); ASSERT_TRUE(server_ciphertext1.ok()); - auto client_decryption_result1 = (*client_encryptor)->Decrypt(*server_ciphertext1); + auto client_decryption_result1 = + (*client_encryptor)->Decrypt(*server_ciphertext1); ASSERT_TRUE(client_decryption_result1.ok()); - EXPECT_THAT(server_plaintext_message1, StrEq(client_decryption_result1->plaintext)); - EXPECT_THAT(kOakHPKEInfoTest, StrEq(client_decryption_result1->associated_data)); + EXPECT_THAT(server_plaintext_message1, + StrEq(client_decryption_result1->plaintext)); + EXPECT_THAT(kOakHPKEInfoTest, + StrEq(client_decryption_result1->associated_data)); std::string server_plaintext_message2 = "Hello again, client!"; - auto server_ciphertext2 = server_encryptor.Encrypt(server_plaintext_message2, kOakHPKEInfoTest); + auto server_ciphertext2 = + server_encryptor.Encrypt(server_plaintext_message2, kOakHPKEInfoTest); ASSERT_TRUE(server_ciphertext2.ok()); - auto client_decryption_result2 = (*client_encryptor)->Decrypt(*server_ciphertext2); + auto client_decryption_result2 = + (*client_encryptor)->Decrypt(*server_ciphertext2); ASSERT_TRUE(client_decryption_result2.ok()); - EXPECT_THAT(server_plaintext_message2, StrEq(client_decryption_result2->plaintext)); - EXPECT_THAT(kOakHPKEInfoTest, StrEq(client_decryption_result2->associated_data)); + EXPECT_THAT(server_plaintext_message2, + StrEq(client_decryption_result2->plaintext)); + EXPECT_THAT(kOakHPKEInfoTest, + StrEq(client_decryption_result2->associated_data)); } -TEST(EncryptorTest, ClientEncryptorAndServerEncryptorCommunicateMismatchPublicKeysFailure) { +TEST(EncryptorTest, + ClientEncryptorAndServerEncryptorCommunicateMismatchPublicKeysFailure) { // Set up client and server encryptors. auto encryption_key = EncryptionKeyProvider::Create(); ASSERT_TRUE(encryption_key.ok()); @@ -96,14 +113,17 @@ TEST(EncryptorTest, ClientEncryptorAndServerEncryptorCommunicateMismatchPublicKe std::string client_plaintext_message = "Hello server"; - // Encrypt plaintext message and have server encryptor decrypt message. This should result in - // failure since the public key is incorrect. - auto client_ciphertext = (*client_encryptor)->Encrypt(client_plaintext_message, kOakHPKEInfoTest); + // Encrypt plaintext message and have server encryptor decrypt message. This + // should result in failure since the public key is incorrect. + auto client_ciphertext = + (*client_encryptor)->Encrypt(client_plaintext_message, kOakHPKEInfoTest); ASSERT_TRUE(client_ciphertext.ok()); auto server_decryption_result = server_encryptor.Decrypt(*client_ciphertext); EXPECT_FALSE(server_decryption_result.ok()); - EXPECT_EQ(server_decryption_result.status().code(), absl::StatusCode::kAborted); - EXPECT_THAT(server_decryption_result.status().message(), StrEq("Unable to decrypt message")); + EXPECT_EQ(server_decryption_result.status().code(), + absl::StatusCode::kAborted); + EXPECT_THAT(server_decryption_result.status().message(), + StrEq("Unable to decrypt message")); } } // namespace diff --git a/cc/crypto/hpke/jni/context_jni.cc b/cc/crypto/hpke/jni/context_jni.cc index 7eda9095c8b..018164b2029 100644 --- a/cc/crypto/hpke/jni/context_jni.cc +++ b/cc/crypto/hpke/jni/context_jni.cc @@ -21,15 +21,18 @@ #include "com_google_oak_crypto_hpke_SenderContext.h" #include "jni_helper.h" -JNIEXPORT jbyteArray JNICALL Java_com_google_oak_crypto_hpke_SenderContext_nativeSeal( - JNIEnv* env, jobject obj, jbyteArray nonce, jbyteArray plaintext, jbyteArray associated_data) { +JNIEXPORT jbyteArray JNICALL +Java_com_google_oak_crypto_hpke_SenderContext_nativeSeal( + JNIEnv* env, jobject obj, jbyteArray nonce, jbyteArray plaintext, + jbyteArray associated_data) { if (nonce == NULL || plaintext == NULL || associated_data == NULL) { return {}; } std::string nonce_str = convert_jbytearray_to_string(env, nonce); std::string plaintext_str = convert_jbytearray_to_string(env, plaintext); - std::string associated_data_str = convert_jbytearray_to_string(env, associated_data); + std::string associated_data_str = + convert_jbytearray_to_string(env, associated_data); jclass sender_context_class = env->GetObjectClass(obj); jfieldID fid = env->GetFieldID(sender_context_class, "nativePtr", "J"); @@ -52,15 +55,18 @@ JNIEXPORT jbyteArray JNICALL Java_com_google_oak_crypto_hpke_SenderContext_nativ return ret; } -JNIEXPORT jbyteArray JNICALL Java_com_google_oak_crypto_hpke_SenderContext_nativeOpen( - JNIEnv* env, jobject obj, jbyteArray nonce, jbyteArray ciphertext, jbyteArray associated_data) { +JNIEXPORT jbyteArray JNICALL +Java_com_google_oak_crypto_hpke_SenderContext_nativeOpen( + JNIEnv* env, jobject obj, jbyteArray nonce, jbyteArray ciphertext, + jbyteArray associated_data) { if (ciphertext == NULL || associated_data == NULL) { return {}; } std::string nonce_str = convert_jbytearray_to_string(env, nonce); std::string ciphertext_str = convert_jbytearray_to_string(env, ciphertext); - std::string associated_data_str = convert_jbytearray_to_string(env, associated_data); + std::string associated_data_str = + convert_jbytearray_to_string(env, associated_data); jclass sender_context_class = env->GetObjectClass(obj); jfieldID fid = env->GetFieldID(sender_context_class, "nativePtr", "J"); @@ -83,15 +89,18 @@ JNIEXPORT jbyteArray JNICALL Java_com_google_oak_crypto_hpke_SenderContext_nativ return ret; } -JNIEXPORT jbyteArray JNICALL Java_com_google_oak_crypto_hpke_RecipientContext_nativeOpen( - JNIEnv* env, jobject obj, jbyteArray nonce, jbyteArray ciphertext, jbyteArray associated_data) { +JNIEXPORT jbyteArray JNICALL +Java_com_google_oak_crypto_hpke_RecipientContext_nativeOpen( + JNIEnv* env, jobject obj, jbyteArray nonce, jbyteArray ciphertext, + jbyteArray associated_data) { if (ciphertext == NULL || associated_data == NULL) { return {}; } std::string nonce_str = convert_jbytearray_to_string(env, nonce); std::string ciphertext_str = convert_jbytearray_to_string(env, ciphertext); - std::string associated_data_str = convert_jbytearray_to_string(env, associated_data); + std::string associated_data_str = + convert_jbytearray_to_string(env, associated_data); jclass recipient_context_class = env->GetObjectClass(obj); jfieldID fid = env->GetFieldID(recipient_context_class, "nativePtr", "J"); @@ -114,15 +123,18 @@ JNIEXPORT jbyteArray JNICALL Java_com_google_oak_crypto_hpke_RecipientContext_na return ret; } -JNIEXPORT jbyteArray JNICALL Java_com_google_oak_crypto_hpke_RecipientContext_nativeSeal( - JNIEnv* env, jobject obj, jbyteArray nonce, jbyteArray plaintext, jbyteArray associated_data) { +JNIEXPORT jbyteArray JNICALL +Java_com_google_oak_crypto_hpke_RecipientContext_nativeSeal( + JNIEnv* env, jobject obj, jbyteArray nonce, jbyteArray plaintext, + jbyteArray associated_data) { if (nonce == NULL || plaintext == NULL || associated_data == NULL) { return {}; } std::string nonce_str = convert_jbytearray_to_string(env, nonce); std::string plaintext_str = convert_jbytearray_to_string(env, plaintext); - std::string associated_data_str = convert_jbytearray_to_string(env, associated_data); + std::string associated_data_str = + convert_jbytearray_to_string(env, associated_data); jclass recipient_context_class = env->GetObjectClass(obj); jfieldID fid = env->GetFieldID(recipient_context_class, "nativePtr", "J"); @@ -145,8 +157,9 @@ JNIEXPORT jbyteArray JNICALL Java_com_google_oak_crypto_hpke_RecipientContext_na return ret; } -JNIEXPORT void JNICALL Java_com_google_oak_crypto_hpke_SenderContext_nativeDestroy(JNIEnv* env, - jobject obj) { +JNIEXPORT void JNICALL +Java_com_google_oak_crypto_hpke_SenderContext_nativeDestroy(JNIEnv* env, + jobject obj) { jclass context_class = env->GetObjectClass(obj); jfieldID fid = env->GetFieldID(context_class, "nativePtr", "J"); oak::crypto::SenderContext* sender_context = @@ -159,8 +172,9 @@ JNIEXPORT void JNICALL Java_com_google_oak_crypto_hpke_SenderContext_nativeDestr env->SetLongField(obj, fid, 0); } -JNIEXPORT void JNICALL Java_com_google_oak_crypto_hpke_RecipientContext_nativeDestroy(JNIEnv* env, - jobject obj) { +JNIEXPORT void JNICALL +Java_com_google_oak_crypto_hpke_RecipientContext_nativeDestroy(JNIEnv* env, + jobject obj) { jclass context_class = env->GetObjectClass(obj); jfieldID fid = env->GetFieldID(context_class, "nativePtr", "J"); oak::crypto::RecipientContext* recipient_context = diff --git a/cc/crypto/hpke/jni/hpke_jni.cc b/cc/crypto/hpke/jni/hpke_jni.cc index aca6856c6da..179d5afe458 100644 --- a/cc/crypto/hpke/jni/hpke_jni.cc +++ b/cc/crypto/hpke/jni/hpke_jni.cc @@ -23,8 +23,10 @@ #include "com_google_oak_crypto_hpke_Hpke.h" #include "jni_helper.h" -JNIEXPORT jobject JNICALL Java_com_google_oak_crypto_hpke_Hpke_nativeSetupBaseSender( - JNIEnv* env, jclass obj, jbyteArray serialized_recipient_public_key, jbyteArray info) { +JNIEXPORT jobject JNICALL +Java_com_google_oak_crypto_hpke_Hpke_nativeSetupBaseSender( + JNIEnv* env, jclass obj, jbyteArray serialized_recipient_public_key, + jbyteArray info) { if (serialized_recipient_public_key == NULL || info == NULL) { return {}; } @@ -33,8 +35,9 @@ JNIEXPORT jobject JNICALL Java_com_google_oak_crypto_hpke_Hpke_nativeSetupBaseSe convert_jbytearray_to_string(env, serialized_recipient_public_key); std::string info_str = convert_jbytearray_to_string(env, info); - absl::StatusOr> native_sender_context = - oak::crypto::SetupBaseSender(serialized_recipient_public_key_str, info_str); + absl::StatusOr> + native_sender_context = oak::crypto::SetupBaseSender( + serialized_recipient_public_key_str, info_str); if (!native_sender_context.ok()) { return {}; } @@ -44,23 +47,30 @@ JNIEXPORT jobject JNICALL Java_com_google_oak_crypto_hpke_Hpke_nativeSetupBaseSe jbyteArray serialized_encapsulated_public_key = env->NewByteArray(serialized_encapsulated_public_key_len); env->SetByteArrayRegion( - serialized_encapsulated_public_key, 0, serialized_encapsulated_public_key_len, - reinterpret_cast( - (*native_sender_context)->GetSerializedEncapsulatedPublicKey().data())); - - jclass sender_context_class = env->FindClass("com/google/oak/crypto/hpke/SenderContext"); - jmethodID sender_context_constructor = env->GetMethodID(sender_context_class, "", "([BJ)V"); + serialized_encapsulated_public_key, 0, + serialized_encapsulated_public_key_len, + reinterpret_cast((*native_sender_context) + ->GetSerializedEncapsulatedPublicKey() + .data())); + + jclass sender_context_class = + env->FindClass("com/google/oak/crypto/hpke/SenderContext"); + jmethodID sender_context_constructor = + env->GetMethodID(sender_context_class, "", "([BJ)V"); jobject sender_context = env->NewObject(sender_context_class, sender_context_constructor, - serialized_encapsulated_public_key, (long)native_sender_context->release()); + serialized_encapsulated_public_key, + (long)native_sender_context->release()); return sender_context; } -JNIEXPORT jobject JNICALL Java_com_google_oak_crypto_hpke_Hpke_nativeSetupBaseRecipient( +JNIEXPORT jobject JNICALL +Java_com_google_oak_crypto_hpke_Hpke_nativeSetupBaseRecipient( JNIEnv* env, jclass obj, jbyteArray serialized_encapsulated_public_key, jobject recipient_key_pair, jbyteArray info) { - if (serialized_encapsulated_public_key == NULL || info == NULL || recipient_key_pair == NULL) { + if (serialized_encapsulated_public_key == NULL || info == NULL || + recipient_key_pair == NULL) { return {}; } @@ -69,43 +79,50 @@ JNIEXPORT jobject JNICALL Java_com_google_oak_crypto_hpke_Hpke_nativeSetupBaseRe std::string info_str = convert_jbytearray_to_string(env, info); jclass key_pair_class = env->GetObjectClass(recipient_key_pair); - jfieldID private_key_fid = env->GetFieldID(key_pair_class, "privateKey", "[B"); - jbyteArray private_key = - static_cast(env->GetObjectField(recipient_key_pair, private_key_fid)); + jfieldID private_key_fid = + env->GetFieldID(key_pair_class, "privateKey", "[B"); + jbyteArray private_key = static_cast( + env->GetObjectField(recipient_key_pair, private_key_fid)); std::string private_key_str = convert_jbytearray_to_string(env, private_key); jfieldID public_key_fid = env->GetFieldID(key_pair_class, "publicKey", "[B"); - jbyteArray public_key = - static_cast(env->GetObjectField(recipient_key_pair, public_key_fid)); + jbyteArray public_key = static_cast( + env->GetObjectField(recipient_key_pair, public_key_fid)); std::string public_key_str = convert_jbytearray_to_string(env, public_key); oak::crypto::KeyPair key_pair; key_pair.public_key = public_key_str; key_pair.private_key = private_key_str; - absl::StatusOr> native_recipient_context = - oak::crypto::SetupBaseRecipient(serialized_encapsulated_public_key_str, key_pair, info_str); + absl::StatusOr> + native_recipient_context = oak::crypto::SetupBaseRecipient( + serialized_encapsulated_public_key_str, key_pair, info_str); if (!native_recipient_context.ok()) { return {}; } - jclass recipient_context_class = env->FindClass("com/google/oak/crypto/hpke/RecipientContext"); + jclass recipient_context_class = + env->FindClass("com/google/oak/crypto/hpke/RecipientContext"); jmethodID recipient_context_constructor = env->GetMethodID(recipient_context_class, "", "(J)V"); - jobject recipient_context = env->NewObject(recipient_context_class, recipient_context_constructor, - (long)native_recipient_context->release()); + jobject recipient_context = + env->NewObject(recipient_context_class, recipient_context_constructor, + (long)native_recipient_context->release()); return recipient_context; } JNIEXPORT jbyteArray JNICALL -Java_com_google_oak_crypto_hpke_Hpke_nativeGenerateRandomNonce(JNIEnv* env, jclass obj) { - absl::StatusOr> nonce = oak::crypto::GenerateRandomNonce(); +Java_com_google_oak_crypto_hpke_Hpke_nativeGenerateRandomNonce(JNIEnv* env, + jclass obj) { + absl::StatusOr> nonce = + oak::crypto::GenerateRandomNonce(); if (!nonce.ok()) { return {}; } jbyteArray ret = env->NewByteArray(nonce->size()); - env->SetByteArrayRegion(ret, 0, nonce->size(), reinterpret_cast(&nonce->front())); + env->SetByteArrayRegion(ret, 0, nonce->size(), + reinterpret_cast(&nonce->front())); return ret; } diff --git a/cc/crypto/hpke/jni/keypair_jni.cc b/cc/crypto/hpke/jni/keypair_jni.cc index f4363fab1de..00382ffe1d6 100644 --- a/cc/crypto/hpke/jni/keypair_jni.cc +++ b/cc/crypto/hpke/jni/keypair_jni.cc @@ -18,9 +18,11 @@ #include "absl/status/statusor.h" #include "com_google_oak_crypto_hpke_KeyPair.h" -JNIEXPORT jobject JNICALL Java_com_google_oak_crypto_hpke_KeyPair_nativeGenerate(JNIEnv* env, - jclass obj) { - absl::StatusOr<::oak::crypto::KeyPair> kp_status = oak::crypto::KeyPair::Generate(); +JNIEXPORT jobject JNICALL +Java_com_google_oak_crypto_hpke_KeyPair_nativeGenerate(JNIEnv* env, + jclass obj) { + absl::StatusOr<::oak::crypto::KeyPair> kp_status = + oak::crypto::KeyPair::Generate(); if (!kp_status.ok()) { return {}; } @@ -35,8 +37,9 @@ JNIEXPORT jobject JNICALL Java_com_google_oak_crypto_hpke_KeyPair_nativeGenerate reinterpret_cast(public_key.c_str())); jclass key_pair_class = env->FindClass("com/google/oak/crypto/hpke/KeyPair"); - jmethodID key_pair_constructor = env->GetMethodID(key_pair_class, "", "([B[B)V"); - jobject key_pair = - env->NewObject(key_pair_class, key_pair_constructor, private_key_arr, public_key_arr); + jmethodID key_pair_constructor = + env->GetMethodID(key_pair_class, "", "([B[B)V"); + jobject key_pair = env->NewObject(key_pair_class, key_pair_constructor, + private_key_arr, public_key_arr); return key_pair; } \ No newline at end of file diff --git a/cc/crypto/hpke/recipient_context.cc b/cc/crypto/hpke/recipient_context.cc index 74c90096531..16d7741c972 100644 --- a/cc/crypto/hpke/recipient_context.cc +++ b/cc/crypto/hpke/recipient_context.cc @@ -34,8 +34,9 @@ namespace oak::crypto { namespace { using ::oak::crypto::v1::SessionKeys; -// Validates that the public and private key pairing is valid for HPKE. If the public and private -// keys are valid, the recipient_keys argument will be an initialized HPKE_KEY. +// Validates that the public and private key pairing is valid for HPKE. If the +// public and private keys are valid, the recipient_keys argument will be an +// initialized HPKE_KEY. absl::Status ValidateKeys(std::vector& public_key_bytes, std::vector& private_key_bytes, std::vector encap_public_key_bytes, @@ -47,7 +48,8 @@ absl::Status ValidateKeys(std::vector& public_key_bytes, return absl::InvalidArgumentError("A public key must be provided"); } if (encap_public_key_bytes.empty()) { - return absl::InvalidArgumentError("An encapsulated public key must be provided"); + return absl::InvalidArgumentError( + "An encapsulated public key must be provided"); } if (!EVP_HPKE_KEY_init( @@ -100,9 +102,9 @@ absl::StatusOr> RecipientContext::Deserialize( /* response_aead_context= */ std::move(response_aead_context)); } -absl::StatusOr RecipientContext::Open(const std::vector& nonce, - absl::string_view ciphertext, - absl::string_view associated_data) { +absl::StatusOr RecipientContext::Open( + const std::vector& nonce, absl::string_view ciphertext, + absl::string_view associated_data) { absl::StatusOr plaintext = AeadOpen(request_aead_context_.get(), nonce, ciphertext, associated_data); if (!plaintext.ok()) { @@ -111,9 +113,9 @@ absl::StatusOr RecipientContext::Open(const std::vector& n return plaintext; } -absl::StatusOr RecipientContext::Seal(const std::vector& nonce, - absl::string_view plaintext, - absl::string_view associated_data) { +absl::StatusOr RecipientContext::Seal( + const std::vector& nonce, absl::string_view plaintext, + absl::string_view associated_data) { absl::StatusOr ciphertext = AeadSeal(response_aead_context_.get(), nonce, plaintext, associated_data); if (!ciphertext.ok()) { @@ -128,20 +130,23 @@ RecipientContext::~RecipientContext() { } absl::StatusOr> SetupBaseRecipient( - absl::string_view serialized_encapsulated_public_key, const KeyPair& recipient_key_pair, - absl::string_view info) { - // First verify that the supplied key pairing is valid using the BoringSSL library. + absl::string_view serialized_encapsulated_public_key, + const KeyPair& recipient_key_pair, absl::string_view info) { + // First verify that the supplied key pairing is valid using the BoringSSL + // library. std::vector private_key_bytes(recipient_key_pair.private_key.begin(), recipient_key_pair.private_key.end()); std::vector public_key_bytes(recipient_key_pair.public_key.begin(), recipient_key_pair.public_key.end()); - std::vector encap_public_key_bytes(serialized_encapsulated_public_key.begin(), - serialized_encapsulated_public_key.end()); + std::vector encap_public_key_bytes( + serialized_encapsulated_public_key.begin(), + serialized_encapsulated_public_key.end()); bssl::ScopedEVP_HPKE_KEY recipient_keys; - absl::Status keys_valid_status = ValidateKeys(public_key_bytes, private_key_bytes, - encap_public_key_bytes, recipient_keys.get()); + absl::Status keys_valid_status = + ValidateKeys(public_key_bytes, private_key_bytes, encap_public_key_bytes, + recipient_keys.get()); if (!keys_valid_status.ok()) { return keys_valid_status; } @@ -165,26 +170,30 @@ absl::StatusOr> SetupBaseRecipient( } // Configure recipient request context. - // This is a deviation from the HPKE RFC, because we are deriving both session request and - // response keys from the exporter secret, instead of having a request key be directly derived - // from the shared secret. This is required to be able to share session keys between the Kernel - // and the Application via RPC. + // This is a deviation from the HPKE RFC, because we are deriving both session + // request and response keys from the exporter secret, instead of having a + // request key be directly derived from the shared secret. This is required to + // be able to share session keys between the Kernel and the Application via + // RPC. // - auto request_aead_context = GetContext(hpke_recipient_context.get(), "request_key"); + auto request_aead_context = + GetContext(hpke_recipient_context.get(), "request_key"); if (!request_aead_context.ok()) { return request_aead_context.status(); } // Configure recipient response context. - auto response_aead_context = GetContext(hpke_recipient_context.get(), "response_key"); + auto response_aead_context = + GetContext(hpke_recipient_context.get(), "response_key"); if (!response_aead_context.ok()) { return response_aead_context.status(); } // Create recipient context. - std::unique_ptr recipient_context = std::make_unique( - /* request_aead_context= */ *std::move(request_aead_context), - /* response_aead_context= */ *std::move(response_aead_context)); + std::unique_ptr recipient_context = + std::make_unique( + /* request_aead_context= */ *std::move(request_aead_context), + /* response_aead_context= */ *std::move(response_aead_context)); EVP_HPKE_CTX_free(hpke_recipient_context.release()); return recipient_context; diff --git a/cc/crypto/hpke/recipient_context.h b/cc/crypto/hpke/recipient_context.h index 86593f3d720..f06ec256de4 100644 --- a/cc/crypto/hpke/recipient_context.h +++ b/cc/crypto/hpke/recipient_context.h @@ -49,13 +49,15 @@ class RecipientContext { // Decrypts message and validates associated data using AEAD. // - absl::StatusOr Open(const std::vector& nonce, absl::string_view ciphertext, + absl::StatusOr Open(const std::vector& nonce, + absl::string_view ciphertext, absl::string_view associated_data); - // Encrypts response message with associated data using AEAD as part of bidirectional - // communication. + // Encrypts response message with associated data using AEAD as part of + // bidirectional communication. // - absl::StatusOr Seal(const std::vector& nonce, absl::string_view plaintext, + absl::StatusOr Seal(const std::vector& nonce, + absl::string_view plaintext, absl::string_view associated_data); ~RecipientContext(); @@ -68,8 +70,8 @@ class RecipientContext { // Sets up an HPKE recipient by creating a recipient context. // absl::StatusOr> SetupBaseRecipient( - absl::string_view serialized_encapsulated_public_key, const KeyPair& recipient_key_pair, - absl::string_view info); + absl::string_view serialized_encapsulated_public_key, + const KeyPair& recipient_key_pair, absl::string_view info); } // namespace oak::crypto diff --git a/cc/crypto/hpke/recipient_context_test.cc b/cc/crypto/hpke/recipient_context_test.cc index 614320f1693..433f4dbe3a4 100644 --- a/cc/crypto/hpke/recipient_context_test.cc +++ b/cc/crypto/hpke/recipient_context_test.cc @@ -38,22 +38,28 @@ using ::testing::StrNe; class RecipientContextTest : public testing::Test { protected: void SetUp() override { - // This key pairing was randomly generated using EVP_HPKE_KEY_generate() with the x25519 KEM. + // This key pairing was randomly generated using EVP_HPKE_KEY_generate() + // with the x25519 KEM. const std::vector public_key_bytes = { - 236, 102, 18, 92, 231, 237, 92, 56, 199, 21, 200, 213, 172, 150, 80, 217, - 64, 33, 77, 203, 109, 68, 21, 12, 76, 219, 16, 62, 110, 19, 69, 8}; - std::string serialized_public_key(public_key_bytes.begin(), public_key_bytes.end()); + 236, 102, 18, 92, 231, 237, 92, 56, 199, 21, 200, + 213, 172, 150, 80, 217, 64, 33, 77, 203, 109, 68, + 21, 12, 76, 219, 16, 62, 110, 19, 69, 8}; + std::string serialized_public_key(public_key_bytes.begin(), + public_key_bytes.end()); const std::vector private_key_bytes = { - 255, 12, 169, 64, 221, 170, 194, 165, 224, 77, 222, 165, 95, 179, 124, 55, - 236, 237, 58, 11, 130, 177, 153, 40, 31, 221, 13, 138, 71, 107, 243, 173}; - std::string serialized_private_key(private_key_bytes.begin(), private_key_bytes.end()); + 255, 12, 169, 64, 221, 170, 194, 165, 224, 77, 222, + 165, 95, 179, 124, 55, 236, 237, 58, 11, 130, 177, + 153, 40, 31, 221, 13, 138, 71, 107, 243, 173}; + std::string serialized_private_key(private_key_bytes.begin(), + private_key_bytes.end()); recipient_key_pair_.public_key = serialized_public_key; recipient_key_pair_.private_key = serialized_private_key; // Random encapsulated public key from the SetupBaseSender function. const std::vector encap_public_key_bytes = { - 85, 255, 224, 169, 132, 101, 176, 248, 95, 67, 86, 31, 44, 31, 230, 224, - 226, 174, 242, 10, 200, 162, 222, 196, 255, 25, 114, 64, 4, 15, 193, 89}; + 85, 255, 224, 169, 132, 101, 176, 248, 95, 67, 86, + 31, 44, 31, 230, 224, 226, 174, 242, 10, 200, 162, + 222, 196, 255, 25, 114, 64, 4, 15, 193, 89}; std::string encap_public_key_string(encap_public_key_bytes.begin(), encap_public_key_bytes.end()); encap_public_key_ = encap_public_key_string; @@ -62,18 +68,22 @@ class RecipientContextTest : public testing::Test { associated_data_response_ = "Test response associated data"; associated_data_request_ = "Test request associated data"; - const std::vector request_key = {164, 174, 176, 213, 235, 46, 157, 155, 157, 138, 173, - 65, 231, 242, 53, 28, 46, 170, 179, 170, 172, 110, - 195, 108, 240, 157, 178, 24, 91, 148, 232, 121}; - *crypto_context_.mutable_request_key() = std::string(request_key.begin(), request_key.end()); - const std::vector request_base_nonce = {155, 198, 201, 66, 230, 227, - 208, 99, 5, 64, 207, 183}; + const std::vector request_key = { + 164, 174, 176, 213, 235, 46, 157, 155, 157, 138, 173, + 65, 231, 242, 53, 28, 46, 170, 179, 170, 172, 110, + 195, 108, 240, 157, 178, 24, 91, 148, 232, 121}; + *crypto_context_.mutable_request_key() = + std::string(request_key.begin(), request_key.end()); + const std::vector request_base_nonce = { + 155, 198, 201, 66, 230, 227, 208, 99, 5, 64, 207, 183}; const std::vector response_key = { - 109, 21, 112, 119, 203, 119, 184, 30, 12, 31, 93, 71, 171, 224, 74, 241, - 113, 168, 228, 50, 145, 105, 164, 174, 206, 149, 197, 5, 25, 186, 254, 154}; - *crypto_context_.mutable_response_key() = std::string(response_key.begin(), response_key.end()); - const std::vector response_base_nonce = {111, 93, 22, 215, 77, 149, - 30, 204, 13, 168, 55, 163}; + 109, 21, 112, 119, 203, 119, 184, 30, 12, 31, 93, + 71, 171, 224, 74, 241, 113, 168, 228, 50, 145, 105, + 164, 174, 206, 149, 197, 5, 25, 186, 254, 154}; + *crypto_context_.mutable_response_key() = + std::string(response_key.begin(), response_key.end()); + const std::vector response_base_nonce = { + 111, 93, 22, 215, 77, 149, 30, 204, 13, 168, 55, 163}; } KeyPair recipient_key_pair_; std::string encap_public_key_; @@ -85,53 +95,65 @@ class RecipientContextTest : public testing::Test { TEST_F(RecipientContextTest, SetupBaseRecipientEmptyEncapKeyReturnsFailure) { std::string empty_string = ""; - auto recipient_context = SetupBaseRecipient(empty_string, recipient_key_pair_, info_string_); + auto recipient_context = + SetupBaseRecipient(empty_string, recipient_key_pair_, info_string_); EXPECT_FALSE(recipient_context.ok()); } TEST_F(RecipientContextTest, SetupBaseRecipientEmptyPublicKeyReturnsFailure) { recipient_key_pair_.public_key = ""; - auto recipient_context = SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); + auto recipient_context = + SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); EXPECT_FALSE(recipient_context.ok()); } TEST_F(RecipientContextTest, SetupBaseRecipientEmptyPrivateKeyReturnsFailure) { recipient_key_pair_.private_key = ""; - auto recipient_context = SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); + auto recipient_context = + SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); EXPECT_FALSE(recipient_context.ok()); } -TEST_F(RecipientContextTest, SetupBaseRecipientMismatchedKeyPairReturnsFailure) { +TEST_F(RecipientContextTest, + SetupBaseRecipientMismatchedKeyPairReturnsFailure) { // We edit the default public key to produce an invalid key pairing. - std::vector different_public_key(recipient_key_pair_.public_key.begin(), - recipient_key_pair_.public_key.end()); + std::vector different_public_key( + recipient_key_pair_.public_key.begin(), + recipient_key_pair_.public_key.end()); different_public_key[0] += 1; - std::string different_public_key_str(different_public_key.begin(), different_public_key.end()); + std::string different_public_key_str(different_public_key.begin(), + different_public_key.end()); recipient_key_pair_.public_key = different_public_key_str; - auto recipient_context = SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); + auto recipient_context = + SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); EXPECT_FALSE(recipient_context.ok()); } TEST_F(RecipientContextTest, SetupBaseRecipientReturnsValidPointersOnSuccess) { - auto recipient_context = SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); + auto recipient_context = + SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); ASSERT_TRUE(recipient_context.ok()); EXPECT_TRUE(*recipient_context); } TEST_F(RecipientContextTest, RecipientContextOpenSuccess) { // Initialize an HPKE sender. - auto sender_context = SetupBaseSender(recipient_key_pair_.public_key, info_string_); + auto sender_context = + SetupBaseSender(recipient_key_pair_.public_key, info_string_); ASSERT_TRUE(sender_context.ok()); std::string plaintext = "Hello World"; absl::StatusOr> nonce = GenerateRandomNonce(); ASSERT_TRUE(nonce.ok()); - auto ciphertext = (*sender_context)->Seal(*nonce, plaintext, associated_data_request_); + auto ciphertext = + (*sender_context)->Seal(*nonce, plaintext, associated_data_request_); ASSERT_TRUE(ciphertext.ok()); - std::string encap_public_key = (*sender_context)->GetSerializedEncapsulatedPublicKey(); - auto recipient_context = SetupBaseRecipient(encap_public_key, recipient_key_pair_, info_string_); + std::string encap_public_key = + (*sender_context)->GetSerializedEncapsulatedPublicKey(); + auto recipient_context = + SetupBaseRecipient(encap_public_key, recipient_key_pair_, info_string_); ASSERT_TRUE(recipient_context.ok()); auto received_plaintext = (*recipient_context)->Open(*nonce, *ciphertext, associated_data_request_); @@ -142,50 +164,60 @@ TEST_F(RecipientContextTest, RecipientContextOpenSuccess) { TEST_F(RecipientContextTest, RecipientRequestContextOpenFailure) { // Initialize an HPKE sender. - auto sender_context = SetupBaseSender(recipient_key_pair_.public_key, info_string_); + auto sender_context = + SetupBaseSender(recipient_key_pair_.public_key, info_string_); ASSERT_TRUE(sender_context.ok()); std::string plaintext = "Hello World"; absl::StatusOr> nonce = GenerateRandomNonce(); ASSERT_TRUE(nonce.ok()); - auto ciphertext = (*sender_context)->Seal(*nonce, plaintext, associated_data_request_); + auto ciphertext = + (*sender_context)->Seal(*nonce, plaintext, associated_data_request_); ASSERT_TRUE(ciphertext.ok()); std::string edited_ciphertext = absl::StrCat(*ciphertext, "no!"); - std::string encap_public_key = (*sender_context)->GetSerializedEncapsulatedPublicKey(); - auto recipient_context = SetupBaseRecipient(encap_public_key, recipient_key_pair_, info_string_); + std::string encap_public_key = + (*sender_context)->GetSerializedEncapsulatedPublicKey(); + auto recipient_context = + SetupBaseRecipient(encap_public_key, recipient_key_pair_, info_string_); ASSERT_TRUE(recipient_context.ok()); auto received_plaintext = - (*recipient_context)->Open(*nonce, edited_ciphertext, associated_data_request_); + (*recipient_context) + ->Open(*nonce, edited_ciphertext, associated_data_request_); EXPECT_FALSE(received_plaintext.ok()); EXPECT_EQ(received_plaintext.status().code(), absl::StatusCode::kAborted); } TEST_F(RecipientContextTest, RecipientResponseContextSealSuccess) { - auto recipient_context = SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); + auto recipient_context = + SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); ASSERT_TRUE(recipient_context.ok()); std::string plaintext = "Hello World"; absl::StatusOr> nonce = GenerateRandomNonce(); ASSERT_TRUE(nonce.ok()); - auto ciphertext = (*recipient_context)->Seal(*nonce, plaintext, associated_data_response_); + auto ciphertext = + (*recipient_context)->Seal(*nonce, plaintext, associated_data_response_); ASSERT_TRUE(ciphertext.ok()); EXPECT_THAT(plaintext, StrNe(*ciphertext)); } TEST_F(RecipientContextTest, RecipientResponseContextSealFailure) { - auto recipient_context = SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); + auto recipient_context = + SetupBaseRecipient(encap_public_key_, recipient_key_pair_, info_string_); ASSERT_TRUE(recipient_context.ok()); std::string empty_plaintext = ""; absl::StatusOr> nonce = GenerateRandomNonce(); ASSERT_TRUE(nonce.ok()); - auto ciphertext = (*recipient_context)->Seal(*nonce, empty_plaintext, associated_data_response_); + auto ciphertext = + (*recipient_context) + ->Seal(*nonce, empty_plaintext, associated_data_response_); EXPECT_FALSE(ciphertext.ok()); EXPECT_EQ(ciphertext.status().code(), absl::StatusCode::kInvalidArgument); } @@ -197,9 +229,11 @@ TEST_F(RecipientContextTest, GenerateKeysAndSetupBaseRecipientSuccess) { auto sender_context = SetupBaseSender(key_pair->public_key, info_string_); ASSERT_TRUE(sender_context.ok()); - std::string encap_public_key = (*sender_context)->GetSerializedEncapsulatedPublicKey(); + std::string encap_public_key = + (*sender_context)->GetSerializedEncapsulatedPublicKey(); - auto recipient_context = SetupBaseRecipient(encap_public_key, *key_pair, info_string_); + auto recipient_context = + SetupBaseRecipient(encap_public_key, *key_pair, info_string_); EXPECT_TRUE(recipient_context.ok()); } diff --git a/cc/crypto/hpke/sender_context.cc b/cc/crypto/hpke/sender_context.cc index f119f05219d..3ea0207da15 100644 --- a/cc/crypto/hpke/sender_context.cc +++ b/cc/crypto/hpke/sender_context.cc @@ -30,9 +30,9 @@ namespace oak::crypto { -absl::StatusOr SenderContext::Seal(const std::vector& nonce, - absl::string_view plaintext, - absl::string_view associated_data) { +absl::StatusOr SenderContext::Seal( + const std::vector& nonce, absl::string_view plaintext, + absl::string_view associated_data) { absl::StatusOr ciphertext = AeadSeal(request_aead_context_.get(), nonce, plaintext, associated_data); if (!ciphertext.ok()) { @@ -41,11 +41,11 @@ absl::StatusOr SenderContext::Seal(const std::vector& nonc return ciphertext; } -absl::StatusOr SenderContext::Open(const std::vector& nonce, - absl::string_view ciphertext, - absl::string_view associated_data) { - absl::StatusOr plaintext = - AeadOpen(response_aead_context_.get(), nonce, ciphertext, associated_data); +absl::StatusOr SenderContext::Open( + const std::vector& nonce, absl::string_view ciphertext, + absl::string_view associated_data) { + absl::StatusOr plaintext = AeadOpen( + response_aead_context_.get(), nonce, ciphertext, associated_data); if (!plaintext.ok()) { return plaintext.status(); } @@ -59,12 +59,15 @@ SenderContext::~SenderContext() { absl::StatusOr> SetupBaseSender( absl::string_view serialized_recipient_public_key, absl::string_view info) { - // First collect encapsulated public key information and sender request context. + // First collect encapsulated public key information and sender request + // context. KeyInfo encap_public_key_info; - encap_public_key_info.key_bytes = std::vector(EVP_HPKE_MAX_ENC_LENGTH); + encap_public_key_info.key_bytes = + std::vector(EVP_HPKE_MAX_ENC_LENGTH); - std::vector recipient_public_key_bytes(serialized_recipient_public_key.begin(), - serialized_recipient_public_key.end()); + std::vector recipient_public_key_bytes( + serialized_recipient_public_key.begin(), + serialized_recipient_public_key.end()); if (recipient_public_key_bytes.empty()) { return absl::InvalidArgumentError("No key was provided"); @@ -94,27 +97,31 @@ absl::StatusOr> SetupBaseSender( encap_public_key_info.key_bytes.resize(encap_public_key_info.key_size); // Configure sender request context. - // This is a deviation from the HPKE RFC, because we are deriving both session request and - // response keys from the exporter secret, instead of having a request key be directly derived - // from the shared secret. This is required to be able to share session keys between the Kernel - // and the Application via RPC. + // This is a deviation from the HPKE RFC, because we are deriving both session + // request and response keys from the exporter secret, instead of having a + // request key be directly derived from the shared secret. This is required to + // be able to share session keys between the Kernel and the Application via + // RPC. // - auto request_aead_context = GetContext(hpke_sender_context.get(), "request_key"); + auto request_aead_context = + GetContext(hpke_sender_context.get(), "request_key"); if (!request_aead_context.ok()) { return request_aead_context.status(); } // Configure sender response context. - auto response_aead_context = GetContext(hpke_sender_context.get(), "response_key"); + auto response_aead_context = + GetContext(hpke_sender_context.get(), "response_key"); if (!response_aead_context.ok()) { return response_aead_context.status(); } // Create sender context. - std::unique_ptr sender_context = std::make_unique( - /* encapsulated_public_key= */ encap_public_key_info.key_bytes, - /* request_aead_context= */ *std::move(request_aead_context), - /* response_aead_context= */ *std::move(response_aead_context)); + std::unique_ptr sender_context = + std::make_unique( + /* encapsulated_public_key= */ encap_public_key_info.key_bytes, + /* request_aead_context= */ *std::move(request_aead_context), + /* response_aead_context= */ *std::move(response_aead_context)); EVP_HPKE_CTX_free(hpke_sender_context.release()); return sender_context; diff --git a/cc/crypto/hpke/sender_context.h b/cc/crypto/hpke/sender_context.h index 630e3b6bdb2..940517c5f0b 100644 --- a/cc/crypto/hpke/sender_context.h +++ b/cc/crypto/hpke/sender_context.h @@ -29,8 +29,9 @@ namespace oak::crypto { -// Context for generating encrypted requests to the recipient and for decrypting encrypted responses -// from the recipient. This is based on bi-directional encryption using AEAD. +// Context for generating encrypted requests to the recipient and for decrypting +// encrypted responses from the recipient. This is based on bi-directional +// encryption using AEAD. // . class SenderContext { public: @@ -48,13 +49,15 @@ class SenderContext { // Encrypts message with associated data using AEAD. // - absl::StatusOr Seal(const std::vector& nonce, absl::string_view plaintext, + absl::StatusOr Seal(const std::vector& nonce, + absl::string_view plaintext, absl::string_view associated_data); - // Decrypts response message and validates associated data using AEAD as part of - // bidirectional communication. + // Decrypts response message and validates associated data using AEAD as part + // of bidirectional communication. // - absl::StatusOr Open(const std::vector& nonce, absl::string_view ciphertext, + absl::StatusOr Open(const std::vector& nonce, + absl::string_view ciphertext, absl::string_view associated_data); ~SenderContext(); @@ -65,13 +68,13 @@ class SenderContext { std::unique_ptr response_aead_context_; }; -// Sets up an HPKE sender by generating an ephemeral keypair (and serializing the corresponding -// public key) and creating a sender context. -// Returns the encapsulated public key, sender request and sender response contexts. +// Sets up an HPKE sender by generating an ephemeral keypair (and serializing +// the corresponding public key) and creating a sender context. Returns the +// encapsulated public key, sender request and sender response contexts. // // -// Encapsulated public key is represented as a NIST P-256 SEC1 encoded point public key. -// +// Encapsulated public key is represented as a NIST P-256 SEC1 encoded point +// public key. absl::StatusOr> SetupBaseSender( absl::string_view serialized_recipient_public_key, absl::string_view info); diff --git a/cc/crypto/hpke/sender_context_test.cc b/cc/crypto/hpke/sender_context_test.cc index 4491a953062..75b452a2da8 100644 --- a/cc/crypto/hpke/sender_context_test.cc +++ b/cc/crypto/hpke/sender_context_test.cc @@ -35,24 +35,29 @@ using ::testing::StrNe; class SenderContextTest : public testing::Test { protected: void SetUp() override { - std::vector public_key_bytes = {11, 107, 5, 176, 4, 145, 171, 193, 163, 81, 105, - 238, 171, 115, 56, 160, 130, 85, 22, 227, 118, 76, - 77, 89, 144, 223, 10, 112, 11, 149, 205, 199}; - std::string serialized_public_key(public_key_bytes.begin(), public_key_bytes.end()); + std::vector public_key_bytes = { + 11, 107, 5, 176, 4, 145, 171, 193, 163, 81, 105, + 238, 171, 115, 56, 160, 130, 85, 22, 227, 118, 76, + 77, 89, 144, 223, 10, 112, 11, 149, 205, 199}; + std::string serialized_public_key(public_key_bytes.begin(), + public_key_bytes.end()); serialized_public_key_ = serialized_public_key; info_string_ = "Test HPKE info"; associated_data_response_ = "Test response associated data"; associated_data_request_ = "Test request associated data"; - default_response_key_ = {166, 107, 125, 81, 22, 76, 76, 237, 160, 40, 232, - 236, 244, 165, 13, 38, 157, 220, 162, 233, 235, 158, - 226, 157, 152, 52, 162, 106, 93, 68, 12, 171}; + default_response_key_ = {166, 107, 125, 81, 22, 76, 76, 237, + 160, 40, 232, 236, 244, 165, 13, 38, + 157, 220, 162, 233, 235, 158, 226, 157, + 152, 52, 162, 106, 93, 68, 12, 171}; default_nonce_bytes_ = {1, 242, 45, 144, 96, 26, 190, 43, 156, 154, 2, 69}; } - SenderContext CreateTestSenderContext(std::unique_ptr response_aead_context) { - return SenderContext(std::vector(), - nullptr, // Sender request sealing is not tested in this test. - std::move(response_aead_context)); + SenderContext CreateTestSenderContext( + std::unique_ptr response_aead_context) { + return SenderContext( + std::vector(), + nullptr, // Sender request sealing is not tested in this test. + std::move(response_aead_context)); } std::string serialized_public_key_; @@ -67,14 +72,17 @@ TEST_F(SenderContextTest, SetupBaseSenderReturnsUniqueEncapsulatedKey) { absl::StatusOr> sender_context = SetupBaseSender(serialized_public_key_, info_string_); ASSERT_TRUE(sender_context.ok()); - std::string encapsulated_public_key1 = (*sender_context)->GetSerializedEncapsulatedPublicKey(); + std::string encapsulated_public_key1 = + (*sender_context)->GetSerializedEncapsulatedPublicKey(); auto sender_context2 = SetupBaseSender(serialized_public_key_, info_string_); ASSERT_TRUE(sender_context2.ok()); - std::string encapsulated_public_key2 = (*sender_context2)->GetSerializedEncapsulatedPublicKey(); + std::string encapsulated_public_key2 = + (*sender_context2)->GetSerializedEncapsulatedPublicKey(); EXPECT_THAT(encapsulated_public_key1, StrNe(encapsulated_public_key2)); } -TEST_F(SenderContextTest, SetupBaseSenderReturnsInvalidArgumentErrorForEmptyKey) { +TEST_F(SenderContextTest, + SetupBaseSenderReturnsInvalidArgumentErrorForEmptyKey) { std::string empty_public_key = ""; absl::StatusOr> sender_context = SetupBaseSender(empty_public_key, info_string_); @@ -110,11 +118,13 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageSuccess) { /* tag_len= */ 0)); // Configure sender context. - SenderContext sender_context = CreateTestSenderContext(std::move(response_aead_context_receive)); + SenderContext sender_context = + CreateTestSenderContext(std::move(response_aead_context_receive)); // Generate encrypted message. std::string plaintext_message = "Hello World"; - std::vector plaintext_bytes(plaintext_message.begin(), plaintext_message.end()); + std::vector plaintext_bytes(plaintext_message.begin(), + plaintext_message.end()); std::unique_ptr response_aead_context_send(EVP_AEAD_CTX_new( /* aead= */ aead_version, @@ -122,7 +132,8 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageSuccess) { /* key_len= */ default_response_key_.size(), /* tag_len= */ 0)); - std::vector ciphertext_bytes(plaintext_bytes.size() + EVP_HPKE_MAX_OVERHEAD); + std::vector ciphertext_bytes(plaintext_bytes.size() + + EVP_HPKE_MAX_OVERHEAD); size_t ciphertext_size; ASSERT_TRUE(EVP_AEAD_CTX_seal( /* ctx= */ response_aead_context_send.get(), @@ -139,8 +150,8 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageSuccess) { std::string ciphertext(ciphertext_bytes.begin(), ciphertext_bytes.end()); // Successfully open encrypted message and get back original plaintext. - auto decyphered_message = - sender_context.Open(default_nonce_bytes_, ciphertext, associated_data_response_); + auto decyphered_message = sender_context.Open( + default_nonce_bytes_, ciphertext, associated_data_response_); EXPECT_TRUE(decyphered_message.ok()); // Cleanup the lingering context. @@ -149,7 +160,8 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageSuccess) { TEST_F(SenderContextTest, SenderOpensEncryptedMessageFailureNoncesNotAligned) { // The second set of nonce bytes are not the same. - std::vector nonce_bytes_diff = {0, 242, 45, 144, 96, 26, 190, 43, 156, 154, 2, 69}; + std::vector nonce_bytes_diff = {0, 242, 45, 144, 96, 26, + 190, 43, 156, 154, 2, 69}; std::vector associated_data_bytes(associated_data_response_.begin(), associated_data_response_.end()); @@ -163,11 +175,13 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageFailureNoncesNotAligned) { /* tag_len= */ 0)); // Configure sender context. - SenderContext sender_context = CreateTestSenderContext(std::move(response_aead_context_receive)); + SenderContext sender_context = + CreateTestSenderContext(std::move(response_aead_context_receive)); // Generate encrypted message. std::string plaintext_message = "Hello World"; - std::vector plaintext_bytes(plaintext_message.begin(), plaintext_message.end()); + std::vector plaintext_bytes(plaintext_message.begin(), + plaintext_message.end()); std::unique_ptr response_aead_context_send(EVP_AEAD_CTX_new( /* aead= */ aead_version, @@ -175,7 +189,8 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageFailureNoncesNotAligned) { /* key_len= */ default_response_key_.size(), /* tag_len= */ 0)); - std::vector ciphertext_bytes(plaintext_bytes.size() + EVP_HPKE_MAX_OVERHEAD); + std::vector ciphertext_bytes(plaintext_bytes.size() + + EVP_HPKE_MAX_OVERHEAD); size_t ciphertext_size; ASSERT_TRUE(EVP_AEAD_CTX_seal( /* ctx= */ response_aead_context_send.get(), @@ -192,8 +207,8 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageFailureNoncesNotAligned) { std::string ciphertext(ciphertext_bytes.begin(), ciphertext_bytes.end()); // Attempt to open the encrypted message. This should fail. - auto decyphered_message = - sender_context.Open(default_nonce_bytes_, ciphertext, associated_data_response_); + auto decyphered_message = sender_context.Open( + default_nonce_bytes_, ciphertext, associated_data_response_); EXPECT_FALSE(decyphered_message.ok()); EXPECT_EQ(decyphered_message.status().code(), absl::StatusCode::kAborted); @@ -201,7 +216,8 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageFailureNoncesNotAligned) { EVP_AEAD_CTX_free(response_aead_context_send.release()); } -TEST_F(SenderContextTest, SenderOpensEncryptedMessageFailureAssociatedDataNotAligned) { +TEST_F(SenderContextTest, + SenderOpensEncryptedMessageFailureAssociatedDataNotAligned) { std::vector associated_data_bytes(associated_data_response_.begin(), associated_data_response_.end()); @@ -214,11 +230,13 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageFailureAssociatedDataNotAli /* tag_len= */ 0)); // Configure sender context. - SenderContext sender_context = CreateTestSenderContext(std::move(response_aead_context_receive)); + SenderContext sender_context = + CreateTestSenderContext(std::move(response_aead_context_receive)); // Generate encrypted message. std::string plaintext_message = "Hello World"; - std::vector plaintext_bytes(plaintext_message.begin(), plaintext_message.end()); + std::vector plaintext_bytes(plaintext_message.begin(), + plaintext_message.end()); std::unique_ptr response_aead_context_send(EVP_AEAD_CTX_new( /* aead= */ aead_version, @@ -226,7 +244,8 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageFailureAssociatedDataNotAli /* key_len= */ default_response_key_.size(), /* tag_len= */ 0)); - std::vector ciphertext_bytes(plaintext_bytes.size() + EVP_HPKE_MAX_OVERHEAD); + std::vector ciphertext_bytes(plaintext_bytes.size() + + EVP_HPKE_MAX_OVERHEAD); size_t ciphertext_size; ASSERT_TRUE(EVP_AEAD_CTX_seal( /* ctx= */ response_aead_context_send.get(), @@ -242,10 +261,11 @@ TEST_F(SenderContextTest, SenderOpensEncryptedMessageFailureAssociatedDataNotAli ciphertext_bytes.resize(ciphertext_size); std::string ciphertext(ciphertext_bytes.begin(), ciphertext_bytes.end()); - // Attempt to open the encrypted message using different associated data. This should fail. + // Attempt to open the encrypted message using different associated data. This + // should fail. std::string different_associated_data = "Different response associated data"; - auto decyphered_message = - sender_context.Open(default_nonce_bytes_, ciphertext, different_associated_data); + auto decyphered_message = sender_context.Open( + default_nonce_bytes_, ciphertext, different_associated_data); EXPECT_FALSE(decyphered_message.ok()); EXPECT_EQ(decyphered_message.status().code(), absl::StatusCode::kAborted); @@ -266,14 +286,16 @@ TEST_F(SenderContextTest, SenderOpensEmptyEncryptedMessageFailure) { /* tag_len= */ 0)); // Configure sender context. - SenderContext sender_context = CreateTestSenderContext(std::move(response_aead_context_receive)); + SenderContext sender_context = + CreateTestSenderContext(std::move(response_aead_context_receive)); // We use an empty ciphertext. std::string ciphertext = ""; - auto decyphered_message = - sender_context.Open(default_nonce_bytes_, ciphertext, associated_data_response_); + auto decyphered_message = sender_context.Open( + default_nonce_bytes_, ciphertext, associated_data_response_); EXPECT_FALSE(decyphered_message.ok()); - EXPECT_EQ(decyphered_message.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(decyphered_message.status().code(), + absl::StatusCode::kInvalidArgument); } } // namespace diff --git a/cc/crypto/hpke/utils.cc b/cc/crypto/hpke/utils.cc index 46ca3c2633b..d973d47601b 100644 --- a/cc/crypto/hpke/utils.cc +++ b/cc/crypto/hpke/utils.cc @@ -37,10 +37,11 @@ namespace oak::crypto { constexpr size_t kAeadAlgorithmKeySizeBytes = 32; constexpr size_t kAeadNonceSizeBytes = 12; -absl::StatusOr> GetContext(EVP_HPKE_CTX* hpke_ctx, - absl::string_view key_context_string) { +absl::StatusOr> GetContext( + EVP_HPKE_CTX* hpke_ctx, absl::string_view key_context_string) { std::vector key(kAeadAlgorithmKeySizeBytes); - std::vector key_context_bytes(key_context_string.begin(), key_context_string.end()); + std::vector key_context_bytes(key_context_string.begin(), + key_context_string.end()); if (!EVP_HPKE_CTX_export( /* ctx= */ hpke_ctx, @@ -72,16 +73,19 @@ absl::StatusOr> GenerateRandomNonce() { return nonce; } -absl::StatusOr AeadSeal(const EVP_AEAD_CTX* context, std::vector nonce, +absl::StatusOr AeadSeal(const EVP_AEAD_CTX* context, + std::vector nonce, absl::string_view plaintext, absl::string_view associated_data) { std::vector plaintext_bytes(plaintext.begin(), plaintext.end()); if (plaintext_bytes.empty()) { return absl::InvalidArgumentError("No plaintext was provided"); } - std::vector associated_data_bytes(associated_data.begin(), associated_data.end()); + std::vector associated_data_bytes(associated_data.begin(), + associated_data.end()); size_t max_out_len = - plaintext_bytes.size() + EVP_AEAD_max_overhead(EVP_HPKE_AEAD_aead(EVP_hpke_aes_256_gcm())); + plaintext_bytes.size() + + EVP_AEAD_max_overhead(EVP_HPKE_AEAD_aead(EVP_hpke_aes_256_gcm())); std::vector ciphertext_bytes(max_out_len); size_t ciphertext_bytes_len = 0; @@ -104,14 +108,16 @@ absl::StatusOr AeadSeal(const EVP_AEAD_CTX* context, std::vector AeadOpen(const EVP_AEAD_CTX* context, std::vector nonce, +absl::StatusOr AeadOpen(const EVP_AEAD_CTX* context, + std::vector nonce, absl::string_view ciphertext, absl::string_view associated_data) { std::vector ciphertext_bytes(ciphertext.begin(), ciphertext.end()); if (ciphertext_bytes.empty()) { return absl::InvalidArgumentError("No ciphertext was provided"); } - std::vector associated_data_bytes(associated_data.begin(), associated_data.end()); + std::vector associated_data_bytes(associated_data.begin(), + associated_data.end()); // The plaintext should not be longer than the ciphertext. std::vector plaintext_bytes(ciphertext_bytes.size()); diff --git a/cc/crypto/hpke/utils.h b/cc/crypto/hpke/utils.h index bb224298818..bda3a2f1ff0 100644 --- a/cc/crypto/hpke/utils.h +++ b/cc/crypto/hpke/utils.h @@ -31,35 +31,43 @@ namespace oak::crypto { -// Helpful struct for keeping track of key information returned from the BoringSSL HPKE library. +// Helpful struct for keeping track of key information returned from the +// BoringSSL HPKE library. struct KeyInfo { size_t key_size; std::vector key_bytes; }; // Generate session key for the AEAD context. -absl::StatusOr> GetContext(EVP_HPKE_CTX* hpke_ctx, - absl::string_view key_context_string); +absl::StatusOr> GetContext( + EVP_HPKE_CTX* hpke_ctx, absl::string_view key_context_string); // Generates random nonce for AEAD. -// RFC 9180 uses deterministic nonces which leads to the possibility of the following attack: -// - An attacker can record a client request and wait until the application database changes +// RFC 9180 uses deterministic nonces which leads to the possibility of the +// following attack: +// - An attacker can record a client request and wait until the application +// database changes // - I.e. it updates an internal lookup database based on the public data -// - Then if an attacker replays the same request it can get a different response encrypted with the -// same nonce -// - And having 2 different messages encrypted with the same nonce breaks AES-GCM +// - Then if an attacker replays the same request it can get a different +// response encrypted with the same nonce +// - And having 2 different messages encrypted with the same nonce breaks +// AES-GCM // - The attack is called AES-GCM Forbidden Attack -// To mitigate the AES-GCM Forbidden Attack Oak is using random nonces for encrypting messages with -// AEAD. +// To mitigate the AES-GCM Forbidden Attack Oak is using random nonces for +// encrypting messages with AEAD. absl::StatusOr> GenerateRandomNonce(); -// Encrypts `plaintext` and authenticates `associated_data` using AEAD with `context` and `nonce`. -absl::StatusOr AeadSeal(const EVP_AEAD_CTX* context, std::vector nonce, +// Encrypts `plaintext` and authenticates `associated_data` using AEAD with +// `context` and `nonce`. +absl::StatusOr AeadSeal(const EVP_AEAD_CTX* context, + std::vector nonce, absl::string_view plaintext, absl::string_view associated_data); -// Decrypts `ciphertext` and authenticates `associated_data` using AEAD using `context` and `nonce`. -absl::StatusOr AeadOpen(const EVP_AEAD_CTX* context, std::vector nonce, +// Decrypts `ciphertext` and authenticates `associated_data` using AEAD using +// `context` and `nonce`. +absl::StatusOr AeadOpen(const EVP_AEAD_CTX* context, + std::vector nonce, absl::string_view ciphertext, absl::string_view associated_data); diff --git a/cc/crypto/server_encryptor.cc b/cc/crypto/server_encryptor.cc index f211dd70d50..222ea84f4e0 100644 --- a/cc/crypto/server_encryptor.cc +++ b/cc/crypto/server_encryptor.cc @@ -37,7 +37,8 @@ using ::oak::crypto::v1::EncryptedRequest; using ::oak::crypto::v1::EncryptedResponse; } // namespace -absl::StatusOr ServerEncryptor::Decrypt(EncryptedRequest encrypted_request) { +absl::StatusOr ServerEncryptor::Decrypt( + EncryptedRequest encrypted_request) { // Get recipient context. if (!recipient_context_) { absl::Status status = InitializeRecipientContexts(encrypted_request); @@ -47,20 +48,22 @@ absl::StatusOr ServerEncryptor::Decrypt(EncryptedRequest encry } // Decrypt request. - const std::vector nonce(encrypted_request.encrypted_message().nonce().begin(), - encrypted_request.encrypted_message().nonce().end()); - absl::StatusOr plaintext = - recipient_context_->Open(nonce, encrypted_request.encrypted_message().ciphertext(), - encrypted_request.encrypted_message().associated_data()); + const std::vector nonce( + encrypted_request.encrypted_message().nonce().begin(), + encrypted_request.encrypted_message().nonce().end()); + absl::StatusOr plaintext = recipient_context_->Open( + nonce, encrypted_request.encrypted_message().ciphertext(), + encrypted_request.encrypted_message().associated_data()); if (!plaintext.ok()) { return plaintext.status(); } - return DecryptionResult{*plaintext, encrypted_request.encrypted_message().associated_data()}; + return DecryptionResult{ + *plaintext, encrypted_request.encrypted_message().associated_data()}; } -absl::StatusOr ServerEncryptor::Encrypt(absl::string_view plaintext, - absl::string_view associated_data) { +absl::StatusOr ServerEncryptor::Encrypt( + absl::string_view plaintext, absl::string_view associated_data) { // Get recipient context. if (!recipient_context_) { return absl::InternalError("server encryptor is not initialized"); @@ -81,23 +84,29 @@ absl::StatusOr ServerEncryptor::Encrypt(absl::string_view pla EncryptedResponse encrypted_response; *encrypted_response.mutable_encrypted_message()->mutable_nonce() = std::string(nonce->begin(), nonce->end()); - *encrypted_response.mutable_encrypted_message()->mutable_ciphertext() = *ciphertext; - *encrypted_response.mutable_encrypted_message()->mutable_associated_data() = associated_data; + *encrypted_response.mutable_encrypted_message()->mutable_ciphertext() = + *ciphertext; + *encrypted_response.mutable_encrypted_message()->mutable_associated_data() = + associated_data; return encrypted_response; } -absl::Status ServerEncryptor::InitializeRecipientContexts(const EncryptedRequest& request) { +absl::Status ServerEncryptor::InitializeRecipientContexts( + const EncryptedRequest& request) { // Get serialized encapsulated public key. if (!request.has_serialized_encapsulated_public_key()) { return absl::InvalidArgumentError( - "serialized encapsulated public key is not present in the initial request message"); + "serialized encapsulated public key is not present in the initial " + "request message"); } - std::string serialized_encapsulated_public_key = request.serialized_encapsulated_public_key(); + std::string serialized_encapsulated_public_key = + request.serialized_encapsulated_public_key(); // Create recipient contexts. absl::StatusOr> recipient_context = - encryption_key_handle_.GenerateRecipientContext(serialized_encapsulated_public_key); + encryption_key_handle_.GenerateRecipientContext( + serialized_encapsulated_public_key); if (!recipient_context.ok()) { return recipient_context.status(); } diff --git a/cc/crypto/server_encryptor.h b/cc/crypto/server_encryptor.h index 7e959ef8caf..a704b1ccd6e 100644 --- a/cc/crypto/server_encryptor.h +++ b/cc/crypto/server_encryptor.h @@ -31,39 +31,45 @@ namespace oak::crypto { -// Encryptor class for decrypting client requests that are received by the server and encrypting -// server responses that will be sent back to the client. Each Encryptor corresponds to a single -// crypto session between the client and the server. +// Encryptor class for decrypting client requests that are received by the +// server and encrypting server responses that will be sent back to the client. +// Each Encryptor corresponds to a single crypto session between the client and +// the server. // -// Sequence numbers for requests and responses are incremented separately, meaning that there could -// be multiple responses per request and multiple requests per response. +// Sequence numbers for requests and responses are incremented separately, +// meaning that there could be multiple responses per request and multiple +// requests per response. class ServerEncryptor { public: // Constructor for `ServerEncryptor`. - // `EncryptionKeyHandle` argument is a long-term object containing the private key and - // should outlive the per-session `ServerEncryptor` object. + // `EncryptionKeyHandle` argument is a long-term object containing the private + // key and should outlive the per-session `ServerEncryptor` object. ServerEncryptor(EncryptionKeyHandle& encryption_key_handle) - : encryption_key_handle_(encryption_key_handle), recipient_context_(nullptr){}; + : encryption_key_handle_(encryption_key_handle), + recipient_context_(nullptr){}; // Decrypts a [`EncryptedRequest`] proto message using AEAD. // // // Returns a response message plaintext and associated data. - absl::StatusOr Decrypt(oak::crypto::v1::EncryptedRequest encrypted_request); + absl::StatusOr Decrypt( + oak::crypto::v1::EncryptedRequest encrypted_request); // Encrypts `plaintext` and authenticates `associated_data` using AEAD. // // // Returns an [`oak.crypto.EncryptedResponse`] proto message. - // TODO(#3843): Return unserialized proto messages once we have Java encryption without JNI. - absl::StatusOr<::oak::crypto::v1::EncryptedResponse> Encrypt(absl::string_view plaintext, - absl::string_view associated_data); + // TODO(#3843): Return unserialized proto messages once we have Java + // encryption without JNI. + absl::StatusOr<::oak::crypto::v1::EncryptedResponse> Encrypt( + absl::string_view plaintext, absl::string_view associated_data); private: EncryptionKeyHandle& encryption_key_handle_; std::unique_ptr recipient_context_; - absl::Status InitializeRecipientContexts(const oak::crypto::v1::EncryptedRequest& request); + absl::Status InitializeRecipientContexts( + const oak::crypto::v1::EncryptedRequest& request); }; } // namespace oak::crypto diff --git a/cc/oak_echo_raw_enclave_app/main.cc b/cc/oak_echo_raw_enclave_app/main.cc index a816d6ed038..f546f365cbb 100644 --- a/cc/oak_echo_raw_enclave_app/main.cc +++ b/cc/oak_echo_raw_enclave_app/main.cc @@ -23,8 +23,8 @@ constexpr int CHANNEL_FD = 10; int main(int argc, char* argv[]) { char buf; - // This should be set up by the runtime for us, but it isn't. It's a bug in the toolchain we need - // to fix. + // This should be set up by the runtime for us, but it isn't. It's a bug in + // the toolchain we need to fix. std::ios_base::Init init; std::cerr << "In main!" << std::endl; diff --git a/cc/oak_functions/native_sdk/native_sdk.cc b/cc/oak_functions/native_sdk/native_sdk.cc index f1faaae10f8..ffd591e1d73 100644 --- a/cc/oak_functions/native_sdk/native_sdk.cc +++ b/cc/oak_functions/native_sdk/native_sdk.cc @@ -28,14 +28,16 @@ namespace oak::functions::sdk { // These functions are implemented in Rust, but called from C. -// This code does not take ownership of the raw pointers resulting from the Rust callbacks. +// This code does not take ownership of the raw pointers resulting from the Rust +// callbacks. extern "C" { const uint8_t* (*read_request_callback)(size_t* len); bool (*write_response_callback)(const uint8_t* data, size_t len); bool (*log_callback)(const uint8_t* data, size_t len); -const uint8_t* (*storage_get_item_callback)(const uint8_t* key, size_t key_len, size_t* item_len); +const uint8_t* (*storage_get_item_callback)(const uint8_t* key, size_t key_len, + size_t* item_len); const uint8_t* (*read_error_callback)(uint32_t* status_code, size_t* len); } // extern "C" @@ -47,8 +49,10 @@ namespace { absl::Status read_status() { uint32_t status_code; size_t len; - const char* message = reinterpret_cast(read_error_callback(&status_code, &len)); - return absl::Status(static_cast(status_code), absl::string_view(message, len)); + const char* message = + reinterpret_cast(read_error_callback(&status_code, &len)); + return absl::Status(static_cast(status_code), + absl::string_view(message, len)); } } // namespace @@ -64,15 +68,16 @@ absl::StatusOr read_request() { } absl::Status write_response(absl::string_view response) { - if (!write_response_callback(reinterpret_cast(response.data()), - response.size())) { + if (!write_response_callback( + reinterpret_cast(response.data()), response.size())) { return read_status(); } return absl::OkStatus(); } absl::Status log(absl::string_view message) { - if (!log_callback(reinterpret_cast(message.data()), message.size())) { + if (!log_callback(reinterpret_cast(message.data()), + message.size())) { return read_status(); } return absl::OkStatus(); @@ -80,8 +85,8 @@ absl::Status log(absl::string_view message) { absl::StatusOr storage_get_item(absl::string_view key) { size_t item_len; - const uint8_t* item = storage_get_item_callback(reinterpret_cast(key.data()), - key.size(), &item_len); + const uint8_t* item = storage_get_item_callback( + reinterpret_cast(key.data()), key.size(), &item_len); if (item == nullptr) { return read_status(); } @@ -93,7 +98,8 @@ extern "C" void register_callbacks( const uint8_t* (*read_request_cb)(size_t* len), bool (*write_response_cb)(const uint8_t* data, size_t len), bool (*log_cb)(const uint8_t* data, size_t len), - const uint8_t* (*storage_get_item_cb)(const uint8_t* key, size_t key_len, size_t* item_len), + const uint8_t* (*storage_get_item_cb)(const uint8_t* key, size_t key_len, + size_t* item_len), const uint8_t* (*read_error_cb)(uint32_t* status_code, size_t* len)) { read_request_callback = read_request_cb; write_response_callback = write_response_cb; diff --git a/cc/oak_functions/native_sdk/native_sdk.h b/cc/oak_functions/native_sdk/native_sdk.h index b4610d97ef7..bec05bfb34f 100644 --- a/cc/oak_functions/native_sdk/native_sdk.h +++ b/cc/oak_functions/native_sdk/native_sdk.h @@ -39,8 +39,8 @@ absl::Status write_response(absl::string_view response); // Calls the log Rust function that writes a debug log message if in debug mode. absl::Status log(absl::string_view message); -// Calls the lookup_data Rust function that looks up an item from the in-memory key/value lookup -// store. +// Calls the lookup_data Rust function that looks up an item from the in-memory +// key/value lookup store. absl::StatusOr storage_get_item(absl::string_view key); } // namespace oak::functions::sdk diff --git a/cc/oak_functions/native_sdk/native_sdk_ffi.h b/cc/oak_functions/native_sdk/native_sdk_ffi.h index cf9894c7b52..a7baaeb44fa 100644 --- a/cc/oak_functions/native_sdk/native_sdk_ffi.h +++ b/cc/oak_functions/native_sdk/native_sdk_ffi.h @@ -27,12 +27,13 @@ extern "C" { // This is for use with bindgen. Note that only basic types that work in // both C and Rust. The char type is complicated and uint8_t is used instead. -void register_callbacks(const uint8_t* (*read_request_cb)(size_t* len), - bool (*write_response_cb)(const uint8_t* data, size_t len), - bool (*log_cb)(const uint8_t* data, size_t len), - const uint8_t* (*storage_get_item_cb)(const uint8_t* key, size_t key_len, - size_t* item_len), - const uint8_t* (*read_error_cb)(uint32_t* status_code, size_t* len)); +void register_callbacks( + const uint8_t* (*read_request_cb)(size_t* len), + bool (*write_response_cb)(const uint8_t* data, size_t len), + bool (*log_cb)(const uint8_t* data, size_t len), + const uint8_t* (*storage_get_item_cb)(const uint8_t* key, size_t key_len, + size_t* item_len), + const uint8_t* (*read_error_cb)(uint32_t* status_code, size_t* len)); void oak_main(); diff --git a/cc/transport/BUILD b/cc/transport/BUILD index 5f3b8eeb122..ce75129f9bc 100644 --- a/cc/transport/BUILD +++ b/cc/transport/BUILD @@ -36,6 +36,7 @@ cc_library( hdrs = ["grpc_streaming_transport.h"], deps = [ ":transport", + ":util", "//proto/crypto:crypto_cc_proto", "//proto/session:messages_cc_proto", "//proto/session:service_streaming_cc_grpc", @@ -48,6 +49,34 @@ cc_library( ], ) +cc_library( + name = "grpc_unary_transport", + hdrs = ["grpc_unary_transport.h"], + deps = [ + ":transport", + ":util", + "//proto/crypto:crypto_cc_proto", + "//proto/session:messages_cc_proto", + "//proto/session:service_unary_cc_grpc", + "//proto/session:service_unary_cc_proto", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/status", + ], +) + cc_test( name = "grpc_streaming_transport_test", srcs = ["grpc_streaming_transport_test.cc"], @@ -64,3 +93,20 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_test( + name = "grpc_unary_transport_test", + srcs = ["grpc_unary_transport_test.cc"], + deps = [ + ":grpc_unary_transport", + "//proto/crypto:crypto_cc_proto", + "//proto/session:messages_cc_proto", + "//proto/session:service_unary_cc_grpc", + "//proto/session:service_unary_cc_proto", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/cc/transport/grpc_streaming_transport.cc b/cc/transport/grpc_streaming_transport.cc index 4ae85dd7355..2048e5a2502 100644 --- a/cc/transport/grpc_streaming_transport.cc +++ b/cc/transport/grpc_streaming_transport.cc @@ -23,6 +23,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "cc/transport/util.h" #include "grpcpp/channel.h" #include "grpcpp/client_context.h" #include "grpcpp/create_channel.h" @@ -32,25 +33,19 @@ namespace oak::transport { -namespace { using ::oak::crypto::v1::EncryptedRequest; using ::oak::crypto::v1::EncryptedResponse; using ::oak::session::v1::EndorsedEvidence; using ::oak::session::v1::GetEndorsedEvidenceRequest; using ::oak::session::v1::RequestWrapper; using ::oak::session::v1::ResponseWrapper; -} // namespace - -absl::Status to_absl_status(const grpc::Status& grpc_status) { - return absl::Status(static_cast(grpc_status.error_code()), - grpc_status.error_message()); -} absl::StatusOr GrpcStreamingTransport::GetEndorsedEvidence() { // Create request. RequestWrapper request; GetEndorsedEvidenceRequest get_endorsed_evidence_request; - *request.mutable_get_endorsed_evidence_request() = get_endorsed_evidence_request; + *request.mutable_get_endorsed_evidence_request() = + get_endorsed_evidence_request; // Send request. auto response = Send(request); @@ -63,10 +58,12 @@ absl::StatusOr GrpcStreamingTransport::GetEndorsedEvidence() { case ResponseWrapper::kGetEndorsedEvidenceResponseFieldNumber: return response->get_endorsed_evidence_response().endorsed_evidence(); case ResponseWrapper::kInvokeResponseFieldNumber: - return absl::InternalError("received InvokeResponse instead of GetEndorsedEvidenceResponse"); + return absl::InternalError( + "received InvokeResponse instead of GetEndorsedEvidenceResponse"); case ResponseWrapper::RESPONSE_NOT_SET: default: - return absl::InternalError("received unsupported response: " + absl::StrCat(*response)); + return absl::InternalError("received unsupported response: " + + absl::StrCat(*response)); } } @@ -74,7 +71,8 @@ absl::StatusOr GrpcStreamingTransport::Invoke( const EncryptedRequest& encrypted_request) { // Create request. RequestWrapper request; - *request.mutable_invoke_request()->mutable_encrypted_request() = encrypted_request; + *request.mutable_invoke_request()->mutable_encrypted_request() = + encrypted_request; // Send request. auto response = Send(request); @@ -85,12 +83,14 @@ absl::StatusOr GrpcStreamingTransport::Invoke( // Process response. switch (response->response_case()) { case ResponseWrapper::kGetEndorsedEvidenceResponseFieldNumber: - return absl::InternalError("received GetEndorsedEvidenceResponse instead of InvokeResponse"); + return absl::InternalError( + "received GetEndorsedEvidenceResponse instead of InvokeResponse"); case ResponseWrapper::kInvokeResponseFieldNumber: return response->invoke_response().encrypted_response(); case ResponseWrapper::RESPONSE_NOT_SET: default: - return absl::InternalError("received unsupported response: " + absl::StrCat(*response)); + return absl::InternalError("received unsupported response: " + + absl::StrCat(*response)); } } @@ -101,15 +101,18 @@ GrpcStreamingTransport::~GrpcStreamingTransport() { } } -absl::StatusOr GrpcStreamingTransport::Send(const RequestWrapper& request) { +absl::StatusOr GrpcStreamingTransport::Send( + const RequestWrapper& request) { // Send a request. if (!channel_reader_writer_->Write(request)) { absl::Status status = Close(); if (status.ok()) { return absl::InternalError( - "failed to read request for unspecified reason. This is likely an implementation bug."); + "failed to read request for unspecified reason. This is likely an " + "implementation bug."); } else { - return absl::Status(status.code(), absl::StrCat("while writing request: ", status.message())); + return absl::Status(status.code(), absl::StrCat("while writing request: ", + status.message())); } } @@ -119,9 +122,11 @@ absl::StatusOr GrpcStreamingTransport::Send(const RequestWrappe absl::Status status = Close(); if (status.ok()) { return absl::InternalError( - "failed to write request for unspecified reason. This is likely an implementation bug."); + "failed to write request for unspecified reason. This is likely an " + "implementation bug."); } else { - return absl::Status(status.code(), absl::StrCat("while reading request: ", status.message())); + return absl::Status(status.code(), absl::StrCat("while reading request: ", + status.message())); } } return response; diff --git a/cc/transport/grpc_streaming_transport.h b/cc/transport/grpc_streaming_transport.h index 8e1e32de9d7..7fa15d81c16 100644 --- a/cc/transport/grpc_streaming_transport.h +++ b/cc/transport/grpc_streaming_transport.h @@ -33,20 +33,22 @@ namespace oak::transport { class GrpcStreamingTransport : public TransportWrapper { public: explicit GrpcStreamingTransport( - std::unique_ptr<::grpc::ClientReaderWriterInterface<::oak::session::v1::RequestWrapper, - ::oak::session::v1::ResponseWrapper>> + std::unique_ptr<::grpc::ClientReaderWriterInterface< + ::oak::session::v1::RequestWrapper, + ::oak::session::v1::ResponseWrapper>> channel_reader_writer) : channel_reader_writer_(std::move(channel_reader_writer)) {} - absl::StatusOr<::oak::session::v1::EndorsedEvidence> GetEndorsedEvidence() override; + absl::StatusOr<::oak::session::v1::EndorsedEvidence> GetEndorsedEvidence() + override; absl::StatusOr<::oak::crypto::v1::EncryptedResponse> Invoke( const oak::crypto::v1::EncryptedRequest& encrypted_request) override; ~GrpcStreamingTransport() override; private: - std::unique_ptr<::grpc::ClientReaderWriterInterface<::oak::session::v1::RequestWrapper, - ::oak::session::v1::ResponseWrapper>> + std::unique_ptr<::grpc::ClientReaderWriterInterface< + ::oak::session::v1::RequestWrapper, ::oak::session::v1::ResponseWrapper>> channel_reader_writer_; absl::once_flag close_once_; absl::Status close_status_; diff --git a/cc/transport/grpc_streaming_transport_test.cc b/cc/transport/grpc_streaming_transport_test.cc index 3ac1932d580..cf9f814c0a4 100644 --- a/cc/transport/grpc_streaming_transport_test.cc +++ b/cc/transport/grpc_streaming_transport_test.cc @@ -38,12 +38,16 @@ using oak::session::v1::RequestWrapper; using oak::session::v1::ResponseWrapper; using oak::transport::GrpcStreamingTransport; using ::testing::_; -using ServerStream = ::grpc::ServerReaderWriter; -using ClientStream = ::grpc::ClientReaderWriterInterface; +using ServerStream = + ::grpc::ServerReaderWriter; +using ClientStream = + ::grpc::ClientReaderWriterInterface; -class MockServiceStreaming : public ::oak::session::v1::StreamingSession::Service { +class MockServiceStreaming + : public ::oak::session::v1::StreamingSession::Service { public: - MOCK_METHOD(grpc::Status, Stream, (grpc::ServerContext*, (ServerStream*)), (override)); + MOCK_METHOD(grpc::Status, Stream, (grpc::ServerContext*, (ServerStream*)), + (override)); }; class GrpcStreamingTransportTest : public ::testing::Test { protected: @@ -72,52 +76,64 @@ class GrpcStreamingTransportTest : public ::testing::Test { TEST_F(GrpcStreamingTransportTest, InvokePropagatesSendError) { GrpcStreamingTransport transport(stub_->Stream(&context_)); - EXPECT_CALL(mock_service_, Stream(_, _)).WillOnce([](grpc::ServerContext*, ServerStream* stream) { - return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "fake error"); - }); + EXPECT_CALL(mock_service_, Stream(_, _)) + .WillOnce([](grpc::ServerContext*, ServerStream* stream) { + return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, + "fake error"); + }); EncryptedRequest request; - absl::StatusOr response = transport.Invoke(request); - ASSERT_EQ(response.status(), absl::Status(absl::StatusCode::kFailedPrecondition, - "while writing request: fake error")); + absl::StatusOr response = + transport.Invoke(request); + ASSERT_EQ(response.status(), + absl::Status(absl::StatusCode::kFailedPrecondition, + "while writing request: fake error")); } TEST_F(GrpcStreamingTransportTest, GetEndorsedEvidencePropagatesSendError) { GrpcStreamingTransport transport(stub_->Stream(&context_)); - EXPECT_CALL(mock_service_, Stream(_, _)).WillOnce([](grpc::ServerContext*, ServerStream* stream) { - return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "fake error"); - }); + EXPECT_CALL(mock_service_, Stream(_, _)) + .WillOnce([](grpc::ServerContext*, ServerStream* stream) { + return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, + "fake error"); + }); absl::StatusOr response = transport.GetEndorsedEvidence(); - ASSERT_EQ(response.status(), absl::Status(absl::StatusCode::kFailedPrecondition, - "while writing request: fake error")); + ASSERT_EQ(response.status(), + absl::Status(absl::StatusCode::kFailedPrecondition, + "while writing request: fake error")); } TEST_F(GrpcStreamingTransportTest, InvokePropagatesWeirdError) { GrpcStreamingTransport transport(stub_->Stream(&context_)); - EXPECT_CALL(mock_service_, Stream(_, _)).WillOnce([](grpc::ServerContext*, ServerStream* stream) { - return grpc::Status::OK; - }); + EXPECT_CALL(mock_service_, Stream(_, _)) + .WillOnce([](grpc::ServerContext*, ServerStream* stream) { + return grpc::Status::OK; + }); EncryptedRequest request; - absl::StatusOr response = transport.Invoke(request); + absl::StatusOr response = + transport.Invoke(request); ASSERT_EQ(response.status().code(), absl::StatusCode::kInternal); - EXPECT_THAT(response.status().message(), testing::StartsWith("failed to read request")); + EXPECT_THAT(response.status().message(), + testing::StartsWith("failed to read request")); } TEST_F(GrpcStreamingTransportTest, GetEndorsedEvidencePropagatesWeirdError) { ::grpc::ClientContext context; GrpcStreamingTransport transport(stub_->Stream(&context)); - EXPECT_CALL(mock_service_, Stream(_, _)).WillOnce([](grpc::ServerContext*, ServerStream* stream) { - return grpc::Status::OK; - }); + EXPECT_CALL(mock_service_, Stream(_, _)) + .WillOnce([](grpc::ServerContext*, ServerStream* stream) { + return grpc::Status::OK; + }); absl::StatusOr response = transport.GetEndorsedEvidence(); ASSERT_EQ(response.status().code(), absl::StatusCode::kInternal); - EXPECT_THAT(response.status().message(), testing::StartsWith("failed to read request")); + EXPECT_THAT(response.status().message(), + testing::StartsWith("failed to read request")); } diff --git a/cc/transport/grpc_unary_transport.h b/cc/transport/grpc_unary_transport.h new file mode 100644 index 00000000000..ec8f490c408 --- /dev/null +++ b/cc/transport/grpc_unary_transport.h @@ -0,0 +1,86 @@ +/* + * Copyright 2023 The Project Oak Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CC_TRANSPORT_GRPC_UNARY_TRANSPORT_H_ +#define CC_TRANSPORT_GRPC_UNARY_TRANSPORT_H_ + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "cc/transport/transport.h" +#include "cc/transport/util.h" +#include "grpcpp/client_context.h" +#include "proto/crypto/crypto.pb.h" +#include "proto/session/messages.pb.h" +#include "proto/session/service_unary.pb.h" + +namespace oak::transport { + +// Transport class for communication with unary gRPC Oak service. Evidence +// must be collected from the enclave and verified prior to issuing any data +// requests. This template class can be used to communicate with any unary +// stubby Oak service that use a gRPC interface consistent with +// oak/proto/session/service_unary.proto. +template +class GrpcUnaryTransport : public ::oak::transport::TransportWrapper { + public: + explicit GrpcUnaryTransport(OakBackendStub* const client_stub) + : client_stub_(client_stub) {} + + // Collects the enclave's evidence that needs to be verified by the client. + absl::StatusOr<::oak::session::v1::EndorsedEvidence> GetEndorsedEvidence() + override { + ::grpc::ClientContext context; + ::oak::session::v1::GetEndorsedEvidenceRequest request; + ::oak::session::v1::GetEndorsedEvidenceResponse response; + + grpc::Status status = + client_stub_->GetEndorsedEvidence(&context, request, &response); + if (!status.ok()) { + absl::Status absl_status = to_absl_status(status); + LOG(ERROR) << "Failed to fetch evidence with status: " << absl_status; + return absl_status; + } + + return response.endorsed_evidence(); + } + + // Takes an encrypted request and sends it to the enclave, returning the + // enclave's encrypted response. + absl::StatusOr<::oak::crypto::v1::EncryptedResponse> Invoke( + const ::oak::crypto::v1::EncryptedRequest& encrypted_request) override { + ::grpc::ClientContext context; + ::oak::session::v1::InvokeRequest request; + ::oak::session::v1::InvokeResponse response; + + *request.mutable_encrypted_request() = encrypted_request; + grpc::Status status = client_stub_->Invoke(&context, request, &response); + if (!status.ok()) { + absl::Status absl_status = to_absl_status(status); + LOG(ERROR) << "Failed to call invoke with status: " << absl_status; + return absl_status; + } + + return response.encrypted_response(); + } + + private: + OakBackendStub* client_stub_; +}; + +} // namespace oak::transport + +#endif // CC_TRANSPORT_GRPC_UNARY_TRANSPORT_H_ diff --git a/cc/transport/grpc_unary_transport_test.cc b/cc/transport/grpc_unary_transport_test.cc new file mode 100644 index 00000000000..5f95c842202 --- /dev/null +++ b/cc/transport/grpc_unary_transport_test.cc @@ -0,0 +1,106 @@ +/* + * Copyright 2023 The Project Oak Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cc/transport/grpc_unary_transport.h" + +#include +#include + +#include "absl/status/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "proto/crypto/crypto.pb.h" +#include "proto/session/messages.pb.h" +#include "proto/session/service_unary.pb.h" +#include "proto/session/service_unary_mock.grpc.pb.h" + +namespace oak::transport { +namespace { + +using ::oak::crypto::v1::AeadEncryptedMessage; +using ::oak::crypto::v1::EncryptedRequest; +using ::oak::crypto::v1::EncryptedResponse; +using ::oak::session::v1::EndorsedEvidence; +using ::oak::session::v1::GetEndorsedEvidenceRequest; +using ::oak::session::v1::GetEndorsedEvidenceResponse; +using ::oak::session::v1::InvokeRequest; +using ::oak::session::v1::InvokeResponse; +using ::oak::session::v1::MockUnarySessionStub; + +using ::testing::_; +using ::testing::DoAll; +using ::testing::Return; +using ::testing::SetArgPointee; +using ::testing::StrEq; + +TEST(StubbyUnaryTransportTest, KeyRetrievedAndInvokeCalledSuccess) { + auto mock_stub = std::make_unique(); + + // Test the get endorsed evidence method. + GetEndorsedEvidenceRequest empty_request; + GetEndorsedEvidenceResponse evidence_response; + std::string application_key = "001"; + *evidence_response.mutable_endorsed_evidence() + ->mutable_evidence() + ->mutable_application_keys() + ->mutable_encryption_public_key_certificate() = application_key; + + EXPECT_CALL(*mock_stub, GetEndorsedEvidence(_, _, _)) + .WillOnce( + DoAll(SetArgPointee<2>(evidence_response), Return(grpc::Status::OK))); + + GrpcUnaryTransport unary_transport(mock_stub.get()); + + auto actual_endorsed_evidence = unary_transport.GetEndorsedEvidence(); + ASSERT_TRUE(actual_endorsed_evidence.ok()); + EXPECT_THAT(actual_endorsed_evidence->evidence() + .application_keys() + .encryption_public_key_certificate(), + StrEq(application_key)); + + // Now we test the invoke method. + + const std::string request_ciphertext = "Some encrypted request."; + AeadEncryptedMessage request_aead_encrypted_message; + request_aead_encrypted_message.set_ciphertext(request_ciphertext); + EncryptedRequest encrypted_request; + *encrypted_request.mutable_encrypted_message() = + request_aead_encrypted_message; + InvokeRequest invoke_request; + *invoke_request.mutable_encrypted_request() = encrypted_request; + + const std::string response_ciphertext = "Some encrypted response."; + AeadEncryptedMessage response_aead_encrypted_message; + response_aead_encrypted_message.set_ciphertext(response_ciphertext); + EncryptedResponse encrypted_response; + *encrypted_response.mutable_encrypted_message() = + response_aead_encrypted_message; + InvokeResponse invoke_response; + *invoke_response.mutable_encrypted_response() = encrypted_response; + + EXPECT_CALL(*mock_stub, Invoke(_, _, _)) + .WillOnce( + DoAll(SetArgPointee<2>(invoke_response), Return(::grpc::Status::OK))); + + auto actual_encrypted_response = unary_transport.Invoke(encrypted_request); + ASSERT_TRUE(actual_encrypted_response.ok()); + + EXPECT_THAT(actual_encrypted_response->encrypted_message().ciphertext(), + StrEq(response_ciphertext)); +} + +} // namespace +} // namespace oak::transport diff --git a/cc/transport/transport.h b/cc/transport/transport.h index 0277d46e1cc..75c8f12ceba 100644 --- a/cc/transport/transport.h +++ b/cc/transport/transport.h @@ -32,7 +32,8 @@ class EvidenceProvider { virtual ~EvidenceProvider() = default; // Returns evidence about the trustworthiness of a remote server. - virtual absl::StatusOr<::oak::session::v1::EndorsedEvidence> GetEndorsedEvidence() = 0; + virtual absl::StatusOr<::oak::session::v1::EndorsedEvidence> + GetEndorsedEvidence() = 0; }; // Abstract class for sending messages to the enclave. diff --git a/cc/transport/util.cc b/cc/transport/util.cc new file mode 100644 index 00000000000..7924d64cc43 --- /dev/null +++ b/cc/transport/util.cc @@ -0,0 +1,27 @@ +/* + * Copyright 2023 The Project Oak Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "absl/status/status.h" +#include "grpcpp/grpcpp.h" + +namespace oak::transport { + +absl::Status to_absl_status(const grpc::Status& grpc_status) { + return absl::Status(static_cast(grpc_status.error_code()), + grpc_status.error_message()); +} + +} // namespace oak::transport diff --git a/cc/transport/util.h b/cc/transport/util.h new file mode 100644 index 00000000000..42b84d2a78e --- /dev/null +++ b/cc/transport/util.h @@ -0,0 +1,31 @@ +/* + * Copyright 2023 The Project Oak Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CC_TRANSPORT_UTIL_H_ +#define CC_TRANSPORT_UTIL_H_ + +#include "absl/status/status.h" +#include "grpcpp/grpcpp.h" + +namespace oak::transport { + +// Converts gRPC status to an absl status. The gRPC error status code is casted +// and the error message is copied. +absl::Status to_absl_status(const grpc::Status& grpc_status); + +} // namespace oak::transport + +#endif // CC_TRANSPORT_UTIL_H_ diff --git a/cc/utils/cose/cose.cc b/cc/utils/cose/cose.cc index b35ebabce52..aacce47942a 100644 --- a/cc/utils/cose/cose.cc +++ b/cc/utils/cose/cose.cc @@ -35,7 +35,8 @@ absl::StatusOr CoseSign1::Deserialize(absl::string_view data) { auto [item, end, error] = cppbor::parse(reinterpret_cast(data.data()), data.size()); if (!error.empty()) { - return absl::InvalidArgumentError(absl::StrCat("couldn't parse COSE_Sign1: ", error)); + return absl::InvalidArgumentError( + absl::StrCat("couldn't parse COSE_Sign1: ", error)); } if (item->type() != cppbor::ARRAY) { return UnexpectedCborTypeError("COSE_Sign1", cppbor::ARRAY, item->type()); @@ -43,16 +44,19 @@ absl::StatusOr CoseSign1::Deserialize(absl::string_view data) { const cppbor::Array* array = item->asArray(); if (array->size() != 4) { return absl::InvalidArgumentError( - absl::StrCat("invalid COSE_Sign1 CBOR array size, expected 4, found ", array->size())); + absl::StrCat("invalid COSE_Sign1 CBOR array size, expected 4, found ", + array->size())); } const auto& protected_headers = array->get(0); if (protected_headers->type() != cppbor::BSTR) { - return UnexpectedCborTypeError("protected_headers", cppbor::BSTR, protected_headers->type()); + return UnexpectedCborTypeError("protected_headers", cppbor::BSTR, + protected_headers->type()); } const auto& unprotected_headers = array->get(1); if (unprotected_headers->type() != cppbor::MAP) { - return UnexpectedCborTypeError("unprotected_headers", cppbor::MAP, unprotected_headers->type()); + return UnexpectedCborTypeError("unprotected_headers", cppbor::MAP, + unprotected_headers->type()); } const auto& payload = array->get(2); if (payload->type() != cppbor::BSTR) { @@ -60,14 +64,16 @@ absl::StatusOr CoseSign1::Deserialize(absl::string_view data) { } const auto& signature = array->get(3); if (signature->type() != cppbor::BSTR) { - return UnexpectedCborTypeError("signature", cppbor::BSTR, signature->type()); + return UnexpectedCborTypeError("signature", cppbor::BSTR, + signature->type()); } - return CoseSign1(protected_headers->asBstr(), unprotected_headers->asMap(), payload->asBstr(), - signature->asBstr(), std::move(item)); + return CoseSign1(protected_headers->asBstr(), unprotected_headers->asMap(), + payload->asBstr(), signature->asBstr(), std::move(item)); } -absl::StatusOr> CoseSign1::Serialize(const std::vector& payload) { +absl::StatusOr> CoseSign1::Serialize( + const std::vector& payload) { cppbor::Array array; // TODO(#4818): Implement headers and signature. std::vector protected_headers; @@ -78,35 +84,41 @@ absl::StatusOr> CoseSign1::Serialize(const std::vector encoded_array(array.encodedSize()); - array.encode(encoded_array.data(), encoded_array.data() + encoded_array.size()); + array.encode(encoded_array.data(), + encoded_array.data() + encoded_array.size()); return encoded_array; } -absl::StatusOr CoseKey::DeserializeHpkePublicKey(absl::string_view data) { +absl::StatusOr CoseKey::DeserializeHpkePublicKey( + absl::string_view data) { auto [item, end, error] = cppbor::parse(reinterpret_cast(data.data()), data.size()); if (!error.empty()) { - return absl::InvalidArgumentError(absl::StrCat("couldn't parse COSE_Key: ", error)); + return absl::InvalidArgumentError( + absl::StrCat("couldn't parse COSE_Key: ", error)); } return DeserializeHpkePublicKey(std::move(item)); } -absl::StatusOr CoseKey::DeserializeHpkePublicKey(const std::vector& data) { +absl::StatusOr CoseKey::DeserializeHpkePublicKey( + const std::vector& data) { auto [item, end, error] = cppbor::parse(data); if (!error.empty()) { - return absl::InvalidArgumentError(absl::StrCat("couldn't parse COSE_Key: ", error)); + return absl::InvalidArgumentError( + absl::StrCat("couldn't parse COSE_Key: ", error)); } return DeserializeHpkePublicKey(std::move(item)); } -absl::StatusOr CoseKey::DeserializeHpkePublicKey(std::unique_ptr&& item) { +absl::StatusOr CoseKey::DeserializeHpkePublicKey( + std::unique_ptr&& item) { if (item->type() != cppbor::MAP) { return UnexpectedCborTypeError("COSE_Key", cppbor::MAP, item->type()); } const cppbor::Map* map = item->asMap(); if (map->size() < 5) { - return absl::InvalidArgumentError( - absl::StrCat("invalid COSE_Key CBOR map size, expected >= 5, found ", map->size())); + return absl::InvalidArgumentError(absl::StrCat( + "invalid COSE_Key CBOR map size, expected >= 5, found ", map->size())); } const auto& kty = map->get(KTY); @@ -146,8 +158,8 @@ absl::StatusOr CoseKey::DeserializeHpkePublicKey(std::unique_ptrtype()); } - return CoseKey(kty->asUint(), alg->asNint(), key_ops->asArray(), crv->asUint(), x->asBstr(), - std::move(item)); + return CoseKey(kty->asUint(), alg->asNint(), key_ops->asArray(), + crv->asUint(), x->asBstr(), std::move(item)); } absl::StatusOr> CoseKey::SerializeHpkePublicKey( @@ -188,11 +200,12 @@ std::string CborTypeToString(cppbor::MajorType cbor_type) { } } -absl::Status UnexpectedCborTypeError(std::string_view name, cppbor::MajorType expected, +absl::Status UnexpectedCborTypeError(std::string_view name, + cppbor::MajorType expected, cppbor::MajorType found) { - return absl::InvalidArgumentError(absl::StrCat("expected ", name, " to have ", - CborTypeToString(expected), " CBOR type, found ", - CborTypeToString(found))); + return absl::InvalidArgumentError( + absl::StrCat("expected ", name, " to have ", CborTypeToString(expected), + " CBOR type, found ", CborTypeToString(found))); } } // namespace oak::utils::cose diff --git a/cc/utils/cose/cose.h b/cc/utils/cose/cose.h index ae04c075cb8..0ada9bd6d64 100644 --- a/cc/utils/cose/cose.h +++ b/cc/utils/cose/cose.h @@ -33,13 +33,16 @@ namespace oak::utils::cose { // class CoseSign1 { public: - // Parameters about the current layer that are to be cryptographically protected. + // Parameters about the current layer that are to be cryptographically + // protected. const cppbor::Bstr* protected_headers; - // Parameters about the current layer that are not cryptographically protected. + // Parameters about the current layer that are not cryptographically + // protected. const cppbor::Map* unprotected_headers; // Serialized content to be signed. const cppbor::Bstr* payload; - // Array of signatures. Each signature is represented as a COSE_Signature structure. + // Array of signatures. Each signature is represented as a COSE_Signature + // structure. const cppbor::Bstr* signature; static absl::StatusOr Deserialize(absl::string_view data); @@ -47,15 +50,16 @@ class CoseSign1 { // Puts payload into a COSE_Sign1 and serializes it. // TODO(#4818): This function is currently used for tests only. We need to // refactor COSE classes to support both serialization and deserialization. - static absl::StatusOr> Serialize(const std::vector& payload); + static absl::StatusOr> Serialize( + const std::vector& payload); private: // Parsed CBOR item containing COSE_Sign1 object. std::unique_ptr item_; - CoseSign1(const cppbor::Bstr* protected_headers, const cppbor::Map* unprotected_headers, - const cppbor::Bstr* payload, const cppbor::Bstr* signature, - std::unique_ptr&& item) + CoseSign1(const cppbor::Bstr* protected_headers, + const cppbor::Map* unprotected_headers, const cppbor::Bstr* payload, + const cppbor::Bstr* signature, std::unique_ptr&& item) : protected_headers(protected_headers), unprotected_headers(unprotected_headers), payload(payload), @@ -81,8 +85,10 @@ class CoseKey { // Deserializes HPKE public key as a COSE_Key. // - static absl::StatusOr DeserializeHpkePublicKey(absl::string_view data); - static absl::StatusOr DeserializeHpkePublicKey(const std::vector& data); + static absl::StatusOr DeserializeHpkePublicKey( + absl::string_view data); + static absl::StatusOr DeserializeHpkePublicKey( + const std::vector& data); // Transforms HPKE public key into a COSE_Key and serializes it. // TODO(#4818): This function is currently used for tests only. We need to @@ -141,16 +147,24 @@ class CoseKey { // Parsed CBOR item containing COSE_Key object. std::unique_ptr item_; - CoseKey(const cppbor::Uint* kty, const cppbor::Nint* alg, const cppbor::Array* key_ops, - const cppbor::Uint* crv, const cppbor::Bstr* x, std::unique_ptr&& item) - : kty(kty), alg(alg), key_ops(key_ops), crv(crv), x(x), item_(std::move(item)) {} + CoseKey(const cppbor::Uint* kty, const cppbor::Nint* alg, + const cppbor::Array* key_ops, const cppbor::Uint* crv, + const cppbor::Bstr* x, std::unique_ptr&& item) + : kty(kty), + alg(alg), + key_ops(key_ops), + crv(crv), + x(x), + item_(std::move(item)) {} - static absl::StatusOr DeserializeHpkePublicKey(std::unique_ptr&& item); + static absl::StatusOr DeserializeHpkePublicKey( + std::unique_ptr&& item); }; std::string CborTypeToString(cppbor::MajorType cbor_type); -absl::Status UnexpectedCborTypeError(std::string_view name, cppbor::MajorType expected, +absl::Status UnexpectedCborTypeError(std::string_view name, + cppbor::MajorType expected, cppbor::MajorType found); } // namespace oak::utils::cose diff --git a/cc/utils/cose/cose_test.cc b/cc/utils/cose/cose_test.cc index e7dc9546692..e42ca59e990 100644 --- a/cc/utils/cose/cose_test.cc +++ b/cc/utils/cose/cose_test.cc @@ -35,9 +35,11 @@ TEST(CoseTest, CoseSign1SerializeDeserializeSuccess) { auto serialized_cose_sign1_string = std::string(serialized_cose_sign1->begin(), serialized_cose_sign1->end()); - auto deserialized_cose_sign1 = CoseSign1::Deserialize(serialized_cose_sign1_string); + auto deserialized_cose_sign1 = + CoseSign1::Deserialize(serialized_cose_sign1_string); EXPECT_TRUE(deserialized_cose_sign1.ok()) << deserialized_cose_sign1.status(); - EXPECT_THAT(deserialized_cose_sign1->payload->value(), ElementsAreArray(test_payload)); + EXPECT_THAT(deserialized_cose_sign1->payload->value(), + ElementsAreArray(test_payload)); } TEST(CoseTest, CoseKeySerializeDeserializeSuccess) { @@ -45,9 +47,11 @@ TEST(CoseTest, CoseKeySerializeDeserializeSuccess) { auto serialized_cose_key = CoseKey::SerializeHpkePublicKey(test_public_key); EXPECT_TRUE(serialized_cose_key.ok()) << serialized_cose_key.status(); - auto deserialized_cose_key = CoseKey::DeserializeHpkePublicKey(*serialized_cose_key); + auto deserialized_cose_key = + CoseKey::DeserializeHpkePublicKey(*serialized_cose_key); EXPECT_TRUE(deserialized_cose_key.ok()) << deserialized_cose_key.status(); - EXPECT_THAT(deserialized_cose_key->GetPublicKey(), ElementsAreArray(test_public_key)); + EXPECT_THAT(deserialized_cose_key->GetPublicKey(), + ElementsAreArray(test_public_key)); } } // namespace diff --git a/cc/utils/cose/cwt.cc b/cc/utils/cose/cwt.cc index f74aaf5ab49..99497e234fc 100644 --- a/cc/utils/cose/cwt.cc +++ b/cc/utils/cose/cwt.cc @@ -41,15 +41,16 @@ absl::StatusOr Cwt::Deserialize(absl::string_view data) { // Deserialize COSE_Sign1 payload. auto [item, end, error] = cppbor::parse(cose_sign1->payload->value()); if (!error.empty()) { - return absl::InvalidArgumentError(absl::StrCat("couldn't deserialize CWT: ", error)); + return absl::InvalidArgumentError( + absl::StrCat("couldn't deserialize CWT: ", error)); } if (item->type() != cppbor::MAP) { return UnexpectedCborTypeError("CWT", cppbor::MAP, item->type()); } const cppbor::Map* map = item->asMap(); if (map->size() < 3) { - return absl::InvalidArgumentError( - absl::StrCat("invalid CWT map size, expected >= 3, found ", map->size())); + return absl::InvalidArgumentError(absl::StrCat( + "invalid CWT map size, expected >= 3, found ", map->size())); } // Get CWT claims. @@ -67,7 +68,8 @@ absl::StatusOr Cwt::Deserialize(absl::string_view data) { if (sub->type() != cppbor::TSTR) { return UnexpectedCborTypeError("sub", cppbor::TSTR, sub->type()); } - const auto& subject_public_key_item = map->get(SUBJECT_PUBLIC_KEY_ID); + const auto& subject_public_key_item = + map->get(SUBJECT_PUBLIC_KEY_ID); if (subject_public_key_item == nullptr) { return absl::InvalidArgumentError("SUB not found"); } @@ -78,17 +80,20 @@ absl::StatusOr Cwt::Deserialize(absl::string_view data) { // Deserialize COSE_Key. absl::StatusOr subject_public_key = - CoseKey::DeserializeHpkePublicKey(subject_public_key_item->asBstr()->value()); + CoseKey::DeserializeHpkePublicKey( + subject_public_key_item->asBstr()->value()); if (!subject_public_key.ok()) { return subject_public_key.status(); } - return Cwt(iss->asTstr(), sub->asTstr(), std::move(*subject_public_key), std::move(item)); + return Cwt(iss->asTstr(), sub->asTstr(), std::move(*subject_public_key), + std::move(item)); } absl::StatusOr> Cwt::SerializeHpkePublicKey( const std::vector& public_key) { - auto serialized_public_key_certificate = CoseKey::SerializeHpkePublicKey(public_key); + auto serialized_public_key_certificate = + CoseKey::SerializeHpkePublicKey(public_key); if (!serialized_public_key_certificate.ok()) { return serialized_public_key_certificate.status(); } @@ -97,7 +102,8 @@ absl::StatusOr> Cwt::SerializeHpkePublicKey( // TODO(#4818): Implement assigning ISS and SUB public fields. map.add(ISS, cppbor::Tstr("")); map.add(SUB, cppbor::Tstr("")); - map.add(SUBJECT_PUBLIC_KEY_ID, cppbor::Bstr(*serialized_public_key_certificate)); + map.add(SUBJECT_PUBLIC_KEY_ID, + cppbor::Bstr(*serialized_public_key_certificate)); std::vector encoded_map(map.encodedSize()); map.encode(encoded_map.data(), encoded_map.data() + encoded_map.size()); diff --git a/cc/utils/cose/cwt.h b/cc/utils/cose/cwt.h index b6575359d07..076bbdc6bec 100644 --- a/cc/utils/cose/cwt.h +++ b/cc/utils/cose/cwt.h @@ -57,7 +57,8 @@ class Cwt { IAT = 6, CTI = 7, - // Public key associated with the subject in the form of a COSE_Key structure. + // Public key associated with the subject in the form of a COSE_Key + // structure. // SUBJECT_PUBLIC_KEY_ID = -4670552, }; @@ -65,8 +66,8 @@ class Cwt { // Parsed CBOR item containing CWT object. std::unique_ptr item_; - Cwt(const cppbor::Tstr* iss, const cppbor::Tstr* sub, CoseKey&& subject_public_key, - std::unique_ptr&& item) + Cwt(const cppbor::Tstr* iss, const cppbor::Tstr* sub, + CoseKey&& subject_public_key, std::unique_ptr&& item) : iss(iss), sub(sub), subject_public_key(std::move(subject_public_key)), diff --git a/cc/utils/cose/cwt_test.cc b/cc/utils/cose/cwt_test.cc index 9978cc5cd18..61940ab7fba 100644 --- a/cc/utils/cose/cwt_test.cc +++ b/cc/utils/cose/cwt_test.cc @@ -37,24 +37,28 @@ using ::testing::ElementsAreArray; constexpr absl::string_view kTestEvidencePath = "oak_attestation_verification/testdata/oc_evidence.textproto"; -// Public key extracted from the `kTestEvidencePath` `encryption_public_key_certificate`. -constexpr uint8_t kTestPublicKey[] = {169, 153, 134, 149, 237, 126, 255, 33, 224, 237, 186, - 74, 214, 193, 103, 57, 197, 109, 186, 1, 225, 116, - 71, 4, 227, 236, 105, 90, 14, 138, 10, 91}; +// Public key extracted from the `kTestEvidencePath` +// `encryption_public_key_certificate`. +constexpr uint8_t kTestPublicKey[] = {169, 153, 134, 149, 237, 126, 255, 33, + 224, 237, 186, 74, 214, 193, 103, 57, + 197, 109, 186, 1, 225, 116, 71, 4, + 227, 236, 105, 90, 14, 138, 10, 91}; class CertificateTest : public testing::Test { protected: void SetUp() override { std::ifstream test_evidence_file(kTestEvidencePath.data()); ASSERT_TRUE(test_evidence_file); - google::protobuf::io::IstreamInputStream test_evidence_protobuf_stream(&test_evidence_file); + google::protobuf::io::IstreamInputStream test_evidence_protobuf_stream( + &test_evidence_file); auto test_evidence = std::make_unique(); - bool parse_success = - google::protobuf::TextFormat::Parse(&test_evidence_protobuf_stream, test_evidence.get()); + bool parse_success = google::protobuf::TextFormat::Parse( + &test_evidence_protobuf_stream, test_evidence.get()); ASSERT_TRUE(parse_success); - public_key_certificate_ = test_evidence->application_keys().encryption_public_key_certificate(); + public_key_certificate_ = + test_evidence->application_keys().encryption_public_key_certificate(); } std::string public_key_certificate_; @@ -63,14 +67,16 @@ class CertificateTest : public testing::Test { TEST_F(CertificateTest, CwtDeserializeSuccess) { auto cwt = Cwt::Deserialize(public_key_certificate_); EXPECT_TRUE(cwt.ok()) << cwt.status(); - EXPECT_THAT(cwt->subject_public_key.GetPublicKey(), ElementsAreArray(kTestPublicKey)); + EXPECT_THAT(cwt->subject_public_key.GetPublicKey(), + ElementsAreArray(kTestPublicKey)); } TEST_F(CertificateTest, CwtSerializeDeserializeSuccess) { std::vector test_public_key = {1, 2, 3, 4}; auto serialized_cwt = Cwt::SerializeHpkePublicKey(test_public_key); EXPECT_TRUE(serialized_cwt.ok()) << serialized_cwt.status(); - auto serialized_cwt_string = std::string(serialized_cwt->begin(), serialized_cwt->end()); + auto serialized_cwt_string = + std::string(serialized_cwt->begin(), serialized_cwt->end()); auto deserialized_cwt = Cwt::Deserialize(serialized_cwt_string); EXPECT_TRUE(deserialized_cwt.ok()) << deserialized_cwt.status(); diff --git a/flake.lock b/flake.lock index bb250a65082..0fb0e33b888 100644 --- a/flake.lock +++ b/flake.lock @@ -7,11 +7,11 @@ ] }, "locked": { - "lastModified": 1710886643, - "narHash": "sha256-saTZuv9YeZ9COHPuj8oedGdUwJZcbQ3vyRqe7NVJMsQ=", + "lastModified": 1712180168, + "narHash": "sha256-sYe00cK+kKnQlVo1wUIZ5rZl9x8/r3djShUqNgfjnM4=", "owner": "ipetkov", "repo": "crane", - "rev": "5bace74e9a65165c918205cf67ad3977fe79c584", + "rev": "06a9ff255c1681299a87191c2725d9d579f28b82", "type": "github" }, "original": { @@ -42,11 +42,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1710806803, - "narHash": "sha256-qrxvLS888pNJFwJdK+hf1wpRCSQcqA6W5+Ox202NDa0=", + "lastModified": 1712122226, + "narHash": "sha256-pmgwKs8Thu1WETMqCrWUm0CkN1nmCKX3b51+EXsAZyY=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "b06025f1533a1e07b6db3e75151caa155d1c7eb3", + "rev": "08b9151ed40350725eb40b1fe96b0b86304a654b", "type": "github" }, "original": { @@ -75,11 +75,11 @@ ] }, "locked": { - "lastModified": 1710987136, - "narHash": "sha256-Q8GRdlAIKZ8tJUXrbcRO1pA33AdoPfTUirsSnmGQnOU=", + "lastModified": 1712110341, + "narHash": "sha256-8LU2IM4ctHz043hlzoFUwQS1QIdhiMGEH/oIfPCxoWU=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "97596b54ac34ad8184ca1eef44b1ec2e5c2b5f9e", + "rev": "74deb67494783168f5b6d2071d73177e6bccab65", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 3410947c45c..08707c53b18 100644 --- a/flake.nix +++ b/flake.nix @@ -40,6 +40,12 @@ # - CONFIG_MODULE_SIG_ALL is not set # - CONFIG_DEBUG_INFO_DWARF_TOOLCHAIN_DEFAULT is not set configfile = ./oak_containers_kernel/configs/6.1.33/minimal.config; + # And also the following build variables. + # See https://docs.kernel.org/kbuild/reproducible-builds.html. + extraMakeFlags = [ + "KBUILD_BUILD_USER=user" + "KBUILD_BUILD_HOST=host" + ]; version = linux_kernel_version; src = linux_kernel_src; allowImportFromDerivation = true; diff --git a/java/proto/server/secure_proxy.proto b/java/proto/server/secure_proxy.proto index 069f83f55e7..38b7774a63a 100644 --- a/java/proto/server/secure_proxy.proto +++ b/java/proto/server/secure_proxy.proto @@ -26,7 +26,8 @@ option java_package = "com.google.oak.server"; // Secure proxy wraps a service in an encrypted connection. service SecureProxy { // EncryptedConnect creates an encrypted connection between a client and a - // server, and delegates the processing of the requests to another, unencrypted, service. + // server, and delegates the processing of the requests to another, + // unencrypted, service. rpc EncryptedConnect(stream oak.session.v1.RequestWrapper) returns (stream oak.session.v1.ResponseWrapper) {} } diff --git a/justfile b/justfile index 198f8494cc4..643dd8c4137 100644 --- a/justfile +++ b/justfile @@ -26,6 +26,15 @@ oak_restricted_kernel_bin: _wrap_kernel kernel_bin_prefix: env --chdir=oak_restricted_kernel_wrapper OAK_RESTRICTED_KERNEL_FILE_NAME={{kernel_bin_prefix}}_bin cargo build --release rust-objcopy --output-target=binary oak_restricted_kernel_wrapper/target/x86_64-unknown-none/release/oak_restricted_kernel_wrapper oak_restricted_kernel_wrapper/target/x86_64-unknown-none/release/{{kernel_bin_prefix}}_wrapper_bin + rm -rf oak_restricted_kernel_wrapper/target/released_bin_with_components_{{kernel_bin_prefix}} + mkdir -p oak_restricted_kernel_wrapper/target/released_bin_with_components_{{kernel_bin_prefix}} + cp \ + oak_restricted_kernel_wrapper/target/x86_64-unknown-none/release/{{kernel_bin_prefix}}_wrapper_bin \ + oak_restricted_kernel_wrapper/target/released_bin_with_components_{{kernel_bin_prefix}}/bzimage + cargo run --package oak_kernel_measurement -- \ + --kernel=oak_restricted_kernel_wrapper/target/released_bin_with_components_{{kernel_bin_prefix}}/bzimage \ + --kernel-setup-data-output=oak_restricted_kernel_wrapper/target/released_bin_with_components_{{kernel_bin_prefix}}/kernel_setup_data \ + --kernel-image-output=oak_restricted_kernel_wrapper/target/released_bin_with_components_{{kernel_bin_prefix}}/kernel_image oak_restricted_kernel_wrapper: oak_restricted_kernel_bin just _wrap_kernel oak_restricted_kernel diff --git a/micro_rpc_workspace_test/Cargo.lock b/micro_rpc_workspace_test/Cargo.lock index b25acd9ac21..9cc1af0edac 100644 --- a/micro_rpc_workspace_test/Cargo.lock +++ b/micro_rpc_workspace_test/Cargo.lock @@ -543,9 +543,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" +checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" dependencies = [ "bytes", "fnv", diff --git a/micro_rpc_workspace_test/proto/stubs.proto b/micro_rpc_workspace_test/proto/stubs.proto index 5522cd03ac1..d3092a9f476 100644 --- a/micro_rpc_workspace_test/proto/stubs.proto +++ b/micro_rpc_workspace_test/proto/stubs.proto @@ -16,10 +16,11 @@ syntax = "proto3"; -// This file contains empty proto message stubs to make protoc happy when compiling code that -// depends on Oak protos that are not available in other repos. -// The actual implementations are provided in Rust using the `extern_paths` option of -// tonic: https://docs.rs/tonic-build/0.10.2/tonic_build/struct.Builder.html#method.extern_path +// This file contains empty proto message stubs to make protoc happy when +// compiling code that depends on Oak protos that are not available in other +// repos. The actual implementations are provided in Rust using the +// `extern_paths` option of tonic: +// https://docs.rs/tonic-build/0.10.2/tonic_build/struct.Builder.html#method.extern_path // See also the `build.rs` file in this crate. package oak.crypto.v1; diff --git a/oak_attestation/src/dice.rs b/oak_attestation/src/dice.rs index f8cc9b51278..4b6bdbe6751 100644 --- a/oak_attestation/src/dice.rs +++ b/oak_attestation/src/dice.rs @@ -96,6 +96,8 @@ impl DiceBuilder { additional_claims: Vec<(ClaimName, ciborium::Value)>, kem_public_key: &[u8], verifying_key: &VerifyingKey, + group_kem_public_key: Option<&[u8]>, + group_verifying_key: Option<&VerifyingKey>, ) -> anyhow::Result { // The last evidence layer contains the certificate for the current signing key. // Since the builder contains an existing signing key there must be at @@ -127,7 +129,7 @@ impl DiceBuilder { let signing_public_key_certificate = generate_signing_certificate( &self.signing_key, - issuer_id, + issuer_id.clone(), verifying_key, additional_claims, ) @@ -136,9 +138,44 @@ impl DiceBuilder { .to_vec() .map_err(anyhow::Error::msg)?; + // Generate group keys certificates as part of Key Provisioning. + let group_encryption_public_key_certificate = + if let Some(group_kem_public_key) = group_kem_public_key { + generate_kem_certificate( + &self.signing_key, + issuer_id.clone(), + group_kem_public_key, + vec![], + ) + .map_err(anyhow::Error::msg) + .context("couldn't generate encryption public key certificate")? + .to_vec() + .map_err(anyhow::Error::msg)? + } else { + vec![] + }; + + let group_signing_public_key_certificate = + if let Some(group_verifying_key) = group_verifying_key { + generate_signing_certificate( + &self.signing_key, + issuer_id.clone(), + group_verifying_key, + vec![], + ) + .map_err(anyhow::Error::msg) + .context("couldn't generate signing public key certificate")? + .to_vec() + .map_err(anyhow::Error::msg)? + } else { + vec![] + }; + evidence.application_keys = Some(ApplicationKeys { encryption_public_key_certificate, signing_public_key_certificate, + group_encryption_public_key_certificate, + group_signing_public_key_certificate, }); Ok(evidence) @@ -237,5 +274,10 @@ fn application_keys_to_proto( let signing_public_key_certificate = oak_dice::utils::cbor_encoded_bytes_to_vec(&value.signing_public_key_certificate[..]) .map_err(anyhow::Error::msg)?; - Ok(ApplicationKeys { encryption_public_key_certificate, signing_public_key_certificate }) + Ok(ApplicationKeys { + encryption_public_key_certificate, + signing_public_key_certificate, + group_encryption_public_key_certificate: vec![], + group_signing_public_key_certificate: vec![], + }) } diff --git a/oak_attestation_integration_tests/tests/verifier_tests.rs b/oak_attestation_integration_tests/tests/verifier_tests.rs index bdb67e6a91c..efc3b7a3c4e 100644 --- a/oak_attestation_integration_tests/tests/verifier_tests.rs +++ b/oak_attestation_integration_tests/tests/verifier_tests.rs @@ -18,11 +18,11 @@ use oak_attestation::dice::evidence_to_proto; use oak_attestation_verification::verifier::{to_attestation_results, verify, verify_dice_chain}; use oak_proto_rust::oak::attestation::v1::{ attestation_results::Status, binary_reference_value, endorsements, - kernel_binary_reference_value, reference_values, regex_reference_value, + kernel_binary_reference_value, reference_values, text_reference_value, ApplicationLayerReferenceValues, BinaryReferenceValue, Endorsements, InsecureReferenceValues, KernelBinaryReferenceValue, KernelLayerReferenceValues, OakRestrictedKernelEndorsements, - OakRestrictedKernelReferenceValues, ReferenceValues, RegexReferenceValue, - RootLayerEndorsements, RootLayerReferenceValues, SkipVerification, + OakRestrictedKernelReferenceValues, ReferenceValues, RootLayerEndorsements, + RootLayerReferenceValues, SkipVerification, TextReferenceValue, }; use oak_restricted_kernel_sdk::attestation::EvidenceProvider; @@ -78,11 +78,12 @@ fn verify_mock_evidence() { SkipVerification {}, )), }), - kernel_image: Some(skip.clone()), - kernel_setup_data: Some(skip.clone()), - kernel_cmd_line: Some(skip.clone()), - kernel_cmd_line_regex: Some(RegexReferenceValue { - r#type: Some(regex_reference_value::Type::Skip(SkipVerification {})), + kernel_image: None, + kernel_setup_data: None, + kernel_cmd_line: None, + kernel_cmd_line_regex: None, + kernel_cmd_line_text: Some(TextReferenceValue { + r#type: Some(text_reference_value::Type::Skip(SkipVerification {})), }), init_ram_fs: Some(skip.clone()), memory_map: Some(skip.clone()), diff --git a/oak_attestation_verification/BUILD b/oak_attestation_verification/BUILD index 53ecb3610c4..062cda204cd 100644 --- a/oak_attestation_verification/BUILD +++ b/oak_attestation_verification/BUILD @@ -26,6 +26,36 @@ rust_library( compile_data = [ "//oak_attestation_verification/data:ask_milan.pem", ], + deps = [ + "//oak_dice", + "//oak_proto_rust", + "//oak_sev_snp_attestation_report", + "@oak_crates_index//:anyhow", + "@oak_crates_index//:base64", + "@oak_crates_index//:coset", + "@oak_crates_index//:ecdsa", + "@oak_crates_index//:getrandom", + "@oak_crates_index//:hex", + "@oak_crates_index//:p256", + "@oak_crates_index//:p384", + "@oak_crates_index//:prost", + "@oak_crates_index//:rsa", + "@oak_crates_index//:serde", + "@oak_crates_index//:serde_json", + "@oak_crates_index//:sha2", + "@oak_crates_index//:time", + "@oak_crates_index//:x509-cert", + "@oak_crates_index//:zerocopy", + ], +) + +rust_library( + name = "oak_attestation_verification_with_regex", + srcs = glob(["src/**"]), + compile_data = [ + "//oak_attestation_verification/data:ask_milan.pem", + ], + crate_features = ["regex"], deps = [ "//oak_dice", "//oak_proto_rust", diff --git a/oak_attestation_verification/Cargo.toml b/oak_attestation_verification/Cargo.toml index 2c3efabbe39..d6ffee6bbdb 100644 --- a/oak_attestation_verification/Cargo.toml +++ b/oak_attestation_verification/Cargo.toml @@ -35,7 +35,7 @@ p384 = { version = "0.13.0", default-features = false, features = [ prost = { workspace = true, default-features = false, features = [ "prost-derive", ] } -regex = { version = "*", default-features = false } +regex = { version = "*", default-features = false, optional = true } rsa = { version = "0.9.6", default-features = false } serde = { version = "*", default-features = false, features = ["derive"] } serde_json = { version = "*", default-features = false, features = ["alloc"] } diff --git a/oak_attestation_verification/src/verifier.rs b/oak_attestation_verification/src/verifier.rs index 075b751179e..e50bf909fc0 100644 --- a/oak_attestation_verification/src/verifier.rs +++ b/oak_attestation_verification/src/verifier.rs @@ -33,8 +33,8 @@ use oak_proto_rust::oak::{ attestation::v1::{ attestation_results::Status, binary_reference_value, endorsements, extracted_evidence::EvidenceValues, kernel_binary_reference_value, reference_values, - regex_reference_value, root_layer_data::Report, AmdAttestationReport, - AmdSevReferenceValues, ApplicationKeys, ApplicationLayerData, ApplicationLayerEndorsements, + root_layer_data::Report, text_reference_value, AmdAttestationReport, AmdSevReferenceValues, + ApplicationKeys, ApplicationLayerData, ApplicationLayerEndorsements, ApplicationLayerReferenceValues, AttestationResults, BinaryReferenceValue, CbData, CbEndorsements, CbReferenceValues, ContainerLayerData, ContainerLayerEndorsements, ContainerLayerReferenceValues, Endorsements, Evidence, ExtractedEvidence, @@ -43,14 +43,15 @@ use oak_proto_rust::oak::{ KernelLayerEndorsements, KernelLayerReferenceValues, OakContainersData, OakContainersEndorsements, OakContainersReferenceValues, OakRestrictedKernelData, OakRestrictedKernelEndorsements, OakRestrictedKernelReferenceValues, ReferenceValues, - RegexReferenceValue, RootLayerData, RootLayerEndorsements, RootLayerEvidence, - RootLayerReferenceValues, SystemLayerData, SystemLayerEndorsements, - SystemLayerReferenceValues, TcbVersion, TeePlatform, TransparentReleaseEndorsement, + RootLayerData, RootLayerEndorsements, RootLayerEvidence, RootLayerReferenceValues, + SystemLayerData, SystemLayerEndorsements, SystemLayerReferenceValues, TcbVersion, + TeePlatform, TextReferenceValue, TransparentReleaseEndorsement, }, HexDigest, RawDigest, }; use oak_sev_snp_attestation_report::AttestationReport; use prost::Message; +#[cfg(feature = "regex")] use regex::Regex; use x509_cert::{ der::{Decode, DecodePem}, @@ -518,27 +519,29 @@ fn verify_kernel_layer( .context("kernel failed verification")?; if let Some(kernel_raw_cmd_line) = values.kernel_raw_cmd_line.as_ref() { - verify_regex( + verify_text( + now_utc_millis, kernel_raw_cmd_line.as_str(), reference_values - .kernel_cmd_line_regex + .kernel_cmd_line_text .as_ref() - .context("no kernel command line regex reference values")?, + .context("no kernel command line text reference values")?, + endorsements.and_then(|value| value.kernel_cmd_line.as_ref()), ) .context("kernel command line failed verification")?; } else { - // Support missing kernel_cmd_line_regex but only if the corresponding reference + // Support missing kernel_raw_cmd_line but only if the corresponding reference // value is set to skip. This is a temporary workaround until all clients are // migrated. anyhow::ensure!( matches!( reference_values - .kernel_cmd_line_regex + .kernel_cmd_line_text .as_ref() - .expect("no kernel command line regex reference values") + .expect("no kernel command line text reference values") .r#type .as_ref(), - Some(regex_reference_value::Type::Skip(_)) + Some(text_reference_value::Type::Skip(_)) ), "No kernel_raw_cmd_line provided" ) @@ -760,25 +763,63 @@ fn verify_hex_digests(actual: &HexDigest, expected: &HexDigest) -> anyhow::Resul } } -fn verify_regex(actual: &str, expected: &RegexReferenceValue) -> anyhow::Result<()> { +fn verify_text( + now_utc_millis: i64, + actual: &str, + expected: &TextReferenceValue, + endorsement: Option<&TransparentReleaseEndorsement>, +) -> anyhow::Result<()> { match expected.r#type.as_ref() { - Some(regex_reference_value::Type::Skip(_)) => Ok(()), - Some(regex_reference_value::Type::Regex(regex)) => { - let re = Regex::new(regex.value.as_str()).map_err(|msg| { - anyhow::anyhow!("Couldn't parse regex in the reference value: {msg}") - })?; - if re.is_match(actual) { - Ok(()) - } else { - anyhow::bail!(format!( - "kernel cmd line doesn't match the reference value: {actual}" - )) + Some(text_reference_value::Type::Skip(_)) => Ok(()), + Some(text_reference_value::Type::Endorsement(public_keys)) => { + let endorsement = + endorsement.context("matching endorsement not found for text reference value")?; + verify_binary_endorsement( + now_utc_millis, + &endorsement.endorsement, + &endorsement.endorsement_signature, + &endorsement.rekor_log_entry, + &public_keys.endorser_public_key, + &public_keys.rekor_public_key, + )?; + // Compare the actual command line against the one inlined in the endorsement. + let regex = String::from_utf8(endorsement.subject.clone()) + .expect("endorsement subject is not utf8"); + verify_regex(actual, ®ex).context("regex from endorsement does not match") + } + Some(text_reference_value::Type::Regex(regex)) => { + verify_regex(actual, ®ex.value).context("regex from reference values does not match") + } + Some(text_reference_value::Type::StringLiterals(string_literals)) => { + anyhow::ensure!(!string_literals.value.is_empty()); + for sl in string_literals.value.iter() { + if sl == actual { + return Ok(()); + } } + Err(anyhow::anyhow!(format!( + "value doesn't match the reference value string literal: {actual}" + ))) } - None => Err(anyhow::anyhow!("missing skip or value in the regex reference value")), + None => Err(anyhow::anyhow!("missing skip or value in the text reference value")), } } +#[cfg(feature = "regex")] +fn verify_regex(actual: &str, regex: &str) -> anyhow::Result<()> { + let re = Regex::new(regex) + .map_err(|msg| anyhow::anyhow!("couldn't parse regex in the reference value: {msg}"))?; + Ok(anyhow::ensure!( + re.is_match(actual), + format!("value doesn't match the reference value regex: {actual}") + )) +} + +#[cfg(not(feature = "regex"))] +fn verify_regex(_actual: &str, _regex: &str) -> anyhow::Result<()> { + Err(anyhow::anyhow!("verification of regex values not supported")) +} + struct ApplicationKeyValues { encryption_public_key: Vec, signing_public_key: Vec, diff --git a/oak_attestation_verification/tests/verifier_tests.rs b/oak_attestation_verification/tests/verifier_tests.rs index 364d2a665aa..2e62f804146 100644 --- a/oak_attestation_verification/tests/verifier_tests.rs +++ b/oak_attestation_verification/tests/verifier_tests.rs @@ -23,15 +23,15 @@ use oak_attestation_verification::{ use oak_proto_rust::oak::{ attestation::v1::{ attestation_results::Status, binary_reference_value, extracted_evidence::EvidenceValues, - kernel_binary_reference_value, reference_values, regex_reference_value, - root_layer_data::Report, AmdSevReferenceValues, ApplicationLayerEndorsements, + kernel_binary_reference_value, reference_values, root_layer_data::Report, + text_reference_value, AmdSevReferenceValues, ApplicationLayerEndorsements, ApplicationLayerReferenceValues, BinaryReferenceValue, ContainerLayerEndorsements, ContainerLayerReferenceValues, Digests, EndorsementReferenceValue, Endorsements, Evidence, InsecureReferenceValues, KernelBinaryReferenceValue, KernelLayerEndorsements, KernelLayerReferenceValues, OakContainersEndorsements, OakContainersReferenceValues, OakRestrictedKernelEndorsements, OakRestrictedKernelReferenceValues, ReferenceValues, - Regex, RegexReferenceValue, RootLayerEndorsements, RootLayerReferenceValues, - SkipVerification, SystemLayerEndorsements, SystemLayerReferenceValues, TcbVersion, + Regex, RootLayerEndorsements, RootLayerReferenceValues, SkipVerification, StringLiterals, + SystemLayerEndorsements, SystemLayerReferenceValues, TcbVersion, TextReferenceValue, TransparentReleaseEndorsement, }, RawDigest, @@ -167,11 +167,16 @@ fn create_containers_reference_values() -> ReferenceValues { kernel: Some(KernelBinaryReferenceValue { r#type: Some(kernel_binary_reference_value::Type::Skip(SkipVerification {})), }), - kernel_image: Some(skip.clone()), - kernel_setup_data: Some(skip.clone()), - kernel_cmd_line: Some(skip.clone()), - kernel_cmd_line_regex: Some(RegexReferenceValue { - r#type: Some(regex_reference_value::Type::Regex(Regex { value: String::from("^.*$") })), + kernel_setup_data: None, + kernel_image: None, + kernel_cmd_line: None, + kernel_cmd_line_regex: None, + kernel_cmd_line_text: Some(TextReferenceValue { + r#type: Some(text_reference_value::Type::StringLiterals(StringLiterals { + value: vec![String::from( + "console=ttyS0 panic=-1 earlycon=uart,io,0x3F8 brd.rd_nr=1 brd.rd_size=3072000 brd.max_part=1 ip=10.0.2.15:::255.255.255.0::eth0:off net.ifnames=0 quiet", + )], + })), }), init_ram_fs: Some(skip.clone()), memory_map: Some(skip.clone()), @@ -210,11 +215,14 @@ fn create_rk_reference_values() -> ReferenceValues { kernel: Some(KernelBinaryReferenceValue { r#type: Some(kernel_binary_reference_value::Type::Skip(SkipVerification {})), }), - kernel_image: Some(skip.clone()), - kernel_setup_data: Some(skip.clone()), - kernel_cmd_line: Some(skip.clone()), - kernel_cmd_line_regex: Some(RegexReferenceValue { - r#type: Some(regex_reference_value::Type::Regex(Regex { value: String::from("^.*$") })), + kernel_setup_data: None, + kernel_image: None, + kernel_cmd_line: None, + kernel_cmd_line_regex: None, + kernel_cmd_line_text: Some(TextReferenceValue { + r#type: Some(text_reference_value::Type::StringLiterals(StringLiterals { + value: vec![String::from("console=ttyS0")], + })), }), init_ram_fs: Some(skip.clone()), memory_map: Some(skip.clone()), @@ -498,8 +506,8 @@ fn verify_fails_with_non_matching_command_line_reference_value_set() { let mut reference_values = create_rk_reference_values(); match reference_values.r#type.as_mut() { Some(reference_values::Type::OakRestrictedKernel(rfs)) => { - rfs.kernel_layer.as_mut().unwrap().kernel_cmd_line_regex = Some(RegexReferenceValue { - r#type: Some(regex_reference_value::Type::Regex(Regex { + rfs.kernel_layer.as_mut().unwrap().kernel_cmd_line_text = Some(TextReferenceValue { + r#type: Some(text_reference_value::Type::Regex(Regex { value: String::from("this will fail"), })), }); @@ -519,14 +527,43 @@ fn verify_fails_with_non_matching_command_line_reference_value_set() { } #[test] -fn verify_succeeds_with_matching_command_line_reference_value_set() { +#[cfg(not(feature = "regex"))] +fn verify_fails_with_matching_command_line_reference_value_regex_set_and_regex_disabled() { + let evidence = create_rk_evidence(); + let endorsements = create_rk_endorsements(); + let mut reference_values = create_rk_reference_values(); + match reference_values.r#type.as_mut() { + Some(reference_values::Type::OakRestrictedKernel(rfs)) => { + rfs.kernel_layer.as_mut().unwrap().kernel_cmd_line_text = Some(TextReferenceValue { + r#type: Some(text_reference_value::Type::Regex(Regex { + value: String::from("^console=[a-zA-Z0-9]+$"), + })), + }); + } + Some(_) => {} + None => {} + }; + + let r = verify(NOW_UTC_MILLIS, &evidence, &endorsements, &reference_values); + let p = to_attestation_results(&r); + + eprintln!("======================================"); + eprintln!("code={} reason={}", p.status as i32, p.reason); + eprintln!("======================================"); + assert!(r.is_err()); + assert!(p.status() == Status::GenericFailure); +} + +#[test] +#[cfg(feature = "regex")] +fn verify_succeeds_with_matching_command_line_reference_value_regex_set_and_regex_enabled() { let evidence = create_rk_evidence(); let endorsements = create_rk_endorsements(); let mut reference_values = create_rk_reference_values(); match reference_values.r#type.as_mut() { Some(reference_values::Type::OakRestrictedKernel(rfs)) => { - rfs.kernel_layer.as_mut().unwrap().kernel_cmd_line_regex = Some(RegexReferenceValue { - r#type: Some(regex_reference_value::Type::Regex(Regex { + rfs.kernel_layer.as_mut().unwrap().kernel_cmd_line_text = Some(TextReferenceValue { + r#type: Some(text_reference_value::Type::Regex(Regex { value: String::from("^console=[a-zA-Z0-9]+$"), })), }); @@ -568,8 +605,8 @@ fn verify_succeeds_with_skip_command_line_reference_value_set_and_obsolete_evide let mut reference_values = create_rk_reference_values(); match reference_values.r#type.as_mut() { Some(reference_values::Type::OakRestrictedKernel(rfs)) => { - rfs.kernel_layer.as_mut().unwrap().kernel_cmd_line_regex = Some(RegexReferenceValue { - r#type: Some(regex_reference_value::Type::Skip(SkipVerification {})), + rfs.kernel_layer.as_mut().unwrap().kernel_cmd_line_text = Some(TextReferenceValue { + r#type: Some(text_reference_value::Type::Skip(SkipVerification {})), }); } Some(_) => {} diff --git a/oak_containers/proto/interfaces.proto b/oak_containers/proto/interfaces.proto index f11b978adb9..2d6cad964cd 100644 --- a/oak_containers/proto/interfaces.proto +++ b/oak_containers/proto/interfaces.proto @@ -24,9 +24,9 @@ import "proto/attestation/endorsement.proto"; import "proto/attestation/evidence.proto"; import "proto/session/messages.proto"; -// As images can be large (hundreds of megabytes), the launcher chunks up the response into smaller -// pieces to respect proto/gRPC limits. The image needs to be reassembled in the stage1 or the -// orchestrator. +// As images can be large (hundreds of megabytes), the launcher chunks up the +// response into smaller pieces to respect proto/gRPC limits. The image needs to +// be reassembled in the stage1 or the orchestrator. message GetImageResponse { bytes image_chunk = 1; } @@ -41,39 +41,45 @@ message SendAttestationEvidenceRequest { oak.attestation.v1.Evidence dice_evidence = 2; } -// Defines the service exposed by the launcher, that can be invoked by the stage1 and the -// orchestrator. +// Defines the service exposed by the launcher, that can be invoked by the +// stage1 and the orchestrator. service Launcher { - // Provides stage1 with the Oak system image (which contains the Linux distribution and the - // orchestrator binary). - rpc GetOakSystemImage(google.protobuf.Empty) returns (stream GetImageResponse) {} + // Provides stage1 with the Oak system image (which contains the Linux + // distribution and the orchestrator binary). + rpc GetOakSystemImage(google.protobuf.Empty) + returns (stream GetImageResponse) {} // Provides orchestrator with the trusted container image. - rpc GetContainerBundle(google.protobuf.Empty) returns (stream GetImageResponse) {} + rpc GetContainerBundle(google.protobuf.Empty) + returns (stream GetImageResponse) {} // This method is used by the orchestrator to load and measure the trusted // application config. The orchestrator will later, separately expose this // config to the application. - rpc GetApplicationConfig(google.protobuf.Empty) returns (GetApplicationConfigResponse) {} + rpc GetApplicationConfig(google.protobuf.Empty) + returns (GetApplicationConfigResponse) {} - // Sends Attestation Evidence containing the Attestation Report with corresponding measurements - // and public keys to the Launcher. - // This API is called exactly once after the Attestation Evidence is generated. Calling this API - // a second time will result in an error. - rpc SendAttestationEvidence(SendAttestationEvidenceRequest) returns (google.protobuf.Empty) {} + // Sends Attestation Evidence containing the Attestation Report with + // corresponding measurements and public keys to the Launcher. This API is + // called exactly once after the Attestation Evidence is generated. Calling + // this API a second time will result in an error. + rpc SendAttestationEvidence(SendAttestationEvidenceRequest) + returns (google.protobuf.Empty) {} - // Notifies the launcher that the trusted app is ready to serve requests and listening on the - // pre-arranged port (8080). + // Notifies the launcher that the trusted app is ready to serve requests and + // listening on the pre-arranged port (8080). rpc NotifyAppReady(google.protobuf.Empty) returns (google.protobuf.Empty) {} } -// Defines the service exposed by the orchestrator, that can be invoked by the application. +// Defines the service exposed by the orchestrator, that can be invoked by the +// application. service Orchestrator { - // Exposes the previously loaded trusted application config to the application, - // which may choose to retrieve it. - rpc GetApplicationConfig(google.protobuf.Empty) returns (GetApplicationConfigResponse) {} + // Exposes the previously loaded trusted application config to the + // application, which may choose to retrieve it. + rpc GetApplicationConfig(google.protobuf.Empty) + returns (GetApplicationConfigResponse) {} - // Notifies the orchestrator that the trusted app is ready to serve requests and listening on the - // pre-arranged port (8080). + // Notifies the orchestrator that the trusted app is ready to serve requests + // and listening on the pre-arranged port (8080). rpc NotifyAppReady(google.protobuf.Empty) returns (google.protobuf.Empty) {} } diff --git a/oak_containers_orchestrator/src/main.rs b/oak_containers_orchestrator/src/main.rs index bdf761a9af4..22961d67942 100644 --- a/oak_containers_orchestrator/src/main.rs +++ b/oak_containers_orchestrator/src/main.rs @@ -56,78 +56,89 @@ async fn main() -> anyhow::Result<()> { .map_err(|error| anyhow!("couldn't create client: {:?}", error))?, ); + // Get key provisioning role. + let key_provisioning_role = launcher_client + .get_key_provisioning_role() + .await + .map_err(|error| anyhow!("couldn't get key provisioning role: {:?}", error))?; + + // Generate application keys. + let (instance_keys, instance_public_keys) = generate_instance_keys(); + let (mut group_keys, group_public_keys) = + if key_provisioning_role == KeyProvisioningRole::Leader { + let (group_keys, group_public_keys) = instance_keys.generate_group_keys(); + (Some(Arc::new(group_keys)), Some(group_public_keys)) + } else { + (None, None) + }; + + // Load application. let container_bundle = launcher_client .get_container_bundle() .await .map_err(|error| anyhow!("couldn't get container bundle: {:?}", error))?; - let application_config = launcher_client .get_application_config() .await .map_err(|error| anyhow!("couldn't get application config: {:?}", error))?; + // Generate attestation evidence and send it to the Hostlib. let dice_builder = oak_containers_orchestrator::dice::load_stage1_dice_data()?; let additional_claims = oak_containers_orchestrator::dice::measure_container_and_config( &container_bundle, &application_config, ); - let (instance_keys, instance_public_keys) = generate_instance_keys(); - let evidence = dice_builder.add_application_keys( additional_claims, &instance_public_keys.encryption_public_key, &instance_public_keys.signing_public_key, + if let Some(ref group_public_keys) = group_public_keys { + Some(&group_public_keys.encryption_public_key) + } else { + None + }, + None, )?; launcher_client .send_attestation_evidence(evidence) .await .map_err(|error| anyhow!("couldn't send attestation evidence: {:?}", error))?; + // Request group keys. + if key_provisioning_role == KeyProvisioningRole::Follower { + let get_group_keys_response = launcher_client + .get_group_keys() + .await + .map_err(|error| anyhow!("couldn't get group keys: {:?}", error))?; + let provisioned_group_keys = instance_keys + .provide_group_keys(get_group_keys_response) + .context("couldn't provide group keys")?; + group_keys = Some(Arc::new(provisioned_group_keys)); + } + if let Some(path) = args.ipc_socket_path.parent() { tokio::fs::create_dir_all(path).await?; } - let key_provisioning_role = launcher_client - .get_key_provisioning_role() - .await - .map_err(|error| anyhow!("couldn't get key provisioning role: {:?}", error))?; - let group_keys = Arc::new(match key_provisioning_role { - KeyProvisioningRole::Unspecified => anyhow::bail!("unspecified key provisioning role"), - KeyProvisioningRole::Leader => { - // TODO(#4442): Sign group public keys in the enclave evidence. - let (group_keys, _) = instance_keys.generate_group_keys(); - group_keys - } - KeyProvisioningRole::Dependant => { - let get_group_keys_response = launcher_client - .get_group_keys() - .await - .map_err(|error| anyhow!("couldn't get group keys: {:?}", error))?; - instance_keys - .provide_group_keys(get_group_keys_response) - .context("couldn't provide group keys")? - } - }); - let _metrics = oak_containers_orchestrator::metrics::run(launcher_client.clone())?; + // Start application and gRPC servers. let user = nix::unistd::User::from_name(&args.runtime_user) .context(format!("error resolving user {}", args.runtime_user))? .context(format!("user `{}` not found", args.runtime_user))?; - let cancellation_token = CancellationToken::new(); tokio::try_join!( oak_containers_orchestrator::ipc_server::create( &args.ipc_socket_path, instance_keys, - group_keys.clone(), + group_keys.clone().context("group keys were not provisioned")?, application_config, launcher_client, cancellation_token.clone(), ), oak_containers_orchestrator::key_provisioning::create( &args.orchestrator_addr, - group_keys, + group_keys.context("group keys were not provisioned")?, cancellation_token.clone(), ), oak_containers_orchestrator::container_runtime::run( diff --git a/oak_dice/src/evidence.rs b/oak_dice/src/evidence.rs index 0ce5390f31b..0e59d79c5f2 100644 --- a/oak_dice/src/evidence.rs +++ b/oak_dice/src/evidence.rs @@ -53,10 +53,7 @@ pub const X25519_PRIVATE_KEY_SIZE: usize = 32; /// public key. pub const PUBLIC_KEY_SIZE: usize = 256; -/// The maximum size of a larger serialized CWT certificate. -pub const LARGE_CERTIFICATE_SIZE: usize = 1536; - -/// The maximum size of a standard serialized CWT certificate. +/// The maximum size of a serialized CWT certificate. pub const CERTIFICATE_SIZE: usize = 1024; /// The name of the kernel command-line parameter that is used to send the @@ -142,7 +139,7 @@ pub struct LayerEvidence { /// Serialized CWT certificate for the ECA private key owned by the /// corresponding layer. The certificate must include measurements of /// the layer that owns the private key. - pub eca_certificate: [u8; LARGE_CERTIFICATE_SIZE], + pub eca_certificate: [u8; CERTIFICATE_SIZE], } impl LayerEvidence { @@ -153,7 +150,7 @@ impl LayerEvidence { } } -static_assertions::assert_eq_size!([u8; LARGE_CERTIFICATE_SIZE], LayerEvidence); +static_assertions::assert_eq_size!([u8; CERTIFICATE_SIZE], LayerEvidence); /// Private key that can be used by a layer to sign a certificate for the next /// layer. @@ -193,7 +190,7 @@ pub struct Stage0DiceData { pub layer_1_certificate_authority: CertificateAuthority, /// The compound device identifier for Layer 1. pub layer_1_cdi: CompoundDeviceIdentifier, - _padding_1: [u8; 128], + _padding_1: [u8; 640], } static_assertions::assert_eq_size!([u8; 4096], Stage0DiceData); @@ -251,7 +248,7 @@ pub struct Evidence { pub application_keys: ApplicationKeys, } -static_assertions::assert_eq_size!([u8; 5904], Evidence); +static_assertions::assert_eq_size!([u8; 5392], Evidence); /// Wrapper for passing the attestation evidence and private keys from the /// Restricted Kernel to the application. @@ -262,4 +259,4 @@ pub struct RestrictedKernelDiceData { pub application_private_keys: ApplicationPrivateKeys, } -static_assertions::assert_eq_size!([u8; 6032], RestrictedKernelDiceData); +static_assertions::assert_eq_size!([u8; 5520], RestrictedKernelDiceData); diff --git a/oak_grpc_unary_attestation/proto/unary_server.proto b/oak_grpc_unary_attestation/proto/unary_server.proto index 797cd2dcd18..29cb0032186 100644 --- a/oak_grpc_unary_attestation/proto/unary_server.proto +++ b/oak_grpc_unary_attestation/proto/unary_server.proto @@ -54,21 +54,26 @@ service UnarySession { // // The expected message seqeuence starts with an intial handshake: // - Client->Server: `UnaryRequest` with a serialized `ClientHello` message. - // - Server->Client: `UnaryResponse` with a serialized `ServerIdentity` message. - // - Client->Server: `UnaryRequest` with a serialized `ClientIdentity` message. - // - Server->Client: `UnaryResponse` with an empty message, confirming handshake completion. + // - Server->Client: `UnaryResponse` with a serialized `ServerIdentity` + // message. + // - Client->Server: `UnaryRequest` with a serialized `ClientIdentity` + // message. + // - Server->Client: `UnaryResponse` with an empty message, confirming + // handshake completion. // // After the handshake, both client and server derive matching session keys // and are then able to exchange multiple UnaryRequest/UnaryResponse request // pairs that contain a seralized `EncryptedData` message: // - Client->Server: `UnaryRequest` with a serialized `EncryptedData` message. - // - Server->Client: `UnaryResponse` with a serialized `EncryptedData` message. + // - Server->Client: `UnaryResponse` with a serialized `EncryptedData` + // message. // - // Messages are represented as serialized messages defined in the `remote_attestation::message.rs` - // and `com.google.oak.remote_attestation.Message`. + // Messages are represented as serialized messages defined in the + // `remote_attestation::message.rs` and + // `com.google.oak.remote_attestation.Message`. rpc Message(UnaryRequest) returns (UnaryResponse); - // Gets the public key and the attestation report that binds the public key to a specific instance - // of the code running in a TEE. + // Gets the public key and the attestation report that binds the public key to + // a specific instance of the code running in a TEE. rpc GetPublicKeyInfo(google.protobuf.Empty) returns (PublicKeyInfo); } diff --git a/oak_kernel_measurement/src/main.rs b/oak_kernel_measurement/src/main.rs index 60054e48e2b..e6cda81bac9 100644 --- a/oak_kernel_measurement/src/main.rs +++ b/oak_kernel_measurement/src/main.rs @@ -29,6 +29,10 @@ const DEFAULT_LINUX_KERNEL: &str = "oak_containers_kernel/target/bzImage"; struct Cli { #[arg(long, help = "The location of the kernel bzImage file")] kernel: Option, + #[arg(long, help = "The location of output the extracted kernel setup data file to")] + kernel_setup_data_output: Option, + #[arg(long, help = "The location of output the extracted kernel image file to")] + kernel_image_output: Option, } impl Cli { @@ -51,6 +55,14 @@ fn main() -> anyhow::Result<()> { setup_hasher.update(&kernel_info.setup_data); println!("Kernel Setup Data Measurement: sha2-256:{}", hex::encode(setup_hasher.finalize())); + if let Some(path) = cli.kernel_setup_data_output { + std::fs::write(path, kernel_info.setup_data).context("couldn't write kernel setup data")?; + } + + if let Some(path) = cli.kernel_image_output { + std::fs::write(path, kernel_info.kernel_image).context("couldn't write kernel image")?; + } + Ok(()) } diff --git a/oak_ml_transparency/runner/Cargo.lock b/oak_ml_transparency/runner/Cargo.lock index d205af7192d..6c26de3f55e 100644 --- a/oak_ml_transparency/runner/Cargo.lock +++ b/oak_ml_transparency/runner/Cargo.lock @@ -660,7 +660,6 @@ dependencies = [ "p256", "p384", "prost", - "regex", "rsa", "serde", "serde_json", diff --git a/oak_restricted_kernel/src/interrupts.rs b/oak_restricted_kernel/src/interrupts.rs index 6da184117ba..578657213fc 100644 --- a/oak_restricted_kernel/src/interrupts.rs +++ b/oak_restricted_kernel/src/interrupts.rs @@ -108,6 +108,17 @@ mutable_interrupt_handler_with_error_code!( ) { match error_code { 0x72 => { + // Make sure it was triggered from a CPUID instruction. + const CPUID_INSTRUCTION: u16 = 0xa20f; + // Safety: we are copying two bytes and interpreting it as a + // 16-bit number without making any other assumptions about + // the layout. + let instruction: u16 = + unsafe { core::ptr::read_unaligned(stack_frame.rip.as_ptr()) }; + if instruction != CPUID_INSTRUCTION { + panic!("INSTRUCTION WAS NOT CPUID"); + } + if let Some(cpuid_page) = CPUID_PAGE.get() { let target = stack_frame.into(); let count = cpuid_page.count as usize; diff --git a/oak_sev_guest/src/ghcb.rs b/oak_sev_guest/src/ghcb.rs index 1387f04cfab..b305269e9cb 100644 --- a/oak_sev_guest/src/ghcb.rs +++ b/oak_sev_guest/src/ghcb.rs @@ -19,14 +19,18 @@ //! hypervisor. use bitflags::bitflags; -use x86_64::{PhysAddr, VirtAddr}; +use x86_64::{ + structures::paging::{page::NotGiantPageSize, PageSize, PhysFrame, Size2MiB, Size4KiB}, + PhysAddr, VirtAddr, +}; use zerocopy::{AsBytes, FromBytes, FromZeroes}; use crate::{ cpuid::{CpuidInput, CpuidOutput}, + instructions::PageSize as SevPageSize, msr::{ - register_ghcb_location, set_ghcb_address_and_exit, GhcbGpa, RegisterGhcbGpaError, - RegisterGhcbGpaRequest, + register_ghcb_location, set_ghcb_address_and_exit, GhcbGpa, PageAssignment, + RegisterGhcbGpaError, RegisterGhcbGpaRequest, }, Translator, }; @@ -69,6 +73,12 @@ const SW_EXIT_CODE_MMIO_WRITE: u64 = 0x8000_0002; /// See table 6 in . const SW_EXIT_CODE_AP_JUMP_TABLE: u64 = 0x8000_0005; +/// +/// The value of the sw_exit_code field when doing a Page State Change request. +/// +/// See table 6 in . +const SW_EXIT_CODE_PAGE_STATE_CHANGE: u64 = 0x8000_0010; + /// /// The value of the sw_exit_code field when doing a Guest Message request. /// @@ -110,6 +120,9 @@ const BASE_VALID_BITMAP: ValidBitmap = /// and RDX, so we only use the least significant 32 bits. const MSR_REGISTER_MASK: u64 = 0xffff_ffff; +/// Size of the shared buffer space in the GHCB structure. +const SHARED_BUFFER_SIZE: usize = 2032; + /// The guest-host communications block. /// /// See: Table 3 in @@ -168,7 +181,7 @@ pub struct Ghcb { _reserved_7: [u8; 1016], /// Area that can be used as a shared buffer for communicating additional /// information. - pub shared_buffer: [u8; 2032], + pub shared_buffer: [u8; SHARED_BUFFER_SIZE], /// Reserved. Must be 0. _reserved_8: [u8; 10], /// The version of the GHCB protocol and page layout in use. @@ -583,6 +596,46 @@ where } } + /// Performs a Page State Change operation on the given physical frame. + /// + /// See section 4.1.6 in . + pub fn page_state_change( + &mut self, + frame: PhysFrame, + assignment: PageAssignment, + ) -> Result<(), &'static str> { + let mut page_state_change = PageStateChange::new_zeroed(); + page_state_change.header.cur_entry = 0; + page_state_change.header.end_entry = 0; + page_state_change.entry[0] = + PageStateChangeEntry { page_operation: assignment, gfn: frame, current_page: 0 }.into(); + + let gpa_base = self.get_gpa().as_u64(); + let ghcb = self.ghcb.as_mut(); + ghcb.reset(); + ghcb.sw_exit_code = SW_EXIT_CODE_PAGE_STATE_CHANGE; + ghcb.sw_scratch = gpa_base + (core::mem::offset_of!(Ghcb, shared_buffer) as u64); + page_state_change + .write_to(&mut ghcb.shared_buffer) + .ok_or("Unexpected length mismatch between PSC request and GHCB shared buffer")?; + ghcb.valid_bitmap = BASE_VALID_BITMAP | ValidBitmap::SW_SCRATCH; + + self.do_vmg_exit()?; + + let ghcb = self.ghcb.as_ref(); + let page_state_change = PageStateChange::read_from(&ghcb.shared_buffer) + .ok_or("Unexpected length mismatch between PSC request and GHCB shared buffer")?; + // If cur_entry did not move past end_entry, SW_EXITINFO2 will contain + // additional information. For now, we just return a generic error. + if page_state_change.header.cur_entry <= page_state_change.header.end_entry { + return Err("cur_entry did not move!"); + }; + // According to documentation the `current_page` field in each + // `PageStateChangeEntry` struct should change to show that a page has been + // successfully processed, but I always get zeroes there. + Ok(()) + } + /// Sets the address of the GHCB, exits to the hypervisor, and checks the /// return value when execution resumes. fn do_vmg_exit(&mut self) -> Result<(), &'static str> { @@ -614,3 +667,73 @@ fn reset_slice(slice: &mut [u8]) { *byte = 0; } } + +#[repr(C)] +#[derive(AsBytes, FromBytes, FromZeroes)] +struct PageStateChangeHeader { + cur_entry: u16, + end_entry: u16, + _reserved: u32, +} + +/// Page State Change structure. +/// +/// See section 4.1.6 in . +#[repr(C)] +#[derive(AsBytes, FromBytes, FromZeroes)] +struct PageStateChange { + header: PageStateChangeHeader, + entry: [u64; 253], +} +static_assertions::assert_eq_size!(PageStateChange, [u8; SHARED_BUFFER_SIZE]); + +/// Page State Change Entry. +/// +/// See Table 9 in . +struct PageStateChangeEntry { + page_operation: PageAssignment, + + /// Physical frame in guest memory for the operation. + gfn: PhysFrame, + + /// Input: offset, in 4K increments, on which to bein the page state change + /// operation. Output: offset of the current page in 4K increments that + /// has been successfully processed. + current_page: u16, +} + +impl From> for u64 { + fn from(value: PageStateChangeEntry) -> Self { + // [63:57]: Reserved, must be zero. Reset the whole thing to zero. + let mut entry = 0u64; + // [56]: Page size. + entry |= match S::SIZE { + Size4KiB::SIZE => (SevPageSize::Page4KiB as u64) << 56, + Size2MiB::SIZE => (SevPageSize::Page2MiB as u64) << 56, + _ => unreachable!("Unexpected non-giant page size (not 4 KiB or 2 MiB)"), + }; + // [55:52]: Page operation. + entry |= (value.page_operation as u64) << 52; + // [51:12]: Guest physical frame number + // PhysFrame guarantees that the address is properly aligned. + entry |= value.gfn.start_address().as_u64(); + // [11:0]: Current page. + entry |= (value.current_page as u64) & 0xFFF; + entry + } +} + +impl TryFrom for PageStateChangeEntry { + type Error = &'static str; + + fn try_from(value: u64) -> Result { + let page_operation = PageAssignment::from_repr(((value & 0x10_0000_0000_0000) >> 52) as u8) + .ok_or("Invalid page assignment field value")?; + let address = PhysAddr::new(value & 0x000F_FFFF_FFFF_F000); + let gfn = + PhysFrame::from_start_address(address).map_err(|_| "Frame address not aligned")?; + let current_page = (value & 0x0FFF) as u16; + + Ok(PageStateChangeEntry { page_operation, gfn, current_page }) + } +} diff --git a/oak_sev_guest/src/instructions.rs b/oak_sev_guest/src/instructions.rs index dfe8a02013a..0b7d1f1e039 100644 --- a/oak_sev_guest/src/instructions.rs +++ b/oak_sev_guest/src/instructions.rs @@ -33,7 +33,7 @@ pub enum Validation { } /// The size of a memory page. -#[derive(Debug, FromRepr)] +#[derive(Clone, Copy, Debug, FromRepr)] #[repr(u32)] pub enum PageSize { /// The page is a 4KiB page. diff --git a/proto/attestation/dice.proto b/proto/attestation/dice.proto index ea025f2e73f..e8cca10ea20 100644 --- a/proto/attestation/dice.proto +++ b/proto/attestation/dice.proto @@ -24,10 +24,11 @@ option go_package = "proto/oak/attestation/v1"; option java_multiple_files = true; option java_package = "com.google.oak.attestation.v1"; -// Message for passing embedded certificate authority information between layers. -// Will never appear in the evidence that is sent to the client. +// Message for passing embedded certificate authority information between +// layers. Will never appear in the evidence that is sent to the client. message CertificateAuthority { - // ECA private key that will be used by a layer to sign a certificate for the next layer. + // ECA private key that will be used by a layer to sign a certificate for the + // next layer. bytes eca_private_key = 1; } diff --git a/proto/attestation/evidence.proto b/proto/attestation/evidence.proto index cf5e88beaf4..a54a52bc30e 100644 --- a/proto/attestation/evidence.proto +++ b/proto/attestation/evidence.proto @@ -84,6 +84,20 @@ message ApplicationKeys { // Represented as a CBOR/COSE/CWT ECA certificate. // bytes signing_public_key_certificate = 2; + + // Certificate signing the group encryption public key as part of Key + // Provisioning. + // + // Represented as a CBOR/COSE/CWT ECA certificate. + // + bytes group_encryption_public_key_certificate = 3; + + // Certificate signing the group signing public key as part of Key + // Provisioning. + // + // Represented as a CBOR/COSE/CWT ECA certificate. + // + bytes group_signing_public_key_certificate = 4; } // Attestation Evidence used by the client to the identity of firmware and diff --git a/proto/attestation/reference_value.proto b/proto/attestation/reference_value.proto index f0166e163fd..abfaa43a5ac 100644 --- a/proto/attestation/reference_value.proto +++ b/proto/attestation/reference_value.proto @@ -19,8 +19,8 @@ syntax = "proto3"; package oak.attestation.v1; -import "proto/digest.proto"; import "proto/attestation/tcb_version.proto"; +import "proto/digest.proto"; option go_package = "proto/oak/attestation/v1"; option java_multiple_files = true; @@ -79,14 +79,17 @@ message FileReferenceValue { // Allowable digests for the file. Digests digests = 1; - // Absolute path to the file in question, or just the file name. Relative paths are - // not supported. + // Absolute path to the file in question, or just the file name. Relative + // paths are not supported. string path = 2; } -// Verifies that a particular string is equal to at least one of the specified ones. -// No checks are performed if this is empty. +// Verifies that a particular string is equal to at least one of the specified +// ones. No checks are performed if this is empty. message StringReferenceValue { + // Use TextReferenceValue instead. + option deprecated = true; + repeated string values = 1; } @@ -94,21 +97,41 @@ message Regex { string value = 1; } +// A match in at least one value is considered a success. At least one value +// must be specified, otherwise verification fails. +message StringLiterals { + repeated string value = 1; +} + message RegexReferenceValue { + // Use TextReferenceValue instead. + option deprecated = true; + oneof type { SkipVerification skip = 1; Regex regex = 2; } } +// Reference value to match text via endorsement, or directly via constants +// or a regular expression. +message TextReferenceValue { + oneof type { + SkipVerification skip = 1; + EndorsementReferenceValue endorsement = 4; + Regex regex = 2; + StringLiterals string_literals = 3; + } +} + message RootLayerReferenceValues { // Switches between AMD SEV-SNP and Intel TDX based on TeePlatform value. // Verification is skipped when not running in a TEE. AmdSevReferenceValues amd_sev = 1; IntelTdxReferenceValues intel_tdx = 2; - // When insecure is set no verification of the TEE platform is performed. This can - // be used when not running in a TEE or when the client is agnostic about the - // platform and doesn't care about the hardware verification. + // When insecure is set no verification of the TEE platform is performed. This + // can be used when not running in a TEE or when the client is agnostic about + // the platform and doesn't care about the hardware verification. InsecureReferenceValues insecure = 3; } @@ -139,17 +162,16 @@ message KernelLayerReferenceValues { // Verifies the kernel based on endorsement. KernelBinaryReferenceValue kernel = 1; - // No longer used. Will be removed. - // TODO: b/325979696 - Validate the kernel command-line using a regex. - BinaryReferenceValue kernel_cmd_line = 2 [deprecated = true]; - - // Validates the kernel command-line using a regex. - RegexReferenceValue kernel_cmd_line_regex = 8; + // Verifies the kernel command line, i.e. the parameters passed to the + // kernel on boot. + TextReferenceValue kernel_cmd_line_text = 9; - // Fields are deprecated and kept only for backwards compatibility. - // Remove ASAP. + // Fields are deprecated and kept only for backwards compatibility. They are + // not being used by the verifier. Remove ASAP. BinaryReferenceValue kernel_setup_data = 3 [deprecated = true]; BinaryReferenceValue kernel_image = 7 [deprecated = true]; + RegexReferenceValue kernel_cmd_line_regex = 8 [deprecated = true]; + BinaryReferenceValue kernel_cmd_line = 2 [deprecated = true]; // Verifies the stage1 binary if running as Oak Containers. BinaryReferenceValue init_ram_fs = 4; diff --git a/proto/attestation/tcb_version.proto b/proto/attestation/tcb_version.proto index 0b5f3024e80..359dbbe35ff 100644 --- a/proto/attestation/tcb_version.proto +++ b/proto/attestation/tcb_version.proto @@ -17,6 +17,10 @@ syntax = "proto3"; package oak.attestation.v1; +option go_package = "proto/oak/attestation/v1"; +option java_multiple_files = true; +option java_package = "com.google.oak.attestation.v1"; + // The versions of the components in the AMD SEV-SNP platform Trusted Compute // Base (TCB). message TcbVersion { diff --git a/proto/containers/hostlib_key_provisioning.proto b/proto/containers/hostlib_key_provisioning.proto index 2b06cbfcb7d..49051894f56 100644 --- a/proto/containers/hostlib_key_provisioning.proto +++ b/proto/containers/hostlib_key_provisioning.proto @@ -24,7 +24,7 @@ import "proto/key_provisioning/key_provisioning.proto"; enum KeyProvisioningRole { KEY_PROVISIONING_ROLE_UNSPECIFIED = 0; LEADER = 1; - DEPENDANT = 2; + FOLLOWER = 2; } message GetKeyProvisioningRoleResponse { @@ -35,18 +35,22 @@ message GetGroupKeysResponse { oak.key_provisioning.v1.GroupKeys group_keys = 1; } -// Defines the service exposed by the Hostlib that is used provide the Orchestrator with group keys. +// Defines the service exposed by the Hostlib that is used provide the +// Orchestrator with group keys. service HostlibKeyProvisioning { // Get the enclave role for Key Provisioning. // Could be one of the following: // - Leader that generates group keys and distributes them. - // - Dependant that requests group keys from the leader. - rpc GetKeyProvisioningRole(google.protobuf.Empty) returns (GetKeyProvisioningRoleResponse) {} + // - Follower that requests group keys from the leader. + rpc GetKeyProvisioningRole(google.protobuf.Empty) + returns (GetKeyProvisioningRoleResponse) {} // Get enclave group keys to the enclave as part of Key Provisioning. - // This method is only called by the Dependant Orchestrator. + // This method is only called by the Follower Orchestrator. // - // This method must be called after `oak.containers.Launcher.SendAttestationEvidence`, because - // Hostlib needs to have the Attestation Evidence in order to request group keys from the leader. + // This method must be called after + // `oak.containers.Launcher.SendAttestationEvidence`, because Hostlib needs to + // have the Attestation Evidence in order to request group keys from the + // leader. rpc GetGroupKeys(google.protobuf.Empty) returns (GetGroupKeysResponse) {} } diff --git a/proto/containers/orchestrator_crypto.proto b/proto/containers/orchestrator_crypto.proto index b64ffc6220b..5c6e0714010 100644 --- a/proto/containers/orchestrator_crypto.proto +++ b/proto/containers/orchestrator_crypto.proto @@ -20,8 +20,8 @@ package oak.containers.v1; import "proto/crypto/crypto.proto"; -// Choice between a key generated by the enclave instance and the key distributed to the enclave -// group with Key Provisioning. +// Choice between a key generated by the enclave instance and the key +// distributed to the enclave group with Key Provisioning. enum KeyOrigin { KEY_ORIGIN_UNSPECIFIED = 0; INSTANCE = 1; @@ -30,12 +30,14 @@ enum KeyOrigin { message DeriveSessionKeysRequest { KeyOrigin key_origin = 1; - // Ephemeral Diffie-Hellman client public key that is needed to derive session keys. + // Ephemeral Diffie-Hellman client public key that is needed to derive session + // keys. bytes serialized_encapsulated_public_key = 2; } message DeriveSessionKeysResponse { - // Session keys for decrypting client requests and encrypting enclave responses. + // Session keys for decrypting client requests and encrypting enclave + // responses. oak.crypto.v1.SessionKeys session_keys = 1; } @@ -53,8 +55,10 @@ message SignResponse { // - Sign arbitrary data // TODO(#4504): Implement data signing. service OrchestratorCrypto { - // Derives session keys for decrypting client requests and encrypting enclave responses. - rpc DeriveSessionKeys(DeriveSessionKeysRequest) returns (DeriveSessionKeysResponse) {} + // Derives session keys for decrypting client requests and encrypting enclave + // responses. + rpc DeriveSessionKeys(DeriveSessionKeysRequest) + returns (DeriveSessionKeysResponse) {} // Signs the provided message using the hardware rooted signing key. rpc Sign(SignRequest) returns (SignResponse) {} } diff --git a/proto/crypto/crypto.proto b/proto/crypto/crypto.proto index 5dcd755f197..4d9b0605325 100644 --- a/proto/crypto/crypto.proto +++ b/proto/crypto/crypto.proto @@ -29,8 +29,8 @@ message EncryptedRequest { // Message encrypted with Authenticated Encryption with Associated Data (AEAD) // using the derived session key. AeadEncryptedMessage encrypted_message = 1; - // Ephemeral Diffie-Hellman client public key that is needed to derive a session key. - // Only sent in the first message of the secure session. + // Ephemeral Diffie-Hellman client public key that is needed to derive a + // session key. Only sent in the first message of the secure session. optional bytes serialized_encapsulated_public_key = 2; } @@ -51,8 +51,9 @@ message AeadEncryptedMessage { bytes nonce = 3; } -// Envelope containing session keys required to encrypt/decrypt messages within a secure session. -// Needed to serialize contexts in order to send them over an RPC. +// Envelope containing session keys required to encrypt/decrypt messages within +// a secure session. Needed to serialize contexts in order to send them over an +// RPC. message SessionKeys { // AEAD key for encrypting/decrypting client requests. bytes request_key = 1; diff --git a/proto/key_provisioning/key_provisioning.proto b/proto/key_provisioning/key_provisioning.proto index 073cdefcc92..17516d85201 100644 --- a/proto/key_provisioning/key_provisioning.proto +++ b/proto/key_provisioning/key_provisioning.proto @@ -23,14 +23,14 @@ import "proto/attestation/evidence.proto"; import "proto/attestation/endorsement.proto"; message GroupKeys { - // Encryption private key that was encrypted with HPKE using the encryption public key provided - // in the endorsed evidence. + // Encryption private key that was encrypted with HPKE using the encryption + // public key provided in the endorsed evidence. oak.crypto.v1.EncryptedRequest encrypted_encryption_private_key = 1; } message GetGroupKeysRequest { - // Evidence contains the encryption public key for encrypting the group encryption key with - // Hybrid Encryption. + // Evidence contains the encryption public key for encrypting the group + // encryption key with Hybrid Encryption. // oak.attestation.v1.Evidence evidence = 1; oak.attestation.v1.Endorsements endorsements = 2; @@ -42,6 +42,7 @@ message GetGroupKeysResponse { // Defines the Key Provisioning Service that distributes keys between enclaves. service KeyProvisioning { - // Request enclave group keys from for other enclaves as part of Key Provisioning. + // Request enclave group keys from for other enclaves as part of Key + // Provisioning. rpc GetGroupKeys(GetGroupKeysRequest) returns (GetGroupKeysResponse) {} } diff --git a/proto/micro_rpc/messages.proto b/proto/micro_rpc/messages.proto index 26993ff128b..2198646bceb 100644 --- a/proto/micro_rpc/messages.proto +++ b/proto/micro_rpc/messages.proto @@ -23,10 +23,10 @@ option java_package = "com.google.micro_rpc"; // A wrapper message representing a request over a transport. message RequestWrapper { - // The id of the method to invoke. This is usually specified via the IDL that drives the code - // generation. In other contexts (e.g. gRPC), IDLs may use the method name as part of the - // serialization, but here we prefer an opaque numeric identifier for both conciseness and - // stability across renames. + // The id of the method to invoke. This is usually specified via the IDL that + // drives the code generation. In other contexts (e.g. gRPC), IDLs may use the + // method name as part of the serialization, but here we prefer an opaque + // numeric identifier for both conciseness and stability across renames. uint32 method_id = 1; // The bytes of the serialized request. bytes body = 2; @@ -34,7 +34,8 @@ message RequestWrapper { // A message representing an error status code with associated message. // -// Similar to https://github.com/googleapis/googleapis/blob/master/google/rpc/status.proto. +// Similar to +// https://github.com/googleapis/googleapis/blob/master/google/rpc/status.proto. message Status { // The status code, which should be an enum value of // https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto. diff --git a/proto/oak_functions/abi.proto b/proto/oak_functions/abi.proto index 05b8a78a291..5234e984750 100644 --- a/proto/oak_functions/abi.proto +++ b/proto/oak_functions/abi.proto @@ -21,7 +21,8 @@ package oak.functions.abi; option java_multiple_files = true; option java_package = "oak.functions.abi"; -// The client can check the configuration report for the configuration of the Oak Functions runtime. +// The client can check the configuration report for the configuration of the +// Oak Functions runtime. message ConfigurationReport { // Hash of the loaded Wasm module. bytes wasm_hash = 1; @@ -29,26 +30,29 @@ message ConfigurationReport { ServerPolicy policy = 2; } -/// Server-side policy describing limits on the size of the response and response processing time to -/// avoid side-channel leaks. +/// Server-side policy describing limits on the size of the response and +/// response processing time to avoid side-channel leaks. message ServerPolicy { // A fixed size for responses returned by the trusted runtime. // - // This size only applies to the body of the Oak Functions response. If the response body - // computed by the Wasm module is smaller than this amount, it is padded with additional - // data before serialization and inclusion in the HTTP response to the client. If the body is - // larger than this amount, the trusted runtime discards the response and instead uses a - // response with a body of exactly this size, containing an error message indicating the - // policy violation. The body included in the HTTP response sent to the client is the binary - // protobuf encoding of the Oak Functions response, and will have a size larger than - // `constant_response_size_bytes`. However, this size is still guaranteed to be a constant. + // This size only applies to the body of the Oak Functions response. If the + // response body computed by the Wasm module is smaller than this amount, it + // is padded with additional data before serialization and inclusion in the + // HTTP response to the client. If the body is larger than this amount, the + // trusted runtime discards the response and instead uses a response with a + // body of exactly this size, containing an error message indicating the + // policy violation. The body included in the HTTP response sent to the client + // is the binary protobuf encoding of the Oak Functions response, and will + // have a size larger than `constant_response_size_bytes`. However, this size + // is still guaranteed to be a constant. uint32 constant_response_size_bytes = 1; // A fixed response time, in milliseconds. // - // Similar to the previous one, but controls the amount of time the function is allowed to run - // for. If the function finishes before this time, the response is not sent back until the - // time is elapsed. If the function does not finish within this deadline, the trusted runtime - // sends a response to the client containing an error message indicating the failure. The size - // of this response is equal to the size specified by the previous parameter. + // Similar to the previous one, but controls the amount of time the function + // is allowed to run for. If the function finishes before this time, the + // response is not sent back until the time is elapsed. If the function does + // not finish within this deadline, the trusted runtime sends a response to + // the client containing an error message indicating the failure. The size of + // this response is equal to the size specified by the previous parameter. uint32 constant_processing_time_ms = 2; } diff --git a/proto/oak_functions/application_config.proto b/proto/oak_functions/application_config.proto index 4c4bb0c9205..076654e457d 100644 --- a/proto/oak_functions/application_config.proto +++ b/proto/oak_functions/application_config.proto @@ -25,7 +25,8 @@ enum HandlerType { // Use a wasm interpreter to load the module. HANDLER_WASM = 1; - // Interpret the module as a native .so file. Only supported when running on Oak Containers. + // Interpret the module as a native .so file. Only supported when running on + // Oak Containers. HANDLER_NATIVE = 2; } @@ -45,9 +46,11 @@ message ApplicationConfig { // Communication channel parameters. // The default behaviour depends on the flavour of Oak Functions: - // - when running on Restricted Kernel this setting is ignored completely as the communication + // - when running on Restricted Kernel this setting is ignored completely as + // the communication // channel is abstracted away by Restricted Kernel itself. - // - on Oak Containers, if not specified, the default communication channel is TCP. + // - on Oak Containers, if not specified, the default communication channel + // is TCP. oneof communication_channel { TcpCommunicationChannel tcp_channel = 2; VsockCommunicationChannel vsock_channel = 3; diff --git a/proto/oak_functions/sdk/oak_functions_wasm.proto b/proto/oak_functions/sdk/oak_functions_wasm.proto index d85561831b6..79827e61014 100644 --- a/proto/oak_functions/sdk/oak_functions_wasm.proto +++ b/proto/oak_functions/sdk/oak_functions_wasm.proto @@ -21,10 +21,11 @@ package oak.functions.wasm.v1; import "google/protobuf/wrappers.proto"; import "proto/micro_rpc/options.proto"; -// The standard API for Oak Functions, which is exposed to Wasm modules via micro RPC, and wrapped -// by the Oak Functions SDK. +// The standard API for Oak Functions, which is exposed to Wasm modules via +// micro RPC, and wrapped by the Oak Functions SDK. // -// Other forks of Oak Functions may customize the API and provide a different set of methods. +// Other forks of Oak Functions may customize the API and provide a different +// set of methods. service StdWasmApi { // Read a request from the client. // @@ -37,9 +38,9 @@ service StdWasmApi { // Write a response for the client. // - // Multiple calls overwrite the response, and only the last value is sent to the client. - // If the Oak Functions WebAssembly module never invokes this method, the Oak Functions - // runtime sends an empty response to the client. + // Multiple calls overwrite the response, and only the last value is sent to + // the client. If the Oak Functions WebAssembly module never invokes this + // method, the Oak Functions runtime sends an empty response to the client. // // method_id: 1 rpc WriteResponse(WriteResponseRequest) returns (WriteResponseResponse) { @@ -48,8 +49,8 @@ service StdWasmApi { // Writes a debug log message. // - // These log messages are considered sensitive, so will only be logged by the runtime if running - // in debug mode. + // These log messages are considered sensitive, so will only be logged by the + // runtime if running in debug mode. // // method_id: 2 rpc Log(LogRequest) returns (LogResponse) { @@ -66,7 +67,8 @@ service StdWasmApi { // Looks up multiple items from the in-memory key/value lookup store. // // method_id: 4 - rpc LookupDataMulti(LookupDataMultiRequest) returns (LookupDataMultiResponse) { + rpc LookupDataMulti(LookupDataMultiRequest) + returns (LookupDataMultiResponse) { option (.oak.micro_rpc.method_id) = 4; } @@ -124,7 +126,7 @@ message TestResponse { message BytesValue { bytes value = 1; - // If true, the value was found in the store. This is useful to distinguish between a value that - // was not found and a value that was found and is empty. + // If true, the value was found in the store. This is useful to distinguish + // between a value that was not found and a value that was found and is empty. bool found = 2; } diff --git a/proto/oak_functions/service/oak_functions.proto b/proto/oak_functions/service/oak_functions.proto index e574981dc28..5b87b733b33 100644 --- a/proto/oak_functions/service/oak_functions.proto +++ b/proto/oak_functions/service/oak_functions.proto @@ -38,20 +38,23 @@ service OakFunctions { } // Extends the next lookup data by the given chunk of lookup data. Only - // after the sender calls finishes building the next lookup data, the receiver replaces the - // current lookup data with the next lookup data, and only then chunk is will be served in - // lookups. + // after the sender calls finishes building the next lookup data, the receiver + // replaces the current lookup data with the next lookup data, and only then + // chunk is will be served in lookups. // // method_id: 2 - rpc ExtendNextLookupData(ExtendNextLookupDataRequest) returns (ExtendNextLookupDataResponse) { + rpc ExtendNextLookupData(ExtendNextLookupDataRequest) + returns (ExtendNextLookupDataResponse) { option (.oak.micro_rpc.method_id) = 2; } - // Finishes building the next lookup data with the given chunk of lookup data. The receiver - // replaces the current lookup data and the next lookup data will be served in lookups. + // Finishes building the next lookup data with the given chunk of lookup data. + // The receiver replaces the current lookup data and the next lookup data will + // be served in lookups. // // method_id: 3 - rpc FinishNextLookupData(FinishNextLookupDataRequest) returns (FinishNextLookupDataResponse) { + rpc FinishNextLookupData(FinishNextLookupDataRequest) + returns (FinishNextLookupDataResponse) { option (.oak.micro_rpc.method_id) = 3; } @@ -62,19 +65,22 @@ service OakFunctions { option (.oak.micro_rpc.method_id) = 4; } - // Streaming version combining `ExtendNextLookupData` and `FinishNextLookupData`. + // Streaming version combining `ExtendNextLookupData` and + // `FinishNextLookupData`. // // This is mainly for use with gRPC, as microRPC doesn't support streaming. // // method_id: 5 - rpc StreamLookupData(stream LookupDataChunk) returns (FinishNextLookupDataResponse) { + rpc StreamLookupData(stream LookupDataChunk) + returns (FinishNextLookupDataResponse) { option (.oak.micro_rpc.method_id) = 5; } // Reserves additional capacity for entries in the lookup table. // - // It should be called before `ExtendNextLookupData`/`StreamLookupData` to reduce the - // number of memory allocations, but it's not mandatory to call this RPC. + // It should be called before `ExtendNextLookupData`/`StreamLookupData` to + // reduce the number of memory allocations, but it's not mandatory to call + // this RPC. // // method_id: 6 rpc Reserve(ReserveRequest) returns (ReserveResponse) { @@ -108,8 +114,8 @@ message LookupDataChunk { repeated LookupDataEntry items = 1; } -// If the definition of ExtendNextLookupData changes, the estimation of the size when -// serialized in the Oak Functions Launcher needs to change, too. +// If the definition of ExtendNextLookupData changes, the estimation of the size +// when serialized in the Oak Functions Launcher needs to change, too. message ExtendNextLookupDataRequest { LookupDataChunk chunk = 1; } diff --git a/proto/session/BUILD b/proto/session/BUILD index b92faa7f5f6..c790d87e75c 100644 --- a/proto/session/BUILD +++ b/proto/session/BUILD @@ -75,6 +75,7 @@ cc_proto_library( cc_grpc_library( name = "service_unary_cc_grpc", srcs = [":service_unary_proto"], + generate_mocks = True, grpc_only = True, deps = [":service_unary_cc_proto"], ) diff --git a/proto/session/messages.proto b/proto/session/messages.proto index ec93a58524f..319e4b3a792 100644 --- a/proto/session/messages.proto +++ b/proto/session/messages.proto @@ -25,8 +25,8 @@ import "proto/attestation/evidence.proto"; option java_multiple_files = true; option java_package = "com.google.oak.session.v1"; -// Endorsed evidence contains an attestation evidence provided by the enclave and the corresponding -// attestation endorsements provided by the hostlib. +// Endorsed evidence contains an attestation evidence provided by the enclave +// and the corresponding attestation endorsements provided by the hostlib. message EndorsedEvidence { oak.attestation.v1.Evidence evidence = 1; oak.attestation.v1.Endorsements endorsements = 2; diff --git a/proto/session/service_streaming.proto b/proto/session/service_streaming.proto index a454ae3eac5..1b779a881a4 100644 --- a/proto/session/service_streaming.proto +++ b/proto/session/service_streaming.proto @@ -39,18 +39,21 @@ message ResponseWrapper { // Service definition for streaming communication with an Oak server. service StreamingSession { - // Used to send a sequence of messages ensuring that they are all handled by the same server - // instance, by virtue of being multiplexed over a single gRPC stream. + // Used to send a sequence of messages ensuring that they are all handled by + // the same server instance, by virtue of being multiplexed over a single gRPC + // stream. // - // The `RequestWrapper` and `ResponseWrapper` messages are thin wrappers around the underlying - // messages exchanged by client and server, giving it a minimal amount of structure and type - // safety. + // The `RequestWrapper` and `ResponseWrapper` messages are thin wrappers + // around the underlying messages exchanged by client and server, giving it a + // minimal amount of structure and type safety. // - // The expected message sequence starts with the client sending a `GetEndorsedEvidenceRequest` - // message in order to fetch the evidence of the enclave. This method may be handled by the - // untrusted launcher or by the enclave, depending on the server implementation. + // The expected message sequence starts with the client sending a + // `GetEndorsedEvidenceRequest` message in order to fetch the evidence of the + // enclave. This method may be handled by the untrusted launcher or by the + // enclave, depending on the server implementation. // - // Then the client encrypts the payload with the public key contained in the evidence via a hybrid - // encryption protocol, and sends the encrypted payload as part of a `InvokeRequest` message. + // Then the client encrypts the payload with the public key contained in the + // evidence via a hybrid encryption protocol, and sends the encrypted payload + // as part of a `InvokeRequest` message. rpc Stream(stream RequestWrapper) returns (stream ResponseWrapper); } diff --git a/proto/session/service_unary.proto b/proto/session/service_unary.proto index 7b513e4080a..4c4e0d91370 100644 --- a/proto/session/service_unary.proto +++ b/proto/session/service_unary.proto @@ -26,7 +26,8 @@ option java_package = "com.google.oak.session.v1"; // Service definition for unary communication with Oak server. service UnarySession { // Gets a attestation evidence and endorsements. - rpc GetEndorsedEvidence(GetEndorsedEvidenceRequest) returns (GetEndorsedEvidenceResponse); + rpc GetEndorsedEvidence(GetEndorsedEvidenceRequest) + returns (GetEndorsedEvidenceResponse); // Performs lookup for a list of encrypted keys. The keys should be encrypted // using the Public key provided by the enclave. The response is encrypted diff --git a/stage0/src/sev.rs b/stage0/src/sev.rs index 9c9942c64df..43de67b9158 100644 --- a/stage0/src/sev.rs +++ b/stage0/src/sev.rs @@ -19,6 +19,7 @@ use core::{ alloc::{AllocError, Allocator, Layout}, ops::{Deref, DerefMut}, ptr::NonNull, + sync::atomic::{AtomicUsize, Ordering}, }; use oak_core::sync::OnceCell; @@ -38,8 +39,9 @@ use spinning_top::Spinlock; use x86_64::{ instructions::tlb, structures::paging::{ - frame::PhysFrameRange, page::AddressNotAligned, Page, PageSize, PageTable, PageTableFlags, - PhysFrame, Size1GiB, Size2MiB, Size4KiB, + frame::PhysFrameRange, + page::{AddressNotAligned, NotGiantPageSize}, + Page, PageSize, PageTable, PageTableFlags, PhysFrame, Size1GiB, Size2MiB, Size4KiB, }, PhysAddr, VirtAddr, }; @@ -230,7 +232,7 @@ pub fn unshare_page(page: Page) { } tlb::flush_all(); // We have to revalidate the page again after un-sharing it. - if let Err(err) = page.pvalidate() { + if let Err(err) = page.pvalidate(&counters::VALIDATED_4K) { if err != InstructionError::ValidationStatusNotUpdated { panic!("shared page revalidation failed"); } @@ -256,12 +258,14 @@ impl ValidatablePageSize for Size2MiB { } trait Validate { - fn pvalidate(&self) -> Result<(), InstructionError>; + fn pvalidate(&self, counter: &AtomicUsize) -> Result<(), InstructionError>; } impl Validate for Page { - fn pvalidate(&self) -> Result<(), InstructionError> { - pvalidate(self.start_address().as_u64() as usize, S::SEV_PAGE_SIZE, Validation::Validated) + fn pvalidate(&self, counter: &AtomicUsize) -> Result<(), InstructionError> { + pvalidate(self.start_address().as_u64() as usize, S::SEV_PAGE_SIZE, Validation::Validated)?; + counter.fetch_add(1, Ordering::SeqCst); + Ok(()) } } @@ -280,11 +284,12 @@ impl MappedPage { } } -fn pvalidate_range( +fn pvalidate_range( range: &PhysFrameRange, memory: &mut MappedPage, encrypted: u64, flags: PageTableFlags, + success_counter: &AtomicUsize, mut f: F, ) -> Result<(), InstructionError> where @@ -324,7 +329,7 @@ where .iter() .zip(pages) .filter(|(entry, _)| !entry.is_unused()) - .map(|(entry, page)| (entry, page.pvalidate())) + .map(|(entry, page)| (entry, page.pvalidate(success_counter))) .map(|(entry, result)| result.or_else(|err| f(entry.addr(), err))) .find(|result| result.is_err()) { @@ -338,6 +343,23 @@ where Ok(()) } +pub mod counters { + use core::sync::atomic::AtomicUsize; + + /// Number of PVALIDATE invocations that did not change Validated state. + pub static ERROR_VALIDATION_STATUS_NOT_UPDATED: AtomicUsize = AtomicUsize::new(0); + + /// Number of FAIL_SIZEMISMATCH errors when invoking PVALIDATE on 2 MiB + /// pages. + pub static ERROR_FAIL_SIZE_MISMATCH: AtomicUsize = AtomicUsize::new(0); + + /// Number of successful PVALIDATE invocations on 2 MiB pages. + pub static VALIDATED_2M: AtomicUsize = AtomicUsize::new(0); + + /// Number of successful PVALIDATE invocations on 4 KiB pages. + pub static VALIDATED_4K: AtomicUsize = AtomicUsize::new(0); +} + trait Validatable4KiB { /// Validate a region of memory using 4 KiB pages. /// @@ -357,15 +379,23 @@ impl Validatable4KiB for PhysFrameRange { pt: &mut MappedPage, encrypted: u64, ) -> Result<(), InstructionError> { - pvalidate_range(self, pt, encrypted, PageTableFlags::empty(), |_addr, err| match err { - InstructionError::ValidationStatusNotUpdated => { - // We don't treat this as an error. It only happens if SEV-SNP is not enabled, - // or it is already validated. See the PVALIDATE instruction in - // for more details. - Ok(()) - } - other => Err(other), - }) + pvalidate_range( + self, + pt, + encrypted, + PageTableFlags::empty(), + &counters::VALIDATED_4K, + |_addr, err| match err { + InstructionError::ValidationStatusNotUpdated => { + // We don't treat this as an error. It only happens if SEV-SNP is not enabled, + // or it is already validated. See the PVALIDATE instruction in + // for more details. + counters::ERROR_VALIDATION_STATUS_NOT_UPDATED.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + other => Err(other), + }, + ) } } @@ -393,26 +423,55 @@ impl Validatable2MiB for PhysFrameRange { pt: &mut MappedPage, encrypted: u64, ) -> Result<(), InstructionError> { - pvalidate_range(self, pd, encrypted, PageTableFlags::HUGE_PAGE, |addr, err| match err { - InstructionError::FailSizeMismatch => { - // 2MiB is no go, fail back to 4KiB pages. - // This will not panic as every address that is 2 MiB-aligned is by definition - // also 4 KiB-aligned. - let start = PhysFrame::::from_start_address(PhysAddr::new( - addr.as_u64() & !encrypted, - )) - .unwrap(); - let range = PhysFrame::range(start, start + 512); - range.pvalidate(pt, encrypted) - } - InstructionError::ValidationStatusNotUpdated => { - // We don't treat this as an error. It only happens if SEV-SNP is not enabled, - // or it is already validated. See the PVALIDATE instruction in - // for more details. - Ok(()) - } - other => Err(other), - }) + pvalidate_range( + self, + pd, + encrypted, + PageTableFlags::HUGE_PAGE, + &counters::VALIDATED_2M, + |addr, err| match err { + InstructionError::FailSizeMismatch => { + // 2MiB is no go, fail back to 4KiB pages. + // This will not panic as every address that is 2 MiB-aligned is by definition + // also 4 KiB-aligned. + counters::ERROR_FAIL_SIZE_MISMATCH.fetch_add(1, Ordering::SeqCst); + let start = PhysFrame::::from_start_address(PhysAddr::new( + addr.as_u64() & !encrypted, + )) + .unwrap(); + let range = PhysFrame::range(start, start + 512); + range.pvalidate(pt, encrypted) + } + InstructionError::ValidationStatusNotUpdated => { + // We don't treat this as an error. It only happens if SEV-SNP is not enabled, + // or it is already validated. See the PVALIDATE instruction in + // for more details. + counters::ERROR_VALIDATION_STATUS_NOT_UPDATED.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + other => Err(other), + }, + ) + } +} + +trait PageStateChange { + fn page_state_change(&self, assignment: PageAssignment) -> Result<(), &'static str>; +} + +impl PageStateChange for PhysFrameRange { + fn page_state_change(&self, assignment: PageAssignment) -> Result<(), &'static str> { + // Future optimization: do this operation in batches of 253 frames (that's how + // many can fit in one PageStateChange request) instead of one at a time. + for frame in *self { + GHCB_WRAPPER + .get() + .expect("GHCB not initialized") + .lock() + .page_state_change(frame, assignment)?; + } + + Ok(()) } } @@ -459,34 +518,82 @@ pub fn validate_memory(e820_table: &[BootE820Entry], encrypted: u64) { continue; } - let start_address = PhysAddr::new(entry.addr() as u64); - let limit_address = PhysAddr::new((entry.addr() + entry.size()) as u64); + let start_address = PhysAddr::new(entry.addr() as u64).align_up(Size4KiB::SIZE); + let limit_address = + PhysAddr::new((entry.addr() + entry.size()) as u64).align_down(Size4KiB::SIZE); + + if start_address > limit_address { + log::error!( + "nonsensical entry in E820 table: [{}, {})", + entry.addr(), + entry.addr() + entry.size() + ); + continue; + } - // If the memory boundaries align with 2 MiB, start with that. - if start_address.is_aligned(Size2MiB::SIZE) && limit_address.is_aligned(Size2MiB::SIZE) { + // Attempt to validate as many pages as possible using 2 MiB pages (aka + // hugepages). + let hugepage_start = start_address.align_up(Size2MiB::SIZE); + let hugepage_limit = limit_address.align_down(Size2MiB::SIZE); + + // If start_address == hugepage_start, we're aligned with 2M address boundary. + // Otherwise, we need to process any 4K pages before the alignment. + // Note that limit_address may be less than hugepage_start, which means that the + // E820 entry was less than 2M in size and didn't cross a 2M boundary. + if hugepage_start > start_address { + let limit = core::cmp::min(hugepage_start, limit_address); + // We know the addresses are aligned to at least 4K, so the unwraps are safe. + let range = PhysFrame::::range( + PhysFrame::from_start_address(start_address).unwrap(), + PhysFrame::from_start_address(limit).unwrap(), + ); + range.page_state_change(PageAssignment::Private).unwrap(); + range.pvalidate(&mut validation_pt, encrypted).expect("failed to validate memory"); + } + + // If hugepage_limit > hugepage_start, we've got some contiguous 2M chunks that + // we can process as hugepages. + if hugepage_limit > hugepage_start { // These unwraps can't fail as we've made sure that the addresses are 2 // MiB-aligned. let range = PhysFrame::::range( - PhysFrame::from_start_address(start_address).unwrap(), - PhysFrame::from_start_address(limit_address).unwrap(), + PhysFrame::from_start_address(hugepage_start).unwrap(), + PhysFrame::from_start_address(hugepage_limit).unwrap(), ); - range.pvalidate(&mut validation_pd, &mut validation_pt, encrypted) - } else { - // No such luck, fail over to 4K pages. - // The unwraps can't fail as we make sure that the addresses are 4 KiB-aligned. + range.page_state_change(PageAssignment::Private).unwrap(); + range + .pvalidate(&mut validation_pd, &mut validation_pt, encrypted) + .expect("failed to validate memory"); + } + + // And finally, we may have some trailing 4K pages in [hugepage_limit, + // limit_address) that we need to process. + if limit_address > hugepage_limit { + let start = core::cmp::max(start_address, hugepage_limit); + // We know the addresses are aligned to at least 4K, so the unwraps are safe. let range = PhysFrame::::range( - PhysFrame::from_start_address(start_address.align_up(Size4KiB::SIZE)).unwrap(), - PhysFrame::from_start_address(limit_address.align_down(Size4KiB::SIZE)).unwrap(), + PhysFrame::from_start_address(start).unwrap(), + PhysFrame::from_start_address(limit_address).unwrap(), ); - range.pvalidate(&mut validation_pt, encrypted) + range.page_state_change(PageAssignment::Private).unwrap(); + range.pvalidate(&mut validation_pt, encrypted).expect("failed to validate memory"); } - .expect("failed to validate memory"); } page_tables.pd_0[1].set_unused(); page_tables.pdpt[1].set_unused(); tlb::flush_all(); log::info!("SEV-SNP memory validation complete."); + log::info!(" Validated using 2 MiB pages: {}", counters::VALIDATED_2M.load(Ordering::SeqCst)); + log::info!(" Validated using 4 KiB pages: {}", counters::VALIDATED_4K.load(Ordering::SeqCst)); + log::info!( + " Valid state not updated: {}", + counters::ERROR_VALIDATION_STATUS_NOT_UPDATED.load(Ordering::SeqCst) + ); + log::info!( + " RMP page size mismatch errors (fallback to 4K): {}", + counters::ERROR_FAIL_SIZE_MISMATCH.load(Ordering::SeqCst) + ); } /// Initializes the Guest Message encryptor using VMPCK0.