Skip to content

Commit

Permalink
flp: Split out FLP methods into their own trait
Browse files Browse the repository at this point in the history
Define a new trait, `Flp`, that implements the core FLP proof system.
The remaining methods on `Type` are related to encoding of the
measurement, truncating the output shares, and decoding the aggregate
result.
  • Loading branch information
cjpatton committed Jan 17, 2025
1 parent 9eac4c1 commit 4e2aa2a
Show file tree
Hide file tree
Showing 5 changed files with 439 additions and 382 deletions.
127 changes: 68 additions & 59 deletions src/flp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! Implementation of the generic Fully Linear Proof (FLP) system specified in
//! [[draft-irtf-cfrg-vdaf-08]]. This is the main building block of [`Prio3`](crate::vdaf::prio3).
//!
//! The FLP is derived for any implementation of the [`Type`] trait. Such an implementation
//! The FLP is derived for any implementation of the [`Flp`] trait. Such an implementation
//! specifies a validity circuit that defines the set of valid measurements, as well as the finite
//! field in which the validity circuit is evaluated. It also determines how raw measurements are
//! encoded as inputs to the validity circuit, and how aggregates are decoded from sums of
Expand All @@ -19,7 +19,7 @@
//!
//! ```
//! use prio::flp::types::Count;
//! use prio::flp::Type;
//! use prio::flp::{Type, Flp};
//! use prio::field::{random_vector, FieldElement, Field64};
//!
//! // The prover chooses a measurement.
Expand Down Expand Up @@ -63,15 +63,15 @@ pub mod types;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum FlpError {
/// Calling [`Type::prove`] returned an error.
/// Calling [`Flp::prove`] returned an error.
#[error("prove error: {0}")]
Prove(String),

/// Calling [`Type::query`] returned an error.
/// Calling [`Flp::query`] returned an error.
#[error("query error: {0}")]
Query(String),

/// Calling [`Type::decide`] returned an error.
/// Calling [`Flp::decide`] returned an error.
#[error("decide error: {0}")]
Decide(String),

Expand Down Expand Up @@ -113,35 +113,11 @@ pub enum FlpError {
DifferentialPrivacy(#[from] crate::dp::DpError),
}

/// A type. Implementations of this trait specify how a particular kind of measurement is encoded
/// as a vector of field elements and how validity of the encoded measurement is determined.
/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement.
pub trait Type: Sized + Eq + Clone + Debug {
/// The type of raw measurement to be encoded.
type Measurement: Clone + Debug;

/// The type of aggregate result for this type.
type AggregateResult: Clone + Debug;

/// The FLP proof system. The user specifies the validity circuit.
pub trait Flp: Sized + Eq + Clone + Debug {
/// The finite field used for this type.
type Field: FftFriendlyFieldElement;

/// Encodes a measurement as a vector of [`Self::input_len`] field elements.
fn encode_measurement(
&self,
measurement: &Self::Measurement,
) -> Result<Vec<Self::Field>, FlpError>;

/// Decodes an aggregate result.
///
/// This is NOT the inverse of `encode_measurement`. Rather, the input is an aggregation of
/// truncated measurements.
fn decode_result(
&self,
data: &[Self::Field],
num_measurements: usize,
) -> Result<Self::AggregateResult, FlpError>;

/// Returns the sequence of gadgets associated with the validity circuit.
///
/// # Notes
Expand Down Expand Up @@ -173,7 +149,7 @@ pub trait Type: Sized + Eq + Clone + Debug {
///
/// ```
/// use prio::flp::types::Count;
/// use prio::flp::Type;
/// use prio::flp::{Flp, Type};
/// use prio::field::{random_vector, FieldElement, Field64};
///
/// let count = Count::new();
Expand All @@ -190,11 +166,7 @@ pub trait Type: Sized + Eq + Clone + Debug {
num_shares: usize,
) -> Result<Vec<Self::Field>, FlpError>;

/// Constructs an aggregatable output from an encoded input. Calling this method is only safe
/// once `input` has been validated.
fn truncate(&self, input: Vec<Self::Field>) -> Result<Vec<Self::Field>, FlpError>;

/// The length in field elements of the encoded input returned by [`Self::encode_measurement`].
/// The length in field elements of the input to [`Self::valid`].
fn input_len(&self) -> usize;

/// The length in field elements of the proof generated for this type.
Expand All @@ -203,9 +175,6 @@ pub trait Type: Sized + Eq + Clone + Debug {
/// The length in field elements of the verifier message constructed by [`Self::query`].
fn verifier_len(&self) -> usize;

/// The length of the truncated output (i.e., the output of [`Type::truncate`]).
fn output_len(&self) -> usize;

/// The length of the joint random input.
fn joint_rand_len(&self) -> usize;

Expand Down Expand Up @@ -589,6 +558,40 @@ pub trait Type: Sized + Eq + Clone + Debug {
}
}

/// A type. Implementations of this trait specify how a particular kind of measurement is encoded
/// as a vector of field elements and how validity of the encoded measurement is determined.
/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement.
pub trait Type: Flp {
/// The type of raw measurement to be encoded.
type Measurement: Clone + Debug;

/// The type of aggregate result for this type.
type AggregateResult: Clone + Debug;

/// Encodes a measurement as a vector of `self.input_len()` field elements.
fn encode_measurement(
&self,
measurement: &Self::Measurement,
) -> Result<Vec<Self::Field>, FlpError>;

/// Constructs an aggregatable output from an encoded input. Calling this method is only safe
/// once `input` has been validated.
fn truncate(&self, input: Vec<Self::Field>) -> Result<Vec<Self::Field>, FlpError>;

/// Decodes an aggregate result.
///
/// This is NOT the inverse of `encode_measurement`. Rather, the input is an aggregation of
/// truncated measurements.
fn decode_result(
&self,
data: &[Self::Field],
num_measurements: usize,
) -> Result<Self::AggregateResult, FlpError>;

/// The length of the truncated output (i.e., the output of [`Type::truncate`]).
fn output_len(&self) -> usize;
}

/// A type which supports adding noise to aggregate shares for Server Differential Privacy.
#[cfg(feature = "experimental")]
#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
Expand Down Expand Up @@ -804,7 +807,7 @@ pub mod test_utils {

/// Various tests for an FLP.
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub struct FlpTest<'a, T: Type> {
pub struct TypeTest<'a, T: Type> {
/// The FLP.
pub flp: &'a T,

Expand All @@ -821,15 +824,15 @@ pub mod test_utils {
pub expect_valid: bool,
}

impl<T: Type> FlpTest<'_, T> {
impl<T: Type> TypeTest<'_, T> {
/// Construct a test and run it. Expect the input to be valid and compare the truncated
/// output to the provided value.
pub fn expect_valid<const SHARES: usize>(
flp: &T,
input: &[T::Field],
expected_output: &[T::Field],
) {
FlpTest {
TypeTest {
flp,
name: None,
input,
Expand All @@ -841,7 +844,7 @@ pub mod test_utils {

/// Construct a test and run it. Expect the input to be invalid.
pub fn expect_invalid<const SHARES: usize>(flp: &T, input: &[T::Field]) {
FlpTest {
TypeTest {
flp,
name: None,
input,
Expand All @@ -853,7 +856,7 @@ pub mod test_utils {

/// Construct a test and run it. Expect the input to be valid.
pub fn expect_valid_no_output<const SHARES: usize>(flp: &T, input: &[T::Field]) {
FlpTest {
TypeTest {
flp,
name: None,
input,
Expand Down Expand Up @@ -1077,9 +1080,7 @@ mod tests {
}
}

impl<F: FftFriendlyFieldElement> Type for TestType<F> {
type Measurement = F::Integer;
type AggregateResult = F::Integer;
impl<F: FftFriendlyFieldElement> Flp for TestType<F> {
type Field = F;

fn valid(
Expand Down Expand Up @@ -1132,10 +1133,6 @@ mod tests {
1 + mul + poly
}

fn output_len(&self) -> usize {
self.input_len()
}

fn joint_rand_len(&self) -> usize {
1
}
Expand All @@ -1158,6 +1155,11 @@ mod tests {
fn num_gadgets(&self) -> usize {
2
}
}

impl<F: FftFriendlyFieldElement> Type for TestType<F> {
type Measurement = F::Integer;
type AggregateResult = F::Integer;

fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
Ok(vec![
Expand All @@ -1177,6 +1179,10 @@ mod tests {
) -> Result<F::Integer, FlpError> {
panic!("not implemented");
}

fn output_len(&self) -> usize {
self.input_len()
}
}

// In https://github.com/divviup/libprio-rs/issues/254 an out-of-bounds bug was reported that
Expand Down Expand Up @@ -1215,9 +1221,7 @@ mod tests {
}
}

impl<F: FftFriendlyFieldElement> Type for Issue254Type<F> {
type Measurement = F::Integer;
type AggregateResult = F::Integer;
impl<F: FftFriendlyFieldElement> Flp for Issue254Type<F> {
type Field = F;

fn valid(
Expand Down Expand Up @@ -1265,10 +1269,6 @@ mod tests {
1 + first + second
}

fn output_len(&self) -> usize {
self.input_len()
}

fn joint_rand_len(&self) -> usize {
0
}
Expand Down Expand Up @@ -1298,6 +1298,11 @@ mod tests {
fn num_gadgets(&self) -> usize {
2
}
}

impl<F: FftFriendlyFieldElement> Type for Issue254Type<F> {
type Measurement = F::Integer;
type AggregateResult = F::Integer;

fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
Ok(vec![F::from(*measurement)])
Expand All @@ -1314,5 +1319,9 @@ mod tests {
) -> Result<F::Integer, FlpError> {
panic!("not implemented");
}

fn output_len(&self) -> usize {
self.input_len()
}
}
}
Loading

0 comments on commit 4e2aa2a

Please sign in to comment.