diff --git a/src/exp/grpcroute.rs b/src/exp/grpcroute.rs index 0fb9c55..75236d0 100644 --- a/src/exp/grpcroute.rs +++ b/src/exp/grpcroute.rs @@ -247,9 +247,7 @@ pub type GrpcHeaderMatch = crate::httproute::HttpHeaderMatch; /// Method specifies a gRPC request service/method matcher. If this field is /// not specified, all services and methods will match. -#[derive( - Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize, schemars::JsonSchema, -)] +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, schemars::JsonSchema)] #[serde(tag = "type", rename_all = "PascalCase")] pub enum GrpcMethodMatch { #[serde(rename_all = "camelCase")] @@ -285,6 +283,77 @@ pub enum GrpcMethodMatch { }, } +impl<'de> serde::Deserialize<'de> for GrpcMethodMatch { + // NOTE: This custom deserialization exists to ensure the deserialization + // behavior matches the behavior prescribed by the gateway api docs + // for how the "type" field on `GRPCRouteMatch` is expected to work. + // + // ref: https://gateway-api.sigs.k8s.io/reference/spec/#gateway.networking.k8s.io%2fv1alpha2.GRPCMethodMatch + fn deserialize>(deserializer: D) -> Result { + #[derive(serde::Deserialize)] + #[serde(field_identifier, rename_all = "lowercase")] + enum Field { + Type, + Method, + Service, + } + + struct GrpcMethodMatchVisitor; + + impl<'de> serde::de::Visitor<'de> for GrpcMethodMatchVisitor { + type Value = GrpcMethodMatch; + + fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + formatter.write_str("GrpcMethodMatch") + } + + fn visit_map(self, mut map: V) -> Result + where + V: serde::de::MapAccess<'de>, + { + let (mut r#type, mut method, mut service) = (None, None, None); + + while let Some(key) = map.next_key()? { + match key { + Field::Type => { + if r#type.is_some() { + return Err(serde::de::Error::duplicate_field("type")); + } + r#type = map.next_value::>()?; + } + Field::Method => { + if method.is_some() { + return Err(serde::de::Error::duplicate_field("method")); + } + method = map.next_value::>()?; + } + Field::Service => { + if service.is_some() { + return Err(serde::de::Error::duplicate_field("service")); + } + service = map.next_value::>()?; + } + } + } + + match r#type { + None | Some("Exact") => Ok(GrpcMethodMatch::Exact { method, service }), + Some("RegularExpression") => { + Ok(GrpcMethodMatch::RegularExpression { method, service }) + } + Some(value) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(value), + &r#"one of: {"Exact", "RegularExpression"}"#, + )), + } + } + } + + const FIELDS: &[&str] = &["type", "method", "service"]; + deserializer.deserialize_struct("GrpcMethodMatch", FIELDS, GrpcMethodMatchVisitor) + } +} + /// GrpcRouteFilter defines processing steps that must be completed during the /// request or response lifecycle. GrpcRouteFilters are meant as an extension /// point to express processing that may be done in Gateway implementations. Some @@ -394,3 +463,88 @@ pub struct GrpcRouteBackendRef { #[serde(default, skip_serializing_if = "Option::is_none")] pub weight: Option, } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_grpc_route_deserialization() { + // Test deserialization against upstream example + // ref: https://gateway-api.sigs.k8s.io/api-types/grpcroute/#backendrefs-optional + let data = r#"{ + "apiVersion": "gateway.networking.k8s.io/v1alpha2", + "kind": "GRPCRoute", + "metadata": { + "name": "grpc-app-1" + }, + "spec": { + "parentRefs": [ + { + "name": "my-gateway" + } + ], + "hostnames": [ + "example.com" + ], + "rules": [ + { + "matches": [ + { + "method": { + "service": "com.example.User", + "method": "Login" + } + }, + { + "method": { + "service": "com.example.User", + "method": "Logout", + "type": "Exact" + } + }, + { + "method": { + "service": "com.example.User", + "method": "UpdateProfile", + "type": "RegularExpression" + } + } + ], + "backendRefs": [ + { + "name": "my-service1", + "port": 50051 + } + ] + }, + { + "matches": [ + { + "headers": [ + { + "type": "Exact", + "name": "magic", + "value": "foo" + } + ], + "method": { + "service": "com.example.Things", + "method": "DoThing" + } + } + ], + "backendRefs": [ + { + "name": "my-service2", + "port": 50051 + } + ] + } + ] + } + }"#; + let route = serde_json::from_str::(data); + assert!(route.is_ok()); + } +}