Skip to content

Commit

Permalink
Dynamic casting to COM implementation #3055
Browse files Browse the repository at this point in the history
This provides a new feature for COM developers using the windows-rs crate.
It allows for safe dynamic casting from IUnknown to an implementation object.
It is based on Rust's Any trait.

Any type that is marked with #[implement], except for those that contain
non-static lifetimes, can be used with dynamic casting.

Example:

```rust
struct MyApp { ... }

fn main() {
    let my_app = ComObject::new(MyApp { ... });
    let iunknown: IUnknown = my_app.to_interface();
    do_stuff(&iunknown);
}

fn do_stuff(unknown: &IUnknown) -> Result<()> {
    let my_app: ComObject<MyApp> = unknown.cast_object()?;
    my_app.internal_method();
    Ok(())
}
```
  • Loading branch information
Arlie Davis committed May 24, 2024
1 parent 40d35fa commit 692c4b6
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 7 deletions.
23 changes: 23 additions & 0 deletions crates/libs/core/src/com_object.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::imp::Box;
use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, InterfaceRef};
use core::any::Any;
use core::borrow::Borrow;
use core::ops::Deref;
use core::ptr::NonNull;
Expand Down Expand Up @@ -196,6 +197,28 @@ impl<T: ComObjectInner> ComObject<T> {
I::from_raw(raw)
}
}

/// This casts the given COM interface to [`&dyn Any`]. It returns a reference to the "outer"
/// object, e.g. `MyApp_Impl`, not the inner `MyApp` object.
///
/// `T` must be a type that has been annotated with `#[implement]`; this is checked at
/// compile-time by the generic constraints of this method. However, note that the
/// returned `&dyn Any` refers to the _outer_ implementation object that was generated by
/// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type.
///
/// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust
/// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`.
///
/// The returned value is an owned (counted) reference; this function calls `AddRef` on the
/// underlying COM object. If you do not need an owned reference, then you can use the
/// [`Interface::cast_object_ref`] method instead, and avoid the cost of `AddRef` / `Release`.
pub fn cast_from<I>(interface: &I) -> crate::Result<Self>
where
I: Interface,
T::Outer: Any + 'static + IUnknownImpl<Impl = T>,
{
interface.cast_object()
}
}

