Skip to content

Commit

Permalink
test-util: Add a method for running Prio3 test vectors (#862)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjpatton authored Dec 2, 2023
1 parent 3ece13a commit 3bbeb52
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 45 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ctr = { version = "0.9.2", optional = true }
fiat-crypto = { version = "0.2.5", optional = true }
fixed = { version = "1.23", optional = true }
getrandom = { version = "0.2.11", features = ["std"] }
hex = { version = "0.4.3", features = ["serde"], optional = true }
hmac = { version = "0.12.1", optional = true }
num-bigint = { version = "0.4.4", optional = true, features = ["rand", "serde"] }
num-integer = { version = "0.1.45", optional = true }
Expand All @@ -28,6 +29,7 @@ rand = { version = "0.8", optional = true }
rand_core = "0.6.4"
rayon = { version = "1.8.0", optional = true }
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0", optional = true }
sha2 = { version = "0.10.8", optional = true }
sha3 = "0.10.8"
subtle = "2.5.0"
Expand All @@ -39,15 +41,13 @@ base64 = "0.21.5"
cfg-if = "1.0.0"
criterion = "0.5"
fixed-macro = "1.2.0"
hex = { version = "0.4.3", features = ["serde"] }
hex-literal = "0.4.1"
iai = "0.1"
modinverse = "0.1.0"
num-bigint = "0.4.4"
once_cell = "1.18.0"
prio = { path = ".", features = ["crypto-dependencies", "test-util"] }
rand = "0.8"
serde_json = "1.0"
statrs = "0.16.0"
zipf = "7.0.1"

Expand All @@ -57,7 +57,7 @@ experimental = ["bitvec", "fiat-crypto", "fixed", "num-bigint", "num-rational",
multithreaded = ["rayon"]
prio2 = ["crypto-dependencies", "hmac", "sha2"]
crypto-dependencies = ["aes", "ctr"]
test-util = ["rand"]
test-util = ["hex", "rand", "serde_json"]

[workspace]
members = [".", "binaries"]
Expand Down
4 changes: 2 additions & 2 deletions src/vdaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,6 @@ pub mod poplar1;
#[cfg_attr(docsrs, doc(cfg(feature = "prio2")))]
pub mod prio2;
pub mod prio3;
#[cfg(test)]
mod prio3_test;
#[cfg(any(test, feature = "test-util"))]
pub mod prio3_test;
pub mod xof;
79 changes: 39 additions & 40 deletions src/vdaf/prio3_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// SPDX-License-Identifier: MPL-2.0

//! Tools for evaluating Prio3 test vectors.
use crate::{
codec::{Encode, ParameterizedDecode},
flp::Type,
Expand Down Expand Up @@ -174,17 +176,34 @@ where
prio3.unshard(&(), aggregate_shares, 1).unwrap()
}

/// Evaluate a Prio3 test vector. The instance of Prio3 is constructed from the `new_vdaf`
/// callback, which takes in the VDAF parameters encoded by the test vectors and the number of
/// shares.
#[cfg(feature = "test-util")]
pub fn check_test_vec<M, A, T, P, const SEED_SIZE: usize>(
test_vec_json_str: &str,
new_vdaf: impl Fn(&HashMap<String, serde_json::Value>, u8) -> Prio3<T, P, SEED_SIZE>,
) where
M: for<'de> Deserialize<'de>,
A: for<'de> Deserialize<'de> + Debug + Eq,
T: Type<Measurement = M, AggregateResult = A>,
P: Xof<SEED_SIZE>,
{
let t: TPrio3<M> = serde_json::from_str(test_vec_json_str).unwrap();
let vdaf = new_vdaf(&t.other_params, t.shares);
let agg_result = check_aggregate_test_vec(&vdaf, &t);
assert_eq!(agg_result, serde_json::from_value(t.agg_result).unwrap());
}

#[test]
fn test_vec_prio3_count() {
for test_vector_str in [
include_str!("test_vec/08/Prio3Count_0.json"),
include_str!("test_vec/08/Prio3Count_1.json"),
] {
let t: TPrio3<u64> = serde_json::from_str(test_vector_str).unwrap();
let prio3 = Prio3::new_count(t.shares).unwrap();

let aggregate_result = check_aggregate_test_vec(&prio3, &t);
assert_eq!(aggregate_result, t.agg_result.as_u64().unwrap());
check_test_vec(test_vector_str, |_json_params, num_shares| {
Prio3::new_count(num_shares).unwrap()
});
}
}

Expand All @@ -194,12 +213,10 @@ fn test_vec_prio3_sum() {
include_str!("test_vec/08/Prio3Sum_0.json"),
include_str!("test_vec/08/Prio3Sum_1.json"),
] {
let t: TPrio3<u128> = serde_json::from_str(test_vector_str).unwrap();
let bits = t.other_params["bits"].as_u64().unwrap() as usize;
let prio3 = Prio3::new_sum(t.shares, bits).unwrap();

let aggregate_result = check_aggregate_test_vec(&prio3, &t);
assert_eq!(aggregate_result, t.agg_result.as_u64().unwrap() as u128);
check_test_vec(test_vector_str, |json_params, num_shares| {
let bits = json_params["bits"].as_u64().unwrap() as usize;
Prio3::new_sum(num_shares, bits).unwrap()
});
}
}

Expand All @@ -209,21 +226,12 @@ fn test_vec_prio3_sum_vec() {
include_str!("test_vec/08/Prio3SumVec_0.json"),
include_str!("test_vec/08/Prio3SumVec_1.json"),
] {
let t: TPrio3<Vec<u128>> = serde_json::from_str(test_vector_str).unwrap();
let bits = t.other_params["bits"].as_u64().unwrap() as usize;
let length = t.other_params["length"].as_u64().unwrap() as usize;
let chunk_length = t.other_params["chunk_length"].as_u64().unwrap() as usize;
let prio3 = Prio3::new_sum_vec(t.shares, bits, length, chunk_length).unwrap();

let aggregate_result = check_aggregate_test_vec(&prio3, &t);
let expected_aggregate_result = t
.agg_result
.as_array()
.unwrap()
.iter()
.map(|val| val.as_u64().unwrap() as u128)
.collect::<Vec<u128>>();
assert_eq!(aggregate_result, expected_aggregate_result);
check_test_vec(test_vector_str, |json_params, num_shares| {
let bits = json_params["bits"].as_u64().unwrap() as usize;
let length = json_params["length"].as_u64().unwrap() as usize;
let chunk_length = json_params["chunk_length"].as_u64().unwrap() as usize;
Prio3::new_sum_vec(num_shares, bits, length, chunk_length).unwrap()
});
}
}

