Skip to content

Commit

Permalink
Removed explicit bits param for Sum and Average
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Rosenberg committed Dec 5, 2024
1 parent b215891 commit 828a4cb
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 92 deletions.
2 changes: 1 addition & 1 deletion benches/cycle_counts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ fn prio3_client_histogram_10() -> Vec<Prio3InputShare<Field128, 16>> {

fn prio3_client_sum_32() -> Vec<Prio3InputShare<Field128, 16>> {
let bits = 16;
let prio3 = Prio3::new_sum(2, bits, (1 << bits) - 1).unwrap();
let prio3 = Prio3::new_sum(2, (1 << bits) - 1).unwrap();
let measurement = 1337;
let nonce = [0; 16];
prio3
Expand Down
4 changes: 2 additions & 2 deletions benches/speed_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ fn prio3(c: &mut Criterion) {
group.bench_with_input(BenchmarkId::from_parameter(bits), &bits, |b, bits| {
// Doesn't matter for speed what we use for max measurement, or measurement
let max_measurement = (1 << bits) - 1;
let vdaf = Prio3::new_sum(num_shares, *bits, max_measurement).unwrap();
let vdaf = Prio3::new_sum(num_shares, max_measurement).unwrap();
let measurement = max_measurement;
let nonce = black_box([0u8; 16]);
b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap());
Expand All @@ -212,7 +212,7 @@ fn prio3(c: &mut Criterion) {
for bits in [8, 32] {
group.bench_with_input(BenchmarkId::from_parameter(bits), &bits, |b, bits| {
let max_measurement = (1 << bits) - 1;
let vdaf = Prio3::new_sum(num_shares, *bits, max_measurement).unwrap();
let vdaf = Prio3::new_sum(num_shares, max_measurement).unwrap();
let measurement = max_measurement;
let nonce = black_box([0u8; 16]);
let verify_key = black_box([0u8; 16]);
Expand Down
7 changes: 3 additions & 4 deletions binaries/src/bin/vdaf_message_sizes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,12 @@ fn main() {
)
);

let bits = 32;
let max_measurement = 987_234; // arbitrary number
let prio3 = Prio3::new_sum(num_shares, bits, max_measurement).unwrap();
let max_measurement = 987_234;
let prio3 = Prio3::new_sum(num_shares, max_measurement).unwrap();
let measurement = 1337;
println!(
"prio3 sum ({} bits) share size = {}",
bits,
max_measurement.ilog2() + 1,
vdaf_input_share_size::<Prio3Sum, 16>(
prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap()
)
Expand Down
15 changes: 15 additions & 0 deletions src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ pub trait Integer:

/// Returns one.
fn one() -> Self;

/// Returns ⌊log₂(self)⌋, or `None` if `self == 0`
fn checked_ilog2(&self) -> Option<u32>;
}

/// Extension trait for field elements that can be converted back and forth to an integer type.
Expand Down Expand Up @@ -785,6 +788,10 @@ impl Integer for u32 {
fn one() -> Self {
1
}

fn checked_ilog2(&self) -> Option<u32> {
u32::checked_ilog2(*self)
}
}

impl Integer for u64 {
Expand All @@ -798,6 +805,10 @@ impl Integer for u64 {
fn one() -> Self {
1
}

fn checked_ilog2(&self) -> Option<u32> {
u64::checked_ilog2(*self)
}
}

impl Integer for u128 {
Expand All @@ -811,6 +822,10 @@ impl Integer for u128 {
fn one() -> Self {
1
}

fn checked_ilog2(&self) -> Option<u32> {
u128::checked_ilog2(*self)
}
}

make_field!(
Expand Down
6 changes: 6 additions & 0 deletions src/field/field255.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,12 @@ mod tests {
fn one() -> Self {
Self::new(Vec::from([1]))
}

fn checked_ilog2(&self) -> Option<u32> {
// This is a test module, and this code is never used. If we need this in the future,
// use BigUint::bits()
unimplemented!()
}
}

impl TestFieldElementWithInteger for Field255 {
Expand Down
12 changes: 4 additions & 8 deletions src/flp/szk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,10 +794,9 @@ mod tests {
#[test]
fn test_sum_proof_share_encode() {
let mut nonce = [0u8; 16];
let bits = 5;
let max_measurement = 13;
thread_rng().fill(&mut nonce[..]);
let sum = Sum::<Field128>::new(bits, max_measurement).unwrap();
let sum = Sum::<Field128>::new(max_measurement).unwrap();
let encoded_measurement = sum.encode_measurement(&9).unwrap();
let algorithm_id = 5;
let szk_typ = Szk::new_turboshake128(sum, algorithm_id);
Expand Down Expand Up @@ -898,11 +897,10 @@ mod tests {

#[test]
fn test_sum_leader_proof_share_roundtrip() {
let bits = 5;
let max_measurement = 13;
let mut nonce = [0u8; 16];
thread_rng().fill(&mut nonce[..]);
let sum = Sum::<Field128>::new(bits, max_measurement).unwrap();
let sum = Sum::<Field128>::new(max_measurement).unwrap();
let encoded_measurement = sum.encode_measurement(&9).unwrap();
let algorithm_id = 5;
let szk_typ = Szk::new_turboshake128(sum, algorithm_id);
Expand Down Expand Up @@ -940,11 +938,10 @@ mod tests {

#[test]
fn test_sum_helper_proof_share_roundtrip() {
let bits = 5;
let max_measurement = 13;
let mut nonce = [0u8; 16];
thread_rng().fill(&mut nonce[..]);
let sum = Sum::<Field128>::new(bits, max_measurement).unwrap();
let sum = Sum::<Field128>::new(max_measurement).unwrap();
let encoded_measurement = sum.encode_measurement(&9).unwrap();
let algorithm_id = 5;
let szk_typ = Szk::new_turboshake128(sum, algorithm_id);
Expand Down Expand Up @@ -1144,9 +1141,8 @@ mod tests {

#[test]
fn test_sum() {
let bits = 5;
let max_measurement = 13;
let sum = Sum::<Field128>::new(bits, max_measurement).unwrap();
let sum = Sum::<Field128>::new(max_measurement).unwrap();

let five = Field128::from(5);
let nine = sum.encode_measurement(&9).unwrap();
Expand Down
59 changes: 28 additions & 31 deletions src/flp/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

//! A collection of [`Type`] implementations.
use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt};
use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt, Integer};
use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval};
use crate::flp::{FlpError, Gadget, Type};
use crate::polynomial::poly_range_check;
Expand Down Expand Up @@ -121,41 +121,40 @@ impl<F: FftFriendlyFieldElement> Type for Count<F> {
/// [BBCG+19]: https://ia.cr/2019/188
#[derive(Clone, PartialEq, Eq)]
pub struct Sum<F: FftFriendlyFieldElement> {
bits: usize,
max_measurement: F::Integer,

// Computed from given parameters
// Computed from max_measurement
offset: F::Integer,
bits: usize,
// Constant
bit_range_checker: Vec<F>,
}

impl<F: FftFriendlyFieldElement> Debug for Sum<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sum").field("bits", &self.bits).finish()
f.debug_struct("Sum")
.field("max_measurement", &self.max_measurement)
.field("bits", &self.bits)
.finish()
}
}

impl<F: FftFriendlyFieldElement> Sum<F> {
/// Return a new [`Sum`] type parameter. Each value of this type is an integer in range `[0,
/// max_measurement]` where `max_measurement < 2^bits`.
/// max_measurement]` where `max_measurement > 0`.
///
/// # Panics
/// Panics if 2^bits <= max_measurement
pub fn new(bits: usize, max_measurement: F::Integer) -> Result<Self, FlpError> {
if !F::valid_integer_bitlength(bits) {
return Err(FlpError::Encode(
"invalid bits: number of bits exceeds maximum number of bits in this field"
.to_string(),
));
}
let one = F::Integer::try_from(1).unwrap();
/// Panics if `max_measurement == 0`.
pub fn new(max_measurement: F::Integer) -> Result<Self, FlpError> {
assert!(
(one << bits) > max_measurement,
"2^bits must be greater than max_measurement"
max_measurement > F::Integer::zero(),
"max_measurement must be nonzero"
);
// Number of bits needed to represent x is ⌊log₂(x)⌋ + 1
let bits = max_measurement.checked_ilog2().unwrap() as usize + 1;

// The offset we add to the summand for range-checking purposes
let one = F::Integer::try_from(1).unwrap();
let offset = (one << bits) - one - max_measurement;

// Construct a range checker to ensure encoded bits are in the range [0, 2)
Expand Down Expand Up @@ -277,19 +276,20 @@ pub struct Average<F: FftFriendlyFieldElement> {
impl<F: FftFriendlyFieldElement> Debug for Average<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Average")
.field("max_measurement", &self.summer.max_measurement)
.field("bits", &self.summer.bits)
.finish()
}
}

impl<F: FftFriendlyFieldElement> Average<F> {
/// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0,
/// max_measurement]` where `max_measurement < 2^bits`.
/// max_measurement]` where `max_measurement > 0`.
///
/// # Panics
/// Panics if 2^bits <= max_measurement
pub fn new(bits: usize, max_measurement: F::Integer) -> Result<Self, FlpError> {
let summer = Sum::new(bits, max_measurement)?;
/// Panics if `max_measurement == 0`
pub fn new(max_measurement: F::Integer) -> Result<Self, FlpError> {
let summer = Sum::new(max_measurement)?;
Ok(Average { summer })
}
}
Expand Down Expand Up @@ -1072,10 +1072,9 @@ mod tests {

#[test]
fn test_sum() {
let bits = 11;
let max_measurement = 1458; // arbitrary number < 2^11
let max_measurement = 1458;

let sum = Sum::new(bits, max_measurement).unwrap();
let sum = Sum::new(max_measurement).unwrap();
let zero = TestField::zero();
let one = TestField::one();
let nine = TestField::from(9);
Expand All @@ -1096,10 +1095,9 @@ mod tests {
&sum.encode_measurement(&1337).unwrap(),
&[TestField::from(1337)],
);
FlpTest::expect_valid::<3>(&Sum::new(0, 0).unwrap(), &[], &[zero]);

{
let sum = Sum::new(2, 3).unwrap();
let sum = Sum::new(3).unwrap();
let meas = 1;
FlpTest::expect_valid::<3>(
&sum,
Expand All @@ -1109,7 +1107,7 @@ mod tests {
}

{
let sum = Sum::new(9, 400).unwrap();
let sum = Sum::new(400).unwrap();
let meas = 237;
FlpTest::expect_valid::<3>(
&sum,
Expand All @@ -1120,7 +1118,7 @@ mod tests {

// Test FLP on invalid input, specifically on field elements outside of {0,1}
{
let sum = Sum::new(3, (1 << 3) - 1).unwrap();
let sum = Sum::new((1 << 3) - 1).unwrap();
// The sum+offset value can be whatever. The binariness test should fail first
let sum_plus_offset = vec![zero; 3];
FlpTest::expect_invalid::<3>(
Expand All @@ -1129,7 +1127,7 @@ mod tests {
);
}
{
let sum = Sum::new(5, (1 << 5) - 1).unwrap();
let sum = Sum::new((1 << 5) - 1).unwrap();
let sum_plus_offset = vec![zero; 5];
FlpTest::expect_invalid::<3>(
&sum,
Expand All @@ -1140,10 +1138,9 @@ mod tests {

#[test]
fn test_average() {
let max_measurement = 13;
let bits = 11;
let max_measurement = (1 << 11) - 13;

let average = Average::new(bits, max_measurement).unwrap();
let average = Average::new(max_measurement).unwrap();
let zero = TestField::zero();
let one = TestField::one();
let ten = TestField::from(10);
Expand Down
6 changes: 2 additions & 4 deletions src/vdaf/mastic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,8 @@ mod tests {
#[test]
fn test_mastic_shard_sum() {
let algorithm_id = 6;
let bits = 5;
let max_measurement = 29;
let sum_typ = Sum::<Field128>::new(bits, max_measurement).unwrap();
let sum_typ = Sum::<Field128>::new(max_measurement).unwrap();
let encoded_meas_len = sum_typ.input_len();

let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id);
Expand All @@ -418,9 +417,8 @@ mod tests {
#[test]
fn test_input_share_encode_sum() {
let algorithm_id = 6;
let bits = 5;
let max_measurement = 29;
let sum_typ = Sum::<Field128>::new(bits, max_measurement).unwrap();
let sum_typ = Sum::<Field128>::new(max_measurement).unwrap();
let encoded_meas_len = sum_typ.input_len();

let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id);
Expand Down
Loading

0 comments on commit 828a4cb

Please sign in to comment.