impl<T: ComObjectInner + Default> Default for ComObject<T> {
Expand Down
128 changes: 126 additions & 2 deletions crates/libs/core/src/interface.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::*;
use core::any::Any;
use core::ffi::c_void;
use core::marker::PhantomData;
use core::mem::{forget, transmute_copy};
use core::mem::{forget, transmute_copy, MaybeUninit};
use core::ptr::NonNull;

/// Provides low-level access to an interface vtable.
Expand Down Expand Up @@ -97,7 +98,7 @@ pub unsafe trait Interface: Sized + Clone {
//
// This guards against implementations of COM interfaces which may store non-null values
// in 'result' but still return E_NOINTERFACE.
let mut result = core::mem::MaybeUninit::<Option<T>>::zeroed();
let mut result = MaybeUninit::<Option<T>>::zeroed();
self.query(&T::IID, result.as_mut_ptr() as _).ok()?;

// If we get here, then query() has succeeded, but we still need to double-check
Expand All @@ -110,6 +111,123 @@ pub unsafe trait Interface: Sized + Clone {
}
}

/// This casts the given COM interface to [`&dyn Any`].
///
/// Applications should generally _not_ call this method directly. Instead, use the
/// [`Interface::cast_object_ref`] or [`Interface::cast_object`] methods.
///
/// `T` must be a type that has been annotated with `#[implement]`; this is checked at
/// compile-time by the generic constraints of this method. However, note that the
/// returned `&dyn Any` refers to the _outer_ implementation object that was generated by
/// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type.
///
/// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust
/// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`.
///
/// # Safety
///
/// **IMPORTANT!!** This uses a non-standard protocol for QueryInterface! The `DYNAMIC_CAST_IID`
/// IID identifies this protocol, but there is no `IDynamicCast` interface. Instead, objects
/// that recognize `DYNAMIC_CAST_IID` simply store their `&dyn Any` directly at the interface
/// pointer that was passed to `QueryInterface. This means that the returned value has a
/// size that is twice as large (`size_of::<&dyn Any>() == 2 * size_of::<*const c_void>()`).
///
/// This means that callers that use this protocol cannot simply pass `&mut ptr` for
/// an ordinary single-pointer-sized pointer. Only this method understands this protocol.
///
/// Another part of this protocol is that the implementation of `QueryInterface` _does not_
/// AddRef the object. The caller must guarantee the liveness of the COM object. In Rust,
/// this means tying the lifetime of the IUnknown* that we used for the QueryInterface
/// call to the lifetime of the returned `&dyn Any` value.
///
/// This method preserves type safety and relies on these invariants:
///
/// * All `QueryInterface` implementations that recognize `DYNAMIC_CAST_IID` are generated by
/// the `#[implement]` macro and respect the rules described here.
#[inline(always)]
fn cast_to_any<T>(&self) -> Result<&dyn Any>
where
T: ComObjectInner,
T::Outer: Any + 'static + IUnknownImpl<Impl = T>,
{
unsafe {
let mut any_ref_arg: MaybeUninit<&dyn Any> = MaybeUninit::zeroed();
self.query(&DYNAMIC_CAST_IID, any_ref_arg.as_mut_ptr() as *mut *mut c_void).ok()?;
Ok(any_ref_arg.assume_init())
}
}

/// Returns `true` if the given COM interface refers to an implementation of `T`.
///
/// `T` must be a type that has been annotated with `#[implement]`; this is checked at
/// compile-time by the generic constraints of this method.
///
/// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust
/// object that contains non-static lifetimes, then this function will return `false`.
#[inline(always)]
fn is_object<T>(&self) -> bool
where
T: ComObjectInner,
T::Outer: Any + 'static + IUnknownImpl<Impl = T>,
{
if let Ok(any) = self.cast_to_any::<T>() {
any.is::<T::Outer>()
} else {
false
}
}

/// This casts the given COM interface to [`&dyn Any`]. It returns a reference to the "outer"
/// object, e.g. `&MyApp_Impl`, not the inner `&MyApp` object.
///
/// `T` must be a type that has been annotated with `#[implement]`; this is checked at
/// compile-time by the generic constraints of this method. However, note that the
/// returned `&dyn Any` refers to the _outer_ implementation object that was generated by
/// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type.
///
/// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust
/// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`.
///
/// The returned value is borrowed. If you need an owned (counted) reference, then use
/// [`Interface::cast_object`].
#[inline(always)]
fn cast_object_ref<T>(&self) -> Result<&T::Outer>
where
T: ComObjectInner,
T::Outer: Any + 'static + IUnknownImpl<Impl = T>,
{
let any: &dyn Any = self.cast_to_any::<T>()?;
if let Some(outer) = any.downcast_ref::<T::Outer>() {
Ok(outer)
} else {
Err(imp::E_NOINTERFACE.into())
}
}

/// This casts the given COM interface to [`&dyn Any`]. It returns a reference to the "outer"
/// object, e.g. `MyApp_Impl`, not the inner `MyApp` object.
///
/// `T` must be a type that has been annotated with `#[implement]`; this is checked at
/// compile-time by the generic constraints of this method. However, note that the
/// returned `&dyn Any` refers to the _outer_ implementation object that was generated by
/// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type.
///
/// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust
/// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`.
///
/// The returned value is an owned (counted) reference; this function calls `AddRef` on the
/// underlying COM object. If you do not need an owned reference, then you can use the
/// [`Interface::cast_object_ref`] method instead, and avoid the cost of `AddRef` / `Release`.
#[inline(always)]
fn cast_object<T>(&self) -> Result<ComObject<T>>
where
T: ComObjectInner,
T::Outer: Any + 'static + IUnknownImpl<Impl = T>,
{
let object_ref = self.cast_object_ref::<T>()?;
Ok(object_ref.to_object())
}

/// Attempts to create a [`Weak`] reference to this object.
fn downgrade(&self) -> Result<Weak<Self>> {
self.cast::<imp::IWeakReferenceSource>().and_then(|source| Weak::downgrade(&source))
Expand Down Expand Up @@ -210,3 +328,9 @@ impl<'a, I: Interface> core::ops::Deref for InterfaceRef<'a, I> {
unsafe { core::mem::transmute(self) }
}
}

/// This IID identifies a special protocol, used by [`Interface::cast_to_any`]. This is _not_
/// an ordinary COM interface; it uses special lifetime rules and a larger interface pointer.
/// See the comments on [`Interface::cast_to_any`].
#[doc(hidden)]
pub const DYNAMIC_CAST_IID: GUID = GUID::from_u128(0xae49d5cb_143f_431c_874c_2729336e4eca);
15 changes: 15 additions & 0 deletions crates/libs/core/src/unknown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,21 @@ pub trait IUnknownImpl {
{
<Self as ComObjectInterface<I>>::as_interface_ref(self).to_owned()
}

/// Creates a new owned reference to this object.
///
/// # Safety
///
/// This function can only be safely called by `<Foo>_Impl` objects that are embedded in a
/// `ComObject`. Since we only allow safe Rust code to access these objects using a `ComObject`
/// or a `&<Foo>_Impl` that points within a `ComObject`, this is safe.
fn to_object(&self) -> ComObject<Self::Impl>
where
Self::Impl: ComObjectInner<Outer = Self>;

/// The distance from the start of `<Foo>_Impl` to the `this` field within it, measured in
/// pointer-sized elements. The `this` field contains the `MyApp` instance.
const INNER_OFFSET_IN_POINTERS: usize;
}

impl IUnknown_Vtbl {
Expand Down
41 changes: 37 additions & 4 deletions crates/libs/implement/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
let original_type2 = original_type.clone();
let original_type2 = syn::parse_macro_input!(original_type2 as syn::ItemStruct);
let vis = &original_type2.vis;
let original_ident = original_type2.ident;
let original_ident = &original_type2.ident;
let mut constraints = quote! {};

if let Some(where_clause) = original_type2.generics.where_clause {
if let Some(where_clause) = &original_type2.generics.where_clause {
where_clause.predicates.to_tokens(&mut constraints);
}

Expand Down Expand Up @@ -83,6 +83,25 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
}
});

