From 7075f3f714330a540e7f2b736b4f60e02527e01d Mon Sep 17 00:00:00 2001 From: Naomi Plasterer Date: Fri, 17 May 2024 16:21:58 -0700 Subject: [PATCH] Migrate to new OpenMLS storage API (#692) * bump openMLS to get errors for updating * small tweaks * attempt at overriding this * format it * remove crypto config * get through all the initial errors * fix up the sql keystore to actually read and write correctly * getting the impl closer to compiling * error for methods we wont use * get everything compiling * more mls tweaks * fix up the basic credential issues * remove save functions * fix up the errirs * update the table name * fix up some of the mut stuff * cargo build works * bring back key value signature * change version to integer * use correct sqlite * update the migration and database name * fix up merge * update the schema * fix up some testings * get the test compiling * write some tests for it * keep it dry * fix up the test * make sure all the keys align * small tweaks to naming * update to latest mls with fixes * modify the tests a bit * comment out failing tests for now * fix up merge errors * fix up the concurrency of the db * format it * fix up a bunch of lint issues * fix up all the lint issues * fixup append and read_list * fix remove_item * fixup clear proposals * fix up the lint * serialize aad before writing * fixup psk sotrage and some debug info * store resumption psks * fix welcome decryption * fix encryption epoch key pairs * fix lock on the db * reformat * point at new openmls version * fix up a small lint issue * fix diesel deprecation warnings (#743) * point to the merge commit * remove some unwraps * remove another unwrap * pull out queries into constants * remove the final unwraps * dry up the query code * code clean up 1 * final pass on cleaning up code * fix up the linter * undo the changes that may have broken it * small tweak to linter * remove two unwraps * remove the unwrap from hpke * remove the last unwrap * surface pending commit error --------- Co-authored-by: Franziskus Kiefer Co-authored-by: Andrew Plaza Co-authored-by: Andrew Plaza --- Cargo.lock | 61 +- Cargo.toml | 8 +- bindings_ffi/Cargo.lock | 61 +- mls_validation_service/src/handlers.rs | 17 +- .../down.sql | 1 + .../2024-05-06-192337_openmls_storage/up.sql | 6 + xmtp_mls/src/client.rs | 8 +- xmtp_mls/src/groups/members.rs | 2 +- xmtp_mls/src/groups/mod.rs | 45 +- xmtp_mls/src/groups/sync.rs | 24 +- xmtp_mls/src/groups/validated_commit.rs | 15 +- xmtp_mls/src/groups/validated_commit_v2.rs | 2 +- xmtp_mls/src/hpke.rs | 67 +- xmtp_mls/src/identity/v3/legacy.rs | 55 +- xmtp_mls/src/identity/xmtp_id/identity.rs | 2 +- .../src/storage/encrypted_store/schema.rs | 9 + xmtp_mls/src/storage/errors.rs | 13 +- xmtp_mls/src/storage/sql_key_store.rs | 1303 ++++++++++++++++- xmtp_mls/src/verified_key_package.rs | 11 +- xmtp_mls/src/verified_key_package_v2.rs | 2 +- xmtp_mls/src/xmtp_openmls_provider.rs | 8 +- xmtp_proto/src/convert.rs | 2 +- 22 files changed, 1514 insertions(+), 208 deletions(-) create mode 100644 xmtp_mls/migrations/2024-05-06-192337_openmls_storage/down.sql create mode 100644 xmtp_mls/migrations/2024-05-06-192337_openmls_storage/up.sql diff --git a/Cargo.lock b/Cargo.lock index 2034beeef..124b90550 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -86,6 +86,15 @@ dependencies = [ "libc", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anstream" version = "0.6.13" @@ -1914,6 +1923,9 @@ name = "hex" version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +dependencies = [ + "serde", +] [[package]] name = "hkdf" @@ -2796,18 +2808,18 @@ dependencies = [ [[package]] name = "openmls" version = "0.5.0" -source = "git+https://github.com/xmtp/openmls?rev=52cad0e#52cad0e35cb2c88f83002e786e177bbc9065a76c" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" dependencies = [ "backtrace", "itertools 0.10.5", "log", "openmls_basic_credential", + "openmls_memory_storage", "openmls_rust_crypto", + "openmls_test", "openmls_traits", "rand", "rayon", - "rstest", - "rstest_reuse", "serde", "serde_json", "thiserror", @@ -2818,7 +2830,7 @@ dependencies = [ [[package]] name = "openmls_basic_credential" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=52cad0e#52cad0e35cb2c88f83002e786e177bbc9065a76c" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" dependencies = [ "ed25519-dalek", "openmls_traits", @@ -2829,11 +2841,14 @@ dependencies = [ ] [[package]] -name = "openmls_memory_keystore" +name = "openmls_memory_storage" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=52cad0e#52cad0e35cb2c88f83002e786e177bbc9065a76c" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" dependencies = [ + "hex", + "log", "openmls_traits", + "serde", "serde_json", "thiserror", ] @@ -2841,7 +2856,7 @@ dependencies = [ [[package]] name = "openmls_rust_crypto" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=52cad0e#52cad0e35cb2c88f83002e786e177bbc9065a76c" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" dependencies = [ "aes-gcm", "chacha20poly1305", @@ -2851,7 +2866,7 @@ dependencies = [ "hpke-rs", "hpke-rs-crypto", "hpke-rs-rust-crypto", - "openmls_memory_keystore", + "openmls_memory_storage", "openmls_traits", "p256", "rand", @@ -2862,10 +2877,25 @@ dependencies = [ "tls_codec 0.4.2-pre.1", ] +[[package]] +name = "openmls_test" +version = "0.1.0" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" +dependencies = [ + "ansi_term", + "openmls_rust_crypto", + "openmls_traits", + "proc-macro2", + "quote", + "rstest", + "rstest_reuse", + "syn 2.0.55", +] + [[package]] name = "openmls_traits" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=52cad0e#52cad0e35cb2c88f83002e786e177bbc9065a76c" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" dependencies = [ "serde", "tls_codec 0.4.2-pre.1", @@ -3803,9 +3833,9 @@ dependencies = [ [[package]] name = "rstest" -version = "0.16.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b07f2d176c472198ec1e6551dc7da28f1c089652f66a7b722676c2238ebc0edf" +checksum = "de1bb486a691878cd320c2f0d319ba91eeaa2e894066d8b5f8f117c000e9d962" dependencies = [ "futures", "futures-timer", @@ -3815,9 +3845,9 @@ dependencies = [ [[package]] name = "rstest_macros" -version = "0.16.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7229b505ae0706e64f37ffc54a9c163e11022a6636d58fe1f3f52018257ff9f7" +checksum = "290ca1a1c8ca7edb7c3283bd44dc35dd54fdec6253a3912e201ba1072018fca8" dependencies = [ "cfg-if", "proc-macro2", @@ -3829,11 +3859,12 @@ dependencies = [ [[package]] name = "rstest_reuse" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9b5aed35457441e7e0db509695ba3932d4c47e046777141c167efe584d0ec17" +checksum = "45f80dcc84beab3a327bbe161f77db25f336a1452428176787c8c79ac79d7073" dependencies = [ "quote", + "rand", "rustc_version", "syn 1.0.109", ] diff --git a/Cargo.toml b/Cargo.toml index c078032a3..e772fcd4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,10 +32,10 @@ futures-core = "0.3.30" hex = "0.4.3" jsonrpsee = { version = "0.22", features = ["macros", "server", "client-core"] } log = "0.4" -openmls = { git = "https://github.com/xmtp/openmls", rev = "52cad0e" } -openmls_basic_credential = { git = "https://github.com/xmtp/openmls", rev = "52cad0e" } -openmls_rust_crypto = { git = "https://github.com/xmtp/openmls", rev = "52cad0e" } -openmls_traits = { git = "https://github.com/xmtp/openmls", rev = "52cad0e" } +openmls = { git = "https://github.com/xmtp/openmls", rev = "0239b96" } +openmls_basic_credential = { git = "https://github.com/xmtp/openmls", rev = "0239b96" } +openmls_rust_crypto = { git = "https://github.com/xmtp/openmls", rev = "0239b96" } +openmls_traits = { git = "https://github.com/xmtp/openmls", rev = "0239b96" } prost = "^0.12" prost-types = "^0.12" rand = "0.8.5" diff --git a/bindings_ffi/Cargo.lock b/bindings_ffi/Cargo.lock index 24d640cd8..4c602f5da 100644 --- a/bindings_ffi/Cargo.lock +++ b/bindings_ffi/Cargo.lock @@ -86,6 +86,15 @@ dependencies = [ "libc", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anstream" version = "0.6.11" @@ -1900,6 +1909,9 @@ name = "hex" version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +dependencies = [ + "serde", +] [[package]] name = "hkdf" @@ -2592,18 +2604,18 @@ dependencies = [ [[package]] name = "openmls" version = "0.5.0" -source = "git+https://github.com/xmtp/openmls?rev=52cad0e#52cad0e35cb2c88f83002e786e177bbc9065a76c" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" dependencies = [ "backtrace", "itertools 0.10.5", "log", "openmls_basic_credential", + "openmls_memory_storage", "openmls_rust_crypto", + "openmls_test", "openmls_traits", "rand", "rayon", - "rstest", - "rstest_reuse", "serde", "serde_json", "thiserror", @@ -2614,7 +2626,7 @@ dependencies = [ [[package]] name = "openmls_basic_credential" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=52cad0e#52cad0e35cb2c88f83002e786e177bbc9065a76c" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" dependencies = [ "ed25519-dalek", "openmls_traits", @@ -2625,11 +2637,14 @@ dependencies = [ ] [[package]] -name = "openmls_memory_keystore" +name = "openmls_memory_storage" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=52cad0e#52cad0e35cb2c88f83002e786e177bbc9065a76c" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" dependencies = [ + "hex", + "log", "openmls_traits", + "serde", "serde_json", "thiserror", ] @@ -2637,7 +2652,7 @@ dependencies = [ [[package]] name = "openmls_rust_crypto" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=52cad0e#52cad0e35cb2c88f83002e786e177bbc9065a76c" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" dependencies = [ "aes-gcm", "chacha20poly1305", @@ -2647,7 +2662,7 @@ dependencies = [ "hpke-rs", "hpke-rs-crypto", "hpke-rs-rust-crypto", - "openmls_memory_keystore", + "openmls_memory_storage", "openmls_traits", "p256", "rand", @@ -2658,10 +2673,25 @@ dependencies = [ "tls_codec 0.4.2-pre.1", ] +[[package]] +name = "openmls_test" +version = "0.1.0" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" +dependencies = [ + "ansi_term", + "openmls_rust_crypto", + "openmls_traits", + "proc-macro2", + "quote", + "rstest", + "rstest_reuse", + "syn 2.0.48", +] + [[package]] name = "openmls_traits" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=52cad0e#52cad0e35cb2c88f83002e786e177bbc9065a76c" +source = "git+https://github.com/xmtp/openmls?rev=0239b96#0239b966f771c8c6ae4e3d4556879e9b591eaf48" dependencies = [ "serde", "tls_codec 0.4.2-pre.1", @@ -3587,9 +3617,9 @@ dependencies = [ [[package]] name = "rstest" -version = "0.16.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b07f2d176c472198ec1e6551dc7da28f1c089652f66a7b722676c2238ebc0edf" +checksum = "de1bb486a691878cd320c2f0d319ba91eeaa2e894066d8b5f8f117c000e9d962" dependencies = [ "futures", "futures-timer", @@ -3599,9 +3629,9 @@ dependencies = [ [[package]] name = "rstest_macros" -version = "0.16.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7229b505ae0706e64f37ffc54a9c163e11022a6636d58fe1f3f52018257ff9f7" +checksum = "290ca1a1c8ca7edb7c3283bd44dc35dd54fdec6253a3912e201ba1072018fca8" dependencies = [ "cfg-if", "proc-macro2", @@ -3613,11 +3643,12 @@ dependencies = [ [[package]] name = "rstest_reuse" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9b5aed35457441e7e0db509695ba3932d4c47e046777141c167efe584d0ec17" +checksum = "45f80dcc84beab3a327bbe161f77db25f336a1452428176787c8c79ac79d7073" dependencies = [ "quote", + "rand", "rustc_version", "syn 1.0.109", ] diff --git a/mls_validation_service/src/handlers.rs b/mls_validation_service/src/handlers.rs index e1665b873..412f98d26 100644 --- a/mls_validation_service/src/handlers.rs +++ b/mls_validation_service/src/handlers.rs @@ -368,7 +368,8 @@ fn validate_key_package(key_package_bytes: Vec) -> Result SignatureKeyPair { @@ -569,7 +564,7 @@ mod tests { async fn test_validate_key_packages_happy_path() { let (identity, keypair, account_address) = generate_identity(); - let credential: OpenMlsCredential = BasicCredential::new(identity).unwrap().into(); + let credential: OpenMlsCredential = BasicCredential::new(identity).into(); let credential_with_key = CredentialWithKey { credential, signature_key: keypair.to_public_vec().into(), @@ -602,7 +597,7 @@ mod tests { let (identity, keypair, account_address) = generate_identity(); let (_, other_keypair, _) = generate_identity(); - let credential: OpenMlsCredential = BasicCredential::new(identity).unwrap().into(); + let credential: OpenMlsCredential = BasicCredential::new(identity).into(); let credential_with_key = CredentialWithKey { credential, // Use the wrong signature key to make the validation fail diff --git a/xmtp_mls/migrations/2024-05-06-192337_openmls_storage/down.sql b/xmtp_mls/migrations/2024-05-06-192337_openmls_storage/down.sql new file mode 100644 index 000000000..7c35d4288 --- /dev/null +++ b/xmtp_mls/migrations/2024-05-06-192337_openmls_storage/down.sql @@ -0,0 +1 @@ +DROP TABLE openmls_key_value; \ No newline at end of file diff --git a/xmtp_mls/migrations/2024-05-06-192337_openmls_storage/up.sql b/xmtp_mls/migrations/2024-05-06-192337_openmls_storage/up.sql new file mode 100644 index 000000000..c610cef2f --- /dev/null +++ b/xmtp_mls/migrations/2024-05-06-192337_openmls_storage/up.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS openmls_key_value ( + version INT NOT NULL, + key_bytes BLOB NOT NULL, + value_bytes BLOB NOT NULL, + PRIMARY KEY (version, key_bytes) +); \ No newline at end of file diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 67085114a..2927ff62d 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -33,7 +33,7 @@ use crate::{ db_connection::DbConnection, group::{GroupMembershipState, StoredGroup}, refresh_state::EntityKind, - EncryptedMessageStore, StorageError, + sql_key_store, EncryptedMessageStore, StorageError, }, types::Address, verified_key_package::{KeyPackageVerificationError, VerifiedKeyPackage}, @@ -99,11 +99,13 @@ pub enum MessageProcessingError { #[error("invalid payload")] InvalidPayload, #[error("openmls process message error: {0}")] - OpenMlsProcessMessage(#[from] openmls::prelude::ProcessMessageError), + OpenMlsProcessMessage( + #[from] openmls::prelude::ProcessMessageError, + ), #[error("merge pending commit: {0}")] MergePendingCommit(#[from] openmls::group::MergePendingCommitError), #[error("merge staged commit: {0}")] - MergeStagedCommit(#[from] openmls::group::MergeCommitError), + MergeStagedCommit(#[from] openmls::group::MergeCommitError), #[error( "no pending commit to merge. group epoch is {group_epoch:?} and got {message_epoch:?}" )] diff --git a/xmtp_mls/src/groups/members.rs b/xmtp_mls/src/groups/members.rs index 3d5c1eac7..2594170ec 100644 --- a/xmtp_mls/src/groups/members.rs +++ b/xmtp_mls/src/groups/members.rs @@ -35,7 +35,7 @@ pub fn aggregate_member_list(openmls_group: &OpenMlsGroup) -> Result = openmls_group .members() .filter_map(|member| { - let basic_credential = BasicCredential::try_from(&member.credential).ok()?; + let basic_credential = BasicCredential::try_from(member.credential).ok()?; Identity::get_validated_account_address( basic_credential.identity(), &member.signature_key, diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index bd8c5fae9..7ee15607a 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -22,9 +22,8 @@ use openmls::{ group::{CreateGroupContextExtProposalError, MlsGroupCreateConfig, MlsGroupJoinConfig}, messages::proposals::ProposalType, prelude::{ - BasicCredentialError, Capabilities, CredentialWithKey, CryptoConfig, - Error as TlsCodecError, GroupId, MlsGroup as OpenMlsGroup, StagedWelcome, - Welcome as MlsWelcome, WireFormatPolicy, + BasicCredentialError, Capabilities, CredentialWithKey, Error as TlsCodecError, GroupId, + MlsGroup as OpenMlsGroup, StagedWelcome, Welcome as MlsWelcome, WireFormatPolicy, }, }; use openmls_traits::OpenMlsProvider; @@ -83,7 +82,7 @@ use crate::{ group::{GroupMembershipState, Purpose, StoredGroup}, group_intent::{IntentKind, NewGroupIntent}, group_message::{DeliveryStatus, GroupMessageKind, StoredGroupMessage}, - StorageError, + sql_key_store, }, utils::{id::calculate_message_id, time::now_ns}, xmtp_openmls_provider::XmtpOpenMlsProvider, @@ -103,19 +102,19 @@ pub enum GroupError { #[error("intent error: {0}")] Intent(#[from] IntentError), #[error("create message: {0}")] - CreateMessage(#[from] openmls::prelude::CreateMessageError), + CreateMessage(#[from] openmls::prelude::CreateMessageError), #[error("TLS Codec error: {0}")] TlsError(#[from] TlsCodecError), #[error("add members: {0}")] - AddMembers(#[from] openmls::prelude::AddMembersError), + AddMembers(#[from] openmls::prelude::AddMembersError), #[error("remove members: {0}")] - RemoveMembers(#[from] openmls::prelude::RemoveMembersError), + RemoveMembers(#[from] openmls::prelude::RemoveMembersError), #[error("group create: {0}")] - GroupCreate(#[from] openmls::prelude::NewGroupError), + GroupCreate(#[from] openmls::group::NewGroupError), #[error("self update: {0}")] - SelfUpdate(#[from] openmls::group::SelfUpdateError), + SelfUpdate(#[from] openmls::group::SelfUpdateError), #[error("welcome error: {0}")] - WelcomeError(#[from] openmls::prelude::WelcomeError), + WelcomeError(#[from] openmls::prelude::WelcomeError), #[error("Invalid extension {0}")] InvalidExtension(#[from] openmls::prelude::InvalidExtensionError), #[error("Invalid signature: {0}")] @@ -151,7 +150,9 @@ pub enum GroupError { #[error("serialization error: {0}")] EncodeError(#[from] prost::EncodeError), #[error("create group context proposal error: {0}")] - CreateGroupContextExtProposalError(#[from] CreateGroupContextExtProposalError), + CreateGroupContextExtProposalError( + #[from] CreateGroupContextExtProposalError, + ), #[error("Credential error")] CredentialError(#[from] BasicCredentialError), #[error("LeafNode error")] @@ -205,7 +206,8 @@ impl MlsGroup { // Load the stored MLS group from the OpenMLS provider's keystore fn load_mls_group(&self, provider: impl OpenMlsProvider) -> Result { let mls_group = - OpenMlsGroup::load(&GroupId::from_slice(&self.group_id), provider.key_store()) + OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) + .map_err(|_| GroupError::GroupNotFound)? .ok_or(GroupError::GroupNotFound)?; Ok(mls_group) @@ -236,7 +238,7 @@ impl MlsGroup { mutable_permissions, )?; - let mut mls_group = OpenMlsGroup::new( + let mls_group = OpenMlsGroup::new( &provider, &context.identity.installation_keys, &group_config, @@ -245,7 +247,6 @@ impl MlsGroup { signature_key: context.identity.installation_keys.to_public_vec().into(), }, )?; - mls_group.save(provider.key_store())?; let group_id = mls_group.group_id().to_vec(); let stored_group = StoredGroup::new( @@ -255,7 +256,7 @@ impl MlsGroup { added_by_address.clone(), ); - stored_group.store(&provider.conn())?; + stored_group.store(provider.conn_ref())?; Ok(Self::new( context.clone(), group_id, @@ -274,8 +275,7 @@ impl MlsGroup { let mls_welcome = StagedWelcome::new_from_welcome(provider, &build_group_join_config(), welcome, None)?; - let mut mls_group = mls_welcome.into_group(provider)?; - mls_group.save(provider.key_store())?; + let mls_group = mls_welcome.into_group(provider)?; let group_id = mls_group.group_id().to_vec(); let metadata = extract_group_metadata(&mls_group)?; let group_type = metadata.conversation_type; @@ -320,7 +320,7 @@ impl MlsGroup { let added_by_node = staged_welcome.welcome_sender()?; - let added_by_credential = BasicCredential::try_from(added_by_node.credential())?; + let added_by_credential = BasicCredential::try_from(added_by_node.credential().clone())?; let pub_key_bytes = added_by_node.signature_key().as_slice(); let account_address = Identity::get_validated_account_address(added_by_credential.identity(), pub_key_bytes)?; @@ -340,6 +340,7 @@ impl MlsGroup { context.inbox_id(), context.inbox_latest_sequence_id(), ); + let mutable_permissions = build_mutable_permissions_extension(PreconfiguredPolicies::default().to_policy_set())?; let group_config = build_group_config( @@ -348,7 +349,7 @@ impl MlsGroup { group_membership, mutable_permissions, )?; - let mut mls_group = OpenMlsGroup::new( + let mls_group = OpenMlsGroup::new( &provider, &context.identity.installation_keys, &group_config, @@ -357,13 +358,13 @@ impl MlsGroup { signature_key: context.identity.installation_keys.to_public_vec().into(), }, )?; - mls_group.save(provider.key_store())?; let group_id = mls_group.group_id().to_vec(); let stored_group = StoredGroup::new_sync_group(group_id.clone(), now_ns(), GroupMembershipState::Allowed); - stored_group.store(&provider.conn())?; + stored_group.store(provider.conn_ref())?; + Ok(Self::new( context.clone(), stored_group.id, @@ -786,7 +787,7 @@ fn build_group_config( Ok(MlsGroupCreateConfig::builder() .with_group_context_extensions(extensions)? .capabilities(capabilities) - .crypto_config(CryptoConfig::with_default_version(CIPHERSUITE)) + .ciphersuite(CIPHERSUITE) .wire_format_policy(WireFormatPolicy::default()) .max_past_epochs(3) // Trying with 3 max past epochs for now .use_ratchet_tree_extension(true) diff --git a/xmtp_mls/src/groups/sync.rs b/xmtp_mls/src/groups/sync.rs index 170c65ae3..a3e069d61 100644 --- a/xmtp_mls/src/groups/sync.rs +++ b/xmtp_mls/src/groups/sync.rs @@ -219,7 +219,15 @@ impl MlsGroup { openmls_group.merge_pending_commit(&provider) { log::error!("error merging commit: {}", err); - openmls_group.clear_pending_commit(); + match openmls_group.clear_pending_commit(provider.storage()) { + Ok(_) => (), + Err(_) => { + return Err(MessageProcessingError::Generic( + "Error clearing pending commit".to_string(), + )) + } + } + conn.set_group_intent_to_publish(intent.id)?; } else { // If no error committing the change, write a transcript message @@ -337,6 +345,7 @@ impl MlsGroup { &self.context.account_address(), &idempotency_key, ); + StoredGroupMessage { id: message_id, group_id: self.group_id.clone(), @@ -478,7 +487,6 @@ impl MlsGroup { msgv1.id, |provider| -> Result<(), MessageProcessingError> { self.process_message(openmls_group, provider, msgv1, true)?; - openmls_group.save(provider.key_store())?; Ok(()) }, )?; @@ -496,6 +504,7 @@ impl MlsGroup { { let provider = self.context.mls_provider(conn); let mut openmls_group = self.load_mls_group(&provider)?; + log::debug!(" loaded openmls group"); let receive_errors: Vec = messages .into_iter() @@ -525,9 +534,7 @@ impl MlsGroup { ApiClient: XmtpApi, { let messages = client.query_group_messages(&self.group_id, conn).await?; - self.process_messages(messages, conn.clone(), client)?; - Ok(()) } @@ -599,7 +606,6 @@ impl MlsGroup { Some(vec![IntentState::ToPublish]), None, )?; - let num_intents = intents.len(); for intent in intents { let result = retry_async!( @@ -640,10 +646,6 @@ impl MlsGroup { )?; } - if num_intents > 0 { - openmls_group.save(provider.key_store())?; - } - Ok(()) } @@ -966,7 +968,7 @@ fn validate_message_sender( if let Sender::Member(leaf_node_index) = decrypted_message.sender() { if let Some(member) = openmls_group.member_at(*leaf_node_index) { if member.credential.eq(decrypted_message.credential()) { - let basic_credential = BasicCredential::try_from(&member.credential)?; + let basic_credential = BasicCredential::try_from(member.credential)?; sender_account_address = Identity::get_validated_account_address( basic_credential.identity(), &member.signature_key, @@ -978,7 +980,7 @@ fn validate_message_sender( } if sender_account_address.is_none() { - let basic_credential = BasicCredential::try_from(decrypted_message.credential())?; + let basic_credential = BasicCredential::try_from(decrypted_message.credential().clone())?; return Err(MessageProcessingError::InvalidSender { message_time_ns: message_created_ns, credential: basic_credential.identity().to_vec(), diff --git a/xmtp_mls/src/groups/validated_commit.rs b/xmtp_mls/src/groups/validated_commit.rs index aa7cd03b0..77350d4e6 100644 --- a/xmtp_mls/src/groups/validated_commit.rs +++ b/xmtp_mls/src/groups/validated_commit.rs @@ -191,7 +191,7 @@ fn extract_actor( if let Some(leaf_node) = group.member_at(leaf_index) { let signature_key = leaf_node.signature_key.as_slice(); - let basic_credential = BasicCredential::try_from(&leaf_node.credential)?; + let basic_credential = BasicCredential::try_from(leaf_node.credential)?; let account_address = Identity::get_validated_account_address(basic_credential.identity(), signature_key)?; @@ -243,7 +243,7 @@ fn extract_identity_from_remove( if let Some(member) = group.member_at(leaf_index) { let signature_key = member.signature_key.as_slice(); - let basic_credential = BasicCredential::try_from(&member.credential)?; + let basic_credential = BasicCredential::try_from(member.credential)?; let account_address = Identity::get_validated_account_address(basic_credential.identity(), signature_key)?; let is_creator = account_address.eq(&group_metadata.creator_account_address); @@ -448,11 +448,9 @@ mod tests { use openmls::{ credentials::{BasicCredential, CredentialWithKey}, extensions::ExtensionType, - group::config::CryptoConfig, messages::proposals::ProposalType, prelude::Capabilities, prelude_test::KeyPackage, - versions::ProtocolVersion, }; use xmtp_api_grpc::Client as GrpcClient; use xmtp_cryptography::utils::generate_local_wallet; @@ -611,15 +609,12 @@ mod tests { let bad_key_package = KeyPackage::builder() .leaf_node_capabilities(capabilities) .build( - CryptoConfig { - ciphersuite: CIPHERSUITE, - version: ProtocolVersion::default(), - }, + CIPHERSUITE, &bola_provider, &bola.identity().installation_keys, CredentialWithKey { // Broken credential - credential: BasicCredential::new(vec![1, 2, 3]).unwrap().into(), + credential: BasicCredential::new(vec![1, 2, 3]).into(), signature_key: bola.identity().installation_keys.to_public_vec().into(), }, ) @@ -629,7 +624,7 @@ mod tests { .add_members( &amal_provider, &amal.identity().installation_keys, - &[bad_key_package], + &[bad_key_package.key_package().clone()], ) .unwrap(); diff --git a/xmtp_mls/src/groups/validated_commit_v2.rs b/xmtp_mls/src/groups/validated_commit_v2.rs index 68954a448..733e47a6d 100644 --- a/xmtp_mls/src/groups/validated_commit_v2.rs +++ b/xmtp_mls/src/groups/validated_commit_v2.rs @@ -404,7 +404,7 @@ pub fn extract_group_membership( fn inbox_id_from_credential( credential: &OpenMlsCredential, ) -> Result { - let basic_credential = BasicCredential::try_from(credential)?; + let basic_credential = BasicCredential::try_from(credential.clone())?; let identity_bytes = basic_credential.identity(); let decoded = MlsCredential::decode(identity_bytes)?; diff --git a/xmtp_mls/src/hpke.rs b/xmtp_mls/src/hpke.rs index 37c6f6fd8..090cb9528 100644 --- a/xmtp_mls/src/hpke.rs +++ b/xmtp_mls/src/hpke.rs @@ -1,17 +1,20 @@ -use openmls::ciphersuite::hpke::{ - decrypt_with_label, encrypt_with_label, Error as OpenmlsHpkeError, -}; -use openmls::prelude::tls_codec::{Deserialize, Error as TlsCodecError, Serialize}; -use openmls_rust_crypto::RustCrypto; -use openmls_traits::types::HpkeCiphertext; -use openmls_traits::OpenMlsProvider; -use openmls_traits::{key_store::OpenMlsKeyStore, types::HpkePrivateKey}; -use thiserror::Error; - use crate::{ configuration::{CIPHERSUITE, WELCOME_HPKE_LABEL}, + storage::sql_key_store::KEY_PACKAGE_REFERENCES, xmtp_openmls_provider::XmtpOpenMlsProvider, }; +use openmls::{ + ciphersuite::hash_ref::KeyPackageRef, + prelude::tls_codec::{Deserialize, Error as TlsCodecError, Serialize}, +}; +use openmls::{ + ciphersuite::hpke::{decrypt_with_label, encrypt_with_label, Error as OpenmlsHpkeError}, + key_packages::KeyPackageBundle, +}; +use openmls_rust_crypto::RustCrypto; +use openmls_traits::OpenMlsProvider; +use openmls_traits::{storage::StorageProvider, types::HpkeCiphertext}; +use thiserror::Error; #[derive(Debug, Error)] pub enum HpkeError { @@ -44,19 +47,37 @@ pub fn decrypt_welcome( hpke_public_key: &[u8], ciphertext: &[u8], ) -> Result, HpkeError> { - let private_key = provider - .key_store() - .read::(hpke_public_key) - .ok_or(HpkeError::KeyNotFound)?; - let ciphertext = HpkeCiphertext::tls_deserialize_exact(ciphertext)?; - Ok(decrypt_with_label( - private_key.to_vec().as_slice(), - WELCOME_HPKE_LABEL, - &[], - &ciphertext, - CIPHERSUITE, - &RustCrypto::default(), - )?) + let serialized_hpke_public_key = hpke_public_key.tls_serialize_detached()?; + + let hash_ref: Option = match provider + .storage() + .read(KEY_PACKAGE_REFERENCES, &serialized_hpke_public_key) + { + Ok(hash_ref) => hash_ref, + Err(_) => return Err(HpkeError::KeyNotFound), + }; + + if let Some(hash_ref) = hash_ref { + // With the hash reference we can read the key package. + let key_package: Option = match provider.storage().key_package(&hash_ref) + { + Ok(key_package) => key_package, + Err(_) => return Err(HpkeError::KeyNotFound), + }; + + if let Some(kp) = key_package { + return Ok(decrypt_with_label( + kp.init_private_key(), + WELCOME_HPKE_LABEL, + &[], + &ciphertext, + CIPHERSUITE, + &RustCrypto::default(), + )?); + } + } + + Err(HpkeError::KeyNotFound) } diff --git a/xmtp_mls/src/identity/v3/legacy.rs b/xmtp_mls/src/identity/v3/legacy.rs index f7c4b8cfd..3e8031933 100644 --- a/xmtp_mls/src/identity/v3/legacy.rs +++ b/xmtp_mls/src/identity/v3/legacy.rs @@ -10,10 +10,9 @@ use openmls::{ messages::proposals::ProposalType, prelude::{ tls_codec::{Error as TlsCodecError, Serialize}, - Capabilities, Credential as OpenMlsCredential, CredentialWithKey, CryptoConfig, Extension, - ExtensionType, Extensions, KeyPackage, KeyPackageNewError, Lifetime, + Capabilities, Credential as OpenMlsCredential, CredentialWithKey, Extension, ExtensionType, + Extensions, KeyPackage, KeyPackageNewError, Lifetime, }, - versions::ProtocolVersion, }; use openmls_basic_credential::SignatureKeyPair; use openmls_traits::{types::CryptoError, OpenMlsProvider}; @@ -31,7 +30,11 @@ use crate::{ MUTABLE_METADATA_EXTENSION_ID, }, credential::{AssociationError, Credential, UnsignedGrantMessagingAccessData}, - storage::{identity::StoredIdentity, StorageError}, + storage::{ + identity::StoredIdentity, + sql_key_store::{MemoryStorageError, KEY_PACKAGE_REFERENCES}, + StorageError, + }, types::Address, utils::time::now_ns, xmtp_openmls_provider::XmtpOpenMlsProvider, @@ -49,7 +52,7 @@ pub enum IdentityError { #[error("storage error: {0}")] StorageError(#[from] StorageError), #[error("generating key package: {0}")] - KeyPackageGenerationError(#[from] KeyPackageNewError), + KeyPackageGenerationError(#[from] KeyPackageNewError), #[error("deserialization: {0}")] Deserialization(#[from] prost::DecodeError), #[error("invalid extension: {0}")] @@ -68,6 +71,8 @@ pub enum IdentityError { BasicCredential(#[from] BasicCredentialError), #[error(transparent)] Signature(#[from] ed25519_dalek::SignatureError), + #[error(transparent)] + MemoryStorage(#[from] MemoryStorageError), } #[derive(Debug)] @@ -109,7 +114,7 @@ impl Identity { Credential::create_from_legacy(&signature_keys, legacy_signed_private_key)?; let credential_proto: CredentialProto = credential.into(); let mls_credential: OpenMlsCredential = - BasicCredential::new(credential_proto.encode_to_vec())?.into(); + BasicCredential::new(credential_proto.encode_to_vec()).into(); info!("Successfully created identity from legacy key"); Ok(Self { account_address, @@ -129,7 +134,8 @@ impl Identity { ApiClient: XmtpApi, { // Do not re-register if already registered - let stored_identity: Option = provider.conn().fetch(&())?; + let conn = provider.conn(); + let stored_identity: Option = conn.fetch(&())?; if stored_identity.is_some() { info!("Identity already registered, skipping registration"); return Ok(()); @@ -150,7 +156,7 @@ impl Identity { )? .into(); let credential: OpenMlsCredential = - BasicCredential::new(credential_proto.encode_to_vec())?.into(); + BasicCredential::new(credential_proto.encode_to_vec()).into(); self.set_credential(credential)?; } @@ -160,7 +166,7 @@ impl Identity { api_client.register_installation(kp_bytes).await?; // Only persist the installation keys if the registration was successful - self.installation_keys.store(provider.key_store())?; + self.installation_keys.store(provider.storage())?; StoredIdentity::from(self).store(provider.conn_ref())?; Ok(()) @@ -224,10 +230,7 @@ impl Identity { .key_package_extensions(key_package_extensions) .key_package_lifetime(Lifetime::new(6 * 30 * 86400)) .build( - CryptoConfig { - ciphersuite: CIPHERSUITE, - version: ProtocolVersion::default(), - }, + CIPHERSUITE, provider, &self.installation_keys, CredentialWithKey { @@ -236,7 +239,31 @@ impl Identity { }, )?; - Ok(kp) + // Store the hash reference, keyed with the public init key. + // This is needed to get to the private key when decrypting welcome messages. + let public_init_key = kp.key_package().hpke_init_key().tls_serialize_detached()?; + + let key_package_hash_ref = match kp.key_package().hash_ref(provider.crypto()) { + Ok(key_package_hash_ref) => key_package_hash_ref, + Err(_) => return Err(IdentityError::UninitializedIdentity), + }; + + // Serialize the hash reference + let hash_ref = match serde_json::to_vec(&key_package_hash_ref) { + Ok(hash_ref) => hash_ref, + Err(_) => return Err(IdentityError::UninitializedIdentity), + }; + + // Store the hash reference, keyed with the public init key + provider + .storage() + .write::<{ openmls_traits::storage::CURRENT_VERSION }>( + KEY_PACKAGE_REFERENCES, + &public_init_key, + &hash_ref, + )?; + + Ok(kp.key_package().clone()) } pub(crate) fn get_validated_account_address( diff --git a/xmtp_mls/src/identity/xmtp_id/identity.rs b/xmtp_mls/src/identity/xmtp_id/identity.rs index 13fbc894f..4ff10cea7 100644 --- a/xmtp_mls/src/identity/xmtp_id/identity.rs +++ b/xmtp_mls/src/identity/xmtp_id/identity.rs @@ -286,5 +286,5 @@ fn create_credential(inbox_id: InboxId) -> Result Integer, + key_bytes -> Binary, + value_bytes -> Binary, + } +} + diesel::table! { refresh_state (entity_id, entity_kind) { entity_id -> Binary, @@ -90,5 +98,6 @@ diesel::allow_tables_to_appear_in_same_query!( identity_inbox, identity_updates, openmls_key_store, + openmls_key_value, refresh_state, ); diff --git a/xmtp_mls/src/storage/errors.rs b/xmtp_mls/src/storage/errors.rs index 49ceae82f..ccce414c5 100644 --- a/xmtp_mls/src/storage/errors.rs +++ b/xmtp_mls/src/storage/errors.rs @@ -54,12 +54,9 @@ impl RetryableError for openmls::group::CreateCommitError { } } -impl RetryableError for openmls::key_packages::errors::KeyPackageNewError { +impl RetryableError for openmls::key_packages::errors::KeyPackageNewError { fn is_retryable(&self) -> bool { - match self { - Self::KeyStoreError(storage) => retryable!(storage), - _ => false, - } + matches!(self, Self::StorageError) } } @@ -75,7 +72,7 @@ impl RetryableError for openmls::group::RemoveMembersError { impl RetryableError for openmls::group::NewGroupError { fn is_retryable(&self) -> bool { match self { - Self::KeyStoreError(storage) => retryable!(storage), + Self::StorageError(storage) => retryable!(storage), _ => false, } } @@ -85,7 +82,7 @@ impl RetryableError for openmls::group::SelfUpdateError { fn is_retryable(&self) -> bool { match self { Self::CreateCommitError(commit) => retryable!(commit), - Self::KeyStoreError => true, + Self::StorageError(storage) => retryable!(storage), _ => false, } } @@ -94,7 +91,7 @@ impl RetryableError for openmls::group::SelfUpdateError { impl RetryableError for openmls::group::WelcomeError { fn is_retryable(&self) -> bool { match self { - Self::KeyStoreError(storage) => retryable!(storage), + Self::StorageError(storage) => retryable!(storage), _ => false, } } diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index 228dcaddb..34bacfdfa 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -1,16 +1,32 @@ -use log::{debug, error}; -use openmls_traits::key_store::{MlsEntity, OpenMlsKeyStore}; - -use super::{ - encrypted_store::{db_connection::DbConnection, key_store_entry::StoredKeyStoreEntry}, - serialization::{db_deserialize, db_serialize}, - StorageError, +use super::encrypted_store::db_connection::DbConnection; +use diesel::{ + prelude::*, + sql_types::Binary, + {sql_query, RunQueryDsl}, }; -use crate::{Delete, Fetch}; +use log::error; +use openmls_traits::storage::*; +use serde::Serialize; +use serde_json::{from_slice, from_value, Value}; + +const SELECT_QUERY: &str = + "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?"; +const REPLACE_QUERY: &str = + "REPLACE INTO openmls_key_value (key_bytes, version, value_bytes) VALUES (?, ?, ?)"; +const UPDATE_QUERY: &str = + "UPDATE openmls_key_value SET value_bytes = ? WHERE key_bytes = ? AND version = ?"; +const DELETE_QUERY: &str = "DELETE FROM openmls_key_value WHERE key_bytes = ? AND version = ?"; + +#[derive(QueryableByName, Debug, Clone, PartialEq, Eq)] +#[diesel(table_name = openmls_key_value)] +struct StorageData { + #[diesel(sql_type = Binary)] + value_bytes: Vec, +} -/// CRUD Operations for an [`OpenMlsKeyStore`] #[derive(Debug, Clone)] pub struct SqlKeyStore { + // Directly wrap the DbConnection which is a SqliteConnection in this case conn: DbConnection, } @@ -26,66 +42,1081 @@ impl SqlKeyStore { pub fn conn_ref(&self) -> &DbConnection { &self.conn } -} -impl OpenMlsKeyStore for SqlKeyStore { - /// The error type returned by the [`OpenMlsKeyStore`]. - type Error = StorageError; - - /// Store a value `v` that implements the [`MlsEntity`] trait for - /// serialization for ID `k`. - /// - /// Returns an error if storing fails. - fn store(&self, k: &[u8], v: &V) -> Result<(), Self::Error> { - self.conn() - .insert_or_update_key_store_entry(k.to_vec(), db_serialize(v)?)?; + fn select_query( + &self, + storage_key: &Vec, + ) -> Result, diesel::result::Error> { + self.conn().raw_query(|conn| { + sql_query(SELECT_QUERY) + .bind::(&storage_key) + .bind::(VERSION as i32) + .load(conn) + }) + } + + fn replace_query( + &self, + storage_key: &Vec, + value: &[u8], + ) -> Result { + self.conn().raw_query(|conn| { + sql_query(REPLACE_QUERY) + .bind::(&storage_key) + .bind::(VERSION as i32) + .bind::(&value) + .execute(conn) + }) + } + + fn update_query( + &self, + storage_key: &Vec, + modified_data: &Vec, + ) -> Result { + self.conn().raw_query(|conn| { + sql_query(UPDATE_QUERY) + .bind::(&modified_data) + .bind::(&storage_key) + .bind::(VERSION as i32) + .execute(conn) + }) + } + + pub fn write( + &self, + label: &[u8], + key: &[u8], + value: &[u8], + ) -> Result<(), >::Error> { + log::debug!("write {}", String::from_utf8_lossy(label)); + + let storage_key = build_key_from_vec::(label, key.to_vec()); + + let _ = self.replace_query::(&storage_key, value); + Ok(()) } - /// Read and return a value stored for ID `k` that implements the - /// [`MlsEntity`] trait for deserialization. - /// - /// Returns [`None`] if no value is stored for `k` or reading fails. - fn read(&self, k: &[u8]) -> Option { - let fetch_result = self.conn().fetch(&k.to_vec()); + pub fn append( + &self, + label: &[u8], + key: &[u8], + value: &[u8], + ) -> Result<(), >::Error> { + log::debug!("append {}", String::from_utf8_lossy(label)); + + let storage_key = build_key_from_vec::(label, key.to_vec()); + let current_data: Result, diesel::result::Error> = + self.select_query::(&storage_key); + + match current_data { + Ok(data) => { + if let Some(entry) = data.into_iter().next() { + // The value in the storage is an array of array of bytes, encoded as json. + match from_slice::(&entry.value_bytes) { + Ok(mut deserialized) => { + // Assuming value is JSON and needs to be added to an array + if let Value::Array(ref mut arr) = deserialized { + arr.push(Value::from(value)); + } + + let modified_data = serde_json::to_vec(&deserialized) + .map_err(|_| MemoryStorageError::SerializationError)?; + + let _ = self.update_query::(&storage_key, &modified_data); + Ok(()) + } + Err(_e) => Err(MemoryStorageError::SerializationError), + } + } else { + // Add a first entry + let value_bytes = &serde_json::to_vec(&[value])?; + let _ = self.replace_query::(&storage_key, value_bytes); + + Ok(()) + } + } + Err(_) => Err(MemoryStorageError::None), + } + } + + pub fn remove_item( + &self, + label: &[u8], + key: &[u8], + value: &[u8], + ) -> Result<(), >::Error> { + log::debug!("remove_item {}", String::from_utf8_lossy(label)); + + let storage_key = build_key_from_vec::(label, key.to_vec()); + let current_data: Result, diesel::result::Error> = + self.select_query::(&storage_key); + + match current_data { + Ok(data) => { + if let Some(entry) = data.into_iter().next() { + // The value in the storage is an array of array of bytes, encoded as json. + match from_slice::(&entry.value_bytes) { + Ok(mut deserialized) => { + if let Value::Array(ref mut arr) = deserialized { + // Find and remove the value. + let vpos = arr.iter().position(|v| { + match from_value::>(v.clone()) { + Ok(deserialized_value) => deserialized_value == value, + Err(_) => false, + } + }); + + if let Some(pos) = vpos { + arr.remove(pos); + } + } + let modified_data = serde_json::to_vec(&deserialized) + .map_err(|_| MemoryStorageError::SerializationError)?; + + let _ = self.update_query::(&storage_key, &modified_data); + Ok(()) + } + Err(_) => Err(MemoryStorageError::SerializationError), + } + } else { + // Add a first entry + let value_bytes = serde_json::to_vec(&[value]) + .map_err(|_| MemoryStorageError::SerializationError)?; + let _ = self.replace_query::(&storage_key, &value_bytes); + Ok(()) + } + } + Err(_) => Err(MemoryStorageError::None), + } + } + + pub fn read>( + &self, + label: &[u8], + key: &[u8], + ) -> Result, >::Error> { + log::debug!("read {}", String::from_utf8_lossy(label)); + + let storage_key = build_key_from_vec::(label, key.to_vec()); + + let results: Result, diesel::result::Error> = + self.select_query::(&storage_key); + + match results { + Ok(data) => { + if let Some(entry) = data.into_iter().next() { + match serde_json::from_slice::(&entry.value_bytes) { + Ok(deserialized) => Ok(Some(deserialized)), + Err(e) => { + eprintln!("Error occurred: {}", e); + Err(MemoryStorageError::SerializationError) + } + } + } else { + Ok(None) + } + } + Err(_e) => Err(MemoryStorageError::None), + } + } + + pub fn read_list>( + &self, + label: &[u8], + key: &[u8], + ) -> Result, >::Error> { + log::debug!("read_list {}", String::from_utf8_lossy(label)); + + let storage_key = build_key_from_vec::(label, key.to_vec()); + + match self.select_query::(&storage_key) { + Ok(results) => { + if let Some(entry) = results.into_iter().next() { + let list = from_slice::>>(&entry.value_bytes)?; - if let Err(e) = fetch_result { - error!("Failed to fetch key: {:?}", e); - return None; + // Read the values from the bytes in the list + let mut deserialized_list = Vec::new(); + for v in list { + match serde_json::from_slice(&v) { + Ok(deserialized_value) => deserialized_list.push(deserialized_value), + Err(_) => return Err(MemoryStorageError::SerializationError), + } + } + Ok(deserialized_list) + } else { + Ok(vec![]) + } + } + Err(_e) => Err(MemoryStorageError::None), } - let entry_option: Option = fetch_result.unwrap(); + } + + pub fn delete( + &self, + label: &[u8], + key: &[u8], + ) -> Result<(), >::Error> { + let storage_key = build_key_from_vec::(label, key.to_vec()); + + let _ = self.conn().raw_query(|conn| { + sql_query(DELETE_QUERY) + .bind::(&storage_key) + .bind::(VERSION as i32) + .execute(conn) + }); + Ok(()) + } +} + +/// Errors thrown by the key store. +#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)] +pub enum MemoryStorageError { + #[error("The key store does not allow storing serialized values.")] + UnsupportedValueTypeBytes, + #[error("Updating is not supported by this key store.")] + UnsupportedMethod, + #[error("Error serializing value.")] + SerializationError, + #[error("Value does not exist.")] + None, +} + +const KEY_PACKAGE_LABEL: &[u8] = b"KeyPackage"; +const ENCRYPTION_KEY_PAIR_LABEL: &[u8] = b"EncryptionKeyPair"; +const SIGNATURE_KEY_PAIR_LABEL: &[u8] = b"SignatureKeyPair"; +const EPOCH_KEY_PAIRS_LABEL: &[u8] = b"EpochKeyPairs"; +pub const KEY_PACKAGE_REFERENCES: &[u8] = b"KeyPackageReferences"; + +// related to PublicGroup +const TREE_LABEL: &[u8] = b"Tree"; +const GROUP_CONTEXT_LABEL: &[u8] = b"GroupContext"; +const INTERIM_TRANSCRIPT_HASH_LABEL: &[u8] = b"InterimTranscriptHash"; +const CONFIRMATION_TAG_LABEL: &[u8] = b"ConfirmationTag"; + +// related to CoreGroup +const OWN_LEAF_NODE_INDEX_LABEL: &[u8] = b"OwnLeafNodeIndex"; +const EPOCH_SECRETS_LABEL: &[u8] = b"EpochSecrets"; +const MESSAGE_SECRETS_LABEL: &[u8] = b"MessageSecrets"; +const USE_RATCHET_TREE_LABEL: &[u8] = b"UseRatchetTree"; + +// related to MlsGroup +const JOIN_CONFIG_LABEL: &[u8] = b"MlsGroupJoinConfig"; +const OWN_LEAF_NODES_LABEL: &[u8] = b"OwnLeafNodes"; +const AAD_LABEL: &[u8] = b"AAD"; +const GROUP_STATE_LABEL: &[u8] = b"GroupState"; +const QUEUED_PROPOSAL_LABEL: &[u8] = b"QueuedProposal"; +const PROPOSAL_QUEUE_REFS_LABEL: &[u8] = b"ProposalQueueRefs"; +const RESUMPTION_PSK_STORE_LABEL: &[u8] = b"ResumptionPskStore"; + +impl StorageProvider for SqlKeyStore { + type Error = MemoryStorageError; + + fn queue_proposal< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + QueuedProposal: traits::QueuedProposal, + >( + &self, + group_id: &GroupId, + proposal_ref: &ProposalRef, + proposal: &QueuedProposal, + ) -> Result<(), Self::Error> { + // write proposal to key (group_id, proposal_ref) + let key = serde_json::to_vec(&(group_id, proposal_ref))?; + let value = serde_json::to_vec(proposal)?; + self.write::(QUEUED_PROPOSAL_LABEL, &key, &value)?; + + // update proposal list for group_id + let key = build_key::(PROPOSAL_QUEUE_REFS_LABEL, group_id)?; + let value = serde_json::to_vec(proposal_ref)?; + self.append::(PROPOSAL_QUEUE_REFS_LABEL, &key, &value)?; + + Ok(()) + } + + fn write_tree< + GroupId: traits::GroupId, + TreeSync: traits::TreeSync, + >( + &self, + group_id: &GroupId, + tree: &TreeSync, + ) -> Result<(), Self::Error> { + let key = build_key::(TREE_LABEL, group_id)?; + let value = serde_json::to_vec(&tree)?; + self.write::(TREE_LABEL, &key, &value) + } + + fn write_interim_transcript_hash< + GroupId: traits::GroupId, + InterimTranscriptHash: traits::InterimTranscriptHash, + >( + &self, + group_id: &GroupId, + interim_transcript_hash: &InterimTranscriptHash, + ) -> Result<(), Self::Error> { + let key = build_key::(INTERIM_TRANSCRIPT_HASH_LABEL, group_id)?; + let value = serde_json::to_vec(&interim_transcript_hash)?; + let _ = self.write::(INTERIM_TRANSCRIPT_HASH_LABEL, &key, &value); + + Ok(()) + } + + fn write_context< + GroupId: traits::GroupId, + GroupContext: traits::GroupContext, + >( + &self, + group_id: &GroupId, + group_context: &GroupContext, + ) -> Result<(), Self::Error> { + let key = build_key::(GROUP_CONTEXT_LABEL, group_id)?; + let value = serde_json::to_vec(&group_context)?; + let _ = self.write::(GROUP_CONTEXT_LABEL, &key, &value); + + Ok(()) + } + + fn write_confirmation_tag< + GroupId: traits::GroupId, + ConfirmationTag: traits::ConfirmationTag, + >( + &self, + group_id: &GroupId, + confirmation_tag: &ConfirmationTag, + ) -> Result<(), Self::Error> { + let key = build_key::(CONFIRMATION_TAG_LABEL, group_id)?; + let value = serde_json::to_vec(&confirmation_tag)?; + let _ = self.write::(CONFIRMATION_TAG_LABEL, &key, &value); + + Ok(()) + } + + fn write_signature_key_pair< + SignaturePublicKey: traits::SignaturePublicKey, + SignatureKeyPair: traits::SignatureKeyPair, + >( + &self, + public_key: &SignaturePublicKey, + signature_key_pair: &SignatureKeyPair, + ) -> Result<(), Self::Error> { + let key = build_key::( + SIGNATURE_KEY_PAIR_LABEL, + public_key, + )?; + let value = serde_json::to_vec(&signature_key_pair)?; + let _ = self.write::(SIGNATURE_KEY_PAIR_LABEL, &key, &value); + + Ok(()) + } + + fn queued_proposal_refs< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(PROPOSAL_QUEUE_REFS_LABEL, group_id)?; + self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &key) + } + + fn queued_proposals< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + QueuedProposal: traits::QueuedProposal, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(PROPOSAL_QUEUE_REFS_LABEL, group_id)?; + let refs: Vec = self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &key)?; + + refs.into_iter() + .map(|proposal_ref| -> Result<_, _> { + let key = serde_json::to_vec(&(group_id, &proposal_ref))?; + match self.read(QUEUED_PROPOSAL_LABEL, &key)? { + Some(proposal) => Ok((proposal_ref, proposal)), + None => Err(MemoryStorageError::SerializationError), + } + }) + .collect::, _>>() + } + + fn treesync< + GroupId: traits::GroupId, + TreeSync: traits::TreeSync, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(TREE_LABEL, group_id)?; + + self.read(TREE_LABEL, &key) + } + + fn group_context< + GroupId: traits::GroupId, + GroupContext: traits::GroupContext, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(GROUP_CONTEXT_LABEL, group_id)?; + + self.read(GROUP_CONTEXT_LABEL, &key) + } + + fn interim_transcript_hash< + GroupId: traits::GroupId, + InterimTranscriptHash: traits::InterimTranscriptHash, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(INTERIM_TRANSCRIPT_HASH_LABEL, group_id)?; + + self.read(INTERIM_TRANSCRIPT_HASH_LABEL, &key) + } + + fn confirmation_tag< + GroupId: traits::GroupId, + ConfirmationTag: traits::ConfirmationTag, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(CONFIRMATION_TAG_LABEL, group_id)?; + + self.read(CONFIRMATION_TAG_LABEL, &key) + } + + fn signature_key_pair< + SignaturePublicKey: traits::SignaturePublicKey, + SignatureKeyPair: traits::SignatureKeyPair, + >( + &self, + public_key: &SignaturePublicKey, + ) -> Result, Self::Error> { + let key = build_key::( + SIGNATURE_KEY_PAIR_LABEL, + public_key, + )?; + + self.read(SIGNATURE_KEY_PAIR_LABEL, &key) + } + + fn write_key_package< + HashReference: traits::HashReference, + KeyPackage: traits::KeyPackage, + >( + &self, + hash_ref: &HashReference, + key_package: &KeyPackage, + ) -> Result<(), Self::Error> { + let key = build_key::(KEY_PACKAGE_LABEL, hash_ref)?; + let value = serde_json::to_vec(&key_package)?; + + // Store the key package + self.write::(KEY_PACKAGE_LABEL, &key, &value)?; + + Ok(()) + } + + fn write_psk< + PskId: traits::PskId, + PskBundle: traits::PskBundle, + >( + &self, + _psk_id: &PskId, + _psk: &PskBundle, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn write_encryption_key_pair< + EncryptionKey: traits::EncryptionKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + public_key: &EncryptionKey, + key_pair: &HpkeKeyPair, + ) -> Result<(), Self::Error> { + let key = + build_key::(ENCRYPTION_KEY_PAIR_LABEL, public_key)?; + self.write::( + ENCRYPTION_KEY_PAIR_LABEL, + &key, + &serde_json::to_vec(key_pair)?, + ) + } + + fn key_package< + HashReference: traits::HashReference, + KeyPackage: traits::KeyPackage, + >( + &self, + hash_ref: &HashReference, + ) -> Result, Self::Error> { + let key = build_key::(KEY_PACKAGE_LABEL, hash_ref)?; + self.read(KEY_PACKAGE_LABEL, &key) + } + + fn psk, PskId: traits::PskId>( + &self, + _psk_id: &PskId, + ) -> Result, Self::Error> { + Ok(None) + } + + fn encryption_key_pair< + HpkeKeyPair: traits::HpkeKeyPair, + EncryptionKey: traits::EncryptionKey, + >( + &self, + public_key: &EncryptionKey, + ) -> Result, Self::Error> { + let key = + build_key::(ENCRYPTION_KEY_PAIR_LABEL, public_key)?; + self.read(ENCRYPTION_KEY_PAIR_LABEL, &key) + } + + fn delete_signature_key_pair< + SignaturePublicKey: traits::SignaturePublicKey, + >( + &self, + public_key: &SignaturePublicKey, + ) -> Result<(), Self::Error> { + let key = build_key::( + SIGNATURE_KEY_PAIR_LABEL, + public_key, + )?; + self.delete::(SIGNATURE_KEY_PAIR_LABEL, &key) + } + + fn delete_encryption_key_pair>( + &self, + public_key: &EncryptionKey, + ) -> Result<(), Self::Error> { + let key = + build_key::(ENCRYPTION_KEY_PAIR_LABEL, public_key)?; + self.delete::(ENCRYPTION_KEY_PAIR_LABEL, &key) + } + + fn delete_key_package>( + &self, + hash_ref: &HashReference, + ) -> Result<(), Self::Error> { + let key = build_key::(KEY_PACKAGE_LABEL, hash_ref)?; + self.delete::(KEY_PACKAGE_LABEL, &key) + } + + fn delete_psk>( + &self, + _psk_id: &PskKey, + ) -> Result<(), Self::Error> { + Err(MemoryStorageError::UnsupportedMethod) + } + + fn group_state< + GroupState: traits::GroupState, + GroupId: traits::GroupId, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(GROUP_STATE_LABEL, group_id)?; + self.read(GROUP_STATE_LABEL, &key) + } + + fn write_group_state< + GroupState: traits::GroupState, + GroupId: traits::GroupId, + >( + &self, + group_id: &GroupId, + group_state: &GroupState, + ) -> Result<(), Self::Error> { + let key = build_key::(GROUP_STATE_LABEL, group_id)?; + self.write::(GROUP_STATE_LABEL, &key, &serde_json::to_vec(group_state)?) + } + + fn delete_group_state>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(GROUP_STATE_LABEL, group_id)?; + self.delete::(GROUP_STATE_LABEL, &key) + } + + fn message_secrets< + GroupId: traits::GroupId, + MessageSecrets: traits::MessageSecrets, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(MESSAGE_SECRETS_LABEL, group_id)?; + self.read(MESSAGE_SECRETS_LABEL, &key) + } + + fn write_message_secrets< + GroupId: traits::GroupId, + MessageSecrets: traits::MessageSecrets, + >( + &self, + group_id: &GroupId, + message_secrets: &MessageSecrets, + ) -> Result<(), Self::Error> { + let key = build_key::(MESSAGE_SECRETS_LABEL, group_id)?; + self.write::( + MESSAGE_SECRETS_LABEL, + &key, + &serde_json::to_vec(message_secrets)?, + ) + } + + fn delete_message_secrets>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(MESSAGE_SECRETS_LABEL, group_id)?; + self.delete::(MESSAGE_SECRETS_LABEL, &key) + } + + fn resumption_psk_store< + GroupId: traits::GroupId, + ResumptionPskStore: traits::ResumptionPskStore, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + self.read(RESUMPTION_PSK_STORE_LABEL, &serde_json::to_vec(group_id)?) + } + + fn write_resumption_psk_store< + GroupId: traits::GroupId, + ResumptionPskStore: traits::ResumptionPskStore, + >( + &self, + group_id: &GroupId, + resumption_psk_store: &ResumptionPskStore, + ) -> Result<(), Self::Error> { + self.write::( + RESUMPTION_PSK_STORE_LABEL, + &serde_json::to_vec(group_id)?, + &serde_json::to_vec(resumption_psk_store)?, + ) + } + + fn delete_all_resumption_psk_secrets>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(RESUMPTION_PSK_STORE_LABEL, &serde_json::to_vec(group_id)?) + } + + fn own_leaf_index< + GroupId: traits::GroupId, + LeafNodeIndex: traits::LeafNodeIndex, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(OWN_LEAF_NODE_INDEX_LABEL, group_id)?; + self.read(OWN_LEAF_NODE_INDEX_LABEL, &key) + } + + fn write_own_leaf_index< + GroupId: traits::GroupId, + LeafNodeIndex: traits::LeafNodeIndex, + >( + &self, + group_id: &GroupId, + own_leaf_index: &LeafNodeIndex, + ) -> Result<(), Self::Error> { + let key = build_key::(OWN_LEAF_NODE_INDEX_LABEL, group_id)?; + self.write::( + OWN_LEAF_NODE_INDEX_LABEL, + &key, + &serde_json::to_vec(own_leaf_index)?, + ) + } + + fn delete_own_leaf_index>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(OWN_LEAF_NODE_INDEX_LABEL, group_id)?; + self.delete::(OWN_LEAF_NODE_INDEX_LABEL, &key) + } + + fn use_ratchet_tree_extension>( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(USE_RATCHET_TREE_LABEL, group_id)?; + self.read(USE_RATCHET_TREE_LABEL, &key) + } + + fn set_use_ratchet_tree_extension>( + &self, + group_id: &GroupId, + value: bool, + ) -> Result<(), Self::Error> { + let key = build_key::(USE_RATCHET_TREE_LABEL, group_id)?; + self.write::(USE_RATCHET_TREE_LABEL, &key, &serde_json::to_vec(&value)?) + } + + fn delete_use_ratchet_tree_extension>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(USE_RATCHET_TREE_LABEL, group_id)?; + self.delete::(USE_RATCHET_TREE_LABEL, &key) + } + + fn group_epoch_secrets< + GroupId: traits::GroupId, + GroupEpochSecrets: traits::GroupEpochSecrets, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(EPOCH_SECRETS_LABEL, group_id)?; + self.read(EPOCH_SECRETS_LABEL, &key) + } + + fn write_group_epoch_secrets< + GroupId: traits::GroupId, + GroupEpochSecrets: traits::GroupEpochSecrets, + >( + &self, + group_id: &GroupId, + group_epoch_secrets: &GroupEpochSecrets, + ) -> Result<(), Self::Error> { + let key = build_key::(EPOCH_SECRETS_LABEL, group_id)?; + self.write::( + EPOCH_SECRETS_LABEL, + &key, + &serde_json::to_vec(group_epoch_secrets)?, + ) + } + + fn delete_group_epoch_secrets>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(EPOCH_SECRETS_LABEL, group_id)?; + self.delete::(EPOCH_SECRETS_LABEL, &key) + } + + fn write_encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + group_id: &GroupId, + epoch: &EpochKey, + leaf_index: u32, + key_pairs: &[HpkeKeyPair], + ) -> Result<(), Self::Error> { + let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?; + let value = serde_json::to_vec(key_pairs)?; + log::debug!("Writing encryption epoch key pairs"); + log::debug!(" key: {}", hex::encode(&key)); + log::debug!(" value: {}", hex::encode(&value)); + + self.write::(EPOCH_KEY_PAIRS_LABEL, &key, &value) + } - if entry_option.is_none() { - debug!("No entry to read for key {:?}", k); - return None; + fn encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + group_id: &GroupId, + epoch: &EpochKey, + leaf_index: u32, + ) -> Result, Self::Error> { + log::debug!("Reading encryption epoch key pairs"); + + let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?; + let storage_key = build_key_from_vec::(EPOCH_KEY_PAIRS_LABEL, key); + log::debug!(" key: {}", hex::encode(&storage_key)); + + let query = "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?"; + + let results: Result, diesel::result::Error> = + self.conn().raw_query(|conn| { + sql_query(query) + .bind::(&storage_key) + .bind::(CURRENT_VERSION as i32) + .load(conn) + }); + + match results { + Ok(data) => { + if let Some(entry) = data.into_iter().next() { + match serde_json::from_slice::>(&entry.value_bytes) { + Ok(deserialized) => Ok(deserialized), + Err(e) => { + eprintln!("Error occurred: {}", e); + Err(MemoryStorageError::SerializationError) + } + } + } else { + Ok(vec![]) + } + } + Err(_e) => Err(MemoryStorageError::None), } - db_deserialize(&entry_option.unwrap().value_bytes).ok() - } - - /// Delete a value stored for ID `k`. - /// - /// Interface is unclear on expected behavior when item is already deleted - - /// we choose to not surface an error if this is the case. - fn delete(&self, k: &[u8]) -> Result<(), Self::Error> { - let conn: &dyn Delete> = &self.conn(); - let num_deleted = conn.delete(k.to_vec())?; - if num_deleted == 0 { - debug!("No entry to delete for key {:?}", k); + } + + fn delete_encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + >( + &self, + group_id: &GroupId, + epoch: &EpochKey, + leaf_index: u32, + ) -> Result<(), Self::Error> { + let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?; + self.delete::(EPOCH_KEY_PAIRS_LABEL, &key) + } + + fn clear_proposal_queue< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(PROPOSAL_QUEUE_REFS_LABEL, group_id)?; + let proposal_refs: Vec = self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &key)?; + + for proposal_ref in proposal_refs { + let key = serde_json::to_vec(&(group_id, proposal_ref))?; + let _ = self.delete::(QUEUED_PROPOSAL_LABEL, &key); } + + let key = build_key::(PROPOSAL_QUEUE_REFS_LABEL, group_id)?; + let _ = self.delete::(PROPOSAL_QUEUE_REFS_LABEL, &key); + Ok(()) } + + fn mls_group_join_config< + GroupId: traits::GroupId, + MlsGroupJoinConfig: traits::MlsGroupJoinConfig, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(JOIN_CONFIG_LABEL, group_id)?; + self.read(JOIN_CONFIG_LABEL, &key) + } + + fn write_mls_join_config< + GroupId: traits::GroupId, + MlsGroupJoinConfig: traits::MlsGroupJoinConfig, + >( + &self, + group_id: &GroupId, + config: &MlsGroupJoinConfig, + ) -> Result<(), Self::Error> { + let key = build_key::(JOIN_CONFIG_LABEL, group_id)?; + let value = serde_json::to_vec(config)?; + + self.write::(JOIN_CONFIG_LABEL, &key, &value) + } + + fn own_leaf_nodes< + GroupId: traits::GroupId, + LeafNode: traits::LeafNode, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + log::debug!("own_leaf_nodes"); + let key = build_key::(OWN_LEAF_NODES_LABEL, group_id)?; + self.read_list(OWN_LEAF_NODES_LABEL, &key) + } + + fn append_own_leaf_node< + GroupId: traits::GroupId, + LeafNode: traits::LeafNode, + >( + &self, + group_id: &GroupId, + leaf_node: &LeafNode, + ) -> Result<(), Self::Error> { + let key = build_key::(OWN_LEAF_NODES_LABEL, group_id)?; + let value = serde_json::to_vec(leaf_node)?; + self.append::(OWN_LEAF_NODES_LABEL, &key, &value) + } + + fn clear_own_leaf_nodes>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(OWN_LEAF_NODES_LABEL, group_id)?; + self.delete::(OWN_LEAF_NODES_LABEL, &key) + } + + fn aad>( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = build_key::(AAD_LABEL, group_id)?; + match self.read::>(AAD_LABEL, &key) { + Ok(Some(value)) => Ok(value), + Ok(None) => Ok(Vec::new()), + Err(e) => Err(e), + } + } + + fn write_aad>( + &self, + group_id: &GroupId, + aad: &[u8], + ) -> Result<(), Self::Error> { + let key = build_key::(AAD_LABEL, group_id)?; + let value = serde_json::to_vec(&aad)?; + + self.write::(AAD_LABEL, &key, &value) + } + + fn delete_aad>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(AAD_LABEL, group_id)?; + self.delete::(AAD_LABEL, &key) + } + + fn delete_own_leaf_nodes>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(OWN_LEAF_NODES_LABEL, group_id)?; + self.delete::(OWN_LEAF_NODES_LABEL, &key) + } + + fn delete_group_config>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(JOIN_CONFIG_LABEL, group_id)?; + self.delete::(JOIN_CONFIG_LABEL, &key) + } + + fn delete_tree>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(TREE_LABEL, group_id)?; + self.delete::(TREE_LABEL, &key) + } + + fn delete_confirmation_tag>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(CONFIRMATION_TAG_LABEL, group_id)?; + self.delete::(CONFIRMATION_TAG_LABEL, &key) + } + + fn delete_context>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(GROUP_CONTEXT_LABEL, group_id)?; + self.delete::(GROUP_CONTEXT_LABEL, &key) + } + + fn delete_interim_transcript_hash>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = build_key::(INTERIM_TRANSCRIPT_HASH_LABEL, group_id)?; + self.delete::(INTERIM_TRANSCRIPT_HASH_LABEL, &key) + } + + fn remove_proposal< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + group_id: &GroupId, + proposal_ref: &ProposalRef, + ) -> Result<(), Self::Error> { + // Delete the proposal ref + let key = build_key::(PROPOSAL_QUEUE_REFS_LABEL, group_id)?; + let value = serde_json::to_vec(proposal_ref)?; + self.remove_item::(PROPOSAL_QUEUE_REFS_LABEL, &key, &value)?; + + // Delete the proposal + let key = serde_json::to_vec(&(group_id, proposal_ref))?; + self.delete::(QUEUED_PROPOSAL_LABEL, &key) + } +} + +/// Build a key with version and label. +fn build_key_from_vec(label: &[u8], key: Vec) -> Vec { + let mut key_out = label.to_vec(); + key_out.extend_from_slice(&key); + key_out.extend_from_slice(&u16::to_be_bytes(V)); + key_out +} + +/// Build a key with version and label. +fn build_key( + label: &[u8], + key: K, +) -> Result, MemoryStorageError> { + let key_vec = serde_json::to_vec(&key)?; + Ok(build_key_from_vec::(label, key_vec)) +} + +fn epoch_key_pairs_id( + group_id: &impl traits::GroupId, + epoch: &impl traits::EpochKey, + leaf_index: u32, +) -> Result, MemoryStorageError> { + let mut key = serde_json::to_vec(group_id)?; + key.extend_from_slice(&serde_json::to_vec(epoch)?); + key.extend_from_slice(&serde_json::to_vec(&leaf_index)?); + Ok(key) +} + +impl From for MemoryStorageError { + fn from(_: serde_json::Error) -> Self { + Self::SerializationError + } } #[cfg(test)] mod tests { - use openmls_basic_credential::SignatureKeyPair; - use openmls_traits::key_store::OpenMlsKeyStore; + use openmls::group::GroupId; + use openmls_basic_credential::{SignatureKeyPair, StorageId}; + use openmls_traits::{ + storage::{traits, Entity, Key, StorageProvider, CURRENT_VERSION}, + OpenMlsProvider, + }; + use serde::{Deserialize, Serialize}; use super::SqlKeyStore; use crate::{ configuration::CIPHERSUITE, - storage::{EncryptedMessageStore, StorageOption}, + storage::{sql_key_store::MemoryStorageError, EncryptedMessageStore, StorageOption}, utils::test::tmp_path, + xmtp_openmls_provider::XmtpOpenMlsProvider, }; #[test] @@ -96,14 +1127,176 @@ mod tests { EncryptedMessageStore::generate_enc_key(), ) .unwrap(); + let conn = &store.conn().unwrap(); let key_store = SqlKeyStore::new(conn.clone()); + let signature_keys = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm()).unwrap(); - let index = "index".as_bytes(); - assert!(key_store.read::(index).is_none()); - key_store.store(index, &signature_keys).unwrap(); - assert!(key_store.read::(index).is_some()); - key_store.delete::(index).unwrap(); - assert!(key_store.read::(index).is_none()); + let public_key = StorageId::from(signature_keys.to_public_vec()); + assert!(key_store + .signature_key_pair::(&public_key) + .unwrap() + .is_none()); + + key_store + .write_signature_key_pair::(&public_key, &signature_keys) + .unwrap(); + + assert!(key_store + .signature_key_pair::(&public_key) + .unwrap() + .is_some()); + + key_store + .delete_signature_key_pair::(&public_key) + .unwrap(); + + assert!(key_store + .signature_key_pair::(&public_key) + .unwrap() + .is_none()); + } + + #[test] + fn list_write_remove() { + let db_path = tmp_path(); + let store = EncryptedMessageStore::new( + StorageOption::Persistent(db_path), + EncryptedMessageStore::generate_enc_key(), + ) + .unwrap(); + let conn = store.conn().unwrap(); + let key_store = SqlKeyStore::new(conn.clone()); + let provider = XmtpOpenMlsProvider::new(conn); + let group_id = GroupId::random(provider.rand()); + + assert!(key_store.aad::(&group_id).unwrap().is_empty()); + + key_store + .write_aad::(&group_id, &"test".as_bytes()) + .unwrap(); + + assert!(!key_store.aad::(&group_id).unwrap().is_empty()); + + key_store.delete_aad::(&group_id).unwrap(); + + assert!(key_store.aad::(&group_id).unwrap().is_empty()); + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + struct Proposal(Vec); + impl traits::QueuedProposal for Proposal {} + impl Entity for Proposal {} + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)] + struct ProposalRef(usize); + impl traits::ProposalRef for ProposalRef {} + impl Key for ProposalRef {} + impl Entity for ProposalRef {} + + #[tokio::test] + async fn list_append_remove() { + let db_path = tmp_path(); + let store = EncryptedMessageStore::new( + StorageOption::Persistent(db_path), + EncryptedMessageStore::generate_enc_key(), + ) + .unwrap(); + let conn = store.conn().unwrap(); + let key_store = SqlKeyStore::new(conn.clone()); + let provider = XmtpOpenMlsProvider::new(conn); + let group_id = GroupId::random(provider.rand()); + let proposals = (0..10) + .map(|i| Proposal(format!("TestProposal{i}").as_bytes().to_vec())) + .collect::>(); + + // Store proposals + for (i, proposal) in proposals.iter().enumerate() { + key_store + .queue_proposal::( + &group_id, + &ProposalRef(i), + proposal, + ) + .unwrap(); + } + + // Read proposal refs + let proposal_refs_read: Vec = + key_store.queued_proposal_refs(&group_id).unwrap(); + assert_eq!( + (0..10).map(|i| ProposalRef(i)).collect::>(), + proposal_refs_read + ); + + // Read proposals + let proposals_read: Vec<(ProposalRef, Proposal)> = + key_store.queued_proposals(&group_id).unwrap(); + let proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10) + .map(|i| ProposalRef(i)) + .zip(proposals.clone().into_iter()) + .collect(); + assert_eq!(proposals_expected, proposals_read); + + // Remove proposal 5 + key_store + .remove_proposal(&group_id, &ProposalRef(5)) + .unwrap(); + + let proposal_refs_read: Vec = + key_store.queued_proposal_refs(&group_id).unwrap(); + let mut expected = (0..10).map(|i| ProposalRef(i)).collect::>(); + expected.remove(5); + assert_eq!(expected, proposal_refs_read); + + let proposals_read: Vec<(ProposalRef, Proposal)> = + key_store.queued_proposals(&group_id).unwrap(); + let mut proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10) + .map(|i| ProposalRef(i)) + .zip(proposals.clone().into_iter()) + .collect(); + proposals_expected.remove(5); + assert_eq!(proposals_expected, proposals_read); + + // Clear all proposals + key_store + .clear_proposal_queue::(&group_id) + .unwrap(); + let proposal_refs_read: Result, MemoryStorageError> = + key_store.queued_proposal_refs(&group_id); + assert!(proposal_refs_read.unwrap().is_empty()); + + let proposals_read: Result, MemoryStorageError> = + key_store.queued_proposals(&group_id); + assert!(proposals_read.unwrap().is_empty()); + } + + #[tokio::test] + async fn group_state() { + let db_path = tmp_path(); + let store = EncryptedMessageStore::new( + StorageOption::Persistent(db_path), + EncryptedMessageStore::generate_enc_key(), + ) + .unwrap(); + let conn = store.conn().unwrap(); + let key_store = SqlKeyStore::new(conn.clone()); + let provider = XmtpOpenMlsProvider::new(conn); + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)] + struct GroupState(usize); + impl traits::GroupState for GroupState {} + impl Entity for GroupState {} + + let group_id = GroupId::random(provider.rand()); + + // Group state + key_store + .write_group_state(&group_id, &GroupState(77)) + .unwrap(); + + // Read group state + let group_state: Option = key_store.group_state(&group_id).unwrap(); + assert_eq!(GroupState(77), group_state.unwrap()); } } diff --git a/xmtp_mls/src/verified_key_package.rs b/xmtp_mls/src/verified_key_package.rs index 82b7b8df9..e138a9b58 100644 --- a/xmtp_mls/src/verified_key_package.rs +++ b/xmtp_mls/src/verified_key_package.rs @@ -54,7 +54,7 @@ impl VerifiedKeyPackage { /// Validates starting with a KeyPackage (which is already validated by OpenMLS) pub fn from_key_package(kp: KeyPackage) -> Result { let leaf_node = kp.leaf_node(); - let basic_credential = BasicCredential::try_from(leaf_node.credential())?; + let basic_credential = BasicCredential::try_from(leaf_node.credential().clone())?; let pub_key_bytes = leaf_node.signature_key().as_slice(); let account_address = identity_to_account_address(basic_credential.identity(), pub_key_bytes)?; @@ -124,10 +124,8 @@ mod tests { extensions::{ ApplicationIdExtension, Extension, ExtensionType, Extensions, LastResortExtension, }, - group::config::CryptoConfig, prelude::Capabilities, prelude_test::KeyPackage, - versions::ProtocolVersion, }; use xmtp_cryptography::utils::generate_local_wallet; @@ -163,10 +161,7 @@ mod tests { .key_package_extensions(Extensions::single(last_resort)) .leaf_node_extensions(leaf_node_extensions) .build( - CryptoConfig { - ciphersuite: CIPHERSUITE, - version: ProtocolVersion::default(), - }, + CIPHERSUITE, &provider, &client.identity().installation_keys, CredentialWithKey { @@ -176,7 +171,7 @@ mod tests { ) .unwrap(); - let verified_kp_result = VerifiedKeyPackage::from_key_package(kp); + let verified_kp_result = VerifiedKeyPackage::from_key_package(kp.key_package().clone()); assert!(verified_kp_result.is_err()); assert_eq!( KeyPackageVerificationError::ApplicationIdCredentialMismatch( diff --git a/xmtp_mls/src/verified_key_package_v2.rs b/xmtp_mls/src/verified_key_package_v2.rs index 29ff3d7d8..769877291 100644 --- a/xmtp_mls/src/verified_key_package_v2.rs +++ b/xmtp_mls/src/verified_key_package_v2.rs @@ -63,7 +63,7 @@ impl TryFrom for VerifiedKeyPackageV2 { fn try_from(kp: KeyPackage) -> Result { let leaf_node = kp.leaf_node(); - let basic_credential = BasicCredential::try_from(leaf_node.credential())?; + let basic_credential = BasicCredential::try_from(leaf_node.credential().clone())?; let pub_key_bytes = leaf_node.signature_key().as_slice().to_vec(); let credential = MlsCredential::decode(basic_credential.identity())?; diff --git a/xmtp_mls/src/xmtp_openmls_provider.rs b/xmtp_mls/src/xmtp_openmls_provider.rs index 1902cfce1..dcc93ee1f 100644 --- a/xmtp_mls/src/xmtp_openmls_provider.rs +++ b/xmtp_mls/src/xmtp_openmls_provider.rs @@ -38,7 +38,7 @@ impl XmtpOpenMlsProvider { impl OpenMlsProvider for XmtpOpenMlsProvider { type CryptoProvider = RustCrypto; type RandProvider = RustCrypto; - type KeyStoreProvider = SqlKeyStore; + type StorageProvider = SqlKeyStore; fn crypto(&self) -> &Self::CryptoProvider { &self.crypto @@ -48,7 +48,7 @@ impl OpenMlsProvider for XmtpOpenMlsProvider { &self.crypto } - fn key_store(&self) -> &Self::KeyStoreProvider { + fn storage(&self) -> &Self::StorageProvider { &self.key_store } } @@ -56,7 +56,7 @@ impl OpenMlsProvider for XmtpOpenMlsProvider { impl<'a> OpenMlsProvider for &'a XmtpOpenMlsProvider { type CryptoProvider = RustCrypto; type RandProvider = RustCrypto; - type KeyStoreProvider = SqlKeyStore; + type StorageProvider = SqlKeyStore; fn crypto(&self) -> &Self::CryptoProvider { &self.crypto @@ -66,7 +66,7 @@ impl<'a> OpenMlsProvider for &'a XmtpOpenMlsProvider { &self.crypto } - fn key_store(&self) -> &Self::KeyStoreProvider { + fn storage(&self) -> &Self::StorageProvider { &self.key_store } } diff --git a/xmtp_proto/src/convert.rs b/xmtp_proto/src/convert.rs index 4f05e0a80..f85e8c8a0 100644 --- a/xmtp_proto/src/convert.rs +++ b/xmtp_proto/src/convert.rs @@ -12,7 +12,7 @@ mod inbox_id { fn try_from(proto: MlsCredential) -> Result { let bytes = proto.encode_to_vec(); - Ok(BasicCredential::new(bytes)?.into()) + Ok(BasicCredential::new(bytes).into()) } } }