diff --git a/Cargo.lock b/Cargo.lock index 53908908420..b1b63f849b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -727,6 +727,18 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg-vis" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3a2c3bf5fc10fe2ca157564fbe08a4cb2b0a7d2ff3fe2f9683e65d5e7c7859c" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "chacha20" version = "0.8.2" @@ -2517,6 +2529,7 @@ dependencies = [ "backoff", "bytes", "bytesize", + "cfg-vis", "ctor", "dashmap", "dirs", @@ -4101,6 +4114,7 @@ dependencies = [ "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", "winreg", @@ -5723,6 +5737,19 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasm-streams" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.61" diff --git a/crates/matrix-sdk-ui/src/timeline/futures.rs b/crates/matrix-sdk-ui/src/timeline/futures.rs new file mode 100644 index 00000000000..c42b7713082 --- /dev/null +++ b/crates/matrix-sdk-ui/src/timeline/futures.rs @@ -0,0 +1,69 @@ +use std::{ + fs, + future::{Future, IntoFuture}, + path::Path, + pin::Pin, +}; + +use eyeball::{shared::Observable as SharedObservable, Subscriber}; +use matrix_sdk::{attachment::AttachmentConfig, room::Room, TransmissionProgress}; +use mime::Mime; + +use super::{Error, Timeline}; + +pub struct SendAttachment<'a> { + timeline: &'a Timeline, + url: String, + mime_type: Mime, + config: AttachmentConfig, + pub(crate) send_progress: SharedObservable, +} + +impl<'a> SendAttachment<'a> { + pub(crate) fn new( + timeline: &'a Timeline, + url: String, + mime_type: Mime, + config: AttachmentConfig, + ) -> Self { + Self { timeline, url, mime_type, config, send_progress: Default::default() } + } + + /// Get a subscriber to observe the progress of sending the request + /// body. + #[cfg(not(target_arch = "wasm32"))] + pub fn subscribe_to_send_progress(&self) -> Subscriber { + self.send_progress.subscribe() + } +} + +impl<'a> IntoFuture for SendAttachment<'a> { + type Output = Result<(), Error>; + #[cfg(target_arch = "wasm32")] + type IntoFuture = Pin + 'a>>; + #[cfg(not(target_arch = "wasm32"))] + type IntoFuture = Pin + Send + 'a>>; + + fn into_future(self) -> Self::IntoFuture { + let Self { timeline, url, mime_type, config, send_progress } = self; + Box::pin(async move { + let Room::Joined(room) = Room::from(timeline.room().clone()) else { + return Err(Error::RoomNotJoined); + }; + + let body = Path::new(&url) + .file_name() + .ok_or(Error::InvalidAttachmentFileName)? + .to_str() + .expect("path was created from UTF-8 string, hence filename part is UTF-8 too"); + let data = fs::read(&url).map_err(|_| Error::InvalidAttachmentData)?; + + room.send_attachment(body, &mime_type, data, config) + .with_send_progress_observable(send_progress) + .await + .map_err(|_| Error::FailedSendingAttachment)?; + + Ok(()) + }) + } +} diff --git a/crates/matrix-sdk-ui/src/timeline/mod.rs b/crates/matrix-sdk-ui/src/timeline/mod.rs index 659836e8936..ea570792d96 100644 --- a/crates/matrix-sdk-ui/src/timeline/mod.rs +++ b/crates/matrix-sdk-ui/src/timeline/mod.rs @@ -16,7 +16,7 @@ //! //! See [`Timeline`] for details. -use std::{fs, path::Path, pin::Pin, sync::Arc, task::Poll, time::Duration}; +use std::{pin::Pin, sync::Arc, task::Poll, time::Duration}; use async_std::sync::{Condvar, Mutex}; use eyeball_im::VectorDiff; @@ -47,6 +47,7 @@ use tracing::{debug, error, info, instrument, warn}; mod builder; mod event_handler; mod event_item; +mod futures; mod inner; mod pagination; mod read_receipts; @@ -70,6 +71,7 @@ pub use self::{ OtherState, Profile, ReactionGroup, RepliedToEvent, RoomMembershipChange, Sticker, TimelineDetails, TimelineItemContent, }, + futures::SendAttachment, pagination::{PaginationOptions, PaginationOutcome}, traits::RoomExt, virtual_item::VirtualTimelineItem, @@ -346,25 +348,13 @@ impl Timeline { /// * `config` - An attachment configuration object containing details about /// the attachment /// like a thumbnail, its size, duration etc. - pub async fn send_attachment( + pub fn send_attachment( &self, url: String, mime_type: Mime, config: AttachmentConfig, - ) -> Result<(), Error> { - let Room::Joined(room) = Room::from(self.room().clone()) else { - return Err(Error::RoomNotJoined); - }; - - let body = - Path::new(&url).file_name().ok_or(Error::InvalidAttachmentFileName)?.to_str().unwrap(); - let data = fs::read(&url).map_err(|_| Error::InvalidAttachmentData)?; - - room.send_attachment(body, &mime_type, data, config) - .await - .map_err(|_| Error::FailedSendingAttachment)?; - - Ok(()) + ) -> SendAttachment<'_> { + SendAttachment::new(self, url, mime_type, config) } /// Retry sending a message that previously failed to send. diff --git a/crates/matrix-sdk/Cargo.toml b/crates/matrix-sdk/Cargo.toml index 56f52f28699..6073135a272 100644 --- a/crates/matrix-sdk/Cargo.toml +++ b/crates/matrix-sdk/Cargo.toml @@ -16,12 +16,7 @@ features = ["docsrs"] rustdoc-args = ["--cfg", "docsrs"] [features] -default = [ - "e2e-encryption", - "automatic-room-key-forwarding", - "sqlite", - "native-tls", -] +default = ["e2e-encryption", "automatic-room-key-forwarding", "sqlite", "native-tls"] testing = [] e2e-encryption = [ @@ -47,16 +42,14 @@ appservice = ["ruma/appservice-api-s"] image-proc = ["dep:image"] image-rayon = ["image-proc", "image?/jpeg_rayon"] -experimental-sliding-sync = ["matrix-sdk-base/experimental-sliding-sync", "reqwest/gzip", "dep:eyeball-im-util"] - -docsrs = [ - "e2e-encryption", - "sqlite", - "sso-login", - "qrcode", - "image-proc", +experimental-sliding-sync = [ + "matrix-sdk-base/experimental-sliding-sync", + "reqwest/gzip", + "dep:eyeball-im-util", ] +docsrs = ["e2e-encryption", "sqlite", "sso-login", "qrcode", "image-proc"] + [dependencies] anyhow = { workspace = true, optional = true } anymap2 = "0.13.0" @@ -64,6 +57,7 @@ async-stream = { workspace = true } async-trait = { workspace = true } bytes = "1.1.0" bytesize = "1.1" +cfg-vis = "0.3.0" dashmap = { workspace = true } event-listener = "2.5.2" eyeball = { workspace = true } @@ -82,7 +76,6 @@ matrix-sdk-sqlite = { version = "0.1.0", path = "../matrix-sdk-sqlite", default- mime = "0.3.16" mime2ext = "0.1.52" rand = { version = "0.8.5", optional = true } -reqwest = { version = "0.11.10", default_features = false } ruma = { workspace = true, features = ["rand", "unstable-msc2448", "unstable-msc2965"] } serde = { workspace = true } serde_html_form = { workspace = true } @@ -116,10 +109,14 @@ optional = true [target.'cfg(target_arch = "wasm32")'.dependencies] gloo-timers = { version = "0.2.6", features = ["futures"] } +reqwest = { version = "0.11.10", default_features = false } tokio = { workspace = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] backoff = { version = "0.4.0", features = ["tokio"] } +# only activate reqwest's stream feature on non-wasm, the wasm part seems to not +# support *sending* streams, which makes it useless for us. +reqwest = { version = "0.11.10", default_features = false, features = ["stream"] } tokio = { workspace = true, features = ["fs", "rt", "macros"] } [dev-dependencies] diff --git a/crates/matrix-sdk/src/client/builder.rs b/crates/matrix-sdk/src/client/builder.rs index 13bf03f9476..8a7223af77f 100644 --- a/crates/matrix-sdk/src/client/builder.rs +++ b/crates/matrix-sdk/src/client/builder.rs @@ -371,6 +371,7 @@ impl ClientBuilder { None, None, &[MatrixVersion::V1_0], + Default::default(), ) .await .map_err(|e| match e { diff --git a/crates/matrix-sdk/src/client/futures.rs b/crates/matrix-sdk/src/client/futures.rs new file mode 100644 index 00000000000..70df412d2be --- /dev/null +++ b/crates/matrix-sdk/src/client/futures.rs @@ -0,0 +1,99 @@ +use std::{ + fmt::Debug, + future::{Future, IntoFuture}, + pin::Pin, +}; + +use cfg_vis::cfg_vis; +use eyeball::shared::Observable as SharedObservable; +#[cfg(not(target_arch = "wasm32"))] +use eyeball::Subscriber; +use ruma::api::{client::error::ErrorKind, error::FromHttpResponseError, OutgoingRequest}; + +use super::super::Client; +use crate::{ + config::RequestConfig, + error::{HttpError, HttpResult}, + RefreshTokenError, TransmissionProgress, +}; + +/// `IntoFuture` returned by [`Client::send`]. +#[allow(missing_debug_implementations)] +pub struct SendRequest { + pub(crate) client: Client, + pub(crate) request: R, + pub(crate) config: Option, + pub(crate) send_progress: SharedObservable, +} + +impl SendRequest { + /// Replace the default `SharedObservable` used for tracking upload + /// progress. + /// + /// Note that any subscribers obtained from + /// [`subscribe_to_send_progress`][Self::subscribe_to_send_progress] + /// will be invalidated by this. + #[cfg_vis(target_arch = "wasm32", pub(crate))] + pub fn with_send_progress_observable( + mut self, + send_progress: SharedObservable, + ) -> Self { + self.send_progress = send_progress; + self + } + + /// Get a subscriber to observe the progress of sending the request + /// body. + #[cfg(not(target_arch = "wasm32"))] + pub fn subscribe_to_send_progress(&self) -> Subscriber { + self.send_progress.subscribe() + } +} + +impl IntoFuture for SendRequest +where + R: OutgoingRequest + Clone + Debug + Send + Sync + 'static, + R::IncomingResponse: Send + Sync, + HttpError: From>, +{ + type Output = HttpResult; + #[cfg(target_arch = "wasm32")] + type IntoFuture = Pin>>; + #[cfg(not(target_arch = "wasm32"))] + type IntoFuture = Pin + Send>>; + + fn into_future(self) -> Self::IntoFuture { + let Self { client, request, config, send_progress } = self; + Box::pin(async move { + let res = + Box::pin(client.send_inner(request.clone(), config, None, send_progress.clone())) + .await; + + // If this is an `M_UNKNOWN_TOKEN` error and refresh token handling is active, + // try to refresh the token and retry the request. + if client.inner.handle_refresh_tokens { + if let Err(Some(ErrorKind::UnknownToken { .. })) = + res.as_ref().map_err(HttpError::client_api_error_kind) + { + if let Err(refresh_error) = client.refresh_access_token().await { + match &refresh_error { + HttpError::RefreshToken(RefreshTokenError::RefreshTokenRequired) => { + // Refreshing access tokens is not supported + // by + // this `Session`, ignore. + } + _ => { + return Err(refresh_error); + } + } + } else { + return Box::pin(client.send_inner(request, config, None, send_progress)) + .await; + } + } + } + + res + }) + } +} diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index a8eeea5ae54..a2622b4cb36 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -25,7 +25,7 @@ use std::{ }; use dashmap::DashMap; -use eyeball::{unique::Observable, Subscriber}; +use eyeball::{shared::Observable as SharedObservable, unique::Observable, Subscriber}; use futures_core::Stream; use futures_util::StreamExt; use matrix_sdk_base::{ @@ -84,16 +84,18 @@ use crate::{ http_client::HttpClient, room, sync::{RoomUpdate, SyncResponse}, - Account, Error, Media, RefreshTokenError, Result, RumaApiError, + Account, Error, Media, RefreshTokenError, Result, RumaApiError, TransmissionProgress, }; mod builder; +mod futures; mod login_builder; #[cfg(feature = "sso-login")] pub use self::login_builder::SsoLoginBuilder; pub use self::{ builder::{ClientBuildError, ClientBuilder}, + futures::SendRequest, login_builder::LoginBuilder, }; @@ -1452,6 +1454,7 @@ impl Client { self.access_token().as_deref(), self.user_id(), self.server_versions().await?, + Default::default(), ) .await; @@ -1827,40 +1830,16 @@ impl Client { /// // returned /// # anyhow::Ok(()) }; /// ``` - pub async fn send( + pub fn send( &self, request: Request, config: Option, - ) -> HttpResult + ) -> SendRequest where Request: OutgoingRequest + Clone + Debug, HttpError: From>, { - let res = Box::pin(self.send_inner(request.clone(), config, None)).await; - - // If this is an `M_UNKNOWN_TOKEN` error and refresh token handling is active, - // try to refresh the token and retry the request. - if self.inner.handle_refresh_tokens { - if let Err(Some(ErrorKind::UnknownToken { .. })) = - res.as_ref().map_err(HttpError::client_api_error_kind) - { - if let Err(refresh_error) = self.refresh_access_token().await { - match &refresh_error { - HttpError::RefreshToken(RefreshTokenError::RefreshTokenRequired) => { - // Refreshing access tokens is not supported by - // this `Session`, ignore. - } - _ => { - return Err(refresh_error); - } - } - } else { - return Box::pin(self.send_inner(request, config, None)).await; - } - } - } - - res + SendRequest { client: self.clone(), request, config, send_progress: Default::default() } } #[cfg(feature = "experimental-sliding-sync")] @@ -1876,8 +1855,13 @@ impl Client { Request: OutgoingRequest + Clone + Debug, HttpError: From>, { - let res = - Box::pin(self.send_inner(request.clone(), config, sliding_sync_proxy.clone())).await; + let res = Box::pin(self.send_inner( + request.clone(), + config, + sliding_sync_proxy.clone(), + Default::default(), + )) + .await; // If this is an `M_UNKNOWN_TOKEN` error and refresh token handling is active, // try to refresh the token and retry the request. @@ -1896,7 +1880,13 @@ impl Client { } } } else { - return Box::pin(self.send_inner(request, config, sliding_sync_proxy)).await; + return Box::pin(self.send_inner( + request, + config, + sliding_sync_proxy, + Default::default(), + )) + .await; } } } @@ -1909,6 +1899,7 @@ impl Client { request: Request, config: Option, homeserver: Option, + send_progress: SharedObservable, ) -> HttpResult where Request: OutgoingRequest + Debug, @@ -1929,6 +1920,7 @@ impl Client { self.access_token().as_deref(), self.user_id(), self.server_versions().await?, + send_progress, ) .await; @@ -1957,6 +1949,7 @@ impl Client { None, None, &[MatrixVersion::V1_0], + Default::default(), ) .await? .known_versions() diff --git a/crates/matrix-sdk/src/encryption/futures.rs b/crates/matrix-sdk/src/encryption/futures.rs new file mode 100644 index 00000000000..0cd65a1f71b --- /dev/null +++ b/crates/matrix-sdk/src/encryption/futures.rs @@ -0,0 +1,90 @@ +use std::{ + future::{Future, IntoFuture}, + io::Read, + pin::Pin, +}; + +use cfg_vis::cfg_vis; +use eyeball::shared::Observable as SharedObservable; +#[cfg(not(target_arch = "wasm32"))] +use eyeball::Subscriber; + +use crate::{Client, Result, TransmissionProgress}; + +/// Future returned by [`Client::prepare_encrypted_file`]. +#[allow(missing_debug_implementations)] +pub struct PrepareEncryptedFile<'a, R: ?Sized> { + client: &'a Client, + content_type: &'a mime::Mime, + reader: &'a mut R, + send_progress: SharedObservable, +} + +impl<'a, R: ?Sized> PrepareEncryptedFile<'a, R> { + pub(crate) fn new(client: &'a Client, content_type: &'a mime::Mime, reader: &'a mut R) -> Self { + Self { client, content_type, reader, send_progress: Default::default() } + } + + /// Replace the default `SharedObservable` used for tracking upload + /// progress. + /// + /// Note that any subscribers obtained from + /// [`subscribe_to_send_progress`][Self::subscribe_to_send_progress] + /// will be invalidated by this. + #[cfg_vis(target_arch = "wasm32", pub(crate))] + pub fn with_send_progress_observable( + mut self, + send_progress: SharedObservable, + ) -> Self { + self.send_progress = send_progress; + self + } + + /// Get a subscriber to observe the progress of sending the request + /// body. + #[cfg(not(target_arch = "wasm32"))] + pub fn subscribe_to_send_progress(&self) -> Subscriber { + self.send_progress.subscribe() + } +} + +impl<'a, R> IntoFuture for PrepareEncryptedFile<'a, R> +where + R: Read + Send + ?Sized + 'a, +{ + type Output = Result; + #[cfg(target_arch = "wasm32")] + type IntoFuture = Pin + 'a>>; + #[cfg(not(target_arch = "wasm32"))] + type IntoFuture = Pin + Send + 'a>>; + + fn into_future(self) -> Self::IntoFuture { + let Self { client, content_type, reader, send_progress } = self; + Box::pin(async move { + let mut encryptor = matrix_sdk_base::crypto::AttachmentEncryptor::new(reader); + + let mut buf = Vec::new(); + encryptor.read_to_end(&mut buf)?; + + let response = client + .media() + .upload(content_type, buf) + .with_send_progress_observable(send_progress) + .await?; + + let file: ruma::events::room::EncryptedFile = { + let keys = encryptor.finish(); + ruma::events::room::EncryptedFileInit { + url: response.content_uri, + key: keys.key, + iv: keys.iv, + hashes: keys.hashes, + v: keys.version, + } + .into() + }; + + Ok(file) + }) + } +} diff --git a/crates/matrix-sdk/src/encryption/mod.rs b/crates/matrix-sdk/src/encryption/mod.rs index 0a428f08d8e..9c090834bcb 100644 --- a/crates/matrix-sdk/src/encryption/mod.rs +++ b/crates/matrix-sdk/src/encryption/mod.rs @@ -16,8 +16,6 @@ #![doc = include_str!("../docs/encryption.md")] #![cfg_attr(target_arch = "wasm32", allow(unused_imports))] -pub mod identities; -pub mod verification; use std::{ collections::{BTreeMap, HashSet}, io::{Read, Write}, @@ -25,16 +23,8 @@ use std::{ path::PathBuf, }; +use eyeball::shared::Observable as SharedObservable; use futures_util::stream::{self, StreamExt}; -pub use matrix_sdk_base::crypto::{ - olm::{ - SessionCreationError as MegolmSessionCreationError, - SessionExportError as OlmSessionExportError, - }, - vodozemac, CrossSigningStatus, CryptoStoreError, DecryptorError, EventError, KeyExportError, - LocalTrust, MediaEncryptionInfo, MegolmError, OlmError, RoomKeyImportResult, SecretImportError, - SessionCreationError, SignatureError, VERSION, -}; use matrix_sdk_base::crypto::{OlmMachine, OutgoingRequest, RoomMessageRequest, ToDeviceRequest}; use ruma::{ api::client::{ @@ -53,7 +43,6 @@ use ruma::{ use tokio::sync::RwLockReadGuard; use tracing::{debug, instrument, trace, warn}; -pub use crate::error::RoomKeyImportError; use crate::{ attachment::{AttachmentInfo, Thumbnail}, encryption::{ @@ -61,9 +50,26 @@ use crate::{ verification::{SasVerification, Verification, VerificationRequest}, }, error::HttpResult, - room, Client, Error, Result, + room, Client, Error, Result, TransmissionProgress, }; +mod futures; +pub mod identities; +pub mod verification; + +pub use matrix_sdk_base::crypto::{ + olm::{ + SessionCreationError as MegolmSessionCreationError, + SessionExportError as OlmSessionExportError, + }, + vodozemac, CrossSigningStatus, CryptoStoreError, DecryptorError, EventError, KeyExportError, + LocalTrust, MediaEncryptionInfo, MegolmError, OlmError, RoomKeyImportResult, SecretImportError, + SessionCreationError, SignatureError, VERSION, +}; + +pub use self::futures::PrepareEncryptedFile; +pub use crate::error::RoomKeyImportError; + impl Client { pub(crate) async fn olm_machine(&self) -> RwLockReadGuard<'_, Option> { self.base_client().olm_machine().await @@ -137,31 +143,12 @@ impl Client { /// room.send(CustomEventContent { encrypted_file }, None).await?; /// # anyhow::Ok(()) }; /// ``` - pub async fn prepare_encrypted_file<'a, R: Read + ?Sized + 'a>( - &self, - content_type: &mime::Mime, + pub fn prepare_encrypted_file<'a, R: Read + ?Sized + 'a>( + &'a self, + content_type: &'a mime::Mime, reader: &'a mut R, - ) -> Result { - let mut encryptor = matrix_sdk_base::crypto::AttachmentEncryptor::new(reader); - - let mut buf = Vec::new(); - encryptor.read_to_end(&mut buf)?; - - let response = self.media().upload(content_type, buf).await?; - - let file: ruma::events::room::EncryptedFile = { - let keys = encryptor.finish(); - ruma::events::room::EncryptedFileInit { - url: response.content_uri, - key: keys.key, - iv: keys.iv, - hashes: keys.hashes, - v: keys.version, - } - .into() - }; - - Ok(file) + ) -> PrepareEncryptedFile<'a, R> { + PrepareEncryptedFile::new(self, content_type, reader) } /// Encrypt and upload the file to be read from `reader` and construct an @@ -173,11 +160,16 @@ impl Client { data: Vec, info: Option, thumbnail: Option, + send_progress: SharedObservable, ) -> Result { + // FIXME: Upload the thumbnail in parallel with the main file let (thumbnail_source, thumbnail_info) = if let Some(thumbnail) = thumbnail { let mut cursor = Cursor::new(thumbnail.data); - let file = self.prepare_encrypted_file(content_type, &mut cursor).await?; + let file = self + .prepare_encrypted_file(content_type, &mut cursor) + .with_send_progress_observable(send_progress.clone()) + .await?; use ruma::events::room::ThumbnailInfo; #[rustfmt::skip] @@ -192,7 +184,10 @@ impl Client { }; let mut cursor = Cursor::new(data); - let file = self.prepare_encrypted_file(content_type, &mut cursor).await?; + let file = self + .prepare_encrypted_file(content_type, &mut cursor) + .with_send_progress_observable(send_progress) + .await?; use std::io::Cursor; diff --git a/crates/matrix-sdk/src/http_client.rs b/crates/matrix-sdk/src/http_client.rs index a1fa2af9afa..39e263128be 100644 --- a/crates/matrix-sdk/src/http_client.rs +++ b/crates/matrix-sdk/src/http_client.rs @@ -25,6 +25,7 @@ use std::{ use async_trait::async_trait; use bytes::{Bytes, BytesMut}; use bytesize::ByteSize; +use eyeball::shared::Observable as SharedObservable; use matrix_sdk_common::AsyncTraitDeps; use ruma::{ api::{ @@ -57,12 +58,16 @@ pub trait HttpSend: AsyncTraitDeps { /// `Request`. /// /// * `timeout` - A timeout for the full request > response cycle. + /// /// # Examples /// /// ``` /// use std::time::Duration; /// - /// use matrix_sdk::{async_trait, bytes::Bytes, HttpError, HttpSend}; + /// use eyeball::shared::Observable as SharedObservable; + /// use matrix_sdk::{ + /// async_trait, bytes::Bytes, HttpError, HttpSend, TransmissionProgress, + /// }; /// /// #[derive(Debug)] /// struct Client(reqwest::Client); @@ -83,6 +88,7 @@ pub trait HttpSend: AsyncTraitDeps { /// &self, /// request: http::Request, /// timeout: Duration, + /// _send_progress: SharedObservable, /// ) -> Result, HttpError> { /// Ok(self /// .response_to_http_response( @@ -98,6 +104,7 @@ pub trait HttpSend: AsyncTraitDeps { &self, request: http::Request, timeout: Duration, + send_progress: SharedObservable, ) -> Result, HttpError>; } @@ -170,27 +177,60 @@ impl HttpClient { &self, request: http::Request, config: RequestConfig, + send_progress: SharedObservable, ) -> Result where R: OutgoingRequest + Debug, HttpError: From>, { - // There's a bunch of state here, factor out a pinned inner future to - // reduce this size of futures that await this function. + // There's a bunch of state in send_request_with_retries, factor + // out a pinned inner future to reduce this size of futures that + // await this function. #[cfg(not(target_arch = "wasm32"))] - let (status_code, response_size, response) = Box::pin(async move { - use backoff::{future::retry, Error as RetryError, ExponentialBackoff}; - use ruma::api::client::error::{ - ErrorBody as ClientApiErrorBody, ErrorKind as ClientApiErrorKind, - }; + let response = + Box::pin(self.send_request_with_retries::(config, request, send_progress)).await?; + + #[cfg(target_arch = "wasm32")] + let response = { + let response = self.inner.send_request(request, config.timeout, send_progress).await?; + + let status_code = response.status(); + let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX)); + tracing::Span::current() + .record("status", status_code.as_u16()) + .record("response_size", response_size.to_string_as(true)); + + R::IncomingResponse::try_from_http_response(response)? + }; + + Ok(response) + } + + #[cfg(not(target_arch = "wasm32"))] + async fn send_request_with_retries( + &self, + config: RequestConfig, + request: http::Request, + send_progress: SharedObservable, + ) -> Result + where + R: OutgoingRequest + Debug, + HttpError: From>, + { + use backoff::{future::retry, Error as RetryError, ExponentialBackoff}; + use ruma::api::client::error::{ + ErrorBody as ClientApiErrorBody, ErrorKind as ClientApiErrorKind, + }; - use crate::RumaApiError; + use crate::RumaApiError; - let backoff = - ExponentialBackoff { max_elapsed_time: config.retry_timeout, ..Default::default() }; - let retry_count = AtomicU64::new(1); + let backoff = + ExponentialBackoff { max_elapsed_time: config.retry_timeout, ..Default::default() }; + let retry_count = AtomicU64::new(1); - let send_request = || async { + let send_request = || { + let send_progress = send_progress.clone(); + async { let stop = if let Some(retry_limit) = config.retry_limit { retry_count.fetch_add(1, Ordering::Relaxed) >= retry_limit } else { @@ -233,42 +273,27 @@ impl HttpClient { let response = self .inner - .send_request(clone_request(&request), config.timeout) + .send_request(clone_request(&request), config.timeout, send_progress) .await .map_err(error_type)?; let status_code = response.status(); - let response_size = response.body().len(); - - let response = R::IncomingResponse::try_from_http_response(response) - .map_err(|e| error_type(HttpError::from(e)))?; - - Ok((status_code, response_size, response)) - }; - - retry::<_, HttpError, _, _, _>(backoff, send_request).await - }) - .await?; - - #[cfg(target_arch = "wasm32")] - let (status_code, response_size, response) = { - let response = self.inner.send_request(request, config.timeout).await?; - let status_code = response.status(); - let response_size = response.body().len(); + let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX)); + tracing::Span::current() + .record("status", status_code.as_u16()) + .record("response_size", response_size.to_string_as(true)); - (status_code, response_size, R::IncomingResponse::try_from_http_response(response)?) + R::IncomingResponse::try_from_http_response(response) + .map_err(|e| error_type(HttpError::from(e))) + } }; - let response_bytesize = ByteSize(response_size.try_into().unwrap_or(u64::MAX)); - tracing::Span::current() - .record("status", status_code.as_u16()) - .record("response_size", response_bytesize.to_string_as(true)); - - Ok(response) + retry::<_, HttpError, _, _, _>(backoff, send_request).await } + #[allow(clippy::too_many_arguments)] #[instrument( - skip(self, access_token, config, request, user_id), + skip(self, access_token, config, request, user_id, send_progress), fields( config, path, @@ -288,6 +313,7 @@ impl HttpClient { access_token: Option<&str>, user_id: Option<&UserId>, server_versions: &[MatrixVersion], + send_progress: SharedObservable, ) -> Result where R: OutgoingRequest + Debug, @@ -350,7 +376,7 @@ impl HttpClient { }; debug!("Sending request"); - match self.send_request::(request, config).await { + match self.send_request::(request, config, send_progress).await { Ok(response) => { debug!("Got response"); Ok(response) @@ -417,6 +443,15 @@ impl HttpSettings { } } +/// Progress of sending or receiving a payload. +#[derive(Clone, Copy, Debug, Default)] +pub struct TransmissionProgress { + /// How many bytes were already transferred. + pub current: usize, + /// How many bytes there are in total. + pub total: usize, +} + // Clones all request parts except the extensions which can't be cloned. // See also https://github.com/hyperium/http/issues/395 #[cfg(not(target_arch = "wasm32"))] @@ -455,18 +490,108 @@ impl HttpSend for reqwest::Client { &self, request: http::Request, _timeout: Duration, + _send_progress: SharedObservable, ) -> Result, HttpError> { - #[allow(unused_mut)] - let mut request = reqwest::Request::try_from(request)?; - - // reqwest's timeout functionality is not available on WASM #[cfg(not(target_arch = "wasm32"))] - { + let request = { + use std::convert::Infallible; + + use futures_util::stream; + + let mut request = if _send_progress.subscriber_count() != 0 { + _send_progress.update(|p| p.total += request.body().len()); + reqwest::Request::try_from(request.map(|body| { + let chunks = stream::iter(BytesChunks::new(body, 8192).map( + move |chunk| -> Result<_, Infallible> { + _send_progress.update(|p| p.current += chunk.len()); + Ok(chunk) + }, + )); + reqwest::Body::wrap_stream(chunks) + }))? + } else { + reqwest::Request::try_from(request)? + }; + *request.timeout_mut() = Some(_timeout); - } + request + }; + + // Both reqwest::Body::wrap_stream and the timeout functionality are + // not available on WASM + #[cfg(target_arch = "wasm32")] + let request = reqwest::Request::try_from(request)?; let response = self.execute(request).await?; Ok(response_to_http_response(response).await?) } } + +#[cfg(not(target_arch = "wasm32"))] +struct BytesChunks { + bytes: Bytes, + size: usize, +} + +#[cfg(not(target_arch = "wasm32"))] +impl BytesChunks { + fn new(bytes: Bytes, size: usize) -> Self { + assert_ne!(size, 0); + Self { bytes, size } + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl Iterator for BytesChunks { + type Item = Bytes; + + fn next(&mut self) -> Option { + use std::mem; + + if self.bytes.is_empty() { + None + } else if self.bytes.len() < self.size { + Some(mem::take(&mut self.bytes)) + } else { + Some(self.bytes.split_to(self.size)) + } + } +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tests { + use bytes::Bytes; + + use super::BytesChunks; + + #[test] + fn bytes_chunks() { + let bytes = Bytes::new(); + assert!(BytesChunks::new(bytes, 1).collect::>().is_empty()); + + let bytes = Bytes::from_iter([1, 2]); + assert_eq!(BytesChunks::new(bytes, 2).collect::>(), [Bytes::from_iter([1, 2])]); + + let bytes = Bytes::from_iter([1, 2]); + assert_eq!(BytesChunks::new(bytes, 3).collect::>(), [Bytes::from_iter([1, 2])]); + + let bytes = Bytes::from_iter([1, 2, 3]); + assert_eq!( + BytesChunks::new(bytes, 1).collect::>(), + [Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])] + ); + + let bytes = Bytes::from_iter([1, 2, 3]); + assert_eq!( + BytesChunks::new(bytes, 2).collect::>(), + [Bytes::from_iter([1, 2]), Bytes::from_iter([3])] + ); + + let bytes = Bytes::from_iter([1, 2, 3, 4]); + assert_eq!( + BytesChunks::new(bytes, 2).collect::>(), + [Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])] + ); + } +} diff --git a/crates/matrix-sdk/src/lib.rs b/crates/matrix-sdk/src/lib.rs index a7d59fb1598..ab82a357f1d 100644 --- a/crates/matrix-sdk/src/lib.rs +++ b/crates/matrix-sdk/src/lib.rs @@ -51,11 +51,13 @@ pub mod encryption; pub use account::Account; #[cfg(feature = "sso-login")] pub use client::SsoLoginBuilder; -pub use client::{Client, ClientBuildError, ClientBuilder, LoginBuilder, LoopCtrl, UnknownToken}; +pub use client::{ + Client, ClientBuildError, ClientBuilder, LoginBuilder, LoopCtrl, SendRequest, UnknownToken, +}; #[cfg(feature = "image-proc")] pub use error::ImageError; pub use error::{Error, HttpError, HttpResult, RefreshTokenError, Result, RumaApiError}; -pub use http_client::HttpSend; +pub use http_client::{HttpSend, TransmissionProgress}; pub use media::Media; pub use ruma::{IdParseError, OwnedServerName, ServerName}; #[cfg(feature = "experimental-sliding-sync")] diff --git a/crates/matrix-sdk/src/media.rs b/crates/matrix-sdk/src/media.rs index 66d80aa8890..e626dbdc417 100644 --- a/crates/matrix-sdk/src/media.rs +++ b/crates/matrix-sdk/src/media.rs @@ -21,6 +21,7 @@ use std::io::Read; use std::path::Path; use std::time::Duration; +use eyeball::shared::Observable as SharedObservable; pub use matrix_sdk_base::media::*; use mime::Mime; #[cfg(not(target_arch = "wasm32"))] @@ -38,7 +39,7 @@ use tokio::{fs::File as TokioFile, io::AsyncWriteExt}; use crate::{ attachment::{AttachmentInfo, Thumbnail}, - Client, Result, + Client, Result, SendRequest, TransmissionProgress, }; /// A conservative upload speed of 1Mbps @@ -74,6 +75,9 @@ impl MediaFileHandle { } } +/// `IntoFuture` returned by [`Media::upload`]. +pub type SendUploadRequest = SendRequest; + impl Media { pub(crate) fn new(client: Client) -> Self { Self { client } @@ -106,11 +110,7 @@ impl Media { /// println!("Cat URI: {}", response.content_uri); /// # anyhow::Ok(()) }; /// ``` - pub async fn upload( - &self, - content_type: &Mime, - data: Vec, - ) -> Result { + pub fn upload(&self, content_type: &Mime, data: Vec) -> SendUploadRequest { let timeout = std::cmp::max( Duration::from_secs(data.len() as u64 / DEFAULT_UPLOAD_SPEED), MIN_UPLOAD_REQUEST_TIMEOUT, @@ -121,7 +121,7 @@ impl Media { }); let request_config = self.client.request_config().timeout(timeout); - Ok(self.client.send(request, Some(request_config)).await?) + self.client.send(request, Some(request_config)) } /// Gets a media file by copying it to a temporary location on disk. @@ -400,9 +400,14 @@ impl Media { data: Vec, info: Option, thumbnail: Option, + send_progress: SharedObservable, ) -> Result { + // FIXME: Upload the thumbnail in parallel with the main file let (thumbnail_source, thumbnail_info) = if let Some(thumbnail) = thumbnail { - let response = self.upload(&thumbnail.content_type, thumbnail.data).await?; + let response = self + .upload(&thumbnail.content_type, thumbnail.data) + .with_send_progress_observable(send_progress.clone()) + .await?; let url = response.content_uri; use ruma::events::room::ThumbnailInfo; @@ -416,7 +421,8 @@ impl Media { (None, None) }; - let response = self.upload(content_type, data).await?; + let response = + self.upload(content_type, data).with_send_progress_observable(send_progress).await?; let url = response.content_uri; diff --git a/crates/matrix-sdk/src/room/joined/futures.rs b/crates/matrix-sdk/src/room/joined/futures.rs new file mode 100644 index 00000000000..ff10e2fa54b --- /dev/null +++ b/crates/matrix-sdk/src/room/joined/futures.rs @@ -0,0 +1,144 @@ +#[cfg(feature = "image-proc")] +use std::io::Cursor; +use std::{ + future::{Future, IntoFuture}, + pin::Pin, +}; + +use eyeball::shared::Observable as SharedObservable; +use mime::Mime; +use ruma::api::client::message::send_message_event; +use tracing::{Instrument, Span}; + +use super::Joined; +use crate::{attachment::AttachmentConfig, Result, TransmissionProgress}; +#[cfg(feature = "image-proc")] +use crate::{ + attachment::{generate_image_thumbnail, Thumbnail}, + error::ImageError, +}; + +#[allow(missing_debug_implementations)] +pub struct SendAttachment<'a> { + room: &'a Joined, + body: &'a str, + content_type: &'a Mime, + data: Vec, + config: AttachmentConfig, + tracing_span: Span, + send_progress: SharedObservable, +} + +impl<'a> SendAttachment<'a> { + pub(crate) fn new( + room: &'a Joined, + body: &'a str, + content_type: &'a Mime, + data: Vec, + config: AttachmentConfig, + ) -> Self { + Self { + room, + body, + content_type, + data, + config, + tracing_span: Span::current(), + send_progress: Default::default(), + } + } + + /// Replace the default `SharedObservable` used for tracking upload + /// progress. + /// + /// Note that any subscribers obtained from + /// [`subscribe_to_send_progress`][Self::subscribe_to_send_progress] + /// will be invalidated by this. + #[cfg(not(target_arch = "wasm32"))] + pub fn with_send_progress_observable( + mut self, + send_progress: SharedObservable, + ) -> Self { + self.send_progress = send_progress; + self + } +} + +impl<'a> IntoFuture for SendAttachment<'a> { + type Output = Result; + #[cfg(target_arch = "wasm32")] + type IntoFuture = Pin + 'a>>; + #[cfg(not(target_arch = "wasm32"))] + type IntoFuture = Pin + Send + 'a>>; + + fn into_future(self) -> Self::IntoFuture { + let Self { room, body, content_type, data, config, tracing_span, send_progress } = self; + let fut = async move { + if config.thumbnail.is_some() { + room.prepare_and_send_attachment(body, content_type, data, config, send_progress) + .await + } else { + #[cfg(not(feature = "image-proc"))] + let thumbnail = None; + + #[cfg(feature = "image-proc")] + let data_slot; + #[cfg(feature = "image-proc")] + let (data, thumbnail) = if config.generate_thumbnail { + let content_type = content_type.clone(); + let make_thumbnail = move |data| { + let res = generate_image_thumbnail( + &content_type, + Cursor::new(&data), + config.thumbnail_size, + ); + (data, res) + }; + + #[cfg(not(target_arch = "wasm32"))] + let (data, res) = tokio::task::spawn_blocking(move || make_thumbnail(data)) + .await + .expect("Task join error"); + + #[cfg(target_arch = "wasm32")] + let (data, res) = make_thumbnail(data); + + let thumbnail = match res { + Ok((thumbnail_data, thumbnail_info)) => { + data_slot = thumbnail_data; + Some(Thumbnail { + data: data_slot, + content_type: mime::IMAGE_JPEG, + info: Some(thumbnail_info), + }) + } + Err( + ImageError::ThumbnailBiggerThanOriginal + | ImageError::FormatNotSupported, + ) => None, + Err(error) => return Err(error.into()), + }; + + (data, thumbnail) + } else { + (data, None) + }; + + let config = AttachmentConfig { + txn_id: config.txn_id, + info: config.info, + thumbnail, + #[cfg(feature = "image-proc")] + generate_thumbnail: false, + #[cfg(feature = "image-proc")] + thumbnail_size: None, + }; + + room.prepare_and_send_attachment(body, content_type, data, config, send_progress) + .await + } + }; + + Box::pin(fut.instrument(tracing_span)) + } +} diff --git a/crates/matrix-sdk/src/room/joined.rs b/crates/matrix-sdk/src/room/joined/mod.rs similarity index 93% rename from crates/matrix-sdk/src/room/joined.rs rename to crates/matrix-sdk/src/room/joined/mod.rs index 8fe6f034696..372b8b360b8 100644 --- a/crates/matrix-sdk/src/room/joined.rs +++ b/crates/matrix-sdk/src/room/joined/mod.rs @@ -1,9 +1,8 @@ -#[cfg(feature = "image-proc")] -use std::io::Cursor; #[cfg(feature = "e2e-encryption")] use std::sync::Arc; use std::{borrow::Borrow, ops::Deref}; +use eyeball::shared::Observable as SharedObservable; #[cfg(feature = "e2e-encryption")] use matrix_sdk_base::RoomMemberships; use matrix_sdk_common::instant::{Duration, Instant}; @@ -47,14 +46,13 @@ use crate::{ attachment::AttachmentConfig, error::{Error, HttpResult}, room::Common, - BaseRoom, Client, Result, RoomState, -}; -#[cfg(feature = "image-proc")] -use crate::{ - attachment::{generate_image_thumbnail, Thumbnail}, - error::ImageError, + BaseRoom, Client, Result, RoomState, TransmissionProgress, }; +mod futures; + +pub use self::futures::SendAttachment; + const TYPING_NOTICE_TIMEOUT: Duration = Duration::from_secs(4); const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3); @@ -711,73 +709,14 @@ impl Joined { /// [`upload()`]: crate::Media::upload /// [`send()`]: Joined::send #[instrument(skip_all)] - pub async fn send_attachment( - &self, - body: &str, - content_type: &Mime, + pub fn send_attachment<'a>( + &'a self, + body: &'a str, + content_type: &'a Mime, data: Vec, config: AttachmentConfig, - ) -> Result { - if config.thumbnail.is_some() { - self.prepare_and_send_attachment(body, content_type, data, config).await - } else { - #[cfg(not(feature = "image-proc"))] - let thumbnail = None; - - #[cfg(feature = "image-proc")] - let data_slot; - #[cfg(feature = "image-proc")] - let (data, thumbnail) = if config.generate_thumbnail { - let content_type = content_type.clone(); - let make_thumbnail = move |data| { - let res = generate_image_thumbnail( - &content_type, - Cursor::new(&data), - config.thumbnail_size, - ); - (data, res) - }; - - #[cfg(not(target_arch = "wasm32"))] - let (data, res) = tokio::task::spawn_blocking(move || make_thumbnail(data)) - .await - .expect("Task join error"); - - #[cfg(target_arch = "wasm32")] - let (data, res) = make_thumbnail(data); - - let thumbnail = match res { - Ok((thumbnail_data, thumbnail_info)) => { - data_slot = thumbnail_data; - Some(Thumbnail { - data: data_slot, - content_type: mime::IMAGE_JPEG, - info: Some(thumbnail_info), - }) - } - Err( - ImageError::ThumbnailBiggerThanOriginal | ImageError::FormatNotSupported, - ) => None, - Err(error) => return Err(error.into()), - }; - - (data, thumbnail) - } else { - (data, None) - }; - - let config = AttachmentConfig { - txn_id: config.txn_id, - info: config.info, - thumbnail, - #[cfg(feature = "image-proc")] - generate_thumbnail: false, - #[cfg(feature = "image-proc")] - thumbnail_size: None, - }; - - self.prepare_and_send_attachment(body, content_type, data, config).await - } + ) -> SendAttachment<'a> { + SendAttachment::new(self, body, content_type, data, config) } /// Prepare and send an attachment to this room. @@ -802,12 +741,13 @@ impl Joined { /// media. /// /// * `config` - Metadata and configuration for the attachment. - async fn prepare_and_send_attachment( - &self, - body: &str, - content_type: &Mime, + async fn prepare_and_send_attachment<'a>( + &'a self, + body: &'a str, + content_type: &'a Mime, data: Vec, config: AttachmentConfig, + send_progress: SharedObservable, ) -> Result { #[cfg(feature = "e2e-encryption")] let content = if self.is_encrypted().await? { @@ -818,12 +758,20 @@ impl Joined { data, config.info, config.thumbnail, + send_progress, ) .await? } else { self.client .media() - .prepare_attachment_message(body, content_type, data, config.info, config.thumbnail) + .prepare_attachment_message( + body, + content_type, + data, + config.info, + config.thumbnail, + send_progress, + ) .await? }; @@ -831,7 +779,14 @@ impl Joined { let content = self .client .media() - .prepare_attachment_message(body, content_type, data, config.info, config.thumbnail) + .prepare_attachment_message( + body, + content_type, + data, + config.info, + config.thumbnail, + send_progress, + ) .await?; self.send(RoomMessageEventContent::new(content), config.txn_id.as_deref()).await