From c50bb9a47b396ad6a08a3fec36b98bcc2d9217a1 Mon Sep 17 00:00:00 2001
From: Christopher Patton <cpatton@cloudflare.com>
Date: Wed, 18 Dec 2024 11:20:54 -0800
Subject: [PATCH] Align IDPF public share encoding with VDAF-13 (#1168)

To simplify the spec, we've coalesced the seeds and payloads into
continuous chunks.
---
 src/idpf.rs | 88 ++++++++++++++++++++++++++++++++++-------------------
 1 file changed, 56 insertions(+), 32 deletions(-)

diff --git a/src/idpf.rs b/src/idpf.rs
index 2a11535b..df0e1217 100644
--- a/src/idpf.rs
+++ b/src/idpf.rs
@@ -690,10 +690,16 @@ where
     VL: Encode,
 {
     fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
-        // Control bits need to be written within each byte in LSB-to-MSB order, and assigned into
-        // bytes in big-endian order. Thus, the first four levels will have their control bits
-        // encoded in the last byte, and the last levels will have their control bits encoded in the
-        // first byte.
+        // draft-irtf-cfrg-vdaf-13, Section 8.2.6.1:
+        //
+        // struct {
+        //     opaque packed_control_bits[packed_len];
+        //     opaque seed[poplar1.idpf.KEY_SIZE*B];
+        //     Poplar1FieldInner payload_inner[Fi*poplar1.idpf.VALUE_LEN*(B-1)];
+        //     Poplar1FieldLeaf payload_leaf[Fl*poplar1.idpf.VALUE_LEN];
+        // } Poplar1PublicShare;
+        //
+        // Control bits
         let mut control_bits: BitVec<u8, Lsb0> =
             BitVec::with_capacity(self.inner_correction_words.len() * 2 + 2);
         for correction_words in self.inner_correction_words.iter() {
@@ -709,11 +715,18 @@ where
         let mut packed_control = control_bits.into_vec();
         bytes.append(&mut packed_control);
 
+        // Seeds
         for correction_words in self.inner_correction_words.iter() {
             Seed(correction_words.seed).encode(bytes)?;
-            correction_words.value.encode(bytes)?;
         }
         Seed(self.leaf_correction_word.seed).encode(bytes)?;
+
+        // Inner payloads
+        for correction_words in self.inner_correction_words.iter() {
+            correction_words.value.encode(bytes)?;
+        }
+
+        // Leaf payload
         self.leaf_correction_word.value.encode(bytes)
     }
 
@@ -735,39 +748,50 @@ where
 {
     fn decode_with_param(bits: &usize, bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
         let packed_control_len = (bits + 3) / 4;
-        let mut packed = vec![0u8; packed_control_len];
-        bytes.read_exact(&mut packed)?;
-        let unpacked_control_bits: BitVec<u8, Lsb0> = BitVec::from_vec(packed);
+        let mut packed_control_bits = vec![0u8; packed_control_len];
+        bytes.read_exact(&mut packed_control_bits)?;
+        let unpacked_control_bits: BitVec<u8, Lsb0> = BitVec::from_vec(packed_control_bits);
 
-        let mut inner_correction_words = Vec::with_capacity(bits - 1);
-        for chunk in unpacked_control_bits[0..(bits - 1) * 2].chunks(2) {
-            let control_bits = [(chunk[0] as u8).into(), (chunk[1] as u8).into()];
-            let seed = Seed::decode(bytes)?.0;
-            let value = VI::decode(bytes)?;
-            inner_correction_words.push(IdpfCorrectionWord {
-                seed,
-                control_bits,
-                value,
-            })
+        // Control bits
+        let mut control_bits = Vec::with_capacity(*bits);
+        for chunk in unpacked_control_bits[0..bits * 2].chunks(2) {
+            control_bits.push([(chunk[0] as u8).into(), (chunk[1] as u8).into()]);
         }
 
-        let control_bits = [
-            (unpacked_control_bits[(bits - 1) * 2] as u8).into(),
-            (unpacked_control_bits[bits * 2 - 1] as u8).into(),
-        ];
-        let seed = Seed::decode(bytes)?.0;
-        let value = VL::decode(bytes)?;
-        let leaf_correction_word = IdpfCorrectionWord {
-            seed,
-            control_bits,
-            value,
-        };
-
         // Check that unused packed bits are zero.
         if unpacked_control_bits[bits * 2..].any() {
             return Err(CodecError::UnexpectedValue);
         }
 
+        // Seeds
+        let mut seeds = std::iter::repeat_with(|| Seed::decode(bytes).map(|seed| seed.0))
+            .take(*bits)
+            .collect::<Result<Vec<_>, _>>()?;
+
+        // Inner payloads
+        let inner_payloads = std::iter::repeat_with(|| VI::decode(bytes))
+            .take(bits - 1)
+            .collect::<Result<Vec<_>, _>>()?;
+
+        // Outer payload
+        let leaf_paylaod = VL::decode(bytes)?;
+
+        let leaf_correction_word = IdpfCorrectionWord {
+            seed: seeds.pop().unwrap(),                // *bits == 0
+            control_bits: control_bits.pop().unwrap(), // *bits == 0
+            value: leaf_paylaod,
+        };
+
+        let inner_correction_words = seeds
+            .into_iter()
+            .zip(control_bits.into_iter().zip(inner_payloads))
+            .map(|(seed, (control_bits, payload))| IdpfCorrectionWord {
+                seed,
+                control_bits,
+                value: payload,
+            })
+            .collect::<Vec<_>>();
+
         Ok(IdpfPublicShare {
             inner_correction_words,
             leaf_correction_word,
@@ -1748,12 +1772,12 @@ mod tests {
         let message = hex::decode(concat!(
             "39",                               // packed control bit correction words (0b00111001)
             "abababababababababababababababab", // seed correction word, first level
+            "cdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcd", // seed correction word, second level
+            "ffffffffffffffffffffffffffffffff", // seed correction word, third level
             "3d45010000000000",                 // field element correction word
             "e7e8010000000000",                 // field element correction word, continued
-            "cdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcd", // seed correction word, second level
             "28c50c0100000000",                 // field element correction word
             "c250000000000000",                 // field element correction word, continued
-            "ffffffffffffffffffffffffffffffff", // seed correction word, third level
             "0100000000000000000000000000000000000000000000000000000000000000", // field element correction word, leaf field
             "f0debc9a78563412f0debc9a78563412f0debc9a78563412f0debc9a78563412", // field element correction word, continued
         ))