Expand All @@ -233,19 +241,10 @@ fn test_vec_prio3_histogram() {
include_str!("test_vec/08/Prio3Histogram_0.json"),
include_str!("test_vec/08/Prio3Histogram_1.json"),
] {
let t: TPrio3<usize> = serde_json::from_str(test_vector_str).unwrap();
let length = t.other_params["length"].as_u64().unwrap() as usize;
let chunk_length = t.other_params["chunk_length"].as_u64().unwrap() as usize;
let prio3 = Prio3::new_histogram(t.shares, length, chunk_length).unwrap();

let aggregate_result = check_aggregate_test_vec(&prio3, &t);
let expected_aggregate_result = t
.agg_result
.as_array()
.unwrap()
.iter()
.map(|val| val.as_u64().unwrap() as u128)
.collect::<Vec<u128>>();
assert_eq!(aggregate_result, expected_aggregate_result);
check_test_vec(test_vector_str, |json_params, num_shares| {
let length = json_params["length"].as_u64().unwrap() as usize;
let chunk_length = json_params["chunk_length"].as_u64().unwrap() as usize;
Prio3::new_histogram(num_shares, length, chunk_length).unwrap()
});
}
}
6 changes: 6 additions & 0 deletions supply-chain/imports.lock
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,12 @@ criteria = "safe-to-deploy"
delta = "0.1.19 -> 0.2.6"
aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml"

[[audits.mozilla.audits.hex]]
who = "Simon Friedberger <[email protected]>"
criteria = "safe-to-deploy"
version = "0.4.3"
aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml"

[[audits.mozilla.audits.libc]]
who = "Mike Hommey <[email protected]>"
criteria = "safe-to-deploy"
Expand Down

0 comments on commit 3bbeb52

Please sign in to comment.