Skip to content

Commit

Permalink
Update batch mode-specific messages to include a length prefix. (#3502)
Browse files Browse the repository at this point in the history
As before, we will (correctly) generate an invalidMessage error if we
receive a message with an unknown batch mode, starting from an error
generated at the point of parsing the batch mode discriminator byte.
  • Loading branch information
branlwyd authored Nov 22, 2024
1 parent 0ac1c9c commit 15b0542
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 49 deletions.
25 changes: 16 additions & 9 deletions messages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use self::batch_mode::{BatchMode, LeaderSelected, TimeInterval};
use anyhow::anyhow;
use base64::{display::Base64Display, engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use core::slice;
use derivative::Derivative;
use num_enum::{FromPrimitive, IntoPrimitive, TryFromPrimitive};
use prio::{
Expand Down Expand Up @@ -1483,18 +1484,20 @@ impl Query<LeaderSelected> {
impl<B: BatchMode> Encode for Query<B> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
B::CODE.encode(bytes)?;
self.query_body.encode(bytes)
encode_u16_items(bytes, &(), slice::from_ref(&self.query_body))
}

fn encoded_len(&self) -> Option<usize> {
Some(1 + self.query_body.encoded_len()?)
Some(1 + 2 + self.query_body.encoded_len()?)
}
}

impl<B: BatchMode> Decode for Query<B> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
batch_mode::Code::decode_expecting_value(bytes, B::CODE)?;
let query_body = B::QueryBody::decode(bytes)?;

let buf = decode_u16_items(&(), bytes)?;
let query_body = B::QueryBody::get_decoded(&buf)?;

Ok(Self { query_body })
}
Expand Down Expand Up @@ -1604,18 +1607,20 @@ impl PartialBatchSelector<LeaderSelected> {
impl<B: BatchMode> Encode for PartialBatchSelector<B> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
B::CODE.encode(bytes)?;
self.batch_identifier.encode(bytes)
encode_u16_items(bytes, &(), slice::from_ref(&self.batch_identifier))
}

fn encoded_len(&self) -> Option<usize> {
Some(1 + self.batch_identifier.encoded_len()?)
Some(1 + 2 + self.batch_identifier.encoded_len()?)
}
}

impl<B: BatchMode> Decode for PartialBatchSelector<B> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
batch_mode::Code::decode_expecting_value(bytes, B::CODE)?;
let batch_identifier = B::PartialBatchIdentifier::decode(bytes)?;

let buf = decode_u16_items(&(), bytes)?;
let batch_identifier = B::PartialBatchIdentifier::get_decoded(&buf)?;

Ok(Self { batch_identifier })
}
Expand Down Expand Up @@ -2558,18 +2563,20 @@ impl BatchSelector<LeaderSelected> {
impl<B: BatchMode> Encode for BatchSelector<B> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
B::CODE.encode(bytes)?;
self.batch_identifier.encode(bytes)
encode_u16_items(bytes, &(), slice::from_ref(&self.batch_identifier))
}

fn encoded_len(&self) -> Option<usize> {
Some(1 + self.batch_identifier.encoded_len()?)
Some(1 + 2 + self.batch_identifier.encoded_len()?)
}
}

impl<B: BatchMode> Decode for BatchSelector<B> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
batch_mode::Code::decode_expecting_value(bytes, B::CODE)?;
let batch_identifier = B::BatchIdentifier::decode(bytes)?;

let buf = decode_u16_items(&(), bytes)?;
let batch_identifier = B::BatchIdentifier::get_decoded(&buf)?;