// Dynamic casting requires that the object not contain non-static lifetimes.
let enable_dyn_casting = original_type2.generics.lifetimes().count() == 0;
let dynamic_cast_query = if enable_dyn_casting {
quote! {
else if *iid == ::windows_core::DYNAMIC_CAST_IID {
// DYNAMIC_CAST_IID is special. We _do not_ increase the reference count for this pseudo-interface.
// Also, instead of returning an interface pointer, we simply write the `&dyn Any` directly to the
// 'interface' pointer. Since the size of `&dyn Any` is 2 pointers, not one, the caller must be
// prepared for this. This is not a normal QueryInterface call.
//
// See the `Interface::cast_to_any` method, which is the only caller that should use DYNAMIC_CAST_ID.
(interface as *mut *const dyn core::any::Any).write(self as &dyn ::core::any::Any as *const dyn ::core::any::Any);
return ::windows_core::HRESULT(0);
}
}
} else {
quote!()
};

// The distance from the beginning of the generated type to the 'this' field, in units of pointers (not bytes).
let offset_of_this_in_pointers = 1 + attributes.implement.len();
let offset_of_this_in_pointers_token = proc_macro2::Literal::usize_unsuffixed(offset_of_this_in_pointers);
Expand Down Expand Up @@ -201,7 +220,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
|| iid == &<::windows_core::IInspectable as ::windows_core::Interface>::IID
|| iid == &<::windows_core::imp::IAgileObject as ::windows_core::Interface>::IID {
&self.identity as *const _ as *mut _
} #(#queries)* else {
}
#(#queries)*
#dynamic_cast_query
else {
::core::ptr::null_mut()
};

