Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make pool read-only, with a single write connection. #1517

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
6 changes: 4 additions & 2 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1078,8 +1078,10 @@ pub(crate) mod tests {
.unwrap();

let conn = amal.store().conn().unwrap();
conn.raw_query(|conn| diesel::delete(identity_updates::table).execute(conn))
.unwrap();
conn.raw_query(true, |conn| {
diesel::delete(identity_updates::table).execute(conn)
})
.unwrap();

let members = group.members().await.unwrap();
// // The three installations should count as two members
Expand Down
8 changes: 4 additions & 4 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1250,13 +1250,13 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
state,
hex::encode(self.group_id.clone()),
);
let new_records = conn
let new_records: Vec<_> = conn
.insert_or_replace_consent_records(&[consent_record.clone()])?
.into_iter()
.map(UserPreferenceUpdate::ConsentUpdate)
.collect();

if self.client.history_sync_url().is_some() {
if !new_records.is_empty() && self.client.history_sync_url().is_some() {
// Dispatch an update event so it can be synced across devices
let _ = self
.client
Expand Down Expand Up @@ -2170,7 +2170,7 @@ pub(crate) mod tests {

// The dm shows up
let alix_groups = alix_conn
.raw_query(|conn| groups::table.load::<StoredGroup>(conn))
.raw_query(false, |conn| groups::table.load::<StoredGroup>(conn))
.unwrap();
assert_eq!(alix_groups.len(), 2);
// They should have the same ID
Expand Down Expand Up @@ -3697,7 +3697,7 @@ pub(crate) mod tests {
let conn_1: XmtpOpenMlsProvider = bo.store().conn().unwrap().into();
let conn_2 = bo.store().conn().unwrap();
conn_2
.raw_query(|c| {
.raw_query(false, |c| {
c.batch_execute("BEGIN EXCLUSIVE").unwrap();
Ok::<_, diesel::result::Error>(())
})
Expand Down
5 changes: 3 additions & 2 deletions xmtp_mls/src/storage/encrypted_store/association_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ impl StoredAssociationState {
.and(dsl::sequence_id.eq_any(sequence_ids)),
);

let association_states =
conn.raw_query(|query_conn| query.load::<StoredAssociationState>(query_conn))?;
let association_states = conn.raw_query(false, |query_conn| {
query.load::<StoredAssociationState>(query_conn)
})?;

association_states
.into_iter()
Expand Down
6 changes: 3 additions & 3 deletions xmtp_mls/src/storage/encrypted_store/consent_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl DbConnection {
entity: String,
entity_type: ConsentType,
) -> Result<Option<StoredConsentRecord>, StorageError> {
Ok(self.raw_query(|conn| -> diesel::QueryResult<_> {
Ok(self.raw_query(false, |conn| -> diesel::QueryResult<_> {
dsl::consent_records
.filter(dsl::entity.eq(entity))
.filter(dsl::entity_type.eq(entity_type))
Expand Down Expand Up @@ -77,7 +77,7 @@ impl DbConnection {
);
}

let changed = self.raw_query(|conn| -> diesel::QueryResult<_> {
let changed = self.raw_query(true, |conn| -> diesel::QueryResult<_> {
let existing: Vec<StoredConsentRecord> = query.load(conn)?;
let changed: Vec<_> = records
.iter()
Expand Down Expand Up @@ -107,7 +107,7 @@ impl DbConnection {
&self,
record: &StoredConsentRecord,
) -> Result<Option<StoredConsentRecord>, StorageError> {
self.raw_query(|conn| {
self.raw_query(true, |conn| {
let maybe_inserted_consent_record: Option<StoredConsentRecord> =
diesel::insert_into(dsl::consent_records)
.values(record)
Expand Down
8 changes: 4 additions & 4 deletions xmtp_mls/src/storage/encrypted_store/conversation_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl DbConnection {
.select(conversation_list::all_columns())
.order(conversation_list_dsl::created_at_ns.asc());

self.raw_query(|conn| query.load::<ConversationListItem>(conn))?
self.raw_query(false, |conn| query.load::<ConversationListItem>(conn))?
} else {
let query = query
.inner_join(
Expand All @@ -141,18 +141,18 @@ impl DbConnection {
.select(conversation_list::all_columns())
.order(conversation_list_dsl::created_at_ns.asc());

self.raw_query(|conn| query.load::<ConversationListItem>(conn))?
self.raw_query(false, |conn| query.load::<ConversationListItem>(conn))?
}
} else {
self.raw_query(|conn| query.load::<ConversationListItem>(conn))?
self.raw_query(false, |conn| query.load::<ConversationListItem>(conn))?
};

// Were sync groups explicitly asked for? Was the include_sync_groups flag set to true?
// Then query for those separately
if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups {
let query = conversation_list_dsl::conversation_list
.filter(conversation_list_dsl::conversation_type.eq(ConversationType::Sync));
let mut sync_groups = self.raw_query(|conn| query.load(conn))?;
let mut sync_groups = self.raw_query(false, |conn| query.load(conn))?;
conversations.append(&mut sync_groups);
}

Expand Down
32 changes: 23 additions & 9 deletions xmtp_mls/src/storage/encrypted_store/db_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ pub type DbConnection = DbConnectionPrivate<sqlite_web::connection::WasmSqliteCo
// callers should be able to accomplish everything with one conn/reference.
#[doc(hidden)]
pub struct DbConnectionPrivate<C> {
inner: Arc<Mutex<C>>,
read: Arc<Mutex<C>>,
write: Option<Arc<Mutex<C>>>,
}

/// Owned DBConnection Methods
impl<C> DbConnectionPrivate<C> {
/// Create a new [`DbConnectionPrivate`] from an existing Arc<Mutex<C>>
pub(super) fn from_arc_mutex(conn: Arc<Mutex<C>>) -> Self {
Self { inner: conn }
pub(super) fn from_arc_mutex(read: Arc<Mutex<C>>, write: Option<Arc<Mutex<C>>>) -> Self {
Self { read, write }
}
}

Expand All @@ -36,26 +37,39 @@ where
{
/// Do a scoped query with a mutable [`diesel::Connection`]
/// reference
pub(crate) fn raw_query<T, E, F>(&self, fun: F) -> Result<T, E>
pub(crate) fn raw_query<T, E, F>(&self, write: bool, fun: F) -> Result<T, E>
where
F: FnOnce(&mut C) -> Result<T, E>,
{
let mut lock = self.inner.lock();
if write {
if let Some(write_conn) = &self.write {
let mut lock = write_conn.lock();
return fun(&mut lock);
}
}

let mut lock = self.read.lock();
fun(&mut lock)
}

/// Internal-only API to get the underlying `diesel::Connection` reference
/// without a scope
/// Must be used with care. holding this reference while calling `raw_query`
/// will cause a deadlock.
pub(super) fn inner_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> {
self.inner.lock()
pub(super) fn read_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> {
self.read.lock()
}

/// Internal-only API to get the underlying `diesel::Connection` reference
/// without a scope
pub(super) fn read_ref(&self) -> Arc<Mutex<C>> {
self.read.clone()
}

/// Internal-only API to get the underlying `diesel::Connection` reference
/// without a scope
pub(super) fn inner_ref(&self) -> Arc<Mutex<C>> {
self.inner.clone()
pub(super) fn write_ref(&self) -> Option<Arc<Mutex<C>>> {
self.write.clone()
}
}

Expand Down
40 changes: 21 additions & 19 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ impl DbConnection {
.select(groups_dsl::groups::all_columns())
.order(groups_dsl::created_at_ns.asc());

self.raw_query(|conn| query.load::<StoredGroup>(conn))?
self.raw_query(false, |conn| query.load::<StoredGroup>(conn))?
} else {
let query = query
.inner_join(
Expand All @@ -293,34 +293,36 @@ impl DbConnection {
.select(groups_dsl::groups::all_columns())
.order(groups_dsl::created_at_ns.asc());

self.raw_query(|conn| query.load::<StoredGroup>(conn))?
self.raw_query(false, |conn| query.load::<StoredGroup>(conn))?
}
} else {
self.raw_query(|conn| query.load::<StoredGroup>(conn))?
self.raw_query(false, |conn| query.load::<StoredGroup>(conn))?
};

// Were sync groups explicitly asked for? Was the include_sync_groups flag set to true?
// Then query for those separately
if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups {
let query =
groups_dsl::groups.filter(groups_dsl::conversation_type.eq(ConversationType::Sync));
let mut sync_groups = self.raw_query(|conn| query.load(conn))?;
let mut sync_groups = self.raw_query(false, |conn| query.load(conn))?;
groups.append(&mut sync_groups);
}

Ok(groups)
}

pub fn consent_records(&self) -> Result<Vec<StoredConsentRecord>, StorageError> {
Ok(self.raw_query(|conn| super::schema::consent_records::table.load(conn))?)
Ok(self.raw_query(false, |conn| {
super::schema::consent_records::table.load(conn)
})?)
}

pub fn all_sync_groups(&self) -> Result<Vec<StoredGroup>, StorageError> {
let query = dsl::groups
.order(dsl::created_at_ns.desc())
.filter(dsl::conversation_type.eq(ConversationType::Sync));

Ok(self.raw_query(|conn| query.load(conn))?)
Ok(self.raw_query(false, |conn| query.load(conn))?)
}

pub fn latest_sync_group(&self) -> Result<Option<StoredGroup>, StorageError> {
Expand All @@ -329,15 +331,15 @@ impl DbConnection {
.filter(dsl::conversation_type.eq(ConversationType::Sync))
.limit(1);

Ok(self.raw_query(|conn| query.load(conn))?.pop())
Ok(self.raw_query(false, |conn| query.load(conn))?.pop())
}

/// Return a single group that matches the given ID
pub fn find_group(&self, id: Vec<u8>) -> Result<Option<StoredGroup>, StorageError> {
let mut query = dsl::groups.order(dsl::created_at_ns.asc()).into_boxed();

query = query.limit(1).filter(dsl::id.eq(id));
let groups: Vec<StoredGroup> = self.raw_query(|conn| query.load(conn))?;
let groups: Vec<StoredGroup> = self.raw_query(false, |conn| query.load(conn))?;

// Manually extract the first element
Ok(groups.into_iter().next())
Expand All @@ -352,7 +354,7 @@ impl DbConnection {
.order(dsl::created_at_ns.asc())
.filter(dsl::welcome_id.eq(welcome_id));

let groups: Vec<StoredGroup> = self.raw_query(|conn| query.load(conn))?;
let groups: Vec<StoredGroup> = self.raw_query(false, |conn| query.load(conn))?;
if groups.len() > 1 {
tracing::error!("More than one group found for welcome_id {}", welcome_id);
}
Expand All @@ -370,7 +372,7 @@ impl DbConnection {
.filter(dsl::dm_id.eq(Some(dm_id)))
.order(dsl::last_message_ns.desc());

let groups: Vec<StoredGroup> = self.raw_query(|conn| query.load(conn))?;
let groups: Vec<StoredGroup> = self.raw_query(false, |conn| query.load(conn))?;
if groups.len() > 1 {
tracing::info!("More than one group found for dm_inbox_id {members:?}");
}
Expand All @@ -384,7 +386,7 @@ impl DbConnection {
group_id: GroupId,
state: GroupMembershipState,
) -> Result<(), StorageError> {
self.raw_query(|conn| {
self.raw_query(true, |conn| {
diesel::update(dsl::groups.find(group_id.as_ref()))
.set(dsl::membership_state.eq(state))
.execute(conn)
Expand All @@ -394,7 +396,7 @@ impl DbConnection {
}

pub fn get_rotated_at_ns(&self, group_id: Vec<u8>) -> Result<i64, StorageError> {
let last_ts: Option<i64> = self.raw_query(|conn| {
let last_ts: Option<i64> = self.raw_query(false, |conn| {
let ts = dsl::groups
.find(&group_id)
.select(dsl::rotated_at_ns)
Expand All @@ -410,7 +412,7 @@ impl DbConnection {

/// Updates the 'last time checked' we checked for new installations.
pub fn update_rotated_at_ns(&self, group_id: Vec<u8>) -> Result<(), StorageError> {
self.raw_query(|conn| {
self.raw_query(true, |conn| {
let now = xmtp_common::time::now_ns();
diesel::update(dsl::groups.find(&group_id))
.set(dsl::rotated_at_ns.eq(now))
Expand All @@ -421,7 +423,7 @@ impl DbConnection {
}

pub fn get_installations_time_checked(&self, group_id: Vec<u8>) -> Result<i64, StorageError> {
let last_ts = self.raw_query(|conn| {
let last_ts = self.raw_query(false, |conn| {
let ts = dsl::groups
.find(&group_id)
.select(dsl::installations_last_checked)
Expand All @@ -435,7 +437,7 @@ impl DbConnection {

/// Updates the 'last time checked' we checked for new installations.
pub fn update_installations_time_checked(&self, group_id: Vec<u8>) -> Result<(), StorageError> {
self.raw_query(|conn| {
self.raw_query(true, |conn| {
let now = xmtp_common::time::now_ns();
diesel::update(dsl::groups.find(&group_id))
.set(dsl::installations_last_checked.eq(now))
Expand All @@ -447,7 +449,7 @@ impl DbConnection {

pub fn insert_or_replace_group(&self, group: StoredGroup) -> Result<StoredGroup, StorageError> {
tracing::info!("Trying to insert group");
let stored_group = self.raw_query(|conn| {
let stored_group = self.raw_query(true, |conn| {
let maybe_inserted_group: Option<StoredGroup> = diesel::insert_into(dsl::groups)
.values(&group)
.on_conflict_do_nothing()
Expand Down Expand Up @@ -660,7 +662,7 @@ pub(crate) mod tests {

test_group.store(conn).unwrap();
assert_eq!(
conn.raw_query(|raw_conn| groups.first::<StoredGroup>(raw_conn))
conn.raw_query(false, |raw_conn| groups.first::<StoredGroup>(raw_conn))
.unwrap(),
test_group
);
Expand All @@ -674,7 +676,7 @@ pub(crate) mod tests {
with_connection(|conn| {
let test_group = generate_group(None);

conn.raw_query(|raw_conn| {
conn.raw_query(true, |raw_conn| {
diesel::insert_into(groups)
.values(test_group.clone())
.execute(raw_conn)
Expand Down Expand Up @@ -850,7 +852,7 @@ pub(crate) mod tests {
with_connection(|conn| {
let test_group = generate_group(None);

conn.raw_query(|raw_conn| {
conn.raw_query(true, |raw_conn| {
diesel::insert_into(groups)
.values(test_group.clone())
.execute(raw_conn)
Expand Down
Loading
Loading