From 15b0542870a594ce8ce4a69a3fc2ba7ea7e16479 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Fri, 22 Nov 2024 11:03:57 -0800 Subject: [PATCH] Update batch mode-specific messages to include a length prefix. (#3502) As before, we will (correctly) generate an invalidMessage error if we receive a message with an unknown batch mode, starting from an error generated at the point of parsing the batch mode discriminator byte. --- messages/src/lib.rs | 25 +++++---- messages/src/tests/aggregation.rs | 7 ++- messages/src/tests/collection.rs | 90 ++++++++++++++++++++----------- messages/src/tests/query.rs | 16 +++--- 4 files changed, 89 insertions(+), 49 deletions(-) diff --git a/messages/src/lib.rs b/messages/src/lib.rs index 6f2ecc420..c114aca78 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -6,6 +6,7 @@ use self::batch_mode::{BatchMode, LeaderSelected, TimeInterval}; use anyhow::anyhow; use base64::{display::Base64Display, engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use core::slice; use derivative::Derivative; use num_enum::{FromPrimitive, IntoPrimitive, TryFromPrimitive}; use prio::{ @@ -1483,18 +1484,20 @@ impl Query { impl Encode for Query { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { B::CODE.encode(bytes)?; - self.query_body.encode(bytes) + encode_u16_items(bytes, &(), slice::from_ref(&self.query_body)) } fn encoded_len(&self) -> Option { - Some(1 + self.query_body.encoded_len()?) + Some(1 + 2 + self.query_body.encoded_len()?) } } impl Decode for Query { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { batch_mode::Code::decode_expecting_value(bytes, B::CODE)?; - let query_body = B::QueryBody::decode(bytes)?; + + let buf = decode_u16_items(&(), bytes)?; + let query_body = B::QueryBody::get_decoded(&buf)?; Ok(Self { query_body }) } @@ -1604,18 +1607,20 @@ impl PartialBatchSelector { impl Encode for PartialBatchSelector { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { B::CODE.encode(bytes)?; - self.batch_identifier.encode(bytes) + encode_u16_items(bytes, &(), slice::from_ref(&self.batch_identifier)) } fn encoded_len(&self) -> Option { - Some(1 + self.batch_identifier.encoded_len()?) + Some(1 + 2 + self.batch_identifier.encoded_len()?) } } impl Decode for PartialBatchSelector { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { batch_mode::Code::decode_expecting_value(bytes, B::CODE)?; - let batch_identifier = B::PartialBatchIdentifier::decode(bytes)?; + + let buf = decode_u16_items(&(), bytes)?; + let batch_identifier = B::PartialBatchIdentifier::get_decoded(&buf)?; Ok(Self { batch_identifier }) } @@ -2558,18 +2563,20 @@ impl BatchSelector { impl Encode for BatchSelector { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { B::CODE.encode(bytes)?; - self.batch_identifier.encode(bytes) + encode_u16_items(bytes, &(), slice::from_ref(&self.batch_identifier)) } fn encoded_len(&self) -> Option { - Some(1 + self.batch_identifier.encoded_len()?) + Some(1 + 2 + self.batch_identifier.encoded_len()?) } } impl Decode for BatchSelector { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { batch_mode::Code::decode_expecting_value(bytes, B::CODE)?; - let batch_identifier = B::BatchIdentifier::decode(bytes)?; + + let buf = decode_u16_items(&(), bytes)?; + let batch_identifier = B::BatchIdentifier::get_decoded(&buf)?; Ok(Self { batch_identifier }) } diff --git a/messages/src/tests/aggregation.rs b/messages/src/tests/aggregation.rs index 3ff461234..5cabe301e 100644 --- a/messages/src/tests/aggregation.rs +++ b/messages/src/tests/aggregation.rs @@ -335,7 +335,9 @@ fn roundtrip_aggregation_job_initialize_req() { ), concat!( // partial_batch_selector - "01", // batch_mode + "01", // batch_mode + "0000", // length + "", // opaque data ), concat!( // prepare_inits @@ -472,7 +474,8 @@ fn roundtrip_aggregation_job_initialize_req() { concat!( // partial_batch_selector "02", // batch_mode - "0202020202020202020202020202020202020202020202020202020202020202", // batch_id + "0020", // length + "0202020202020202020202020202020202020202020202020202020202020202", // opaque data ), concat!( // prepare_inits diff --git a/messages/src/tests/collection.rs b/messages/src/tests/collection.rs index 00db586ce..94a55dfb5 100644 --- a/messages/src/tests/collection.rs +++ b/messages/src/tests/collection.rs @@ -23,9 +23,10 @@ fn roundtrip_collection_req() { concat!( concat!( // query - "01", // batch_mode + "01", // batch_mode + "0010", // length concat!( - // query_body + // opaque data "000000000000D431", // start "0000000000003039", // duration ), @@ -51,9 +52,10 @@ fn roundtrip_collection_req() { concat!( concat!( // query - "01", // batch_mode + "01", // batch_mode + "0010", // length concat!( - // batch_interval + // query body "000000000000BF11", // start "000000000000AEB1", // duration ), @@ -76,7 +78,9 @@ fn roundtrip_collection_req() { }, concat!( concat!( - "02", // batch_mode + "02", // batch_mode + "0000", // length + "", // opaque data ), concat!( // aggregation_parameter @@ -92,7 +96,9 @@ fn roundtrip_collection_req() { }, concat!( concat!( - "02", // batch_mode + "02", // batch_mode + "0000", // length + "", // opaque data ), concat!( // aggregation_parameter @@ -110,7 +116,9 @@ fn roundtrip_partial_batch_selector() { roundtrip_encoding(&[( PartialBatchSelector::new_time_interval(), concat!( - "01", // batch_mode + "01", // batch_mode + "0000", // length + "", // opaque data ), )]); @@ -120,14 +128,16 @@ fn roundtrip_partial_batch_selector() { PartialBatchSelector::new_leader_selected(BatchId::from([3u8; 32])), concat!( "02", // batch_mode - "0303030303030303030303030303030303030303030303030303030303030303", // batch_id + "0020", // length + "0303030303030303030303030303030303030303030303030303030303030303", // opaque data ), ), ( PartialBatchSelector::new_leader_selected(BatchId::from([4u8; 32])), concat!( "02", // batch_mode - "0404040404040404040404040404040404040404040404040404040404040404", // batch_id + "0020", // length + "0404040404040404040404040404040404040404040404040404040404040404", // opaque data ), ), ]) @@ -160,7 +170,9 @@ fn roundtrip_collection() { concat!( concat!( // partial_batch_selector - "01", // batch_mode + "01", // batch_mode + "0000", // length + "", // opaque data ), "0000000000000000", // report_count concat!( @@ -217,7 +229,9 @@ fn roundtrip_collection() { concat!( concat!( // partial_batch_selector - "01", // batch_mode + "01", // batch_mode + "0000", // length + "", // opaque data ), "0000000000000017", // report_count concat!( @@ -280,8 +294,9 @@ fn roundtrip_collection() { concat!( concat!( // partial_batch_selector - "02", // batch_mode - "0303030303030303030303030303030303030303030303030303030303030303", // batch_id + "02", // batch_mode + "0020", // length + "0303030303030303030303030303030303030303030303030303030303030303", // opaque data ), "0000000000000000", // report_count concat!( @@ -340,8 +355,9 @@ fn roundtrip_collection() { concat!( concat!( // partial_batch_selector - "02", // batch_mode - "0404040404040404040404040404040404040404040404040404040404040404", // batch_id + "02", // batch_mode + "0020", // length + "0404040404040404040404040404040404040404040404040404040404040404", // opaque data ), "0000000000000017", // report_count concat!( @@ -395,9 +411,10 @@ fn roundtrip_batch_selector() { .unwrap(), }, concat!( - "01", // batch_mode + "01", // batch_mode + "0010", // length concat!( - // batch_interval + // opaque data "000000000000D431", // start "0000000000003039", // duration ), @@ -412,9 +429,10 @@ fn roundtrip_batch_selector() { .unwrap(), }, concat!( - "01", // batch_mode + "01", // batch_mode + "0010", // length concat!( - // batch_interval + // opaque data "000000000000C685", // start "0000000000014982", // duration ), @@ -431,7 +449,8 @@ fn roundtrip_batch_selector() { concat!( // batch_selector "02", // batch_mode - "0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C", // batch_id + "0020", // length + "0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C", // opaque data ), ), ( @@ -440,7 +459,8 @@ fn roundtrip_batch_selector() { }, concat!( "02", // batch_mode - "0707070707070707070707070707070707070707070707070707070707070707", // batch_id + "0020", // length + "0707070707070707070707070707070707070707070707070707070707070707", // opaque data ), ), ]) @@ -466,9 +486,10 @@ fn roundtrip_aggregate_share_req() { concat!( concat!( // batch_selector - "01", // batch_mode + "01", // batch_mode + "0010", // length concat!( - // batch_interval + // opaque data "000000000000D431", // start "0000000000003039", // duration ), @@ -498,9 +519,10 @@ fn roundtrip_aggregate_share_req() { concat!( concat!( // batch_selector - "01", // batch_mode + "01", // batch_mode + "0010", // length concat!( - // batch_interval + // opaque data "000000000000C685", // start "0000000000014982", // duration ), @@ -530,8 +552,9 @@ fn roundtrip_aggregate_share_req() { concat!( concat!( // batch_selector - "02", // batch_mode - "0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C", // batch_id + "02", // batch_mode + "0020", // length + "0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C", // opaque data ), concat!( // aggregation_parameter @@ -554,8 +577,9 @@ fn roundtrip_aggregate_share_req() { concat!( concat!( // batch_selector - "02", // batch_mode - "0707070707070707070707070707070707070707070707070707070707070707", // batch_id + "02", // batch_mode + "0020", // length + "0707070707070707070707070707070707070707070707070707070707070707", // opaque data ), concat!( // aggregation_parameter @@ -644,9 +668,10 @@ fn roundtrip_aggregate_share_aad() { ), concat!( // batch_selector - "01", // batch_mode + "01", // batch_mode + "0010", // length concat!( - // batch_interval + // opaque data "000000000000D431", // start "0000000000003039", // duration ), @@ -673,7 +698,8 @@ fn roundtrip_aggregate_share_aad() { concat!( // batch_selector "02", // batch_mode - "0707070707070707070707070707070707070707070707070707070707070707", // batch_id + "0020", // length + "0707070707070707070707070707070707070707070707070707070707070707", // opaque data ), ), )]) diff --git a/messages/src/tests/query.rs b/messages/src/tests/query.rs index 68e11a4a6..01448d44d 100644 --- a/messages/src/tests/query.rs +++ b/messages/src/tests/query.rs @@ -37,9 +37,10 @@ fn roundtrip_query() { .unwrap(), }, concat!( - "01", // batch_mode + "01", // batch_mode + "0010", // length concat!( - // query_body + // opaque data "000000000000D431", // start "0000000000003039", // duration ), @@ -54,9 +55,10 @@ fn roundtrip_query() { .unwrap(), }, concat!( - "01", // batch_mode + "01", // batch_mode + "0010", // length concat!( - // query_body + // opaque data "000000000000BF11", // start "000000000000AEB1", // duration ), @@ -68,8 +70,10 @@ fn roundtrip_query() { roundtrip_encoding(&[( Query:: { query_body: () }, concat!( - "02", // batch_mode - ), + "02", // batch_mode + "0000", // length + "", // opaque data + ), )]) }