Skip to content

Commit

Permalink
Prio3: Decouple the algorithm ID from the FLP circuit
Browse files Browse the repository at this point in the history
Move `Type::ID` to `Prio3::new()`. This allows new Prio3 variants to be
befined for a standard circuit.
  • Loading branch information
cjpatton committed Nov 29, 2023
1 parent 5ba8b3b commit a6689de
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 23 deletions.
5 changes: 0 additions & 5 deletions src/flp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ pub enum FlpError {
/// as a vector of field elements and how validity of the encoded measurement is determined.
/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement.
pub trait Type: Sized + Eq + Clone + Debug {
/// The Prio3 VDAF identifier corresponding to this type.
const ID: u32;

/// The type of raw measurement to be encoded.
type Measurement: Clone + Debug;

Expand Down Expand Up @@ -826,7 +823,6 @@ mod tests {
}

impl<F: FftFriendlyFieldElement> Type for TestType<F> {
const ID: u32 = 0xFFFF0000;
type Measurement = F::Integer;
type AggregateResult = F::Integer;
type Field = F;
Expand Down Expand Up @@ -961,7 +957,6 @@ mod tests {
}

impl<F: FftFriendlyFieldElement> Type for Issue254Type<F> {
const ID: u32 = 0xFFFF0000;
type Measurement = F::Integer;
type AggregateResult = F::Integer;
type Field = F;
Expand Down
5 changes: 0 additions & 5 deletions src/flp/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ impl<F: FftFriendlyFieldElement> Default for Count<F> {
}

impl<F: FftFriendlyFieldElement> Type for Count<F> {
const ID: u32 = 0x00000000;
type Measurement = F::Integer;
type AggregateResult = F::Integer;
type Field = F;
Expand Down Expand Up @@ -140,7 +139,6 @@ impl<F: FftFriendlyFieldElement> Sum<F> {
}

impl<F: FftFriendlyFieldElement> Type for Sum<F> {
const ID: u32 = 0x00000001;
type Measurement = F::Integer;
type AggregateResult = F::Integer;
type Field = F;
Expand Down Expand Up @@ -239,7 +237,6 @@ impl<F: FftFriendlyFieldElement> Average<F> {
}

impl<F: FftFriendlyFieldElement> Type for Average<F> {
const ID: u32 = 0xFFFF0000;
type Measurement = F::Integer;
type AggregateResult = f64;
type Field = F;
Expand Down Expand Up @@ -380,7 +377,6 @@ where
F: FftFriendlyFieldElement,
S: ParallelSumGadget<F, Mul<F>> + Eq + 'static,
{
const ID: u32 = 0x00000003;
type Measurement = usize;
type AggregateResult = Vec<F::Integer>;
type Field = F;
Expand Down Expand Up @@ -574,7 +570,6 @@ where
F: FftFriendlyFieldElement,
S: ParallelSumGadget<F, Mul<F>> + Eq + 'static,
{
const ID: u32 = 0x00000002;
type Measurement = Vec<F::Integer>;
type AggregateResult = Vec<F::Integer>;
type Field = F;
Expand Down
1 change: 0 additions & 1 deletion src/flp/types/fixedpoint_l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ where
SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Eq + Clone + 'static,
SMul: ParallelSumGadget<Field128, Mul<Field128>> + Eq + Clone + 'static,
{
const ID: u32 = 0xFFFF0000;
type Measurement = Vec<T>;
type AggregateResult = Vec<f64>;
type Field = Field128;
Expand Down
1 change: 1 addition & 0 deletions src/vdaf/poplar1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Poplar1<P, SEED_SIZE> {
}

/// Evaluate the IDPF at the given prefixes and compute the Aggregator's share of the sketch.
#[allow(clippy::too_many_arguments)]
fn eval_and_sketch<F>(
&self,
verify_key: &[u8; SEED_SIZE],
Expand Down
62 changes: 50 additions & 12 deletions src/vdaf/prio3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub type Prio3Count = Prio3<Count<Field64>, XofTurboShake128, 16>;
impl Prio3Count {
/// Construct an instance of Prio3Count with the given number of aggregators.
pub fn new_count(num_aggregators: u8) -> Result<Self, VdafError> {
Prio3::new(num_aggregators, 1, Count::new())
Prio3::new(num_aggregators, 1, 0x00000000, Count::new())
}
}

Expand All @@ -96,7 +96,12 @@ impl Prio3SumVec {
len: usize,
chunk_length: usize,
) -> Result<Self, VdafError> {
Prio3::new(num_aggregators, 1, SumVec::new(bits, len, chunk_length)?)
Prio3::new(
num_aggregators,
1,
0x00000002,
SumVec::new(bits, len, chunk_length)?,
)
}
}

Expand All @@ -121,7 +126,12 @@ impl Prio3SumVecMultithreaded {
len: usize,
chunk_length: usize,
) -> Result<Self, VdafError> {
Prio3::new(num_aggregators, 1, SumVec::new(bits, len, chunk_length)?)
Prio3::new(
num_aggregators,
1,
0x00000002,
SumVec::new(bits, len, chunk_length)?,
)
}
}

Expand All @@ -139,7 +149,7 @@ impl Prio3Sum {
)));
}

Prio3::new(num_aggregators, 1, Sum::new(bits)?)
Prio3::new(num_aggregators, 1, 0x00000001, Sum::new(bits)?)
}
}

Expand Down Expand Up @@ -176,7 +186,12 @@ impl<Fx: Fixed + CompatibleFloat> Prio3FixedPointBoundedL2VecSum<Fx> {
entries: usize,
) -> Result<Self, VdafError> {
check_num_aggregators(num_aggregators)?;
Prio3::new(num_aggregators, 1, FixedPointBoundedL2VecSum::new(entries)?)
Prio3::new(
num_aggregators,
1,
0xFFFF0000,
FixedPointBoundedL2VecSum::new(entries)?,
)
}
}

Expand Down Expand Up @@ -207,7 +222,12 @@ impl<Fx: Fixed + CompatibleFloat> Prio3FixedPointBoundedL2VecSumMultithreaded<Fx
entries: usize,
) -> Result<Self, VdafError> {
check_num_aggregators(num_aggregators)?;
Prio3::new(num_aggregators, 1, FixedPointBoundedL2VecSum::new(entries)?)
Prio3::new(
num_aggregators,
1,
0xFFFF0000,
FixedPointBoundedL2VecSum::new(entries)?,
)
}
}

Expand All @@ -224,7 +244,12 @@ impl Prio3Histogram {
length: usize,
chunk_length: usize,
) -> Result<Self, VdafError> {
Prio3::new(num_aggregators, 1, Histogram::new(length, chunk_length)?)
Prio3::new(
num_aggregators,
1,
0x00000003,
Histogram::new(length, chunk_length)?,
)
}
}

