diff --git a/Cargo.lock b/Cargo.lock index 94baec5b631..32481828874 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -233,6 +233,55 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-executor" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17adb73da160dfb475c183343c8cccd80721ea5a605d3eb57125f0a7b7a92d0b" +dependencies = [ + "async-lock", + "async-task", + "concurrent-queue", + "fastrand", + "futures-lite", + "slab", +] + +[[package]] +name = "async-global-executor" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1b6f5d7df27bd294849f8eec66ecfc63d11814df7a4f5d74168a2394467b776" +dependencies = [ + "async-channel", + "async-executor", + "async-io", + "async-lock", + "blocking", + "futures-lite", + "once_cell", +] + +[[package]] +name = "async-io" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c374dda1ed3e7d8f0d9ba58715f924862c63eae6849c92d3a18e7fbde9e2794" +dependencies = [ + "async-lock", + "autocfg", + "concurrent-queue", + "futures-lite", + "libc", + "log", + "parking", + "polling", + "slab", + "socket2", + "waker-fn", + "windows-sys 0.42.0", +] + [[package]] name = "async-lock" version = "2.6.0" @@ -249,6 +298,51 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "390a110411bbc7c93b77a736cbd694f64cb06dfa2702173f63169d7a1e1b5298" +[[package]] +name = "async-process" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6381ead98388605d0d9ff86371043b5aa922a3905824244de40dc263a14fcba4" +dependencies = [ + "async-io", + "async-lock", + "autocfg", + "blocking", + "cfg-if", + "event-listener", + "futures-lite", + "libc", + "signal-hook", + "windows-sys 0.42.0", +] + +[[package]] +name = "async-std" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62565bb4402e926b29953c785397c6dc0391b7b446e45008b0049eb43cec6f5d" +dependencies = [ + "async-channel", + "async-global-executor", + "async-io", + "async-lock", + "async-process", + "crossbeam-utils", + "futures-channel", + "futures-core", + "futures-io", + "futures-lite", + "gloo-timers", + "kv-log-macro", + "log", + "memchr", + "once_cell", + "pin-project-lite", + "pin-utils", + "slab", + "wasm-bindgen-futures", +] + [[package]] name = "async-stream" version = "0.3.3" @@ -270,6 +364,12 @@ dependencies = [ "syn", ] +[[package]] +name = "async-task" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a40729d2133846d9ed0ea60a8b9541bccddab49cd30f0715a1da672fe9a2524" + [[package]] name = "async-trait" version = "0.1.64" @@ -290,6 +390,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "atomic-waker" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "debc29dde2e69f9e47506b525f639ed42300fc014a3e007832592448fa8e4599" + [[package]] name = "atty" version = "0.2.14" @@ -482,6 +588,20 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blocking" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c67b173a56acffd6d2326fb7ab938ba0b00a71480e14902b2591c87bc5741e8" +dependencies = [ + "async-channel", + "async-lock", + "async-task", + "atomic-waker", + "fastrand", + "futures-lite", +] + [[package]] name = "bs58" version = "0.4.0" @@ -2396,6 +2516,15 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "984e109462d46ad18314f10e392c286c3d47bce203088a09012de1015b45b737" +[[package]] +name = "kv-log-macro" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de8b303297635ad57c9f5059fd9cee7a47f8e8daa09df0fcd07dd39fb22977f" +dependencies = [ + "log", +] + [[package]] name = "lazy-regex" version = "2.4.1" @@ -2490,6 +2619,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" dependencies = [ "cfg-if", + "value-bag", ] [[package]] @@ -2730,6 +2860,7 @@ dependencies = [ "aes", "anyhow", "assert_matches", + "async-std", "async-trait", "atomic", "base64 0.21.0", @@ -3804,6 +3935,20 @@ dependencies = [ "miniz_oxide 0.6.2", ] +[[package]] +name = "polling" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22122d5ec4f9fe1b3916419b76be1e80bcb93f618d071d2edf841b137b2a2bd6" +dependencies = [ + "autocfg", + "cfg-if", + "libc", + "log", + "wepoll-ffi", + "windows-sys 0.42.0", +] + [[package]] name = "poly1305" version = "0.7.2" @@ -5825,6 +5970,16 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "value-bag" +version = "1.0.0-alpha.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2209b78d1249f7e6f3293657c9779fe31ced465df091bbd433a1cf88e916ec55" +dependencies = [ + "ctor", + "version_check", +] + [[package]] name = "vcpkg" version = "0.2.15" @@ -6053,6 +6208,15 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" +[[package]] +name = "wepoll-ffi" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d743fdedc5c64377b5fc2bc036b01c7fd642205a0d96356034ae3404d49eb7fb" +dependencies = [ + "cc", +] + [[package]] name = "which" version = "4.4.0" diff --git a/crates/matrix-sdk-crypto/Cargo.toml b/crates/matrix-sdk-crypto/Cargo.toml index c3a5e3283c3..5f90c8556aa 100644 --- a/crates/matrix-sdk-crypto/Cargo.toml +++ b/crates/matrix-sdk-crypto/Cargo.toml @@ -27,6 +27,7 @@ testing = ["dep:http"] [dependencies] aes = "0.8.1" atomic = "0.5.1" +async-std = { version = "1.12.0", features = ["unstable"] } async-trait = { workspace = true } base64 = { workspace = true } bs58 = { version = "0.4.0", optional = true } diff --git a/crates/matrix-sdk-crypto/src/identities/manager.rs b/crates/matrix-sdk-crypto/src/identities/manager.rs index 1943e4d2309..f4c1197bea6 100644 --- a/crates/matrix-sdk-crypto/src/identities/manager.rs +++ b/crates/matrix-sdk-crypto/src/identities/manager.rs @@ -69,6 +69,9 @@ pub(crate) struct KeysQueryListener { pub(crate) enum UserKeyQueryResult { WasPending, WasNotPending, + + /// A query was pending, but we gave up waiting + TimeoutExpired, } impl KeysQueryListener { diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index fb8afe13af6..3731ac0b752 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -1259,9 +1259,7 @@ impl OlmMachine { async fn wait_if_user_pending(&self, user_id: &UserId, timeout: Option) { if let Some(timeout) = timeout { - let listener = self.identity_manager.listen_for_received_queries(); - - let _ = listener.wait_if_user_pending(timeout, user_id).await; + self.store.wait_if_user_key_query_pending(timeout, user_id).await; } } diff --git a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs index 22f39128af0..5687ade05f3 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs @@ -176,11 +176,11 @@ impl SessionManager { let user_devices = if user_devices.is_empty() { match self - .keys_query_listener - .wait_if_user_pending(Self::KEYS_QUERY_WAIT_TIME, user_id) + .store + .wait_if_user_key_query_pending(Self::KEYS_QUERY_WAIT_TIME, user_id) .await { - Ok(WasPending) => self.store.get_readonly_devices_filtered(user_id).await?, + WasPending => self.store.get_readonly_devices_filtered(user_id).await?, _ => user_devices, } } else { @@ -404,15 +404,22 @@ mod tests { use matrix_sdk_common::locks::Mutex; use matrix_sdk_test::{async_test, response_from_file}; use ruma::{ - api::{client::keys::claim_keys::v3::Response as KeyClaimResponse, IncomingResponse}, + api::{ + client::keys::{ + claim_keys::v3::Response as KeyClaimResponse, + get_keys::v3::Response as KeysQueryResponse, + }, + IncomingResponse, + }, device_id, user_id, DeviceId, UserId, }; use serde_json::json; + use tracing::info; use super::SessionManager; use crate::{ gossiping::GossipMachine, - identities::{KeysQueryListener, ReadOnlyDevice}, + identities::{IdentityManager, KeysQueryListener, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, session_manager::GroupSessionCache, store::{IntoCryptoStore, MemoryStore, Store}, @@ -528,6 +535,56 @@ mod tests { assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none()); } + #[async_test] + async fn session_creation_waits_for_keys_query() { + let manager = session_manager().await; + let identity_manager = IdentityManager::new( + manager.account.user_id.clone(), + manager.account.device_id.clone(), + manager.store.clone(), + ); + + // start a keys query request. At this point, we are only interested in our own + // devices. + let (key_query_txn_id, key_query_request) = + identity_manager.users_for_key_query().await.unwrap().pop().unwrap(); + info!("Initial key query: {:?}", key_query_request); + + // now bob turns up, and we start tracking his devices... + let bob = bob_account(); + let bob_device = ReadOnlyDevice::from_account(&bob).await; + manager.store.update_tracked_users(iter::once(bob.user_id())).await.unwrap(); + + // ... and start off an attempt to get the missing sessions. This should block + // for now. + let missing_sessions_future = manager.get_missing_sessions(iter::once(bob.user_id())); + + // the initial keys query completes, and we start another + let response_json = json!({ "device_keys": { manager.account.user_id(): {}}}); + let response = + KeysQueryResponse::try_from_http_response(response_from_file(&response_json)).unwrap(); + identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap(); + + let (key_query_txn_id, key_query_request) = + identity_manager.users_for_key_query().await.unwrap().pop().unwrap(); + info!("Second key query: {:?}", key_query_request); + + // that second request completes with info on bob's device + let response_json = json!({ "device_keys": { bob.user_id(): { + bob_device.device_id(): bob_device.as_device_keys() + }}}); + let response = + KeysQueryResponse::try_from_http_response(response_from_file(&response_json)).unwrap(); + identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap(); + + // the missing_sessions_future should now finally complete, with a claim + // including bob's device + let (_, keys_claim_request) = missing_sessions_future.await.unwrap().unwrap(); + //info!("Key claim: {:?}", keys_claim_request); + let bob_key_claims = keys_claim_request.one_time_keys.get(bob.user_id()).unwrap(); + assert!(bob_key_claims.contains_key(bob_device.device_id())); + } + // This test doesn't run on macos because we're modifying the session // creation time so we can get around the UNWEDGING_INTERVAL. #[async_test] diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 467cba2d182..8623f527098 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -42,9 +42,11 @@ use std::{ collections::{HashMap, HashSet}, fmt::Debug, ops::Deref, - sync::{atomic::AtomicBool, Arc}, + sync::{atomic::AtomicBool, Arc, Weak}, + time::Duration, }; +use async_std::sync::{Condvar, Mutex as AsyncStdMutex}; use atomic::Ordering; use dashmap::DashSet; use matrix_sdk_common::locks::Mutex; @@ -80,10 +82,12 @@ mod traits; pub mod integration_tests; pub use error::{CryptoStoreError, Result}; +use matrix_sdk_common::timeout::timeout; pub use memorystore::MemoryStore; pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore}; pub use crate::gossiping::{GossipRequest, SecretInfo}; +use crate::identities::UserKeyQueryResult; /// A wrapper for our CryptoStore trait object. /// @@ -98,7 +102,17 @@ pub(crate) struct Store { inner: Arc, verification_machine: VerificationMachine, tracked_users_cache: Arc>, - users_for_key_query: Arc>, + + /// Record of the users that are waiting for a /keys/query. + // + // This uses an async_std::sync::Mutex rather than a + // matrix_sdk_common::locks::Mutex because it has to match the Condvar (and tokio lacks a + // working Condvar implementation) + users_for_key_query: Arc>, + + // condition variable that is notified each time an update is received for a user. + users_for_key_query_condvar: Arc, + tracked_user_loading_lock: Arc>, tracked_users_loaded: Arc, } @@ -118,6 +132,12 @@ struct UsersForKeyQuery { /// The users pending a lookup, together with the sequence number at which /// they were added to the list user_map: HashMap, + + /// A list of tasks waiting for key queries to complete. + /// + /// We expect this list to remain fairly short, so don't bother partitioning + /// by user. + tasks_awaiting_key_query: Vec>, } // We use wrapping arithmetic for the sequence numbers, to make sure we never @@ -133,7 +153,11 @@ type InvalidationSequenceNumber = i64; impl UsersForKeyQuery { /// Create a new, empty, `UsersForKeyQueryCache` fn new() -> Self { - UsersForKeyQuery { next_sequence_number: 0, user_map: HashMap::new() } + UsersForKeyQuery { + next_sequence_number: 0, + user_map: HashMap::new(), + tasks_awaiting_key_query: Vec::new(), + } } /// Record a new user that requires a key query @@ -159,6 +183,24 @@ impl UsersForKeyQuery { ) -> bool { let last_invalidation = self.user_map.get(user); + // if there were any jobs waiting for this key query to complete, we can flag + // them as completed and remove them from our list. + // we also clear out any tasks that have been cancelled. + self.tasks_awaiting_key_query.retain(|waiter| { + let Some(waiter) = waiter.upgrade() else { + // the TaskAwaitingKeyQuery has been dropped, so it probably timed out and the + // caller went away. We can remove it from our list whether or not it's for this + // user. + return false; + }; + if waiter.user == user && waiter.sequence_number.wrapping_sub(query_sequence) <= 0 { + waiter.completed.store(true, Ordering::Relaxed); + false + } else { + true + } + }); + if let Some(invalidation_sequence) = last_invalidation { Span::current().record("invalidation_sequence", invalidation_sequence); if invalidation_sequence.wrapping_sub(query_sequence) > 0 { @@ -182,6 +224,44 @@ impl UsersForKeyQuery { let sequence_number = self.next_sequence_number.wrapping_sub(1); (self.user_map.keys().cloned().collect(), sequence_number) } + + /// Check if a key query is pending for a user, and register for a wakeup if + /// so. + /// + /// If no key query is currently pending, returns `None`. Otherwise, returns + /// (an `Arc` to) a `KeysQueryWaiter`, whose `completed` flag will + /// be set once the lookup completes. + fn maybe_register_waiting_task(&mut self, user: &UserId) -> Option> { + match self.user_map.get(user) { + None => None, + Some(&sequence_number) => { + let waiter = Arc::new(KeysQueryWaiter { + sequence_number, + user: user.to_owned(), + completed: AtomicBool::new(false), + }); + self.tasks_awaiting_key_query.push(Arc::downgrade(&waiter)); + Some(waiter) + } + } + } +} + +/// Information on a task which is waiting for a `/keys/query` to complete. +#[derive(Debug)] +struct KeysQueryWaiter { + /// The user that we are waiting for + user: OwnedUserId, + + /// The sequence number of the last invalidation of the users's device list + /// when we started waiting (ie, any `/keys/query` result with the same or + /// greater sequence number will satisfy this waiter) + sequence_number: InvalidationSequenceNumber, + + /// Whether the `/keys/query` has completed. + /// + /// This is only modified whilst holding the mutex on `users_for_key_query`. + completed: AtomicBool, } #[derive(Default, Debug)] @@ -371,7 +451,8 @@ impl Store { inner: store, verification_machine, tracked_users_cache: DashSet::new().into(), - users_for_key_query: Mutex::new(UsersForKeyQuery::new()).into(), + users_for_key_query: AsyncStdMutex::new(UsersForKeyQuery::new()).into(), + users_for_key_query_condvar: Condvar::new().into(), tracked_users_loaded: AtomicBool::new(false).into(), tracked_user_loading_lock: Mutex::new(()).into(), } @@ -770,6 +851,8 @@ impl Store { } } self.inner.save_tracked_users(&store_updates).await?; + // wake up any tasks that may have been waiting for updates + self.users_for_key_query_condvar.notify_all(); Ok(()) } @@ -826,6 +909,42 @@ impl Store { Ok(self.users_for_key_query.lock().await.users_for_key_query()) } + /// Wait for a `/keys/query` response to be received if one is expected for + /// the given user. + /// + /// If the given timeout elapses, the method will stop waiting and return + /// `UserKeyQueryResult::TimeoutExpired` + pub async fn wait_if_user_key_query_pending( + &self, + timeout_duration: Duration, + user: &UserId, + ) -> UserKeyQueryResult { + let mut g = self.users_for_key_query.lock().await; + + let Some(w) = g.maybe_register_waiting_task(user) else { + return UserKeyQueryResult::WasNotPending; + }; + + let f1 = async { + while !w.completed.load(Ordering::Relaxed) { + g = self.users_for_key_query_condvar.wait(g).await; + } + }; + + match timeout(Box::pin(f1), timeout_duration).await { + Err(_) => { + warn!( + user_id = ?user, + "The user has a pending `/key/query` request which did \ + not finish yet, some devices might be missing." + ); + + UserKeyQueryResult::TimeoutExpired + } + _ => UserKeyQueryResult::WasPending, + } + } + /// See the docs for [`crate::OlmMachine::tracked_users()`]. pub async fn tracked_users(&self) -> Result> { self.load_tracked_users().await?;