From 2efe56235a9bafc278e537d511d2bfad751ab7cf Mon Sep 17 00:00:00 2001
From: Andrew Plaza <github@andrewplaza.dev>
Date: Wed, 4 Dec 2024 14:46:47 -0500
Subject: [PATCH] feat(wasm): unblock streams in the browser

---
 Cargo.lock                       |   1 +
 xmtp_api_http/Cargo.toml         |   1 +
 xmtp_api_http/src/http_stream.rs | 129 +++++++++++++++++++++++++++++++
 xmtp_api_http/src/lib.rs         |   4 +-
 xmtp_api_http/src/util.rs        |  83 +-------------------
 xmtp_mls/src/subscriptions.rs    |   1 +
 6 files changed, 136 insertions(+), 83 deletions(-)
 create mode 100644 xmtp_api_http/src/http_stream.rs

diff --git a/Cargo.lock b/Cargo.lock
index 8d5913404..fae157cfe 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -7248,6 +7248,7 @@ version = "0.1.0"
 dependencies = [
  "async-stream",
  "async-trait",
+ "bytes",
  "futures",
  "reqwest 0.12.9",
  "serde",
diff --git a/xmtp_api_http/Cargo.toml b/xmtp_api_http/Cargo.toml
index b26a414a9..dae2a490c 100644
--- a/xmtp_api_http/Cargo.toml
+++ b/xmtp_api_http/Cargo.toml
@@ -18,6 +18,7 @@ thiserror = "2.0"
 tokio = { workspace = true, features = ["sync", "rt", "macros"] }
 xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] }
 async-trait = "0.1"
+bytes = "1.9"
 
 [dev-dependencies]
 xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] }
diff --git a/xmtp_api_http/src/http_stream.rs b/xmtp_api_http/src/http_stream.rs
new file mode 100644
index 000000000..0a5f83014
--- /dev/null
+++ b/xmtp_api_http/src/http_stream.rs
@@ -0,0 +1,129 @@
+//! Streams that work with HTTP POST requests
+
+use crate::util::GrpcResponse;
+use futures::{
+    stream::{self, Stream, StreamExt},
+    Future,
+};
+use reqwest::Response;
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
+use serde_json::Deserializer;
+use std::pin::Pin;
+use xmtp_proto::{Error, ErrorKind};
+
+#[derive(Deserialize, Serialize, Debug)]
+pub(crate) struct SubscriptionItem<T> {
+    pub result: T,
+}
+
+enum HttpPostStream<F>
+where
+    F: Future<Output = Result<Response, reqwest::Error>>,
+{
+    NotStarted(F),
+    // NotStarted(Box<dyn Future<Output = Result<Response, Error>>>),
+    Started(Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin + Send>>),
+}
+
+impl<F> Stream for HttpPostStream<F>
+where
+    F: Future<Output = Result<Response, reqwest::Error>> + Unpin,
+{
+    type Item = Result<bytes::Bytes, reqwest::Error>;
+
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<Option<Self::Item>> {
+        use futures::task::Poll::*;
+        use HttpPostStream::*;
+        match self.as_mut().get_mut() {
+            NotStarted(ref mut f) => {
+                tracing::info!("Polling");
+                let f = std::pin::pin!(f);
+                match f.poll(cx) {
+                    Ready(response) => {
+                        let s = response.unwrap().bytes_stream();
+                        self.set(Self::Started(Box::pin(s.boxed())));
+                        self.poll_next(cx)
+                    }
+                    Pending => {
+                        // cx.waker().wake_by_ref();
+                        Pending
+                    }
+                }
+            }
+            Started(s) => s.poll_next_unpin(cx),
+        }
+    }
+}
+
+#[cfg(target_arch = "wasm32")]
+pub fn create_grpc_stream<
+    T: Serialize + Send + 'static,
+    R: DeserializeOwned + Send + std::fmt::Debug + 'static,
+>(
+    request: T,
+    endpoint: String,
+    http_client: reqwest::Client,
+) -> stream::LocalBoxStream<'static, Result<R, Error>> {
+    create_grpc_stream_inner(request, endpoint, http_client).boxed_local()
+}
+
+#[cfg(not(target_arch = "wasm32"))]
+pub fn create_grpc_stream<
+    T: Serialize + Send + 'static,
+    R: DeserializeOwned + Send + std::fmt::Debug + 'static,
+>(
+    request: T,
+    endpoint: String,
+    http_client: reqwest::Client,
+) -> stream::BoxStream<'static, Result<R, Error>> {
+    create_grpc_stream_inner(request, endpoint, http_client).boxed()
+}
+
+pub fn create_grpc_stream_inner<
+    T: Serialize + Send + 'static,
+    R: DeserializeOwned + Send + std::fmt::Debug + 'static,
+>(
+    request: T,
+    endpoint: String,
+    http_client: reqwest::Client,
+) -> impl Stream<Item = Result<R, Error>> {
+    let request = http_client.post(endpoint).json(&request).send();
+    let http_stream = HttpPostStream::NotStarted(request);
+
+    async_stream::stream! {
+        tracing::info!("spawning grpc http stream");
+        let mut remaining = vec![];
+        for await bytes in http_stream {
+            let bytes = bytes
+                .map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?;
+            let bytes = &[remaining.as_ref(), bytes.as_ref()].concat();
+            let de = Deserializer::from_slice(bytes);
+            let mut stream = de.into_iter::<GrpcResponse<R>>();
+            'messages: loop {
+                tracing::debug!("Waiting on next response ...");
+                let response = stream.next();
+                let res = match response {
+                    Some(Ok(GrpcResponse::Ok(response))) => Ok(response),
+                    Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result),
+                    Some(Ok(GrpcResponse::Err(e))) => {
+                        Err(Error::new(ErrorKind::MlsError).with(e.message))
+                    }
+                    Some(Err(e)) => {
+                        if e.is_eof() {
+                            remaining = (&**bytes)[stream.byte_offset()..].to_vec();
+                            break 'messages;
+                        } else {
+                            Err(Error::new(ErrorKind::MlsError).with(e.to_string()))
+                        }
+                    }
+                    Some(Ok(GrpcResponse::Empty {})) => continue 'messages,
+                    None => break 'messages,
+                };
+                yield res;
+            }
+        }
+    }
+}
diff --git a/xmtp_api_http/src/lib.rs b/xmtp_api_http/src/lib.rs
index 80489fb3c..8a3f972c4 100755
--- a/xmtp_api_http/src/lib.rs
+++ b/xmtp_api_http/src/lib.rs
@@ -1,11 +1,13 @@
 #![warn(clippy::unwrap_used)]
 
 pub mod constants;