Expand All @@ -247,7 +272,12 @@ impl Prio3HistogramMultithreaded {
length: usize,
chunk_length: usize,
) -> Result<Self, VdafError> {
Prio3::new(num_aggregators, 1, Histogram::new(length, chunk_length)?)
Prio3::new(
num_aggregators,
1,
0x00000003,
Histogram::new(length, chunk_length)?,
)
}
}

Expand All @@ -270,6 +300,7 @@ impl Prio3Average {
Ok(Prio3 {
num_aggregators,
num_proofs: 1,
algorithm_id: 0xFFFF0000,
typ: Average::new(bits)?,
phantom: PhantomData,
})
Expand Down Expand Up @@ -347,6 +378,7 @@ where
{
num_aggregators: u8,
num_proofs: u8,
algorithm_id: u32,
typ: T,
phantom: PhantomData<P>,
}
Expand All @@ -357,8 +389,13 @@ where
P: Xof<SEED_SIZE>,
{
/// Construct an instance of this Prio3 VDAF with the given number of aggregators, number of
/// proofs to generate and verify, and the underlying type.
pub fn new(num_aggregators: u8, num_proofs: u8, typ: T) -> Result<Self, VdafError> {
/// proofs to generate and verify, the algorithm ID, and the underlying type.
pub fn new(
num_aggregators: u8,
num_proofs: u8,
algorithm_id: u32,
typ: T,
) -> Result<Self, VdafError> {
check_num_aggregators(num_aggregators)?;
if num_proofs == 0 {
return Err(VdafError::Uncategorized(
Expand All @@ -369,6 +406,7 @@ where
Ok(Self {
num_aggregators,
num_proofs,
algorithm_id,
typ,
phantom: PhantomData,
})
Expand Down Expand Up @@ -659,7 +697,7 @@ where
type AggregateShare = AggregateShare<T::Field>;

fn algorithm_id(&self) -> u32 {
T::ID
self.algorithm_id
}

fn num_aggregators(&self) -> usize {
Expand Down Expand Up @@ -1620,7 +1658,7 @@ mod tests {
SumVec<Field128, ParallelSum<Field128, Mul<Field128>>>,
XofTurboShake128,
16,
>::new(2, 2, SumVec::new(2, 20, 4).unwrap())
>::new(2, 2, 0xFFFF0000, SumVec::new(2, 20, 4).unwrap())
.unwrap();

assert_eq!(
Expand Down

0 comments on commit a6689de

Please sign in to comment.