diff --git a/crates/libs/bindgen/src/rust/implements.rs b/crates/libs/bindgen/src/rust/implements.rs index cccf8503a2f..5d6a894dd53 100644 --- a/crates/libs/bindgen/src/rust/implements.rs +++ b/crates/libs/bindgen/src/rust/implements.rs @@ -97,10 +97,15 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream { if has_unknown_base { quote! { - unsafe extern "system" fn #name<#constraints Identity: windows_core::IUnknownImpl, Impl: #impl_ident<#generic_names>, const OFFSET: isize> #vtbl_signature { + unsafe extern "system" fn #name< + #constraints + Identity: windows_core::IUnknownImpl, + OuterToImpl: ::windows_core::ComGetImpl, + const OFFSET: isize + > #vtbl_signature where OuterToImpl::Impl: #impl_ident<#generic_names> { // offset the `this` pointer by `OFFSET` times the size of a pointer and cast it as an IUnknown implementation - let this = (this as *const *const ()).offset(OFFSET) as *const Identity; - let this = (*this).get_impl(); + let this_outer: &Identity = &*((this as *const *const ()).offset(OFFSET) as *const Identity); + let this = OuterToImpl::get_impl(this_outer); #invoke_upcall } } @@ -123,7 +128,7 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream { Some(metadata::Type::TypeDef(def, generics)) => { let name = writer.type_def_name_imp(*def, generics, "_Vtbl"); if has_unknown_base { - methods.combine("e! { base__: #name::new::(), }); + methods.combine("e! { base__: #name::new::(), }); } else { methods.combine("e! { base__: #name::new::(), }); } @@ -136,7 +141,8 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream { for method in def.methods() { let name = method_names.add(method); if has_unknown_base { - methods.combine("e! { #name: #name::<#generic_names Identity, Impl, OFFSET>, }); + methods + .combine("e! { #name: #name::<#generic_names Identity, OuterToImpl, OFFSET>, }); } else { methods.combine("e! { #name: #name::, }); } @@ -151,7 +157,12 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream { #runtime_name #features impl<#constraints> #vtbl_ident<#generic_names> { - pub const fn new, Impl: #impl_ident<#generic_names>, const OFFSET: isize>() -> #vtbl_ident<#generic_names> { + pub const fn new< + Identity: windows_core::IUnknownImpl, + OuterToImpl: ::windows_core::ComGetImpl, + const OFFSET: isize + >() -> #vtbl_ident<#generic_names> + where OuterToImpl::Impl : #impl_ident<#generic_names> { #(#method_impls)* Self{ #methods diff --git a/crates/libs/core/src/com_object.rs b/crates/libs/core/src/com_object.rs index ede313071e7..aa50ba7f54f 100644 --- a/crates/libs/core/src/com_object.rs +++ b/crates/libs/core/src/com_object.rs @@ -330,3 +330,57 @@ impl Borrow for ComObject { self.get() } } + +/// Allows a COM object implementation to implement COM interfaces either on the "outer" type or +/// the "inner" type. +/// +/// This trait is part of the implementation of `windows-rs` and is not meant to be used directly +/// by user code. This trait is not stable and may change at any time. +#[doc(hidden)] +pub trait ComGetImpl { + /// The type that implements the COM interface. + type Impl; + + /// At runtime, casts from the outer object type to the implementation type. + fn get_impl(outer: &Outer) -> &Self::Impl; +} + +/// Selects the "inner" type of a COM object implementation. This implementation uses the +/// `IUnknownImpl` trait both to specify the type that implements the COM interface and to +/// cast from `&Outer` to `&Inner` (i.e. from `&MyApp_Impl` to `&MyApp`). +/// +/// This struct is part of the implementation of `windows-rs` and is not meant to be used directly +/// by user code. This trait is not stable and may change at any time. +#[doc(hidden)] +pub struct ComGetImplInner { + _marker: core::marker::PhantomData, +} + +impl ComGetImpl for ComGetImplInner +where + Outer: IUnknownImpl, +{ + type Impl = ::Impl; + + fn get_impl(outer: &Outer) -> &Self::Impl { + ::get_impl(outer) + } +} + +/// Selects the "outer" type of a COM object implementation. This is basically an identify function, +/// over types. +/// +/// This struct is part of the implementation of `windows-rs` and is not meant to be used directly +/// by user code. This trait is not stable and may change at any time. +#[doc(hidden)] +pub struct ComGetImplOuter { + _marker: core::marker::PhantomData, +} + +impl ComGetImpl for ComGetImplOuter { + type Impl = Outer; + + fn get_impl(outer: &Outer) -> &Self::Impl { + outer + } +} diff --git a/crates/libs/implement/src/lib.rs b/crates/libs/implement/src/lib.rs index 4fa01941a4f..e6053ac483d 100644 --- a/crates/libs/implement/src/lib.rs +++ b/crates/libs/implement/src/lib.rs @@ -77,8 +77,16 @@ pub fn implement( .enumerate() .map(|(enumerate, implement)| { let vtbl_ident = implement.to_vtbl_ident(); + let outer_to_impl = match implement.impl_location { + ImplLocation::Outer => { + quote!(::windows_core::ComGetImplOuter<#impl_ident #generics>) + } + ImplLocation::Inner => { + quote!(::windows_core::ComGetImplInner<#impl_ident #generics>) + } + }; let offset = proc_macro2::Literal::isize_unsuffixed(-1 - enumerate as isize); - quote! { #vtbl_ident::new::() } + quote! { #vtbl_ident::new::() } }); let offset = attributes @@ -389,6 +397,18 @@ pub fn implement( struct ImplementType { type_name: String, generics: Vec, + impl_location: ImplLocation, +} + +/// Specifies whether a COM object implements COM interfaces on its "inner" or "outer" object. +/// +/// The default, for backward compatibility, is inner. In the long-term, arguably all COM objects +/// should switch to defining interfaces on the outer object. +#[derive(Copy, Clone, Eq, PartialEq, Default)] +enum ImplLocation { + #[default] + Inner, + Outer, } impl ImplementType { @@ -415,9 +435,10 @@ struct ImplementAttributes { impl syn::parse::Parse for ImplementAttributes { fn parse(cursor: syn::parse::ParseStream<'_>) -> syn::parse::Result { let mut input = Self::default(); + let mut current_impl_location = ImplLocation::Inner; while !cursor.is_empty() { - input.parse_implement(cursor)?; + input.parse_implement(&mut current_impl_location, cursor)?; } Ok(input) @@ -425,9 +446,13 @@ impl syn::parse::Parse for ImplementAttributes { } impl ImplementAttributes { - fn parse_implement(&mut self, cursor: syn::parse::ParseStream<'_>) -> syn::parse::Result<()> { + fn parse_implement( + &mut self, + current_impl_location: &mut ImplLocation, + cursor: syn::parse::ParseStream<'_>, + ) -> syn::parse::Result<()> { let tree = cursor.parse::()?; - self.walk_implement(&tree, &mut String::new())?; + self.walk_implement(&tree, current_impl_location, &mut String::new())?; if !cursor.is_empty() { cursor.parse::()?; @@ -439,6 +464,7 @@ impl ImplementAttributes { fn walk_implement( &mut self, tree: &UseTree2, + current_impl_location: &mut ImplLocation, namespace: &mut String, ) -> syn::parse::Result<()> { match tree { @@ -448,17 +474,21 @@ impl ImplementAttributes { } namespace.push_str(&input.ident.to_string()); - self.walk_implement(&input.tree, namespace)?; + self.walk_implement(&input.tree, current_impl_location, namespace)?; } UseTree2::Name(_) => { - self.implement.push(tree.to_element_type(namespace)?); + self.implement + .push(tree.to_element_type(*current_impl_location, namespace)?); } UseTree2::Group(input) => { for tree in &input.items { - self.walk_implement(tree, namespace)?; + self.walk_implement(tree, current_impl_location, namespace)?; } } UseTree2::TrustLevel(input) => self.trust_level = *input, + UseTree2::ImplLocation(location) => { + *current_impl_location = *location; + } } Ok(()) @@ -470,10 +500,15 @@ enum UseTree2 { Name(UseName2), Group(UseGroup2), TrustLevel(usize), + ImplLocation(ImplLocation), } impl UseTree2 { - fn to_element_type(&self, namespace: &mut String) -> syn::parse::Result { + fn to_element_type( + &self, + impl_location: ImplLocation, + namespace: &mut String, + ) -> syn::parse::Result { match self { UseTree2::Path(input) => { if !namespace.is_empty() { @@ -481,7 +516,7 @@ impl UseTree2 { } namespace.push_str(&input.ident.to_string()); - input.tree.to_element_type(namespace) + input.tree.to_element_type(impl_location, namespace) } UseTree2::Name(input) => { let mut type_name = input.ident.to_string(); @@ -493,12 +528,13 @@ impl UseTree2 { let mut generics = vec![]; for g in &input.generics { - generics.push(g.to_element_type(&mut String::new())?); + generics.push(g.to_element_type(impl_location, &mut String::new())?); } Ok(ImplementType { type_name, generics, + impl_location, }) } UseTree2::Group(input) => Err(syn::parse::Error::new( @@ -538,23 +574,38 @@ impl syn::parse::Parse for UseTree2 { tree: Box::new(input.parse()?), })) } else if input.peek(syn::Token![=]) { - if ident != "TrustLevel" { + if ident == "TrustLevel" { + input.parse::()?; + let span = input.span(); + let value = input.call(syn::Ident::parse_any)?; + match value.to_string().as_str() { + "Partial" => Ok(UseTree2::TrustLevel(1)), + "Full" => Ok(UseTree2::TrustLevel(2)), + _ => Err(syn::parse::Error::new( + span, + "`TrustLevel` must be `Partial` or `Full`", + )), + } + } else if ident == "ImplLocation" { + input.parse::()?; + let span = input.span(); + let value = input.call(syn::Ident::parse_any)?; + Ok(UseTree2::ImplLocation(match value.to_string().as_str() { + "Outer" => ImplLocation::Outer, + "Inner" => ImplLocation::Inner, + _ => { + return Err(syn::parse::Error::new( + span, + "`ImplLocation` must be `Outer` or `Inner`", + )) + } + })) + } else { return Err(syn::parse::Error::new( ident.span(), "Unrecognized key-value pair", )); } - input.parse::()?; - let span = input.span(); - let value = input.call(syn::Ident::parse_any)?; - match value.to_string().as_str() { - "Partial" => Ok(UseTree2::TrustLevel(1)), - "Full" => Ok(UseTree2::TrustLevel(2)), - _ => Err(syn::parse::Error::new( - span, - "`TrustLevel` must be `Partial` or `Full`", - )), - } } else { let generics = if input.peek(syn::Token![<]) { input.parse::()?; diff --git a/crates/libs/interface/src/lib.rs b/crates/libs/interface/src/lib.rs index e33379a9213..680664359f6 100644 --- a/crates/libs/interface/src/lib.rs +++ b/crates/libs/interface/src/lib.rs @@ -135,8 +135,8 @@ impl Interface { if m.is_result() { quote! { - #[inline(always)] - #vis unsafe fn #name<#(#generics),*>(&self, #(#params),*) #ret { + #[inline(always)] + #vis unsafe fn #name<#(#generics),*>(&self, #(#params),*) #ret { (::windows_core::Interface::vtable(self).#name)(::windows_core::Interface::as_raw(self), #(#args),*).ok() } } @@ -214,7 +214,7 @@ impl Interface { let parent_vtable_generics = if self.parent_is_iunknown() { quote!(Identity, OFFSET) } else { - quote!(Identity, Impl, OFFSET) + quote!(Identity, OuterToImpl, OFFSET) }; let parent_vtable = self.parent_vtable(); @@ -253,13 +253,35 @@ impl Interface { if parent_vtable.is_some() { quote! { - unsafe extern "system" fn #name, Impl: #trait_name, const OFFSET: isize>(this: *mut ::core::ffi::c_void, #(#args),*) #ret { - let this = (this as *const *const ()).offset(OFFSET) as *const Identity; - let this_impl: &Impl = (*this).get_impl(); + unsafe extern "system" fn #name< + Identity: ::windows_core::IUnknownImpl, + OuterToImpl: ::windows_core::ComGetImpl, + const OFFSET: isize + >( + this: *mut ::core::ffi::c_void, // <-- This is the COM "this" pointer, which is not the same as &T or &T_Impl. + #(#args),* + ) #ret + where + OuterToImpl::Impl : #trait_name + { + // This step is essentially a virtual dispatch adjustor thunk. Its purpose is to adjust + // the "this" pointer from the address used by the COM interface to the root of the + // MyApp_Impl object. Since a given MyApp_Impl may implement more than one COM interface + // (and more than one COM interface chain), we need to know how to get from COM's "this" + // back to &MyApp_Impl. The OFFSET constant gives us the value (in pointer-sized units). + let this_outer: &Identity = &*((this as *const *const ()).offset(OFFSET) as *const Identity); + + // This step selects the part of the MyApp_Impl object which implements a given COM interface, + // i.e. IFoo_Impl trait. There are really only two possibilities: either MyApp_Impl or MyApp + // implements a given IFoo_Impl trait. The ComGetImplInner and ComGetImplOuter types + // allow the code that specialized this function to select which one is used. + let this_impl: &OuterToImpl::Impl = OuterToImpl::get_impl(this_outer); + + // Last, we invoke the implementation function. // We use explicit so that we can select the correct method // for situations where IFoo3 derives from IFoo2 and both declare a method with // the same name. - ::#name(this_impl, #(#params),*).into() + ::#name(this_impl, #(#params),*).into() } } } else { @@ -274,24 +296,16 @@ impl Interface { }) .collect::>(); - let entries = self - .methods - .iter() - .map(|m| { - let name = &m.name; - if parent_vtable.is_some() { - quote! { - #name: #name:: - } - } else { - quote! { - #name: #name:: - } - } - }) - .collect::>(); - if let Some(parent_vtable) = parent_vtable { + let entries = self + .methods + .iter() + .map(|m| { + let name = &m.name; + quote!(#name: #name::) + }) + .collect::>(); + quote! { #[repr(C)] #[doc(hidden)] @@ -300,7 +314,14 @@ impl Interface { #(#vtable_entries)* } impl #vtable_name { - pub const fn new, Impl: #trait_name, const OFFSET: isize>() -> Self { + pub const fn new< + Identity: ::windows_core::IUnknownImpl, + OuterToImpl: ::windows_core::ComGetImpl, + const OFFSET: isize, + >() -> Self + where + OuterToImpl::Impl : #trait_name + { #(#functions)* Self { base__: #parent_vtable::new::<#parent_vtable_generics>(), #(#entries),* } } @@ -313,6 +334,15 @@ impl Interface { } } } else { + let entries = self + .methods + .iter() + .map(|m| { + let name = &m.name; + quote!(#name: #name::) + }) + .collect::>(); + quote! { #[repr(C)] #[doc(hidden)] diff --git a/crates/tests/implement_core/src/impl_on_outer.rs b/crates/tests/implement_core/src/impl_on_outer.rs new file mode 100644 index 00000000000..8dd189415ac --- /dev/null +++ b/crates/tests/implement_core/src/impl_on_outer.rs @@ -0,0 +1,63 @@ +use windows_core::*; + +#[interface("cccccccc-0000-0000-0000-000000000001")] +unsafe trait IFoo: IUnknown { + fn hello(&self); +} + +#[interface("cccccccc-0000-0000-0000-000000000002")] +unsafe trait IFoo2: IFoo { + fn hello(&self); +} + +#[interface("cccccccc-0000-0000-0000-000000000003")] +unsafe trait IFoo3: IFoo2 { + fn hello(&self); +} + +#[interface("cccccccc-0000-0000-0000-000000000004")] +unsafe trait IBar: IUnknown { + fn goodbye(&self); +} + +// This tests that we can compile a COM object that has some COM interfaces implemented on the +// outer object and some on the inner object. +#[implement(ImplLocation = Outer, IFoo3, ImplLocation = Inner, IBar)] +struct MyApp {} + +impl IFoo_Impl for MyApp_Impl { + unsafe fn hello(&self) { + println!("MyApp as IFoo: hello"); + } +} +impl IFoo2_Impl for MyApp_Impl { + unsafe fn hello(&self) { + println!("MyApp as IFoo2: hello"); + } +} +impl IFoo3_Impl for MyApp_Impl { + unsafe fn hello(&self) { + println!("MyApp as IFoo3: hello"); + } +} + +impl IBar_Impl for MyApp { + unsafe fn goodbye(&self) { + println!("MyApp as IBar: goodbye"); + } +} + +#[test] +fn basic() { + let app = ComObject::new(MyApp {}); + let ifoo3: IFoo3 = app.cast().unwrap(); + let ifoo2: IFoo2 = app.cast().unwrap(); + let ifoo: IFoo = app.cast().unwrap(); + let ibar: IBar = app.cast().unwrap(); + unsafe { + ifoo.hello(); + ifoo2.hello(); + ifoo3.hello(); + ibar.goodbye(); + } +} diff --git a/crates/tests/implement_core/src/lib.rs b/crates/tests/implement_core/src/lib.rs index aa8f3bec531..39557c2b613 100644 --- a/crates/tests/implement_core/src/lib.rs +++ b/crates/tests/implement_core/src/lib.rs @@ -5,3 +5,4 @@ mod com_chain; mod com_object; +mod impl_on_outer;