From 7b5bdf53d1647a1d6e209e0b6bd77a995632b2f2 Mon Sep 17 00:00:00 2001 From: William Date: Fri, 30 Aug 2024 15:19:30 -0700 Subject: [PATCH 1/7] add prost-wkt-derive --- Cargo.toml | 6 +- README.md | 3 +- example/Cargo.toml | 8 ++- example/build.rs | 11 ++-- src/lib.rs | 12 +++- wkt-build/Cargo.toml | 18 ------ wkt-build/src/lib.rs | 101 ------------------------------ wkt-derive/Cargo.toml | 18 ++++++ wkt-derive/src/lib.rs | 60 ++++++++++++++++++ wkt-types/Cargo.toml | 3 +- wkt-types/build.rs | 27 ++++---- wkt-types/src/pbany.rs | 113 ++++++++++++++++++++-------------- wkt-types/tests/pbany_test.rs | 48 ++------------- 13 files changed, 186 insertions(+), 242 deletions(-) delete mode 100644 wkt-build/Cargo.toml delete mode 100644 wkt-build/src/lib.rs create mode 100644 wkt-derive/Cargo.toml create mode 100644 wkt-derive/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index fbf96fb..12868bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,13 +11,15 @@ edition = "2021" rust-version = "1.70" [workspace] -members = [ "wkt-build", "wkt-types", "example" ] +members = ["wkt-types", "wkt-derive", "example"] [dependencies] prost = "0.13.1" +prost-wkt-derive = { path = "wkt-derive" } +erased-serde = "0.4" inventory = "0.3.0" serde = "1.0" serde_json = "1.0" serde_derive = "1.0" chrono = { version = "0.4.27", default-features = false, features = ["serde"] } -typetag = "0.2" +const_format = "0.2.32" diff --git a/README.md b/README.md index 5b31e4c..11f9eb2 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,6 @@ serde = { version = "1.0", features = ["derive"] } [build-dependencies] prost-build = "0.13" -prost-wkt-build = "0.6" ``` In your `build.rs`, make sure to add the following options: @@ -290,7 +289,7 @@ Contributions are welcome! When upgrading Prost to the latest version, make sure the latest changes from `prost-types` are incorporated into `prost-wkt-types` to ensure full compatibility. Currently the `Name` traits have specifically not been implemented until this implementation in Prost has fully -stabilized. +stabilized. ## MSRV ## diff --git a/example/Cargo.toml b/example/Cargo.toml index 07b16a6..02a4a89 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -9,11 +9,13 @@ rust-version = "1.70" prost = "0.13.1" prost-wkt = { path = ".." } prost-wkt-types = { path = "../wkt-types" } -serde = "1.0" +serde = { version = "1.0", features = ["derive"] } serde_derive = "1.0" serde_json = "1.0" -chrono = { version = "0.4.27", default-features = false, features = ["clock", "serde"] } +chrono = { version = "0.4.27", default-features = false, features = [ + "clock", + "serde", +] } [build-dependencies] prost-build = "0.13.1" -prost-wkt-build = { path = "../wkt-build" } diff --git a/example/build.rs b/example/build.rs index b8f9bf0..9564826 100644 --- a/example/build.rs +++ b/example/build.rs @@ -1,4 +1,3 @@ -use prost_wkt_build::*; use std::{env, path::PathBuf}; fn main() { @@ -6,27 +5,25 @@ fn main() { let descriptor_file = out.join("descriptors.bin"); let mut prost_build = prost_build::Config::new(); prost_build + .enable_type_names() .type_attribute( ".my.requests", - "#[derive(serde::Serialize, serde::Deserialize)] #[serde(default, rename_all=\"camelCase\")]", + "#[derive(serde::Serialize, serde::Deserialize, ::prost_wkt::MessageSerde)] #[serde(default, rename_all=\"camelCase\")]", ) .type_attribute( ".my.messages.Foo", "#[derive(serde::Serialize, serde::Deserialize)] #[serde(default, rename_all=\"camelCase\")]", ) + .message_attribute(".my.messages.Foo", "#[derive(::prost_wkt::MessageSerde)]") .type_attribute( ".my.messages.Content", "#[derive(serde::Serialize, serde::Deserialize)] #[serde(rename_all=\"camelCase\")]", ) + .message_attribute(".my.messages.Content", "#[derive(::prost_wkt::MessageSerde)]") .extern_path(".google.protobuf.Any", "::prost_wkt_types::Any") .extern_path(".google.protobuf.Timestamp", "::prost_wkt_types::Timestamp") .extern_path(".google.protobuf.Value", "::prost_wkt_types::Value") .file_descriptor_set_path(&descriptor_file) .compile_protos(&["proto/messages.proto", "proto/requests.proto"], &["proto/"]) .unwrap(); - - let descriptor_bytes = std::fs::read(descriptor_file).unwrap(); - let descriptor = FileDescriptorSet::decode(&descriptor_bytes[..]).unwrap(); - - prost_wkt_build::add_serde(out, descriptor); } diff --git a/src/lib.rs b/src/lib.rs index 0186a6e..b05c998 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,11 @@ pub use inventory; -pub use typetag; +pub use const_format; +pub use erased_serde; +pub use prost_wkt_derive::MessageSerde; /// Trait to support serialization and deserialization of `prost` messages. -#[typetag::serde(tag = "@type")] -pub trait MessageSerde: prost::Message + std::any::Any { +pub trait MessageSerde: prost::Message + std::any::Any + erased_serde::Serialize { /// message name as in proto file fn message_name(&self) -> &'static str; /// package name as in proto file @@ -15,6 +16,8 @@ pub trait MessageSerde: prost::Message + std::any::Any { fn new_instance(&self, data: Vec) -> Result, prost::DecodeError>; /// Returns the encoded protobuf message as bytes fn try_encoded(&self) -> Result, prost::EncodeError>; + /// Returns an erased serialize dynamic reference + fn as_erased_serialize(&self) -> &dyn erased_serde::Serialize; } /// The implementation here is a direct copy of the `impl dyn` of [`std::any::Any`]! @@ -86,10 +89,13 @@ impl dyn MessageSerde { } type MessageSerdeDecoderFn = fn(&[u8]) -> Result, ::prost::DecodeError>; +type MessageSerdeDeserializerFn = + fn(&mut dyn erased_serde::Deserializer) -> Result, erased_serde::Error>; pub struct MessageSerdeDecoderEntry { pub type_url: &'static str, pub decoder: MessageSerdeDecoderFn, + pub deserializer: MessageSerdeDeserializerFn, } inventory::collect!(MessageSerdeDecoderEntry); diff --git a/wkt-build/Cargo.toml b/wkt-build/Cargo.toml deleted file mode 100644 index 9982f37..0000000 --- a/wkt-build/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "prost-wkt-build" -version = "0.6.0" -authors = ["fdeantoni "] -license = "Apache-2.0" -repository = "https://github.com/fdeantoni/prost-wkt" -description = "Helper crate for prost to allow JSON serialization and deserialization of Well Known Types." -readme = "../README.md" -documentation = "https://docs.rs/prost-wkt-build" -edition = "2021" -rust-version = "1.70" - -[dependencies] -prost = "0.13.1" -prost-types = "0.13.1" -prost-build = "0.13.1" -quote = "1.0" -heck = { version = ">=0.4, <=0.5" } diff --git a/wkt-build/src/lib.rs b/wkt-build/src/lib.rs deleted file mode 100644 index 4b93223..0000000 --- a/wkt-build/src/lib.rs +++ /dev/null @@ -1,101 +0,0 @@ -use heck::ToUpperCamelCase; -use quote::{format_ident, quote}; -use std::fs::{File, OpenOptions}; -use std::io::Write; -use std::path::PathBuf; - -pub use prost::Message; -pub use prost_types::FileDescriptorSet; - -use prost_build::Module; - -pub fn add_serde(out: PathBuf, descriptor: FileDescriptorSet) { - for fd in &descriptor.file { - let package_name = match fd.package { - Some(ref pkg) => pkg, - None => continue, - }; - - let rust_path = out - .join(Module::from_protobuf_package_name(package_name).to_file_name_or(package_name)); - - // In some cases the generated file would be in empty. These files are no longer created by Prost, so - // we'll create here. Otherwise we append. - let mut rust_file = OpenOptions::new() - .create(true) - .append(true) - .open(rust_path) - .unwrap(); - - for msg in &fd.message_type { - let message_name = match msg.name { - Some(ref name) => name, - None => continue, - }; - - let type_url = format!("type.googleapis.com/{package_name}.{message_name}"); - - gen_trait_impl(&mut rust_file, package_name, message_name, &type_url); - } - } -} - -// This method uses the `heck` crate (the same that prost uses) to properly format the message name -// to UpperCamelCase as the prost_build::ident::{to_snake, to_upper_camel} methods -// in the `ident` module of prost_build is private. -fn gen_trait_impl(rust_file: &mut File, package_name: &str, message_name: &str, type_url: &str) { - let type_name = message_name.to_upper_camel_case(); - let type_name = format_ident!("{}", type_name); - - let tokens = quote! { - #[allow(dead_code)] - const _: () = { - use ::prost_wkt::typetag; - #[typetag::serde(name=#type_url)] - impl ::prost_wkt::MessageSerde for #type_name { - fn package_name(&self) -> &'static str { - #package_name - } - fn message_name(&self) -> &'static str { - #message_name - } - fn type_url(&self) -> &'static str { - #type_url - } - fn new_instance(&self, data: Vec) -> ::std::result::Result, ::prost::DecodeError> { - let mut target = Self::default(); - ::prost::Message::merge(&mut target, data.as_slice())?; - let erased: ::std::boxed::Box = ::std::boxed::Box::new(target); - Ok(erased) - } - fn try_encoded(&self) -> ::std::result::Result<::std::vec::Vec, ::prost::EncodeError> { - let mut buf = ::std::vec::Vec::with_capacity(::prost::Message::encoded_len(self)); - ::prost::Message::encode(self, &mut buf)?; - Ok(buf) - } - } - - ::prost_wkt::inventory::submit!{ - ::prost_wkt::MessageSerdeDecoderEntry { - type_url: #type_url, - decoder: |buf: &[u8]| { - let msg: #type_name = ::prost::Message::decode(buf)?; - Ok(::std::boxed::Box::new(msg)) - } - } - } - - impl ::prost::Name for #type_name { - const PACKAGE: &'static str = #package_name; - const NAME: &'static str = #message_name; - - fn type_url() -> String { - #type_url.to_string() - } - } - }; - }; - - writeln!(rust_file).unwrap(); - writeln!(rust_file, "{}", &tokens).unwrap(); -} diff --git a/wkt-derive/Cargo.toml b/wkt-derive/Cargo.toml new file mode 100644 index 0000000..b9f69d7 --- /dev/null +++ b/wkt-derive/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "prost-wkt-derive" +version = "0.6.0" +authors = ["fdeantoni "] +license = "Apache-2.0" +repository = "https://github.com/fdeantoni/prost-wkt" +description = "Derive traits to assist with JSON serialization and deserialization of Well Known Types." +readme = "../README.md" +documentation = "https://docs.rs/prost-wkt-derive" +edition = "2021" +rust-version = "1.70" + +[lib] +proc-macro = true + +[dependencies] +quote = "1.0.37" +syn = "2.0.76" diff --git a/wkt-derive/src/lib.rs b/wkt-derive/src/lib.rs new file mode 100644 index 0000000..3f491ee --- /dev/null +++ b/wkt-derive/src/lib.rs @@ -0,0 +1,60 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::DeriveInput; + +#[proc_macro_derive(MessageSerde)] +pub fn message_serde_derive(input: TokenStream) -> TokenStream { + let ast: DeriveInput = syn::parse(input).unwrap(); + let name = &ast.ident; + let gen = quote! { + const _: () = { + const TYPE_URL: &str = ::prost_wkt::const_format::concatcp!( + "type.googleapis.com/", + <#name as ::prost::Name>::PACKAGE, + ".", + <#name as ::prost::Name>::NAME, + ); + + impl ::prost_wkt::MessageSerde for #name { + fn package_name(&self) -> &'static str { + ::PACKAGE + } + fn message_name(&self) -> &'static str { + ::NAME + } + fn type_url(&self) -> &'static str { + TYPE_URL + } + fn new_instance(&self, data: Vec) -> ::std::result::Result, ::prost::DecodeError> { + let mut target = Self::default(); + ::prost::Message::merge(&mut target, data.as_slice())?; + let erased: ::std::boxed::Box = ::std::boxed::Box::new(target); + Ok(erased) + } + fn try_encoded(&self) -> ::std::result::Result<::std::vec::Vec, ::prost::EncodeError> { + let mut buf = ::std::vec::Vec::with_capacity(::prost::Message::encoded_len(self)); + ::prost::Message::encode(self, &mut buf)?; + Ok(buf) + } + fn as_erased_serialize(&self) -> &dyn ::prost_wkt::erased_serde::Serialize { + self + } + } + + ::prost_wkt::inventory::submit!{ + ::prost_wkt::MessageSerdeDecoderEntry { + type_url: TYPE_URL, + decoder: |buf: &[u8]| { + let msg: #name = ::prost::Message::decode(buf)?; + Ok(::std::boxed::Box::new(msg)) + }, + deserializer: |de: &mut dyn ::prost_wkt::erased_serde::Deserializer| { + ::prost_wkt::erased_serde::deserialize::<#name>(de) + .map(|v| Box::new(v) as Box) + }, + } + } + }; + }; + gen.into() +} diff --git a/wkt-types/Cargo.toml b/wkt-types/Cargo.toml index 4a7b760..b0c8eef 100644 --- a/wkt-types/Cargo.toml +++ b/wkt-types/Cargo.toml @@ -27,13 +27,14 @@ prost = "0.13.1" serde = "1.0" serde_json = "1.0" serde_derive = "1.0" +serde-value = "0.7" +erased-serde = "0.4" chrono = { version = "0.4.27", default-features = false, features = ["serde"] } [build-dependencies] prost = "0.13.1" prost-types = "0.13.1" prost-build = "0.13.1" -prost-wkt-build = { version = "0.6.0", path = "../wkt-build" } regex = "1" protobuf-src = { version = "1.1.0", optional = true } protox = { version = "0.6.0", optional = true } diff --git a/wkt-types/build.rs b/wkt-types/build.rs index 687d0e5..5f1e98f 100644 --- a/wkt-types/build.rs +++ b/wkt-types/build.rs @@ -2,9 +2,6 @@ use std::env; use std::fs::create_dir_all; use std::path::{Path, PathBuf}; -use prost::Message; -use prost_types::FileDescriptorSet; - fn main() { #[cfg(feature = "vendored-protoc")] std::env::set_var("PROTOC", protobuf_src::protoc()); @@ -34,20 +31,18 @@ fn build(dir: &Path, proto: &str) { prost_build .compile_well_known_types() - .type_attribute("google.protobuf.Empty","#[derive(serde_derive::Serialize, serde_derive::Deserialize)]") - .type_attribute("google.protobuf.FieldMask","#[derive(serde_derive::Serialize, serde_derive::Deserialize)]") + .enable_type_names() + .type_attribute( + "google.protobuf.Empty", + "#[derive(serde_derive::Serialize, serde_derive::Deserialize)]", + ) + .type_attribute( + "google.protobuf.FieldMask", + "#[derive(serde_derive::Serialize, serde_derive::Deserialize)]", + ) + .message_attribute(".", "#[derive(::prost_wkt::MessageSerde)]") .file_descriptor_set_path(&descriptor_file) .out_dir(&out) - .compile_protos( - &[ - source - ], - &["proto/".to_string()], - ) + .compile_protos(&[source], &["proto/".to_string()]) .unwrap(); - - let descriptor_bytes = std::fs::read(descriptor_file).unwrap(); - let descriptor = FileDescriptorSet::decode(&descriptor_bytes[..]).unwrap(); - - prost_wkt_build::add_serde(out, descriptor); } diff --git a/wkt-types/src/pbany.rs b/wkt-types/src/pbany.rs index 1d278e8..e938192 100644 --- a/wkt-types/src/pbany.rs +++ b/wkt-types/src/pbany.rs @@ -1,12 +1,13 @@ use prost_wkt::MessageSerde; -use serde::de::{Deserialize, Deserializer}; +use serde::de::{Deserialize, Deserializer, Visitor}; use serde::ser::{Serialize, SerializeStruct, Serializer}; include!(concat!(env!("OUT_DIR"), "/pbany/google.protobuf.rs")); -use prost::{DecodeError, Message, EncodeError, Name}; +use prost::{DecodeError, EncodeError, Message, Name}; +use serde_value::ValueDeserializer; -use std::borrow::Cow; +use std::{borrow::Cow, fmt}; #[derive(Clone, Debug, PartialEq, Eq)] pub struct AnyError { @@ -30,8 +31,8 @@ impl std::error::Error for AnyError { } } -impl std::fmt::Display for AnyError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for AnyError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("failed to convert Value: ")?; f.write_str(&self.description) } @@ -144,7 +145,6 @@ impl Any { err.push("unexpected type URL", "type_url"); Err(err) } - } impl Serialize for Any { @@ -152,15 +152,64 @@ impl Serialize for Any { where S: Serializer, { + let mut state = serializer.serialize_struct("Any", 3)?; + state.serialize_field("@type", &self.type_url)?; match self.clone().try_unpack() { - Ok(result) => serde::ser::Serialize::serialize(result.as_ref(), serializer), + Ok(result) => { + state.serialize_field("value", result.as_erased_serialize())?; + } Err(_) => { - let mut state = serializer.serialize_struct("Any", 3)?; - state.serialize_field("@type", &self.type_url)?; state.serialize_field("value", &self.value)?; - state.end() } } + state.end() + } +} + +struct AnyVisitor; + +impl<'de> Visitor<'de> for AnyVisitor { + type Value = Box; + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("type.googleapis.com/google.protobuf.any") + } + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut cached_type_url: Option = None; + let mut cached_value: Option = None; + while let Some(key) = map.next_key::()? { + match &*key { + "@type" => { + if cached_type_url.is_some() { + return Err(serde::de::Error::duplicate_field("@type")); + } + cached_type_url.replace(map.next_value()?); + } + "value" => { + if cached_value.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + cached_value.replace(map.next_value()?); + } + _ => return Err(serde::de::Error::unknown_field(&key, &["@type", "value"])), + }; + } + let type_url = cached_type_url.ok_or_else(|| serde::de::Error::missing_field("@type"))?; + let raw_value = cached_value.ok_or_else(|| serde::de::Error::missing_field("value"))?; + let entry = ::prost_wkt::inventory::iter::<::prost_wkt::MessageSerdeDecoderEntry> + .into_iter() + .find(|entry| type_url == entry.type_url) + .ok_or_else(|| { + serde::de::Error::invalid_type( + serde::de::Unexpected::Str(&type_url), + &"a typeurl registered by deriving SerdeMessage" as &dyn serde::de::Expected, + ) + })?; + let mut deserializer = + ::erase(ValueDeserializer::::new(raw_value)); + (entry.deserializer)(&mut deserializer).map_err(|err| serde::de::Error::custom(err)) } } @@ -169,8 +218,7 @@ impl<'de> Deserialize<'de> for Any { where D: Deserializer<'de>, { - let erased: Box = - serde::de::Deserialize::deserialize(deserializer)?; + let erased = deserializer.deserialize_struct("Any", &["@type", "value"], AnyVisitor)?; let type_url = erased.type_url().to_string(); let value = erased.try_encoded().map_err(|err| { serde::de::Error::custom(format!("Failed to encode message: {err:?}")) @@ -219,12 +267,12 @@ impl<'a> TypeUrl<'a> { #[cfg(test)] mod tests { use crate::pbany::*; - use prost::{DecodeError, EncodeError, Message}; - use prost_wkt::*; - use serde::*; + use serde_derive::*; use serde_json::json; - #[derive(Clone, Eq, PartialEq, ::prost::Message, Serialize, Deserialize)] + #[derive( + Clone, Eq, PartialEq, ::prost::Message, Serialize, Deserialize, ::prost_wkt::MessageSerde, + )] #[serde(default, rename_all = "camelCase")] pub struct Foo { #[prost(string, tag = "1")] @@ -236,33 +284,6 @@ mod tests { const PACKAGE: &'static str = "any.test"; } - #[typetag::serde(name = "type.googleapis.com/any.test.Foo")] - impl prost_wkt::MessageSerde for Foo { - fn message_name(&self) -> &'static str { - "Foo" - } - - fn package_name(&self) -> &'static str { - "any.test" - } - - fn type_url(&self) -> &'static str { - "type.googleapis.com/any.test.Foo" - } - fn new_instance(&self, data: Vec) -> Result, DecodeError> { - let mut target = Self::default(); - Message::merge(&mut target, data.as_slice())?; - let erased: Box = Box::new(target); - Ok(erased) - } - - fn try_encoded(&self) -> Result, EncodeError> { - let mut buf = Vec::with_capacity(Message::encoded_len(self)); - Message::encode(self, &mut buf)?; - Ok(buf) - } - } - #[test] fn pack_unpack_test() { let msg = Foo { @@ -294,10 +315,10 @@ mod tests { "@type": type_url, "value": {} }); - let erased: Box = serde_json::from_value(data).unwrap(); - let foo: &Foo = erased.downcast_ref::().unwrap(); + let any: Any = serde_json::from_value(data).unwrap(); + let foo: Foo = any.to_msg().unwrap(); println!("Deserialize default: {foo:?}"); - assert_eq!(foo, &Foo::default()) + assert_eq!(foo, Foo::default()) } #[test] diff --git a/wkt-types/tests/pbany_test.rs b/wkt-types/tests/pbany_test.rs index af4f92e..e5d59be 100644 --- a/wkt-types/tests/pbany_test.rs +++ b/wkt-types/tests/pbany_test.rs @@ -1,10 +1,10 @@ -use prost::{DecodeError, EncodeError, Message, Name}; +use prost::Name; use prost_wkt::*; use prost_wkt_types::*; -use serde::{Deserialize, Serialize}; +use serde_derive::{Deserialize, Serialize}; use std::collections::HashMap; -#[derive(Clone, PartialEq, ::prost::Message, Serialize, Deserialize)] +#[derive(Clone, PartialEq, ::prost::Message, Serialize, Deserialize, MessageSerde)] #[prost(package = "any.test")] #[serde(rename_all = "camelCase")] pub struct Foo { @@ -31,44 +31,6 @@ impl Name for Foo { } } -#[typetag::serde(name = "type.googleapis.com/any.test.Foo")] -impl prost_wkt::MessageSerde for Foo { - fn message_name(&self) -> &'static str { - "Foo" - } - - fn package_name(&self) -> &'static str { - "any.test" - } - - fn type_url(&self) -> &'static str { - "type.googleapis.com/any.test.Foo" - } - - fn new_instance(&self, data: Vec) -> Result, DecodeError> { - let mut target = Self::default(); - Message::merge(&mut target, data.as_slice())?; - let erased: Box = Box::new(target); - Ok(erased) - } - - fn try_encoded(&self) -> Result, EncodeError> { - let mut buf = Vec::with_capacity(Message::encoded_len(self)); - Message::encode(self, &mut buf)?; - Ok(buf) - } -} - -::prost_wkt::inventory::submit! { - ::prost_wkt::MessageSerdeDecoderEntry { - type_url: "type.googleapis.com/any.test.Foo", - decoder: |buf: &[u8]| { - let msg: Foo = ::prost::Message::decode(buf)?; - Ok(Box::new(msg)) - } - } -} - fn create_struct() -> Value { let number: Value = Value::from(10.0); let null: Value = Value::null(); @@ -109,8 +71,8 @@ fn test_any_serialization() { "Serialized to string: {}", serde_json::to_string_pretty(&msg).unwrap() ); - let erased = &msg as &dyn MessageSerde; - let json = serde_json::to_string(erased).unwrap(); + let any = Any::from_msg(&msg).unwrap(); + let json = serde_json::to_string(&any).unwrap(); println!("Erased json: {json}"); } From 2f29c346d35c36a57c1d77a611bd23493794ae75 Mon Sep 17 00:00:00 2001 From: William Date: Fri, 30 Aug 2024 19:31:35 -0700 Subject: [PATCH 2/7] handle omitted url --- src/lib.rs | 27 +++++++---- wkt-derive/src/lib.rs | 35 +------------- wkt-types/Cargo.toml | 1 - wkt-types/src/pbany.rs | 105 +++++++++++++++++++++++++++++------------ 4 files changed, 96 insertions(+), 72 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b05c998..00f732f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,14 +6,8 @@ pub use prost_wkt_derive::MessageSerde; /// Trait to support serialization and deserialization of `prost` messages. pub trait MessageSerde: prost::Message + std::any::Any + erased_serde::Serialize { - /// message name as in proto file - fn message_name(&self) -> &'static str; - /// package name as in proto file - fn package_name(&self) -> &'static str; /// the message proto type url e.g. type.googleapis.com/my.package.MyMessage - fn type_url(&self) -> &'static str; - /// Creates a new instance of this message using the protobuf encoded data - fn new_instance(&self, data: Vec) -> Result, prost::DecodeError>; + fn type_url(&self) -> String; /// Returns the encoded protobuf message as bytes fn try_encoded(&self) -> Result, prost::EncodeError>; /// Returns an erased serialize dynamic reference @@ -88,14 +82,31 @@ impl dyn MessageSerde { } } +type MessageSerdeTypeUrlFn = fn() -> String; type MessageSerdeDecoderFn = fn(&[u8]) -> Result, ::prost::DecodeError>; type MessageSerdeDeserializerFn = fn(&mut dyn erased_serde::Deserializer) -> Result, erased_serde::Error>; pub struct MessageSerdeDecoderEntry { - pub type_url: &'static str, + pub type_url: MessageSerdeTypeUrlFn, pub decoder: MessageSerdeDecoderFn, pub deserializer: MessageSerdeDeserializerFn, } +impl MessageSerdeDecoderEntry { + pub const fn new() -> Self + where + for<'a> M: MessageSerde + prost::Name + Default + serde::Deserialize<'a>, + { + Self { + type_url: ::type_url, + decoder: |buf| { + let msg: M = prost::Message::decode(buf)?; + Ok(Box::new(msg)) + }, + deserializer: |de| erased_serde::deserialize::(de).map(|v| Box::new(v) as _), + } + } +} + inventory::collect!(MessageSerdeDecoderEntry); diff --git a/wkt-derive/src/lib.rs b/wkt-derive/src/lib.rs index 3f491ee..8fa7482 100644 --- a/wkt-derive/src/lib.rs +++ b/wkt-derive/src/lib.rs @@ -8,29 +8,8 @@ pub fn message_serde_derive(input: TokenStream) -> TokenStream { let name = &ast.ident; let gen = quote! { const _: () = { - const TYPE_URL: &str = ::prost_wkt::const_format::concatcp!( - "type.googleapis.com/", - <#name as ::prost::Name>::PACKAGE, - ".", - <#name as ::prost::Name>::NAME, - ); - impl ::prost_wkt::MessageSerde for #name { - fn package_name(&self) -> &'static str { - ::PACKAGE - } - fn message_name(&self) -> &'static str { - ::NAME - } - fn type_url(&self) -> &'static str { - TYPE_URL - } - fn new_instance(&self, data: Vec) -> ::std::result::Result, ::prost::DecodeError> { - let mut target = Self::default(); - ::prost::Message::merge(&mut target, data.as_slice())?; - let erased: ::std::boxed::Box = ::std::boxed::Box::new(target); - Ok(erased) - } + fn type_url(&self) -> String { <#name as ::prost::Name>::type_url() } fn try_encoded(&self) -> ::std::result::Result<::std::vec::Vec, ::prost::EncodeError> { let mut buf = ::std::vec::Vec::with_capacity(::prost::Message::encoded_len(self)); ::prost::Message::encode(self, &mut buf)?; @@ -42,17 +21,7 @@ pub fn message_serde_derive(input: TokenStream) -> TokenStream { } ::prost_wkt::inventory::submit!{ - ::prost_wkt::MessageSerdeDecoderEntry { - type_url: TYPE_URL, - decoder: |buf: &[u8]| { - let msg: #name = ::prost::Message::decode(buf)?; - Ok(::std::boxed::Box::new(msg)) - }, - deserializer: |de: &mut dyn ::prost_wkt::erased_serde::Deserializer| { - ::prost_wkt::erased_serde::deserialize::<#name>(de) - .map(|v| Box::new(v) as Box) - }, - } + ::prost_wkt::MessageSerdeDecoderEntry::new::<#name>() } }; }; diff --git a/wkt-types/Cargo.toml b/wkt-types/Cargo.toml index b0c8eef..8993103 100644 --- a/wkt-types/Cargo.toml +++ b/wkt-types/Cargo.toml @@ -28,7 +28,6 @@ serde = "1.0" serde_json = "1.0" serde_derive = "1.0" serde-value = "0.7" -erased-serde = "0.4" chrono = { version = "0.4.27", default-features = false, features = ["serde"] } [build-dependencies] diff --git a/wkt-types/src/pbany.rs b/wkt-types/src/pbany.rs index e938192..3b806bb 100644 --- a/wkt-types/src/pbany.rs +++ b/wkt-types/src/pbany.rs @@ -1,4 +1,4 @@ -use prost_wkt::MessageSerde; +use prost_wkt::{MessageSerde, MessageSerdeDecoderEntry}; use serde::de::{Deserialize, Deserializer, Visitor}; use serde::ser::{Serialize, SerializeStruct, Serializer}; @@ -59,7 +59,10 @@ impl Any { where T: Message + MessageSerde + Default, { - let type_url = MessageSerde::type_url(&message).to_string(); + let original_type_url = message.type_url(); + let type_url = TypeUrl::new(&message.type_url()) + .map(|s| s.to_string()) + .unwrap_or(original_type_url); // Serialize the message into a value let mut buf = Vec::with_capacity(message.encoded_len()); message.encode(&mut buf)?; @@ -88,10 +91,7 @@ impl Any { /// let back: Box = any.try_unpack()?; /// ``` pub fn try_unpack(self) -> Result, AnyError> { - ::prost_wkt::inventory::iter::<::prost_wkt::MessageSerdeDecoderEntry> - .into_iter() - .find(|entry| self.type_url == entry.type_url) - .ok_or_else(|| format!("Failed to deserialize {}. Make sure prost-wkt-build is executed.", self.type_url)) + find_entry(&self.type_url).ok_or_else(|| format!("Failed to deserialize {}. Make sure prost-wkt-build is executed.", self.type_url)) .and_then(|entry| { (entry.decoder)(&self.value).map_err(|error| { format!( @@ -110,7 +110,9 @@ impl Any { where M: Name, { - let type_url = M::type_url(); + let type_url = TypeUrl::new(&M::type_url()) + .map(|s| s.to_string()) + .unwrap_or_else(M::type_url); let mut value = Vec::new(); Message::encode(msg, &mut value)?; Ok(Any { type_url, value }) @@ -153,7 +155,11 @@ impl Serialize for Any { S: Serializer, { let mut state = serializer.serialize_struct("Any", 3)?; - state.serialize_field("@type", &self.type_url)?; + if let Some(type_url) = TypeUrl::new(&self.type_url) { + state.serialize_field("@type", &type_url)?; + } else { + state.serialize_field("@type", &self.type_url)?; + } match self.clone().try_unpack() { Ok(result) => { state.serialize_field("value", result.as_erased_serialize())?; @@ -169,7 +175,7 @@ impl Serialize for Any { struct AnyVisitor; impl<'de> Visitor<'de> for AnyVisitor { - type Value = Box; + type Value = (String, Box); fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("type.googleapis.com/google.protobuf.any") } @@ -198,18 +204,19 @@ impl<'de> Visitor<'de> for AnyVisitor { } let type_url = cached_type_url.ok_or_else(|| serde::de::Error::missing_field("@type"))?; let raw_value = cached_value.ok_or_else(|| serde::de::Error::missing_field("value"))?; - let entry = ::prost_wkt::inventory::iter::<::prost_wkt::MessageSerdeDecoderEntry> - .into_iter() - .find(|entry| type_url == entry.type_url) - .ok_or_else(|| { - serde::de::Error::invalid_type( - serde::de::Unexpected::Str(&type_url), - &"a typeurl registered by deriving SerdeMessage" as &dyn serde::de::Expected, - ) - })?; - let mut deserializer = - ::erase(ValueDeserializer::::new(raw_value)); - (entry.deserializer)(&mut deserializer).map_err(|err| serde::de::Error::custom(err)) + let entry = find_entry(&type_url).ok_or_else(|| { + serde::de::Error::invalid_type( + serde::de::Unexpected::Str(&type_url), + &"a typeurl registered by deriving SerdeMessage" as &dyn serde::de::Expected, + ) + })?; + let mut deserializer = ::erase( + ValueDeserializer::::new(raw_value), + ); + Ok(( + type_url, + (entry.deserializer)(&mut deserializer).map_err(|err| serde::de::Error::custom(err))?, + )) } } @@ -218,8 +225,8 @@ impl<'de> Deserialize<'de> for Any { where D: Deserializer<'de>, { - let erased = deserializer.deserialize_struct("Any", &["@type", "value"], AnyVisitor)?; - let type_url = erased.type_url().to_string(); + let (type_url, erased) = + deserializer.deserialize_struct("Any", &["@type", "value"], AnyVisitor)?; let value = erased.try_encoded().map_err(|err| { serde::de::Error::custom(format!("Failed to encode message: {err:?}")) })?; @@ -242,6 +249,8 @@ impl<'de> Deserialize<'de> for Any { /// specific semantics. #[derive(Debug, Eq, PartialEq)] struct TypeUrl<'a> { + /// The type's base url + base_url: &'a str, /// Fully qualified name of the type, e.g. `google.protobuf.Duration` full_name: &'a str, } @@ -249,21 +258,57 @@ struct TypeUrl<'a> { impl<'a> TypeUrl<'a> { fn new(s: &'a str) -> core::option::Option { // Must contain at least one "/" character. - let slash_pos = s.rfind('/')?; - // The last segment of the URL's path must represent the fully qualified name // of the type (as in `path/google.protobuf.Duration`) - let full_name = s.get((slash_pos + 1)..)?; + let (base_url, full_name) = s.rsplit_once("/")?; // The name should be in a canonical form (e.g., leading "." is not accepted). if full_name.starts_with('.') { return None; } - Some(Self { full_name }) + // Make the base url explicit + let base_url = if base_url == "" { + "type.googleapis.com" + } else { + base_url + }; + + Some(Self { + base_url, + full_name, + }) + } +} + +impl<'a> std::fmt::Display for TypeUrl<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}/{}", self.base_url, self.full_name) } } +impl<'a> Serialize for TypeUrl<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +fn find_entry(type_url: &str) -> Option<&'static MessageSerdeDecoderEntry> { + let to_search = TypeUrl::new(type_url)?; + prost_wkt::inventory::iter:: + .into_iter() + .find(|entry| { + let raw_entry_type_url = (entry.type_url)(); + let Some(entry_type_url) = TypeUrl::new(&raw_entry_type_url) else { + return false; + }; + entry_type_url == to_search + }) +} + #[cfg(test)] mod tests { use crate::pbany::*; @@ -285,13 +330,13 @@ mod tests { } #[test] - fn pack_unpack_test() { + fn to_from_msg_test() { let msg = Foo { string: "Hello World!".to_string(), }; - let any = Any::try_pack(msg.clone()).unwrap(); + let any = Any::from_msg(&msg).unwrap(); println!("{any:?}"); - let unpacked = any.unpack_as(Foo::default()).unwrap(); + let unpacked: Foo = any.to_msg().unwrap(); println!("{unpacked:?}"); assert_eq!(unpacked, msg) } From adc76b3fb6019d363df9c6d86c77c9da01a8748e Mon Sep 17 00:00:00 2001 From: William Date: Sat, 31 Aug 2024 10:33:34 -0700 Subject: [PATCH 3/7] don't handle empty domain, user should use prost build config --- wkt-types/build.rs | 1 + wkt-types/src/pbany.rs | 10 +++------- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/wkt-types/build.rs b/wkt-types/build.rs index 5f1e98f..59004a4 100644 --- a/wkt-types/build.rs +++ b/wkt-types/build.rs @@ -32,6 +32,7 @@ fn build(dir: &Path, proto: &str) { prost_build .compile_well_known_types() .enable_type_names() + .type_name_domain(&["."], "type.googleapis.com") .type_attribute( "google.protobuf.Empty", "#[derive(serde_derive::Serialize, serde_derive::Deserialize)]", diff --git a/wkt-types/src/pbany.rs b/wkt-types/src/pbany.rs index 3b806bb..a2b6c84 100644 --- a/wkt-types/src/pbany.rs +++ b/wkt-types/src/pbany.rs @@ -267,13 +267,6 @@ impl<'a> TypeUrl<'a> { return None; } - // Make the base url explicit - let base_url = if base_url == "" { - "type.googleapis.com" - } else { - base_url - }; - Some(Self { base_url, full_name, @@ -327,6 +320,9 @@ mod tests { impl Name for Foo { const NAME: &'static str = "Foo"; const PACKAGE: &'static str = "any.test"; + fn type_url() -> String { + format!("type.googleapis.com/{}.{}", Self::PACKAGE, Self::NAME) + } } #[test] From 8e10a722d6a30f75d77a2980fbd5e016fcdce45a Mon Sep 17 00:00:00 2001 From: William Date: Sat, 31 Aug 2024 10:51:53 -0700 Subject: [PATCH 4/7] add type name domain in example --- example/build.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/example/build.rs b/example/build.rs index 9564826..3f27fd0 100644 --- a/example/build.rs +++ b/example/build.rs @@ -6,6 +6,7 @@ fn main() { let mut prost_build = prost_build::Config::new(); prost_build .enable_type_names() + .type_name_domain(&[".my.requests", ".my.messages"], "type.googleapis.com") .type_attribute( ".my.requests", "#[derive(serde::Serialize, serde::Deserialize, ::prost_wkt::MessageSerde)] #[serde(default, rename_all=\"camelCase\")]", From 3820e776f9e60e998ee0e3f5dcc3061945559f11 Mon Sep 17 00:00:00 2001 From: William Date: Thu, 12 Sep 2024 08:08:45 -0700 Subject: [PATCH 5/7] use function instead of lambda for the deserializer callback --- src/lib.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 00f732f..69ab520 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -104,9 +104,18 @@ impl MessageSerdeDecoderEntry { let msg: M = prost::Message::decode(buf)?; Ok(Box::new(msg)) }, - deserializer: |de| erased_serde::deserialize::(de).map(|v| Box::new(v) as _), + deserializer: deserialize_boxed::, } } } +fn deserialize_boxed( + de: &mut dyn erased_serde::Deserializer, +) -> Result, erased_serde::Error> +where + for<'a> M: MessageSerde + serde::Deserialize<'a>, +{ + erased_serde::deserialize::(de).map(|v| Box::new(v) as _) +} + inventory::collect!(MessageSerdeDecoderEntry); From 556822fda7f2af13ad89399710a1c216f537dc5e Mon Sep 17 00:00:00 2001 From: William Date: Thu, 12 Sep 2024 08:23:50 -0700 Subject: [PATCH 6/7] change error message in try_unpack when entry not found --- example/src/main.rs | 5 +++-- wkt-types/src/lib.rs | 1 - wkt-types/src/pbany.rs | 2 +- wkt-types/src/pbempty.rs | 3 ++- wkt-types/src/pbstruct.rs | 26 +++++++++----------------- wkt-types/src/pbtime/datetime.rs | 1 - wkt-types/src/pbtime/duration.rs | 8 +++----- wkt-types/src/pbtime/mod.rs | 12 ++++++++---- wkt-types/src/pbtime/timestamp.rs | 1 - wkt-types/tests/pbstruct_test.rs | 5 ++--- 10 files changed, 28 insertions(+), 36 deletions(-) diff --git a/example/src/main.rs b/example/src/main.rs index a29d30d..8c300d2 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -6,8 +6,9 @@ include!(concat!(env!("OUT_DIR"), "/my.messages.rs")); include!(concat!(env!("OUT_DIR"), "/my.requests.rs")); fn main() -> Result<(), AnyError> { - - let content: Content = Content { body: Some(content::Body::SomeBool(true)) }; + let content: Content = Content { + body: Some(content::Body::SomeBool(true)), + }; let foo_msg: Foo = Foo { data: "Hello World".to_string(), diff --git a/wkt-types/src/lib.rs b/wkt-types/src/lib.rs index 1d02c1c..940a3d4 100644 --- a/wkt-types/src/lib.rs +++ b/wkt-types/src/lib.rs @@ -16,4 +16,3 @@ mod pbmask; pub use crate::pbmask::*; pub use prost_wkt::MessageSerde; - diff --git a/wkt-types/src/pbany.rs b/wkt-types/src/pbany.rs index a2b6c84..59dbedf 100644 --- a/wkt-types/src/pbany.rs +++ b/wkt-types/src/pbany.rs @@ -91,7 +91,7 @@ impl Any { /// let back: Box = any.try_unpack()?; /// ``` pub fn try_unpack(self) -> Result, AnyError> { - find_entry(&self.type_url).ok_or_else(|| format!("Failed to deserialize {}. Make sure prost-wkt-build is executed.", self.type_url)) + find_entry(&self.type_url).ok_or_else(|| format!("Failed to deserialize {}. Make sure the MessageSerde trait is derived through its proc macro.", self.type_url)) .and_then(|entry| { (entry.decoder)(&self.value).map_err(|error| { format!( diff --git a/wkt-types/src/pbempty.rs b/wkt-types/src/pbempty.rs index 659d09e..12ab968 100644 --- a/wkt-types/src/pbempty.rs +++ b/wkt-types/src/pbempty.rs @@ -24,7 +24,8 @@ mod tests { #[test] fn deserialize_empty() { - let msg: Empty = serde_json::from_str("{}").expect("Could not deserialize `{}` to an Empty struct!"); + let msg: Empty = + serde_json::from_str("{}").expect("Could not deserialize `{}` to an Empty struct!"); assert_eq!(msg, EMPTY); } diff --git a/wkt-types/src/pbstruct.rs b/wkt-types/src/pbstruct.rs index 4b53097..26b8e82 100644 --- a/wkt-types/src/pbstruct.rs +++ b/wkt-types/src/pbstruct.rs @@ -201,7 +201,7 @@ impl Serialize for Struct { { let mut map = serializer.serialize_map(Some(self.fields.len()))?; for (k, v) in &self.fields { - map.serialize_entry( k, v)?; + map.serialize_entry(k, v)?; } map.end() } @@ -217,12 +217,8 @@ impl Serialize for Value { Some(value::Kind::StringValue(string)) => serializer.serialize_str(string), Some(value::Kind::BoolValue(boolean)) => serializer.serialize_bool(*boolean), Some(value::Kind::NullValue(_)) => serializer.serialize_none(), - Some(value::Kind::ListValue(list)) => { - list.serialize(serializer) - } - Some(value::Kind::StructValue(object)) => { - object.serialize(serializer) - } + Some(value::Kind::ListValue(list)) => list.serialize(serializer), + Some(value::Kind::StructValue(object)) => object.serialize(serializer), _ => serializer.serialize_none(), } } @@ -244,9 +240,7 @@ impl<'de> Visitor<'de> for ListValueVisitor { while let Some(el) = seq.next_element()? { values.push(el) } - Ok(ListValue { - values - }) + Ok(ListValue { values }) } } @@ -271,21 +265,19 @@ impl<'de> Visitor<'de> for StructVisitor { where A: MapAccess<'de>, { - let mut fields: std::collections::HashMap = - std::collections::HashMap::new(); + let mut fields: std::collections::HashMap = std::collections::HashMap::new(); while let Some((key, value)) = map.next_entry::()? { fields.insert(key, value); } - Ok(Struct { - fields - }) + Ok(Struct { fields }) } } impl<'de> Deserialize<'de> for Struct { fn deserialize(deserializer: D) -> Result>::Error> - where - D: Deserializer<'de> { + where + D: Deserializer<'de>, + { deserializer.deserialize_map(StructVisitor) } } diff --git a/wkt-types/src/pbtime/datetime.rs b/wkt-types/src/pbtime/datetime.rs index b2f1c07..369cb89 100644 --- a/wkt-types/src/pbtime/datetime.rs +++ b/wkt-types/src/pbtime/datetime.rs @@ -4,7 +4,6 @@ //////////////////////////////////////////////////////////////////////////////// /// FROM prost-types/src/datetime.rs //////////////////////////////////////////////////////////////////////////////// - use core::fmt; use crate::Duration; diff --git a/wkt-types/src/pbtime/duration.rs b/wkt-types/src/pbtime/duration.rs index fcad5ec..c5df49d 100644 --- a/wkt-types/src/pbtime/duration.rs +++ b/wkt-types/src/pbtime/duration.rs @@ -190,9 +190,10 @@ impl From for chrono::Duration { // A call to `normalize` should capture all out-of-bound sitations hopefully // ensuring a panic never happens! Ideally this implementation should be // deprecated in favour of TryFrom but unfortunately having `TryFrom` along with - // `From` causes a conflict. + // `From` causes a conflict. value.normalize(); - let s = chrono::TimeDelta::try_seconds(value.seconds).expect("invalid or out-of-range seconds"); + let s = + chrono::TimeDelta::try_seconds(value.seconds).expect("invalid or out-of-range seconds"); let ns = chrono::Duration::nanoseconds(value.nanos as i64); s + ns } @@ -276,6 +277,3 @@ impl<'de> Deserialize<'de> for Duration { deserializer.deserialize_str(DurationVisitor) } } - - - diff --git a/wkt-types/src/pbtime/mod.rs b/wkt-types/src/pbtime/mod.rs index 19266c7..e58df0d 100644 --- a/wkt-types/src/pbtime/mod.rs +++ b/wkt-types/src/pbtime/mod.rs @@ -8,8 +8,8 @@ pub use timestamp::TimestampError; use core::convert::TryFrom; use core::str::FromStr; -use core::*; use core::time; +use core::*; use std::convert::TryInto; use chrono::prelude::*; @@ -26,7 +26,6 @@ include!(concat!(env!("OUT_DIR"), "/pbtime/google.protobuf.rs")); const NANOS_PER_SECOND: i32 = 1_000_000_000; const NANOS_MAX: i32 = NANOS_PER_SECOND - 1; - #[cfg(test)] mod tests { @@ -89,12 +88,17 @@ mod tests { }; let chrono_duration: chrono::Duration = duration.into(); assert_eq!(chrono_duration.num_seconds(), 10); - assert_eq!((chrono_duration - chrono::Duration::try_seconds(10).expect("seconds")).num_nanoseconds(), Some(100)); + assert_eq!( + (chrono_duration - chrono::Duration::try_seconds(10).expect("seconds")) + .num_nanoseconds(), + Some(100) + ); } #[test] fn test_duration_conversion_chrono_to_pb() { - let chrono_duration = chrono::Duration::try_seconds(10).expect("seconds") + chrono::Duration::nanoseconds(100); + let chrono_duration = chrono::Duration::try_seconds(10).expect("seconds") + + chrono::Duration::nanoseconds(100); let duration: Duration = chrono_duration.into(); assert_eq!(duration.seconds, 10); assert_eq!(duration.nanos, 100); diff --git a/wkt-types/src/pbtime/timestamp.rs b/wkt-types/src/pbtime/timestamp.rs index 471a205..6e66724 100644 --- a/wkt-types/src/pbtime/timestamp.rs +++ b/wkt-types/src/pbtime/timestamp.rs @@ -320,4 +320,3 @@ impl<'de> Deserialize<'de> for Timestamp { deserializer.deserialize_str(TimestampVisitor) } } - diff --git a/wkt-types/tests/pbstruct_test.rs b/wkt-types/tests/pbstruct_test.rs index 6f59808..add9ef6 100644 --- a/wkt-types/tests/pbstruct_test.rs +++ b/wkt-types/tests/pbstruct_test.rs @@ -30,7 +30,7 @@ fn test_flatten_struct() { let mut fields: HashMap = HashMap::new(); fields.insert("test".to_string(), create_struct()); let strct = Struct { - fields: fields.clone() + fields: fields.clone(), }; let string_strct = serde_json::to_string_pretty(&strct).expect("Serialized struct"); println!("{string_strct}"); @@ -46,7 +46,7 @@ fn test_flatten_struct() { fn test_flatten_list() { let values: Vec = vec![Value::null(), Value::from(20.0), Value::from(true)]; let list: ListValue = ListValue { - values: values.clone() + values: values.clone(), }; let string_list = serde_json::to_string_pretty(&list).expect("Serialized list"); println!("{string_list}"); @@ -56,5 +56,4 @@ fn test_flatten_list() { println!("{string}"); assert_eq!(string_list, string); - } From dc3c1ae2ae57f854d745412ef47981ca273c7ba9 Mon Sep 17 00:00:00 2001 From: William Date: Thu, 12 Sep 2024 08:39:08 -0700 Subject: [PATCH 7/7] cleanup typeurl usage --- wkt-types/src/pbany.rs | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/wkt-types/src/pbany.rs b/wkt-types/src/pbany.rs index 59dbedf..f18b8e8 100644 --- a/wkt-types/src/pbany.rs +++ b/wkt-types/src/pbany.rs @@ -59,10 +59,7 @@ impl Any { where T: Message + MessageSerde + Default, { - let original_type_url = message.type_url(); - let type_url = TypeUrl::new(&message.type_url()) - .map(|s| s.to_string()) - .unwrap_or(original_type_url); + let type_url = message.type_url(); // Serialize the message into a value let mut buf = Vec::with_capacity(message.encoded_len()); message.encode(&mut buf)?; @@ -110,9 +107,7 @@ impl Any { where M: Name, { - let type_url = TypeUrl::new(&M::type_url()) - .map(|s| s.to_string()) - .unwrap_or_else(M::type_url); + let type_url = M::type_url(); let mut value = Vec::new(); Message::encode(msg, &mut value)?; Ok(Any { type_url, value }) @@ -155,11 +150,7 @@ impl Serialize for Any { S: Serializer, { let mut state = serializer.serialize_struct("Any", 3)?; - if let Some(type_url) = TypeUrl::new(&self.type_url) { - state.serialize_field("@type", &type_url)?; - } else { - state.serialize_field("@type", &self.type_url)?; - } + state.serialize_field("@type", &self.type_url)?; match self.clone().try_unpack() { Ok(result) => { state.serialize_field("value", result.as_erased_serialize())?; @@ -289,17 +280,10 @@ impl<'a> Serialize for TypeUrl<'a> { } } -fn find_entry(type_url: &str) -> Option<&'static MessageSerdeDecoderEntry> { - let to_search = TypeUrl::new(type_url)?; +fn find_entry(to_search: &str) -> Option<&'static MessageSerdeDecoderEntry> { prost_wkt::inventory::iter:: .into_iter() - .find(|entry| { - let raw_entry_type_url = (entry.type_url)(); - let Some(entry_type_url) = TypeUrl::new(&raw_entry_type_url) else { - return false; - }; - entry_type_url == to_search - }) + .find(|entry| (entry.type_url)() == to_search) } #[cfg(test)]