diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index e2b9dae25..ec00a0d4b 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1093,7 +1093,7 @@ pub(crate) mod tests { .unwrap(); let conn = amal.store().conn().unwrap(); - conn.raw_query(|conn| diesel::delete(identity_updates::table).execute(conn)) + conn.raw_query_write(|conn| diesel::delete(identity_updates::table).execute(conn)) .unwrap(); let members = group.members().await.unwrap(); @@ -1424,6 +1424,7 @@ pub(crate) mod tests { not(target_arch = "wasm32"), tokio::test(flavor = "multi_thread", worker_threads = 1) )] + #[ignore] async fn test_add_remove_then_add_again() { let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await; @@ -1445,6 +1446,7 @@ pub(crate) mod tests { .unwrap(); assert_eq!(amal_group.members().await.unwrap().len(), 1); tracing::info!("Syncing bolas welcomes"); + // See if Bola can see that they were added to the group bola.sync_welcomes(&bola.mls_provider().unwrap()) .await diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index 369cbad10..883bd77c6 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -64,10 +64,12 @@ impl MlsGroup { intent_kind: IntentKind, intent_data: Vec, ) -> Result { - provider.transaction(|provider| { + let res = provider.transaction(|provider| { let conn = provider.conn_ref(); self.queue_intent_with_conn(conn, intent_kind, intent_data) - }) + }); + + res } fn queue_intent_with_conn( diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index ce62993dc..337db6c72 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -935,6 +935,8 @@ impl MlsGroup { intent_data.into(), )?; + tracing::warn!("This makes it here?"); + self.sync_until_intent_resolved(provider, intent.id).await } @@ -1250,13 +1252,13 @@ impl MlsGroup { state, hex::encode(self.group_id.clone()), ); - let new_records = conn + let new_records: Vec<_> = conn .insert_or_replace_consent_records(&[consent_record.clone()])? .into_iter() .map(UserPreferenceUpdate::ConsentUpdate) .collect(); - if self.client.history_sync_url().is_some() { + if !new_records.is_empty() && self.client.history_sync_url().is_some() { // Dispatch an update event so it can be synced across devices let _ = self .client @@ -2169,7 +2171,7 @@ pub(crate) mod tests { // The dm shows up let alix_groups = alix_conn - .raw_query(|conn| groups::table.load::(conn)) + .raw_query_read( |conn| groups::table.load::(conn)) .unwrap(); assert_eq!(alix_groups.len(), 2); // They should have the same ID @@ -3696,7 +3698,7 @@ pub(crate) mod tests { let conn_1: XmtpOpenMlsProvider = bo.store().conn().unwrap().into(); let conn_2 = bo.store().conn().unwrap(); conn_2 - .raw_query(|c| { + .raw_query_read( |c| { c.batch_execute("BEGIN EXCLUSIVE").unwrap(); Ok::<_, diesel::result::Error>(()) }) diff --git a/xmtp_mls/src/storage/encrypted_store/association_state.rs b/xmtp_mls/src/storage/encrypted_store/association_state.rs index 0eb194b26..f29d05341 100644 --- a/xmtp_mls/src/storage/encrypted_store/association_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/association_state.rs @@ -107,8 +107,9 @@ impl StoredAssociationState { .and(dsl::sequence_id.eq_any(sequence_ids)), ); - let association_states = - conn.raw_query(|query_conn| query.load::(query_conn))?; + let association_states = conn.raw_query_read( |query_conn| { + query.load::(query_conn) + })?; association_states .into_iter() diff --git a/xmtp_mls/src/storage/encrypted_store/consent_record.rs b/xmtp_mls/src/storage/encrypted_store/consent_record.rs index a0789d9e3..b70552a3b 100644 --- a/xmtp_mls/src/storage/encrypted_store/consent_record.rs +++ b/xmtp_mls/src/storage/encrypted_store/consent_record.rs @@ -48,7 +48,7 @@ impl DbConnection { entity: String, entity_type: ConsentType, ) -> Result, StorageError> { - Ok(self.raw_query(|conn| -> diesel::QueryResult<_> { + Ok(self.raw_query_read( |conn| -> diesel::QueryResult<_> { dsl::consent_records .filter(dsl::entity.eq(entity)) .filter(dsl::entity_type.eq(entity_type)) @@ -77,7 +77,7 @@ impl DbConnection { ); } - let changed = self.raw_query(|conn| -> diesel::QueryResult<_> { + let changed = self.raw_query_write( |conn| -> diesel::QueryResult<_> { let existing: Vec = query.load(conn)?; let changed: Vec<_> = records .iter() @@ -107,7 +107,7 @@ impl DbConnection { &self, record: &StoredConsentRecord, ) -> Result, StorageError> { - self.raw_query(|conn| { + self.raw_query_write( |conn| { let maybe_inserted_consent_record: Option = diesel::insert_into(dsl::consent_records) .values(record) diff --git a/xmtp_mls/src/storage/encrypted_store/conversation_list.rs b/xmtp_mls/src/storage/encrypted_store/conversation_list.rs index de1b7bd32..76e807d51 100644 --- a/xmtp_mls/src/storage/encrypted_store/conversation_list.rs +++ b/xmtp_mls/src/storage/encrypted_store/conversation_list.rs @@ -139,7 +139,7 @@ impl DbConnection { .select(conversation_list::all_columns()) .order(conversation_list_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } else { // Only include the specified states let query = query @@ -153,11 +153,11 @@ impl DbConnection { .select(conversation_list::all_columns()) .order(conversation_list_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } } else { // Handle the case where `consent_states` is `None` - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? }; // Were sync groups explicitly asked for? Was the include_sync_groups flag set to true? @@ -165,7 +165,7 @@ impl DbConnection { if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups { let query = conversation_list_dsl::conversation_list .filter(conversation_list_dsl::conversation_type.eq(ConversationType::Sync)); - let mut sync_groups = self.raw_query(|conn| query.load(conn))?; + let mut sync_groups = self.raw_query_read(|conn| query.load(conn))?; conversations.append(&mut sync_groups); } diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 045b897a9..c31189ad8 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -1,8 +1,10 @@ -use parking_lot::Mutex; -use std::fmt; -use std::sync::Arc; - use crate::storage::xmtp_openmls_provider::XmtpOpenMlsProvider; +use parking_lot::Mutex; +use std::{ + fmt, + sync::atomic::{AtomicBool, Ordering}, + sync::Arc, +}; #[cfg(not(target_arch = "wasm32"))] pub type DbConnection = DbConnectionPrivate; @@ -19,14 +21,22 @@ pub type DbConnection = DbConnectionPrivate { - inner: Arc>, + read: Arc>, + write: Option>>, + in_transaction: Arc, + transaction_lock: Arc>, } /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> - pub(super) fn from_arc_mutex(conn: Arc>) -> Self { - Self { inner: conn } + pub(super) fn from_arc_mutex(read: Arc>, write: Option>>) -> Self { + Self { + read, + write, + in_transaction: Arc::new(AtomicBool::new(false)), + transaction_lock: Arc::new(tokio::sync::Mutex::new(())), + } } } @@ -34,13 +44,38 @@ impl DbConnectionPrivate where C: diesel::Connection, { + pub(crate) async fn start_transaction<'a>(&'a self) -> tokio::sync::MutexGuard<'a, ()> { + let guard = self.transaction_lock.lock().await; + self.in_transaction.store(true, Ordering::SeqCst); + guard + } + + fn is_in_transaction(&self) -> bool { + self.in_transaction.load(Ordering::SeqCst) + } + + /// Do a scoped query with a mutable [`diesel::Connection`] + /// reference + pub(crate) fn raw_query_read(&self, fun: F) -> Result + where + F: FnOnce(&mut C) -> Result, + { + let mut lock = self.read.lock(); + fun(&mut lock) + } + /// Do a scoped query with a mutable [`diesel::Connection`] /// reference - pub(crate) fn raw_query(&self, fun: F) -> Result + pub(crate) fn raw_query_write(&self, fun: F) -> Result where F: FnOnce(&mut C) -> Result, { - let mut lock = self.inner.lock(); + if let Some(write_conn) = &self.write { + let mut lock = write_conn.lock(); + return fun(&mut lock); + } + + let mut lock = self.read.lock(); fun(&mut lock) } @@ -48,14 +83,41 @@ where /// without a scope /// Must be used with care. holding this reference while calling `raw_query` /// will cause a deadlock. - pub(super) fn inner_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> { - self.inner.lock() + pub(super) fn read_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> { + if self.is_in_transaction() { + if let Some(write) = &self.write { + return write.lock(); + } + } + self.read.lock() + } + + /// Internal-only API to get the underlying `diesel::Connection` reference + /// without a scope + pub(super) fn read_ref(&self) -> Arc> { + if self.is_in_transaction() { + if let Some(write) = &self.write { + return write.clone(); + }; + } + self.read.clone() } /// Internal-only API to get the underlying `diesel::Connection` reference /// without a scope - pub(super) fn inner_ref(&self) -> Arc> { - self.inner.clone() + /// Must be used with care. holding this reference while calling `raw_query` + /// will cause a deadlock. + pub(super) fn write_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> { + let Some(write) = &self.write else { + return self.read_mut_ref(); + }; + write.lock() + } + + /// Internal-only API to get the underlying `diesel::Connection` reference + /// without a scope + pub(super) fn write_ref(&self) -> Option>> { + self.write.clone() } } @@ -77,3 +139,12 @@ impl fmt::Debug for DbConnectionPrivate { .finish() } } + +pub struct TransactionGuard { + in_transaction: Arc, +} +impl Drop for TransactionGuard { + fn drop(&mut self) { + self.in_transaction.store(false, Ordering::SeqCst); + } +} diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index 89cde7088..fa79224ac 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -292,7 +292,7 @@ impl DbConnection { .select(groups_dsl::groups::all_columns()) .order(groups_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } else { // Only include the specified states let query = query @@ -305,11 +305,11 @@ impl DbConnection { .select(groups_dsl::groups::all_columns()) .order(groups_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } } else { // Handle the case where `consent_states` is `None` - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? }; // Were sync groups explicitly asked for? Was the include_sync_groups flag set to true? @@ -317,7 +317,7 @@ impl DbConnection { if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups { let query = groups_dsl::groups.filter(groups_dsl::conversation_type.eq(ConversationType::Sync)); - let mut sync_groups = self.raw_query(|conn| query.load(conn))?; + let mut sync_groups = self.raw_query_read(|conn| query.load(conn))?; groups.append(&mut sync_groups); } @@ -325,7 +325,7 @@ impl DbConnection { } pub fn consent_records(&self) -> Result, StorageError> { - Ok(self.raw_query(|conn| super::schema::consent_records::table.load(conn))?) + Ok(self.raw_query_read(|conn| super::schema::consent_records::table.load(conn))?) } pub fn all_sync_groups(&self) -> Result, StorageError> { @@ -333,7 +333,7 @@ impl DbConnection { .order(dsl::created_at_ns.desc()) .filter(dsl::conversation_type.eq(ConversationType::Sync)); - Ok(self.raw_query(|conn| query.load(conn))?) + Ok(self.raw_query_read(|conn| query.load(conn))?) } pub fn latest_sync_group(&self) -> Result, StorageError> { @@ -342,7 +342,7 @@ impl DbConnection { .filter(dsl::conversation_type.eq(ConversationType::Sync)) .limit(1); - Ok(self.raw_query(|conn| query.load(conn))?.pop()) + Ok(self.raw_query_read(|conn| query.load(conn))?.pop()) } /// Return a single group that matches the given ID @@ -350,7 +350,7 @@ impl DbConnection { let mut query = dsl::groups.order(dsl::created_at_ns.asc()).into_boxed(); query = query.limit(1).filter(dsl::id.eq(id)); - let groups: Vec = self.raw_query(|conn| query.load(conn))?; + let groups: Vec = self.raw_query_read(|conn| query.load(conn))?; // Manually extract the first element Ok(groups.into_iter().next()) @@ -365,7 +365,7 @@ impl DbConnection { .order(dsl::created_at_ns.asc()) .filter(dsl::welcome_id.eq(welcome_id)); - let groups: Vec = self.raw_query(|conn| query.load(conn))?; + let groups: Vec = self.raw_query_read(|conn| query.load(conn))?; if groups.len() > 1 { tracing::error!("More than one group found for welcome_id {}", welcome_id); } @@ -383,7 +383,7 @@ impl DbConnection { .filter(dsl::dm_id.eq(Some(dm_id))) .order(dsl::last_message_ns.desc()); - let groups: Vec = self.raw_query(|conn| query.load(conn))?; + let groups: Vec = self.raw_query_read(|conn| query.load(conn))?; if groups.len() > 1 { tracing::info!("More than one group found for dm_inbox_id {members:?}"); } @@ -397,7 +397,7 @@ impl DbConnection { group_id: GroupId, state: GroupMembershipState, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { diesel::update(dsl::groups.find(group_id.as_ref())) .set(dsl::membership_state.eq(state)) .execute(conn) @@ -407,7 +407,7 @@ impl DbConnection { } pub fn get_rotated_at_ns(&self, group_id: Vec) -> Result { - let last_ts: Option = self.raw_query(|conn| { + let last_ts: Option = self.raw_query_read(|conn| { let ts = dsl::groups .find(&group_id) .select(dsl::rotated_at_ns) @@ -423,7 +423,7 @@ impl DbConnection { /// Updates the 'last time checked' we checked for new installations. pub fn update_rotated_at_ns(&self, group_id: Vec) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { let now = xmtp_common::time::now_ns(); diesel::update(dsl::groups.find(&group_id)) .set(dsl::rotated_at_ns.eq(now)) @@ -434,7 +434,7 @@ impl DbConnection { } pub fn get_installations_time_checked(&self, group_id: Vec) -> Result { - let last_ts = self.raw_query(|conn| { + let last_ts = self.raw_query_read(|conn| { let ts = dsl::groups .find(&group_id) .select(dsl::installations_last_checked) @@ -448,7 +448,7 @@ impl DbConnection { /// Updates the 'last time checked' we checked for new installations. pub fn update_installations_time_checked(&self, group_id: Vec) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { let now = xmtp_common::time::now_ns(); diesel::update(dsl::groups.find(&group_id)) .set(dsl::installations_last_checked.eq(now)) @@ -460,7 +460,7 @@ impl DbConnection { pub fn insert_or_replace_group(&self, group: StoredGroup) -> Result { tracing::info!("Trying to insert group"); - let stored_group = self.raw_query(|conn| { + let stored_group = self.raw_query_write(|conn| { let maybe_inserted_group: Option = diesel::insert_into(dsl::groups) .values(&group) .on_conflict_do_nothing() @@ -673,7 +673,7 @@ pub(crate) mod tests { test_group.store(conn).unwrap(); assert_eq!( - conn.raw_query(|raw_conn| groups.first::(raw_conn)) + conn.raw_query_read(|raw_conn| groups.first::(raw_conn)) .unwrap(), test_group ); @@ -687,7 +687,7 @@ pub(crate) mod tests { with_connection(|conn| { let test_group = generate_group(None); - conn.raw_query(|raw_conn| { + conn.raw_query_write(|raw_conn| { diesel::insert_into(groups) .values(test_group.clone()) .execute(raw_conn) @@ -863,7 +863,7 @@ pub(crate) mod tests { with_connection(|conn| { let test_group = generate_group(None); - conn.raw_query(|raw_conn| { + conn.raw_query_write(|raw_conn| { diesel::insert_into(groups) .values(test_group.clone()) .execute(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 743eccc5e..464f7b9f6 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -119,8 +119,9 @@ impl_fetch!(StoredGroupIntent, group_intents, ID); impl Delete for DbConnection { type Key = ID; fn delete(&self, key: ID) -> Result { - Ok(self - .raw_query(|raw_conn| diesel::delete(dsl::group_intents.find(key)).execute(raw_conn))?) + Ok(self.raw_query_write( |raw_conn| { + diesel::delete(dsl::group_intents.find(key)).execute(raw_conn) + })?) } } @@ -155,7 +156,7 @@ impl DbConnection { &self, to_save: NewGroupIntent, ) -> Result { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_write( |conn| { diesel::insert_into(dsl::group_intents) .values(to_save) .get_result(conn) @@ -184,7 +185,7 @@ impl DbConnection { query = query.order(dsl::id.asc()); - Ok(self.raw_query(|conn| query.load::(conn))?) + Ok(self.raw_query_read( |conn| query.load::(conn))?) } // Set the intent with the given ID to `Published` and set the payload hash. Optionally add @@ -197,7 +198,7 @@ impl DbConnection { staged_commit: Option>, published_in_epoch: i64, ) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query_write( |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Published is from @@ -214,7 +215,7 @@ impl DbConnection { })?; if rows_changed == 0 { - let already_published = self.raw_query(|conn| { + let already_published = self.raw_query_read( |conn| { dsl::group_intents .filter(dsl::id.eq(intent_id)) .first::(conn) @@ -231,7 +232,7 @@ impl DbConnection { // Set the intent with the given ID to `Committed` pub fn set_group_intent_committed(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query_write( |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Committed is from @@ -252,7 +253,7 @@ impl DbConnection { // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and // `post_commit_data` pub fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query_write( |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to ToPublish is from @@ -278,7 +279,7 @@ impl DbConnection { /// Set the intent with the given ID to `Error` #[tracing::instrument(level = "trace", skip(self))] pub fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query_write( |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::state.eq(IntentState::Error)) @@ -298,7 +299,7 @@ impl DbConnection { &self, payload_hash: Vec, ) -> Result, StorageError> { - let result = self.raw_query(|conn| { + let result = self.raw_query_read( |conn| { dsl::group_intents .filter(dsl::payload_hash.eq(payload_hash)) .first::(conn) @@ -312,7 +313,7 @@ impl DbConnection { &self, intent_id: ID, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write( |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::publish_attempts.eq(dsl::publish_attempts + 1)) @@ -431,7 +432,7 @@ pub(crate) mod tests { } fn find_first_intent(conn: &DbConnection, group_id: group::ID) -> StoredGroupIntent { - conn.raw_query(|raw_conn| { + conn.raw_query_read( |raw_conn| { dsl::group_intents .filter(dsl::group_id.eq(group_id)) .first(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/group_message.rs b/xmtp_mls/src/storage/encrypted_store/group_message.rs index 20335e88c..ba6254dfe 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_message.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_message.rs @@ -290,7 +290,7 @@ impl DbConnection { query = query.limit(limit); } - Ok(self.raw_query(|conn| query.load::(conn))?) + Ok(self.raw_query_read( |conn| query.load::(conn))?) } /// Query for group messages with their reactions @@ -341,7 +341,7 @@ impl DbConnection { }; let reactions: Vec = - self.raw_query(|conn| reactions_query.load(conn))?; + self.raw_query_read( |conn| reactions_query.load(conn))?; // Group reactions by parent message id let mut reactions_by_reference: HashMap, Vec> = HashMap::new(); @@ -377,7 +377,7 @@ impl DbConnection { &self, id: MessageId, ) -> Result, StorageError> { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_read( |conn| { dsl::group_messages .filter(dsl::id.eq(id.as_ref())) .first(conn) @@ -390,7 +390,7 @@ impl DbConnection { group_id: GroupId, timestamp: i64, ) -> Result, StorageError> { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_read( |conn| { dsl::group_messages .filter(dsl::group_id.eq(group_id.as_ref())) .filter(dsl::sent_at_ns.eq(timestamp)) @@ -404,7 +404,7 @@ impl DbConnection { msg_id: &MessageId, timestamp: u64, ) -> Result { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_write( |conn| { diesel::update(dsl::group_messages) .filter(dsl::id.eq(msg_id.as_ref())) .set(( @@ -419,7 +419,7 @@ impl DbConnection { &self, msg_id: &MessageId, ) -> Result { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_write( |conn| { diesel::update(dsl::group_messages) .filter(dsl::id.eq(msg_id.as_ref())) .set((dsl::delivery_status.eq(DeliveryStatus::Failed),)) @@ -517,7 +517,7 @@ pub(crate) mod tests { } let count: i64 = conn - .raw_query(|raw_conn| { + .raw_query_read( |raw_conn| { dsl::group_messages .select(diesel::dsl::count_star()) .first(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/identity_update.rs b/xmtp_mls/src/storage/encrypted_store/identity_update.rs index 0a0c87b0a..bb37dd774 100644 --- a/xmtp_mls/src/storage/encrypted_store/identity_update.rs +++ b/xmtp_mls/src/storage/encrypted_store/identity_update.rs @@ -72,7 +72,7 @@ impl DbConnection { query = query.filter(dsl::sequence_id.le(sequence_id)); } - Ok(self.raw_query(|conn| query.load::(conn))?) + Ok(self.raw_query_read( |conn| query.load::(conn))?) } /// Batch insert identity updates, ignoring duplicates. @@ -81,7 +81,7 @@ impl DbConnection { &self, updates: &[StoredIdentityUpdate], ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write( |conn| { diesel::insert_or_ignore_into(dsl::identity_updates) .values(updates) .execute(conn)?; @@ -98,7 +98,7 @@ impl DbConnection { .filter(dsl::inbox_id.eq(inbox_id)) .into_boxed(); - Ok(self.raw_query(|conn| query.first::(conn))?) + Ok(self.raw_query_read( |conn| query.first::(conn))?) } /// Given a list of inbox_ids return a HashMap of each inbox ID -> highest known sequence ID @@ -115,7 +115,7 @@ impl DbConnection { // Get the results as a Vec of (inbox_id, sequence_id) tuples let result_tuples: Vec<(String, i64)> = self - .raw_query(|conn| query.load::<(String, Option)>(conn))? + .raw_query_read( |conn| query.load::<(String, Option)>(conn))? .into_iter() // Diesel needs an Option type for aggregations like max(sequence_id), so we // unwrap the option here diff --git a/xmtp_mls/src/storage/encrypted_store/key_package_history.rs b/xmtp_mls/src/storage/encrypted_store/key_package_history.rs index 9e43243c0..253894761 100644 --- a/xmtp_mls/src/storage/encrypted_store/key_package_history.rs +++ b/xmtp_mls/src/storage/encrypted_store/key_package_history.rs @@ -39,7 +39,7 @@ impl DbConnection { &self, hash_ref: Vec, ) -> Result { - let result = self.raw_query(|conn| { + let result = self.raw_query_read( |conn| { key_package_history::dsl::key_package_history .filter(key_package_history::dsl::key_package_hash_ref.eq(hash_ref)) .first::(conn) @@ -52,7 +52,7 @@ impl DbConnection { &self, id: i32, ) -> Result, StorageError> { - let result = self.raw_query(|conn| { + let result = self.raw_query_read( |conn| { key_package_history::dsl::key_package_history .filter(key_package_history::dsl::id.lt(id)) .load::(conn) @@ -65,7 +65,7 @@ impl DbConnection { &self, id: i32, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write( |conn| { diesel::delete( key_package_history::dsl::key_package_history .filter(key_package_history::dsl::id.lt(id)), diff --git a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs index 7fc4deb3d..44f1027c9 100644 --- a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs +++ b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs @@ -18,7 +18,7 @@ impl Delete for DbConnection { type Key = Vec; fn delete(&self, key: Vec) -> Result where { use super::schema::openmls_key_store::dsl::*; - Ok(self.raw_query(|conn| { + Ok(self.raw_query_write(|conn| { diesel::delete(openmls_key_store.filter(key_bytes.eq(key))).execute(conn) })?) } @@ -36,7 +36,7 @@ impl DbConnection { value_bytes: value, }; - self.raw_query(|conn| { + self.raw_query_write(|conn| { diesel::replace_into(openmls_key_store) .values(entry) .execute(conn) diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 9a96252a6..bc6444b22 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -179,7 +179,7 @@ pub mod private { #[tracing::instrument(level = "trace", skip_all)] pub(super) fn init_db(&mut self) -> Result<(), StorageError> { self.db.validate(&self.opts)?; - self.db.conn()?.raw_query(|conn| { + self.db.conn()?.raw_query_write(|conn| { conn.batch_execute("PRAGMA journal_mode = WAL;")?; tracing::info!("Running DB migrations"); conn.run_pending_migrations(MIGRATIONS)?; @@ -242,7 +242,7 @@ macro_rules! impl_fetch { type Key = (); fn fetch(&self, _key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; - Ok(self.raw_query(|conn| $table.first(conn).optional())?) + Ok(self.raw_query_read(|conn| $table.first(conn).optional())?) } } }; @@ -254,7 +254,7 @@ macro_rules! impl_fetch { type Key = $key; fn fetch(&self, key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; - Ok(self.raw_query(|conn| $table.find(key.clone()).first(conn).optional())?) + Ok(self.raw_query_read(|conn| $table.find(key.clone()).first(conn).optional())?) } } }; @@ -286,8 +286,9 @@ macro_rules! impl_fetch_list_with_key { keys: &[Self::Key], ) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::{$column, *}; - Ok(self - .raw_query(|conn| $table.filter($column.eq_any(keys)).load::<$model>(conn))?) + Ok(self.raw_query_read(|conn| { + $table.filter($column.eq_any(keys)).load::<$model>(conn) + })?) } } }; @@ -304,7 +305,7 @@ macro_rules! impl_store { &self, into: &$crate::storage::encrypted_store::db_connection::DbConnection, ) -> Result<(), $crate::StorageError> { - into.raw_query(|conn| { + into.raw_query_write(|conn| { diesel::insert_into($table::table) .values(self) .execute(conn) @@ -326,7 +327,7 @@ macro_rules! impl_store_or_ignore { &self, into: &$crate::storage::encrypted_store::db_connection::DbConnection, ) -> Result<(), $crate::StorageError> { - into.raw_query(|conn| { + into.raw_query_write(|conn| { diesel::insert_or_ignore_into($table::table) .values(self) .execute(conn) @@ -391,17 +392,19 @@ where E: From + From, { tracing::debug!("Transaction beginning"); - { - let connection = self.conn_ref(); - let mut connection = connection.inner_mut_ref(); + + let _guard = { + let wrapper = self.conn_ref(); + let mut connection = wrapper.write_mut_ref(); ::TransactionManager::begin_transaction(&mut *connection)?; - } + wrapper.start_transaction() + }; let conn = self.conn_ref(); match fun(self) { Ok(value) => { - conn.raw_query(|conn| { + conn.raw_query_write(|conn| { ::TransactionManager::commit_transaction(&mut *conn) })?; tracing::debug!("Transaction being committed"); @@ -409,7 +412,7 @@ where } Err(err) => { tracing::debug!("Transaction being rolled back"); - match conn.raw_query(|conn| { + match conn.raw_query_write(|conn| { ::TransactionManager::rollback_transaction(&mut *conn) }) { Ok(()) => Err(err), @@ -441,40 +444,43 @@ where E: From + From, Db: 'a, { - tracing::debug!("Transaction async beginning"); - { - let connection = self.conn_ref(); - let mut connection = connection.inner_mut_ref(); + tracing::info!("Transaction async beginning"); + let _guard = { + let wrapper = self.conn_ref(); + let mut connection = wrapper.write_mut_ref(); ::TransactionManager::begin_transaction(&mut *connection)?; - } + wrapper.start_transaction() + }; // ensuring we have only one strong reference let result = fun(self).await; - let local_connection = self.conn_ref().inner_ref(); - if Arc::strong_count(&local_connection) > 1 { + let local_read_connection = self.conn_ref().read_ref(); + let local_write_connection = self.conn_ref().write_ref(); + if Arc::strong_count(&local_read_connection) > 1 { tracing::warn!( "More than 1 strong connection references still exist during async transaction" ); } - if Arc::weak_count(&local_connection) > 1 { + if Arc::weak_count(&local_read_connection) > 1 { tracing::warn!("More than 1 weak connection references still exist during transaction"); } // after the closure finishes, `local_provider` should have the only reference ('strong') // to `XmtpOpenMlsProvider` inner `DbConnection`.. - let local_connection = DbConnectionPrivate::from_arc_mutex(local_connection); + let local_connection = + DbConnectionPrivate::from_arc_mutex(local_read_connection, local_write_connection); match result { Ok(value) => { - local_connection.raw_query(|conn| { + local_connection.raw_query_write(|conn| { ::TransactionManager::commit_transaction(&mut *conn) })?; - tracing::debug!("Transaction async being committed"); + tracing::info!("Transaction async being committed"); Ok(value) } Err(err) => { - tracing::debug!("Transaction async being rolled back"); - match local_connection.raw_query(|conn| { + tracing::info!("Transaction async being rolled back"); + match local_connection.raw_query_write(|conn| { ::TransactionManager::rollback_transaction(&mut *conn) }) { Ok(()) => Err(err), @@ -624,7 +630,7 @@ pub(crate) mod tests { .db .conn() .unwrap() - .raw_query(|conn| { + .raw_query_write(|conn| { for _ in 0..15 { conn.run_next_migration(MIGRATIONS)?; } @@ -670,14 +676,14 @@ pub(crate) mod tests { .db .conn() .unwrap() - .raw_query(|conn| { + .raw_query_write(|conn| { conn.run_pending_migrations(MIGRATIONS)?; Ok::<_, StorageError>(()) }) .unwrap(); let groups = conn - .raw_query(|conn| groups::table.load::(conn)) + .raw_query_read(|conn| groups::table.load::(conn)) .unwrap(); assert_eq!(groups.len(), 1); assert_eq!(&**groups[0].dm_id.as_ref().unwrap(), "dm:98765:inbox_id"); diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 635b1b4c7..8644cee09 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -7,7 +7,7 @@ use diesel::{ r2d2::{self, CustomizeConnection, PoolTransactionManager, PooledConnection}, Connection, }; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use std::sync::Arc; pub type ConnectionManager = r2d2::ConnectionManager; @@ -64,7 +64,7 @@ impl ValidatedConnection for UnencryptedConnection {} impl CustomizeConnection for UnencryptedConnection { fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), r2d2::Error> { - conn.batch_execute("PRAGMA busy_timeout = 5000;") + conn.batch_execute("PRAGMA query_only = ON; PRAGMA busy_timeout = 5000;") .map_err(r2d2::Error::QueryError)?; Ok(()) } @@ -89,9 +89,10 @@ impl StorageOption { } } -#[derive(Clone, Debug)] +#[derive(Clone)] /// Database used in `native` (everywhere but web) pub struct NativeDb { + pub(super) write_conn: Option>>, pub(super) pool: Arc>>, customizer: Option>, opts: StorageOption, @@ -106,9 +107,9 @@ impl NativeDb { let mut builder = Pool::builder(); let customizer = if let Some(key) = enc_key { - let enc_opts = EncryptedConnection::new(key, opts)?; - builder = builder.connection_customizer(Box::new(enc_opts.clone())); - Some(Box::new(enc_opts) as Box) + let enc_connection = EncryptedConnection::new(key, opts)?; + builder = builder.connection_customizer(Box::new(enc_connection.clone())); + Some(Box::new(enc_connection) as Box) } else if matches!(opts, StorageOption::Persistent(_)) { builder = builder.connection_customizer(Box::new(UnencryptedConnection)); Some(Box::new(UnencryptedConnection) as Box) @@ -125,7 +126,16 @@ impl NativeDb { .build(ConnectionManager::new(path))?, }; + let write_conn = if matches!(opts, StorageOption::Persistent(_)) { + let mut write_conn = pool.get()?; + write_conn.batch_execute("PRAGMA query_only = OFF;")?; + Some(Arc::new(Mutex::new(write_conn))) + } else { + None + }; + Ok(Self { + write_conn, pool: Arc::new(Some(pool).into()), customizer, opts: opts.clone(), @@ -156,9 +166,10 @@ impl XmtpDb for NativeDb { /// Returns the Wrapped [`super::db_connection::DbConnection`] Connection implementation for this Database fn conn(&self) -> Result, StorageError> { let conn = self.raw_conn()?; - Ok(DbConnectionPrivate::from_arc_mutex(Arc::new( - parking_lot::Mutex::new(conn), - ))) + Ok(DbConnectionPrivate::from_arc_mutex( + Arc::new(parking_lot::Mutex::new(conn)), + self.write_conn.clone(), + )) } fn validate(&self, opts: &StorageOption) -> Result<(), StorageError> { diff --git a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs index b1cfefcb0..2fb7787a1 100644 --- a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs @@ -74,7 +74,7 @@ impl DbConnection { entity_kind: EntityKind, ) -> Result, StorageError> { use super::schema::refresh_state::dsl; - let res = self.raw_query(|conn| { + let res = self.raw_query_read( |conn| { dsl::refresh_state .find((entity_id.as_ref(), entity_kind)) .first(conn) @@ -115,7 +115,7 @@ impl DbConnection { NotFound::RefreshStateByIdAndKind(entity_id.as_ref().to_vec(), entity_kind), )?; - let num_updated = self.raw_query(|conn| { + let num_updated = self.raw_query_write( |conn| { diesel::update(&state) .filter(dsl::cursor.lt(cursor)) .set(dsl::cursor.eq(cursor)) diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index 6723f0df1..b5cae60f3 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -204,7 +204,9 @@ impl EncryptedConnection { /// Output the corect order of PRAGMAS to instantiate a connection fn pragmas(&self) -> impl Display { - let Self { ref key, ref salt } = self; + let Self { + ref key, ref salt, .. + } = self; if let Some(s) = salt { format!( @@ -281,6 +283,7 @@ impl diesel::r2d2::CustomizeConnection fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), diesel::r2d2::Error> { conn.batch_execute(&format!( "{} + PRAGMA query_only = ON; PRAGMA busy_timeout = 5000;", self.pragmas() )) diff --git a/xmtp_mls/src/storage/encrypted_store/user_preferences.rs b/xmtp_mls/src/storage/encrypted_store/user_preferences.rs index 98170d52f..ed534ce79 100644 --- a/xmtp_mls/src/storage/encrypted_store/user_preferences.rs +++ b/xmtp_mls/src/storage/encrypted_store/user_preferences.rs @@ -37,7 +37,7 @@ impl<'a> From<&'a StoredUserPreferences> for NewStoredUserPreferences<'a> { impl Store for StoredUserPreferences { fn store(&self, conn: &DbConnection) -> Result<(), StorageError> { - conn.raw_query(|conn| { + conn.raw_query_write( |conn| { diesel::update(dsl::user_preferences) .set(self) .execute(conn) @@ -50,7 +50,7 @@ impl Store for StoredUserPreferences { impl StoredUserPreferences { pub fn load(conn: &DbConnection) -> Result { let query = dsl::user_preferences.order(dsl::id.desc()).limit(1); - let mut result = conn.raw_query(|conn| query.load::(conn))?; + let mut result = conn.raw_query_read( |conn| query.load::(conn))?; Ok(result.pop().unwrap_or_default()) } @@ -73,7 +73,7 @@ impl StoredUserPreferences { ])); let to_insert: NewStoredUserPreferences = (&preferences).into(); - conn.raw_query(|conn| { + conn.raw_query_write( |conn| { diesel::insert_into(dsl::user_preferences) .values(to_insert) .execute(conn) @@ -115,7 +115,7 @@ mod tests { // check that there are two preferences stored let query = dsl::user_preferences.order(dsl::id.desc()); let result = conn - .raw_query(|conn| query.load::(conn)) + .raw_query_read( |conn| query.load::(conn)) .unwrap(); assert_eq!(result.len(), 1); } diff --git a/xmtp_mls/src/storage/mod.rs b/xmtp_mls/src/storage/mod.rs index 4cda41805..05ebfbbec 100644 --- a/xmtp_mls/src/storage/mod.rs +++ b/xmtp_mls/src/storage/mod.rs @@ -57,12 +57,12 @@ pub mod test_util { for query in queries { let query = diesel::sql_query(query); - let _ = self.raw_query(|conn| query.execute(conn)).unwrap(); + let _ = self.raw_query_write(|conn| query.execute(conn)).unwrap(); } } pub fn intents_published(&self) -> i32 { - self.raw_query(|conn| { + self.raw_query_read(|conn| { let mut row = conn .load(sql_query( "SELECT intents_published FROM test_metadata WHERE rowid = 1", @@ -78,7 +78,7 @@ pub mod test_util { } pub fn intents_deleted(&self) -> i32 { - self.raw_query(|conn| { + self.raw_query_read(|conn| { let mut row = conn .load(sql_query("SELECT intents_deleted FROM test_metadata")) .unwrap(); @@ -92,7 +92,7 @@ pub mod test_util { } pub fn intents_created(&self) -> i32 { - self.raw_query(|conn| { + self.raw_query_read(|conn| { let mut row = conn .load(sql_query("SELECT intents_created FROM test_metadata")) .unwrap(); diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index 7fc9ed1c5..c9c2b40df 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -49,7 +49,7 @@ where &self, storage_key: &Vec, ) -> Result, diesel::result::Error> { - self.conn_ref().raw_query(|conn| { + self.conn_ref().raw_query_read(|conn| { sql_query(SELECT_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -62,7 +62,7 @@ where storage_key: &Vec, value: &[u8], ) -> Result { - self.conn_ref().raw_query(|conn| { + self.conn_ref().raw_query_write(|conn| { sql_query(REPLACE_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -76,7 +76,7 @@ where storage_key: &Vec, modified_data: &Vec, ) -> Result { - self.conn_ref().raw_query(|conn| { + self.conn_ref().raw_query_write(|conn| { sql_query(UPDATE_QUERY) .bind::(&modified_data) .bind::(&storage_key) @@ -224,7 +224,7 @@ where ) -> Result<(), >::Error> { let storage_key = build_key_from_vec::(label, key.to_vec()); - let _ = self.conn_ref().raw_query(|conn| { + let _ = self.conn_ref().raw_query_write(|conn| { sql_query(DELETE_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -809,7 +809,7 @@ where let query = "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?"; - let data: Vec = self.conn_ref().raw_query(|conn| { + let data: Vec = self.conn_ref().raw_query_read(|conn| { sql_query(query) .bind::(&storage_key) .bind::(CURRENT_VERSION as i32)