+mod http_stream;
 mod util;
 
 use futures::stream;
+use http_stream::create_grpc_stream;
 use reqwest::header;
-use util::{create_grpc_stream, handle_error};
+use util::handle_error;
 use xmtp_proto::api_client::{ClientWithMetadata, XmtpIdentityClient};
 use xmtp_proto::xmtp::identity::api::v1::{
     GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request,
diff --git a/xmtp_api_http/src/util.rs b/xmtp_api_http/src/util.rs
index 8a839fc56..34c878c4a 100644
--- a/xmtp_api_http/src/util.rs
+++ b/xmtp_api_http/src/util.rs
@@ -1,9 +1,5 @@
-use futures::{
-    stream::{self, StreamExt},
-    Stream,
-};
+use crate::http_stream::SubscriptionItem;
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
-use serde_json::Deserializer;
 use std::io::Read;
 use xmtp_proto::{Error, ErrorKind};
 
@@ -23,11 +19,6 @@ pub(crate) struct ErrorResponse {
     details: Vec<String>,
 }
 
-#[derive(Deserialize, Serialize, Debug)]
-pub(crate) struct SubscriptionItem<T> {
-    pub result: T,
-}
-
 /// handle JSON response from gRPC, returning either
 /// the expected deserialized response object or a gRPC [`Error`]
 pub fn handle_error<R: Read, T>(reader: R) -> Result<T, Error>
@@ -43,78 +34,6 @@ where
     }
 }
 
-#[cfg(target_arch = "wasm32")]
-pub fn create_grpc_stream<
-    T: Serialize + Send + 'static,
-    R: DeserializeOwned + Send + std::fmt::Debug + 'static,
->(
-    request: T,
-    endpoint: String,
-    http_client: reqwest::Client,
-) -> stream::LocalBoxStream<'static, Result<R, Error>> {
-    create_grpc_stream_inner(request, endpoint, http_client).boxed_local()
-}
-
-#[cfg(not(target_arch = "wasm32"))]
-pub fn create_grpc_stream<
-    T: Serialize + Send + 'static,
-    R: DeserializeOwned + Send + std::fmt::Debug + 'static,
->(
-    request: T,
-    endpoint: String,
-    http_client: reqwest::Client,
-) -> stream::BoxStream<'static, Result<R, Error>> {
-    create_grpc_stream_inner(request, endpoint, http_client).boxed()
-}
-
-pub fn create_grpc_stream_inner<
-    T: Serialize + Send + 'static,
-    R: DeserializeOwned + Send + std::fmt::Debug + 'static,
->(
-    request: T,
-    endpoint: String,
-    http_client: reqwest::Client,
-) -> impl Stream<Item = Result<R, Error>> {
-    async_stream::stream! {
-        let request = http_client
-                .post(endpoint)
-                .json(&request)
-                .send()
-                .await
-                .map_err(|e| Error::new(ErrorKind::MlsError).with(e))?;
-
-        let mut remaining = vec![];
-        for await bytes in request.bytes_stream() {
-            let bytes = bytes
-                .map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?;
-            let bytes = &[remaining.as_ref(), bytes.as_ref()].concat();
-            let de = Deserializer::from_slice(bytes);
-            let mut stream = de.into_iter::<GrpcResponse<R>>();
-            'messages: loop {
-                let response = stream.next();
-                let res = match response {
-                    Some(Ok(GrpcResponse::Ok(response))) => Ok(response),
-                    Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result),
-                    Some(Ok(GrpcResponse::Err(e))) => {
-                        Err(Error::new(ErrorKind::MlsError).with(e.message))
-                    }
-                    Some(Err(e)) => {
-                        if e.is_eof() {
-                            remaining = (&**bytes)[stream.byte_offset()..].to_vec();
-                            break 'messages;
-                        } else {
-                            Err(Error::new(ErrorKind::MlsError).with(e.to_string()))
-                        }
-                    }
-                    Some(Ok(GrpcResponse::Empty {})) => continue 'messages,
-                    None => break 'messages,
-                };
-                yield res;
-            }
-        }
-    }
-}
-
 #[cfg(feature = "test-utils")]
 #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
 #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs
index d74331ba2..fa8326832 100644
--- a/xmtp_mls/src/subscriptions.rs
+++ b/xmtp_mls/src/subscriptions.rs
@@ -592,6 +592,7 @@ pub(crate) mod tests {
         let alice_group = alice
             .create_group(None, GroupMetadataOptions::default())
             .unwrap();
+        tracing::info!("Group Id = [{}]", hex::encode(&alice_group.group_id));
 
         alice_group
             .add_members_by_inbox_id(&[bob.inbox_id()])