From b31a534c09d2719e1239dbb32bbb7554cb37a8a4 Mon Sep 17 00:00:00 2001
From: def <def@hai.li>
Date: Fri, 1 Dec 2023 13:13:08 +0800
Subject: [PATCH] feat: add host_object shared_buffer implement and example

---
 examples/host_object.rs   | 165 ++++++++++++++++++++++++++++++++++++++
 examples/shared_buffer.rs | 145 +++++++++++++++++++++++++++++++++
 src/lib.rs                |  53 +++++++++++-
 src/webview2/mod.rs       |  52 ++++++++++++
 4 files changed, 414 insertions(+), 1 deletion(-)
 create mode 100644 examples/host_object.rs
 create mode 100644 examples/shared_buffer.rs

diff --git a/examples/host_object.rs b/examples/host_object.rs
new file mode 100644
index 0000000000..c9c14a7833
--- /dev/null
+++ b/examples/host_object.rs
@@ -0,0 +1,165 @@
+// Copyright 2020-2023 Tauri Programme within The Commons Conservancy
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-License-Identifier: MIT
+
+use tao::{
+  event::{Event, WindowEvent},
+  event_loop::{ControlFlow, EventLoop},
+  window::WindowBuilder,
+};
+use wry::WebViewBuilder;
+
+#[cfg(target_os = "windows")]
+fn main() -> wry::Result<()> {
+
+  let event_loop = EventLoop::new();
+  let window = WindowBuilder::new().build(&event_loop).unwrap();
+
+  use wry::WebViewExtWindows;
+  use windows::{
+    core::{w, BSTR},
+    Win32::{
+      System::Variant::{
+        VARIANT, VARIANT_0, VARIANT_0_0, VARIANT_0_0_0, VARENUM,
+        VT_BSTR, VT_I4, VT_DISPATCH,
+      },
+      System::Com::{
+        IDispatch, IDispatch_Impl, ITypeInfo,
+        DISPATCH_FLAGS, DISPPARAMS, EXCEPINFO,
+      }
+    }
+  };
+  use std::mem::ManuallyDrop;
+
+  // This is a simple usage example. add_host_object_to_script is a mapping of the native [addHostObjectToScript](https://learn.microsoft.com/en-us/microsoft-edge/webview2/reference/win32/icorewebview2#addhostobjecttoscript) method of webview2. It requires manual creation of hostobject and memory management. Please use it with caution.
+  struct Variant(VARIANT);
+  impl Variant {
+      pub fn new(num: VARENUM, contents: VARIANT_0_0_0) -> Variant {
+          Variant {
+              0: VARIANT {
+                  Anonymous: VARIANT_0 {
+                      Anonymous: ManuallyDrop::new(VARIANT_0_0 {
+                          vt: num,
+                          wReserved1: 0,
+                          wReserved2: 0,
+                          wReserved3: 0,
+                          Anonymous: contents,
+                      }),
+                  },
+              },
+          }
+      }   
+  }
+  impl From<String> for Variant {
+    fn from(value: String) -> Variant { Variant::new(
+      VT_BSTR,
+      VARIANT_0_0_0 { 
+        bstrVal: ManuallyDrop::new(BSTR::from(value)) 
+      }
+    ) }
+  }
+  impl From<&str> for Variant {
+    fn from(value: &str) -> Variant { Variant::from(value.to_string()) }
+  }
+  impl From<i32> for Variant {
+    fn from(value: i32) -> Variant { Variant::new(VT_I4, VARIANT_0_0_0 { lVal: value }) }
+  }
+  impl From<std::mem::ManuallyDrop<::core::option::Option<IDispatch>>> for Variant {
+    fn from(value: std::mem::ManuallyDrop<::core::option::Option<IDispatch>>) -> Variant { Variant::new(VT_DISPATCH, VARIANT_0_0_0 { pdispVal: value }) }
+  }
+  impl Drop for Variant {
+    fn drop(&mut self) {
+        match VARENUM(unsafe { self.0.Anonymous.Anonymous.vt.0 }) {
+            VT_BSTR => unsafe {
+                drop(&mut &self.0.Anonymous.Anonymous.Anonymous.bstrVal)
+            } 
+            _ => {}
+        }
+        unsafe { drop(&mut self.0.Anonymous.Anonymous) }
+    }
+  }
+  #[windows::core::implement(IDispatch)]
+  struct FunctionWithStringArgument;
+  impl IDispatch_Impl for FunctionWithStringArgument {
+    #![allow(non_snake_case)]
+    fn GetTypeInfoCount(&self) -> windows::core::Result<u32> {Ok(0)}
+    fn GetTypeInfo(&self, _itinfo: u32, _lcid: u32) -> windows::core::Result<ITypeInfo> {Err(windows::core::Error::new(windows::Win32::Foundation
+  ::E_FAIL, "GetTypeInfo Error \t\n\r".into()))}
+    fn GetIDsOfNames(&self, _riid: *const ::windows::core::GUID, _rgsznames: *const ::windows::core::PCWSTR, _cnames: u32, _lcid: u32, _rgdispid: *mut i32) -> windows::core::Result<()> {Ok(())}
+    fn Invoke(
+      &self,
+      _dispidmember: i32,
+      _riid: *const windows::core::GUID,
+      _lcid: u32,
+      _wflags: DISPATCH_FLAGS,
+      pdispparams: *const DISPPARAMS,
+      pvarresult: *mut VARIANT,
+      _pexcepinfo: *mut EXCEPINFO,
+      _puargerr: *mut u32
+    ) -> windows::core::Result<()> {
+      let pdispparams = unsafe { *pdispparams };
+      let rgvarg = unsafe { &*(pdispparams.rgvarg) };
+      let rgvarg_0_0 = unsafe { &rgvarg.Anonymous.Anonymous };
+      unsafe { dbg!(&rgvarg_0_0.Anonymous.bstrVal); }
+      let b_str_val = unsafe { &rgvarg_0_0.Anonymous.bstrVal.to_string() };
+      dbg!(b_str_val);
+
+      let pvarresult_0_0 = unsafe { &mut (*pvarresult).Anonymous.Anonymous };
+      pvarresult_0_0.vt = VT_BSTR;
+      pvarresult_0_0.Anonymous.bstrVal = ManuallyDrop::new(BSTR::from(format!(r#"Successful sync call functionWithStringArgument, and the argument is "{}"."#, b_str_val).to_string())) ;
+      Ok(())
+    }
+  }
+  let mut i32_variant =  Variant::from(1234);
+  let mut string_variant =  Variant::from("string variant");
+  let mut function_with_string_argument_variant =  Variant::from(ManuallyDrop::new(Some(IDispatch::from(FunctionWithStringArgument))));
+
+  let webview = WebViewBuilder::new(&window)
+    .with_url("https://tauri.app")?
+    .with_initialization_script(r#"
+      alert(chrome.webview.hostObjects.sync.i32)
+      alert(chrome.webview.hostObjects.sync.string)
+      alert(chrome.webview.hostObjects.sync.functionWithStringArgument("hi"))
+    "#)
+    .build()?;
+
+  unsafe {
+    let _ = webview.add_host_object_to_script(w!("i32"), &mut i32_variant.0);
+    let _ = webview.add_host_object_to_script(w!("string"), &mut string_variant.0);
+    let _ = webview.add_host_object_to_script(w!("functionWithStringArgument"), &mut function_with_string_argument_variant.0);
+  }
+
+  event_loop.run(move |event, _, control_flow| {
+    *control_flow = ControlFlow::Wait;
+
+    if let Event::WindowEvent {
+      event: WindowEvent::CloseRequested,
+      ..
+    } = event
+    {
+      *control_flow = ControlFlow::Exit
+    }
+  });
+}
+
+#[cfg(not(target_os = "windows"))]
+fn main() -> wry::Result<()> {
+  let event_loop = EventLoop::new();
+  let window = WindowBuilder::new().build(&event_loop).unwrap();
+
+  let webview = WebViewBuilder::new(&window)
+    .with_url("https://tauri.app")?
+    .build()?;
+
+  event_loop.run(move |event, _, control_flow| {
+    *control_flow = ControlFlow::Wait;
+
+    if let Event::WindowEvent {
+      event: WindowEvent::CloseRequested,
+      ..
+    } = event
+    {
+      *control_flow = ControlFlow::Exit
+    }
+  }); 
+}
\ No newline at end of file
diff --git a/examples/shared_buffer.rs b/examples/shared_buffer.rs
new file mode 100644
index 0000000000..1a3bebce0e
--- /dev/null
+++ b/examples/shared_buffer.rs
@@ -0,0 +1,145 @@
+// Copyright 2020-2023 Tauri Programme within The Commons Conservancy
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-License-Identifier: MIT
+
+use tao::{
+  event::{Event, WindowEvent},
+  event_loop::ControlFlow,
+  window::WindowBuilder,
+};
+use wry::WebViewBuilder;
+
+// Currently, only Windows platforms support shared_buffer.
+#[cfg(target_os = "windows")]
+fn main() -> wry::Result<()> {
+  use wry::WebViewExtWindows;
+
+  enum UserEvent {
+    InitSharedBuffer,
+    PingSharedBuffer,
+  }
+
+  let event_loop = tao::event_loop::EventLoopBuilder::<UserEvent>::with_user_event().build();
+  let proxy = event_loop.create_proxy();
+  let window = WindowBuilder::new().build(&event_loop).unwrap();
+
+  let webview = WebViewBuilder::new(&window)
+    .with_url("https://tauri.app")?
+    .with_ipc_handler(move |req: String| match req.as_str() {
+      "initSharedBuffer" => { let _ = proxy.send_event(UserEvent::InitSharedBuffer); }
+      "pingSharedBuffer" => { let _ = proxy.send_event(UserEvent::PingSharedBuffer); }
+      _ => {}
+    })
+    .with_initialization_script(r#";(function() {
+      function writeStringIntoSharedBuffer(string, sharedBuffer, pathPtr) {
+        const path = new TextEncoder().encode(string)
+        const pathLen = path.length
+        const pathArray = new Uint8Array(sharedBuffer, pathPtr, pathLen*8)
+        for(let i = 0; i < pathLen; i++) {
+          pathArray[i] = path[i]
+        }
+        return [pathPtr, pathLen]
+      }
+
+      const sharedBufferReceivedHandler = e => {
+        window.chrome.webview.removeEventListener("sharedbufferreceived", sharedBufferReceivedHandler);
+
+        alert(JSON.stringify(e.additionalData))
+
+        var sharedBuffer = e.getBuffer()
+        console.log(sharedBuffer)
+        window.sharedBuffer = sharedBuffer
+
+        // JS write
+        writeStringIntoSharedBuffer("I'm JS.", sharedBuffer, 0)
+
+        window.ipc.postMessage('pingSharedBuffer');
+      }
+      window.chrome.webview.addEventListener("sharedbufferreceived", sharedBufferReceivedHandler);
+      window.ipc.postMessage('initSharedBuffer');
+    })();"#)
+    .build()?;
+
+  // The Webview2 developer tools include a memory inspector, which makes it easy to debug memory issues.
+  webview.open_devtools();
+
+  let mut shared_buffer: Option<webview2_com::Microsoft::Web::WebView2::Win32::ICoreWebView2SharedBuffer> = None;
+
+  event_loop.run(move |event, _, control_flow| {
+    *control_flow = ControlFlow::Wait;
+    match event {
+      Event::WindowEvent {
+        event: WindowEvent::CloseRequested,
+        ..
+      } => {
+        *control_flow = ControlFlow::Exit
+      },
+
+      Event::UserEvent(e) => match e {
+        UserEvent::InitSharedBuffer => {
+          // Memory obtained through webview2 must be manually managed. Use it with care.
+          shared_buffer = Some(unsafe { webview.create_shared_buffer(1024) }.unwrap());
+          if let Some(shared_buffer) = &shared_buffer {
+            dbg!(shared_buffer);
+            let _ = unsafe {
+              webview.post_shared_buffer_to_script(
+                shared_buffer,
+                webview2_com::Microsoft::Web::WebView2::Win32::COREWEBVIEW2_SHARED_BUFFER_ACCESS_READ_WRITE,
+                windows::core::w!(r#"{"jsonkey":"jsonvalue"}"#)
+              )
+            };
+          }
+        },
+        UserEvent::PingSharedBuffer => {
+          if let Some(shared_buffer) = &shared_buffer {
+            let mut ptr: *mut u8 = &mut 0u8;
+            let _ = unsafe { shared_buffer.Buffer(&mut ptr) };
+
+            // Rust read
+            let len = 8; // align to 4
+            let read_string: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(ptr, len) };
+            let read_string = std::str::from_utf8(&read_string).unwrap();
+            dbg!(read_string);
+
+            // Rust write
+            let mut vec = String::from("I'm Rust.").into_bytes();
+            unsafe { std::ptr::copy((&mut vec).as_mut_ptr(), ptr.offset(len as isize), 9) };
+
+            let _ = webview.evaluate_script(r#";(function() {
+              // JS read
+              alert(
+                new TextDecoder()
+                .decode(new Uint8Array(window.sharedBuffer, 8, 9))
+              )
+            })()"#);
+          }
+        }
+      },
+
+      _ => (),
+    }
+  });
+}
+
+// Non-Windows systems do not yet support shared_buffer.
+#[cfg(not(target_os = "windows"))]
+fn main() -> wry::Result<()> {
+  let event_loop = tao::event_loop::EventLoop::new();
+  let window = WindowBuilder::new().build(&event_loop).unwrap();
+
+  let _ = WebViewBuilder::new(&window)
+    .with_url("https://tauri.app")?
+    .build()?;
+
+  event_loop.run(move |event, _, control_flow| {
+    *control_flow = ControlFlow::Wait;
+
+    if let Event::WindowEvent {
+      event: WindowEvent::CloseRequested,
+      ..
+    } = event
+    {
+      *control_flow = ControlFlow::Exit
+    }
+  }); 
+}
\ No newline at end of file
diff --git a/src/lib.rs b/src/lib.rs
index aa4297d93f..85d1a086cb 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -210,7 +210,6 @@ pub(crate) mod webview2;
 use self::webview2::*;
 #[cfg(target_os = "windows")]
 use webview2_com::Microsoft::Web::WebView2::Win32::ICoreWebView2Controller;
-
 use std::{borrow::Cow, path::PathBuf, rc::Rc};
 
 use http::{Request, Response};
@@ -1404,6 +1403,29 @@ pub trait WebViewExtWindows {
   /// [1]: https://learn.microsoft.com/en-us/dotnet/api/microsoft.web.webview2.core.corewebview2memoryusagetargetlevel
   /// [2]: https://learn.microsoft.com/en-us/dotnet/api/microsoft.web.webview2.core.corewebview2.memoryusagetargetlevel?view=webview2-dotnet-1.0.2088.41#remarks
   fn set_memory_usage_level(&self, level: MemoryUsageLevel);
+
+  unsafe fn add_host_object_to_script<P0>(
+    &self,
+    name: P0,
+    object: *mut ::windows::Win32::System::Variant::VARIANT,
+  ) -> ::windows::core::Result<()>
+  where
+      P0: ::windows::core::IntoParam<::windows::core::PCWSTR>;
+
+  unsafe fn create_shared_buffer(
+    &self,
+    size: u64,
+  ) -> ::windows::core::Result<webview2_com::Microsoft::Web::WebView2::Win32::ICoreWebView2SharedBuffer>;
+
+  unsafe fn post_shared_buffer_to_script<P0, P1>(
+    &self,
+    sharedbuffer: P0,
+    access: webview2_com::Microsoft::Web::WebView2::Win32::COREWEBVIEW2_SHARED_BUFFER_ACCESS,
+    additionaldataasjson: P1, 
+  ) -> ::windows::core::Result<()>
+  where
+    P0: ::windows::core::IntoParam<webview2_com::Microsoft::Web::WebView2::Win32::ICoreWebView2SharedBuffer>,
+    P1: ::windows::core::IntoParam<::windows::core::PCWSTR>;
 }
 
 #[cfg(target_os = "windows")]
@@ -1419,6 +1441,35 @@ impl WebViewExtWindows for WebView {
   fn set_memory_usage_level(&self, level: MemoryUsageLevel) {
     self.webview.set_memory_usage_level(level);
   }
+
+  unsafe fn add_host_object_to_script<P0>(
+    &self,
+    name: P0,
+    object: *mut ::windows::Win32::System::Variant::VARIANT,
+  ) -> ::windows::core::Result<()> where P0: ::windows::core::IntoParam<::windows::core::PCWSTR>,
+  {
+    self.webview.add_host_object_to_script(name, object)
+  }
+
+  unsafe fn create_shared_buffer(
+    &self,
+    size: u64,
+  ) -> ::windows::core::Result<webview2_com::Microsoft::Web::WebView2::Win32::ICoreWebView2SharedBuffer> {
+    self.webview.create_shared_buffer(size)
+  }
+
+  unsafe fn post_shared_buffer_to_script<P0, P1>(
+    &self,
+    sharedbuffer: P0,
+    access: webview2_com::Microsoft::Web::WebView2::Win32::COREWEBVIEW2_SHARED_BUFFER_ACCESS,
+    additionaldataasjson: P1, 
+  ) -> ::windows::core::Result<()>
+  where
+    P0: ::windows::core::IntoParam<webview2_com::Microsoft::Web::WebView2::Win32::ICoreWebView2SharedBuffer>,
+    P1: ::windows::core::IntoParam<::windows::core::PCWSTR>,
+  {
+    self.webview.post_shared_buffer_to_script(sharedbuffer, access, additionaldataasjson)
+  }
 }
 
 /// Additional methods on `WebView` that are specific to Linux.
diff --git a/src/webview2/mod.rs b/src/webview2/mod.rs
index cdb53bd4af..37cbe02955 100644
--- a/src/webview2/mod.rs
+++ b/src/webview2/mod.rs
@@ -1092,6 +1092,58 @@ impl InnerWebView {
     let level = COREWEBVIEW2_MEMORY_USAGE_TARGET_LEVEL(level);
     let _ = unsafe { webview.SetMemoryUsageTargetLevel(level) };
   }
+
+  pub unsafe fn add_host_object_to_script<P0>(
+    &self,
+    name: P0,
+    object: *mut ::windows::Win32::System::Variant::VARIANT,
+  ) -> ::windows::core::Result<()>
+  where
+      P0: ::windows::core::IntoParam<::windows::core::PCWSTR>,
+  {
+    match self.webview.cast::<ICoreWebView2_19>() {
+      Ok(webview) => {
+        webview.AddHostObjectToScript(name, object)
+      },
+      Err(error) => {
+        Err(error)
+      }
+    } 
+  }
+
+  pub unsafe fn create_shared_buffer(
+    &self,
+    size: u64,
+  ) -> ::windows::core::Result<ICoreWebView2SharedBuffer> {
+    match self.env.cast::<ICoreWebView2Environment12>() {
+      Ok(env) => {
+        env.CreateSharedBuffer(size)
+      },
+      Err(error) => {
+        Err(error)
+      }
+    } 
+  }
+
+  pub unsafe fn post_shared_buffer_to_script<P0, P1>(
+    &self,
+    sharedbuffer: P0,
+    access: webview2_com::Microsoft::Web::WebView2::Win32::COREWEBVIEW2_SHARED_BUFFER_ACCESS,
+    additionaldataasjson: P1, 
+  ) -> ::windows::core::Result<()>
+  where
+    P0: ::windows::core::IntoParam<webview2_com::Microsoft::Web::WebView2::Win32::ICoreWebView2SharedBuffer>,
+    P1: ::windows::core::IntoParam<::windows::core::PCWSTR>,
+  {
+    match self.webview.cast::<ICoreWebView2_19>() {
+      Ok(webview) => {
+        webview.PostSharedBufferToScript(sharedbuffer, access, additionaldataasjson)
+      },
+      Err(error) => {
+        Err(error)
+      }
+    }
+  }
 }
 
 unsafe fn prepare_web_request_response(