Ok(Self { batch_identifier })
}
Expand Down
7 changes: 5 additions & 2 deletions messages/src/tests/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,9 @@ fn roundtrip_aggregation_job_initialize_req() {
),
concat!(
// partial_batch_selector
"01", // batch_mode
"01", // batch_mode
"0000", // length
"", // opaque data
),
concat!(
// prepare_inits
Expand Down Expand Up @@ -472,7 +474,8 @@ fn roundtrip_aggregation_job_initialize_req() {
concat!(
// partial_batch_selector
"02", // batch_mode
"0202020202020202020202020202020202020202020202020202020202020202", // batch_id
"0020", // length
"0202020202020202020202020202020202020202020202020202020202020202", // opaque data
),
concat!(
// prepare_inits
Expand Down
90 changes: 58 additions & 32 deletions messages/src/tests/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ fn roundtrip_collection_req() {
concat!(
concat!(
// query
"01", // batch_mode
"01", // batch_mode
"0010", // length
concat!(
// query_body
// opaque data
"000000000000D431", // start
"0000000000003039", // duration
),
Expand All @@ -51,9 +52,10 @@ fn roundtrip_collection_req() {
concat!(
concat!(
// query
"01", // batch_mode
"01", // batch_mode
"0010", // length
concat!(
// batch_interval
// query body
"000000000000BF11", // start
"000000000000AEB1", // duration
),
Expand All @@ -76,7 +78,9 @@ fn roundtrip_collection_req() {
},
concat!(
concat!(
"02", // batch_mode
"02", // batch_mode
"0000", // length
"", // opaque data
),
concat!(
// aggregation_parameter
Expand All @@ -92,7 +96,9 @@ fn roundtrip_collection_req() {
},
concat!(
concat!(
"02", // batch_mode
"02", // batch_mode
"0000", // length
"", // opaque data
),
concat!(
// aggregation_parameter
Expand All @@ -110,7 +116,9 @@ fn roundtrip_partial_batch_selector() {
roundtrip_encoding(&[(
PartialBatchSelector::new_time_interval(),
concat!(
"01", // batch_mode
"01", // batch_mode
"0000", // length
"", // opaque data
),
)]);

Expand All @@ -120,14 +128,16 @@ fn roundtrip_partial_batch_selector() {
PartialBatchSelector::new_leader_selected(BatchId::from([3u8; 32])),
concat!(
"02", // batch_mode
"0303030303030303030303030303030303030303030303030303030303030303", // batch_id
"0020", // length
"0303030303030303030303030303030303030303030303030303030303030303", // opaque data
),
),
(
PartialBatchSelector::new_leader_selected(BatchId::from([4u8; 32])),
concat!(
"02", // batch_mode
"0404040404040404040404040404040404040404040404040404040404040404", // batch_id
"0020", // length
"0404040404040404040404040404040404040404040404040404040404040404", // opaque data
),
),
])
Expand Down Expand Up @@ -160,7 +170,9 @@ fn roundtrip_collection() {
concat!(
concat!(
// partial_batch_selector
"01", // batch_mode
"01", // batch_mode
"0000", // length
"", // opaque data
),
"0000000000000000", // report_count
concat!(
Expand Down Expand Up @@ -217,7 +229,9 @@ fn roundtrip_collection() {
concat!(
concat!(
// partial_batch_selector
"01", // batch_mode
"01", // batch_mode
"0000", // length
"", // opaque data
),
"0000000000000017", // report_count
concat!(
Expand Down Expand Up @@ -280,8 +294,9 @@ fn roundtrip_collection() {
concat!(
concat!(
// partial_batch_selector
"02", // batch_mode
"0303030303030303030303030303030303030303030303030303030303030303", // batch_id
"02", // batch_mode
"0020", // length
"0303030303030303030303030303030303030303030303030303030303030303", // opaque data
),
"0000000000000000", // report_count
concat!(
Expand Down Expand Up @@ -340,8 +355,9 @@ fn roundtrip_collection() {
concat!(
concat!(
// partial_batch_selector
"02", // batch_mode
"0404040404040404040404040404040404040404040404040404040404040404", // batch_id
"02", // batch_mode
"0020", // length
"0404040404040404040404040404040404040404040404040404040404040404", // opaque data
),
"0000000000000017", // report_count
concat!(
Expand Down Expand Up @@ -395,9 +411,10 @@ fn roundtrip_batch_selector() {
.unwrap(),
},
concat!(
"01", // batch_mode
"01", // batch_mode
"0010", // length
concat!(
// batch_interval
// opaque data
"000000000000D431", // start
"0000000000003039", // duration
),
Expand All @@ -412,9 +429,10 @@ fn roundtrip_batch_selector() {
.unwrap(),
},
concat!(
"01", // batch_mode
"01", // batch_mode
"0010", // length
concat!(
// batch_interval
// opaque data
"000000000000C685", // start
"0000000000014982", // duration
),
Expand All @@ -431,7 +449,8 @@ fn roundtrip_batch_selector() {
concat!(
// batch_selector
"02", // batch_mode
"0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C", // batch_id
"0020", // length
"0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C", // opaque data
),
),
(
Expand All @@ -440,7 +459,8 @@ fn roundtrip_batch_selector() {
},
concat!(
"02", // batch_mode
"0707070707070707070707070707070707070707070707070707070707070707", // batch_id
"0020", // length
"0707070707070707070707070707070707070707070707070707070707070707", // opaque data
),
),
])
Expand All @@ -466,9 +486,10 @@ fn roundtrip_aggregate_share_req() {
concat!(
concat!(
// batch_selector
"01", // batch_mode
"01", // batch_mode
"0010", // length
concat!(
// batch_interval
// opaque data
"000000000000D431", // start
"0000000000003039", // duration
),
Expand Down Expand Up @@ -498,9 +519,10 @@ fn roundtrip_aggregate_share_req() {
concat!(
concat!(
// batch_selector
"01", // batch_mode
"01", // batch_mode
"0010", // length
concat!(
// batch_interval
// opaque data
"000000000000C685", // start
"0000000000014982", // duration
),
Expand Down Expand Up @@ -530,8 +552,9 @@ fn roundtrip_aggregate_share_req() {
concat!(
concat!(
// batch_selector
"02", // batch_mode
"0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C", // batch_id
"02", // batch_mode
"0020", // length
"0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C", // opaque data
),
concat!(
// aggregation_parameter
Expand All @@ -554,8 +577,9 @@ fn roundtrip_aggregate_share_req() {
concat!(
concat!(
// batch_selector
"02", // batch_mode
"0707070707070707070707070707070707070707070707070707070707070707", // batch_id
"02", // batch_mode
"0020", // length
"0707070707070707070707070707070707070707070707070707070707070707", // opaque data
),
concat!(
// aggregation_parameter
Expand Down Expand Up @@ -644,9 +668,10 @@ fn roundtrip_aggregate_share_aad() {
),
concat!(
// batch_selector
"01", // batch_mode
"01", // batch_mode
"0010", // length
concat!(
// batch_interval
// opaque data
"000000000000D431", // start
"0000000000003039", // duration
),
Expand All @@ -673,7 +698,8 @@ fn roundtrip_aggregate_share_aad() {
concat!(
// batch_selector
"02", // batch_mode
"0707070707070707070707070707070707070707070707070707070707070707", // batch_id
"0020", // length
"0707070707070707070707070707070707070707070707070707070707070707", // opaque data
),
),
)])
Expand Down
16 changes: 10 additions & 6 deletions messages/src/tests/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ fn roundtrip_query() {
.unwrap(),
},
concat!(
"01", // batch_mode
"01", // batch_mode
"0010", // length
concat!(
// query_body
// opaque data
"000000000000D431", // start
"0000000000003039", // duration
),
Expand All @@ -54,9 +55,10 @@ fn roundtrip_query() {
.unwrap(),
},
concat!(
"01", // batch_mode
"01", // batch_mode
"0010", // length
concat!(
// query_body
// opaque data
"000000000000BF11", // start
"000000000000AEB1", // duration
),
Expand All @@ -68,8 +70,10 @@ fn roundtrip_query() {
roundtrip_encoding(&[(
Query::<LeaderSelected> { query_body: () },
concat!(
"02", // batch_mode
),
"02", // batch_mode
"0000", // length
"", // opaque data
),
)])
}

Expand Down

0 comments on commit 15b0542

Please sign in to comment.