diff --git a/src/flp.rs b/src/flp.rs index dabd41ab..a3643526 100644 --- a/src/flp.rs +++ b/src/flp.rs @@ -3,7 +3,7 @@ //! Implementation of the generic Fully Linear Proof (FLP) system specified in //! [[draft-irtf-cfrg-vdaf-08]]. This is the main building block of [`Prio3`](crate::vdaf::prio3). //! -//! The FLP is derived for any implementation of the [`Type`] trait. Such an implementation +//! The FLP is derived for any implementation of the [`Flp`] trait. Such an implementation //! specifies a validity circuit that defines the set of valid measurements, as well as the finite //! field in which the validity circuit is evaluated. It also determines how raw measurements are //! encoded as inputs to the validity circuit, and how aggregates are decoded from sums of @@ -19,7 +19,7 @@ //! //! ``` //! use prio::flp::types::Count; -//! use prio::flp::Type; +//! use prio::flp::{Type, Flp}; //! use prio::field::{random_vector, FieldElement, Field64}; //! //! // The prover chooses a measurement. @@ -63,15 +63,15 @@ pub mod types; #[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum FlpError { - /// Calling [`Type::prove`] returned an error. + /// Calling [`Flp::prove`] returned an error. #[error("prove error: {0}")] Prove(String), - /// Calling [`Type::query`] returned an error. + /// Calling [`Flp::query`] returned an error. #[error("query error: {0}")] Query(String), - /// Calling [`Type::decide`] returned an error. + /// Calling [`Flp::decide`] returned an error. #[error("decide error: {0}")] Decide(String), @@ -113,35 +113,11 @@ pub enum FlpError { DifferentialPrivacy(#[from] crate::dp::DpError), } -/// A type. Implementations of this trait specify how a particular kind of measurement is encoded -/// as a vector of field elements and how validity of the encoded measurement is determined. -/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement. -pub trait Type: Sized + Eq + Clone + Debug { - /// The type of raw measurement to be encoded. - type Measurement: Clone + Debug; - - /// The type of aggregate result for this type. - type AggregateResult: Clone + Debug; - +/// The FLP proof system. The user specifies the validity circuit. +pub trait Flp: Sized + Eq + Clone + Debug { /// The finite field used for this type. type Field: FftFriendlyFieldElement; - /// Encodes a measurement as a vector of [`Self::input_len`] field elements. - fn encode_measurement( - &self, - measurement: &Self::Measurement, - ) -> Result, FlpError>; - - /// Decodes an aggregate result. - /// - /// This is NOT the inverse of `encode_measurement`. Rather, the input is an aggregation of - /// truncated measurements. - fn decode_result( - &self, - data: &[Self::Field], - num_measurements: usize, - ) -> Result; - /// Returns the sequence of gadgets associated with the validity circuit. /// /// # Notes @@ -173,7 +149,7 @@ pub trait Type: Sized + Eq + Clone + Debug { /// /// ``` /// use prio::flp::types::Count; - /// use prio::flp::Type; + /// use prio::flp::{Flp, Type}; /// use prio::field::{random_vector, FieldElement, Field64}; /// /// let count = Count::new(); @@ -190,10 +166,6 @@ pub trait Type: Sized + Eq + Clone + Debug { num_shares: usize, ) -> Result, FlpError>; - /// Constructs an aggregatable output from an encoded input. Calling this method is only safe - /// once `input` has been validated. - fn truncate(&self, input: Vec) -> Result, FlpError>; - /// The length in field elements of the encoded input returned by [`Self::encode_measurement`]. fn input_len(&self) -> usize; @@ -203,9 +175,6 @@ pub trait Type: Sized + Eq + Clone + Debug { /// The length in field elements of the verifier message constructed by [`Self::query`]. fn verifier_len(&self) -> usize; - /// The length of the truncated output (i.e., the output of [`Type::truncate`]). - fn output_len(&self) -> usize; - /// The length of the joint random input. fn joint_rand_len(&self) -> usize; @@ -589,6 +558,40 @@ pub trait Type: Sized + Eq + Clone + Debug { } } +/// A type. Implementations of this trait specify how a particular kind of measurement is encoded +/// as a vector of field elements and how validity of the encoded measurement is determined. +/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement. +pub trait Type: Flp { + /// The type of raw measurement to be encoded. + type Measurement: Clone + Debug; + + /// The type of aggregate result for this type. + type AggregateResult: Clone + Debug; + + /// Encodes a measurement as a vector of [`Self::input_len`] field elements. + fn encode_measurement( + &self, + measurement: &Self::Measurement, + ) -> Result, FlpError>; + + /// Constructs an aggregatable output from an encoded input. Calling this method is only safe + /// once `input` has been validated. + fn truncate(&self, input: Vec) -> Result, FlpError>; + + /// Decodes an aggregate result. + /// + /// This is NOT the inverse of `encode_measurement`. Rather, the input is an aggregation of + /// truncated measurements. + fn decode_result( + &self, + data: &[Self::Field], + num_measurements: usize, + ) -> Result; + + /// The length of the truncated output (i.e., the output of [`Flp::truncate`]). + fn output_len(&self) -> usize; +} + /// A type which supports adding noise to aggregate shares for Server Differential Privacy. #[cfg(feature = "experimental")] #[cfg_attr(docsrs, doc(cfg(feature = "experimental")))] @@ -804,7 +807,7 @@ pub mod test_utils { /// Various tests for an FLP. #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] - pub struct FlpTest<'a, T: Type> { + pub struct TypeTest<'a, T: Type> { /// The FLP. pub flp: &'a T, @@ -821,7 +824,7 @@ pub mod test_utils { pub expect_valid: bool, } - impl FlpTest<'_, T> { + impl TypeTest<'_, T> { /// Construct a test and run it. Expect the input to be valid and compare the truncated /// output to the provided value. pub fn expect_valid( @@ -829,7 +832,7 @@ pub mod test_utils { input: &[T::Field], expected_output: &[T::Field], ) { - FlpTest { + TypeTest { flp, name: None, input, @@ -841,7 +844,7 @@ pub mod test_utils { /// Construct a test and run it. Expect the input to be invalid. pub fn expect_invalid(flp: &T, input: &[T::Field]) { - FlpTest { + TypeTest { flp, name: None, input, @@ -853,7 +856,7 @@ pub mod test_utils { /// Construct a test and run it. Expect the input to be valid. pub fn expect_valid_no_output(flp: &T, input: &[T::Field]) { - FlpTest { + TypeTest { flp, name: None, input, @@ -1077,9 +1080,7 @@ mod tests { } } - impl Type for TestType { - type Measurement = F::Integer; - type AggregateResult = F::Integer; + impl Flp for TestType { type Field = F; fn valid( @@ -1132,10 +1133,6 @@ mod tests { 1 + mul + poly } - fn output_len(&self) -> usize { - self.input_len() - } - fn joint_rand_len(&self) -> usize { 1 } @@ -1158,6 +1155,11 @@ mod tests { fn num_gadgets(&self) -> usize { 2 } + } + + impl Type for TestType { + type Measurement = F::Integer; + type AggregateResult = F::Integer; fn encode_measurement(&self, measurement: &F::Integer) -> Result, FlpError> { Ok(vec![ @@ -1177,6 +1179,10 @@ mod tests { ) -> Result { panic!("not implemented"); } + + fn output_len(&self) -> usize { + self.input_len() + } } // In https://github.com/divviup/libprio-rs/issues/254 an out-of-bounds bug was reported that @@ -1215,9 +1221,7 @@ mod tests { } } - impl Type for Issue254Type { - type Measurement = F::Integer; - type AggregateResult = F::Integer; + impl Flp for Issue254Type { type Field = F; fn valid( @@ -1265,10 +1269,6 @@ mod tests { 1 + first + second } - fn output_len(&self) -> usize { - self.input_len() - } - fn joint_rand_len(&self) -> usize { 0 } @@ -1298,6 +1298,11 @@ mod tests { fn num_gadgets(&self) -> usize { 2 } + } + + impl Type for Issue254Type { + type Measurement = F::Integer; + type AggregateResult = F::Integer; fn encode_measurement(&self, measurement: &F::Integer) -> Result, FlpError> { Ok(vec![F::from(*measurement)]) @@ -1314,5 +1319,9 @@ mod tests { ) -> Result { panic!("not implemented"); } + + fn output_len(&self) -> usize { + self.input_len() + } } } diff --git a/src/flp/types.rs b/src/flp/types.rs index 8f63af10..a8791195 100644 --- a/src/flp/types.rs +++ b/src/flp/types.rs @@ -4,7 +4,7 @@ use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt, Integer}; use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; -use crate::flp::{FlpError, Gadget, Type}; +use crate::flp::{Flp, FlpError, Gadget, Type}; use crate::polynomial::poly_range_check; use std::convert::TryInto; use std::fmt::{self, Debug}; @@ -42,23 +42,9 @@ impl Default for Count { } } -impl Type for Count { - type Measurement = bool; - type AggregateResult = F::Integer; +impl Flp for Count { type Field = F; - fn encode_measurement(&self, value: &bool) -> Result, FlpError> { - Ok(vec![F::conditional_select( - &F::zero(), - &F::one(), - Choice::from(u8::from(*value)), - )]) - } - - fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result { - decode_result(data) - } - fn gadget(&self) -> Vec>> { vec![Box::new(Mul::new(1))] } @@ -79,11 +65,6 @@ impl Type for Count { Ok(vec![out]) } - fn truncate(&self, input: Vec) -> Result, FlpError> { - self.truncate_call_check(&input)?; - Ok(input) - } - fn input_len(&self) -> usize { 1 } @@ -96,10 +77,6 @@ impl Type for Count { 4 } - fn output_len(&self) -> usize { - self.input_len() - } - fn joint_rand_len(&self) -> usize { 0 } @@ -113,6 +90,33 @@ impl Type for Count { } } +impl Type for Count { + type Measurement = bool; + + type AggregateResult = F::Integer; + + fn encode_measurement(&self, value: &bool) -> Result, FlpError> { + Ok(vec![F::conditional_select( + &F::zero(), + &F::one(), + Choice::from(u8::from(*value)), + )]) + } + + fn truncate(&self, input: Vec) -> Result, FlpError> { + self.truncate_call_check(&input)?; + Ok(input) + } + + fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result { + decode_result(data) + } + + fn output_len(&self) -> usize { + self.input_len() + } +} + /// The sum type. Each measurement is a integer in `[0, max_measurement]` and the aggregate is the /// sum of the measurements. /// @@ -168,29 +172,9 @@ impl Sum { } } -impl Type for Sum { - type Measurement = F::Integer; - type AggregateResult = F::Integer; +impl Flp for Sum { type Field = F; - fn encode_measurement(&self, summand: &F::Integer) -> Result, FlpError> { - if summand > &self.max_measurement { - return Err(FlpError::Encode(format!( - "unexpected measurement: got {:?}; want ≤{:?}", - summand, self.max_measurement - ))); - } - - let enc_summand = F::encode_as_bitvector(*summand, self.bits)?; - let enc_summand_plus_offset = F::encode_as_bitvector(self.offset + *summand, self.bits)?; - - Ok(enc_summand.chain(enc_summand_plus_offset).collect()) - } - - fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result { - decode_result(data) - } - fn gadget(&self) -> Vec>> { vec![Box::new(PolyEval::new( self.bit_range_checker.clone(), @@ -227,12 +211,6 @@ impl Type for Sum { Ok([bit_checks.as_slice(), &[range_check]].concat()) } - fn truncate(&self, input: Vec) -> Result, FlpError> { - self.truncate_call_check(&input)?; - let res = F::decode_bitvector(&input[..self.bits])?; - Ok(vec![res]) - } - fn input_len(&self) -> usize { 2 * self.bits } @@ -245,10 +223,6 @@ impl Type for Sum { 3 } - fn output_len(&self) -> usize { - 1 - } - fn joint_rand_len(&self) -> usize { 0 } @@ -262,6 +236,40 @@ impl Type for Sum { } } +// XXX type with encoding +impl Type for Sum { + type Measurement = F::Integer; + type AggregateResult = F::Integer; + + fn encode_measurement(&self, summand: &F::Integer) -> Result, FlpError> { + if summand > &self.max_measurement { + return Err(FlpError::Encode(format!( + "unexpected measurement: got {:?}; want ≤{:?}", + summand, self.max_measurement + ))); + } + + let enc_summand = F::encode_as_bitvector(*summand, self.bits)?; + let enc_summand_plus_offset = F::encode_as_bitvector(self.offset + *summand, self.bits)?; + + Ok(enc_summand.chain(enc_summand_plus_offset).collect()) + } + + fn truncate(&self, input: Vec) -> Result, FlpError> { + self.truncate_call_check(&input)?; + let res = F::decode_bitvector(&input[..self.bits])?; + Ok(vec![res]) + } + + fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result { + decode_result(data) + } + + fn output_len(&self) -> usize { + 1 + } +} + /// The average type. Each measurement is an integer in `[0, max_measurement]` and the aggregate is /// the arithmetic average of the measurements. // This is just a `Sum` object under the hood. The only difference is that the aggregate result is @@ -289,25 +297,9 @@ impl Average { } } -impl Type for Average { - type Measurement = F::Integer; - type AggregateResult = f64; +impl Flp for Average { type Field = F; - fn encode_measurement(&self, summand: &F::Integer) -> Result, FlpError> { - self.summer.encode_measurement(summand) - } - - fn decode_result(&self, data: &[F], num_measurements: usize) -> Result { - // Compute the average from the aggregated sum. - let sum = self.summer.decode_result(data, num_measurements)?; - let data: u64 = sum - .try_into() - .map_err(|err| FlpError::Decode(format!("failed to convert {sum:?} to u64: {err}",)))?; - let result = (data as f64) / (num_measurements as f64); - Ok(result) - } - fn gadget(&self) -> Vec>> { self.summer.gadget() } @@ -326,10 +318,6 @@ impl Type for Average { self.summer.valid(g, input, joint_rand, num_shares) } - fn truncate(&self, input: Vec) -> Result, FlpError> { - self.summer.truncate(input) - } - fn input_len(&self) -> usize { self.summer.input_len() } @@ -342,10 +330,6 @@ impl Type for Average { self.summer.verifier_len() } - fn output_len(&self) -> usize { - self.summer.output_len() - } - fn joint_rand_len(&self) -> usize { self.summer.joint_rand_len() } @@ -359,6 +343,33 @@ impl Type for Average { } } +impl Type for Average { + type Measurement = F::Integer; + type AggregateResult = f64; + + fn encode_measurement(&self, summand: &F::Integer) -> Result, FlpError> { + self.summer.encode_measurement(summand) + } + + fn truncate(&self, input: Vec) -> Result, FlpError> { + self.summer.truncate(input) + } + + fn decode_result(&self, data: &[F], num_measurements: usize) -> Result { + // Compute the average from the aggregated sum. + let sum = self.summer.decode_result(data, num_measurements)?; + let data: u64 = sum + .try_into() + .map_err(|err| FlpError::Decode(format!("failed to convert {sum:?} to u64: {err}",)))?; + let result = (data as f64) / (num_measurements as f64); + Ok(result) + } + + fn output_len(&self) -> usize { + self.summer.output_len() + } +} + /// The histogram type. Each measurement is an integer in `[0, length)` and the aggregate is a /// histogram counting the number of occurrences of each measurement. #[derive(PartialEq, Eq)] @@ -422,30 +433,13 @@ impl Clone for Histogram { } } -impl Type for Histogram +impl Flp for Histogram where F: FftFriendlyFieldElement, S: ParallelSumGadget> + Eq + 'static, { - type Measurement = usize; - type AggregateResult = Vec; type Field = F; - fn encode_measurement(&self, measurement: &usize) -> Result, FlpError> { - let mut data = vec![F::zero(); self.length]; - - data[*measurement] = F::one(); - Ok(data) - } - - fn decode_result( - &self, - data: &[F], - _num_measurements: usize, - ) -> Result, FlpError> { - decode_result_vec(data, self.length) - } - fn gadget(&self) -> Vec>> { vec![Box::new(S::new( Mul::new(self.gadget_calls), @@ -479,11 +473,6 @@ where Ok(vec![range_check, sum_check]) } - fn truncate(&self, input: Vec) -> Result, FlpError> { - self.truncate_call_check(&input)?; - Ok(input) - } - fn input_len(&self) -> usize { self.length } @@ -496,10 +485,6 @@ where 2 + self.chunk_length * 2 } - fn output_len(&self) -> usize { - self.input_len() - } - fn joint_rand_len(&self) -> usize { self.gadget_calls } @@ -513,6 +498,39 @@ where } } +impl Type for Histogram +where + F: FftFriendlyFieldElement, + S: ParallelSumGadget> + Eq + 'static, +{ + type Measurement = usize; + type AggregateResult = Vec; + + fn encode_measurement(&self, measurement: &usize) -> Result, FlpError> { + let mut data = vec![F::zero(); self.length]; + + data[*measurement] = F::one(); + Ok(data) + } + + fn truncate(&self, input: Vec) -> Result, FlpError> { + self.truncate_call_check(&input)?; + Ok(input) + } + + fn decode_result( + &self, + data: &[F], + _num_measurements: usize, + ) -> Result, FlpError> { + decode_result_vec(data, self.length) + } + + fn output_len(&self) -> usize { + self.input_len() + } +} + /// The multihot counter data type. Each measurement is a list of booleans of length `length`, with /// at most `max_weight` true values, and the aggregate is a histogram counting the number of true /// values at each position across all measurements. @@ -608,57 +626,13 @@ impl Clone for MultihotCountVec { } } -impl Type for MultihotCountVec +impl Flp for MultihotCountVec where F: FftFriendlyFieldElement, S: ParallelSumGadget> + Eq + 'static, { - type Measurement = Vec; - type AggregateResult = Vec; type Field = F; - fn encode_measurement(&self, measurement: &Vec) -> Result, FlpError> { - let weight_reported: usize = measurement.iter().filter(|bit| **bit).count(); - - if measurement.len() != self.length { - return Err(FlpError::Encode(format!( - "unexpected measurement length: got {}; want {}", - measurement.len(), - self.length - ))); - } - if weight_reported > self.max_weight { - return Err(FlpError::Encode(format!( - "unexpected measurement weight: got {}; want ≤{}", - weight_reported, self.max_weight - ))); - } - - // Convert bool vector to field elems - let multihot_vec = measurement - .iter() - // We can unwrap because any Integer type can cast from bool - .map(|bit| F::from(F::valid_integer_try_from(*bit as usize).unwrap())); - - // Encode the measurement weight in binary (actually, the weight plus some offset) - let offset_weight_bits = { - let offset_weight_reported = F::valid_integer_try_from(self.offset + weight_reported)?; - F::encode_as_bitvector(offset_weight_reported, self.bits_for_weight)? - }; - - // Report the concat of the two - Ok(multihot_vec.chain(offset_weight_bits).collect()) - } - - fn decode_result( - &self, - data: &[Self::Field], - _num_measurements: usize, - ) -> Result { - // The aggregate is the same as the decoded result. Just convert to integers - decode_result_vec(data, self.length) - } - fn gadget(&self) -> Vec>> { vec![Box::new(S::new( Mul::new(self.gadget_calls), @@ -698,14 +672,6 @@ where Ok(vec![range_check, weight_check]) } - // Truncates the measurement, removing extra data that was necessary for validity (here, the - // encoded weight), but not important for aggregation - fn truncate(&self, input: Vec) -> Result, FlpError> { - self.truncate_call_check(&input)?; - // Cut off the encoded weight - Ok(input[..self.length].to_vec()) - } - // The length in field elements of the encoded input returned by [`Self::encode_measurement`]. fn input_len(&self) -> usize { self.length + self.bits_for_weight @@ -719,11 +685,6 @@ where 2 + self.chunk_length * 2 } - // The length of the truncated output (i.e., the output of [`Type::truncate`]). - fn output_len(&self) -> usize { - self.length - } - // The number of random values needed in the validity checks fn joint_rand_len(&self) -> usize { self.gadget_calls @@ -738,6 +699,70 @@ where } } +impl Type for MultihotCountVec +where + F: FftFriendlyFieldElement, + S: ParallelSumGadget> + Eq + 'static, +{ + type Measurement = Vec; + type AggregateResult = Vec; + + fn encode_measurement(&self, measurement: &Vec) -> Result, FlpError> { + let weight_reported: usize = measurement.iter().filter(|bit| **bit).count(); + + if measurement.len() != self.length { + return Err(FlpError::Encode(format!( + "unexpected measurement length: got {}; want {}", + measurement.len(), + self.length + ))); + } + if weight_reported > self.max_weight { + return Err(FlpError::Encode(format!( + "unexpected measurement weight: got {}; want ≤{}", + weight_reported, self.max_weight + ))); + } + + // Convert bool vector to field elems + let multihot_vec = measurement + .iter() + // We can unwrap because any Integer type can cast from bool + .map(|bit| F::from(F::valid_integer_try_from(*bit as usize).unwrap())); + + // Encode the measurement weight in binary (actually, the weight plus some offset) + let offset_weight_bits = { + let offset_weight_reported = F::valid_integer_try_from(self.offset + weight_reported)?; + F::encode_as_bitvector(offset_weight_reported, self.bits_for_weight)? + }; + + // Report the concat of the two + Ok(multihot_vec.chain(offset_weight_bits).collect()) + } + + fn decode_result( + &self, + data: &[Self::Field], + _num_measurements: usize, + ) -> Result { + // The aggregate is the same as the decoded result. Just convert to integers + decode_result_vec(data, self.length) + } + + // Truncates the measurement, removing extra data that was necessary for validity (here, the + // encoded weight), but not important for aggregation + fn truncate(&self, input: Vec) -> Result, FlpError> { + self.truncate_call_check(&input)?; + // Cut off the encoded weight + Ok(input[..self.length].to_vec()) + } + + // The length of the truncated output (i.e., the output of [`Type::truncate`]). + fn output_len(&self) -> usize { + self.length + } +} + /// A sequence of integers in range `[0, 2^bits)`. This type uses a neat trick from [[BBCG+19], /// Corollary 4.9] to reduce the proof size to roughly the square root of the input size. /// @@ -837,46 +862,13 @@ impl Clone for SumVec { } } -impl Type for SumVec +impl Flp for SumVec where F: FftFriendlyFieldElement, S: ParallelSumGadget> + Eq + 'static, { - type Measurement = Vec; - type AggregateResult = Vec; type Field = F; - fn encode_measurement(&self, measurement: &Vec) -> Result, FlpError> { - if measurement.len() != self.len { - return Err(FlpError::Encode(format!( - "unexpected measurement length: got {}; want {}", - measurement.len(), - self.len - ))); - } - - let mut flattened = Vec::with_capacity(self.flattened_len); - for summand in measurement.iter() { - if summand > &self.max { - return Err(FlpError::Encode(format!( - "summand exceeds maximum of 2^{}-1", - self.bits - ))); - } - flattened.extend(F::encode_as_bitvector(*summand, self.bits)?); - } - - Ok(flattened) - } - - fn decode_result( - &self, - data: &[F], - _num_measurements: usize, - ) -> Result, FlpError> { - decode_result_vec(data, self.len) - } - fn gadget(&self) -> Vec>> { vec![Box::new(S::new( Mul::new(self.gadget_calls), @@ -901,15 +893,6 @@ where .map(|out| vec![out]) } - fn truncate(&self, input: Vec) -> Result, FlpError> { - self.truncate_call_check(&input)?; - let mut unflattened = Vec::with_capacity(self.len); - for chunk in input.chunks(self.bits) { - unflattened.push(F::decode_bitvector(chunk)?); - } - Ok(unflattened) - } - fn input_len(&self) -> usize { self.flattened_len } @@ -922,10 +905,6 @@ where 2 + self.chunk_length * 2 } - fn output_len(&self) -> usize { - self.len - } - fn joint_rand_len(&self) -> usize { self.gadget_calls } @@ -939,6 +918,59 @@ where } } +impl Type for SumVec +where + F: FftFriendlyFieldElement, + S: ParallelSumGadget> + Eq + 'static, +{ + type Measurement = Vec; + type AggregateResult = Vec; + + fn encode_measurement(&self, measurement: &Vec) -> Result, FlpError> { + if measurement.len() != self.len { + return Err(FlpError::Encode(format!( + "unexpected measurement length: got {}; want {}", + measurement.len(), + self.len + ))); + } + + let mut flattened = Vec::with_capacity(self.flattened_len); + for summand in measurement.iter() { + if summand > &self.max { + return Err(FlpError::Encode(format!( + "summand exceeds maximum of 2^{}-1", + self.bits + ))); + } + flattened.extend(F::encode_as_bitvector(*summand, self.bits)?); + } + + Ok(flattened) + } + + fn truncate(&self, input: Vec) -> Result, FlpError> { + self.truncate_call_check(&input)?; + let mut unflattened = Vec::with_capacity(self.len); + for chunk in input.chunks(self.bits) { + unflattened.push(F::decode_bitvector(chunk)?); + } + Ok(unflattened) + } + + fn decode_result( + &self, + data: &[F], + _num_measurements: usize, + ) -> Result, FlpError> { + decode_result_vec(data, self.len) + } + + fn output_len(&self) -> usize { + self.len + } +} + /// Given a vector `data` of field elements which should contain exactly one entry, return the /// integer representation of that entry. pub(crate) fn decode_result( @@ -1028,7 +1060,7 @@ mod tests { use crate::flp::gadgets::ParallelSum; #[cfg(feature = "multithreaded")] use crate::flp::gadgets::ParallelSumMultithreaded; - use crate::flp::test_utils::FlpTest; + use crate::flp::test_utils::TypeTest; use std::cmp; #[test] @@ -1051,11 +1083,11 @@ mod tests { ); // Test FLP on valid input. - FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&true).unwrap(), &[one]); - FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&false).unwrap(), &[zero]); + TypeTest::expect_valid::<3>(&count, &count.encode_measurement(&true).unwrap(), &[one]); + TypeTest::expect_valid::<3>(&count, &count.encode_measurement(&false).unwrap(), &[zero]); // Test FLP on invalid input. - FlpTest::expect_invalid::<3>(&count, &[TestField::from(1337)]); + TypeTest::expect_invalid::<3>(&count, &[TestField::from(1337)]); // Try running the validity circuit on an input that's too short. count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err(); @@ -1084,7 +1116,7 @@ mod tests { ); // Test FLP on valid input. - FlpTest::expect_valid::<3>( + TypeTest::expect_valid::<3>( &sum, &sum.encode_measurement(&1337).unwrap(), &[TestField::from(1337)], @@ -1093,7 +1125,7 @@ mod tests { { let sum = Sum::new(3).unwrap(); let meas = 1; - FlpTest::expect_valid::<3>( + TypeTest::expect_valid::<3>( &sum, &sum.encode_measurement(&meas).unwrap(), &[TestField::from(meas)], @@ -1103,7 +1135,7 @@ mod tests { { let sum = Sum::new(400).unwrap(); let meas = 237; - FlpTest::expect_valid::<3>( + TypeTest::expect_valid::<3>( &sum, &sum.encode_measurement(&meas).unwrap(), &[TestField::from(meas)], @@ -1115,7 +1147,7 @@ mod tests { let sum = Sum::new((1 << 3) - 1).unwrap(); // The sum+offset value can be whatever. The binariness test should fail first let sum_plus_offset = vec![zero; 3]; - FlpTest::expect_invalid::<3>( + TypeTest::expect_invalid::<3>( &sum, &[&[one, nine, zero], sum_plus_offset.as_slice()].concat(), ); @@ -1123,7 +1155,7 @@ mod tests { { let sum = Sum::new((1 << 5) - 1).unwrap(); let sum_plus_offset = vec![zero; 5]; - FlpTest::expect_invalid::<3>( + TypeTest::expect_invalid::<3>( &sum, &[&[zero, zero, zero, zero, nine], sum_plus_offset.as_slice()].concat(), ); @@ -1199,29 +1231,29 @@ mod tests { ); // Test valid inputs. - FlpTest::expect_valid::<3>( + TypeTest::expect_valid::<3>( &hist, &hist.encode_measurement(&0).unwrap(), &[one, zero, zero], ); - FlpTest::expect_valid::<3>( + TypeTest::expect_valid::<3>( &hist, &hist.encode_measurement(&1).unwrap(), &[zero, one, zero], ); - FlpTest::expect_valid::<3>( + TypeTest::expect_valid::<3>( &hist, &hist.encode_measurement(&2).unwrap(), &[zero, zero, one], ); // Test invalid inputs. - FlpTest::expect_invalid::<3>(&hist, &[zero, zero, nine]); - FlpTest::expect_invalid::<3>(&hist, &[zero, one, one]); - FlpTest::expect_invalid::<3>(&hist, &[one, one, one]); - FlpTest::expect_invalid::<3>(&hist, &[zero, zero, zero]); + TypeTest::expect_invalid::<3>(&hist, &[zero, zero, nine]); + TypeTest::expect_invalid::<3>(&hist, &[zero, one, one]); + TypeTest::expect_invalid::<3>(&hist, &[one, one, one]); + TypeTest::expect_invalid::<3>(&hist, &[zero, zero, zero]); } #[test] @@ -1298,7 +1330,7 @@ mod tests { ); // Test valid inputs with weights 0, 1, and 2 - FlpTest::expect_valid::( + TypeTest::expect_valid::( &multihot_instance, &multihot_instance .encode_measurement(&vec![true, false, false]) @@ -1306,7 +1338,7 @@ mod tests { &[one, zero, zero], ); - FlpTest::expect_valid::( + TypeTest::expect_valid::( &multihot_instance, &multihot_instance .encode_measurement(&vec![false, true, true]) @@ -1314,7 +1346,7 @@ mod tests { &[zero, one, one], ); - FlpTest::expect_valid::( + TypeTest::expect_valid::( &multihot_instance, &multihot_instance .encode_measurement(&vec![false, false, false]) @@ -1325,12 +1357,12 @@ mod tests { // Test invalid inputs. // Not binary - FlpTest::expect_invalid::( + TypeTest::expect_invalid::( &multihot_instance, &[&[zero, zero, nine], &*encoded_weight_plus_offset(1)].concat(), ); // Wrong weight - FlpTest::expect_invalid::( + TypeTest::expect_invalid::( &multihot_instance, &[&[zero, zero, one], &*encoded_weight_plus_offset(2)].concat(), ); @@ -1356,7 +1388,7 @@ mod tests { for len in 1..10 { let chunk_length = cmp::max((len as f64).sqrt() as usize, 1); let sum_vec = f(1, len, chunk_length).unwrap(); - FlpTest::expect_valid_no_output::<3>( + TypeTest::expect_valid_no_output::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![1; len]).unwrap(), ); @@ -1364,7 +1396,7 @@ mod tests { let len = 100; let sum_vec = f(1, len, 10).unwrap(); - FlpTest::expect_valid::<3>( + TypeTest::expect_valid::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![1; len]).unwrap(), &vec![one; len], @@ -1372,7 +1404,7 @@ mod tests { let len = 23; let sum_vec = f(4, len, 4).unwrap(); - FlpTest::expect_valid::<3>( + TypeTest::expect_valid::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![9; len]).unwrap(), &vec![nine; len], @@ -1382,12 +1414,12 @@ mod tests { for len in 1..10 { let chunk_length = cmp::max((len as f64).sqrt() as usize, 1); let sum_vec = f(1, len, chunk_length).unwrap(); - FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; len]); + TypeTest::expect_invalid::<3>(&sum_vec, &vec![nine; len]); } let len = 23; let sum_vec = f(2, len, 4).unwrap(); - FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; 2 * len]); + TypeTest::expect_invalid::<3>(&sum_vec, &vec![nine; 2 * len]); // Round trip let want = vec![1; len]; diff --git a/src/flp/types/fixedpoint_l2.rs b/src/flp/types/fixedpoint_l2.rs index 6342c318..f4b4e477 100644 --- a/src/flp/types/fixedpoint_l2.rs +++ b/src/flp/types/fixedpoint_l2.rs @@ -188,7 +188,7 @@ use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; use crate::flp::types::dp::add_iid_noise_to_field_vec; use crate::flp::types::fixedpoint_l2::compatible_float::CompatibleFloat; use crate::flp::types::parallel_sum_range_checks; -use crate::flp::{FlpError, Gadget, Type, TypeWithNoise}; +use crate::flp::{Flp, FlpError, Gadget, Type, TypeWithNoise}; use crate::vdaf::xof::SeedStreamTurboShake128; use fixed::traits::Fixed; use num_bigint::BigUint; @@ -380,71 +380,14 @@ where } } -impl Type for FixedPointBoundedL2VecSum +impl Flp for FixedPointBoundedL2VecSum where T: Fixed + CompatibleFloat, SPoly: ParallelSumGadget> + Eq + Clone + 'static, SMul: ParallelSumGadget> + Eq + Clone + 'static, { - type Measurement = Vec; - type AggregateResult = Vec; type Field = Field128; - fn encode_measurement(&self, fp_entries: &Vec) -> Result, FlpError> { - if fp_entries.len() != self.entries { - return Err(FlpError::Encode("unexpected input length".into())); - } - - // Convert the fixed-point encoded input values to field integers. We do - // this once here because we need them for encoding but also for - // computing the norm. - let integer_entries = fp_entries.iter().map(|x| x.to_field_integer()); - - // (I) Vector entries. - // Encode the integer entries bitwise, and write them into the `encoded` - // vector. - let mut encoded: Vec = - Vec::with_capacity(self.bits_per_entry * self.entries + self.bits_for_norm); - for entry in integer_entries.clone() { - encoded.extend(Field128::encode_as_bitvector(entry, self.bits_per_entry)?); - } - - // (II) Vector norm. - // Compute the norm of the input vector. - let field_entries = integer_entries.map(Field128::from); - let norm = compute_norm_of_entries(field_entries, self.bits_per_entry)?; - let norm_int = u128::from(norm); - - // Write the norm into the `entries` vector. - encoded.extend(Field128::encode_as_bitvector(norm_int, self.bits_for_norm)?); - - Ok(encoded) - } - - fn decode_result( - &self, - data: &[Field128], - num_measurements: usize, - ) -> Result, FlpError> { - if data.len() != self.entries { - return Err(FlpError::Decode("unexpected input length".into())); - } - let num_measurements = match u128::try_from(num_measurements) { - Ok(m) => m, - Err(_) => { - return Err(FlpError::Decode( - "number of clients is too large to fit into u128".into(), - )) - } - }; - let mut res = Vec::with_capacity(data.len()); - for d in data { - let decoded = ::to_float(*d, num_measurements); - res.push(decoded); - } - Ok(res) - } - fn gadget(&self) -> Vec>> { // This gadget checks that a field element is zero or one. // It is called for all the "bits" of the encoded entries @@ -557,21 +500,6 @@ where Ok(vec![range_check, norm_check]) } - fn truncate(&self, input: Vec) -> Result, FlpError> { - self.truncate_call_check(&input)?; - - let mut decoded_vector = vec![]; - - for i_entry in 0..self.entries { - let start = i_entry * self.bits_per_entry; - let end = (i_entry + 1) * self.bits_per_entry; - - let decoded = Field128::decode_bitvector(&input[start..end])?; - decoded_vector.push(decoded); - } - Ok(decoded_vector) - } - fn input_len(&self) -> usize { self.bits_per_entry * self.entries + self.bits_for_norm } @@ -594,10 +522,6 @@ where self.gadget0_chunk_length * 2 + self.gadget1_chunk_length + 3 } - fn output_len(&self) -> usize { - self.entries - } - fn joint_rand_len(&self) -> usize { self.gadget0_calls } @@ -611,6 +535,90 @@ where } } +impl Type for FixedPointBoundedL2VecSum +where + T: Fixed + CompatibleFloat, + SPoly: ParallelSumGadget> + Eq + Clone + 'static, + SMul: ParallelSumGadget> + Eq + Clone + 'static, +{ + type Measurement = Vec; + type AggregateResult = Vec; + + fn encode_measurement(&self, fp_entries: &Vec) -> Result, FlpError> { + if fp_entries.len() != self.entries { + return Err(FlpError::Encode("unexpected input length".into())); + } + + // Convert the fixed-point encoded input values to field integers. We do + // this once here because we need them for encoding but also for + // computing the norm. + let integer_entries = fp_entries.iter().map(|x| x.to_field_integer()); + + // (I) Vector entries. + // Encode the integer entries bitwise, and write them into the `encoded` + // vector. + let mut encoded: Vec = + Vec::with_capacity(self.bits_per_entry * self.entries + self.bits_for_norm); + for entry in integer_entries.clone() { + encoded.extend(Field128::encode_as_bitvector(entry, self.bits_per_entry)?); + } + + // (II) Vector norm. + // Compute the norm of the input vector. + let field_entries = integer_entries.map(Field128::from); + let norm = compute_norm_of_entries(field_entries, self.bits_per_entry)?; + let norm_int = u128::from(norm); + + // Write the norm into the `entries` vector. + encoded.extend(Field128::encode_as_bitvector(norm_int, self.bits_for_norm)?); + + Ok(encoded) + } + + fn truncate(&self, input: Vec) -> Result, FlpError> { + self.truncate_call_check(&input)?; + + let mut decoded_vector = vec![]; + + for i_entry in 0..self.entries { + let start = i_entry * self.bits_per_entry; + let end = (i_entry + 1) * self.bits_per_entry; + + let decoded = Field128::decode_bitvector(&input[start..end])?; + decoded_vector.push(decoded); + } + Ok(decoded_vector) + } + + fn decode_result( + &self, + data: &[Field128], + num_measurements: usize, + ) -> Result, FlpError> { + if data.len() != self.entries { + return Err(FlpError::Decode("unexpected input length".into())); + } + let num_measurements = match u128::try_from(num_measurements) { + Ok(m) => m, + Err(_) => { + return Err(FlpError::Decode( + "number of clients is too large to fit into u128".into(), + )) + } + }; + let mut res = Vec::with_capacity(data.len()); + for d in data { + let decoded = ::to_float(*d, num_measurements); + res.push(decoded); + } + Ok(res) + } + + fn output_len(&self) -> usize { + self.entries + } +} + impl TypeWithNoise for FixedPointBoundedL2VecSum where @@ -670,7 +678,7 @@ mod tests { use crate::dp::{Rational, ZCdpBudget}; use crate::field::{random_vector, Field128, FieldElement}; use crate::flp::gadgets::ParallelSum; - use crate::flp::test_utils::FlpTest; + use crate::flp::test_utils::TypeTest; use crate::vdaf::xof::SeedStreamTurboShake128; use fixed::types::extra::{U127, U14, U63}; use fixed::types::{I1F15, I1F31, I1F63}; @@ -779,7 +787,7 @@ mod tests { let mut input: Vec = vsum.encode_measurement(&fp_vec).unwrap(); assert_eq!(input[0], Field128::zero()); input[0] = one; // it was zero - FlpTest { + TypeTest { name: None, flp: &vsum, input: &input, @@ -795,7 +803,7 @@ mod tests { // encoding contains entries that are not zero or one let mut input2: Vec = vsum.encode_measurement(&fp_vec).unwrap(); input2[0] = one + one; - FlpTest { + TypeTest { name: None, flp: &vsum, input: &input2, @@ -811,7 +819,7 @@ mod tests { // norm is too big // 2^n - 1, the field element encoded by the all-1 vector let one_enc = Field128::from(((2_u128) << (n - 1)) - 1); - FlpTest { + TypeTest { name: None, flp: &vsum, input: &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors diff --git a/src/vdaf/mastic/szk.rs b/src/vdaf/mastic/szk.rs index 03613b48..bf1c1494 100644 --- a/src/vdaf/mastic/szk.rs +++ b/src/vdaf/mastic/szk.rs @@ -11,6 +11,9 @@ //! here uses an [`Xof`] (to be modeled as a random oracle) to sample coins and the helper's proof share, //! following a strategy similar to [`Prio3`](crate::vdaf::prio3::Prio3). +// The compiler warns this is unused, but if we remove it, compilation fails. +#[allow(unused_imports)] +use crate::flp::Flp; use crate::{ codec::{CodecError, Decode, Encode, ParameterizedDecode}, field::{decode_fieldvec, encode_fieldvec, FieldElement}, diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 94edc192..6bce9dbd 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -46,10 +46,15 @@ use crate::flp::gadgets::{Mul, ParallelSum}; use crate::flp::types::fixedpoint_l2::{ compatible_float::CompatibleFloat, FixedPointBoundedL2VecSum, }; -use crate::flp::types::{Average, Count, Histogram, MultihotCountVec, Sum, SumVec}; -use crate::flp::Type; +// The compiler warns this is unused, but if we remove it, compilation fails. +#[allow(unused_imports)] +use crate::flp::Flp; #[cfg(feature = "experimental")] use crate::flp::TypeWithNoise; +use crate::flp::{ + types::{Average, Count, Histogram, MultihotCountVec, Sum, SumVec}, + Type, +}; use crate::prng::Prng; use crate::vdaf::xof::{IntoFieldVec, Seed, Xof}; use crate::vdaf::{