Expand Down Expand Up @@ -230,7 +252,7 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
unsafe fn Release(self_: *mut Self) -> u32 {
let remaining = (*self_).count.release();
if remaining == 0 {
_ = ::windows_core::imp::Box::from_raw(self_ as *const Self as *mut Self);
_ = ::windows_core::imp::Box::from_raw(self_);
}
remaining
}
Expand All @@ -247,6 +269,17 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
&*((inner as *const Self::Impl as *const *const ::core::ffi::c_void)
.sub(#offset_of_this_in_pointers_token) as *const Self)
}

fn to_object(&self) -> ::windows_core::ComObject<Self::Impl> {
self.count.add_ref();
unsafe {
::windows_core::ComObject::from_raw(
::core::ptr::NonNull::new_unchecked(self as *const Self as *mut Self)
)
}
}

const INNER_OFFSET_IN_POINTERS: usize = #offset_of_this_in_pointers_token;
}

impl #generics #original_ident::#generics where #constraints {
Expand Down
45 changes: 44 additions & 1 deletion crates/tests/implement_core/src/com_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::borrow::Borrow;
use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
use std::sync::Arc;
use windows_core::{
implement, interface, ComObject, IUnknown, IUnknownImpl, IUnknown_Vtbl, InterfaceRef,
implement, interface, ComObject, IUnknown, IUnknownImpl, IUnknown_Vtbl, Interface, InterfaceRef,
};

#[interface("818f2fd1-d479-4398-b286-a93c4c7904d1")]
Expand All @@ -19,8 +19,12 @@ unsafe trait IBar: IUnknown {
fn say_hello(&self);
}

const APP_SIGNATURE: [u8; 8] = *b"cafef00d";

#[implement(IFoo, IBar)]
struct MyApp {
// We use signature to verify field offsets for dynamic casts
signature: [u8; 8],
x: u32,
tombstone: Arc<Tombstone>,
}
Expand Down Expand Up @@ -63,6 +67,7 @@ impl core::fmt::Display for MyApp {
impl Default for MyApp {
fn default() -> Self {
Self {
signature: APP_SIGNATURE,
x: 0,
tombstone: Arc::new(Tombstone::default()),
}
Expand Down Expand Up @@ -109,6 +114,7 @@ impl MyApp {
fn new(x: u32) -> ComObject<Self> {
ComObject::new(Self {
x,
signature: APP_SIGNATURE,
tombstone: Arc::new(Tombstone::default()),
})
}
Expand Down Expand Up @@ -333,6 +339,43 @@ fn from_inner_ref() {
unsafe { ibar.say_hello() };
}

#[test]
fn to_object() {
let app = MyApp::new(42);
let tombstone = app.tombstone.clone();
let app_outer: &MyApp_Impl = &app;

let second_app = app_outer.to_object();
assert!(!tombstone.is_dead());
assert_eq!(second_app.signature, APP_SIGNATURE);

println!("x = {}", unsafe { second_app.get_x() });

drop(second_app);
assert!(!tombstone.is_dead());

drop(app);
assert!(tombstone.is_dead());
}

#[test]
fn dynamic_cast() {
let app = MyApp::new(42);
let unknown = app.to_interface::<IUnknown>();

assert!(!unknown.is_object::<SendableThing>());
assert!(unknown.is_object::<MyApp>());

let dyn_app_ref: &MyApp_Impl = unknown.cast_object_ref::<MyApp>().unwrap();
assert_eq!(dyn_app_ref.signature, APP_SIGNATURE);

let dyn_app_owned: ComObject<MyApp> = unknown.cast_object().unwrap();
assert_eq!(dyn_app_owned.signature, APP_SIGNATURE);

let dyn_app_owned_2: ComObject<MyApp> = ComObject::cast_from(&unknown).unwrap();
assert_eq!(dyn_app_owned_2.signature, APP_SIGNATURE);
}

// This tests that we can place a type that is not Send in a ComObject.
// Compilation is sufficient to test.
#[implement(IBar)]
Expand Down

0 comments on commit 692c4b6

Please sign in to comment.