From 50340018db447d71383fa3e82b7be91bcfdfa939 Mon Sep 17 00:00:00 2001 From: Casper Meijn Date: Thu, 29 Aug 2024 11:10:01 +0200 Subject: [PATCH] feat: derive Eq and Hash trait for messages where possible Integer and bytes types can be compared using trait Eq. Some generated Rust structs can also have this property by deriving the Eq trait. Automatically derive Eq and Hash for: - messages that only have fields with integer or bytes types - messages where all field types also implement Eq and Hash - the Rust enum for one-of fields, where all fields implement Eq and Hash Generated code for Protobuf enums already derives Eq and Hash. BREAKING CHANGE: `prost-build` will automatically derive `trait Eq` and `trait Hash` for types where all field support those as well. If you manually `impl Eq` and/or `impl Hash` for generated types, then you need to remove the manual implementation. If you use `type_attribute` to `derive(Eq)` and/or `derive(Hash)`, then you need to remove those. --- prost-build/src/code_generator.rs | 18 +++++++- .../_expected_field_attributes.rs | 10 ++--- .../_expected_field_attributes_formatted.rs | 10 ++--- .../helloworld/_expected_helloworld.rs | 4 +- .../_expected_helloworld_formatted.rs | 4 +- prost-build/src/message_graph.rs | 41 +++++++++++++++++++ prost-types/src/compiler.rs | 2 +- prost-types/src/duration.rs | 8 ---- prost-types/src/protobuf.rs | 24 +++++------ prost-types/src/timestamp.rs | 13 ------ tests/build.rs | 2 +- tests/single-include/src/outdir/outdir.rs | 2 +- 12 files changed, 86 insertions(+), 52 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 24b5194e5..22eb5ef20 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -229,12 +229,17 @@ impl<'a> CodeGenerator<'a> { self.append_message_attributes(&fq_message_name); self.push_indent(); self.buf.push_str(&format!( - "#[derive(Clone, {}PartialEq, {}::Message)]\n", + "#[derive(Clone, {}PartialEq, {}{}::Message)]\n", if self.message_graph.can_message_derive_copy(&fq_message_name) { "Copy, " } else { "" }, + if self.message_graph.can_message_derive_eq(&fq_message_name) { + "Eq, Hash, " + } else { + "" + }, prost_path(self.config) )); self.append_skip_debug(&fq_message_name); @@ -619,9 +624,18 @@ impl<'a> CodeGenerator<'a> { self.message_graph .can_field_derive_copy(fq_message_name, &field.descriptor) }); + let can_oneof_derive_eq = oneof.fields.iter().all(|field| { + self.message_graph + .can_field_derive_eq(fq_message_name, &field.descriptor) + }); self.buf.push_str(&format!( - "#[derive(Clone, {}PartialEq, {}::Oneof)]\n", + "#[derive(Clone, {}PartialEq, {}{}::Oneof)]\n", if can_oneof_derive_copy { "Copy, " } else { "" }, + if can_oneof_derive_eq { + "Eq, Hash, " + } else { + "" + }, prost_path(self.config) )); self.append_skip_debug(fq_message_name); diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs index bf1e8c517..509e96bbe 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs @@ -1,12 +1,12 @@ // This file is @generated by prost-build. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Container { #[prost(oneof="container::Data", tags="1, 2")] pub data: ::core::option::Option, } /// Nested message and enum types in `Container`. pub mod container { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum Data { #[prost(message, tag="1")] Foo(::prost::alloc::boxed::Box), @@ -14,16 +14,16 @@ pub mod container { Bar(super::Bar), } } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Foo { #[prost(string, tag="1")] pub foo: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag="1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Qux { } diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs index c130aad2e..9f5b10cb1 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs @@ -1,12 +1,12 @@ // This file is @generated by prost-build. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Container { #[prost(oneof = "container::Data", tags = "1, 2")] pub data: ::core::option::Option, } /// Nested message and enum types in `Container`. pub mod container { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum Data { #[prost(message, tag = "1")] Foo(::prost::alloc::boxed::Box), @@ -14,15 +14,15 @@ pub mod container { Bar(super::Bar), } } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Foo { #[prost(string, tag = "1")] pub foo: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag = "1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Qux {} diff --git a/prost-build/src/fixtures/helloworld/_expected_helloworld.rs b/prost-build/src/fixtures/helloworld/_expected_helloworld.rs index f39278358..ae65e24df 100644 --- a/prost-build/src/fixtures/helloworld/_expected_helloworld.rs +++ b/prost-build/src/fixtures/helloworld/_expected_helloworld.rs @@ -1,14 +1,14 @@ // This file is @generated by prost-build. #[derive(derive_builder::Builder)] #[derive(custom_proto::Input)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Message { #[prost(string, tag="1")] pub say: ::prost::alloc::string::String, } #[derive(derive_builder::Builder)] #[derive(custom_proto::Output)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Response { #[prost(string, tag="1")] pub say: ::prost::alloc::string::String, diff --git a/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs b/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs index c75338e2b..49a1f1f9e 100644 --- a/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs +++ b/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs @@ -1,14 +1,14 @@ // This file is @generated by prost-build. #[derive(derive_builder::Builder)] #[derive(custom_proto::Input)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Message { #[prost(string, tag = "1")] pub say: ::prost::alloc::string::String, } #[derive(derive_builder::Builder)] #[derive(custom_proto::Output)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Response { #[prost(string, tag = "1")] pub say: ::prost::alloc::string::String, diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index e2bcad918..e9331ad53 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -153,4 +153,45 @@ impl MessageGraph { ) } } + + /// Returns `true` if this message can automatically derive Eq trait. + pub fn can_message_derive_eq(&self, fq_message_name: &str) -> bool { + assert_eq!(".", &fq_message_name[..1]); + + let msg = self.messages.get(fq_message_name).unwrap(); + msg.field + .iter() + .all(|field| self.can_field_derive_eq(fq_message_name, field)) + } + + /// Returns `true` if the type of this field allows deriving the Eq trait. + pub fn can_field_derive_eq(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool { + assert_eq!(".", &fq_message_name[..1]); + + if field.r#type() == Type::Message { + if field.label() == Label::Repeated || self.is_nested(field.type_name(), fq_message_name) { + false + } else { + self.can_message_derive_eq(field.type_name()) + } + } else { + matches!( + field.r#type(), + Type::Int32 + | Type::Int64 + | Type::Uint32 + | Type::Uint64 + | Type::Sint32 + | Type::Sint64 + | Type::Fixed32 + | Type::Fixed64 + | Type::Sfixed32 + | Type::Sfixed64 + | Type::Bool + | Type::Enum + | Type::String + | Type::Bytes + ) + } + } } diff --git a/prost-types/src/compiler.rs b/prost-types/src/compiler.rs index 862a2a532..ccac39d4f 100644 --- a/prost-types/src/compiler.rs +++ b/prost-types/src/compiler.rs @@ -1,6 +1,6 @@ // This file is @generated by prost-build. /// The version number of protocol compiler. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Version { #[prost(int32, optional, tag = "1")] pub major: ::core::option::Option, diff --git a/prost-types/src/duration.rs b/prost-types/src/duration.rs index 3ce993ee5..187f04a43 100644 --- a/prost-types/src/duration.rs +++ b/prost-types/src/duration.rs @@ -1,13 +1,5 @@ use super::*; -#[cfg(feature = "std")] -impl std::hash::Hash for Duration { - fn hash(&self, state: &mut H) { - self.seconds.hash(state); - self.nanos.hash(state); - } -} - impl Duration { /// Normalizes the duration to a canonical format. /// diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index 6f75dfc2b..cc4210a6d 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -89,7 +89,7 @@ pub mod descriptor_proto { /// Range of reserved tag numbers. Reserved tag numbers may not be used by /// fields or extension ranges in the same message. Reserved ranges may /// not overlap. - #[derive(Clone, Copy, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ReservedRange { /// Inclusive. #[prost(int32, optional, tag = "1")] @@ -350,7 +350,7 @@ pub mod enum_descriptor_proto { /// Note that this is distinct from DescriptorProto.ReservedRange in that it /// is inclusive such that it can appropriately represent the entire int32 /// domain. - #[derive(Clone, Copy, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct EnumReservedRange { /// Inclusive. #[prost(int32, optional, tag = "1")] @@ -961,7 +961,7 @@ pub mod uninterpreted_option { /// extension (denoted with parentheses in options specs in .proto files). /// E.g.,{ \["foo", false\], \["bar.baz", true\], \["qux", false\] } represents /// "foo.(bar.baz).qux". - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct NamePart { #[prost(string, required, tag = "1")] pub name_part: ::prost::alloc::string::String, @@ -1022,7 +1022,7 @@ pub struct SourceCodeInfo { } /// Nested message and enum types in `SourceCodeInfo`. pub mod source_code_info { - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Location { /// Identifies which part of the FileDescriptorProto was defined at this /// location. @@ -1125,7 +1125,7 @@ pub struct GeneratedCodeInfo { } /// Nested message and enum types in `GeneratedCodeInfo`. pub mod generated_code_info { - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Annotation { /// Identifies the element in the original source .proto file. This field /// is formatted the same as SourceCodeInfo.Location.path. @@ -1238,7 +1238,7 @@ pub mod generated_code_info { /// "value": "1.212s" /// } /// ``` -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Any { /// A URL/resource name that uniquely identifies the type of the serialized /// protocol buffer message. This string must contain at least @@ -1275,7 +1275,7 @@ pub struct Any { } /// `SourceContext` represents information about the source of a /// protobuf element, like the file in which it is defined. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct SourceContext { /// The path-qualified name of the .proto file that contained the associated /// protobuf element. For example: `"google/protobuf/source_context.proto"`. @@ -1531,7 +1531,7 @@ pub struct EnumValue { } /// A protocol buffer option, which can be attached to a message, field, /// enumeration, etc. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Option { /// The option's name. For protobuf built-in options (options defined in /// descriptor.proto), this is the short name. For example, `"map_entry"`. @@ -1741,7 +1741,7 @@ pub struct Method { /// ... /// } /// ``` -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Mixin { /// The fully qualified name of the interface which is included. #[prost(string, tag = "1")] @@ -1815,7 +1815,7 @@ pub struct Mixin { /// encoded in JSON format as "3s", while 3 seconds and 1 nanosecond should /// be expressed in JSON format as "3.000000001s", and 3 seconds and 1 /// microsecond should be expressed in JSON format as "3.000001s". -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Duration { /// Signed seconds of the span of time. Must be from -315,576,000,000 /// to +315,576,000,000 inclusive. Note: these bounds are computed from: @@ -2053,7 +2053,7 @@ pub struct Duration { /// The implementation of any API method which has a FieldMask type field in the /// request should verify the included field paths, and return an /// `INVALID_ARGUMENT` error if any path is unmappable. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct FieldMask { /// The set of field mask paths. #[prost(string, repeated, tag = "1")] @@ -2249,7 +2249,7 @@ impl NullValue { /// [`strftime`]() with /// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use /// the Joda Time's [`ISODateTimeFormat.dateTime()`]() to obtain a formatter capable of generating timestamps in this format. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Timestamp { /// Represents seconds of UTC time since Unix epoch /// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to diff --git a/prost-types/src/timestamp.rs b/prost-types/src/timestamp.rs index 1d7e609f4..12e37402c 100644 --- a/prost-types/src/timestamp.rs +++ b/prost-types/src/timestamp.rs @@ -123,19 +123,6 @@ impl Name for Timestamp { } } -/// Implements the unstable/naive version of `Eq`: a basic equality check on the internal fields of the `Timestamp`. -/// This implies that `normalized_ts != non_normalized_ts` even if `normalized_ts == non_normalized_ts.normalized()`. -#[cfg(feature = "std")] -impl Eq for Timestamp {} - -#[cfg(feature = "std")] -impl std::hash::Hash for Timestamp { - fn hash(&self, state: &mut H) { - self.seconds.hash(state); - self.nanos.hash(state); - } -} - #[cfg(feature = "std")] impl From for Timestamp { fn from(system_time: std::time::SystemTime) -> Timestamp { diff --git a/tests/build.rs b/tests/build.rs index b707fb270..cf7cd17cf 100644 --- a/tests/build.rs +++ b/tests/build.rs @@ -38,7 +38,7 @@ fn main() { config.type_attribute("Foo.Custom.Attrs.AnotherEnum", "/// Oneof docs"); config.type_attribute( "Foo.Custom.OneOfAttrs.Msg.field", - "#[derive(Eq, PartialOrd, Ord)]", + "#[derive(PartialOrd, Ord)]", ); config.field_attribute("Foo.Custom.Attrs.AnotherEnum.C", "/// The C docs"); config.field_attribute("Foo.Custom.Attrs.AnotherEnum.D", "/// The D docs"); diff --git a/tests/single-include/src/outdir/outdir.rs b/tests/single-include/src/outdir/outdir.rs index 233028a04..c285a3875 100644 --- a/tests/single-include/src/outdir/outdir.rs +++ b/tests/single-include/src/outdir/outdir.rs @@ -1,5 +1,5 @@ // This file is @generated by prost-build. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct OutdirRequest { #[prost(string, tag = "1")] pub query: ::prost::alloc::string::String,