diff --git a/crates/libs/bindgen/src/lib.rs b/crates/libs/bindgen/src/lib.rs index ea371d34d0..040efdc4d9 100644 --- a/crates/libs/bindgen/src/lib.rs +++ b/crates/libs/bindgen/src/lib.rs @@ -20,6 +20,7 @@ enum ArgKind { Output, Filter, Config, + Derives, } /// Windows metadata compiler. @@ -38,6 +39,7 @@ where let mut exclude = Vec::<&str>::new(); let mut config = std::collections::BTreeMap::<&str, &str>::new(); let mut format = false; + let mut derives = std::collections::BTreeMap::<(&str, &str), tokens::TokenStream>::new(); for arg in &args { if arg.starts_with('-') { @@ -50,6 +52,7 @@ where "-o" | "--out" => kind = ArgKind::Output, "-f" | "--filter" => kind = ArgKind::Filter, "--config" => kind = ArgKind::Config, + "--derives" => kind = ArgKind::Derives, "--format" => format = true, _ => return Err(Error::new(&format!("invalid option `{arg}`"))), }, @@ -75,6 +78,22 @@ where config.insert(arg, ""); } } + ArgKind::Derives => { + if let Some((ty, traits)) = arg.split_once('=') { + if let Some(last_dot) = ty.rfind('.') { + let name = &ty[last_dot + 1..]; + let namespace = &ty[..last_dot]; + let traits: tokens::TokenStream = traits.into(); + if derives.insert((namespace, name), quote! { #[derive(#traits)] }).is_some() { + return Err(Error::new(&format!("Duplicate entry for type `{ty}` in --derives"))); + } + } else { + return Err(Error::new(&format!("The type `{ty}` in --derives must be fully qualified"))); + } + } else { + return Err(Error::new(&format!("Invalid format for --derives, expected ty=traits, actual: `{arg}`"))); + } + } } } @@ -113,10 +132,21 @@ where winmd::verify(reader)?; + let unused_derives = derives.keys().filter(|(namespace, name)| reader.get_type_def(namespace, name).next().is_none()).collect::>(); + if !unused_derives.is_empty() { + let mut message = "unused derives".to_string(); + + for (namespace, name) in unused_derives { + message.push_str(&format!("\n {namespace}.{name}")); + } + + return Err(Error::new(&message)); + } + match extension(&output) { "rdl" => rdl::from_reader(reader, config, &output)?, "winmd" => winmd::from_reader(reader, config, &output)?, - "rs" => rust::from_reader(reader, config, &output)?, + "rs" => rust::from_reader(reader, config, &derives, &output)?, _ => return Err(Error::new("output extension must be one of winmd/rdl/rs")), } @@ -262,3 +292,18 @@ fn extension(path: &str) -> &str { fn directory(path: &str) -> &str { path.rsplit_once(&['/', '\\']).map_or("", |(directory, _)| directory) } + +#[test] +fn bad_derive_args() { + let result = bindgen(&["--derives", "Foo"]).unwrap_err().to_string(); + assert_eq!(result, "error: Invalid format for --derives, expected ty=traits, actual: `Foo`\n"); + + let result = bindgen(&["--derives", "Foo=bar"]).unwrap_err().to_string(); + assert_eq!(result, "error: The type `Foo` in --derives must be fully qualified\n"); + + let result = bindgen(&["--derives", "Foo.Bar=bar", "Foo.Bar=baz"]).unwrap_err().to_string(); + assert_eq!(result, "error: Duplicate entry for type `Foo.Bar` in --derives\n"); + + let result = bindgen(&["--out", "test.rs", "--filter", "Windows.Win32.System.Com.CoInitialize", "--derives", "Foo.Bar=bar"]).unwrap_err().to_string(); + assert_eq!(result, "error: unused derives\n Foo.Bar\n"); +} diff --git a/crates/libs/bindgen/src/rust/mod.rs b/crates/libs/bindgen/src/rust/mod.rs index 76c26680b3..943208c43d 100644 --- a/crates/libs/bindgen/src/rust/mod.rs +++ b/crates/libs/bindgen/src/rust/mod.rs @@ -23,7 +23,7 @@ use index::*; use rayon::prelude::*; use writer::*; -pub fn from_reader(reader: &'static metadata::Reader, mut config: std::collections::BTreeMap<&str, &str>, output: &str) -> Result<()> { +pub fn from_reader(reader: &'static metadata::Reader, mut config: std::collections::BTreeMap<&str, &str>, derives: &std::collections::BTreeMap<(&str, &str), TokenStream>, output: &str) -> Result<()> { let mut writer = Writer::new(reader, output); writer.package = config.remove("package").is_some(); writer.flatten = config.remove("flatten").is_some(); @@ -47,32 +47,32 @@ pub fn from_reader(reader: &'static metadata::Reader, mut config: std::collectio } if writer.package { - gen_package(&writer) + gen_package(&writer, derives) } else { - gen_file(&writer) + gen_file(&writer, derives) } } -fn gen_file(writer: &Writer) -> Result<()> { +fn gen_file(writer: &Writer, derives: &std::collections::BTreeMap<(&str, &str), TokenStream>) -> Result<()> { // TODO: harmonize this output code so we don't need these two wildly differnt code paths // there should be a simple way to generate the with or without namespaces. if writer.flatten { - let tokens = standalone::standalone_imp(writer); + let tokens = standalone::standalone_imp(writer, derives); write_to_file(&writer.output, try_format(writer, &tokens)) } else { let mut tokens = String::new(); let root = Tree::new(writer.reader); for tree in root.nested.values() { - tokens.push_str(&namespace(writer, tree)); + tokens.push_str(&namespace(writer, tree, derives)); } write_to_file(&writer.output, try_format(writer, &tokens)) } } -fn gen_package(writer: &Writer) -> Result<()> { +fn gen_package(writer: &Writer, derives: &std::collections::BTreeMap<(&str, &str), TokenStream>) -> Result<()> { let directory = directory(&writer.output); let root = Tree::new(writer.reader); let mut root_len = 0; @@ -86,7 +86,7 @@ fn gen_package(writer: &Writer) -> Result<()> { trees.par_iter().try_for_each(|tree| { let directory = format!("{directory}/{}", tree.namespace.replace('.', "/")); - let mut tokens = namespace(writer, tree); + let mut tokens = namespace(writer, tree, derives); let tokens_impl = if !writer.sys { namespace_impl(writer, tree) } else { String::new() }; @@ -143,7 +143,7 @@ use std::fmt::Write; use tokens::*; use try_format::*; -fn namespace(writer: &Writer, tree: &Tree) -> String { +fn namespace(writer: &Writer, tree: &Tree, derives: &std::collections::BTreeMap<(&str, &str), TokenStream>) -> String { let writer = &mut writer.clone(); writer.namespace = tree.namespace; let mut tokens = TokenStream::new(); @@ -159,7 +159,7 @@ fn namespace(writer: &Writer, tree: &Tree) -> String { } else { tokens.combine("e! { pub mod #name }); tokens.push_str("{"); - tokens.push_str(&namespace(writer, tree)); + tokens.push_str(&namespace(writer, tree, derives)); tokens.push_str("}"); } } @@ -200,7 +200,7 @@ fn namespace(writer: &Writer, tree: &Tree) -> String { continue; } } - types.entry(kind).or_default().entry(name).or_default().combine(&structs::writer(writer, def)); + types.entry(kind).or_default().entry(name).or_default().combine(&structs::writer(writer, def, derives.get(&(def.namespace(), name)))); } metadata::TypeKind::Delegate => types.entry(kind).or_default().entry(name).or_default().combine(&delegates::writer(writer, def)), } diff --git a/crates/libs/bindgen/src/rust/standalone.rs b/crates/libs/bindgen/src/rust/standalone.rs index 5cf930d17d..780bd7604d 100644 --- a/crates/libs/bindgen/src/rust/standalone.rs +++ b/crates/libs/bindgen/src/rust/standalone.rs @@ -1,7 +1,7 @@ use super::*; use metadata::AsRow; -pub fn standalone_imp(writer: &Writer) -> String { +pub fn standalone_imp(writer: &Writer, derives: &std::collections::BTreeMap<(&str, &str), TokenStream>) -> String { let mut types = std::collections::BTreeSet::new(); let mut functions = std::collections::BTreeSet::new(); let mut constants = std::collections::BTreeSet::new(); @@ -112,7 +112,7 @@ pub fn standalone_imp(writer: &Writer) -> String { continue; } } - sorted.insert(name, structs::writer(writer, def)); + sorted.insert(name, structs::writer(writer, def, derives.get(&(def.namespace(), name)))); } metadata::TypeKind::Delegate => { sorted.insert(def.name(), delegates::writer(writer, def)); diff --git a/crates/libs/bindgen/src/rust/structs.rs b/crates/libs/bindgen/src/rust/structs.rs index f211b013bb..2a8ff45c5f 100644 --- a/crates/libs/bindgen/src/rust/structs.rs +++ b/crates/libs/bindgen/src/rust/structs.rs @@ -1,7 +1,7 @@ use super::*; use metadata::HasAttributes; -pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream { +pub fn writer(writer: &Writer, def: metadata::TypeDef, derives: Option<&TokenStream>) -> TokenStream { if def.has_attribute("ApiContractAttribute") { return quote! {}; } @@ -10,10 +10,10 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream { return handles::writer(writer, def); } - gen_struct_with_name(writer, def, def.name(), &cfg::Cfg::default()) + gen_struct_with_name(writer, def, def.name(), &cfg::Cfg::default(), derives) } -fn gen_struct_with_name(writer: &Writer, def: metadata::TypeDef, struct_name: &str, cfg: &cfg::Cfg) -> TokenStream { +fn gen_struct_with_name(writer: &Writer, def: metadata::TypeDef, struct_name: &str, cfg: &cfg::Cfg, derives: Option<&TokenStream>) -> TokenStream { let name = to_ident(struct_name); if def.fields().next().is_none() { @@ -81,6 +81,7 @@ fn gen_struct_with_name(writer: &Writer, def: metadata::TypeDef, struct_name: &s let mut tokens = quote! { #repr #features + #derives pub #struct_or_union #name {#(#fields)*} }; @@ -103,7 +104,7 @@ fn gen_struct_with_name(writer: &Writer, def: metadata::TypeDef, struct_name: &s for (index, nested_type) in writer.reader.nested_types(def).enumerate() { let nested_name = format!("{struct_name}_{index}"); - tokens.combine(&gen_struct_with_name(writer, nested_type, &nested_name, &cfg)); + tokens.combine(&gen_struct_with_name(writer, nested_type, &nested_name, &cfg, None)); } tokens diff --git a/crates/tests/standalone/build.rs b/crates/tests/standalone/build.rs index 9464470760..920e95f0a3 100644 --- a/crates/tests/standalone/build.rs +++ b/crates/tests/standalone/build.rs @@ -165,18 +165,25 @@ fn main() { "src/b_vtbl_4.rs", &["Windows.Win32.System.Com.IPersistFile"], ); + + // Ensure that derives adds the #[derive(...)] attribute. + write_derives( + "src/b_derives.rs", + &["Windows.Foundation.DateTime"], + &["Windows.Foundation.DateTime=::core::cmp::PartialOrd,::core::cmp::Ord"], + ); } fn write_sys(output: &str, filter: &[&str]) { - riddle(output, filter, &["flatten", "sys", "minimal"]); + riddle(output, filter, &["flatten", "sys", "minimal"], None); } fn write_win(output: &str, filter: &[&str]) { - riddle(output, filter, &["flatten", "minimal"]); + riddle(output, filter, &["flatten", "minimal"], None); } fn write_std(output: &str, filter: &[&str]) { - riddle(output, filter, &["flatten", "std", "minimal"]); + riddle(output, filter, &["flatten", "std", "minimal"], None); } fn write_no_inner_attr(output: &str, filter: &[&str]) { @@ -184,14 +191,19 @@ fn write_no_inner_attr(output: &str, filter: &[&str]) { output, filter, &["flatten", "no-inner-attributes", "minimal"], + None, ); } fn write_vtbl(output: &str, filter: &[&str]) { - riddle(output, filter, &["flatten", "sys", "minimal", "vtbl"]); + riddle(output, filter, &["flatten", "sys", "minimal", "vtbl"], None); +} + +fn write_derives(output: &str, filter: &[&str], derives: &[&str]) { + riddle(output, filter, &["flatten", "minimal"], Some(derives)); } -fn riddle(output: &str, filter: &[&str], config: &[&str]) { +fn riddle(output: &str, filter: &[&str], config: &[&str], derives: Option<&[&str]>) { // Rust-analyzer may re-run build scripts whenever a source file is deleted // which causes an endless loop if the file is deleted from a build script. // To workaround this, we truncate the file instead of deleting it. @@ -221,6 +233,11 @@ fn riddle(output: &str, filter: &[&str], config: &[&str]) { command.arg("--config"); command.args(config); + if let Some(derives) = derives { + command.arg("--derives"); + command.args(derives); + } + if !command.status().unwrap().success() { panic!("Failed to run riddle"); } diff --git a/crates/tests/standalone/src/b_derives.rs b/crates/tests/standalone/src/b_derives.rs new file mode 100644 index 0000000000..43671a0ab3 --- /dev/null +++ b/crates/tests/standalone/src/b_derives.rs @@ -0,0 +1,45 @@ +// Bindings generated by `windows-bindgen` 0.54.0 + +#![allow( + non_snake_case, + non_upper_case_globals, + non_camel_case_types, + dead_code, + clippy::all +)] +#[repr(C)] +#[derive(::core::cmp::PartialOrd, ::core::cmp::Ord)] +pub struct DateTime { + pub UniversalTime: i64, +} +impl ::core::marker::Copy for DateTime {} +impl ::core::clone::Clone for DateTime { + fn clone(&self) -> Self { + *self + } +} +impl ::core::fmt::Debug for DateTime { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + f.debug_struct("DateTime") + .field("UniversalTime", &self.UniversalTime) + .finish() + } +} +impl ::windows_core::TypeKind for DateTime { + type TypeKind = ::windows_core::CopyType; +} +impl ::windows_core::RuntimeType for DateTime { + const SIGNATURE: ::windows_core::imp::ConstBuffer = + ::windows_core::imp::ConstBuffer::from_slice(b"struct(Windows.Foundation.DateTime;i8)"); +} +impl ::core::cmp::PartialEq for DateTime { + fn eq(&self, other: &Self) -> bool { + self.UniversalTime == other.UniversalTime + } +} +impl ::core::cmp::Eq for DateTime {} +impl ::core::default::Default for DateTime { + fn default() -> Self { + unsafe { ::core::mem::zeroed() } + } +} diff --git a/crates/tests/standalone/src/lib.rs b/crates/tests/standalone/src/lib.rs index 5dedc55e7b..152282395d 100644 --- a/crates/tests/standalone/src/lib.rs +++ b/crates/tests/standalone/src/lib.rs @@ -7,6 +7,7 @@ mod b_bstr; mod b_calendar; mod b_constant_types; mod b_depends; +mod b_derives; mod b_enumeration; mod b_enumerator; mod b_guid; @@ -184,3 +185,20 @@ fn from_included() { included::GetVersion(); } } + +#[test] +fn derive_ord() { + use b_derives::*; + let mut dates = [ + DateTime { UniversalTime: 123 }, + DateTime { UniversalTime: 42 }, + ]; + dates.sort(); + assert_eq!( + &dates, + &[ + DateTime { UniversalTime: 42 }, + DateTime { UniversalTime: 123 } + ] + ); +} diff --git a/crates/tools/riddle/src/main.rs b/crates/tools/riddle/src/main.rs index 65185952a0..063e961c2d 100644 --- a/crates/tools/riddle/src/main.rs +++ b/crates/tools/riddle/src/main.rs @@ -12,6 +12,7 @@ Options: --config Override a configuration value --format Format .rdl files only --etc File containing command line options + --derives Emit a derive attribute for a type "# ); } else {