From 9c9fe40c569124839a442bb3eb87062cdfd37be3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 14:09:02 +0000 Subject: [PATCH 01/32] build(deps): Bump serde from 1.0.207 to 1.0.208 (#1110) --- Cargo.lock | 8 ++++---- supply-chain/audits.toml | 4 ++-- supply-chain/imports.lock | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c1a4cf445..347875cbc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -957,18 +957,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.207" +version = "1.0.208" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5665e14a49a4ea1b91029ba7d3bca9f299e1f7cfa194388ccc20f14743e784f2" +checksum = "cff085d2cb684faa248efb494c39b68e522822ac0de72ccf08109abde717cfb2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.207" +version = "1.0.208" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6aea2634c86b0e8ef2cfdc0c340baede54ec27b1e46febd7f80dffb2aa44a00e" +checksum = "24008e81ff7613ed8e5ba0cfaf24e2c2f1e5b8a0495711e44fcd4882fca62bcf" dependencies = [ "proc-macro2", "quote", diff --git a/supply-chain/audits.toml b/supply-chain/audits.toml index fe2a16716..14fd85d45 100644 --- a/supply-chain/audits.toml +++ b/supply-chain/audits.toml @@ -942,13 +942,13 @@ end = "2024-06-08" criteria = "safe-to-deploy" user-id = 3618 # David Tolnay (dtolnay) start = "2019-03-01" -end = "2024-06-08" +end = "2025-06-08" [[trusted.serde_derive]] criteria = "safe-to-deploy" user-id = 3618 # David Tolnay (dtolnay) start = "2019-03-01" -end = "2024-06-08" +end = "2025-06-08" [[trusted.serde_json]] criteria = "safe-to-deploy" diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index 1c3952b44..14f1be4d9 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -114,15 +114,15 @@ user-login = "Amanieu" user-name = "Amanieu d'Antras" [[publisher.serde]] -version = "1.0.203" -when = "2024-05-25" +version = "1.0.208" +when = "2024-08-15" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde_derive]] -version = "1.0.203" -when = "2024-05-25" +version = "1.0.208" +when = "2024-08-15" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From 99eee6ede30cc1ec05ea078bd43207269635779d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 14:11:20 +0000 Subject: [PATCH 02/32] build(deps): Bump serde_json from 1.0.122 to 1.0.125 (#1109) --- Cargo.lock | 4 ++-- supply-chain/config.toml | 4 ---- supply-chain/imports.lock | 20 ++++++++++++++++++-- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 347875cbc..8ebdef81f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -977,9 +977,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.122" +version = "1.0.125" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784b6203951c57ff748476b126ccb5e8e2959a5c19e5c617ab1956be3dbc68da" +checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed" dependencies = [ "itoa", "memchr", diff --git a/supply-chain/config.toml b/supply-chain/config.toml index 213f9df73..f03642d9a 100644 --- a/supply-chain/config.toml +++ b/supply-chain/config.toml @@ -45,10 +45,6 @@ criteria = "safe-to-run" version = "1.2.1" criteria = "safe-to-deploy" -[[exemptions.bitflags]] -version = "1.3.2" -criteria = "safe-to-run" - [[exemptions.bitvec]] version = "1.0.1" criteria = "safe-to-deploy" diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index 14f1be4d9..d87b6ece9 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -128,8 +128,8 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde_json]] -version = "1.0.122" -when = "2024-08-01" +version = "1.0.125" +when = "2024-08-15" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" @@ -367,6 +367,22 @@ who = "Radu Matei " criteria = "safe-to-run" version = "0.3.3" +[[audits.google.audits.bitflags]] +who = "Lukasz Anforowicz " +criteria = "safe-to-deploy" +version = "1.3.2" +notes = """ +Security review of earlier versions of the crate can be found at +(Google-internal, sorry): go/image-crate-chromium-security-review + +The crate exposes a function marked as `unsafe`, but doesn't use any +`unsafe` blocks (except for tests of the single `unsafe` function). I +think this justifies marking this crate as `ub-risk-1`. + +Additional review comments can be found at https://crrev.com/c/4723145/31 +""" +aggregated-from = "https://chromium.googlesource.com/chromium/src/+/main/third_party/rust/chromium_crates_io/supply-chain/audits.toml?format=TEXT" + [[audits.google.audits.cast]] who = "George Burgess IV " criteria = "safe-to-run" From c842f2d62f12db5c1be362335913b434ca07a72a Mon Sep 17 00:00:00 2001 From: Armando Faz Date: Tue, 27 Aug 2024 09:18:39 -0700 Subject: [PATCH 03/32] Use generic implementations for prime field operations. (#1099) This is a refactor of the FieldParameters struct to use generic datatypes providing implementations for primes that fit in one primitive word. Tests script and documentation parameters were updated to support the new structure. Performance for FieldPrio2 operation is twice faster. --- Cargo.toml | 4 +- documentation/field_parameters.sage | 28 +- src/field.rs | 55 +- src/fp.rs | 795 ++++++++++------------------ src/fp/ops.rs | 429 +++++++++++++++ src/fp64.rs | 307 ----------- src/lib.rs | 1 - 7 files changed, 753 insertions(+), 866 deletions(-) create mode 100644 src/fp/ops.rs delete mode 100644 src/fp64.rs diff --git a/Cargo.toml b/Cargo.toml index a88e6985b..3c288beeb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ num-bigint = { version = "0.4.6", optional = true, features = ["rand", "serde"] num-integer = { version = "0.1.46", optional = true } num-iter = { version = "0.1.45", optional = true } num-rational = { version = "0.4.2", optional = true, features = ["serde"] } -num-traits = { version = "0.2.19", optional = true } +num-traits = "0.2.19" rand = "0.8" rand_core = "0.6.4" rayon = { version = "1.10.0", optional = true } @@ -52,7 +52,7 @@ statrs = "0.17.1" [features] default = ["crypto-dependencies"] -experimental = ["bitvec", "fiat-crypto", "fixed", "num-bigint", "num-rational", "num-traits", "num-integer", "num-iter"] +experimental = ["bitvec", "fiat-crypto", "fixed", "num-bigint", "num-rational", "num-integer", "num-iter"] multithreaded = ["rayon"] crypto-dependencies = ["aes", "ctr", "hmac", "sha2"] test-util = ["hex", "serde_json", "zipf"] diff --git a/documentation/field_parameters.sage b/documentation/field_parameters.sage index 328984d57..377729bca 100755 --- a/documentation/field_parameters.sage +++ b/documentation/field_parameters.sage @@ -99,21 +99,27 @@ class Field: for i in range(min(self.num_roots, 20) + 1) ] + def log2_base(self): + """ + Returns log2(r), where r is the base used for multiprecision arithmetic. + """ + return log(self.r, 2) + + def log2_radix(self): + """ + Returns log2(R), where R is the machine word-friendly modulus + used in the Montgomery representation. + """ + return log(self.R, 2) + FIELDS = [ Field( - "FieldPrio2, u128", + "FieldPrio2, u32", 2 ^ 20 * 4095 + 1, 3925978153, - 2 ^ 64, - 2 ^ 128, - ), - Field( - "Field64, u128", - 2 ^ 32 * 4294967295 + 1, - pow(7, 4294967295, 2 ^ 32 * 4294967295 + 1), - 2 ^ 64, - 2 ^ 128, + 2 ^ 32, + 2 ^ 32, ), Field( "Field64, u64", @@ -140,4 +146,6 @@ for field in FIELDS: print(f"bit_mask: {field.bit_mask()}") print("roots:") pprint.pprint(field.roots()) + print(f"log2_base: {field.log2_base()}") + print(f"log2_radix: {field.log2_radix()}") print() diff --git a/src/field.rs b/src/field.rs index 3c329a455..7b8460a4b 100644 --- a/src/field.rs +++ b/src/field.rs @@ -9,8 +9,7 @@ use crate::{ codec::{CodecError, Decode, Encode}, - fp::{FP128, FP32}, - fp64::FP64, + fp::{FieldOps, FieldParameters, FP128, FP32, FP64}, prng::{Prng, PrngError}, }; use rand::{ @@ -465,12 +464,12 @@ macro_rules! make_field { int &= mask; - if int >= $fp.p { + if int >= $fp::PRIME { return Err(FieldError::ModulusOverflow); } // FieldParameters::montgomery() will return a value that has been fully reduced // mod p, satisfying the invariant on Self. - Ok(Self($fp.montgomery(int))) + Ok(Self($fp::montgomery(int))) } } @@ -481,8 +480,8 @@ macro_rules! make_field { // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq // Check the invariant that the integer representation is fully reduced. - debug_assert!(self.0 < $fp.p); - debug_assert!(rhs.0 < $fp.p); + debug_assert!(self.0 < $fp::PRIME); + debug_assert!(rhs.0 < $fp::PRIME); self.0 == rhs.0 } @@ -507,7 +506,7 @@ macro_rules! make_field { // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq // Check the invariant that the integer representation is fully reduced. - debug_assert!(self.0 < $fp.p); + debug_assert!(self.0 < $fp::PRIME); self.0.hash(state); } @@ -520,7 +519,7 @@ macro_rules! make_field { fn add(self, rhs: Self) -> Self { // FieldParameters::add() returns a value that has been fully reduced // mod p, satisfying the invariant on Self. - Self($fp.add(self.0, rhs.0)) + Self($fp::add(self.0, rhs.0)) } } @@ -542,7 +541,7 @@ macro_rules! make_field { fn sub(self, rhs: Self) -> Self { // We know that self.0 and rhs.0 are both less than p, thus FieldParameters::sub() // returns a value less than p, satisfying the invariant on Self. - Self($fp.sub(self.0, rhs.0)) + Self($fp::sub(self.0, rhs.0)) } } @@ -564,7 +563,7 @@ macro_rules! make_field { fn mul(self, rhs: Self) -> Self { // FieldParameters::mul() always returns a value less than p, so the invariant on // Self is satisfied. - Self($fp.mul(self.0, rhs.0)) + Self($fp::mul(self.0, rhs.0)) } } @@ -607,7 +606,7 @@ macro_rules! make_field { fn neg(self) -> Self { // FieldParameters::neg() will return a value less than p because self.0 is less // than p, and neg() dispatches to sub(). - Self($fp.neg(self.0)) + Self($fp::neg(self.0)) } } @@ -622,19 +621,19 @@ macro_rules! make_field { fn from(x: $int_conversion) -> Self { // FieldParameters::montgomery() will return a value that has been fully reduced // mod p, satisfying the invariant on Self. - Self($fp.montgomery($int_internal::try_from(x).unwrap())) + Self($fp::montgomery($int_internal::try_from(x).unwrap())) } } impl From<$elem> for $int_conversion { fn from(x: $elem) -> Self { - $int_conversion::try_from($fp.residue(x.0)).unwrap() + $int_conversion::try_from($fp::residue(x.0)).unwrap() } } impl PartialEq<$int_conversion> for $elem { fn eq(&self, rhs: &$int_conversion) -> bool { - $fp.residue(self.0) == $int_internal::try_from(*rhs).unwrap() + $fp::residue(self.0) == $int_internal::try_from(*rhs).unwrap() } } @@ -648,7 +647,7 @@ macro_rules! make_field { impl From<$elem> for [u8; $elem::ENCODED_SIZE] { fn from(elem: $elem) -> Self { - let int = $fp.residue(elem.0); + let int = $fp::residue(elem.0); let mut slice = [0; $elem::ENCODED_SIZE]; for i in 0..$elem::ENCODED_SIZE { slice[i] = ((int >> (i << 3)) & 0xff) as u8; @@ -665,13 +664,13 @@ macro_rules! make_field { impl Display for $elem { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - write!(f, "{}", $fp.residue(self.0)) + write!(f, "{}", $fp::residue(self.0)) } } impl Debug for $elem { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", $fp.residue(self.0)) + write!(f, "{}", $fp::residue(self.0)) } } @@ -719,11 +718,11 @@ macro_rules! make_field { fn inv(&self) -> Self { // FieldParameters::inv() ultimately relies on mul(), and will always return a // value less than p. - Self($fp.inv(self.0)) + Self($fp::inv(self.0)) } fn try_from_random(bytes: &[u8]) -> Result { - $elem::try_from_bytes(bytes, $fp.bit_mask) + $elem::try_from_bytes(bytes, $fp::BIT_MASK) } fn zero() -> Self { @@ -731,7 +730,7 @@ macro_rules! make_field { } fn one() -> Self { - Self($fp.roots[0]) + Self($fp::ROOTS[0]) } } @@ -741,26 +740,26 @@ macro_rules! make_field { fn pow(&self, exp: Self::Integer) -> Self { // FieldParameters::pow() relies on mul(), and will always return a value less // than p. - Self($fp.pow(self.0, $int_internal::try_from(exp).unwrap())) + Self($fp::pow(self.0, $int_internal::try_from(exp).unwrap())) } fn modulus() -> Self::Integer { - $fp.p as $int_conversion + $fp::PRIME as $int_conversion } } impl FftFriendlyFieldElement for $elem { fn generator() -> Self { - Self($fp.g) + Self($fp::G) } fn generator_order() -> Self::Integer { - 1 << (Self::Integer::try_from($fp.num_roots).unwrap()) + 1 << (Self::Integer::try_from($fp::NUM_ROOTS).unwrap()) } fn root(l: usize) -> Option { - if l < min($fp.roots.len(), $fp.num_roots+1) { - Some(Self($fp.roots[l])) + if l < min($fp::ROOTS.len(), $fp::NUM_ROOTS+1) { + Some(Self($fp::ROOTS[l])) } else { None } @@ -815,9 +814,9 @@ impl Integer for u128 { } make_field!( - /// Same as Field32, but encoded in little endian for compatibility with Prio v2. + /// `GF(4293918721)`, a 32-bit field. FieldPrio2, - u128, + u32, u32, FP32, 4, diff --git a/src/fp.rs b/src/fp.rs index da9fe0fc9..956ba63c6 100644 --- a/src/fp.rs +++ b/src/fp.rs @@ -2,305 +2,96 @@ //! Finite field arithmetic for any field GF(p) for which p < 2^128. +#[macro_use] +mod ops; + +pub use ops::{FieldOps, FieldParameters}; + /// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots /// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This /// is the largest input size we would ever need for the cryptographic applications in this crate. pub(crate) const MAX_ROOTS: usize = 20; -/// This structure represents the parameters of a finite field GF(p) for which p < 2^128. -#[derive(Debug, PartialEq, Eq)] -pub(crate) struct FieldParameters { - /// The prime modulus `p`. - pub p: u128, - /// `mu = -p^(-1) mod 2^64`. - pub mu: u64, - /// `r2 = (2^128)^2 mod p`. - pub r2: u128, - /// The `2^num_roots`-th -principal root of unity. This element is used to generate the - /// elements of `roots`. - pub g: u128, - /// The number of principal roots of unity in `roots`. - pub num_roots: usize, - /// Equal to `2^b - 1`, where `b` is the length of `p` in bits. - pub bit_mask: u128, - /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the - /// multiplicative group. `roots[0]` is equal to one by definition. - pub roots: [u128; MAX_ROOTS + 1], -} - -impl FieldParameters { - /// Addition. The result will be in [0, p), so long as both x and y are as well. - #[inline(always)] - pub fn add(&self, x: u128, y: u128) -> u128 { - // 0,x - // + 0,y - // ===== - // c,z - let (z, carry) = x.overflowing_add(y); - // c, z - // - 0, p - // ======== - // b1,s1,s0 - let (s0, b0) = z.overflowing_sub(self.p); - let (_s1, b1) = (carry as u128).overflowing_sub(b0 as u128); - // if b1 == 1: return z - // else: return s0 - let m = 0u128.wrapping_sub(b1 as u128); - (z & m) | (s0 & !m) - } - - /// Subtraction. The result will be in [0, p), so long as both x and y are as well. - #[inline(always)] - pub fn sub(&self, x: u128, y: u128) -> u128 { - // x - // - y - // ======== - // b0,z0 - let (z0, b0) = x.overflowing_sub(y); - let m = 0u128.wrapping_sub(b0 as u128); - // z0 - // + p - // ======== - // s1,s0 - z0.wrapping_add(m & self.p) - // if b1 == 1: return s0 - // else: return z0 - } - - /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm - /// described [here][montgomery]. The result will be in [0, p). - /// - /// # Example usage - /// ```text - /// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46); - /// ``` - /// - /// [montgomery]: https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf - #[inline(always)] - pub fn mul(&self, x: u128, y: u128) -> u128 { - let x = [lo64(x), hi64(x)]; - let y = [lo64(y), hi64(y)]; - let p = [lo64(self.p), hi64(self.p)]; - let mut zz = [0; 4]; - - // Integer multiplication - // z = x * y - - // x1,x0 - // * y1,y0 - // =========== - // z3,z2,z1,z0 - let mut result = x[0] * y[0]; - let mut carry = hi64(result); - zz[0] = lo64(result); - result = x[0] * y[1]; - let mut hi = hi64(result); - let mut lo = lo64(result); - result = lo + carry; - zz[1] = lo64(result); - let mut cc = hi64(result); - result = hi + cc; - zz[2] = lo64(result); - - result = x[1] * y[0]; - hi = hi64(result); - lo = lo64(result); - result = zz[1] + lo; - zz[1] = lo64(result); - cc = hi64(result); - result = hi + cc; - carry = lo64(result); - - result = x[1] * y[1]; - hi = hi64(result); - lo = lo64(result); - result = lo + carry; - lo = lo64(result); - cc = hi64(result); - result = hi + cc; - hi = lo64(result); - result = zz[2] + lo; - zz[2] = lo64(result); - cc = hi64(result); - result = hi + cc; - zz[3] = lo64(result); - - // Montgomery Reduction - // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64. - - // z3,z2,z1,z0 - // + p1,p0 - // * w = mu*z0 - // =========== - // z3,z2,z1, 0 - let w = self.mu.wrapping_mul(zz[0] as u64); - result = p[0] * (w as u128); - hi = hi64(result); - lo = lo64(result); - result = zz[0] + lo; - zz[0] = lo64(result); - cc = hi64(result); - result = hi + cc; - carry = lo64(result); - - result = p[1] * (w as u128); - hi = hi64(result); - lo = lo64(result); - result = lo + carry; - lo = lo64(result); - cc = hi64(result); - result = hi + cc; - hi = lo64(result); - result = zz[1] + lo; - zz[1] = lo64(result); - cc = hi64(result); - result = zz[2] + hi + cc; - zz[2] = lo64(result); - cc = hi64(result); - result = zz[3] + cc; - zz[3] = lo64(result); - - // z3,z2,z1 - // + p1,p0 - // * w = mu*z1 - // =========== - // z3,z2, 0 - let w = self.mu.wrapping_mul(zz[1] as u64); - result = p[0] * (w as u128); - hi = hi64(result); - lo = lo64(result); - result = zz[1] + lo; - zz[1] = lo64(result); - cc = hi64(result); - result = hi + cc; - carry = lo64(result); - - result = p[1] * (w as u128); - hi = hi64(result); - lo = lo64(result); - result = lo + carry; - lo = lo64(result); - cc = hi64(result); - result = hi + cc; - hi = lo64(result); - result = zz[2] + lo; - zz[2] = lo64(result); - cc = hi64(result); - result = zz[3] + hi + cc; - zz[3] = lo64(result); - cc = hi64(result); - - // z = (z3,z2) - let prod = zz[2] | (zz[3] << 64); - - // Final subtraction - // If z >= p, then z = z - p - - // cc, z - // - 0, p - // ======== - // b1,s1,s0 - let (s0, b0) = prod.overflowing_sub(self.p); - let (_s1, b1) = cc.overflowing_sub(b0 as u128); - // if b1 == 1: return z - // else: return s0 - let mask = 0u128.wrapping_sub(b1 as u128); - (prod & mask) | (s0 & !mask) - } - - /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the - /// runtime of this algorithm is linear in the bit length of `exp`. - pub fn pow(&self, x: u128, exp: u128) -> u128 { - let mut t = self.montgomery(1); - for i in (0..128 - exp.leading_zeros()).rev() { - t = self.mul(t, t); - if (exp >> i) & 1 != 0 { - t = self.mul(t, x); - } - } - t - } - - /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of - /// this algorithm is linear in the bit length of `p`. - #[inline(always)] - pub fn inv(&self, x: u128) -> u128 { - self.pow(x, self.p - 2) - } - - /// Negation, i.e., `-x (mod p)` where `p` is the modulus. - #[inline(always)] - pub fn neg(&self, x: u128) -> u128 { - self.sub(0, x) - } - - /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery - /// domain in order to carry out field arithmetic. The result will be in [0, p). - /// - /// # Example usage - /// ```text - /// let integer = 1; // Standard integer representation - /// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain - /// assert_eq!(elem, 2564090464); - /// ``` - #[inline(always)] - pub fn montgomery(&self, x: u128) -> u128 { - modp(self.mul(x, self.r2), self.p) - } - - /// Maps a field element to its representation as an integer. The result will be in [0, p). - /// - /// #Example usage - /// ```text - /// let elem = 2564090464; // Internal representation in the Montgomery domain - /// let integer = fp.residue(elem); // Standard integer representation - /// assert_eq!(integer, 1); - /// ``` - #[inline(always)] - pub fn residue(&self, x: u128) -> u128 { - modp(self.mul(x, 1), self.p) - } +/// FP32 implements operations over GF(p) for which the prime +/// modulus `p` fits in a u32 word. +pub(crate) struct FP32; + +impl_field_ops_single_word!(FP32, u32, u64); + +impl FieldParameters for FP32 { + const PRIME: u32 = 4293918721; + const MU: u32 = 4293918719; + const R2: u32 = 266338049; + const G: u32 = 3903828692; + const NUM_ROOTS: usize = 20; + const BIT_MASK: u32 = 4294967295; + const ROOTS: [u32; MAX_ROOTS + 1] = [ + 1048575, 4292870146, 1189722990, 3984864191, 2523259768, 2828840154, 1658715539, + 1534972560, 3732920810, 3229320047, 2836564014, 2170197442, 3760663902, 2144268387, + 3849278021, 1395394315, 574397626, 125025876, 3755041587, 2680072542, 3903828692, + ]; + #[cfg(test)] + const LOG2_BASE: usize = 32; + #[cfg(test)] + const LOG2_RADIX: usize = 32; } -#[inline(always)] -pub(crate) fn lo64(x: u128) -> u128 { - x & ((1 << 64) - 1) +/// FP64 implements operations over GF(p) for which the prime +/// modulus `p` fits in a u64 word. +pub(crate) struct FP64; + +impl_field_ops_single_word!(FP64, u64, u128); + +impl FieldParameters for FP64 { + const PRIME: u64 = 18446744069414584321; + const MU: u64 = 18446744069414584319; + const R2: u64 = 18446744065119617025; + const G: u64 = 15733474329512464024; + const NUM_ROOTS: usize = 32; + const BIT_MASK: u64 = 18446744073709551615; + const ROOTS: [u64; MAX_ROOTS + 1] = [ + 4294967295, + 18446744065119617026, + 18446744069414518785, + 18374686475393433601, + 268435456, + 18446673700670406657, + 18446744069414584193, + 576460752303421440, + 16576810576923738718, + 6647628942875889800, + 10087739294013848503, + 2135208489130820273, + 10781050935026037169, + 3878014442329970502, + 1205735313231991947, + 2523909884358325590, + 13797134855221748930, + 12267112747022536458, + 430584883067102937, + 10135969988448727187, + 6815045114074884550, + ]; + #[cfg(test)] + const LOG2_BASE: usize = 64; + #[cfg(test)] + const LOG2_RADIX: usize = 64; } -#[inline(always)] -pub(crate) fn hi64(x: u128) -> u128 { - x >> 64 -} +/// FP128 implements operations over GF(p) for which the prime +/// modulus `p` fits in a u128 word. +pub(crate) struct FP128; -#[inline(always)] -fn modp(x: u128, p: u128) -> u128 { - let (z, carry) = x.overflowing_sub(p); - let m = 0u128.wrapping_sub(carry as u128); - z.wrapping_add(m & p) -} +impl_field_ops_split_word!(FP128, u128, u64); -pub(crate) const FP32: FieldParameters = FieldParameters { - p: 4293918721, // 32-bit prime - mu: 17302828673139736575, - r2: 1676699750, - g: 1074114499, - num_roots: 20, - bit_mask: 4294967295, - roots: [ - 2564090464, 1729828257, 306605458, 2294308040, 1648889905, 57098624, 2788941825, - 2779858277, 368200145, 2760217336, 594450960, 4255832533, 1372848488, 721329415, - 3873251478, 1134002069, 7138597, 2004587313, 2989350643, 725214187, 1074114499, - ], -}; - -pub(crate) const FP128: FieldParameters = FieldParameters { - p: 340282366920938462946865773367900766209, // 128-bit prime - mu: 18446744073709551615, - r2: 403909908237944342183153, - g: 107630958476043550189608038630704257141, - num_roots: 66, - bit_mask: 340282366920938463463374607431768211455, - roots: [ +impl FieldParameters for FP128 { + const PRIME: u128 = 340282366920938462946865773367900766209; + const MU: u128 = 18446744073709551615; + const R2: u128 = 403909908237944342183153; + const G: u128 = 107630958476043550189608038630704257141; + const NUM_ROOTS: usize = 66; + const BIT_MASK: u128 = 340282366920938463463374607431768211455; + const ROOTS: [u128; MAX_ROOTS + 1] = [ 516508834063867445247, 340282366920938462430356939304033320962, 129526470195413442198896969089616959958, @@ -322,8 +113,12 @@ pub(crate) const FP128: FieldParameters = FieldParameters { 332677126194796691532164818746739771387, 258279638927684931537542082169183965856, 148221243758794364405224645520862378432, - ], -}; + ]; + #[cfg(test)] + const LOG2_BASE: usize = 64; + #[cfg(test)] + const LOG2_RADIX: usize = 128; +} /// Compute the ceiling of the base-2 logarithm of `x`. pub(crate) fn log2(x: u128) -> u128 { @@ -333,98 +128,14 @@ pub(crate) fn log2(x: u128) -> u128 { #[cfg(test)] pub(crate) mod tests { - use super::*; + use core::{cmp::max, fmt::Debug, marker::PhantomData}; use modinverse::modinverse; use num_bigint::{BigInt, ToBigInt}; + use num_traits::AsPrimitive; use rand::{distributions::Distribution, thread_rng, Rng}; - use std::cmp::max; - - /// This trait abstracts over the details of [`FieldParameters`] and - /// [`FieldParameters64`](crate::fp64::FieldParameters64) to allow reuse of test code. - pub(crate) trait TestFieldParameters { - fn p(&self) -> u128; - fn g(&self) -> u128; - fn r2(&self) -> u128; - fn mu(&self) -> u64; - fn bit_mask(&self) -> u128; - fn num_roots(&self) -> usize; - fn roots(&self) -> Vec; - fn montgomery(&self, x: u128) -> u128; - fn residue(&self, x: u128) -> u128; - fn add(&self, x: u128, y: u128) -> u128; - fn sub(&self, x: u128, y: u128) -> u128; - fn neg(&self, x: u128) -> u128; - fn mul(&self, x: u128, y: u128) -> u128; - fn pow(&self, x: u128, exp: u128) -> u128; - fn inv(&self, x: u128) -> u128; - fn radix(&self) -> BigInt; - } - - impl TestFieldParameters for FieldParameters { - fn p(&self) -> u128 { - self.p - } - - fn g(&self) -> u128 { - self.g - } - - fn r2(&self) -> u128 { - self.r2 - } - - fn mu(&self) -> u64 { - self.mu - } - - fn bit_mask(&self) -> u128 { - self.bit_mask - } - - fn num_roots(&self) -> usize { - self.num_roots - } - - fn roots(&self) -> Vec { - self.roots.to_vec() - } - - fn montgomery(&self, x: u128) -> u128 { - FieldParameters::montgomery(self, x) - } - - fn residue(&self, x: u128) -> u128 { - FieldParameters::residue(self, x) - } - - fn add(&self, x: u128, y: u128) -> u128 { - FieldParameters::add(self, x, y) - } - - fn sub(&self, x: u128, y: u128) -> u128 { - FieldParameters::sub(self, x, y) - } - - fn neg(&self, x: u128) -> u128 { - FieldParameters::neg(self, x) - } - - fn mul(&self, x: u128, y: u128) -> u128 { - FieldParameters::mul(self, x, y) - } - - fn pow(&self, x: u128, exp: u128) -> u128 { - FieldParameters::pow(self, x, exp) - } - - fn inv(&self, x: u128) -> u128 { - FieldParameters::inv(self, x) - } - fn radix(&self) -> BigInt { - BigInt::from(1) << 128 - } - } + use super::ops::Word; + use crate::fp::{log2, FieldOps, FP128, FP32, FP64, MAX_ROOTS}; #[test] fn test_log2() { @@ -440,165 +151,213 @@ pub(crate) mod tests { assert_eq!(log2((1 << 127) + 13), 128); } - pub(crate) struct TestFieldParametersData { - /// The paramters being tested - pub fp: Box, + struct TestFieldParametersData, W: Word> { /// Expected fp.p - pub expected_p: u128, + pub expected_p: W, /// Expected fp.residue(fp.g) - pub expected_g: u128, - /// Expect fp.residue(fp.pow(fp.g, expected_order)) == 1 - pub expected_order: u128, - } + pub expected_g: W, + /// Expect fp.residue(fp.pow(fp.g, 1 << expected_log2_order)) == 1 + pub expected_log2_order: usize, - #[test] - fn test_fp32_u128() { - all_field_parameters_tests(TestFieldParametersData { - fp: Box::new(FP32), - expected_p: 4293918721, - expected_g: 3925978153, - expected_order: 1 << 20, - }); + phantom: PhantomData, } - #[test] - fn test_fp128_u128() { - all_field_parameters_tests(TestFieldParametersData { - fp: Box::new(FP128), - expected_p: 340282366920938462946865773367900766209, - expected_g: 145091266659756586618791329697897684742, - expected_order: 1 << 66, - }); - } - - pub(crate) fn all_field_parameters_tests(t: TestFieldParametersData) { - // Check that the field parameters have been constructed properly. - check_consistency(t.fp.as_ref(), t.expected_p, t.expected_g, t.expected_order); + impl TestFieldParametersData + where + T: FieldOps, + W: Word + AsPrimitive + ToBigInt + for<'a> TryFrom<&'a BigInt> + Debug, + for<'a> >::Error: Debug, + { + fn all_field_parameters_tests(&self) { + self.check_generator(); + self.check_consistency(); + self.arithmetic_test(); + } // Check that the generator has the correct order. - assert_eq!(t.fp.residue(t.fp.pow(t.fp.g(), t.expected_order)), 1); - assert_ne!(t.fp.residue(t.fp.pow(t.fp.g(), t.expected_order / 2)), 1); - - // Test arithmetic using the field parameters. - arithmetic_test(t.fp.as_ref()); - } + fn check_generator(&self) { + assert_eq!( + T::residue(T::pow(T::G, W::ONE << self.expected_log2_order)), + W::ONE + ); + assert_ne!( + T::residue(T::pow(T::G, W::ONE << (self.expected_log2_order / 2))), + W::ONE + ); + } - fn check_consistency(fp: &dyn TestFieldParameters, p: u128, g: u128, order: u128) { - assert_eq!(fp.p(), p, "p mismatch"); - - let mu = match modinverse((-(p as i128)).rem_euclid(1 << 64), 1 << 64) { - Some(mu) => mu as u64, - None => panic!("inverse of -p (mod 2^64) is undefined"), - }; - assert_eq!(fp.mu(), mu, "mu mismatch"); - - let big_p = &p.to_bigint().unwrap(); - let big_r: &BigInt = &(fp.radix() % big_p); - let big_r2: &BigInt = &(&(big_r * big_r) % big_p); - let mut it = big_r2.iter_u64_digits(); - let mut r2 = 0; - r2 |= it.next().unwrap() as u128; - if let Some(x) = it.next() { - r2 |= (x as u128) << 64; + // Check that the field parameters have been constructed properly. + fn check_consistency(&self) { + assert_eq!(T::PRIME, self.expected_p, "p mismatch"); + + let u128_p = T::PRIME.as_(); + let base = 1i128 << T::LOG2_BASE; + let mu = match modinverse((-(u128_p as i128)).rem_euclid(base), base) { + Some(mu) => mu as u128, + None => panic!("inverse of -p (mod base) is undefined"), + }; + assert_eq!(T::MU.as_(), mu, "mu mismatch"); + + let big_p = &u128_p.to_bigint().unwrap(); + let big_radix = BigInt::from(1) << T::LOG2_RADIX; + let big_r: &BigInt = &(big_radix % big_p); + let big_r2: &BigInt = &(&(big_r * big_r) % big_p); + let mut it = big_r2.iter_u64_digits(); + let mut r2 = 0; + r2 |= it.next().unwrap() as u128; + if let Some(x) = it.next() { + r2 |= (x as u128) << 64; + } + assert_eq!(T::R2.as_(), r2, "r2 mismatch"); + + assert_eq!(T::G, T::montgomery(self.expected_g), "g mismatch"); + assert_eq!( + T::residue(T::pow(T::G, W::ONE << self.expected_log2_order)), + W::ONE, + "g order incorrect" + ); + + let num_roots = self.expected_log2_order; + assert_eq!(T::NUM_ROOTS, num_roots, "num_roots mismatch"); + + let mut roots = vec![W::ZERO; max(num_roots, MAX_ROOTS) + 1]; + roots[num_roots] = T::montgomery(self.expected_g); + for i in (0..num_roots).rev() { + roots[i] = T::mul(roots[i + 1], roots[i + 1]); + } + assert_eq!(T::ROOTS, &roots[..MAX_ROOTS + 1], "roots mismatch"); + assert_eq!(T::residue(T::ROOTS[0]), W::ONE, "first root is not one"); + + let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1); + assert_eq!( + T::BIT_MASK.to_bigint().unwrap(), + bit_mask, + "bit_mask mismatch" + ); } - assert_eq!(fp.r2(), r2, "r2 mismatch"); - assert_eq!(fp.g(), fp.montgomery(g), "g mismatch"); - assert_eq!(fp.residue(fp.pow(fp.g(), order)), 1, "g order incorrect"); + // Test arithmetic using the field parameters. + fn arithmetic_test(&self) { + let u128_p = T::PRIME.as_(); + let big_p = &u128_p.to_bigint().unwrap(); + let big_zero = &BigInt::from(0); + let uniform = rand::distributions::Uniform::from(0..u128_p); + let mut rng = thread_rng(); + + let mut weird_ints = Vec::from([ + 0, + 1, + T::BIT_MASK.as_() - u128_p, + T::BIT_MASK.as_() - u128_p + 1, + u128_p - 1, + ]); + if u128_p > u64::MAX as u128 { + weird_ints.extend_from_slice(&[ + u64::MAX as u128, + 1 << 64, + u128_p & u64::MAX as u128, + u128_p & !u64::MAX as u128, + u128_p & !u64::MAX as u128 | 1, + ]); + } - let num_roots = log2(order) as usize; - assert_eq!(order, 1 << num_roots, "order not a power of 2"); - assert_eq!(fp.num_roots(), num_roots, "num_roots mismatch"); + let mut generate_random = || -> (W, BigInt) { + // Add bias to random element generation, to explore "interesting" inputs. + let intu128 = if rng.gen_ratio(1, 4) { + weird_ints[rng.gen_range(0..weird_ints.len())] + } else { + uniform.sample(&mut rng) + }; + let bigint = intu128.to_bigint().unwrap(); + let int = W::try_from(&bigint).unwrap(); + let montgomery_domain = T::montgomery(int); + (montgomery_domain, bigint) + }; - let mut roots = vec![0; max(num_roots, MAX_ROOTS) + 1]; - roots[num_roots] = fp.montgomery(g); - for i in (0..num_roots).rev() { - roots[i] = fp.mul(roots[i + 1], roots[i + 1]); + for _ in 0..1000 { + let (x, ref big_x) = generate_random(); + let (y, ref big_y) = generate_random(); + + // Test addition. + let got = T::add(x, y); + let want = (big_x + big_y) % big_p; + assert_eq!(T::residue(got).to_bigint().unwrap(), want); + + // Test subtraction. + let got = T::sub(x, y); + let want = if big_x >= big_y { + big_x - big_y + } else { + big_p - big_y + big_x + }; + assert_eq!(T::residue(got).to_bigint().unwrap(), want); + + // Test multiplication. + let got = T::mul(x, y); + let want = (big_x * big_y) % big_p; + assert_eq!(T::residue(got).to_bigint().unwrap(), want); + + // Test inversion. + let got = T::inv(x); + let want = big_x.modpow(&(big_p - 2), big_p); + assert_eq!(T::residue(got).to_bigint().unwrap(), want); + if big_x == big_zero { + assert_eq!(T::residue(T::mul(got, x)), W::ZERO); + } else { + assert_eq!(T::residue(T::mul(got, x)), W::ONE); + } + + // Test negation. + let got = T::neg(x); + let want = (big_p - big_x) % big_p; + assert_eq!(T::residue(got).to_bigint().unwrap(), want); + assert_eq!(T::residue(T::add(got, x)), W::ZERO); + } } - assert_eq!(fp.roots(), &roots[..MAX_ROOTS + 1], "roots mismatch"); - assert_eq!(fp.residue(fp.roots()[0]), 1, "first root is not one"); - - let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1); - assert_eq!( - fp.bit_mask().to_bigint().unwrap(), - bit_mask, - "bit_mask mismatch" - ); } - fn arithmetic_test(fp: &dyn TestFieldParameters) { - let big_p = &fp.p().to_bigint().unwrap(); - let big_zero = &BigInt::from(0); - let uniform = rand::distributions::Uniform::from(0..fp.p()); - let mut rng = thread_rng(); - - let mut weird_ints = Vec::from([ - 0, - 1, - fp.bit_mask() - fp.p(), - fp.bit_mask() - fp.p() + 1, - fp.p() - 1, - ]); - if fp.p() > u64::MAX as u128 { - weird_ints.extend_from_slice(&[ - u64::MAX as u128, - 1 << 64, - fp.p() & u64::MAX as u128, - fp.p() & !u64::MAX as u128, - fp.p() & !u64::MAX as u128 | 1, - ]); + mod fp32 { + #[test] + fn check_field_parameters() { + use super::*; + + TestFieldParametersData:: { + expected_p: 4293918721, + expected_g: 3925978153, + expected_log2_order: 20, + phantom: PhantomData, + } + .all_field_parameters_tests(); } + } - let mut generate_random = || -> (u128, BigInt) { - // Add bias to random element generation, to explore "interesting" inputs. - let int = if rng.gen_ratio(1, 4) { - weird_ints[rng.gen_range(0..weird_ints.len())] - } else { - uniform.sample(&mut rng) - }; - let bigint = int.to_bigint().unwrap(); - let montgomery_domain = fp.montgomery(int); - (montgomery_domain, bigint) - }; - - for _ in 0..1000 { - let (x, ref big_x) = generate_random(); - let (y, ref big_y) = generate_random(); - - // Test addition. - let got = fp.add(x, y); - let want = (big_x + big_y) % big_p; - assert_eq!(fp.residue(got).to_bigint().unwrap(), want); - - // Test subtraction. - let got = fp.sub(x, y); - let want = if big_x >= big_y { - big_x - big_y - } else { - big_p - big_y + big_x - }; - assert_eq!(fp.residue(got).to_bigint().unwrap(), want); - - // Test multiplication. - let got = fp.mul(x, y); - let want = (big_x * big_y) % big_p; - assert_eq!(fp.residue(got).to_bigint().unwrap(), want); - - // Test inversion. - let got = fp.inv(x); - let want = big_x.modpow(&(big_p - 2u128), big_p); - assert_eq!(fp.residue(got).to_bigint().unwrap(), want); - if big_x == big_zero { - assert_eq!(fp.residue(fp.mul(got, x)), 0); - } else { - assert_eq!(fp.residue(fp.mul(got, x)), 1); + mod fp64 { + #[test] + fn check_field_parameters() { + use super::*; + + TestFieldParametersData:: { + expected_p: 18446744069414584321, + expected_g: 1753635133440165772, + expected_log2_order: 32, + phantom: PhantomData, } + .all_field_parameters_tests(); + } + } - // Test negation. - let got = fp.neg(x); - let want = (big_p - big_x) % big_p; - assert_eq!(fp.residue(got).to_bigint().unwrap(), want); - assert_eq!(fp.residue(fp.add(got, x)), 0); + mod fp128 { + #[test] + fn check_field_parameters() { + use super::*; + + TestFieldParametersData:: { + expected_p: 340282366920938462946865773367900766209, + expected_g: 145091266659756586618791329697897684742, + expected_log2_order: 66, + phantom: PhantomData, + } + .all_field_parameters_tests(); } } } diff --git a/src/fp/ops.rs b/src/fp/ops.rs new file mode 100644 index 000000000..87aedc7ee --- /dev/null +++ b/src/fp/ops.rs @@ -0,0 +1,429 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Prime field arithmetic for any Galois field GF(p) for which `p < 2^W`, +//! where `W ≤ 128` is a specified word size. + +use num_traits::{ + ops::overflowing::{OverflowingAdd, OverflowingSub}, + AsPrimitive, ConstOne, ConstZero, PrimInt, Unsigned, WrappingAdd, WrappingMul, WrappingSub, +}; + +use crate::fp::MAX_ROOTS; + +/// `Word` is the datatype used for storing and operating with +/// field elements. +/// +/// The types `u32`, `u64`, and `u128` implement this trait. +pub trait Word: + 'static + + Unsigned + + PrimInt + + OverflowingAdd + + OverflowingSub + + WrappingAdd + + WrappingSub + + WrappingMul + + ConstZero + + ConstOne + + From +{ + /// Number of bits of word size. + const BITS: usize; +} + +impl Word for u32 { + const BITS: usize = Self::BITS as usize; +} + +impl Word for u64 { + const BITS: usize = Self::BITS as usize; +} + +impl Word for u128 { + const BITS: usize = Self::BITS as usize; +} + +/// `FieldParameters` sets the parameters to implement a prime field +/// GF(p) for which the prime modulus `p` fits in one word of +/// `W ≤ 128` bits. +pub trait FieldParameters { + /// The prime modulus `p`. + const PRIME: W; + /// `mu = -p^(-1) mod 2^LOG2_BASE`. + const MU: W; + /// `r2 = R^2 mod p`. + const R2: W; + /// The `2^num_roots`-th -principal root of unity. This element + /// is used to generate the elements of `ROOTS`. + const G: W; + /// The number of principal roots of unity in `ROOTS`. + const NUM_ROOTS: usize; + /// Equal to `2^b - 1`, where `b` is the length of `p` in bits. + const BIT_MASK: W; + /// `ROOTS[l]` is the `2^l`-th principal root of unity, i.e., + /// `ROOTS[l]` has order `2^l` in the multiplicative group. + /// `ROOTS[0]` is equal to one by definition. + const ROOTS: [W; MAX_ROOTS + 1]; + /// The log2(base) for the base used for multiprecision arithmetic. + /// So, `LOG2_BASE ≤ 64` as processors have at most a 64-bit + /// integer multiplier. + #[cfg(test)] + const LOG2_BASE: usize; + /// The log2(R) where R is the machine word-friendly modulus + /// used in the Montgomery representation. + #[cfg(test)] + const LOG2_RADIX: usize; +} + +/// `FieldOps` provides arithmetic operations over GF(p). +/// +/// The multiplication method is required as it admits different +/// implementations. +pub trait FieldOps: FieldParameters { + /// Addition. The result will be in [0, p), so long as both x + /// and y are as well. + #[inline(always)] + fn add(x: W, y: W) -> W { + // 0,x + // + 0,y + // ===== + // c,z + let (z, carry) = x.overflowing_add(&y); + + // c, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = z.overflowing_sub(&Self::PRIME); + let (_s1, b1) = + >::from(carry).overflowing_sub(&>::from(b0)); + // if b1 == 1: return z + // else: return s0 + let mask = W::ZERO.wrapping_sub(&>::from(b1)); + (z & mask) | (s0 & !mask) + } + + /// Subtraction. The result will be in [0, p), so long as both x + /// and y are as well. + #[inline(always)] + fn sub(x: W, y: W) -> W { + // x + // - y + // ======== + // b0,z0 + let (z0, b0) = x.overflowing_sub(&y); + let mask = W::ZERO.wrapping_sub(&>::from(b0)); + // z0 + // + p + // ======== + // s1,s0 + z0.wrapping_add(&(mask & Self::PRIME)) + // z0 + (m & self.p) + // if b1 == 1: return s0 + // else: return z0 + } + + /// Negation, i.e., `-x (mod p)` where `p` is the modulus. + #[inline(always)] + fn neg(x: W) -> W { + Self::sub(W::ZERO, x) + } + + #[inline(always)] + fn modp(x: W) -> W { + Self::sub(x, Self::PRIME) + } + + /// Multiplication. The result will be in [0, p), so long as both x + /// and y are as well. + fn mul(x: W, y: W) -> W; + + /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is + /// the modulus. Note that the runtime of this algorithm is + /// linear in the bit length of `exp`. + fn pow(x: W, exp: W) -> W { + let mut t = Self::ROOTS[0]; + for i in (0..W::BITS - (exp.leading_zeros() as usize)).rev() { + t = Self::mul(t, t); + if (exp >> i) & W::ONE != W::ZERO { + t = Self::mul(t, x); + } + } + t + } + + /// Modular inversion, i.e., x^-1 (mod p) where `p` is the + /// modulus. Note that the runtime of this algorithm is + /// linear in the bit length of `p`. + #[inline(always)] + fn inv(x: W) -> W { + Self::pow(x, Self::PRIME - W::ONE - W::ONE) + } + + /// Maps an integer to its internal representation. Field + /// elements are mapped to the Montgomery domain in order to + /// carry out field arithmetic. The result will be in [0, p). + #[inline(always)] + fn montgomery(x: W) -> W { + Self::modp(Self::mul(x, Self::R2)) + } + + /// Maps a field element to its representation as an integer. + /// The result will be in [0, p). + #[inline(always)] + fn residue(x: W) -> W { + Self::modp(Self::mul(x, W::ONE)) + } +} + +/// `FieldMulOpsSingleWord` implements prime field multiplication. +/// +/// The implementation assumes that the modulus `p` fits in one word of +/// 'W' bits, and that the product of two integers fits in a datatype +/// (`Self::DoubleWord`) of exactly `2*W` bits. +pub(crate) trait FieldMulOpsSingleWord: FieldParameters +where + W: Word + AsPrimitive, +{ + type DoubleWord: Word + AsPrimitive; + + /// Multiplication of field elements in the Montgomery domain. + /// This uses the Montgomery's [REDC algorithm][montgomery]. + /// The result will be in [0, p). + /// + /// [montgomery]: https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf + fn mul(x: W, y: W) -> W { + let hi_lo = |v: Self::DoubleWord| -> (W, W) { ((v >> W::BITS).as_(), v.as_()) }; + + // Integer multiplication + // z = x * y + + // x + // * y + // ===== + // z1,z0 + let (z1, z0) = hi_lo(x.as_() * y.as_()); + + // Montgomery Reduction + // z = z + p * mu*(z mod 2^W), where mu = (-p)^(-1) mod 2^W. + // = z + p * w, where w = mu*z0 + let w = Self::MU.wrapping_mul(&z0); + let (r1, r0) = hi_lo(Self::PRIME.as_() * w.as_()); + + // z1,z0 + // + r1,r0 + // ===== + // cc, z, 0 + let (_zero, carry) = z0.overflowing_add(&r0); + let (cc, z) = hi_lo(z1.as_() + r1.as_() + >::from(carry)); + + // Final subtraction + // If z >= p, then z = z - p + + // cc, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = z.overflowing_sub(&Self::PRIME); + let (_s1, b1) = cc.overflowing_sub(&>::from(b0)); + // if b1 == 1: return z + // else: return s0 + let mask = W::ZERO.wrapping_sub(&>::from(b1)); + (z & mask) | (s0 & !mask) + } +} + +/// `FieldMulOpsSplitWord` implements prime field multiplication. +/// +/// The implementation assumes that the modulus `p` fits in one word of +/// 'W' bits, but the product of two integers does not fit in any primitive +/// integer. Thus, multiplication is processed splitting integers in two +/// words. +pub(crate) trait FieldMulOpsSplitWord: FieldParameters +where + W: Word + AsPrimitive, +{ + type HalfWord: Word + AsPrimitive; + const MU: Self::HalfWord; + /// Multiplication of field elements in the Montgomery domain. + /// This uses the Montgomery's [REDC algorithm][montgomery]. + /// The result will be in [0, p). + /// + /// [montgomery]: https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf + fn mul(x: W, y: W) -> W { + let high = |v: W| v >> (W::BITS / 2); + let low = |v: W| v & ((W::ONE << (W::BITS / 2)) - W::ONE); + + let (x1, x0) = (high(x), low(x)); + let (y1, y0) = (high(y), low(y)); + + // Integer multiplication + // z = x * y + + // x1,x0 + // * y1,y0 + // =========== + // z3,z2,z1,z0 + let mut result = x0 * y0; + let mut carry = high(result); + let z0 = low(result); + result = x0 * y1; + let mut hi = high(result); + let mut lo = low(result); + result = lo + carry; + let mut z1 = low(result); + let mut cc = high(result); + result = hi + cc; + let mut z2 = low(result); + + result = x1 * y0; + hi = high(result); + lo = low(result); + result = z1 + lo; + z1 = low(result); + cc = high(result); + result = hi + cc; + carry = low(result); + + result = x1 * y1; + hi = high(result); + lo = low(result); + result = lo + carry; + lo = low(result); + cc = high(result); + result = hi + cc; + hi = low(result); + result = z2 + lo; + z2 = low(result); + cc = high(result); + result = hi + cc; + let mut z3 = low(result); + + // Montgomery Reduction + // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64. + + // z3,z2,z1,z0 + // + p1,p0 + // * w = mu*z0 + // =========== + // z3,z2,z1, 0 + let mut w = >::MU.wrapping_mul(&z0.as_()); + let p0 = low(Self::PRIME); + result = p0 * w.as_(); + hi = high(result); + lo = low(result); + result = z0 + lo; + cc = high(result); + result = hi + cc; + carry = low(result); + + let p1 = high(Self::PRIME); + result = p1 * w.as_(); + hi = high(result); + lo = low(result); + result = lo + carry; + lo = low(result); + cc = high(result); + result = hi + cc; + hi = low(result); + result = z1 + lo; + z1 = low(result); + cc = high(result); + result = z2 + hi + cc; + z2 = low(result); + cc = high(result); + result = z3 + cc; + z3 = low(result); + + // z3,z2,z1 + // + p1,p0 + // * w = mu*z1 + // =========== + // z3,z2, 0 + w = >::MU.wrapping_mul(&z1.as_()); + result = p0 * w.as_(); + hi = high(result); + lo = low(result); + result = z1 + lo; + cc = high(result); + result = hi + cc; + carry = low(result); + + result = p1 * w.as_(); + hi = high(result); + lo = low(result); + result = lo + carry; + lo = low(result); + cc = high(result); + result = hi + cc; + hi = low(result); + result = z2 + lo; + z2 = low(result); + cc = high(result); + result = z3 + hi + cc; + z3 = low(result); + cc = high(result); + + // z = (z3,z2) + let prod = z2 | (z3 << (W::BITS / 2)); + + // Final subtraction + // If z >= p, then z = z - p + + // cc, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = prod.overflowing_sub(&Self::PRIME); + let (_s1, b1) = cc.overflowing_sub(&>::from(b0)); + // if b1 == 1: return z + // else: return s0 + let mask = W::ZERO.wrapping_sub(&>::from(b1)); + (prod & mask) | (s0 & !mask) + } +} + +/// `impl_field_ops_single_word` helper to implement prime field operations. +/// +/// The implementation assumes that the modulus `p` fits in one word of +/// 'W' bits, and that the product of two integers fits in a datatype +/// (`Self::DoubleWord`) of exactly `2*W` bits. +macro_rules! impl_field_ops_single_word { + ($struct_name:ident, $W:ty, $W2:ty) => { + const _: () = assert!(<$W2>::BITS == 2 * <$W>::BITS); + impl $crate::fp::ops::FieldMulOpsSingleWord<$W> for $struct_name { + type DoubleWord = $W2; + } + impl $crate::fp::ops::FieldOps<$W> for $struct_name { + #[inline(always)] + fn mul(x: $W, y: $W) -> $W { + >::mul(x, y) + } + } + }; +} + +/// `impl_field_ops_split_word` helper to implement prime field operations. +/// +/// The implementation assumes that the modulus `p` fits in one word of +/// 'W' bits, but the product of two integers does not fit. Thus, +/// multiplication is processed splitting integers in two words. +macro_rules! impl_field_ops_split_word { + ($struct_name:ident, $W:ty, $W2:ty) => { + const _: () = assert!(2 * <$W2>::BITS == <$W>::BITS); + impl $crate::fp::ops::FieldMulOpsSplitWord<$W> for $struct_name { + type HalfWord = $W2; + const MU: Self::HalfWord = { + let mu = <$struct_name as FieldParameters<$W>>::MU; + assert!(mu <= (<$W2>::MAX as $W)); + mu as $W2 + }; + } + impl $crate::fp::ops::FieldOps<$W> for $struct_name { + #[inline(always)] + fn mul(x: $W, y: $W) -> $W { + >::mul(x, y) + } + } + }; +} diff --git a/src/fp64.rs b/src/fp64.rs deleted file mode 100644 index e2f2cf2b0..000000000 --- a/src/fp64.rs +++ /dev/null @@ -1,307 +0,0 @@ -// SPDX-License-Identifier: MPL-2.0 - -//! Finite field arithmetic for any field GF(p) for which p < 2^64. - -use crate::fp::{hi64, lo64, MAX_ROOTS}; - -/// This structure represents the parameters of a finite field GF(p) for which p < 2^64. -/// -/// See also [`FieldParameters`](crate::fp::FieldParameters). -#[derive(Debug, PartialEq, Eq)] -pub(crate) struct FieldParameters64 { - /// The prime modulus `p`. - pub p: u64, - /// `mu = -p^(-1) mod 2^64`. - pub mu: u64, - /// `r2 = (2^64)^2 mod p`. - pub r2: u64, - /// The `2^num_roots`-th -principal root of unity. This element is used to generate the - /// elements of `roots`. - pub g: u64, - /// The number of principal roots of unity in `roots`. - pub num_roots: usize, - /// Equal to `2^b - 1`, where `b` is the length of `p` in bits. - pub bit_mask: u64, - /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the - /// multiplicative group. `roots[0]` is equal to one by definition. - pub roots: [u64; MAX_ROOTS + 1], -} - -impl FieldParameters64 { - /// Addition. The result will be in [0, p), so long as both x and y are as well. - #[inline(always)] - pub fn add(&self, x: u64, y: u64) -> u64 { - // 0,x - // + 0,y - // ===== - // c,z - let (z, carry) = x.overflowing_add(y); - // c, z - // - 0, p - // ======== - // b1,s1,s0 - let (s0, b0) = z.overflowing_sub(self.p); - let (_s1, b1) = (carry as u64).overflowing_sub(b0 as u64); - // if b1 == 1: return z - // else: return s0 - let m = 0u64.wrapping_sub(b1 as u64); - (z & m) | (s0 & !m) - } - - /// Subtraction. The result will be in [0, p), so long as both x and y are as well. - #[inline(always)] - pub fn sub(&self, x: u64, y: u64) -> u64 { - // x - // - y - // ======== - // b0,z0 - let (z0, b0) = x.overflowing_sub(y); - let m = 0u64.wrapping_sub(b0 as u64); - // z0 - // + p - // ======== - // s1,s0 - z0.wrapping_add(m & self.p) - // if b1 == 1: return s0 - // else: return z0 - } - - /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm - /// described [here][montgomery]. The result will be in [0, p). - /// - /// # Example usage - /// ```text - /// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46); - /// ``` - /// - /// [montgomery]: https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf - #[inline(always)] - pub fn mul(&self, x: u64, y: u64) -> u64 { - let mut zz = [0; 2]; - - // Integer multiplication - // z = x * y - - // x - // * y - // ===== - // z1,z0 - let result = (x as u128) * (y as u128); - zz[0] = lo64(result) as u64; - zz[1] = hi64(result) as u64; - - // Montgomery Reduction - // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64. - - // z1,z0 - // + p - // * w = mu*z0 - // ===== - // z1, 0 - let w = self.mu.wrapping_mul(zz[0]); - let result = (self.p as u128) * (w as u128); - let hi = hi64(result); - let lo = lo64(result) as u64; - let (result, carry) = zz[0].overflowing_add(lo); - zz[0] = result; - let result = zz[1] as u128 + hi + carry as u128; - zz[1] = lo64(result) as u64; - let cc = hi64(result) as u64; - - // z = (z1) - let prod = zz[1]; - - // Final subtraction - // If z >= p, then z = z - p - - // cc, z - // - 0, p - // ======== - // b1,s1,s0 - let (s0, b0) = prod.overflowing_sub(self.p); - let (_s1, b1) = cc.overflowing_sub(b0 as u64); - // if b1 == 1: return z - // else: return s0 - let mask = 0u64.wrapping_sub(b1 as u64); - (prod & mask) | (s0 & !mask) - } - - /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the - /// runtime of this algorithm is linear in the bit length of `exp`. - pub fn pow(&self, x: u64, exp: u64) -> u64 { - let mut t = self.montgomery(1); - for i in (0..64 - exp.leading_zeros()).rev() { - t = self.mul(t, t); - if (exp >> i) & 1 != 0 { - t = self.mul(t, x); - } - } - t - } - - /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of - /// this algorithm is linear in the bit length of `p`. - #[inline(always)] - pub fn inv(&self, x: u64) -> u64 { - self.pow(x, self.p - 2) - } - - /// Negation, i.e., `-x (mod p)` where `p` is the modulus. - #[inline(always)] - pub fn neg(&self, x: u64) -> u64 { - self.sub(0, x) - } - - /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery - /// domain in order to carry out field arithmetic. The result will be in [0, p). - /// - /// # Example usage - /// ```text - /// let integer = 1; // Standard integer representation - /// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain - /// assert_eq!(elem, 2564090464); - /// ``` - #[inline(always)] - pub fn montgomery(&self, x: u64) -> u64 { - modp(self.mul(x, self.r2), self.p) - } - - /// Maps a field element to its representation as an integer. The result will be in [0, p). - /// - /// #Example usage - /// ```text - /// let elem = 2564090464; // Internal representation in the Montgomery domain - /// let integer = fp.residue(elem); // Standard integer representation - /// assert_eq!(integer, 1); - /// ``` - #[inline(always)] - pub fn residue(&self, x: u64) -> u64 { - modp(self.mul(x, 1), self.p) - } -} - -#[inline(always)] -fn modp(x: u64, p: u64) -> u64 { - let (z, carry) = x.overflowing_sub(p); - let m = 0u64.wrapping_sub(carry as u64); - z.wrapping_add(m & p) -} - -pub(crate) const FP64: FieldParameters64 = FieldParameters64 { - p: 18446744069414584321, // 64-bit prime - mu: 18446744069414584319, - r2: 18446744065119617025, - g: 15733474329512464024, - num_roots: 32, - bit_mask: 18446744073709551615, - roots: [ - 4294967295, - 18446744065119617026, - 18446744069414518785, - 18374686475393433601, - 268435456, - 18446673700670406657, - 18446744069414584193, - 576460752303421440, - 16576810576923738718, - 6647628942875889800, - 10087739294013848503, - 2135208489130820273, - 10781050935026037169, - 3878014442329970502, - 1205735313231991947, - 2523909884358325590, - 13797134855221748930, - 12267112747022536458, - 430584883067102937, - 10135969988448727187, - 6815045114074884550, - ], -}; - -#[cfg(test)] -mod tests { - use num_bigint::BigInt; - - use crate::fp::tests::{ - all_field_parameters_tests, TestFieldParameters, TestFieldParametersData, - }; - - use super::*; - - impl TestFieldParameters for FieldParameters64 { - fn p(&self) -> u128 { - self.p.into() - } - - fn g(&self) -> u128 { - self.g as u128 - } - - fn r2(&self) -> u128 { - self.r2 as u128 - } - - fn mu(&self) -> u64 { - self.mu - } - - fn bit_mask(&self) -> u128 { - self.bit_mask as u128 - } - - fn num_roots(&self) -> usize { - self.num_roots - } - - fn roots(&self) -> Vec { - self.roots.iter().map(|x| *x as u128).collect() - } - - fn montgomery(&self, x: u128) -> u128 { - FieldParameters64::montgomery(self, x.try_into().unwrap()).into() - } - - fn residue(&self, x: u128) -> u128 { - FieldParameters64::residue(self, x.try_into().unwrap()).into() - } - - fn add(&self, x: u128, y: u128) -> u128 { - FieldParameters64::add(self, x.try_into().unwrap(), y.try_into().unwrap()).into() - } - - fn sub(&self, x: u128, y: u128) -> u128 { - FieldParameters64::sub(self, x.try_into().unwrap(), y.try_into().unwrap()).into() - } - - fn neg(&self, x: u128) -> u128 { - FieldParameters64::neg(self, x.try_into().unwrap()).into() - } - - fn mul(&self, x: u128, y: u128) -> u128 { - FieldParameters64::mul(self, x.try_into().unwrap(), y.try_into().unwrap()).into() - } - - fn pow(&self, x: u128, exp: u128) -> u128 { - FieldParameters64::pow(self, x.try_into().unwrap(), exp.try_into().unwrap()).into() - } - - fn inv(&self, x: u128) -> u128 { - FieldParameters64::inv(self, x.try_into().unwrap()).into() - } - - fn radix(&self) -> BigInt { - BigInt::from(1) << 64 - } - } - - #[test] - fn test_fp64_u64() { - all_field_parameters_tests(TestFieldParametersData { - fp: Box::new(FP64), - expected_p: 18446744069414584321, - expected_g: 1753635133440165772, - expected_order: 1 << 32, - }) - } -} diff --git a/src/lib.rs b/src/lib.rs index e5280d505..d306926ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,6 @@ mod fft; pub mod field; pub mod flp; mod fp; -mod fp64; #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] #[cfg_attr( docsrs, From 14ba16d87f9c138794710a715c7846c8f3ed400b Mon Sep 17 00:00:00 2001 From: divviup-github-automation Date: Tue, 27 Aug 2024 16:24:43 +0000 Subject: [PATCH 04/32] Bump libprio-rs patch version, triggered by @divergentdave --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8ebdef81f..1a188fd9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -732,7 +732,7 @@ checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" [[package]] name = "prio" -version = "0.16.6" +version = "0.16.7" dependencies = [ "aes", "assert_matches", diff --git a/Cargo.toml b/Cargo.toml index 3c288beeb..02f47baa0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "prio" -version = "0.16.6" +version = "0.16.7" authors = ["Josh Aas ", "Tim Geoghegan ", "Christopher Patton "] edition = "2021" exclude = ["/supply-chain"] From af81bb7c6a2e34dc4b0827e32ea4c76487f9d2ff Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 22:05:04 +0000 Subject: [PATCH 05/32] build(deps): Bump serde from 1.0.208 to 1.0.209 (#1112) --- Cargo.lock | 8 ++++---- supply-chain/imports.lock | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1a188fd9e..f759c4daa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -957,18 +957,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.208" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff085d2cb684faa248efb494c39b68e522822ac0de72ccf08109abde717cfb2" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.208" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24008e81ff7613ed8e5ba0cfaf24e2c2f1e5b8a0495711e44fcd4882fca62bcf" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index d87b6ece9..d53d4fbee 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -114,15 +114,15 @@ user-login = "Amanieu" user-name = "Amanieu d'Antras" [[publisher.serde]] -version = "1.0.208" -when = "2024-08-15" +version = "1.0.209" +when = "2024-08-24" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde_derive]] -version = "1.0.208" -when = "2024-08-15" +version = "1.0.209" +when = "2024-08-24" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From db5b4827ffad167410ea2fd4e2ae0c7a725b33b5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 22:05:29 +0000 Subject: [PATCH 06/32] build(deps): Bump serde_json from 1.0.125 to 1.0.127 (#1113) --- Cargo.lock | 4 ++-- supply-chain/imports.lock | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f759c4daa..a79c070a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -977,9 +977,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.125" +version = "1.0.127" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed" +checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" dependencies = [ "itoa", "memchr", diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index d53d4fbee..47c997a9c 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -128,8 +128,8 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde_json]] -version = "1.0.125" -when = "2024-08-15" +version = "1.0.127" +when = "2024-08-23" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From df9104347a2564a746a79ef93340397aa397de46 Mon Sep 17 00:00:00 2001 From: David Cook Date: Thu, 5 Sep 2024 14:41:56 -0500 Subject: [PATCH 07/32] Remove ignored ?Sized bound (#1114) --- src/codec.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codec.rs b/src/codec.rs index 98e6299ab..3b4086ff6 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -90,7 +90,7 @@ pub trait ParameterizedDecode

: Sized { /// Provide a blanket implementation so that any [`Decode`] can be used as a /// `ParameterizedDecode` for any `T`. -impl ParameterizedDecode for D { +impl ParameterizedDecode for D { fn decode_with_param( _decoding_parameter: &T, bytes: &mut Cursor<&[u8]>, From 7cdb688931154b0a540ed4982dde08816aa5f99d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 18:47:19 +0000 Subject: [PATCH 08/32] build(deps): Bump serde_json from 1.0.127 to 1.0.128 (#1116) --- Cargo.lock | 4 ++-- supply-chain/imports.lock | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a79c070a6..277b74481 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -977,9 +977,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.127" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "itoa", "memchr", diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index 47c997a9c..e74a50f5c 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -128,8 +128,8 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde_json]] -version = "1.0.127" -when = "2024-08-23" +version = "1.0.128" +when = "2024-09-04" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From 47e2ac583ecd1f38430047e006bfd59417c8d065 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 18:48:13 +0000 Subject: [PATCH 09/32] build(deps): Bump serde from 1.0.209 to 1.0.210 (#1115) --- Cargo.lock | 8 ++++---- supply-chain/imports.lock | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 277b74481..9221e9ebb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -957,18 +957,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.209" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.209" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index e74a50f5c..3106d8b1a 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -114,15 +114,15 @@ user-login = "Amanieu" user-name = "Amanieu d'Antras" [[publisher.serde]] -version = "1.0.209" -when = "2024-08-24" +version = "1.0.210" +when = "2024-09-06" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde_derive]] -version = "1.0.209" -when = "2024-08-24" +version = "1.0.210" +when = "2024-09-06" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From a85d271ddee087f13dfd847a7170786f35abd0b9 Mon Sep 17 00:00:00 2001 From: David Cook Date: Tue, 10 Sep 2024 11:41:11 -0500 Subject: [PATCH 10/32] Replace uses of fixed-macro (#1117) * Replace uses of fixed-macro * Prune exemptions --- Cargo.lock | 61 -------------------------- Cargo.toml | 1 - benches/speed_tests.rs | 39 ++++++++-------- binaries/Cargo.toml | 1 - binaries/src/bin/vdaf_message_sizes.rs | 10 +++-- src/flp/types/fixedpoint_l2.rs | 30 ++++++------- src/vdaf/prio3.rs | 44 ++++++++++--------- supply-chain/config.toml | 8 ---- 8 files changed, 64 insertions(+), 130 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9221e9ebb..a0ce3e561 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -345,41 +345,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "fixed-macro" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0c48af8cb14e02868f449f8a2187bd78af7a08da201fdc78d518ecb1675bc" -dependencies = [ - "fixed", - "fixed-macro-impl", - "fixed-macro-types", -] - -[[package]] -name = "fixed-macro-impl" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c93086f471c0a1b9c5e300ea92f5cd990ac6d3f8edf27616ef624b8fa6402d4b" -dependencies = [ - "fixed", - "paste", - "proc-macro-error", - "proc-macro2", - "quote", - "syn 1.0.104", -] - -[[package]] -name = "fixed-macro-types" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "044a61b034a2264a7f65aa0c3cd112a01b4d4ee58baace51fead3f21b993c7e4" -dependencies = [ - "fixed", - "fixed-macro-impl", -] - [[package]] name = "funty" version = "2.0.0" @@ -744,7 +709,6 @@ dependencies = [ "ctr", "fiat-crypto", "fixed", - "fixed-macro", "getrandom", "hex", "hex-literal", @@ -777,35 +741,10 @@ version = "0.5.0" dependencies = [ "base64", "fixed", - "fixed-macro", "prio", "rand", ] -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.104", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro2" version = "1.0.74" diff --git a/Cargo.toml b/Cargo.toml index 02f47baa0..c0d119ebb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,6 @@ assert_matches = "1.5.0" base64 = "0.22.1" cfg-if = "1.0.0" criterion = "0.5" -fixed-macro = "1.2.0" hex-literal = "0.4.1" iai = "0.1" modinverse = "0.1.0" diff --git a/benches/speed_tests.rs b/benches/speed_tests.rs index bf7a66e9a..4718b182a 100644 --- a/benches/speed_tests.rs +++ b/benches/speed_tests.rs @@ -6,8 +6,6 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri #[cfg(feature = "experimental")] use fixed::types::{I1F15, I1F31}; #[cfg(feature = "experimental")] -use fixed_macro::fixed; -#[cfg(feature = "experimental")] use num_bigint::BigUint; #[cfg(feature = "experimental")] use num_rational::Ratio; @@ -434,6 +432,11 @@ fn prio3(c: &mut Criterion) { #[cfg(feature = "experimental")] { + const FP16_ZERO: I1F15 = I1F15::lit("0"); + const FP32_ZERO: I1F31 = I1F31::lit("0"); + const FP16_HALF: I1F15 = I1F15::lit("0.5"); + const FP32_HALF: I1F31 = I1F31::lit("0.5"); + let mut group = c.benchmark_group("prio3fixedpointboundedl2vecsum_i1f15_shard"); for dimension in [10, 100, 1_000] { group.bench_with_input( @@ -442,8 +445,8 @@ fn prio3(c: &mut Criterion) { |b, dimension| { let vdaf: Prio3, _, 16> = Prio3::new_fixedpoint_boundedl2_vec_sum(num_shares, *dimension).unwrap(); - let mut measurement = vec![fixed!(0: I1F15); *dimension]; - measurement[0] = fixed!(0.5: I1F15); + let mut measurement = vec![FP16_ZERO; *dimension]; + measurement[0] = FP16_HALF; let nonce = black_box([0u8; 16]); b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); }, @@ -462,8 +465,8 @@ fn prio3(c: &mut Criterion) { num_shares, *dimension, ) .unwrap(); - let mut measurement = vec![fixed!(0: I1F15); *dimension]; - measurement[0] = fixed!(0.5: I1F15); + let mut measurement = vec![FP16_ZERO; *dimension]; + measurement[0] = FP16_HALF; let nonce = black_box([0u8; 16]); b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); }, @@ -480,8 +483,8 @@ fn prio3(c: &mut Criterion) { |b, dimension| { let vdaf: Prio3, _, 16> = Prio3::new_fixedpoint_boundedl2_vec_sum(num_shares, *dimension).unwrap(); - let mut measurement = vec![fixed!(0: I1F15); *dimension]; - measurement[0] = fixed!(0.5: I1F15); + let mut measurement = vec![FP16_ZERO; *dimension]; + measurement[0] = FP16_HALF; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); @@ -512,8 +515,8 @@ fn prio3(c: &mut Criterion) { num_shares, *dimension, ) .unwrap(); - let mut measurement = vec![fixed!(0: I1F15); *dimension]; - measurement[0] = fixed!(0.5: I1F15); + let mut measurement = vec![FP16_ZERO; *dimension]; + measurement[0] = FP16_HALF; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); let (public_share, input_shares) = @@ -543,8 +546,8 @@ fn prio3(c: &mut Criterion) { |b, dimension| { let vdaf: Prio3, _, 16> = Prio3::new_fixedpoint_boundedl2_vec_sum(num_shares, *dimension).unwrap(); - let mut measurement = vec![fixed!(0: I1F31); *dimension]; - measurement[0] = fixed!(0.5: I1F31); + let mut measurement = vec![FP32_ZERO; *dimension]; + measurement[0] = FP32_HALF; let nonce = black_box([0u8; 16]); b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); }, @@ -563,8 +566,8 @@ fn prio3(c: &mut Criterion) { num_shares, *dimension, ) .unwrap(); - let mut measurement = vec![fixed!(0: I1F31); *dimension]; - measurement[0] = fixed!(0.5: I1F31); + let mut measurement = vec![FP32_ZERO; *dimension]; + measurement[0] = FP32_HALF; let nonce = black_box([0u8; 16]); b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); }, @@ -581,8 +584,8 @@ fn prio3(c: &mut Criterion) { |b, dimension| { let vdaf: Prio3, _, 16> = Prio3::new_fixedpoint_boundedl2_vec_sum(num_shares, *dimension).unwrap(); - let mut measurement = vec![fixed!(0: I1F31); *dimension]; - measurement[0] = fixed!(0.5: I1F31); + let mut measurement = vec![FP32_ZERO; *dimension]; + measurement[0] = FP32_HALF; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); @@ -613,8 +616,8 @@ fn prio3(c: &mut Criterion) { num_shares, *dimension, ) .unwrap(); - let mut measurement = vec![fixed!(0: I1F31); *dimension]; - measurement[0] = fixed!(0.5: I1F31); + let mut measurement = vec![FP32_ZERO; *dimension]; + measurement[0] = FP32_HALF; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); let (public_share, input_shares) = diff --git a/binaries/Cargo.toml b/binaries/Cargo.toml index 81fd8d3a4..cb57338c9 100644 --- a/binaries/Cargo.toml +++ b/binaries/Cargo.toml @@ -9,6 +9,5 @@ repository = "https://github.com/divviup/libprio-rs" [dependencies] base64 = "0.22.1" fixed = "1.27" -fixed-macro = "1.2.0" rand = "0.8" prio = { path = "..", features = ["experimental", "test-util"] } diff --git a/binaries/src/bin/vdaf_message_sizes.rs b/binaries/src/bin/vdaf_message_sizes.rs index a5c0035b5..dae75685a 100644 --- a/binaries/src/bin/vdaf_message_sizes.rs +++ b/binaries/src/bin/vdaf_message_sizes.rs @@ -1,5 +1,7 @@ -use fixed::{types::extra::U15, FixedI16}; -use fixed_macro::fixed; +use fixed::{ + types::{extra::U15, I1F15}, + FixedI16, +}; use prio::{ codec::Encode, @@ -53,8 +55,8 @@ fn main() { let len = 1000; let prio3 = Prio3::new_fixedpoint_boundedl2_vec_sum(num_shares, len).unwrap(); - let fp_num = fixed!(0.0001: I1F15); - let measurement = vec![fp_num; len]; + const FP_NUM: I1F15 = I1F15::lit("0.0001"); + let measurement = vec![FP_NUM; len]; println!( "prio3 fixedpoint16 boundedl2 vec ({} entries) size = {}", len, diff --git a/src/flp/types/fixedpoint_l2.rs b/src/flp/types/fixedpoint_l2.rs index 8f3a6321f..fbedb9321 100644 --- a/src/flp/types/fixedpoint_l2.rs +++ b/src/flp/types/fixedpoint_l2.rs @@ -672,17 +672,23 @@ mod tests { use crate::flp::test_utils::FlpTest; use crate::vdaf::xof::SeedStreamTurboShake128; use fixed::types::extra::{U127, U14, U63}; + use fixed::types::{I1F15, I1F31, I1F63}; use fixed::{FixedI128, FixedI16, FixedI64}; - use fixed_macro::fixed; use rand::SeedableRng; + const FP16_4_INV: I1F15 = I1F15::lit("0.25"); + const FP16_8_INV: I1F15 = I1F15::lit("0.125"); + const FP16_16_INV: I1F15 = I1F15::lit("0.0625"); + const FP32_4_INV: I1F31 = I1F31::lit("0.25"); + const FP32_8_INV: I1F31 = I1F31::lit("0.125"); + const FP32_16_INV: I1F31 = I1F31::lit("0.0625"); + const FP64_4_INV: I1F63 = I1F63::lit("0.25"); + const FP64_8_INV: I1F63 = I1F63::lit("0.125"); + const FP64_16_INV: I1F63 = I1F63::lit("0.0625"); + #[test] fn test_bounded_fpvec_sum_parallel_fp16() { - let fp16_4_inv = fixed!(0.25: I1F15); - let fp16_8_inv = fixed!(0.125: I1F15); - let fp16_16_inv = fixed!(0.0625: I1F15); - - let fp16_vec = vec![fp16_4_inv, fp16_8_inv, fp16_16_inv]; + let fp16_vec = vec![FP16_4_INV, FP16_8_INV, FP16_16_INV]; // the encoded vector has the following entries: // enc(0.25) = 2^(n-1) * 0.25 + 2^(n-1) = 40960 @@ -693,22 +699,14 @@ mod tests { #[test] fn test_bounded_fpvec_sum_parallel_fp32() { - let fp32_4_inv = fixed!(0.25: I1F31); - let fp32_8_inv = fixed!(0.125: I1F31); - let fp32_16_inv = fixed!(0.0625: I1F31); - - let fp32_vec = vec![fp32_4_inv, fp32_8_inv, fp32_16_inv]; + let fp32_vec = vec![FP32_4_INV, FP32_8_INV, FP32_16_INV]; // computed as above but with n=32 test_fixed(fp32_vec, vec![2684354560, 2415919104, 2281701376]); } #[test] fn test_bounded_fpvec_sum_parallel_fp64() { - let fp64_4_inv = fixed!(0.25: I1F63); - let fp64_8_inv = fixed!(0.125: I1F63); - let fp64_16_inv = fixed!(0.0625: I1F63); - - let fp64_vec = vec![fp64_4_inv, fp64_8_inv, fp64_16_inv]; + let fp64_vec = vec![FP64_4_INV, FP64_8_INV, FP64_16_INV]; // computed as above but with n=64 test_fixed( fp64_vec, diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 635af4805..d3241603d 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -1575,11 +1575,12 @@ mod tests { use assert_matches::assert_matches; #[cfg(feature = "experimental")] use fixed::{ - types::extra::{U15, U31, U63}, + types::{ + extra::{U15, U31, U63}, + I1F15, I1F31, I1F63, + }, FixedI16, FixedI32, FixedI64, }; - #[cfg(feature = "experimental")] - use fixed_macro::fixed; use rand::prelude::*; #[test] @@ -1719,19 +1720,19 @@ mod tests { { const SIZE: usize = 5; - let fp32_0 = fixed!(0: I1F31); + const FP32_0: I1F31 = I1F31::lit("0"); // 32 bit fixedpoint, non-power-of-2 vector, single-threaded { let prio3_32 = ctor_32(2, SIZE).unwrap(); - test_fixed_vec::<_, _, _, SIZE>(fp32_0, prio3_32); + test_fixed_vec::<_, _, _, SIZE>(FP32_0, prio3_32); } // 32 bit fixedpoint, non-power-of-2 vector, multi-threaded #[cfg(feature = "multithreaded")] { let prio3_mt_32 = ctor_mt_32(2, SIZE).unwrap(); - test_fixed_vec::<_, _, _, SIZE>(fp32_0, prio3_mt_32); + test_fixed_vec::<_, _, _, SIZE>(FP32_0, prio3_mt_32); } } @@ -1756,6 +1757,16 @@ mod tests { #[test] #[cfg(feature = "experimental")] fn test_prio3_bounded_fpvec_sum() { + const FP16_4_INV: I1F15 = I1F15::lit("0.25"); + const FP16_8_INV: I1F15 = I1F15::lit("0.125"); + const FP16_16_INV: I1F15 = I1F15::lit("0.0625"); + const FP32_4_INV: I1F31 = I1F31::lit("0.25"); + const FP32_8_INV: I1F31 = I1F31::lit("0.125"); + const FP32_16_INV: I1F31 = I1F31::lit("0.0625"); + const FP64_4_INV: I1F63 = I1F63::lit("0.25"); + const FP64_8_INV: I1F63 = I1F63::lit("0.125"); + const FP64_16_INV: I1F63 = I1F63::lit("0.0625"); + type P = Prio3FixedPointBoundedL2VecSum; let ctor_16 = P::>::new_fixedpoint_boundedl2_vec_sum; let ctor_32 = P::>::new_fixedpoint_boundedl2_vec_sum; @@ -1772,56 +1783,47 @@ mod tests { { // 16 bit fixedpoint - let fp16_4_inv = fixed!(0.25: I1F15); - let fp16_8_inv = fixed!(0.125: I1F15); - let fp16_16_inv = fixed!(0.0625: I1F15); // two aggregators, three entries per vector. { let prio3_16 = ctor_16(2, 3).unwrap(); - test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16); + test_fixed(FP16_4_INV, FP16_8_INV, FP16_16_INV, prio3_16); } #[cfg(feature = "multithreaded")] { let prio3_16_mt = ctor_mt_16(2, 3).unwrap(); - test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16_mt); + test_fixed(FP16_4_INV, FP16_8_INV, FP16_16_INV, prio3_16_mt); } } { // 32 bit fixedpoint - let fp32_4_inv = fixed!(0.25: I1F31); - let fp32_8_inv = fixed!(0.125: I1F31); - let fp32_16_inv = fixed!(0.0625: I1F31); { let prio3_32 = ctor_32(2, 3).unwrap(); - test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32); + test_fixed(FP32_4_INV, FP32_8_INV, FP32_16_INV, prio3_32); } #[cfg(feature = "multithreaded")] { let prio3_32_mt = ctor_mt_32(2, 3).unwrap(); - test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32_mt); + test_fixed(FP32_4_INV, FP32_8_INV, FP32_16_INV, prio3_32_mt); } } { // 64 bit fixedpoint - let fp64_4_inv = fixed!(0.25: I1F63); - let fp64_8_inv = fixed!(0.125: I1F63); - let fp64_16_inv = fixed!(0.0625: I1F63); { let prio3_64 = ctor_64(2, 3).unwrap(); - test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64); + test_fixed(FP64_4_INV, FP64_8_INV, FP64_16_INV, prio3_64); } #[cfg(feature = "multithreaded")] { let prio3_64_mt = ctor_mt_64(2, 3).unwrap(); - test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64_mt); + test_fixed(FP64_4_INV, FP64_8_INV, FP64_16_INV, prio3_64_mt); } } diff --git a/supply-chain/config.toml b/supply-chain/config.toml index f03642d9a..1c12a8a12 100644 --- a/supply-chain/config.toml +++ b/supply-chain/config.toml @@ -100,10 +100,6 @@ notes = "This is only used when the \"crypto-dependencies\" feature is enabled." version = "1.20.0" criteria = "safe-to-deploy" -[[exemptions.fixed-macro-types]] -version = "1.2.0" -criteria = "safe-to-run" - [[exemptions.funty]] version = "2.0.0" criteria = "safe-to-deploy" @@ -162,10 +158,6 @@ criteria = "safe-to-run" version = "0.2.16" criteria = "safe-to-deploy" -[[exemptions.proc-macro-error]] -version = "1.0.4" -criteria = "safe-to-run" - [[exemptions.radium]] version = "0.7.0" criteria = "safe-to-deploy" From dea0456ec4e0af8a3c1fed26db6bf7d776f195ad Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 20:26:16 +0000 Subject: [PATCH 11/32] build(deps): Bump thiserror from 1.0.63 to 1.0.64 (#1118) --- Cargo.lock | 8 ++++---- supply-chain/audits.toml | 10 ++++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a0ce3e561..34c578d16 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1008,18 +1008,18 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", diff --git a/supply-chain/audits.toml b/supply-chain/audits.toml index 14fd85d45..eec216a1b 100644 --- a/supply-chain/audits.toml +++ b/supply-chain/audits.toml @@ -764,6 +764,11 @@ who = "Brandon Pitman " criteria = "safe-to-deploy" delta = "1.0.40 -> 1.0.43" +[[audits.thiserror]] +who = "Brandon Pitman " +criteria = "safe-to-deploy" +delta = "1.0.63 -> 1.0.64" + [[audits.thiserror-impl]] who = "Brandon Pitman " criteria = "safe-to-deploy" @@ -779,6 +784,11 @@ who = "Brandon Pitman " criteria = "safe-to-deploy" delta = "1.0.40 -> 1.0.43" +[[audits.thiserror-impl]] +who = "Brandon Pitman " +criteria = "safe-to-deploy" +delta = "1.0.63 -> 1.0.64" + [[audits.unicode-ident]] who = "David Cook " criteria = "safe-to-deploy" From 2dcf5ca801c21ee1672bb448a2438aeeb4dadff4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 10:53:03 -0500 Subject: [PATCH 12/32] build(deps): Bump once_cell from 1.19.0 to 1.20.1 (#1119) * build(deps): Bump once_cell from 1.19.0 to 1.20.1 Bumps [once_cell](https://github.com/matklad/once_cell) from 1.19.0 to 1.20.1. - [Changelog](https://github.com/matklad/once_cell/blob/master/CHANGELOG.md) - [Commits](https://github.com/matklad/once_cell/compare/v1.19.0...v1.20.1) --- updated-dependencies: - dependency-name: once_cell dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Record audit, add exemption * Upgrade crossbeam dependencies --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: David Cook --- Cargo.lock | 50 ++++++++++----------------- Cargo.toml | 2 +- supply-chain/audits.toml | 5 +++ supply-chain/config.toml | 9 +++-- supply-chain/imports.lock | 71 +++++++++++++++++++++++++++++++++++---- 5 files changed, 91 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 34c578d16..d295b5671 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -231,38 +231,28 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.2" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" dependencies = [ - "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.9.10" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "045ebe27666471bb549370b4b0b3e51b07f56325befa4284db65fc89c02511b1" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "autocfg", - "cfg-if", "crossbeam-utils", - "memoffset", - "once_cell", - "scopeguard", ] [[package]] name = "crossbeam-utils" -version = "0.8.11" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51887d4adc7b564537b15adcfb307936f8075dfcd5f00dde9a9f1d29383682bc" -dependencies = [ - "cfg-if", - "once_cell", -] +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crunchy" @@ -533,15 +523,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "memoffset" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" -dependencies = [ - "autocfg", -] - [[package]] name = "modinverse" version = "0.1.1" @@ -645,9 +626,12 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1" +dependencies = [ + "portable-atomic", +] [[package]] name = "oorandom" @@ -689,6 +673,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" + [[package]] name = "ppv-lite86" version = "0.2.16" @@ -888,12 +878,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "scopeguard" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" - [[package]] name = "serde" version = "1.0.210" diff --git a/Cargo.toml b/Cargo.toml index c0d119ebb..c0d159545 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ hex-literal = "0.4.1" iai = "0.1" modinverse = "0.1.0" num-bigint = "0.4.6" -once_cell = "1.19.0" +once_cell = "1.20.1" prio = { path = ".", features = ["crypto-dependencies", "test-util"] } statrs = "0.17.1" diff --git a/supply-chain/audits.toml b/supply-chain/audits.toml index eec216a1b..dbaf88f1a 100644 --- a/supply-chain/audits.toml +++ b/supply-chain/audits.toml @@ -498,6 +498,11 @@ who = "Brandon Pitman " criteria = "safe-to-deploy" delta = "1.18.0 -> 1.19.0" +[[audits.once_cell]] +who = "David Cook " +criteria = "safe-to-deploy" +delta = "1.19.0 -> 1.20.1" + [[audits.opaque-debug]] who = "David Cook " criteria = "safe-to-deploy" diff --git a/supply-chain/config.toml b/supply-chain/config.toml index 1c12a8a12..5de3ea665 100644 --- a/supply-chain/config.toml +++ b/supply-chain/config.toml @@ -129,11 +129,6 @@ criteria = "safe-to-run" version = "0.3.7" criteria = "safe-to-run" -[[exemptions.memoffset]] -version = "0.6.5" -criteria = "safe-to-deploy" -notes = "This is only used when the \"multithreaded\" feature is enabled." - [[exemptions.nalgebra]] version = "0.29.0" criteria = "safe-to-run" @@ -154,6 +149,10 @@ criteria = "safe-to-run" version = "0.3.4" criteria = "safe-to-run" +[[exemptions.portable-atomic]] +version = "1.9.0" +criteria = "safe-to-deploy" + [[exemptions.ppv-lite86]] version = "0.2.16" criteria = "safe-to-deploy" diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index 3106d8b1a..b67e3bc75 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -106,13 +106,6 @@ user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" -[[publisher.scopeguard]] -version = "1.1.0" -when = "2020-02-16" -user-id = 2915 -user-login = "Amanieu" -user-name = "Amanieu d'Antras" - [[publisher.serde]] version = "1.0.210" when = "2024-09-06" @@ -319,6 +312,12 @@ who = "Pat Hickey " criteria = "safe-to-deploy" version = "0.2.0" +[[audits.bytecode-alliance.audits.crossbeam-epoch]] +who = "Alex Crichton " +criteria = "safe-to-deploy" +delta = "0.9.15 -> 0.9.18" +notes = "Nontrivial update but mostly around dependencies and how `unsafe` code is managed. Everything looks the same shape as before." + [[audits.bytecode-alliance.audits.crypto-common]] who = "Benjamin Bouvier " criteria = "safe-to-deploy" @@ -479,6 +478,37 @@ criteria = "safe-to-deploy" delta = "0.10.2 -> 0.10.3" aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" +[[audits.mozilla.audits.crossbeam-epoch]] +who = "Mike Hommey " +criteria = "safe-to-deploy" +delta = "0.9.10 -> 0.9.13" +aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" + +[[audits.mozilla.audits.crossbeam-epoch]] +who = "Mike Hommey " +criteria = "safe-to-deploy" +delta = "0.9.13 -> 0.9.14" +aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" + +[[audits.mozilla.audits.crossbeam-utils]] +who = "Mike Hommey " +criteria = "safe-to-deploy" +delta = "0.8.11 -> 0.8.14" +aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" + +[[audits.mozilla.audits.crossbeam-utils]] +who = "Jan-Erik Rediger " +criteria = "safe-to-deploy" +delta = "0.8.14 -> 0.8.19" +aggregated-from = "https://raw.githubusercontent.com/mozilla/glean/main/supply-chain/audits.toml" + +[[audits.mozilla.audits.crossbeam-utils]] +who = "Alex Franchuk " +criteria = "safe-to-deploy" +delta = "0.8.19 -> 0.8.20" +notes = "Minor changes." +aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" + [[audits.mozilla.audits.crypto-common]] who = "Mike Hommey " criteria = "safe-to-deploy" @@ -586,6 +616,33 @@ version = "2.5.0" notes = "The goal is to provide some constant-time correctness for cryptographic implementations. The approach is reasonable, it is known to be insufficient but this is pointed out in the documentation." aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" +[[audits.zcash.audits.crossbeam-deque]] +who = "Jack Grigg " +criteria = "safe-to-deploy" +delta = "0.8.2 -> 0.8.3" +notes = "No new code." +aggregated-from = "https://raw.githubusercontent.com/zcash/zcash/master/qa/supply-chain/audits.toml" + +[[audits.zcash.audits.crossbeam-deque]] +who = "Jack Grigg " +criteria = "safe-to-deploy" +delta = "0.8.3 -> 0.8.4" +aggregated-from = "https://raw.githubusercontent.com/zcash/zcash/master/qa/supply-chain/audits.toml" + +[[audits.zcash.audits.crossbeam-deque]] +who = "Daira-Emma Hopwood " +criteria = "safe-to-deploy" +delta = "0.8.4 -> 0.8.5" +notes = "Changes to `unsafe` code look okay." +aggregated-from = "https://raw.githubusercontent.com/zcash/zcash/master/qa/supply-chain/audits.toml" + +[[audits.zcash.audits.crossbeam-epoch]] +who = "Jack Grigg " +criteria = "safe-to-deploy" +delta = "0.9.14 -> 0.9.15" +notes = "Bumps memoffset to 0.9, and unmarks some ARMv7r and Sony Vita targets as not having 64-bit atomics." +aggregated-from = "https://raw.githubusercontent.com/zcash/zcash/master/qa/supply-chain/audits.toml" + [[audits.zcash.audits.getrandom]] who = "Jack Grigg " criteria = "safe-to-deploy" From b6f1a415b117a95a1f9c69bca28a69451b6e1177 Mon Sep 17 00:00:00 2001 From: David Cook Date: Mon, 7 Oct 2024 12:11:14 -0500 Subject: [PATCH 13/32] Upgrade to cargo-vet 0.10.0 (#1120) --- .github/workflows/supply-chain.yml | 2 +- supply-chain/config.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/supply-chain.yml b/.github/workflows/supply-chain.yml index 7f30db31d..a582334d7 100644 --- a/.github/workflows/supply-chain.yml +++ b/.github/workflows/supply-chain.yml @@ -16,7 +16,7 @@ jobs: name: Vet Dependencies runs-on: ubuntu-latest env: - CARGO_VET_VERSION: 0.9.0 + CARGO_VET_VERSION: 0.10.0 steps: - uses: actions/checkout@v4 - name: Install Rust toolchain diff --git a/supply-chain/config.toml b/supply-chain/config.toml index 5de3ea665..104176b2e 100644 --- a/supply-chain/config.toml +++ b/supply-chain/config.toml @@ -2,7 +2,7 @@ # cargo-vet config file [cargo-vet] -version = "0.9" +version = "0.10" [imports.bytecode-alliance] url = "https://raw.githubusercontent.com/bytecodealliance/wasmtime/main/supply-chain/audits.toml" From 45ce6c3576b9923c0463de527f61d79139d7515c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:32:23 +0000 Subject: [PATCH 14/32] build(deps): Bump once_cell from 1.20.1 to 1.20.2 (#1121) --- Cargo.lock | 13 ++----------- Cargo.toml | 2 +- supply-chain/config.toml | 4 ---- supply-chain/imports.lock | 23 +++++++++++++---------- 4 files changed, 16 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d295b5671..046995cb5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -626,12 +626,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.1" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1" -dependencies = [ - "portable-atomic", -] +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "oorandom" @@ -673,12 +670,6 @@ dependencies = [ "plotters-backend", ] -[[package]] -name = "portable-atomic" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" - [[package]] name = "ppv-lite86" version = "0.2.16" diff --git a/Cargo.toml b/Cargo.toml index c0d159545..9863eb40d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ hex-literal = "0.4.1" iai = "0.1" modinverse = "0.1.0" num-bigint = "0.4.6" -once_cell = "1.20.1" +once_cell = "1.20.2" prio = { path = ".", features = ["crypto-dependencies", "test-util"] } statrs = "0.17.1" diff --git a/supply-chain/config.toml b/supply-chain/config.toml index 104176b2e..bfc08a75a 100644 --- a/supply-chain/config.toml +++ b/supply-chain/config.toml @@ -149,10 +149,6 @@ criteria = "safe-to-run" version = "0.3.4" criteria = "safe-to-run" -[[exemptions.portable-atomic]] -version = "1.9.0" -criteria = "safe-to-deploy" - [[exemptions.ppv-lite86]] version = "0.2.16" criteria = "safe-to-deploy" diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index b67e3bc75..2d58fcb3a 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -590,6 +590,19 @@ version = "0.2.15" notes = "All code written or reviewed by Josh Stone." aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" +[[audits.mozilla.audits.once_cell]] +who = "Mike Hommey " +criteria = "safe-to-deploy" +delta = "1.16.0 -> 1.17.1" +aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" + +[[audits.mozilla.audits.once_cell]] +who = "Erich Gubler " +criteria = "safe-to-deploy" +delta = "1.20.1 -> 1.20.2" +notes = "This update works around a Cargo bug that forces the addition of `portable-atomic` into a lockfile, which we have never needed to use." +aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" + [[audits.mozilla.audits.rand_core]] who = "Mike Hommey " criteria = "safe-to-deploy" @@ -700,16 +713,6 @@ criteria = "safe-to-deploy" delta = "2.7.2 -> 2.7.4" aggregated-from = "https://raw.githubusercontent.com/zcash/librustzcash/main/supply-chain/audits.toml" -[[audits.zcash.audits.once_cell]] -who = "Jack Grigg " -criteria = "safe-to-deploy" -delta = "1.17.0 -> 1.17.1" -notes = """ -Small refactor that reduces the overall amount of `unsafe` code. The new strict provenance -approach looks reasonable. -""" -aggregated-from = "https://raw.githubusercontent.com/zcash/zcash/master/qa/supply-chain/audits.toml" - [[audits.zcash.audits.unicode-ident]] who = "Daira Hopwood " criteria = "safe-to-deploy" From a2ffdcda7de09c5831e136e3c908948857b74574 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 21 Oct 2024 18:07:10 +0000 Subject: [PATCH 15/32] build(deps): Bump serde_json from 1.0.128 to 1.0.132 (#1124) --- Cargo.lock | 4 ++-- supply-chain/imports.lock | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 046995cb5..ec392d51e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -891,9 +891,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index 2d58fcb3a..c5de4353d 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -121,8 +121,8 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde_json]] -version = "1.0.128" -when = "2024-09-04" +version = "1.0.132" +when = "2024-10-19" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From d2fe428cadaf690a56a74f4c682f62326f959f06 Mon Sep 17 00:00:00 2001 From: Michael Rosenberg Date: Thu, 31 Oct 2024 11:13:59 -0400 Subject: [PATCH 16/32] Implement `Prio3MutlihotCountVec` (#1123) --- src/flp.rs | 5 +- src/flp/types.rs | 349 +++++++++++++++++++++++++++++++++++++++++++++- src/vdaf/prio3.rs | 57 +++++++- 3 files changed, 404 insertions(+), 7 deletions(-) diff --git a/src/flp.rs b/src/flp.rs index 93822c913..ebba717ca 100644 --- a/src/flp.rs +++ b/src/flp.rs @@ -134,7 +134,10 @@ pub trait Type: Sized + Eq + Clone + Debug { measurement: &Self::Measurement, ) -> Result, FlpError>; - /// Decode an aggregate result. + /// Decodes an aggregate result. + /// + /// This is NOT the inverse of `encode_measurement`. Rather, the input is an aggregation of + /// truncated measurements. fn decode_result( &self, data: &[Self::Field], diff --git a/src/flp/types.rs b/src/flp/types.rs index a7b8f2da8..b45aba498 100644 --- a/src/flp/types.rs +++ b/src/flp/types.rs @@ -6,6 +6,7 @@ use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt}; use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; use crate::flp::{FlpError, Gadget, Type}; use crate::polynomial::poly_range_check; +use crate::vdaf::prio3::ilog2; use std::convert::TryInto; use std::fmt::{self, Debug}; use std::marker::PhantomData; @@ -471,6 +472,237 @@ where } } +/// The multihot counter data type. Each measurement is a list of booleans of length `length`, with +/// at most `max_weight` true values, and the aggregate is a histogram counting the number of true +/// values at each position across all measurements. +#[derive(PartialEq, Eq)] +pub struct MultihotCountVec { + // Parameters + /// The number of elements in the list of booleans + length: usize, + /// The max number of permissible `true` values in the list of booleans + max_weight: usize, + /// The size of the chunks fed into our gadget calls + chunk_length: usize, + + // Calculated from parameters + gadget_calls: usize, + bits_for_weight: usize, + offset: usize, + phantom: PhantomData<(F, S)>, +} + +impl Debug for MultihotCountVec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MultihotCountVec") + .field("length", &self.length) + .field("max_weight", &self.max_weight) + .field("chunk_length", &self.chunk_length) + .finish() + } +} + +impl>> MultihotCountVec { + /// Return a new [`MultihotCountVec`] type with the given number of buckets. + pub fn new( + num_buckets: usize, + max_weight: usize, + chunk_length: usize, + ) -> Result { + if num_buckets >= u32::MAX as usize { + return Err(FlpError::Encode( + "invalid num_buckets: exceeds maximum permitted".to_string(), + )); + } + if num_buckets == 0 { + return Err(FlpError::InvalidParameter( + "num_buckets cannot be zero".to_string(), + )); + } + if chunk_length == 0 { + return Err(FlpError::InvalidParameter( + "chunk_length cannot be zero".to_string(), + )); + } + if max_weight == 0 { + return Err(FlpError::InvalidParameter( + "max_weight cannot be zero".to_string(), + )); + } + + // The bitlength of a measurement is the number of buckets plus the bitlength of the max + // weight + let bits_for_weight = ilog2(max_weight) as usize + 1; + let meas_length = num_buckets + bits_for_weight; + + // Gadget calls is ⌈meas_length / chunk_length⌉ + let gadget_calls = (meas_length + chunk_length - 1) / chunk_length; + // Offset is 2^max_weight.bitlen() - 1 - max_weight + let offset = (1 << bits_for_weight) - 1 - max_weight; + + Ok(Self { + length: num_buckets, + max_weight, + chunk_length, + gadget_calls, + bits_for_weight, + offset, + phantom: PhantomData, + }) + } +} + +// Cannot autoderive clone because it requires F and S to be Clone, which they're not in general +impl Clone for MultihotCountVec { + fn clone(&self) -> Self { + Self { + length: self.length, + max_weight: self.max_weight, + chunk_length: self.chunk_length, + bits_for_weight: self.bits_for_weight, + offset: self.offset, + gadget_calls: self.gadget_calls, + phantom: self.phantom, + } + } +} + +impl Type for MultihotCountVec +where + F: FftFriendlyFieldElement, + S: ParallelSumGadget> + Eq + 'static, +{ + type Measurement = Vec; + type AggregateResult = Vec; + type Field = F; + + fn encode_measurement(&self, measurement: &Vec) -> Result, FlpError> { + let weight_reported: usize = measurement.iter().filter(|bit| **bit).count(); + + if measurement.len() != self.length { + return Err(FlpError::Encode(format!( + "unexpected measurement length: got {}; want {}", + measurement.len(), + self.length + ))); + } + if weight_reported > self.max_weight { + return Err(FlpError::Encode(format!( + "unexpected measurement weight: got {}; want ≤{}", + weight_reported, self.max_weight + ))); + } + + // Convert bool vector to field elems + let multihot_vec: Vec = measurement + .iter() + // We can unwrap because any Integer type can cast from bool + .map(|bit| F::from(F::valid_integer_try_from(*bit as usize).unwrap())) + .collect(); + + // Encode the measurement weight in binary (actually, the weight plus some offset) + let offset_weight_bits = { + let offset_weight_reported = F::valid_integer_try_from(self.offset + weight_reported)?; + F::encode_as_bitvector(offset_weight_reported, self.bits_for_weight)?.collect() + }; + + // Report the concat of the two + Ok([multihot_vec, offset_weight_bits].concat()) + } + + fn decode_result( + &self, + data: &[Self::Field], + _num_measurements: usize, + ) -> Result { + // The aggregate is the same as the decoded result. Just convert to integers + decode_result_vec(data, self.length) + } + + fn gadget(&self) -> Vec>> { + vec![Box::new(S::new( + Mul::new(self.gadget_calls), + self.chunk_length, + ))] + } + + fn valid( + &self, + g: &mut Vec>>, + input: &[F], + joint_rand: &[F], + num_shares: usize, + ) -> Result { + self.valid_call_check(input, joint_rand)?; + + // Check that each element of `input` is a 0 or 1. + let range_check = parallel_sum_range_checks( + &mut g[0], + input, + joint_rand[0], + self.chunk_length, + num_shares, + )?; + + // Check that the elements of `input` sum to at most `max_weight`. + let count_vec = &input[..self.length]; + let weight = count_vec.iter().fold(F::zero(), |a, b| a + *b); + let offset_weight_reported = F::decode_bitvector(&input[self.length..])?; + + // From spec: weight_check = self.offset*shares_inv + weight - weight_reported + let weight_check = { + let offset = F::from(F::valid_integer_try_from(self.offset)?); + let shares_inv = F::from(F::valid_integer_try_from(num_shares)?).inv(); + offset * shares_inv + weight - offset_weight_reported + }; + + // Take a random linear combination of both checks. + let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * weight_check; + Ok(out) + } + + // Truncates the measurement, removing extra data that was necessary for validity (here, the + // encoded weight), but not important for aggregation + fn truncate(&self, input: Vec) -> Result, FlpError> { + self.truncate_call_check(&input)?; + // Cut off the encoded weight + Ok(input[..self.length].to_vec()) + } + + // The length in field elements of the encoded input returned by [`Self::encode_measurement`]. + fn input_len(&self) -> usize { + self.length + self.bits_for_weight + } + + fn proof_len(&self) -> usize { + (self.chunk_length * 2) + 2 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1 + } + + fn verifier_len(&self) -> usize { + 2 + self.chunk_length * 2 + } + + // The length of the truncated output (i.e., the output of [`Type::truncate`]). + fn output_len(&self) -> usize { + self.length + } + + // The number of random values needed in the validity checks + fn joint_rand_len(&self) -> usize { + 2 + } + + fn prove_rand_len(&self) -> usize { + self.chunk_length * 2 + } + + fn query_rand_len(&self) -> usize { + // TODO: this will need to be increase once draft-10 is implemented and more randomness is + // necessary due to random linear combination computations + 1 + } +} + /// A sequence of integers in range `[0, 2^bits)`. This type uses a neat trick from [[BBCG+19], /// Corollary 4.9] to reduce the proof size to roughly the square root of the input size. /// @@ -685,13 +917,13 @@ pub(crate) fn call_gadget_on_vec_entries( input: &[F], rnd: F, ) -> Result { - let mut range_check = F::zero(); + let mut comb = F::zero(); let mut r = rnd; for chunk in input.chunks(1) { - range_check += r * g.call(chunk)?; + comb += r * g.call(chunk)?; r *= rnd; } - Ok(range_check) + Ok(comb) } /// Given a vector `data` of field elements which should contain exactly one entry, return the @@ -776,7 +1008,9 @@ pub(crate) fn parallel_sum_range_checks( #[cfg(test)] mod tests { use super::*; - use crate::field::{random_vector, Field64 as TestField, FieldElement}; + use crate::field::{ + random_vector, Field64 as TestField, FieldElement, FieldElementWithInteger, + }; use crate::flp::gadgets::ParallelSum; #[cfg(feature = "multithreaded")] use crate::flp::gadgets::ParallelSumMultithreaded; @@ -957,6 +1191,113 @@ mod tests { ); } + fn test_multihot(constructor: F) + where + F: Fn(usize, usize, usize) -> Result, FlpError>, + S: ParallelSumGadget> + Eq + 'static, + { + const NUM_SHARES: usize = 3; + + // Chunk size for our range check gadget + let chunk_size = 2; + + // Our test is on multihot vecs of length 3, with max weight 2 + let num_buckets = 3; + let max_weight = 2; + + let multihot_instance = constructor(num_buckets, max_weight, chunk_size).unwrap(); + let zero = TestField::zero(); + let one = TestField::one(); + let nine = TestField::from(9); + + let encoded_weight_plus_offset = |weight| { + let bits_for_weight = ilog2(max_weight) as usize + 1; + let offset = (1 << bits_for_weight) - 1 - max_weight; + TestField::encode_as_bitvector( + ::Integer::try_from(weight + offset).unwrap(), + bits_for_weight, + ) + .unwrap() + .collect::>() + }; + + assert_eq!( + multihot_instance + .encode_measurement(&vec![true, true, false]) + .unwrap(), + [&[one, one, zero], &*encoded_weight_plus_offset(2)].concat(), + ); + assert_eq!( + multihot_instance + .encode_measurement(&vec![false, true, true]) + .unwrap(), + [&[zero, one, one], &*encoded_weight_plus_offset(2)].concat(), + ); + + // Round trip + assert_eq!( + multihot_instance + .decode_result( + &multihot_instance + .truncate( + multihot_instance + .encode_measurement(&vec![false, true, true]) + .unwrap() + ) + .unwrap(), + 1 + ) + .unwrap(), + [0, 1, 1] + ); + + // Test valid inputs with weights 0, 1, and 2 + FlpTest::expect_valid::( + &multihot_instance, + &multihot_instance + .encode_measurement(&vec![true, false, false]) + .unwrap(), + &[one, zero, zero], + ); + + FlpTest::expect_valid::( + &multihot_instance, + &multihot_instance + .encode_measurement(&vec![false, true, true]) + .unwrap(), + &[zero, one, one], + ); + + FlpTest::expect_valid::( + &multihot_instance, + &multihot_instance + .encode_measurement(&vec![false, false, false]) + .unwrap(), + &[zero, zero, zero], + ); + + // Test invalid inputs. + + // Not binary + FlpTest::expect_invalid::( + &multihot_instance, + &[&[zero, zero, nine], &*encoded_weight_plus_offset(1)].concat(), + ); + // Wrong weight + FlpTest::expect_invalid::( + &multihot_instance, + &[&[zero, zero, one], &*encoded_weight_plus_offset(2)].concat(), + ); + // We cannot test the case where the weight is higher than max_weight. This is because + // weight + offset cannot fit into a bitvector of the correct length. In other words, being + // out-of-range requires the prover to lie about their weight, which is tested above + } + + #[test] + fn test_multihot_serial() { + test_multihot(MultihotCountVec::>>::new); + } + fn test_sum_vec(f: F) where F: Fn(usize, usize, usize) -> Result, FlpError>, diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index d3241603d..673d37f1d 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -44,7 +44,7 @@ use crate::flp::gadgets::{Mul, ParallelSum}; use crate::flp::types::fixedpoint_l2::{ compatible_float::CompatibleFloat, FixedPointBoundedL2VecSum, }; -use crate::flp::types::{Average, Count, Histogram, Sum, SumVec}; +use crate::flp::types::{Average, Count, Histogram, MultihotCountVec, Sum, SumVec}; use crate::flp::Type; #[cfg(feature = "experimental")] use crate::flp::TypeWithNoise; @@ -282,6 +282,59 @@ impl Prio3HistogramMultithreaded { } } +/// The multihot counter data type. Each measurement is a list of booleans of length `length`, with +/// at most `max_weight` true values, and the aggregate is a histogram counting the number of true +/// values at each position across all measurements. +pub type Prio3MultihotCountVec = + Prio3>>, XofTurboShake128, 16>; + +impl Prio3MultihotCountVec { + /// Constructs an instance of Prio3MultihotCountVec with the given number of aggregators, number + /// of buckets, max weight, and parallel sum gadget chunk length. + pub fn new_multihot_count_vec( + num_aggregators: u8, + num_buckets: usize, + max_weight: usize, + chunk_length: usize, + ) -> Result { + Prio3::new( + num_aggregators, + 1, + 0xFFFF0000, + MultihotCountVec::new(num_buckets, max_weight, chunk_length)?, + ) + } +} + +/// Like [`Prio3MultihotCountVec`] except this type uses multithreading to improve sharding and preparation +/// time. Note that this improvement is only noticeable for very large input lengths. +#[cfg(feature = "multithreaded")] +#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] +pub type Prio3MultihotCountVecMultithreaded = Prio3< + MultihotCountVec>>, + XofTurboShake128, + 16, +>; + +#[cfg(feature = "multithreaded")] +impl Prio3MultihotCountVecMultithreaded { + /// Constructs an instance of Prio3MultihotCountVecMultithreaded with the given number of + /// aggregators, number of buckets, max weight, and parallel sum gadget chunk length. + pub fn new_multihot_count_vec_multithreaded( + num_aggregators: u8, + num_buckets: usize, + max_weight: usize, + chunk_length: usize, + ) -> Result { + Prio3::new( + num_aggregators, + 1, + 0xFFFF0000, + MultihotCountVec::new(num_buckets, max_weight, chunk_length)?, + ) + } +} + /// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and /// the aggregate is the arithmetic average. pub type Prio3Average = Prio3, XofTurboShake128, 16>; @@ -1519,7 +1572,7 @@ where /// # Panics /// /// This function will panic if `input` is zero. -fn ilog2(input: usize) -> u32 { +pub(crate) fn ilog2(input: usize) -> u32 { if input == 0 { panic!("Tried to take the logarithm of zero"); } From 3035d8f5d0467f7340489f67eff6e67e7fb2b3e7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:50:45 +0000 Subject: [PATCH 17/32] build(deps): Bump thiserror from 1.0.64 to 1.0.67 (#1129) --- Cargo.lock | 22 +++++++++++----------- supply-chain/audits.toml | 6 +++--- supply-chain/imports.lock | 33 +++++++++++++++++++++++++-------- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ec392d51e..9ee42b2a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -558,7 +558,7 @@ checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.87", ] [[package]] @@ -728,9 +728,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.74" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2de98502f212cfcea8d0bb305bd0f49d7ebdd75b64ba0a68f937d888f4e0d6db" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -886,7 +886,7 @@ checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.87", ] [[package]] @@ -966,9 +966,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.46" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89456b690ff72fddcecf231caedbe615c59480c93358a93dfae7fc29e3ebbf0e" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -983,22 +983,22 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "thiserror" -version = "1.0.64" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "3b3c6efbfc763e64eb85c11c25320f0737cb7364c4b6336db90aa9ebe27a0bbd" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "b607164372e89797d78b8e23a6d67d5d1038c1c65efd52e1389ef8b77caba2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.87", ] [[package]] diff --git a/supply-chain/audits.toml b/supply-chain/audits.toml index dbaf88f1a..0e8db2215 100644 --- a/supply-chain/audits.toml +++ b/supply-chain/audits.toml @@ -975,19 +975,19 @@ end = "2025-07-01" criteria = "safe-to-deploy" user-id = 3618 # David Tolnay (dtolnay) start = "2019-03-01" -end = "2024-06-08" +end = "2025-11-04" [[trusted.thiserror]] criteria = "safe-to-deploy" user-id = 3618 # David Tolnay (dtolnay) start = "2019-10-09" -end = "2024-07-25" +end = "2025-11-04" [[trusted.thiserror-impl]] criteria = "safe-to-deploy" user-id = 3618 # David Tolnay (dtolnay) start = "2019-10-09" -end = "2024-07-25" +end = "2025-11-04" [[trusted.thread_local]] criteria = "safe-to-deploy" diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index c5de4353d..b69c55f73 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -65,8 +65,8 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.proc-macro2]] -version = "1.0.74" -when = "2024-01-02" +version = "1.0.86" +when = "2024-06-21" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" @@ -135,22 +135,22 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.syn]] -version = "2.0.46" -when = "2024-01-02" +version = "2.0.87" +when = "2024-11-02" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" [[publisher.thiserror]] -version = "1.0.63" -when = "2024-07-17" +version = "1.0.67" +when = "2024-11-03" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" [[publisher.thiserror-impl]] -version = "1.0.63" -when = "2024-07-17" +version = "1.0.67" +when = "2024-11-03" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" @@ -428,6 +428,23 @@ criteria = "safe-to-run" version = "0.10.5" aggregated-from = "https://chromium.googlesource.com/chromiumos/third_party/rust_crates/+/refs/heads/main/cargo-vet/audits.toml?format=TEXT" +[[audits.google.audits.proc-macro2]] +who = "danakj " +criteria = "safe-to-deploy" +delta = "1.0.86 -> 1.0.87" +notes = "No new unsafe interactions." +aggregated-from = "https://chromium.googlesource.com/chromium/src/+/main/third_party/rust/chromium_crates_io/supply-chain/audits.toml?format=TEXT" + +[[audits.google.audits.proc-macro2]] +who = "Liza Burakova Date: Mon, 4 Nov 2024 17:08:12 +0000 Subject: [PATCH 18/32] build(deps): Bump serde from 1.0.210 to 1.0.214 (#1126) --- Cargo.lock | 8 ++++---- supply-chain/imports.lock | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9ee42b2a6..56e57f90e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -871,18 +871,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.210" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index b69c55f73..19aba3899 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -107,15 +107,15 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde]] -version = "1.0.210" -when = "2024-09-06" +version = "1.0.214" +when = "2024-10-28" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde_derive]] -version = "1.0.210" -when = "2024-09-06" +version = "1.0.214" +when = "2024-10-28" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From 48a020852905829885d96a7f72ac96ce30c2581b Mon Sep 17 00:00:00 2001 From: David Cook Date: Mon, 4 Nov 2024 11:12:16 -0600 Subject: [PATCH 19/32] Prio3 cleanups (#1127) * Remove obsolete MSRV workaround * Correct documentation of multithreaded Prio3 types ParallelSumMultithreaded only accelerates calls to eval_poly(), which is only used during sharding. --- src/flp/types.rs | 5 ++--- src/vdaf/prio3.rs | 22 ++++------------------ 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/src/flp/types.rs b/src/flp/types.rs index b45aba498..b294e3d8d 100644 --- a/src/flp/types.rs +++ b/src/flp/types.rs @@ -6,7 +6,6 @@ use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt}; use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; use crate::flp::{FlpError, Gadget, Type}; use crate::polynomial::poly_range_check; -use crate::vdaf::prio3::ilog2; use std::convert::TryInto; use std::fmt::{self, Debug}; use std::marker::PhantomData; @@ -532,7 +531,7 @@ impl>> MultihotCountV // The bitlength of a measurement is the number of buckets plus the bitlength of the max // weight - let bits_for_weight = ilog2(max_weight) as usize + 1; + let bits_for_weight = max_weight.ilog2() as usize + 1; let meas_length = num_buckets + bits_for_weight; // Gadget calls is ⌈meas_length / chunk_length⌉ @@ -1211,7 +1210,7 @@ mod tests { let nine = TestField::from(9); let encoded_weight_plus_offset = |weight| { - let bits_for_weight = ilog2(max_weight) as usize + 1; + let bits_for_weight = max_weight.ilog2() as usize + 1; let offset = (1 << bits_for_weight) - 1 - max_weight; TestField::encode_as_bitvector( ::Integer::try_from(weight + offset).unwrap(), diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 673d37f1d..02e30e314 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -106,7 +106,7 @@ impl Prio3SumVec { } } -/// Like [`Prio3SumVec`] except this type uses multithreading to improve sharding and preparation +/// Like [`Prio3SumVec`] except this type uses multithreading to improve sharding /// time. Note that the improvement is only noticeable for very large input lengths. #[cfg(feature = "multithreaded")] #[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] @@ -254,7 +254,7 @@ impl Prio3Histogram { } } -/// Like [`Prio3Histogram`] except this type uses multithreading to improve sharding and preparation +/// Like [`Prio3Histogram`] except this type uses multithreading to improve sharding /// time. Note that this improvement is only noticeable for very large input lengths. #[cfg(feature = "multithreaded")] #[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] @@ -306,7 +306,7 @@ impl Prio3MultihotCountVec { } } -/// Like [`Prio3MultihotCountVec`] except this type uses multithreading to improve sharding and preparation +/// Like [`Prio3MultihotCountVec`] except this type uses multithreading to improve sharding /// time. Note that this improvement is only noticeable for very large input lengths. #[cfg(feature = "multithreaded")] #[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] @@ -1565,20 +1565,6 @@ where } } -/// This is a polyfill for `usize::ilog2()`, which is only available in Rust 1.67 and later. It is -/// based on the implementation in the standard library. It can be removed when the MSRV has been -/// advanced past 1.67. -/// -/// # Panics -/// -/// This function will panic if `input` is zero. -pub(crate) fn ilog2(input: usize) -> u32 { - if input == 0 { - panic!("Tried to take the logarithm of zero"); - } - (usize::BITS - 1) - input.leading_zeros() -} - /// Finds the optimal choice of chunk length for [`Prio3Histogram`] or [`Prio3SumVec`], given its /// encoded measurement length. For [`Prio3Histogram`], the measurement length is equal to the /// length parameter. For [`Prio3SumVec`], the measurement length is equal to the product of the @@ -1594,7 +1580,7 @@ pub fn optimal_chunk_length(measurement_length: usize) -> usize { chunk_length: usize, } - let max_log2 = ilog2(measurement_length + 1); + let max_log2 = (measurement_length + 1).ilog2(); let best_opt = (1..=max_log2) .rev() .map(|log2| { From 9245fec4d9cacc2adfe23ed71cb1449aff8a7971 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 19:04:42 +0000 Subject: [PATCH 20/32] build(deps): Bump thiserror from 1.0.67 to 1.0.68 (#1131) --- Cargo.lock | 8 ++++---- supply-chain/imports.lock | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 56e57f90e..c220cf35d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -983,18 +983,18 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "thiserror" -version = "1.0.67" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3c6efbfc763e64eb85c11c25320f0737cb7364c4b6336db90aa9ebe27a0bbd" +checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.67" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b607164372e89797d78b8e23a6d67d5d1038c1c65efd52e1389ef8b77caba2a6" +checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" dependencies = [ "proc-macro2", "quote", diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index 19aba3899..4e5c065d8 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -142,15 +142,15 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.thiserror]] -version = "1.0.67" -when = "2024-11-03" +version = "1.0.68" +when = "2024-11-04" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" [[publisher.thiserror-impl]] -version = "1.0.67" -when = "2024-11-03" +version = "1.0.68" +when = "2024-11-04" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From 731884e5fd766d9bee62ee69c0809fb72a397d64 Mon Sep 17 00:00:00 2001 From: David Cook Date: Fri, 8 Nov 2024 17:12:57 -0600 Subject: [PATCH 21/32] Setup for new release branch (#1130) * Bump version number * Update version table * Revise README * Add new release branch to Dependabot --- .github/dependabot.yml | 12 ++++++++++++ Cargo.lock | 2 +- Cargo.toml | 2 +- README.md | 13 ++++++++----- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 5d0ffc612..0514d2c23 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,3 +9,15 @@ updates: directory: "/" schedule: interval: "weekly" + + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "weekly" + target-branch: release/0.16 + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + target-branch: release/0.16 diff --git a/Cargo.lock b/Cargo.lock index c220cf35d..f74929d93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -678,7 +678,7 @@ checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" [[package]] name = "prio" -version = "0.16.7" +version = "0.17.0-alpha.0" dependencies = [ "aes", "assert_matches", diff --git a/Cargo.toml b/Cargo.toml index 9863eb40d..d244e6796 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "prio" -version = "0.16.7" +version = "0.17.0-alpha.0" authors = ["Josh Aas ", "Tim Geoghegan ", "Christopher Patton "] edition = "2021" exclude = ["/supply-chain"] diff --git a/README.md b/README.md index fbd85cdb4..5b6f84f58 100644 --- a/README.md +++ b/README.md @@ -11,14 +11,14 @@ and Scalable Computation of Aggregate Statistics. ## Exposure Notifications Private Analytics -This crate was used in the [Exposure Notifications Private Analytics][enpa] -system. This is referred to in various places as Prio v2. See +Prior versions of this crate were used in the [Exposure Notifications Private +Analytics][enpa] system. This is referred to in various places as Prio v2. See [`prio-server`][prio-server] or the [ENPA whitepaper][enpa-whitepaper] for more details. ## Verifiable Distributed Aggregation Function -This crate also implements a [Verifiable Distributed Aggregation Function +This crate implements a [Verifiable Distributed Aggregation Function (VDAF)][vdaf] called "Prio3", implemented in the `vdaf` module, allowing Prio to be used in the [Distributed Aggregation Protocol][dap] protocol being developed in the PPM working group at the IETF. This support is still evolving along with @@ -26,7 +26,7 @@ the DAP and VDAF specifications. ### Draft versions and release branches -The `main` branch is under continuous development and will usually be partway between VDAF drafts. +The `main` branch is under continuous development and will sometimes be partway between VDAF drafts. libprio uses stable release branches to maintain implementations of different VDAF draft versions. Crate `prio` version `x.y.z` is released from a corresponding `release/x.y` branch. We try to maintain [Rust SemVer][semver] compatibility, meaning that API breaks only happen on minor version @@ -42,7 +42,8 @@ increases (e.g., 0.10 to 0.11). | 0.13 | `release/0.13` | [`draft-irtf-cfrg-vdaf-06`][vdaf-06] | [`draft-ietf-ppm-dap-05`][dap-05] | Yes | Unmaintained | | 0.14 | `release/0.14` | [`draft-irtf-cfrg-vdaf-06`][vdaf-06] | [`draft-ietf-ppm-dap-05`][dap-05] | Yes | Unmaintained | | 0.15 | `release/0.15` | [`draft-irtf-cfrg-vdaf-07`][vdaf-07] | [`draft-ietf-ppm-dap-07`][dap-07] | Yes | Unmaintained as of June 24, 2024 | -| 0.16 | `main` | [`draft-irtf-cfrg-vdaf-08`][vdaf-08] | [`draft-ietf-ppm-dap-09`][dap-09] | Yes | Supported | +| 0.16 | `release/0.16` | [`draft-irtf-cfrg-vdaf-08`][vdaf-08] | [`draft-ietf-ppm-dap-09`][dap-09] | Yes | Supported | +| 0.17 | `main` | [`draft-irtf-cfrg-vdaf-13`][vdaf-13] | [`draft-ietf-ppm-dap-13`][dap-13] | [No](https://github.com/divviup/libprio-rs/issues/1122) | Supported | [vdaf-01]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/01/ [vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ @@ -51,6 +52,7 @@ increases (e.g., 0.10 to 0.11). [vdaf-06]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/06/ [vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ [vdaf-08]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/08/ +[vdaf-13]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/13/ [dap-01]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/01/ [dap-02]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/02/ [dap-03]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/03/ @@ -58,6 +60,7 @@ increases (e.g., 0.10 to 0.11). [dap-05]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/05/ [dap-07]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/07/ [dap-09]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/09/ +[dap-13]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/13/ [enpa]: https://www.abetterinternet.org/post/prio-services-for-covid-en/ [enpa-whitepaper]: https://covid19-static.cdn-apple.com/applications/covid19/current/static/contact-tracing/pdf/ENPA_White_Paper.pdf [prio-server]: https://github.com/divviup/prio-server From 1e39ac2834ee5892e99fb5f8c24890e4f13c1805 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 8 Nov 2024 23:23:55 +0000 Subject: [PATCH 22/32] build(deps): Bump thiserror from 1.0.68 to 2.0.1 (#1135) --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- supply-chain/imports.lock | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f74929d93..c14302e34 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -983,18 +983,18 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "thiserror" -version = "1.0.68" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" +checksum = "07c1e40dd48a282ae8edc36c732cbc219144b87fb6a4c7316d611c6b1f06ec0c" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.68" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" +checksum = "874aa7e446f1da8d9c3a5c95b1c5eb41d800045252121dc7f8e0ba370cee55f5" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index d244e6796..cd5396a81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ serde_json = { version = "1.0", optional = true } sha2 = { version = "0.10.8", optional = true } sha3 = "0.10.8" subtle = "2.6.1" -thiserror = "1.0" +thiserror = "2.0" zipf = { version = "7.0.1", optional = true } [dev-dependencies] diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index 4e5c065d8..bf9029188 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -142,15 +142,15 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.thiserror]] -version = "1.0.68" -when = "2024-11-04" +version = "2.0.1" +when = "2024-11-08" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" [[publisher.thiserror-impl]] -version = "1.0.68" -when = "2024-11-04" +version = "2.0.1" +when = "2024-11-08" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From baf3b8675433d241922fbf6a35dea95ddb95c599 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Nov 2024 15:52:10 +0000 Subject: [PATCH 23/32] build(deps): Bump thiserror from 2.0.1 to 2.0.3 (#1138) --- Cargo.lock | 8 ++++---- supply-chain/imports.lock | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c14302e34..9e599e8da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -983,18 +983,18 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "thiserror" -version = "2.0.1" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07c1e40dd48a282ae8edc36c732cbc219144b87fb6a4c7316d611c6b1f06ec0c" +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.1" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "874aa7e446f1da8d9c3a5c95b1c5eb41d800045252121dc7f8e0ba370cee55f5" +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" dependencies = [ "proc-macro2", "quote", diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index bf9029188..ead58bc02 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -142,15 +142,15 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.thiserror]] -version = "2.0.1" -when = "2024-11-08" +version = "2.0.3" +when = "2024-11-10" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" [[publisher.thiserror-impl]] -version = "2.0.1" -when = "2024-11-08" +version = "2.0.3" +when = "2024-11-10" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From a8e48e7548b4653ed85fab7bc4e51127611dc6ee Mon Sep 17 00:00:00 2001 From: Michael Rosenberg Date: Tue, 12 Nov 2024 11:06:42 -0500 Subject: [PATCH 24/32] Make FLP Eval Return `Vec` rather than a single `F` (#1132) * Simplified Average to use Sum internally * Made Type::valid return a Vec rather than F * Made Prio3Sum circuit output bits-many elements --------- Co-authored-by: Michael Rosenberg --- src/flp.rs | 108 ++++++++++++----- src/flp/szk.rs | 4 +- src/flp/types.rs | 213 ++++++++++++++------------------- src/flp/types/fixedpoint_l2.rs | 21 ++-- src/vdaf/prio3.rs | 7 -- src/vdaf/prio3_test.rs | 8 ++ 6 files changed, 191 insertions(+), 170 deletions(-) diff --git a/src/flp.rs b/src/flp.rs index ebba717ca..a34f3cf0a 100644 --- a/src/flp.rs +++ b/src/flp.rs @@ -156,6 +156,9 @@ pub trait Type: Sized + Eq + Clone + Debug { /// [BBCG+19]: https://ia.cr/2019/188 fn gadget(&self) -> Vec>>; + /// Returns the number of gadgets associated with this validity circuit. This MUST equal `self.gadget().len()`. + fn num_gadgets(&self) -> usize; + /// Evaluates the validity circuit on an input and returns the output. /// /// # Parameters @@ -179,7 +182,7 @@ pub trait Type: Sized + Eq + Clone + Debug { /// let input: Vec = count.encode_measurement(&true).unwrap(); /// let joint_rand = random_vector(count.joint_rand_len()).unwrap(); /// let v = count.valid(&mut count.gadget(), &input, &joint_rand, 1).unwrap(); - /// assert_eq!(v, Field64::zero()); + /// assert!(v.into_iter().all(|f| f == Field64::zero())); /// ``` fn valid( &self, @@ -187,7 +190,7 @@ pub trait Type: Sized + Eq + Clone + Debug { input: &[Self::Field], joint_rand: &[Self::Field], num_shares: usize, - ) -> Result; + ) -> Result, FlpError>; /// Constructs an aggregatable output from an encoded input. Calling this method is only safe /// once `input` has been validated. @@ -208,14 +211,25 @@ pub trait Type: Sized + Eq + Clone + Debug { /// The length of the joint random input. fn joint_rand_len(&self) -> usize; + /// The length of the circuit output + fn eval_output_len(&self) -> usize; + /// The length in field elements of the random input consumed by the prover to generate a /// proof. This is the same as the sum of the arity of each gadget in the validity circuit. fn prove_rand_len(&self) -> usize; /// The length in field elements of the random input consumed by the verifier to make queries /// against inputs and proofs. This is the same as the number of gadgets in the validity - /// circuit. - fn query_rand_len(&self) -> usize; + /// circuit, plus the number of elements output by the validity circuit (if >1). + fn query_rand_len(&self) -> usize { + let mut n = self.num_gadgets(); + let eval_elems = self.eval_output_len(); + if eval_elems > 1 { + n += eval_elems; + } + + n + } /// Generate a proof of an input's validity. The return value is a sequence of /// [`Self::proof_len`] field elements. @@ -388,6 +402,24 @@ pub trait Type: Sized + Eq + Clone + Debug { self.query_rand_len() ))); } + // We use query randomness to compress outputs from `valid()` (if size is > 1), as well as + // for gadget evaluations. Split these up + let (query_rand_for_validity, query_rand_for_gadgets) = if self.eval_output_len() > 1 { + query_rand.split_at(self.eval_output_len()) + } else { + query_rand.split_at(0) + }; + + // Another check that we have the right amount of randomness + let my_gadgets = self.gadget(); + if query_rand_for_gadgets.len() != my_gadgets.len() { + return Err(FlpError::Query(format!( + "length of query randomness for gadgets doesn't match number of gadgets: \ + got {}; want {}", + query_rand_for_gadgets.len(), + my_gadgets.len() + ))); + } if joint_rand.len() != self.joint_rand_len() { return Err(FlpError::Query(format!( @@ -398,15 +430,13 @@ pub trait Type: Sized + Eq + Clone + Debug { } let mut proof_len = 0; - let mut shims = self - .gadget() + let mut shims = my_gadgets .into_iter() - .enumerate() - .map(|(idx, gadget)| { + .zip(query_rand_for_gadgets) + .map(|(gadget, &r)| { let gadget_degree = gadget.degree(); let gadget_arity = gadget.arity(); let m = (1 + gadget.calls()).next_power_of_two(); - let r = query_rand[idx]; // Make sure the query randomness isn't a root of unity. Evaluating the gadget // polynomial at any of these points would be a privacy violation, since these points @@ -419,7 +449,7 @@ pub trait Type: Sized + Eq + Clone + Debug { ))); } - // Compute the length of the sub-proof corresponding to the `idx`-th gadget. + // Compute the length of the sub-proof corresponding to this gadget. let next_len = gadget_arity + gadget_degree * (m - 1) + 1; let proof_data = &proof[proof_len..proof_len + next_len]; proof_len += next_len; @@ -444,10 +474,23 @@ pub trait Type: Sized + Eq + Clone + Debug { // should be OK, since it's possible to transform any circuit into one for which this is true. // (Needs security analysis.) let validity = self.valid(&mut shims, input, joint_rand, num_shares)?; - verifier.push(validity); + assert_eq!(validity.len(), self.eval_output_len()); + // If `valid()` outputs multiple field elements, compress them into 1 field element using + // query randomness + let check = if validity.len() > 1 { + validity + .iter() + .zip(query_rand_for_validity) + .fold(Self::Field::zero(), |acc, (&val, &r)| acc + r * val) + } else { + // If `valid()` outputs one field element, just use that. If it outputs none, then it is + // trivially satisfied, so use 0 + validity.first().cloned().unwrap_or(Self::Field::zero()) + }; + verifier.push(check); // Fill the buffer with the verifier message. - for (query_rand_val, shim) in query_rand[..shims.len()].iter().zip(shims.iter_mut()) { + for (query_rand_val, shim) in query_rand_for_gadgets.iter().zip(shims.iter_mut()) { let gadget = shim .as_any() .downcast_ref::>() @@ -836,11 +879,6 @@ pub mod test_utils { let joint_rand = random_vector(self.flp.joint_rand_len()).unwrap(); let prove_rand = random_vector(self.flp.prove_rand_len()).unwrap(); let query_rand = random_vector(self.flp.query_rand_len()).unwrap(); - assert_eq!( - self.flp.query_rand_len(), - gadgets.len(), - "{name}: unexpected number of gadgets" - ); assert_eq!( self.flp.joint_rand_len(), joint_rand.len(), @@ -863,9 +901,9 @@ pub mod test_utils { .valid(&mut gadgets, self.input, &joint_rand, 1) .unwrap(); assert_eq!( - v == T::Field::zero(), + v.iter().all(|f| f == &T::Field::zero()), self.expect_valid, - "{name}: unexpected output of valid() returned {v}", + "{name}: unexpected output of valid() returned {v:?}", ); // Generate the proof. @@ -1056,7 +1094,7 @@ mod tests { input: &[F], joint_rand: &[F], _num_shares: usize, - ) -> Result { + ) -> Result, FlpError> { let r = joint_rand[0]; let mut res = F::zero(); @@ -1071,7 +1109,7 @@ mod tests { let x_checked = g[1].call(&[input[0]])?; res += (r * r) * x_checked; - Ok(res) + Ok(vec![res]) } fn input_len(&self) -> usize { @@ -1108,12 +1146,12 @@ mod tests { 1 } - fn prove_rand_len(&self) -> usize { - 3 + fn eval_output_len(&self) -> usize { + 1 } - fn query_rand_len(&self) -> usize { - 2 + fn prove_rand_len(&self) -> usize { + 3 } fn gadget(&self) -> Vec>> { @@ -1123,6 +1161,10 @@ mod tests { ] } + fn num_gadgets(&self) -> usize { + 2 + } + fn encode_measurement(&self, measurement: &F::Integer) -> Result, FlpError> { Ok(vec![ F::from(*measurement), @@ -1190,7 +1232,7 @@ mod tests { input: &[F], _joint_rand: &[F], _num_shares: usize, - ) -> Result { + ) -> Result, FlpError> { // This is a useless circuit, as it only accepts "0". Its purpose is to exercise the // use of multiple gadgets, each of which is called an arbitrary number of times. let mut res = F::zero(); @@ -1200,7 +1242,7 @@ mod tests { for _ in 0..self.num_gadget_calls[1] { res += g[1].call(&[input[0]])?; } - Ok(res) + Ok(vec![res]) } fn input_len(&self) -> usize { @@ -1237,6 +1279,10 @@ mod tests { 0 } + fn eval_output_len(&self) -> usize { + 1 + } + fn prove_rand_len(&self) -> usize { // First chunk let first = 1; // gadget arity @@ -1247,10 +1293,6 @@ mod tests { first + second } - fn query_rand_len(&self) -> usize { - 2 // number of gadgets - } - fn gadget(&self) -> Vec>> { let poly = poly_range_check(0, 2); // A polynomial with degree 2 vec![ @@ -1259,6 +1301,10 @@ mod tests { ] } + fn num_gadgets(&self) -> usize { + 2 + } + fn encode_measurement(&self, measurement: &F::Integer) -> Result, FlpError> { Ok(vec![F::from(*measurement)]) } diff --git a/src/flp/szk.rs b/src/flp/szk.rs index ef504204f..e25598d0f 100644 --- a/src/flp/szk.rs +++ b/src/flp/szk.rs @@ -904,7 +904,7 @@ mod tests { let szk_typ = Szk::new_turboshake128(sum, algorithm_id); let prove_rand_seed = Seed::<16>::generate().unwrap(); let helper_seed = Seed::<16>::generate().unwrap(); - let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let leader_seed_opt = None; let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); let mut leader_input_share = encoded_measurement.clone().to_owned(); for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { @@ -944,7 +944,7 @@ mod tests { let szk_typ = Szk::new_turboshake128(sum, algorithm_id); let prove_rand_seed = Seed::<16>::generate().unwrap(); let helper_seed = Seed::<16>::generate().unwrap(); - let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let leader_seed_opt = None; let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); let mut leader_input_share = encoded_measurement.clone().to_owned(); for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { diff --git a/src/flp/types.rs b/src/flp/types.rs index b294e3d8d..9403039ef 100644 --- a/src/flp/types.rs +++ b/src/flp/types.rs @@ -63,15 +63,20 @@ impl Type for Count { vec![Box::new(Mul::new(1))] } + fn num_gadgets(&self) -> usize { + 1 + } + fn valid( &self, g: &mut Vec>>, input: &[F], joint_rand: &[F], _num_shares: usize, - ) -> Result { + ) -> Result, FlpError> { self.valid_call_check(input, joint_rand)?; - Ok(g[0].call(&[input[0], input[0]])? - input[0]) + let out = g[0].call(&[input[0], input[0]])? - input[0]; + Ok(vec![out]) } fn truncate(&self, input: Vec) -> Result, FlpError> { @@ -99,12 +104,12 @@ impl Type for Count { 0 } - fn prove_rand_len(&self) -> usize { - 2 + fn eval_output_len(&self) -> usize { + 1 } - fn query_rand_len(&self) -> usize { - 1 + fn prove_rand_len(&self) -> usize { + 2 } } @@ -164,15 +169,20 @@ impl Type for Sum { ))] } + fn num_gadgets(&self) -> usize { + 1 + } + fn valid( &self, g: &mut Vec>>, input: &[F], joint_rand: &[F], _num_shares: usize, - ) -> Result { + ) -> Result, FlpError> { self.valid_call_check(input, joint_rand)?; - call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0]) + let gadget = &mut g[0]; + input.iter().map(|&b| gadget.call(&[b])).collect() } fn truncate(&self, input: Vec) -> Result, FlpError> { @@ -198,29 +208,32 @@ impl Type for Sum { } fn joint_rand_len(&self) -> usize { - 1 + 0 } - fn prove_rand_len(&self) -> usize { - 1 + fn eval_output_len(&self) -> usize { + self.bits } - fn query_rand_len(&self) -> usize { + fn prove_rand_len(&self) -> usize { 1 } } /// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the /// aggregate is the arithmetic average. +// This is just a `Sum` object under the hood. The only difference is that the aggregate result is +// an f64, which we get by dividing by `num_measurements` #[derive(Clone, PartialEq, Eq)] pub struct Average { - bits: usize, - range_checker: Vec, + summer: Sum, } impl Debug for Average { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Average").field("bits", &self.bits).finish() + f.debug_struct("Average") + .field("bits", &self.summer.bits) + .finish() } } @@ -228,16 +241,8 @@ impl Average { /// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0, /// 2^bits)`. pub fn new(bits: usize) -> Result { - if !F::valid_integer_bitlength(bits) { - return Err(FlpError::Encode( - "invalid bits: number of bits exceeds maximum number of bits in this field" - .to_string(), - )); - } - Ok(Self { - bits, - range_checker: poly_range_check(0, 2), - }) + let summer = Sum::new(bits)?; + Ok(Average { summer }) } } @@ -247,25 +252,25 @@ impl Type for Average { type Field = F; fn encode_measurement(&self, summand: &F::Integer) -> Result, FlpError> { - let v = F::encode_as_bitvector(*summand, self.bits)?.collect(); - Ok(v) + self.summer.encode_measurement(summand) } fn decode_result(&self, data: &[F], num_measurements: usize) -> Result { // Compute the average from the aggregated sum. - let data = decode_result(data)?; - let data: u64 = data.try_into().map_err(|err| { - FlpError::Decode(format!("failed to convert {data:?} to u64: {err}",)) - })?; + let sum = self.summer.decode_result(data, num_measurements)?; + let data: u64 = sum + .try_into() + .map_err(|err| FlpError::Decode(format!("failed to convert {sum:?} to u64: {err}",)))?; let result = (data as f64) / (num_measurements as f64); Ok(result) } fn gadget(&self) -> Vec>> { - vec![Box::new(PolyEval::new( - self.range_checker.clone(), - self.bits, - ))] + self.summer.gadget() + } + + fn num_gadgets(&self) -> usize { + self.summer.num_gadgets() } fn valid( @@ -273,44 +278,41 @@ impl Type for Average { g: &mut Vec>>, input: &[F], joint_rand: &[F], - _num_shares: usize, - ) -> Result { - self.valid_call_check(input, joint_rand)?; - call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0]) + num_shares: usize, + ) -> Result, FlpError> { + self.summer.valid(g, input, joint_rand, num_shares) } fn truncate(&self, input: Vec) -> Result, FlpError> { - self.truncate_call_check(&input)?; - let res = F::decode_bitvector(&input)?; - Ok(vec![res]) + self.summer.truncate(input) } fn input_len(&self) -> usize { - self.bits + self.summer.bits } fn proof_len(&self) -> usize { - 2 * ((1 + self.bits).next_power_of_two() - 1) + 2 + self.summer.proof_len() } fn verifier_len(&self) -> usize { - 3 + self.summer.verifier_len() } fn output_len(&self) -> usize { - 1 + self.summer.output_len() } fn joint_rand_len(&self) -> usize { - 1 + self.summer.joint_rand_len() } - fn prove_rand_len(&self) -> usize { - 1 + fn eval_output_len(&self) -> usize { + self.summer.eval_output_len() } - fn query_rand_len(&self) -> usize { - 1 + fn prove_rand_len(&self) -> usize { + self.summer.prove_rand_len() } } @@ -408,23 +410,22 @@ where ))] } + fn num_gadgets(&self) -> usize { + 1 + } + fn valid( &self, g: &mut Vec>>, input: &[F], joint_rand: &[F], num_shares: usize, - ) -> Result { + ) -> Result, FlpError> { self.valid_call_check(input, joint_rand)?; // Check that each element of `input` is a 0 or 1. - let range_check = parallel_sum_range_checks( - &mut g[0], - input, - joint_rand[0], - self.chunk_length, - num_shares, - )?; + let range_check = + parallel_sum_range_checks(&mut g[0], input, joint_rand, self.chunk_length, num_shares)?; // Check that the elements of `input` sum to 1. let mut sum_check = -F::from(F::valid_integer_try_from(num_shares)?).inv(); @@ -432,9 +433,7 @@ where sum_check += *val; } - // Take a random linear combination of both checks. - let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * sum_check; - Ok(out) + Ok(vec![range_check, sum_check]) } fn truncate(&self, input: Vec) -> Result, FlpError> { @@ -459,16 +458,16 @@ where } fn joint_rand_len(&self) -> usize { + self.gadget_calls + } + + fn eval_output_len(&self) -> usize { 2 } fn prove_rand_len(&self) -> usize { self.chunk_length * 2 } - - fn query_rand_len(&self) -> usize { - 1 - } } /// The multihot counter data type. Each measurement is a list of booleans of length `length`, with @@ -625,23 +624,22 @@ where ))] } + fn num_gadgets(&self) -> usize { + 1 + } + fn valid( &self, g: &mut Vec>>, input: &[F], joint_rand: &[F], num_shares: usize, - ) -> Result { + ) -> Result, FlpError> { self.valid_call_check(input, joint_rand)?; // Check that each element of `input` is a 0 or 1. - let range_check = parallel_sum_range_checks( - &mut g[0], - input, - joint_rand[0], - self.chunk_length, - num_shares, - )?; + let range_check = + parallel_sum_range_checks(&mut g[0], input, joint_rand, self.chunk_length, num_shares)?; // Check that the elements of `input` sum to at most `max_weight`. let count_vec = &input[..self.length]; @@ -655,9 +653,7 @@ where offset * shares_inv + weight - offset_weight_reported }; - // Take a random linear combination of both checks. - let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * weight_check; - Ok(out) + Ok(vec![range_check, weight_check]) } // Truncates the measurement, removing extra data that was necessary for validity (here, the @@ -688,18 +684,16 @@ where // The number of random values needed in the validity checks fn joint_rand_len(&self) -> usize { + self.gadget_calls + } + + fn eval_output_len(&self) -> usize { 2 } fn prove_rand_len(&self) -> usize { self.chunk_length * 2 } - - fn query_rand_len(&self) -> usize { - // TODO: this will need to be increase once draft-10 is implemented and more randomness is - // necessary due to random linear combination computations - 1 - } } /// A sequence of integers in range `[0, 2^bits)`. This type uses a neat trick from [[BBCG+19], @@ -848,22 +842,21 @@ where ))] } + fn num_gadgets(&self) -> usize { + 1 + } + fn valid( &self, g: &mut Vec>>, input: &[F], joint_rand: &[F], num_shares: usize, - ) -> Result { + ) -> Result, FlpError> { self.valid_call_check(input, joint_rand)?; - parallel_sum_range_checks( - &mut g[0], - input, - joint_rand[0], - self.chunk_length, - num_shares, - ) + parallel_sum_range_checks(&mut g[0], input, joint_rand, self.chunk_length, num_shares) + .map(|out| vec![out]) } fn truncate(&self, input: Vec) -> Result, FlpError> { @@ -892,37 +885,16 @@ where } fn joint_rand_len(&self) -> usize { - 1 - } - - fn prove_rand_len(&self) -> usize { - self.chunk_length * 2 + self.gadget_calls } - fn query_rand_len(&self) -> usize { + fn eval_output_len(&self) -> usize { 1 } -} -/// Compute a random linear combination of the result of calls of `g` on each element of `input`. -/// -/// # Arguments -/// -/// * `g` - The gadget to be applied elementwise -/// * `input` - The vector on whose elements to apply `g` -/// * `rnd` - The randomness used for the linear combination -pub(crate) fn call_gadget_on_vec_entries( - g: &mut Box>, - input: &[F], - rnd: F, -) -> Result { - let mut comb = F::zero(); - let mut r = rnd; - for chunk in input.chunks(1) { - comb += r * g.call(chunk)?; - r *= rnd; + fn prove_rand_len(&self) -> usize { + self.chunk_length * 2 } - Ok(comb) } /// Given a vector `data` of field elements which should contain exactly one entry, return the @@ -970,7 +942,7 @@ pub(crate) fn decode_result_vec( pub(crate) fn parallel_sum_range_checks( gadget: &mut Box>, input: &[F], - joint_randomness: F, + joint_randomness: &[F], chunk_length: usize, num_shares: usize, ) -> Result { @@ -978,15 +950,16 @@ pub(crate) fn parallel_sum_range_checks( let num_shares_inverse = f_num_shares.inv(); let mut output = F::zero(); - let mut r_power = joint_randomness; let mut padded_chunk = vec![F::zero(); 2 * chunk_length]; - for chunk in input.chunks(chunk_length) { + for (chunk, &r) in input.chunks(chunk_length).zip(joint_randomness) { + let mut r_power = r; + // Construct arguments for the Mul subcircuits. for (input, args) in chunk.iter().zip(padded_chunk.chunks_exact_mut(2)) { args[0] = r_power * *input; args[1] = *input - num_shares_inverse; - r_power *= joint_randomness; + r_power *= r; } // If the chunk of the input is smaller than chunk_length, use zeros instead of measurement // inputs for the remaining calls. diff --git a/src/flp/types/fixedpoint_l2.rs b/src/flp/types/fixedpoint_l2.rs index fbedb9321..f17559875 100644 --- a/src/flp/types/fixedpoint_l2.rs +++ b/src/flp/types/fixedpoint_l2.rs @@ -462,13 +462,17 @@ where vec![Box::new(gadget0), Box::new(gadget1)] } + fn num_gadgets(&self) -> usize { + 2 + } + fn valid( &self, g: &mut Vec>>, input: &[Field128], joint_rand: &[Field128], num_shares: usize, - ) -> Result { + ) -> Result, FlpError> { self.valid_call_check(input, joint_rand)?; let f_num_shares = Field128::from(Field128::valid_integer_try_from::(num_shares)?); @@ -491,7 +495,7 @@ where let range_check = parallel_sum_range_checks( &mut g[0], &input[..self.range_norm_end], - joint_rand[0], + joint_rand, self.gadget0_chunk_length, num_shares, )?; @@ -550,10 +554,7 @@ where let norm_check = computed_norm - submitted_norm; - // Finally, we require both checks to be successful by computing a - // random linear combination of them. - let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * norm_check; - Ok(out) + Ok(vec![range_check, norm_check]) } fn truncate(&self, input: Vec) -> Result, FlpError> { @@ -598,16 +599,16 @@ where } fn joint_rand_len(&self) -> usize { + self.gadget0_calls + } + + fn eval_output_len(&self) -> usize { 2 } fn prove_rand_len(&self) -> usize { self.gadget0_chunk_length * 2 + self.gadget1_chunk_length } - - fn query_rand_len(&self) -> usize { - 2 - } } impl TypeWithNoise diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 02e30e314..f6e42d972 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -1664,11 +1664,6 @@ mod tests { thread_rng().fill(&mut verify_key[..]); let nonce = [0; 16]; - let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap(); - input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255; - let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); - assert_matches!(result, Err(VdafError::Uncategorized(_))); - let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap(); assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); @@ -2007,8 +2002,6 @@ mod tests { { assert_ne!(left, right); } - - assert_ne!(x.joint_rand_blind, y.joint_rand_blind); } } } diff --git a/src/vdaf/prio3_test.rs b/src/vdaf/prio3_test.rs index 479c6b4c3..e7627be72 100644 --- a/src/vdaf/prio3_test.rs +++ b/src/vdaf/prio3_test.rs @@ -263,6 +263,11 @@ mod tests { } } + // All the below tests are not passing. We ignore them until the rest of the repo is in a state + // where we can regenerate the JSON test vectors. + // Tracking issue https://github.com/divviup/libprio-rs/issues/1122 + + #[ignore] #[test] fn test_vec_prio3_sum() { for test_vector_str in [ @@ -276,6 +281,7 @@ mod tests { } } + #[ignore] #[test] fn test_vec_prio3_sum_vec() { for test_vector_str in [ @@ -291,6 +297,7 @@ mod tests { } } + #[ignore] #[test] fn test_vec_prio3_sum_vec_multiproof() { type Prio3SumVecField64Multiproof = @@ -314,6 +321,7 @@ mod tests { } } + #[ignore] #[test] fn test_vec_prio3_histogram() { for test_vector_str in [ From d5e9b15a004b485226195dea18099d64367b24c5 Mon Sep 17 00:00:00 2001 From: Ameer Ghani Date: Thu, 14 Nov 2024 11:12:14 -0600 Subject: [PATCH 25/32] cargo vet: zlib-rs audit (#1140) --- supply-chain/audits.toml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/supply-chain/audits.toml b/supply-chain/audits.toml index 0e8db2215..cd8aecfac 100644 --- a/supply-chain/audits.toml +++ b/supply-chain/audits.toml @@ -398,6 +398,16 @@ who = "Brandon Pitman " criteria = "safe-to-deploy" delta = "0.2.149 -> 0.2.150" +[[audits.libz-rs-sys]] +who = "Ameer Ghani " +criteria = "safe-to-deploy" +version = "0.4.0" +notes = """ +This crate uses unsafe since it's for C to Rust FFI. I have reviewed and fuzzed it, and I believe it is free of any serious security problems. + +The only dependency is zlib-rs, which is maintained by the same maintainers as this crate. +""" + [[audits.linux-raw-sys]] who = "Brandon Pitman " criteria = "safe-to-run" @@ -839,6 +849,16 @@ who = "Tim Geoghegan " criteria = "safe-to-run" delta = "7.0.0 -> 7.0.1" +[[audits.zlib-rs]] +who = "Ameer Ghani " +criteria = "safe-to-deploy" +version = "0.4.0" +notes = """ +zlib-rs uses unsafe Rust for invoking compiler intrinsics (i.e. SIMD), eschewing bounds checks, along the FFI boundary, and for interacting with pointers sourced from C. I have extensively reviewed and fuzzed the unsafe code. All findings from that work have been resolved as of version 0.4.0. To the best of my ability, I believe it's free of any serious security problems. + +zlib-rs does not require any external dependencies. +""" + [[trusted.byteorder]] criteria = "safe-to-deploy" user-id = 189 # Andrew Gallant (BurntSushi) From d57731b5b3774e8af2b4a28ea5863eeb0ff02b22 Mon Sep 17 00:00:00 2001 From: Michael Rosenberg Date: Thu, 14 Nov 2024 19:39:29 -0500 Subject: [PATCH 26/32] Added `Aggregator::is_agg_param_valid` method (#1139) * Added Aggregator::is_agg_param_valid method, representing the is_valid method in the spec Co-authored-by: Michael Rosenberg Co-authored-by: David Cook --- src/vdaf.rs | 5 +++ src/vdaf/dummy.rs | 4 +++ src/vdaf/poplar1.rs | 85 +++++++++++++++++++++++++++++++++++++++++++++ src/vdaf/prio2.rs | 5 +++ src/vdaf/prio3.rs | 5 +++ 5 files changed, 104 insertions(+) diff --git a/src/vdaf.rs b/src/vdaf.rs index 8dc8b2fe1..2c68f2e40 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -298,6 +298,11 @@ pub trait Aggregator: Vda agg_param: &Self::AggregationParam, output_shares: M, ) -> Result; + + /// Validates an aggregation parameter with respect to all previous aggregaiton parameters used + /// for the same input share. `prev` MUST be sorted from least to most recently used. + #[must_use] + fn is_agg_param_valid(cur: &Self::AggregationParam, prev: &[Self::AggregationParam]) -> bool; } /// Aggregator that implements differential privacy with Aggregator-side noise addition. diff --git a/src/vdaf/dummy.rs b/src/vdaf/dummy.rs index 6903cd547..5b969bc19 100644 --- a/src/vdaf/dummy.rs +++ b/src/vdaf/dummy.rs @@ -166,6 +166,10 @@ impl vdaf::Aggregator<0, 16> for Vdaf { } Ok(aggregate_share) } + + fn is_agg_param_valid(_cur: &Self::AggregationParam, _prev: &[Self::AggregationParam]) -> bool { + true + } } impl vdaf::Client<16> for Vdaf { diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 514ed3906..1d396e4da 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -17,6 +17,7 @@ use crate::{ use bitvec::{prelude::Lsb0, vec::BitVec}; use rand_core::RngCore; use std::{ + collections::BTreeSet, convert::TryFrom, fmt::Debug, io::{Cursor, Read}, @@ -1245,6 +1246,39 @@ impl, const SEED_SIZE: usize> Aggregator output_shares, ) } + + /// Validates that no aggregation parameter with the same level as `cur` has been used with the + /// same input share before. `prev` contains the aggregation parameters used for the same input. + /// `prev` MUST be sorted from least to most recently used. + fn is_agg_param_valid(cur: &Poplar1AggregationParam, prev: &[Poplar1AggregationParam]) -> bool { + // Exit early if there are no previous aggregation params to compare to, i.e., this is the + // first time the input share has been processed + if prev.is_empty() { + return true; + } + + // Unpack this agg param and the last one in the list + let Poplar1AggregationParam { + level: cur_level, + prefixes: cur_prefixes, + } = cur; + let Poplar1AggregationParam { + level: last_level, + prefixes: last_prefixes, + } = prev.last().as_ref().unwrap(); + let last_prefixes_set = BTreeSet::from_iter(last_prefixes); + + // Check that the level increased. + if cur_level <= last_level { + return false; + } + + // Check that current prefixes are extensions of the last level's prefixes. + cur_prefixes.iter().all(|cur_prefix| { + let last_prefix = cur_prefix.prefix(*last_level as usize); + last_prefixes_set.contains(&last_prefix) + }) + } } impl, const SEED_SIZE: usize> Collector for Poplar1 { @@ -1979,6 +2013,57 @@ mod tests { assert_matches!(err, CodecError::Other(_)); } + // Tests Poplar1::is_valid() functionality. This unit test is translated from + // https://github.com/cfrg/draft-irtf-cfrg-vdaf/blob/a4874547794818573acd8734874c9784043b1140/poc/tests/test_vdaf_poplar1.py#L187 + #[test] + fn agg_param_validity() { + // The actual Poplar instance doesn't matter for the parameter validity tests + type V = Poplar1; + + // Helper function for making aggregation params + fn make_agg_param(bitstrings: &[&[u8]]) -> Result { + Poplar1AggregationParam::try_from_prefixes( + bitstrings + .iter() + .map(|v| { + let bools = v.iter().map(|&b| b != 0).collect::>(); + IdpfInput::from_bools(&bools) + }) + .collect(), + ) + } + + // Test `is_valid` returns False on repeated levels, and True otherwise. + let agg_params = [ + make_agg_param(&[&[0], &[1]]).unwrap(), + make_agg_param(&[&[0, 0]]).unwrap(), + make_agg_param(&[&[0, 0], &[1, 0]]).unwrap(), + ]; + assert!(V::is_agg_param_valid(&agg_params[0], &[])); + assert!(V::is_agg_param_valid(&agg_params[1], &agg_params[..1])); + assert!(!V::is_agg_param_valid(&agg_params[2], &agg_params[..2])); + + // Test `is_valid` accepts level jumps. + let agg_params = [ + make_agg_param(&[&[0], &[1]]).unwrap(), + make_agg_param(&[&[0, 1, 0], &[0, 1, 1], &[1, 0, 1], &[1, 1, 1]]).unwrap(), + ]; + assert!(V::is_agg_param_valid(&agg_params[1], &agg_params[..1])); + + // Test `is_valid` rejects unconnected prefixes. + let agg_params = [ + make_agg_param(&[&[0]]).unwrap(), + make_agg_param(&[&[0, 1, 0], &[0, 1, 1], &[1, 0, 1], &[1, 1, 1]]).unwrap(), + ]; + assert!(!V::is_agg_param_valid(&agg_params[1], &agg_params[..1])); + + // Test that the `Poplar1AggregationParam` constructor rejects unsorted and duplicate + // prefixes. + assert!(make_agg_param(&[&[1], &[0]]).is_err()); + assert!(make_agg_param(&[&[1, 0, 0], &[0, 1, 1]]).is_err()); + assert!(make_agg_param(&[&[0, 0, 0], &[0, 1, 0], &[0, 1, 0]]).is_err()); + } + #[derive(Debug, Deserialize)] struct HexEncoded(#[serde(with = "hex")] Vec); diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index ba725d90d..680f09ea7 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -325,6 +325,11 @@ impl Aggregator<32, 16> for Prio2 { Ok(agg_share) } + + /// Returns `true` iff `prev.is_empty()` + fn is_agg_param_valid(_cur: &Self::AggregationParam, prev: &[Self::AggregationParam]) -> bool { + prev.is_empty() + } } impl Collector for Prio2 { diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index f6e42d972..f0d482dea 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -1441,6 +1441,11 @@ where Ok(agg_share) } + + /// Returns `true` iff `prev.is_empty()` + fn is_agg_param_valid(_cur: &Self::AggregationParam, prev: &[Self::AggregationParam]) -> bool { + prev.is_empty() + } } #[cfg(feature = "experimental")] From c85e537682a7932edc0c44c80049df0429d6fa4c Mon Sep 17 00:00:00 2001 From: Michael Rosenberg Date: Wed, 20 Nov 2024 13:05:37 -0500 Subject: [PATCH 27/32] Add context string to VDAF (#1145) * Added ctx argument for sharding methods * Added ctx to domain_separation_tag, and propagated changes from there Co-authored-by: Michael Rosenberg --- benches/cycle_counts.rs | 29 ++- benches/speed_tests.rs | 119 +++++++---- binaries/src/bin/vdaf_message_sizes.rs | 29 ++- src/topology/ping_pong.rs | 57 ++++-- src/vdaf.rs | 35 +++- src/vdaf/dummy.rs | 8 +- src/vdaf/mastic.rs | 24 ++- src/vdaf/poplar1.rs | 52 ++++- src/vdaf/prio2.rs | 28 ++- src/vdaf/prio2/server.rs | 3 +- src/vdaf/prio3.rs | 270 ++++++++++++++++++------- src/vdaf/prio3_test.rs | 33 ++- 12 files changed, 501 insertions(+), 186 deletions(-) diff --git a/benches/cycle_counts.rs b/benches/cycle_counts.rs index 5ab704cd4..2f3ede7a8 100644 --- a/benches/cycle_counts.rs +++ b/benches/cycle_counts.rs @@ -47,7 +47,10 @@ fn prio2_client(size: usize) -> Vec> { let prio2 = Prio2::new(size).unwrap(); let input = vec![0u32; size]; let nonce = [0; 16]; - prio2.shard(&black_box(input), &black_box(nonce)).unwrap().1 + prio2 + .shard(b"", &black_box(input), &black_box(nonce)) + .unwrap() + .1 } #[cfg(feature = "experimental")] @@ -70,9 +73,19 @@ fn prio2_shard_and_prepare(size: usize) -> Prio2PrepareShare { let prio2 = Prio2::new(size).unwrap(); let input = vec![0u32; size]; let nonce = [0; 16]; - let (public_share, input_shares) = prio2.shard(&black_box(input), &black_box(nonce)).unwrap(); + let (public_share, input_shares) = prio2 + .shard(b"", &black_box(input), &black_box(nonce)) + .unwrap(); prio2 - .prepare_init(&[0; 32], 0, &(), &nonce, &public_share, &input_shares[0]) + .prepare_init( + &[0; 32], + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) .unwrap() .1 } @@ -97,7 +110,7 @@ fn prio3_client_count() -> Vec> { let measurement = true; let nonce = [0; 16]; prio3 - .shard(&black_box(measurement), &black_box(nonce)) + .shard(b"", &black_box(measurement), &black_box(nonce)) .unwrap() .1 } @@ -107,7 +120,7 @@ fn prio3_client_histogram_10() -> Vec> { let measurement = 9; let nonce = [0; 16]; prio3 - .shard(&black_box(measurement), &black_box(nonce)) + .shard(b"", &black_box(measurement), &black_box(nonce)) .unwrap() .1 } @@ -117,7 +130,7 @@ fn prio3_client_sum_32() -> Vec> { let measurement = 1337; let nonce = [0; 16]; prio3 - .shard(&black_box(measurement), &black_box(nonce)) + .shard(b"", &black_box(measurement), &black_box(nonce)) .unwrap() .1 } @@ -128,7 +141,7 @@ fn prio3_client_count_vec_1000() -> Vec> { let measurement = vec![0; len]; let nonce = [0; 16]; prio3 - .shard(&black_box(measurement), &black_box(nonce)) + .shard(b"", &black_box(measurement), &black_box(nonce)) .unwrap() .1 } @@ -140,7 +153,7 @@ fn prio3_client_count_vec_multithreaded_1000() -> Vec>(); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -145,10 +145,18 @@ fn prio2(c: &mut Criterion) { .collect::>(); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 32]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { - vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) - .unwrap(); + vdaf.prepare_init( + &verify_key, + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap(); }); }, ); @@ -164,7 +172,7 @@ fn prio3(c: &mut Criterion) { let vdaf = Prio3::new_count(num_shares).unwrap(); let measurement = black_box(true); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }); c.bench_function("prio3count_prepare_init", |b| { @@ -172,10 +180,18 @@ fn prio3(c: &mut Criterion) { let measurement = black_box(true); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { - vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) - .unwrap() + vdaf.prepare_init( + &verify_key, + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() }); }); @@ -185,7 +201,7 @@ fn prio3(c: &mut Criterion) { let vdaf = Prio3::new_sum(num_shares, *bits).unwrap(); let measurement = (1 << bits) - 1; let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }); } group.finish(); @@ -197,10 +213,18 @@ fn prio3(c: &mut Criterion) { let measurement = (1 << bits) - 1; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { - vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) - .unwrap() + vdaf.prepare_init( + &verify_key, + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() }); }); } @@ -217,7 +241,7 @@ fn prio3(c: &mut Criterion) { .map(|i| i & 1) .collect::>(); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -240,7 +264,7 @@ fn prio3(c: &mut Criterion) { .map(|i| i & 1) .collect::>(); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -259,10 +283,18 @@ fn prio3(c: &mut Criterion) { .collect::>(); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { - vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) - .unwrap() + vdaf.prepare_init( + &verify_key, + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() }); }, ); @@ -287,10 +319,12 @@ fn prio3(c: &mut Criterion) { .collect::>(); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -323,7 +357,7 @@ fn prio3(c: &mut Criterion) { let vdaf = Prio3::new_histogram(num_shares, *input_length, *chunk_length).unwrap(); let measurement = black_box(0); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -352,7 +386,7 @@ fn prio3(c: &mut Criterion) { .unwrap(); let measurement = black_box(0); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -378,10 +412,18 @@ fn prio3(c: &mut Criterion) { let measurement = black_box(0); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { - vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) - .unwrap() + vdaf.prepare_init( + &verify_key, + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() }); }, ); @@ -412,10 +454,12 @@ fn prio3(c: &mut Criterion) { let measurement = black_box(0); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -448,7 +492,7 @@ fn prio3(c: &mut Criterion) { let mut measurement = vec![FP16_ZERO; *dimension]; measurement[0] = FP16_HALF; let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -468,7 +512,7 @@ fn prio3(c: &mut Criterion) { let mut measurement = vec![FP16_ZERO; *dimension]; measurement[0] = FP16_HALF; let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -487,10 +531,12 @@ fn prio3(c: &mut Criterion) { measurement[0] = FP16_HALF; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -520,10 +566,11 @@ fn prio3(c: &mut Criterion) { let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); let (public_share, input_shares) = - vdaf.shard(&measurement, &nonce).unwrap(); + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -549,7 +596,7 @@ fn prio3(c: &mut Criterion) { let mut measurement = vec![FP32_ZERO; *dimension]; measurement[0] = FP32_HALF; let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -569,7 +616,7 @@ fn prio3(c: &mut Criterion) { let mut measurement = vec![FP32_ZERO; *dimension]; measurement[0] = FP32_HALF; let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -588,10 +635,12 @@ fn prio3(c: &mut Criterion) { measurement[0] = FP32_HALF; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -621,10 +670,11 @@ fn prio3(c: &mut Criterion) { let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); let (public_share, input_shares) = - vdaf.shard(&measurement, &nonce).unwrap(); + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -724,7 +774,7 @@ fn poplar1(c: &mut Criterion) { let measurement = IdpfInput::from_bools(&bits); b.iter(|| { - vdaf.shard(&measurement, &nonce).unwrap(); + vdaf.shard(b"", &measurement, &nonce).unwrap(); }); }); } @@ -753,7 +803,7 @@ fn poplar1(c: &mut Criterion) { // We are benchmarking preparation of a single report. For this test, it doesn't matter // which measurement we generate a report for, so pick the first measurement // arbitrarily. - let (public_share, input_shares) = vdaf.shard(&measurements[0], &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurements[0], &nonce).unwrap(); let input_share = input_shares.into_iter().next().unwrap(); // For the aggregation paramter, we use the candidate prefixes from the prefix tree for @@ -765,6 +815,7 @@ fn poplar1(c: &mut Criterion) { b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &agg_param, &nonce, diff --git a/binaries/src/bin/vdaf_message_sizes.rs b/binaries/src/bin/vdaf_message_sizes.rs index dae75685a..998f15722 100644 --- a/binaries/src/bin/vdaf_message_sizes.rs +++ b/binaries/src/bin/vdaf_message_sizes.rs @@ -15,6 +15,9 @@ use prio::{ }, }; +const PRIO2_CTX_STR: &[u8] = b"prio2 ctx"; +const PRIO3_CTX_STR: &[u8] = b"prio3 ctx"; + fn main() { let num_shares = 2; let nonce = [0; 16]; @@ -23,7 +26,9 @@ fn main() { let measurement = true; println!( "prio3 count share size = {}", - vdaf_input_share_size::(prio3.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() + ) ); let length = 10; @@ -32,7 +37,9 @@ fn main() { println!( "prio3 histogram ({} buckets) share size = {}", length, - vdaf_input_share_size::(prio3.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() + ) ); let bits = 32; @@ -41,7 +48,9 @@ fn main() { println!( "prio3 sum ({} bits) share size = {}", bits, - vdaf_input_share_size::(prio3.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() + ) ); let len = 1000; @@ -50,7 +59,9 @@ fn main() { println!( "prio3 sumvec ({} len) share size = {}", len, - vdaf_input_share_size::(prio3.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() + ) ); let len = 1000; @@ -61,7 +72,7 @@ fn main() { "prio3 fixedpoint16 boundedl2 vec ({} entries) size = {}", len, vdaf_input_share_size::>, 16>( - prio3.shard(&measurement, &nonce).unwrap() + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() ) ); @@ -74,7 +85,9 @@ fn main() { println!( "prio2 ({} entries) size = {}", size, - vdaf_input_share_size::(prio2.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio2.shard(PRIO2_CTX_STR, &measurement, &nonce).unwrap() + ) ); // Prio3 @@ -83,7 +96,9 @@ fn main() { println!( "prio3 sumvec ({} entries) size = {}", size, - vdaf_input_share_size::(prio3.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() + ) ); } } diff --git a/src/topology/ping_pong.rs b/src/topology/ping_pong.rs index 646f18186..b3de2fe5d 100644 --- a/src/topology/ping_pong.rs +++ b/src/topology/ping_pong.rs @@ -206,6 +206,7 @@ impl< #[allow(clippy::type_complexity)] pub fn evaluate( &self, + ctx: &[u8], vdaf: &A, ) -> Result< ( @@ -220,6 +221,7 @@ impl< .map_err(PingPongError::CodecPrepMessage)?; vdaf.prepare_next( + ctx, self.previous_prepare_state.clone(), self.current_prepare_message.clone(), ) @@ -362,6 +364,7 @@ pub trait PingPongTopology Result<(Self::State, PingPongMessage), PingPongError> { self.prepare_init( verify_key, + ctx, /* Leader */ 0, agg_param, nonce, @@ -522,6 +532,7 @@ where fn helper_initialized( &self, verify_key: &[u8; VERIFY_KEY_SIZE], + ctx: &[u8], agg_param: &Self::AggregationParam, nonce: &[u8; NONCE_SIZE], public_share: &Self::PublicShare, @@ -531,6 +542,7 @@ where let (prep_state, prep_share) = self .prepare_init( verify_key, + ctx, /* Helper */ 1, agg_param, nonce, @@ -550,7 +562,7 @@ where }; let current_prepare_message = self - .prepare_shares_to_prepare_message(agg_param, [inbound_prep_share, prep_share]) + .prepare_shares_to_prepare_message(ctx, agg_param, [inbound_prep_share, prep_share]) .map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?; Ok(PingPongTransition { @@ -561,20 +573,22 @@ where fn leader_continued( &self, + ctx: &[u8], leader_state: Self::State, agg_param: &Self::AggregationParam, inbound: &PingPongMessage, ) -> Result { - self.continued(true, leader_state, agg_param, inbound) + self.continued(ctx, true, leader_state, agg_param, inbound) } fn helper_continued( &self, + ctx: &[u8], helper_state: Self::State, agg_param: &Self::AggregationParam, inbound: &PingPongMessage, ) -> Result { - self.continued(false, helper_state, agg_param, inbound) + self.continued(ctx, false, helper_state, agg_param, inbound) } } @@ -585,6 +599,7 @@ where { fn continued( &self, + ctx: &[u8], is_leader: bool, host_state: Self::State, agg_param: &Self::AggregationParam, @@ -616,7 +631,7 @@ where let prep_msg = Self::PrepareMessage::get_decoded_with_param(&host_prep_state, prep_msg) .map_err(PingPongError::CodecPrepMessage)?; let host_prep_transition = self - .prepare_next(host_prep_state, prep_msg) + .prepare_next(ctx, host_prep_state, prep_msg) .map_err(PingPongError::VdafPrepareNext)?; match (host_prep_transition, next_peer_prep_share) { @@ -634,7 +649,7 @@ where prep_shares.reverse(); } let current_prepare_message = self - .prepare_shares_to_prepare_message(agg_param, prep_shares) + .prepare_shares_to_prepare_message(ctx, agg_param, prep_shares) .map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?; Ok(PingPongContinuedValue::WithMessage { @@ -667,6 +682,8 @@ mod tests { use crate::vdaf::dummy; use assert_matches::assert_matches; + const CTX_STR: &[u8] = b"pingpong ctx"; + #[test] fn ping_pong_one_round() { let verify_key = []; @@ -683,6 +700,7 @@ mod tests { let (leader_state, leader_message) = leader .leader_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -694,6 +712,7 @@ mod tests { let (helper_state, helper_message) = helper .helper_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -701,14 +720,14 @@ mod tests { &leader_message, ) .unwrap() - .evaluate(&helper) + .evaluate(CTX_STR, &helper) .unwrap(); // 1 round VDAF: helper should finish immediately. assert_matches!(helper_state, PingPongState::Finished(_)); let leader_state = leader - .leader_continued(leader_state, &aggregation_param, &helper_message) + .leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message) .unwrap(); // 1 round VDAF: leader should finish when it gets helper message and emit no message. assert_matches!( @@ -733,6 +752,7 @@ mod tests { let (leader_state, leader_message) = leader .leader_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -744,6 +764,7 @@ mod tests { let (helper_state, helper_message) = helper .helper_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -751,26 +772,26 @@ mod tests { &leader_message, ) .unwrap() - .evaluate(&helper) + .evaluate(CTX_STR, &helper) .unwrap(); // 2 round VDAF, round 1: helper should continue. assert_matches!(helper_state, PingPongState::Continued(_)); let leader_state = leader - .leader_continued(leader_state, &aggregation_param, &helper_message) + .leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message) .unwrap(); // 2 round VDAF, round 1: leader should finish and emit a finish message. let leader_message = assert_matches!( leader_state, PingPongContinuedValue::WithMessage { transition } => { - let (state, message) = transition.evaluate(&leader).unwrap(); + let (state, message) = transition.evaluate(CTX_STR,&leader).unwrap(); assert_matches!(state, PingPongState::Finished(_)); message } ); let helper_state = helper - .helper_continued(helper_state, &aggregation_param, &leader_message) + .helper_continued(CTX_STR, helper_state, &aggregation_param, &leader_message) .unwrap(); // 2 round vdaf, round 1: helper should finish and emit no message. assert_matches!( @@ -795,6 +816,7 @@ mod tests { let (leader_state, leader_message) = leader .leader_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -806,6 +828,7 @@ mod tests { let (helper_state, helper_message) = helper .helper_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -813,38 +836,38 @@ mod tests { &leader_message, ) .unwrap() - .evaluate(&helper) + .evaluate(CTX_STR, &helper) .unwrap(); // 3 round VDAF, round 1: helper should continue. assert_matches!(helper_state, PingPongState::Continued(_)); let leader_state = leader - .leader_continued(leader_state, &aggregation_param, &helper_message) + .leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message) .unwrap(); // 3 round VDAF, round 1: leader should continue and emit a continue message. let (leader_state, leader_message) = assert_matches!( leader_state, PingPongContinuedValue::WithMessage { transition } => { - let (state, message) = transition.evaluate(&leader).unwrap(); + let (state, message) = transition.evaluate(CTX_STR,&leader).unwrap(); assert_matches!(state, PingPongState::Continued(_)); (state, message) } ); let helper_state = helper - .helper_continued(helper_state, &aggregation_param, &leader_message) + .helper_continued(CTX_STR, helper_state, &aggregation_param, &leader_message) .unwrap(); // 3 round vdaf, round 2: helper should finish and emit a finish message. let helper_message = assert_matches!( helper_state, PingPongContinuedValue::WithMessage { transition } => { - let (state, message) = transition.evaluate(&helper).unwrap(); + let (state, message) = transition.evaluate(CTX_STR,&helper).unwrap(); assert_matches!(state, PingPongState::Finished(_)); message } ); let leader_state = leader - .leader_continued(leader_state, &aggregation_param, &helper_message) + .leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message) .unwrap(); // 3 round VDAF, round 2: leader should finish and emit no message. assert_matches!( diff --git a/src/vdaf.rs b/src/vdaf.rs index 2c68f2e40..815836430 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -200,12 +200,16 @@ pub trait Vdaf: Clone + Debug { /// Generate the domain separation tag for this VDAF. The output is used for domain separation /// by the XOF. - fn domain_separation_tag(&self, usage: u16) -> [u8; 8] { - let mut dst = [0_u8; 8]; - dst[0] = VERSION; - dst[1] = 0; // algorithm class - dst[2..6].copy_from_slice(&(self.algorithm_id()).to_be_bytes()); - dst[6..8].copy_from_slice(&usage.to_be_bytes()); + fn domain_separation_tag(&self, usage: u16, ctx: &[u8]) -> Vec { + // Prefix is 8 bytes and defined by the spec. Copy these values in + let mut dst = Vec::with_capacity(ctx.len() + 8); + dst.push(VERSION); + dst.push(0); // algorithm class + dst.extend_from_slice(self.algorithm_id().to_be_bytes().as_slice()); + dst.extend_from_slice(usage.to_be_bytes().as_slice()); + // Finally, append user-chosen `ctx` + dst.extend_from_slice(ctx); + dst } } @@ -217,9 +221,10 @@ pub trait Client: Vdaf { /// /// Implements `Vdaf::shard` from [VDAF]. /// - /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.1 + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-13#section-5.1 fn shard( &self, + ctx: &[u8], measurement: &Self::Measurement, nonce: &[u8; NONCE_SIZE], ) -> Result<(Self::PublicShare, Vec), VdafError>; @@ -254,9 +259,11 @@ pub trait Aggregator: Vda /// Implements `Vdaf.prep_init` from [VDAF]. /// /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2 + #[allow(clippy::too_many_arguments)] fn prepare_init( &self, verify_key: &[u8; VERIFY_KEY_SIZE], + ctx: &[u8], agg_id: usize, agg_param: &Self::AggregationParam, nonce: &[u8; NONCE_SIZE], @@ -271,6 +278,7 @@ pub trait Aggregator: Vda /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2 fn prepare_shares_to_prepare_message>( &self, + ctx: &[u8], agg_param: &Self::AggregationParam, inputs: M, ) -> Result; @@ -288,6 +296,7 @@ pub trait Aggregator: Vda /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2 fn prepare_next( &self, + ctx: &[u8], state: Self::PrepareState, input: Self::PrepareMessage, ) -> Result, VdafError>; @@ -489,6 +498,7 @@ pub mod test_utils { /// Execute the VDAF end-to-end and return the aggregate result. pub fn run_vdaf( + ctx: &[u8], vdaf: &V, agg_param: &V::AggregationParam, measurements: M, @@ -500,16 +510,17 @@ pub mod test_utils { let mut sharded_measurements = Vec::new(); for measurement in measurements.into_iter() { let nonce = random(); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce)?; + let (public_share, input_shares) = vdaf.shard(ctx, &measurement, &nonce)?; sharded_measurements.push((public_share, nonce, input_shares)); } - run_vdaf_sharded(vdaf, agg_param, sharded_measurements) + run_vdaf_sharded(ctx, vdaf, agg_param, sharded_measurements) } /// Execute the VDAF on sharded measurements and return the aggregate result. pub fn run_vdaf_sharded( + ctx: &[u8], vdaf: &V, agg_param: &V::AggregationParam, sharded_measurements: M, @@ -530,6 +541,7 @@ pub mod test_utils { let out_shares = run_vdaf_prepare( vdaf, &verify_key, + ctx, agg_param, &nonce, public_share, @@ -579,6 +591,7 @@ pub mod test_utils { pub fn run_vdaf_prepare( vdaf: &V, verify_key: &[u8; SEED_SIZE], + ctx: &[u8], agg_param: &V::AggregationParam, nonce: &[u8; 16], public_share: V::PublicShare, @@ -600,6 +613,7 @@ pub mod test_utils { for (agg_id, input_share) in input_shares.enumerate() { let (state, msg) = vdaf.prepare_init( verify_key, + ctx, agg_id, agg_param, nonce, @@ -613,6 +627,7 @@ pub mod test_utils { let mut inbound = vdaf .prepare_shares_to_prepare_message( + ctx, agg_param, outbound.iter().map(|encoded| { V::PrepareShare::get_decoded_with_param(&states[0], encoded) @@ -627,6 +642,7 @@ pub mod test_utils { let mut outbound = Vec::new(); for state in states.iter_mut() { match vdaf.prepare_next( + ctx, state.clone(), V::PrepareMessage::get_decoded_with_param(state, &inbound) .expect("failed to decode prep message"), @@ -645,6 +661,7 @@ pub mod test_utils { // Another round is required before output shares are computed. inbound = vdaf .prepare_shares_to_prepare_message( + ctx, agg_param, outbound.iter().map(|encoded| { V::PrepareShare::get_decoded_with_param(&states[0], encoded) diff --git a/src/vdaf/dummy.rs b/src/vdaf/dummy.rs index 5b969bc19..1a78e3ee7 100644 --- a/src/vdaf/dummy.rs +++ b/src/vdaf/dummy.rs @@ -123,6 +123,7 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn prepare_init( &self, _verify_key: &[u8; 0], + _ctx: &[u8], _: usize, aggregation_param: &Self::AggregationParam, _nonce: &[u8; 16], @@ -141,6 +142,7 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn prepare_shares_to_prepare_message>( &self, + _ctx: &[u8], _: &Self::AggregationParam, _: M, ) -> Result { @@ -149,6 +151,7 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn prepare_next( &self, + _ctx: &[u8], state: Self::PrepareState, _: Self::PrepareMessage, ) -> Result, VdafError> { @@ -175,6 +178,7 @@ impl vdaf::Aggregator<0, 16> for Vdaf { impl vdaf::Client<16> for Vdaf { fn shard( &self, + _ctx: &[u8], measurement: &Self::Measurement, _nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec), VdafError> { @@ -361,12 +365,14 @@ mod tests { let mut sharded_measurements = Vec::new(); for measurement in measurements { let nonce = thread_rng().gen(); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = + vdaf.shard(b"dummy ctx", &measurement, &nonce).unwrap(); sharded_measurements.push((public_share, nonce, input_shares)); } let result = run_vdaf_sharded( + b"dummy ctx", &vdaf, &AggregationParam(aggregation_parameter), sharded_measurements.clone(), diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index 7b8d63424..afbac9331 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -341,6 +341,7 @@ where { fn shard( &self, + _ctx: &[u8], (attribute, weight): &(VidpfInput, T::Measurement), nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec), VdafError> { @@ -388,6 +389,7 @@ mod tests { use rand::{thread_rng, Rng}; const TEST_NONCE_SIZE: usize = 16; + const CTX_STR: &[u8] = b"mastic ctx"; #[test] fn test_mastic_shard_sum() { @@ -404,7 +406,9 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); - let (_public, _input_shares) = mastic.shard(&(first_input, 24u128), &nonce).unwrap(); + let (_public, _input_shares) = mastic + .shard(CTX_STR, &(first_input, 24u128), &nonce) + .unwrap(); } #[test] @@ -422,7 +426,9 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); - let (_, input_shares) = mastic.shard(&(first_input, 26u128), &nonce).unwrap(); + let (_, input_shares) = mastic + .shard(CTX_STR, &(first_input, 26u128), &nonce) + .unwrap(); let [leader_input_share, helper_input_share] = [&input_shares[0], &input_shares[1]]; assert_eq!( @@ -450,7 +456,7 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); - let (_public, _input_shares) = mastic.shard(&(first_input, true), &nonce).unwrap(); + let (_public, _input_shares) = mastic.shard(CTX_STR, &(first_input, true), &nonce).unwrap(); } #[test] @@ -470,7 +476,9 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); - let (_public, _input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let (_public, _input_shares) = mastic + .shard(CTX_STR, &(first_input, measurement), &nonce) + .unwrap(); } #[test] @@ -490,7 +498,9 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); - let (_public, input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let (_public, input_shares) = mastic + .shard(CTX_STR, &(first_input, measurement), &nonce) + .unwrap(); let leader_input_share = &input_shares[0]; let helper_input_share = &input_shares[1]; @@ -521,7 +531,9 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); - let (_public, input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let (_public, input_shares) = mastic + .shard(CTX_STR, &(first_input, measurement), &nonce) + .unwrap(); let leader_input_share = &input_shares[0]; let helper_input_share = &input_shares[1]; diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 1d396e4da..71bae8cc9 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -69,6 +69,7 @@ impl, const SEED_SIZE: usize> Poplar1 { &self, seed: &[u8; SEED_SIZE], usage: u16, + ctx: &[u8], binder_chunks: I, ) -> Prng where @@ -77,7 +78,7 @@ impl, const SEED_SIZE: usize> Poplar1 { P: Xof, F: FieldElement, { - let mut xof = P::init(seed, &self.domain_separation_tag(usage)); + let mut xof = P::init(seed, &self.domain_separation_tag(usage, ctx)); for binder_chunk in binder_chunks.into_iter() { xof.update(binder_chunk.as_ref()); } @@ -865,6 +866,7 @@ impl, const SEED_SIZE: usize> Vdaf for Poplar1 { impl, const SEED_SIZE: usize> Poplar1 { fn shard_with_random( &self, + ctx: &[u8], input: &IdpfInput, nonce: &[u8; 16], idpf_random: &[[u8; 16]; 2], @@ -879,7 +881,7 @@ impl, const SEED_SIZE: usize> Poplar1 { // Generate the authenticator for each inner level of the IDPF tree. let mut prng = - self.init_prng::<_, _, Field64>(&poplar_random[2], DST_SHARD_RANDOMNESS, [nonce]); + self.init_prng::<_, _, Field64>(&poplar_random[2], DST_SHARD_RANDOMNESS, ctx, [nonce]); let auth_inner: Vec = (0..self.bits - 1).map(|_| prng.get()).collect(); // Generate the authenticator for the last level of the IDPF tree (i.e., the leaves). @@ -912,11 +914,13 @@ impl, const SEED_SIZE: usize> Poplar1 { let mut corr_prng_0 = self.init_prng::<_, _, Field64>( corr_seed_0, DST_CORR_INNER, + ctx, [[0].as_slice(), nonce.as_slice()], ); let mut corr_prng_1 = self.init_prng::<_, _, Field64>( corr_seed_1, DST_CORR_INNER, + ctx, [[1].as_slice(), nonce.as_slice()], ); let mut corr_inner_0 = Vec::with_capacity(self.bits - 1); @@ -933,11 +937,13 @@ impl, const SEED_SIZE: usize> Poplar1 { let mut corr_prng_0 = self.init_prng::<_, _, Field255>( corr_seed_0, DST_CORR_LEAF, + ctx, [[0].as_slice(), nonce.as_slice()], ); let mut corr_prng_1 = self.init_prng::<_, _, Field255>( corr_seed_1, DST_CORR_LEAF, + ctx, [[1].as_slice(), nonce.as_slice()], ); let (corr_leaf_0, corr_leaf_1) = @@ -967,6 +973,7 @@ impl, const SEED_SIZE: usize> Poplar1 { fn eval_and_sketch( &self, verify_key: &[u8; SEED_SIZE], + ctx: &[u8], agg_id: usize, nonce: &[u8; 16], agg_param: &Poplar1AggregationParam, @@ -983,6 +990,7 @@ impl, const SEED_SIZE: usize> Poplar1 { let mut verify_prng = self.init_prng( verify_key, DST_VERIFY_RANDOMNESS, + ctx, [nonce.as_slice(), agg_param.level.to_be_bytes().as_slice()], ); @@ -1020,6 +1028,7 @@ impl, const SEED_SIZE: usize> Poplar1 { impl, const SEED_SIZE: usize> Client<16> for Poplar1 { fn shard( &self, + ctx: &[u8], input: &IdpfInput, nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec>), VdafError> { @@ -1031,7 +1040,7 @@ impl, const SEED_SIZE: usize> Client<16> for Poplar1, const SEED_SIZE: usize> Aggregator fn prepare_init( &self, verify_key: &[u8; SEED_SIZE], + ctx: &[u8], agg_id: usize, agg_param: &Poplar1AggregationParam, nonce: &[u8; 16], @@ -1066,6 +1076,7 @@ impl, const SEED_SIZE: usize> Aggregator let mut corr_prng = self.init_prng::<_, _, Field64>( input_share.corr_seed.as_ref(), DST_CORR_INNER, + ctx, [[agg_id as u8].as_slice(), nonce.as_slice()], ); // Fast-forward the correlated randomness XOF to the level of the tree that we are @@ -1076,6 +1087,7 @@ impl, const SEED_SIZE: usize> Aggregator let (output_share, sketch_share) = self.eval_and_sketch::( verify_key, + ctx, agg_id, nonce, agg_param, @@ -1099,11 +1111,13 @@ impl, const SEED_SIZE: usize> Aggregator let corr_prng = self.init_prng::<_, _, Field255>( input_share.corr_seed.as_ref(), DST_CORR_LEAF, + ctx, [[agg_id as u8].as_slice(), nonce.as_slice()], ); let (output_share, sketch_share) = self.eval_and_sketch::( verify_key, + ctx, agg_id, nonce, agg_param, @@ -1128,6 +1142,7 @@ impl, const SEED_SIZE: usize> Aggregator fn prepare_shares_to_prepare_message>( &self, + _ctx: &[u8], _: &Poplar1AggregationParam, inputs: M, ) -> Result { @@ -1167,6 +1182,7 @@ impl, const SEED_SIZE: usize> Aggregator fn prepare_next( &self, + _ctx: &[u8], state: Poplar1PrepareState, msg: Poplar1PrepareMessage, ) -> Result, VdafError> { @@ -1540,6 +1556,8 @@ mod tests { use serde::Deserialize; use std::collections::HashSet; + const CTX_STR: &[u8] = b"poplar1 ctx"; + fn test_prepare, const SEED_SIZE: usize>( vdaf: &Poplar1, verify_key: &[u8; SEED_SIZE], @@ -1552,6 +1570,7 @@ mod tests { let out_shares = run_vdaf_prepare( vdaf, verify_key, + CTX_STR, agg_param, nonce, public_share.clone(), @@ -1591,7 +1610,11 @@ mod tests { .map(|measurement| { let nonce = rng.gen(); let (public_share, input_shares) = vdaf - .shard(&IdpfInput::from_bytes(measurement.as_ref()), &nonce) + .shard( + CTX_STR, + &IdpfInput::from_bytes(measurement.as_ref()), + &nonce, + ) .unwrap(); (nonce, public_share, input_shares) }) @@ -1615,6 +1638,7 @@ mod tests { let out_shares = run_vdaf_prepare( vdaf, verify_key, + CTX_STR, &agg_param, nonce, public_share.clone(), @@ -1675,7 +1699,7 @@ mod tests { let verify_key = rng.gen(); let input = IdpfInput::from_bytes(b"12341324"); let nonce = rng.gen(); - let (public_share, input_shares) = vdaf.shard(&input, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(CTX_STR, &input, &nonce).unwrap(); test_prepare( &vdaf, @@ -2096,6 +2120,10 @@ mod tests { } fn check_test_vec(input: &str) { + // We need to use an empty context string for these test vectors to pass. + // TODO: update test vectors to ones that use a real context string + const CTX_STR: &[u8] = b""; + let test_vector: PoplarTestVector = serde_json::from_str(input).unwrap(); assert_eq!(test_vector.prep.len(), 1); let prep = &test_vector.prep[0]; @@ -2133,13 +2161,14 @@ mod tests { // Shard measurement. let poplar = Poplar1::new_turboshake128(test_vector.bits); let (public_share, input_shares) = poplar - .shard_with_random(&measurement, &nonce, &idpf_random, &poplar_random) + .shard_with_random(CTX_STR, &measurement, &nonce, &idpf_random, &poplar_random) .unwrap(); // Run aggregation. let (init_prep_state_0, init_prep_share_0) = poplar .prepare_init( &verify_key, + CTX_STR, 0, &agg_param, &nonce, @@ -2150,6 +2179,7 @@ mod tests { let (init_prep_state_1, init_prep_share_1) = poplar .prepare_init( &verify_key, + CTX_STR, 1, &agg_param, &nonce, @@ -2160,6 +2190,7 @@ mod tests { let r1_prep_msg = poplar .prepare_shares_to_prepare_message( + CTX_STR, &agg_param, [init_prep_share_0.clone(), init_prep_share_1.clone()], ) @@ -2167,19 +2198,20 @@ mod tests { let (r1_prep_state_0, r1_prep_share_0) = assert_matches!( poplar - .prepare_next(init_prep_state_0.clone(), r1_prep_msg.clone()) + .prepare_next(CTX_STR,init_prep_state_0.clone(), r1_prep_msg.clone()) .unwrap(), PrepareTransition::Continue(state, share) => (state, share) ); let (r1_prep_state_1, r1_prep_share_1) = assert_matches!( poplar - .prepare_next(init_prep_state_1.clone(), r1_prep_msg.clone()) + .prepare_next(CTX_STR,init_prep_state_1.clone(), r1_prep_msg.clone()) .unwrap(), PrepareTransition::Continue(state, share) => (state, share) ); let r2_prep_msg = poplar .prepare_shares_to_prepare_message( + CTX_STR, &agg_param, [r1_prep_share_0.clone(), r1_prep_share_1.clone()], ) @@ -2187,13 +2219,13 @@ mod tests { let out_share_0 = assert_matches!( poplar - .prepare_next(r1_prep_state_0.clone(), r2_prep_msg.clone()) + .prepare_next(CTX_STR, r1_prep_state_0.clone(), r2_prep_msg.clone()) .unwrap(), PrepareTransition::Finish(out) => out ); let out_share_1 = assert_matches!( poplar - .prepare_next(r1_prep_state_1, r2_prep_msg.clone()) + .prepare_next(CTX_STR,r1_prep_state_1, r2_prep_msg.clone()) .unwrap(), PrepareTransition::Finish(out) => out ); diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index 680f09ea7..96a8f5a3a 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -143,6 +143,7 @@ impl Vdaf for Prio2 { impl Client<16> for Prio2 { fn shard( &self, + _ctx: &[u8], measurement: &Vec, _nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec>), VdafError> { @@ -253,6 +254,7 @@ impl Aggregator<32, 16> for Prio2 { fn prepare_init( &self, agg_key: &[u8; 32], + _ctx: &[u8], agg_id: usize, _agg_param: &Self::AggregationParam, nonce: &[u8; 16], @@ -278,6 +280,7 @@ impl Aggregator<32, 16> for Prio2 { fn prepare_shares_to_prepare_message>( &self, + _ctx: &[u8], _: &Self::AggregationParam, inputs: M, ) -> Result<(), VdafError> { @@ -300,6 +303,7 @@ impl Aggregator<32, 16> for Prio2 { fn prepare_next( &self, + _ctx: &[u8], state: Prio2PrepareState, _input: (), ) -> Result, VdafError> { @@ -406,12 +410,17 @@ mod tests { use assert_matches::assert_matches; use rand::prelude::*; + // The value of this string doesn't matter. Prio2 is not defined to use the context string for + // any computation + pub(crate) const CTX_STR: &[u8] = b"prio2 ctx"; + #[test] fn run_prio2() { let prio2 = Prio2::new(6).unwrap(); assert_eq!( run_vdaf( + CTX_STR, &prio2, &(), [ @@ -434,11 +443,12 @@ mod tests { let nonce = rng.gen::<[u8; 16]>(); let data = vec![0, 0, 1, 1, 0]; let prio2 = Prio2::new(data.len()).unwrap(); - let (public_share, input_shares) = prio2.shard(&data, &nonce).unwrap(); + let (public_share, input_shares) = prio2.shard(CTX_STR, &data, &nonce).unwrap(); for (agg_id, input_share) in input_shares.iter().enumerate() { let (prepare_state, prepare_share) = prio2 .prepare_init( &verify_key, + CTX_STR, agg_id, &(), &[0; 16], @@ -500,17 +510,21 @@ mod tests { let input_share_1 = Share::get_decoded_with_param(&(&vdaf, 0), server_1_share).unwrap(); let input_share_2 = Share::get_decoded_with_param(&(&vdaf, 1), server_2_share).unwrap(); let (prepare_state_1, prepare_share_1) = vdaf - .prepare_init(&[0; 32], 0, &(), &[0; 16], &(), &input_share_1) + .prepare_init(&[0; 32], CTX_STR, 0, &(), &[0; 16], &(), &input_share_1) .unwrap(); let (prepare_state_2, prepare_share_2) = vdaf - .prepare_init(&[0; 32], 1, &(), &[0; 16], &(), &input_share_2) - .unwrap(); - vdaf.prepare_shares_to_prepare_message(&(), [prepare_share_1, prepare_share_2]) + .prepare_init(&[0; 32], CTX_STR, 1, &(), &[0; 16], &(), &input_share_2) .unwrap(); - let transition_1 = vdaf.prepare_next(prepare_state_1, ()).unwrap(); + vdaf.prepare_shares_to_prepare_message( + CTX_STR, + &(), + [prepare_share_1, prepare_share_2], + ) + .unwrap(); + let transition_1 = vdaf.prepare_next(CTX_STR, prepare_state_1, ()).unwrap(); let output_share_1 = assert_matches!(transition_1, PrepareTransition::Finish(out) => out); - let transition_2 = vdaf.prepare_next(prepare_state_2, ()).unwrap(); + let transition_2 = vdaf.prepare_next(CTX_STR, prepare_state_2, ()).unwrap(); let output_share_2 = assert_matches!(transition_2, PrepareTransition::Finish(out) => out); leader_output_shares.push(output_share_1); diff --git a/src/vdaf/prio2/server.rs b/src/vdaf/prio2/server.rs index 6e457e51d..26aa4a621 100644 --- a/src/vdaf/prio2/server.rs +++ b/src/vdaf/prio2/server.rs @@ -205,6 +205,7 @@ mod tests { prio2::{ client::{proof_length, unpack_proof_mut}, server::test_util::Server, + tests::CTX_STR, Prio2, }, Client, Share, ShareDecodingParameter, @@ -285,7 +286,7 @@ mod tests { } let vdaf = Prio2::new(dim).unwrap(); - let (_, shares) = vdaf.shard(&data, &[0; 16]).unwrap(); + let (_, shares) = vdaf.shard(CTX_STR, &data, &[0; 16]).unwrap(); let share1_original = shares[0].get_encoded().unwrap(); let share2 = shares[1].get_encoded().unwrap(); diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index f0d482dea..3936730ec 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -379,6 +379,7 @@ impl Prio3Average { /// use rand::prelude::*; /// /// let num_shares = 2; +/// let ctx = b"my context str"; /// let vdaf = Prio3::new_count(num_shares).unwrap(); /// /// let mut out_shares = vec![vec![]; num_shares.into()]; @@ -388,7 +389,7 @@ impl Prio3Average { /// for measurement in measurements { /// // Shard /// let nonce = rng.gen::<[u8; 16]>(); -/// let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); +/// let (public_share, input_shares) = vdaf.shard(ctx, &measurement, &nonce).unwrap(); /// /// // Prepare /// let mut prep_states = vec![]; @@ -396,6 +397,7 @@ impl Prio3Average { /// for (agg_id, input_share) in input_shares.iter().enumerate() { /// let (state, share) = vdaf.prepare_init( /// &verify_key, +/// ctx, /// agg_id, /// &(), /// &nonce, @@ -405,10 +407,10 @@ impl Prio3Average { /// prep_states.push(state); /// prep_shares.push(share); /// } -/// let prep_msg = vdaf.prepare_shares_to_prepare_message(&(), prep_shares).unwrap(); +/// let prep_msg = vdaf.prepare_shares_to_prepare_message(ctx, &(), prep_shares).unwrap(); /// /// for (agg_id, state) in prep_states.into_iter().enumerate() { -/// let out_share = match vdaf.prepare_next(state, prep_msg.clone()).unwrap() { +/// let out_share = match vdaf.prepare_next(ctx, state, prep_msg.clone()).unwrap() { /// PrepareTransition::Finish(out_share) => out_share, /// _ => panic!("unexpected transition"), /// }; @@ -481,10 +483,10 @@ where self.num_proofs.into() } - fn derive_prove_rands(&self, prove_rand_seed: &Seed) -> Vec { + fn derive_prove_rands(&self, ctx: &[u8], prove_rand_seed: &Seed) -> Vec { P::seed_stream( prove_rand_seed, - &self.domain_separation_tag(DST_PROVE_RANDOMNESS), + &self.domain_separation_tag(DST_PROVE_RANDOMNESS, ctx), &[self.num_proofs], ) .into_field_vec(self.typ.prove_rand_len() * self.num_proofs()) @@ -492,11 +494,12 @@ where fn derive_joint_rand_seed<'a>( &self, + ctx: &[u8], joint_rand_parts: impl Iterator>, ) -> Seed { let mut xof = P::init( &[0; SEED_SIZE], - &self.domain_separation_tag(DST_JOINT_RAND_SEED), + &self.domain_separation_tag(DST_JOINT_RAND_SEED, ctx), ); for part in joint_rand_parts { xof.update(part.as_ref()); @@ -506,12 +509,13 @@ where fn derive_joint_rands<'a>( &self, + ctx: &[u8], joint_rand_parts: impl Iterator>, ) -> (Seed, Vec) { - let joint_rand_seed = self.derive_joint_rand_seed(joint_rand_parts); + let joint_rand_seed = self.derive_joint_rand_seed(ctx, joint_rand_parts); let joint_rands = P::seed_stream( &joint_rand_seed, - &self.domain_separation_tag(DST_JOINT_RANDOMNESS), + &self.domain_separation_tag(DST_JOINT_RANDOMNESS, ctx), &[self.num_proofs], ) .into_field_vec(self.typ.joint_rand_len() * self.num_proofs()); @@ -521,20 +525,26 @@ where fn derive_helper_proofs_share( &self, + ctx: &[u8], proofs_share_seed: &Seed, agg_id: u8, ) -> Prng { Prng::from_seed_stream(P::seed_stream( proofs_share_seed, - &self.domain_separation_tag(DST_PROOF_SHARE), + &self.domain_separation_tag(DST_PROOF_SHARE, ctx), &[self.num_proofs, agg_id], )) } - fn derive_query_rands(&self, verify_key: &[u8; SEED_SIZE], nonce: &[u8; 16]) -> Vec { + fn derive_query_rands( + &self, + verify_key: &[u8; SEED_SIZE], + ctx: &[u8], + nonce: &[u8; 16], + ) -> Vec { let mut xof = P::init( verify_key, - &self.domain_separation_tag(DST_QUERY_RANDOMNESS), + &self.domain_separation_tag(DST_QUERY_RANDOMNESS, ctx), ); xof.update(&[self.num_proofs]); xof.update(nonce); @@ -562,6 +572,7 @@ where #[allow(clippy::type_complexity)] pub(crate) fn shard_with_random( &self, + ctx: &[u8], measurement: &T::Measurement, nonce: &[u8; N], random: &[u8], @@ -598,7 +609,7 @@ where let proof_share_seed = random_seeds.next().unwrap().try_into().unwrap(); let measurement_share_prng: Prng = Prng::from_seed_stream(P::seed_stream( &Seed(measurement_share_seed), - &self.domain_separation_tag(DST_MEASUREMENT_SHARE), + &self.domain_separation_tag(DST_MEASUREMENT_SHARE, ctx), &[agg_id], )); let joint_rand_blind = if let Some(helper_joint_rand_parts) = @@ -607,7 +618,7 @@ where let joint_rand_blind = random_seeds.next().unwrap().try_into().unwrap(); let mut joint_rand_part_xof = P::init( &joint_rand_blind, - &self.domain_separation_tag(DST_JOINT_RAND_PART), + &self.domain_separation_tag(DST_JOINT_RAND_PART, ctx), ); joint_rand_part_xof.update(&[agg_id]); // Aggregator ID joint_rand_part_xof.update(nonce); @@ -653,7 +664,7 @@ where let mut joint_rand_part_xof = P::init( leader_blind.as_ref(), - &self.domain_separation_tag(DST_JOINT_RAND_PART), + &self.domain_separation_tag(DST_JOINT_RAND_PART, ctx), ); joint_rand_part_xof.update(&[0]); // Aggregator ID joint_rand_part_xof.update(nonce); @@ -684,13 +695,14 @@ where let joint_rands = public_share .joint_rand_parts .as_ref() - .map(|joint_rand_parts| self.derive_joint_rands(joint_rand_parts.iter()).1) + .map(|joint_rand_parts| self.derive_joint_rands(ctx, joint_rand_parts.iter()).1) .unwrap_or_default(); // Generate the proofs. - let prove_rands = self.derive_prove_rands(&Seed::from_bytes( - random_seeds.next().unwrap().try_into().unwrap(), - )); + let prove_rands = self.derive_prove_rands( + ctx, + &Seed::from_bytes(random_seeds.next().unwrap().try_into().unwrap()), + ); let mut leader_proofs_share = Vec::with_capacity(self.typ.proof_len() * self.num_proofs()); for p in 0..self.num_proofs() { let prove_rand = @@ -707,14 +719,14 @@ where // Generate the proof shares and distribute the joint randomness seed hints. for (j, helper) in helper_shares.iter_mut().enumerate() { - for (x, y) in - leader_proofs_share - .iter_mut() - .zip(self.derive_helper_proofs_share( - &helper.proofs_share, - u8::try_from(j).unwrap() + 1, - )) - .take(self.typ.proof_len() * self.num_proofs()) + for (x, y) in leader_proofs_share + .iter_mut() + .zip(self.derive_helper_proofs_share( + ctx, + &helper.proofs_share, + u8::try_from(j).unwrap() + 1, + )) + .take(self.typ.proof_len() * self.num_proofs()) { *x -= y; } @@ -1083,12 +1095,13 @@ where #[allow(clippy::type_complexity)] fn shard( &self, + ctx: &[u8], measurement: &T::Measurement, nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec>), VdafError> { let mut random = vec![0u8; self.random_size()]; getrandom::getrandom(&mut random)?; - self.shard_with_random(measurement, nonce, &random) + self.shard_with_random(ctx, measurement, nonce, &random) } } @@ -1213,6 +1226,7 @@ where fn prepare_init( &self, verify_key: &[u8; SEED_SIZE], + ctx: &[u8], agg_id: usize, _agg_param: &Self::AggregationParam, nonce: &[u8; 16], @@ -1232,7 +1246,7 @@ where Share::Helper(ref seed) => Cow::Owned( P::seed_stream( seed, - &self.domain_separation_tag(DST_MEASUREMENT_SHARE), + &self.domain_separation_tag(DST_MEASUREMENT_SHARE, ctx), &[agg_id], ) .into_field_vec(self.typ.input_len()), @@ -1242,7 +1256,7 @@ where let proofs_share = match msg.proofs_share { Share::Leader(ref data) => Cow::Borrowed(data), Share::Helper(ref seed) => Cow::Owned( - self.derive_helper_proofs_share(seed, agg_id) + self.derive_helper_proofs_share(ctx, seed, agg_id) .take(self.typ.proof_len() * self.num_proofs()) .collect::>(), ), @@ -1252,7 +1266,7 @@ where let (joint_rand_seed, joint_rand_part, joint_rands) = if self.typ.joint_rand_len() > 0 { let mut joint_rand_part_xof = P::init( msg.joint_rand_blind.as_ref().unwrap().as_ref(), - &self.domain_separation_tag(DST_JOINT_RAND_PART), + &self.domain_separation_tag(DST_JOINT_RAND_PART, ctx), ); joint_rand_part_xof.update(&[agg_id]); joint_rand_part_xof.update(nonce); @@ -1288,7 +1302,7 @@ where ); let (joint_rand_seed, joint_rands) = - self.derive_joint_rands(corrected_joint_rand_parts); + self.derive_joint_rands(ctx, corrected_joint_rand_parts); ( Some(joint_rand_seed), @@ -1300,7 +1314,7 @@ where }; // Run the query-generation algorithm. - let query_rands = self.derive_query_rands(verify_key, nonce); + let query_rands = self.derive_query_rands(verify_key, ctx, nonce); let mut verifiers_share = Vec::with_capacity(self.typ.verifier_len() * self.num_proofs()); for p in 0..self.num_proofs() { let query_rand = @@ -1337,6 +1351,7 @@ where M: IntoIterator>, >( &self, + ctx: &[u8], _: &Self::AggregationParam, inputs: M, ) -> Result, VdafError> { @@ -1381,7 +1396,7 @@ where } let joint_rand_seed = if self.typ.joint_rand_len() > 0 { - Some(self.derive_joint_rand_seed(joint_rand_parts.iter())) + Some(self.derive_joint_rand_seed(ctx, joint_rand_parts.iter())) } else { None }; @@ -1391,6 +1406,7 @@ where fn prepare_next( &self, + ctx: &[u8], step: Prio3PrepareState, msg: Prio3PrepareMessage, ) -> Result, VdafError> { @@ -1413,7 +1429,7 @@ where let measurement_share = match step.measurement_share { Share::Leader(data) => data, Share::Helper(seed) => { - let dst = self.domain_separation_tag(DST_MEASUREMENT_SHARE); + let dst = self.domain_separation_tag(DST_MEASUREMENT_SHARE, ctx); P::seed_stream(&seed, &dst, &[step.agg_id]).into_field_vec(self.typ.input_len()) } }; @@ -1627,12 +1643,14 @@ mod tests { }; use rand::prelude::*; + const CTX_STR: &[u8] = b"prio3 ctx"; + #[test] fn test_prio3_count() { let prio3 = Prio3::new_count(2).unwrap(); assert_eq!( - run_vdaf(&prio3, &(), [true, false, false, true, true]).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), [true, false, false, true, true]).unwrap(), 3 ); @@ -1641,17 +1659,41 @@ mod tests { thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); - let (public_share, input_shares) = prio3.shard(&false, &nonce).unwrap(); - run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap(); + let (public_share, input_shares) = prio3.shard(CTX_STR, &false, &nonce).unwrap(); + run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ) + .unwrap(); - let (public_share, input_shares) = prio3.shard(&true, &nonce).unwrap(); - run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap(); + let (public_share, input_shares) = prio3.shard(CTX_STR, &true, &nonce).unwrap(); + run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ) + .unwrap(); test_serialization(&prio3, &true, &nonce).unwrap(); let prio3_extra_helper = Prio3::new_count(3).unwrap(); assert_eq!( - run_vdaf(&prio3_extra_helper, &(), [true, false, false, true, true]).unwrap(), + run_vdaf( + CTX_STR, + &prio3_extra_helper, + &(), + [true, false, false, true, true] + ) + .unwrap(), 3, ); } @@ -1661,7 +1703,7 @@ mod tests { let prio3 = Prio3::new_sum(3, 16).unwrap(); assert_eq!( - run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(), (1 << 16) + 1 ); @@ -1669,18 +1711,34 @@ mod tests { thread_rng().fill(&mut verify_key[..]); let nonce = [0; 16]; - let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap(); + let (public_share, mut input_shares) = prio3.shard(CTX_STR, &1, &nonce).unwrap(); assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); - let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + let result = run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ); assert_matches!(result, Err(VdafError::Uncategorized(_))); - let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap(); + let (public_share, mut input_shares) = prio3.shard(CTX_STR, &1, &nonce).unwrap(); assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); - let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + let result = run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ); assert_matches!(result, Err(VdafError::Uncategorized(_))); test_serialization(&prio3, &1, &nonce).unwrap(); @@ -1691,6 +1749,7 @@ mod tests { let prio3 = Prio3::new_sum_vec(2, 2, 20, 4).unwrap(); assert_eq!( run_vdaf( + CTX_STR, &prio3, &(), [ @@ -1715,6 +1774,7 @@ mod tests { assert_eq!( run_vdaf( + CTX_STR, &prio3, &(), [ @@ -1734,6 +1794,7 @@ mod tests { let prio3 = Prio3::new_sum_vec_multithreaded(2, 2, 20, 4).unwrap(); assert_eq!( run_vdaf( + CTX_STR, &prio3, &(), [ @@ -1787,7 +1848,7 @@ mod tests { let measurements = [fp_vec.clone(), fp_vec]; assert_eq!( - run_vdaf(&prio3, &(), measurements).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), measurements).unwrap(), vec![0.0; SIZE] ); } @@ -1888,21 +1949,21 @@ mod tests { // positive entries let fp_list = [fp_vec1, fp_vec2]; assert_eq!( - run_vdaf(&prio3, &(), fp_list).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), fp_list).unwrap(), vec!(0.5, 0.25, 0.125), ); // negative entries let fp_list2 = [fp_vec3, fp_vec4]; assert_eq!( - run_vdaf(&prio3, &(), fp_list2).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), fp_list2).unwrap(), vec!(-0.5, -0.25, -0.125), ); // both let fp_list3 = [fp_vec5, fp_vec6]; assert_eq!( - run_vdaf(&prio3, &(), fp_list3).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), fp_list3).unwrap(), vec!(0.5, 0.0, 0.0), ); @@ -1912,31 +1973,52 @@ mod tests { thread_rng().fill(&mut nonce); let (public_share, mut input_shares) = prio3 - .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) + .shard(CTX_STR, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) .unwrap(); input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255; - let result = - run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + let result = run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ); assert_matches!(result, Err(VdafError::Uncategorized(_))); let (public_share, mut input_shares) = prio3 - .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) + .shard(CTX_STR, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) .unwrap(); assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); - let result = - run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + let result = run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ); assert_matches!(result, Err(VdafError::Uncategorized(_))); let (public_share, mut input_shares) = prio3 - .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) + .shard(CTX_STR, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) .unwrap(); assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); - let result = - run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + let result = run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ); assert_matches!(result, Err(VdafError::Uncategorized(_))); test_serialization(&prio3, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce).unwrap(); @@ -1948,13 +2030,25 @@ mod tests { let prio3 = Prio3::new_histogram(2, 4, 2).unwrap(); assert_eq!( - run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), [0, 1, 2, 3]).unwrap(), vec![1, 1, 1, 1] ); - assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]); - assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]); - assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]); - assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [0]).unwrap(), + vec![1, 0, 0, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [1]).unwrap(), + vec![0, 1, 0, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [2]).unwrap(), + vec![0, 0, 1, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [3]).unwrap(), + vec![0, 0, 0, 1] + ); test_serialization(&prio3, &3, &[0; 16]).unwrap(); } @@ -1964,13 +2058,25 @@ mod tests { let prio3 = Prio3::new_histogram_multithreaded(2, 4, 2).unwrap(); assert_eq!( - run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), [0, 1, 2, 3]).unwrap(), vec![1, 1, 1, 1] ); - assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]); - assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]); - assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]); - assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [0]).unwrap(), + vec![1, 0, 0, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [1]).unwrap(), + vec![0, 1, 0, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [2]).unwrap(), + vec![0, 0, 1, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [3]).unwrap(), + vec![0, 0, 0, 1] + ); test_serialization(&prio3, &3, &[0; 16]).unwrap(); } @@ -1978,11 +2084,14 @@ mod tests { fn test_prio3_average() { let prio3 = Prio3::new_average(2, 64).unwrap(); - assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64); - assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64); - assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64); + assert_eq!(run_vdaf(CTX_STR, &prio3, &(), [17, 8]).unwrap(), 12.5f64); + assert_eq!(run_vdaf(CTX_STR, &prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [0, 0, 0, 1]).unwrap(), + 0.25f64 + ); assert_eq!( - run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(), 207.5f64 ); } @@ -1990,7 +2099,7 @@ mod tests { #[test] fn test_prio3_input_share() { let prio3 = Prio3::new_sum(5, 16).unwrap(); - let (_public_share, input_shares) = prio3.shard(&1, &[0; 16]).unwrap(); + let (_public_share, input_shares) = prio3.shard(CTX_STR, &1, &[0; 16]).unwrap(); // Check that seed shares are distinct. for (i, x) in input_shares.iter().enumerate() { @@ -2023,7 +2132,7 @@ mod tests { { let mut verify_key = [0; SEED_SIZE]; thread_rng().fill(&mut verify_key[..]); - let (public_share, input_shares) = prio3.shard(measurement, nonce)?; + let (public_share, input_shares) = prio3.shard(CTX_STR, measurement, nonce)?; let encoded_public_share = public_share.get_encoded().unwrap(); let decoded_public_share = @@ -2050,8 +2159,15 @@ mod tests { let mut prepare_shares = Vec::new(); let mut last_prepare_state = None; for (agg_id, input_share) in input_shares.iter().enumerate() { - let (prepare_state, prepare_share) = - prio3.prepare_init(&verify_key, agg_id, &(), nonce, &public_share, input_share)?; + let (prepare_state, prepare_share) = prio3.prepare_init( + &verify_key, + CTX_STR, + agg_id, + &(), + nonce, + &public_share, + input_share, + )?; let encoded_prepare_state = prepare_state.get_encoded().unwrap(); let decoded_prepare_state = @@ -2078,7 +2194,7 @@ mod tests { } let prepare_message = prio3 - .prepare_shares_to_prepare_message(&(), prepare_shares) + .prepare_shares_to_prepare_message(CTX_STR, &(), prepare_shares) .unwrap(); let encoded_prepare_message = prepare_message.get_encoded().unwrap(); diff --git a/src/vdaf/prio3_test.rs b/src/vdaf/prio3_test.rs index e7627be72..10b72c739 100644 --- a/src/vdaf/prio3_test.rs +++ b/src/vdaf/prio3_test.rs @@ -39,6 +39,7 @@ struct TPrio3Prep { #[derive(Deserialize, Serialize)] struct TPrio3 { + ctx: TEncoded, verify_key: TEncoded, shares: u8, prep: Vec>, @@ -63,6 +64,7 @@ macro_rules! err { fn check_prep_test_vec( prio3: &Prio3, verify_key: &[u8; SEED_SIZE], + ctx: &[u8], test_num: usize, t: &TPrio3Prep, ) -> Vec> @@ -74,7 +76,7 @@ where { let nonce = <[u8; 16]>::try_from(t.nonce.clone()).unwrap(); let (public_share, input_shares) = prio3 - .shard_with_random(&t.measurement.clone().into(), &nonce, &t.rand) + .shard_with_random(ctx, &t.measurement.clone().into(), &nonce, &t.rand) .expect("failed to generate input shares"); assert_eq!( @@ -100,7 +102,15 @@ where let mut prep_shares = Vec::new(); for (agg_id, input_share) in input_shares.iter().enumerate() { let (state, prep_share) = prio3 - .prepare_init(verify_key, agg_id, &(), &nonce, &public_share, input_share) + .prepare_init( + verify_key, + ctx, + agg_id, + &(), + &nonce, + &public_share, + input_share, + ) .unwrap_or_else(|e| err!(test_num, e, "prep state init")); states.push(state); prep_shares.push(prep_share); @@ -122,14 +132,17 @@ where } let inbound = prio3 - .prepare_shares_to_prepare_message(&(), prep_shares) + .prepare_shares_to_prepare_message(ctx, &(), prep_shares) .unwrap_or_else(|e| err!(test_num, e, "prep preprocess")); assert_eq!(t.prep_messages.len(), 1); assert_eq!(inbound.get_encoded().unwrap(), t.prep_messages[0].as_ref()); let mut out_shares = Vec::new(); for state in states.iter_mut() { - match prio3.prepare_next(state.clone(), inbound.clone()).unwrap() { + match prio3 + .prepare_next(ctx, state.clone(), inbound.clone()) + .unwrap() + { PrepareTransition::Finish(out_share) => { out_shares.push(out_share); } @@ -164,10 +177,11 @@ where P: Xof, { let verify_key = t.verify_key.as_ref().try_into().unwrap(); + let ctx = t.ctx.as_ref(); let mut all_output_shares = vec![Vec::new(); prio3.num_aggregators()]; for (test_num, p) in t.prep.iter().enumerate() { - let output_shares = check_prep_test_vec(prio3, verify_key, test_num, p); + let output_shares = check_prep_test_vec(prio3, verify_key, ctx, test_num, p); for (aggregator_output_shares, output_share) in all_output_shares.iter_mut().zip(output_shares.into_iter()) { @@ -250,6 +264,11 @@ mod tests { use super::{check_test_vec, check_test_vec_custom_de, Prio3CountMeasurement}; + // All the below tests are not passing. We ignore them until the rest of the repo is in a state + // where we can regenerate the JSON test vectors. + // Tracking issue https://github.com/divviup/libprio-rs/issues/1122 + + #[ignore] #[test] fn test_vec_prio3_count() { for test_vector_str in [ @@ -263,10 +282,6 @@ mod tests { } } - // All the below tests are not passing. We ignore them until the rest of the repo is in a state - // where we can regenerate the JSON test vectors. - // Tracking issue https://github.com/divviup/libprio-rs/issues/1122 - #[ignore] #[test] fn test_vec_prio3_sum() { From 1ac481f5ce94dd4213b671fbaf2a361919b29dbc Mon Sep 17 00:00:00 2001 From: Michael Rosenberg Date: Wed, 27 Nov 2024 10:17:21 -0500 Subject: [PATCH 28/32] Added `ctx` string to DPF computation (#1146) --- benches/cycle_counts.rs | 5 +-- benches/speed_tests.rs | 16 ++++++--- src/idpf.rs | 75 ++++++++++++++++++++++++----------------- src/vdaf/poplar1.rs | 6 ++++ src/vdaf/xof.rs | 53 ++++++++++++++++++++++------- 5 files changed, 107 insertions(+), 48 deletions(-) diff --git a/benches/cycle_counts.rs b/benches/cycle_counts.rs index 2f3ede7a8..8b1b3c184 100644 --- a/benches/cycle_counts.rs +++ b/benches/cycle_counts.rs @@ -165,7 +165,8 @@ fn idpf_poplar_gen( leaf_value: Poplar1IdpfValue, ) { let idpf = Idpf::new((), ()); - idpf.gen(input, inner_values, leaf_value, &[0; 16]).unwrap(); + idpf.gen(input, inner_values, leaf_value, b"", &[0; 16]) + .unwrap(); } #[cfg(feature = "experimental")] @@ -209,7 +210,7 @@ fn idpf_poplar_eval( ) { let mut cache = RingBufferCache::new(1); let idpf = Idpf::new((), ()); - idpf.eval(0, public_share, key, input, &[0; 16], &mut cache) + idpf.eval(0, public_share, key, input, b"", &[0; 16], &mut cache) .unwrap(); } diff --git a/benches/speed_tests.rs b/benches/speed_tests.rs index 2957d8b5f..053cabb2e 100644 --- a/benches/speed_tests.rs +++ b/benches/speed_tests.rs @@ -712,7 +712,7 @@ fn idpf(c: &mut Criterion) { let idpf = Idpf::new((), ()); b.iter(|| { - idpf.gen(&input, inner_values.clone(), leaf_value, &[0; 16]) + idpf.gen(&input, inner_values.clone(), leaf_value, b"", &[0; 16]) .unwrap(); }); }); @@ -735,7 +735,7 @@ fn idpf(c: &mut Criterion) { let idpf = Idpf::new((), ()); let (public_share, keys) = idpf - .gen(&input, inner_values, leaf_value, &[0; 16]) + .gen(&input, inner_values, leaf_value, b"", &[0; 16]) .unwrap(); b.iter(|| { @@ -747,8 +747,16 @@ fn idpf(c: &mut Criterion) { for prefix_length in 1..=size { let prefix = input[..prefix_length].to_owned().into(); - idpf.eval(0, &public_share, &keys[0], &prefix, &[0; 16], &mut cache) - .unwrap(); + idpf.eval( + 0, + &public_share, + &keys[0], + &prefix, + b"", + &[0; 16], + &mut cache, + ) + .unwrap(); } }); }); diff --git a/src/idpf.rs b/src/idpf.rs index f13f37ff4..ac6fdb2db 100644 --- a/src/idpf.rs +++ b/src/idpf.rs @@ -30,6 +30,18 @@ use std::{ }; use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq}; +const EXTEND_DOMAIN_SEP: &[u8; 8] = &[ + VERSION, 1, /* algorithm class */ + 0, 0, 0, 0, /* algorithm ID */ + 0, 0, /* usage */ +]; + +const CONVERT_DOMAIN_SEP: &[u8; 8] = &[ + VERSION, 1, /* algorithm class */ + 0, 0, 0, 0, /* algorithm ID */ + 0, 1, /* usage */ +]; + /// IDPF-related errors. #[derive(Debug, thiserror::Error)] #[non_exhaustive] @@ -394,6 +406,7 @@ where input: &IdpfInput, inner_values: M, leaf_value: VL, + ctx: &[u8], binder: &[u8], random: &[[u8; 16]; 2], ) -> Result<(IdpfPublicShare, [Seed<16>; 2]), VdafError> { @@ -402,18 +415,8 @@ where let initial_keys: [Seed<16>; 2] = [Seed::from_bytes(random[0]), Seed::from_bytes(random[1])]; - let extend_dst = [ - VERSION, 1, /* algorithm class */ - 0, 0, 0, 0, /* algorithm ID */ - 0, 0, /* usage */ - ]; - let convert_dst = [ - VERSION, 1, /* algorithm class */ - 0, 0, 0, 0, /* algorithm ID */ - 0, 1, /* usage */ - ]; - let extend_xof_fixed_key = XofFixedKeyAes128Key::new(&extend_dst, binder); - let convert_xof_fixed_key = XofFixedKeyAes128Key::new(&convert_dst, binder); + let extend_xof_fixed_key = XofFixedKeyAes128Key::new(&[EXTEND_DOMAIN_SEP, ctx], binder); + let convert_xof_fixed_key = XofFixedKeyAes128Key::new(&[CONVERT_DOMAIN_SEP, ctx], binder); let mut keys = [initial_keys[0].0, initial_keys[1].0]; let mut control_bits = [Choice::from(0u8), Choice::from(1u8)]; @@ -467,6 +470,7 @@ where input: &IdpfInput, inner_values: M, leaf_value: VL, + ctx: &[u8], binder: &[u8], ) -> Result<(IdpfPublicShare, [Seed<16>; 2]), VdafError> where @@ -481,7 +485,7 @@ where for random_seed in random.iter_mut() { getrandom::getrandom(random_seed)?; } - self.gen_with_random(input, inner_values, leaf_value, binder, &random) + self.gen_with_random(input, inner_values, leaf_value, ctx, binder, &random) } /// Evaluate an IDPF share on `prefix`, starting from a particular tree level with known @@ -495,23 +499,14 @@ where mut key: [u8; 16], mut control_bit: Choice, prefix: &IdpfInput, + ctx: &[u8], binder: &[u8], cache: &mut dyn IdpfCache, ) -> Result, IdpfError> { let bits = public_share.inner_correction_words.len() + 1; - let extend_dst = [ - VERSION, 1, /* algorithm class */ - 0, 0, 0, 0, /* algorithm ID */ - 0, 0, /* usage */ - ]; - let convert_dst = [ - VERSION, 1, /* algorithm class */ - 0, 0, 0, 0, /* algorithm ID */ - 0, 1, /* usage */ - ]; - let extend_xof_fixed_key = XofFixedKeyAes128Key::new(&extend_dst, binder); - let convert_xof_fixed_key = XofFixedKeyAes128Key::new(&convert_dst, binder); + let extend_xof_fixed_key = XofFixedKeyAes128Key::new(&[EXTEND_DOMAIN_SEP, ctx], binder); + let convert_xof_fixed_key = XofFixedKeyAes128Key::new(&[CONVERT_DOMAIN_SEP, ctx], binder); let mut last_inner_output = None; for ((correction_word, input_bit), level) in public_share.inner_correction_words @@ -556,12 +551,14 @@ where /// The IDPF key evaluation algorithm. /// /// Evaluate an IDPF share on `prefix`. + #[allow(clippy::too_many_arguments)] pub fn eval( &self, agg_id: usize, public_share: &IdpfPublicShare, key: &Seed<16>, prefix: &IdpfInput, + ctx: &[u8], binder: &[u8], cache: &mut dyn IdpfCache, ) -> Result, IdpfError> { @@ -602,6 +599,7 @@ where key, Choice::from(control_bit), prefix, + ctx, binder, cache, ); @@ -617,6 +615,7 @@ where key.0, /* control_bit */ Choice::from((!is_leader) as u8), prefix, + ctx, binder, cache, ) @@ -1075,6 +1074,8 @@ mod tests { sync::Mutex, }; + const CTX_STR: &[u8] = b"idpf context"; + use assert_matches::assert_matches; use bitvec::{ bitbox, @@ -1190,6 +1191,7 @@ mod tests { &input, Vec::from([Poplar1IdpfValue::new([Field64::one(), Field64::one()]); 4]), Poplar1IdpfValue::new([Field255::one(), Field255::one()]), + CTX_STR, &nonce, ) .unwrap(); @@ -1306,10 +1308,10 @@ mod tests { ) { let idpf = Idpf::new((), ()); let share_0 = idpf - .eval(0, public_share, &keys[0], prefix, binder, cache_0) + .eval(0, public_share, &keys[0], prefix, CTX_STR, binder, cache_0) .unwrap(); let share_1 = idpf - .eval(1, public_share, &keys[1], prefix, binder, cache_1) + .eval(1, public_share, &keys[1], prefix, CTX_STR, binder, cache_1) .unwrap(); let output = share_0.merge(share_1).unwrap(); assert_eq!(&output, expected_output); @@ -1340,7 +1342,7 @@ mod tests { let nonce: [u8; 16] = random(); let idpf = Idpf::new((), ()); let (public_share, keys) = idpf - .gen(&input, inner_values.clone(), leaf_values, &nonce) + .gen(&input, inner_values.clone(), leaf_values, CTX_STR, &nonce) .unwrap(); let mut cache_0 = RingBufferCache::new(3); let mut cache_1 = RingBufferCache::new(3); @@ -1409,7 +1411,7 @@ mod tests { let nonce: [u8; 16] = random(); let idpf = Idpf::new((), ()); let (public_share, keys) = idpf - .gen(&input, inner_values.clone(), leaf_values, &nonce) + .gen(&input, inner_values.clone(), leaf_values, CTX_STR, &nonce) .unwrap(); let mut cache_0 = SnoopingCache::new(HashMapCache::new()); let mut cache_1 = HashMapCache::new(); @@ -1588,7 +1590,7 @@ mod tests { let nonce: [u8; 16] = random(); let idpf = Idpf::new((), ()); let (public_share, keys) = idpf - .gen(&input, inner_values.clone(), leaf_values, &nonce) + .gen(&input, inner_values.clone(), leaf_values, CTX_STR, &nonce) .unwrap(); let mut cache_0 = LossyCache::new(); let mut cache_1 = LossyCache::new(); @@ -1624,6 +1626,7 @@ mod tests { &bitbox![].into(), Vec::>::new(), Poplar1IdpfValue::new([Field255::zero(); 2]), + CTX_STR, &nonce, ) .unwrap_err(); @@ -1633,6 +1636,7 @@ mod tests { &bitbox![0;10].into(), Vec::from([Poplar1IdpfValue::new([Field64::zero(); 2]); 9]), Poplar1IdpfValue::new([Field255::zero(); 2]), + CTX_STR, &nonce, ) .unwrap(); @@ -1642,6 +1646,7 @@ mod tests { &bitbox![0; 10].into(), Vec::from([Poplar1IdpfValue::new([Field64::zero(); 2]); 8]), Poplar1IdpfValue::new([Field255::zero(); 2]), + CTX_STR, &nonce, ) .unwrap_err(); @@ -1649,6 +1654,7 @@ mod tests { &bitbox![0; 10].into(), Vec::from([Poplar1IdpfValue::new([Field64::zero(); 2]); 10]), Poplar1IdpfValue::new([Field255::zero(); 2]), + CTX_STR, &nonce, ) .unwrap_err(); @@ -1660,6 +1666,7 @@ mod tests { &public_share, &keys[0], &bitbox![].into(), + CTX_STR, &nonce, &mut NoCache::new(), ) @@ -1671,6 +1678,7 @@ mod tests { &public_share, &keys[0], &bitbox![0; 11].into(), + CTX_STR, &nonce, &mut NoCache::new(), ) @@ -2016,6 +2024,7 @@ mod tests { } } + #[ignore] #[test] fn idpf_poplar_generate_test_vector() { let test_vector = load_idpfpoplar_test_vector(); @@ -2025,6 +2034,7 @@ mod tests { &test_vector.alpha, test_vector.beta_inner, test_vector.beta_leaf, + b"WRONG CTX, REPLACE ME", // TODO: Update test vectors to ones that provide ctx str &test_vector.binder, &test_vector.keys, ) @@ -2256,6 +2266,7 @@ mod tests { Field128::from(2), Field128::from(3), ])), + CTX_STR, binder, ) .unwrap(); @@ -2266,6 +2277,7 @@ mod tests { &public_share, &key_0, &IdpfInput::from_bytes(b"ou"), + CTX_STR, binder, &mut NoCache::new(), ) @@ -2276,6 +2288,7 @@ mod tests { &public_share, &key_1, &IdpfInput::from_bytes(b"ou"), + CTX_STR, binder, &mut NoCache::new(), ) @@ -2294,6 +2307,7 @@ mod tests { &public_share, &key_0, &IdpfInput::from_bytes(b"ae"), + CTX_STR, binder, &mut NoCache::new(), ) @@ -2304,6 +2318,7 @@ mod tests { &public_share, &key_1, &IdpfInput::from_bytes(b"ae"), + CTX_STR, binder, &mut NoCache::new(), ) diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 71bae8cc9..70c572dbd 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -899,6 +899,7 @@ impl, const SEED_SIZE: usize> Poplar1 { .iter() .map(|auth| Poplar1IdpfValue([Field64::one(), *auth])), Poplar1IdpfValue([Field255::one(), auth_leaf]), + ctx, nonce, idpf_random, )?; @@ -1009,6 +1010,7 @@ impl, const SEED_SIZE: usize> Poplar1 { public_share, idpf_key, prefix, + ctx, nonce, &mut idpf_eval_cache, )?); @@ -2386,21 +2388,25 @@ mod tests { assert_eq!(agg_result, test_vector.agg_result); } + #[ignore] #[test] fn test_vec_poplar1_0() { check_test_vec(include_str!("test_vec/08/Poplar1_0.json")); } + #[ignore] #[test] fn test_vec_poplar1_1() { check_test_vec(include_str!("test_vec/08/Poplar1_1.json")); } + #[ignore] #[test] fn test_vec_poplar1_2() { check_test_vec(include_str!("test_vec/08/Poplar1_2.json")); } + #[ignore] #[test] fn test_vec_poplar1_3() { check_test_vec(include_str!("test_vec/08/Poplar1_3.json")); diff --git a/src/vdaf/xof.rs b/src/vdaf/xof.rs index eb1e8de19..9635784b0 100644 --- a/src/vdaf/xof.rs +++ b/src/vdaf/xof.rs @@ -282,19 +282,38 @@ pub struct XofFixedKeyAes128Key { #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] impl XofFixedKeyAes128Key { - /// Derive the fixed key from the domain separation tag and binder string. - pub fn new(dst: &[u8], binder: &[u8]) -> Self { + /// Derive the fixed key from the binder string and the domain separator, which is concatenation + /// of all the items in `dst`. + /// + /// # Panics + /// Panics if the total length of all elements of `dst` exceeds `u16::MAX`. + pub fn new(dst: &[&[u8]], binder: &[u8]) -> Self { let mut fixed_key_deriver = TurboShake128::from_core(TurboShake128Core::new( XOF_FIXED_KEY_AES_128_DOMAIN_SEPARATION, )); - Update::update( - &mut fixed_key_deriver, - &[dst.len().try_into().expect("dst must be at most 255 bytes")], + let tot_dst_len: usize = dst + .iter() + .map(|s| { + let len = s.len(); + assert!(len <= u16::MAX as usize, "dst must be at most 65535 bytes"); + len + }) + .sum(); + + // Feed the dst length, dst, and binder into the XOF + fixed_key_deriver.update( + u16::try_from(tot_dst_len) + .expect("dst must be at most 65535 bytes") + .to_le_bytes() + .as_slice(), ); - Update::update(&mut fixed_key_deriver, dst); - Update::update(&mut fixed_key_deriver, binder); + dst.iter().for_each(|s| fixed_key_deriver.update(s)); + fixed_key_deriver.update(binder); + + // Squeeze out the key let mut key = GenericArray::from([0; 16]); XofReader::read(&mut fixed_key_deriver.finalize_xof(), key.as_mut()); + Self { cipher: Aes128::new(&key), } @@ -330,6 +349,10 @@ pub struct XofFixedKeyAes128 { base_block: Block, } +// This impl is only used by Mastic right now. The XofFixedKeyAes128Key impl is used in cases where +// the base XOF can be reused with different contexts. This is the case in VDAF IDPF computation. +// TODO(#1147): try to remove the duplicated code below. init() It's mostly the same as +// XofFixedKeyAes128Key::new() above #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] impl Xof<16> for XofFixedKeyAes128 { type SeedStream = SeedStreamFixedKeyAes128; @@ -338,7 +361,10 @@ impl Xof<16> for XofFixedKeyAes128 { let mut fixed_key_deriver = TurboShake128::from_core(TurboShake128Core::new(2u8)); Update::update( &mut fixed_key_deriver, - &[dst.len().try_into().expect("dst must be at most 255 bytes")], + u16::try_from(dst.len()) + .expect("dst must be at most 65535 bytes") + .to_le_bytes() + .as_slice(), ); Update::update(&mut fixed_key_deriver, dst); Self { @@ -574,6 +600,7 @@ mod tests { test_xof::(); } + #[ignore] #[cfg(feature = "experimental")] #[test] fn xof_fixed_key_aes128() { @@ -615,19 +642,21 @@ mod tests { #[cfg(feature = "experimental")] #[test] fn xof_fixed_key_aes128_alternate_apis() { - let dst = b"domain separation tag"; + let fixed_dst = b"domain separation tag"; + let ctx = b"context string"; + let full_dst = [fixed_dst.as_slice(), ctx.as_slice()].concat(); let binder = b"AAAAAAAAAAAAAAAAAAAAAAAA"; let seed_1 = Seed::generate().unwrap(); let seed_2 = Seed::generate().unwrap(); - let mut stream_1_trait_api = XofFixedKeyAes128::seed_stream(&seed_1, dst, binder); + let mut stream_1_trait_api = XofFixedKeyAes128::seed_stream(&seed_1, &full_dst, binder); let mut output_1_trait_api = [0u8; 32]; stream_1_trait_api.fill(&mut output_1_trait_api); - let mut stream_2_trait_api = XofFixedKeyAes128::seed_stream(&seed_2, dst, binder); + let mut stream_2_trait_api = XofFixedKeyAes128::seed_stream(&seed_2, &full_dst, binder); let mut output_2_trait_api = [0u8; 32]; stream_2_trait_api.fill(&mut output_2_trait_api); - let fixed_key = XofFixedKeyAes128Key::new(dst, binder); + let fixed_key = XofFixedKeyAes128Key::new(&[fixed_dst, ctx], binder); let mut stream_1_alternate_api = fixed_key.with_seed(seed_1.as_ref()); let mut output_1_alternate_api = [0u8; 32]; stream_1_alternate_api.fill(&mut output_1_alternate_api); From 46f0a6b1f7e9e59ee0df2011bd35fd4d9f7d67b3 Mon Sep 17 00:00:00 2001 From: David Cook Date: Mon, 2 Dec 2024 11:25:04 -0600 Subject: [PATCH 29/32] Fix clippy lints (#1148) --- src/field/field255.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/field/field255.rs b/src/field/field255.rs index 07306400b..65fb443e2 100644 --- a/src/field/field255.rs +++ b/src/field/field255.rs @@ -195,7 +195,7 @@ impl Neg for Field255 { } } -impl<'a> Neg for &'a Field255 { +impl Neg for &Field255 { type Output = Field255; fn neg(self) -> Field255 { @@ -216,7 +216,7 @@ impl From for Field255 { } } -impl<'a> TryFrom<&'a [u8]> for Field255 { +impl TryFrom<&[u8]> for Field255 { type Error = FieldError; fn try_from(bytes: &[u8]) -> Result { From 4f151e246b297ff9bcd289d6b904f18344205ecd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Dec 2024 19:49:59 +0000 Subject: [PATCH 30/32] build(deps): Bump serde from 1.0.214 to 1.0.215 (#1144) --- Cargo.lock | 8 ++++---- supply-chain/config.toml | 9 --------- supply-chain/imports.lock | 35 +++++++++++++++++++++++++++++++---- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9e599e8da..2fb40c08d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -871,18 +871,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.214" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.214" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", diff --git a/supply-chain/config.toml b/supply-chain/config.toml index bfc08a75a..10ff1741c 100644 --- a/supply-chain/config.toml +++ b/supply-chain/config.toml @@ -161,19 +161,10 @@ criteria = "safe-to-deploy" version = "0.8.5" criteria = "safe-to-deploy" -[[exemptions.rand_distr]] -version = "0.4.3" -criteria = "safe-to-run" - [[exemptions.safe_arch]] version = "0.7.0" criteria = "safe-to-run" -[[exemptions.sha2]] -version = "0.10.8" -criteria = "safe-to-deploy" -notes = "We do not use the new asm backend, either its feature or CPU architecture" - [[exemptions.simba]] version = "0.6.0" criteria = "safe-to-run" diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index ead58bc02..cf70d5ea4 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -107,15 +107,15 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde]] -version = "1.0.214" -when = "2024-10-28" +version = "1.0.215" +when = "2024-11-11" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde_derive]] -version = "1.0.214" -when = "2024-10-28" +version = "1.0.215" +when = "2024-11-11" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" @@ -626,6 +626,16 @@ criteria = "safe-to-deploy" delta = "0.6.3 -> 0.6.4" aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" +[[audits.mozilla.audits.rand_distr]] +who = "Ben Dean-Kawamura " +criteria = "safe-to-deploy" +version = "0.4.3" +notes = """ +Simple crate that extends `rand`. It has little unsafe code and uses Miri to test it. +As far as I can tell, it does not have any file IO or network access. +""" +aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" + [[audits.mozilla.audits.rayon]] who = "Josh Stone " criteria = "safe-to-deploy" @@ -639,6 +649,23 @@ criteria = "safe-to-deploy" delta = "1.5.3 -> 1.6.1" aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" +[[audits.mozilla.audits.sha2]] +who = "Mike Hommey " +criteria = "safe-to-deploy" +delta = "0.10.2 -> 0.10.6" +aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" + +[[audits.mozilla.audits.sha2]] +who = "Jeff Muizelaar " +criteria = "safe-to-deploy" +delta = "0.10.6 -> 0.10.8" +notes = """ +The bulk of this is https://github.com/RustCrypto/hashes/pull/490 which adds aarch64 support along with another PR adding longson. +I didn't check the implementation thoroughly but there wasn't anything obviously nefarious. 0.10.8 has been out for more than a year +which suggests no one else has found anything either. +""" +aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml" + [[audits.mozilla.audits.subtle]] who = "Simon Friedberger " criteria = "safe-to-deploy" From db00be27a845d8f3c2f031c35dbf1c982804ad52 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:17:20 +0000 Subject: [PATCH 31/32] build(deps): Bump serde_json from 1.0.132 to 1.0.133 (#1143) --- Cargo.lock | 4 ++-- supply-chain/imports.lock | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2fb40c08d..2850a5d45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -891,9 +891,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", "memchr", diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index cf70d5ea4..36aaa0e2f 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -121,8 +121,8 @@ user-login = "dtolnay" user-name = "David Tolnay" [[publisher.serde_json]] -version = "1.0.132" -when = "2024-10-19" +version = "1.0.133" +when = "2024-11-17" user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" From 3c1aeb30c661d373566749a81589fc0a4045f89a Mon Sep 17 00:00:00 2001 From: Michael Rosenberg Date: Mon, 9 Dec 2024 16:09:40 +0100 Subject: [PATCH 32/32] Added `max_measurement` field to `Prio3Sum` type (#1150) --- benches/cycle_counts.rs | 3 +- benches/speed_tests.rs | 11 +- binaries/src/bin/vdaf_message_sizes.rs | 6 +- src/field.rs | 15 +++ src/field/field255.rs | 6 + src/flp/szk.rs | 12 +- src/flp/types.rs | 164 ++++++++++++++++++------- src/vdaf/mastic.rs | 14 ++- src/vdaf/prio3.rs | 58 ++++----- src/vdaf/prio3_test.rs | 5 +- 10 files changed, 203 insertions(+), 91 deletions(-) diff --git a/benches/cycle_counts.rs b/benches/cycle_counts.rs index 8b1b3c184..3e3ebdf57 100644 --- a/benches/cycle_counts.rs +++ b/benches/cycle_counts.rs @@ -126,7 +126,8 @@ fn prio3_client_histogram_10() -> Vec> { } fn prio3_client_sum_32() -> Vec> { - let prio3 = Prio3::new_sum(2, 16).unwrap(); + let bits = 16; + let prio3 = Prio3::new_sum(2, (1 << bits) - 1).unwrap(); let measurement = 1337; let nonce = [0; 16]; prio3 diff --git a/benches/speed_tests.rs b/benches/speed_tests.rs index 053cabb2e..94dd5d183 100644 --- a/benches/speed_tests.rs +++ b/benches/speed_tests.rs @@ -198,8 +198,10 @@ fn prio3(c: &mut Criterion) { let mut group = c.benchmark_group("prio3sum_shard"); for bits in [8, 32] { group.bench_with_input(BenchmarkId::from_parameter(bits), &bits, |b, bits| { - let vdaf = Prio3::new_sum(num_shares, *bits).unwrap(); - let measurement = (1 << bits) - 1; + // Doesn't matter for speed what we use for max measurement, or measurement + let max_measurement = (1 << bits) - 1; + let vdaf = Prio3::new_sum(num_shares, max_measurement).unwrap(); + let measurement = max_measurement; let nonce = black_box([0u8; 16]); b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }); @@ -209,8 +211,9 @@ fn prio3(c: &mut Criterion) { let mut group = c.benchmark_group("prio3sum_prepare_init"); for bits in [8, 32] { group.bench_with_input(BenchmarkId::from_parameter(bits), &bits, |b, bits| { - let vdaf = Prio3::new_sum(num_shares, *bits).unwrap(); - let measurement = (1 << bits) - 1; + let max_measurement = (1 << bits) - 1; + let vdaf = Prio3::new_sum(num_shares, max_measurement).unwrap(); + let measurement = max_measurement; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); diff --git a/binaries/src/bin/vdaf_message_sizes.rs b/binaries/src/bin/vdaf_message_sizes.rs index 998f15722..940be79a4 100644 --- a/binaries/src/bin/vdaf_message_sizes.rs +++ b/binaries/src/bin/vdaf_message_sizes.rs @@ -42,12 +42,12 @@ fn main() { ) ); - let bits = 32; - let prio3 = Prio3::new_sum(num_shares, bits).unwrap(); + let max_measurement = 0xffff_ffff; + let prio3 = Prio3::new_sum(num_shares, max_measurement).unwrap(); let measurement = 1337; println!( "prio3 sum ({} bits) share size = {}", - bits, + max_measurement.ilog2() + 1, vdaf_input_share_size::( prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() ) diff --git a/src/field.rs b/src/field.rs index 7b8460a4b..88bf40ff1 100644 --- a/src/field.rs +++ b/src/field.rs @@ -201,6 +201,9 @@ pub trait Integer: /// Returns one. fn one() -> Self; + + /// Returns ⌊log₂(self)⌋, or `None` if `self == 0` + fn checked_ilog2(&self) -> Option; } /// Extension trait for field elements that can be converted back and forth to an integer type. @@ -785,6 +788,10 @@ impl Integer for u32 { fn one() -> Self { 1 } + + fn checked_ilog2(&self) -> Option { + u32::checked_ilog2(*self) + } } impl Integer for u64 { @@ -798,6 +805,10 @@ impl Integer for u64 { fn one() -> Self { 1 } + + fn checked_ilog2(&self) -> Option { + u64::checked_ilog2(*self) + } } impl Integer for u128 { @@ -811,6 +822,10 @@ impl Integer for u128 { fn one() -> Self { 1 } + + fn checked_ilog2(&self) -> Option { + u128::checked_ilog2(*self) + } } make_field!( diff --git a/src/field/field255.rs b/src/field/field255.rs index 65fb443e2..8a3f74bda 100644 --- a/src/field/field255.rs +++ b/src/field/field255.rs @@ -388,6 +388,12 @@ mod tests { fn one() -> Self { Self::new(Vec::from([1])) } + + fn checked_ilog2(&self) -> Option { + // This is a test module, and this code is never used. If we need this in the future, + // use BigUint::bits() + unimplemented!() + } } impl TestFieldElementWithInteger for Field255 { diff --git a/src/flp/szk.rs b/src/flp/szk.rs index e25598d0f..4531d3bf9 100644 --- a/src/flp/szk.rs +++ b/src/flp/szk.rs @@ -794,8 +794,9 @@ mod tests { #[test] fn test_sum_proof_share_encode() { let mut nonce = [0u8; 16]; + let max_measurement = 13; thread_rng().fill(&mut nonce[..]); - let sum = Sum::::new(5).unwrap(); + let sum = Sum::::new(max_measurement).unwrap(); let encoded_measurement = sum.encode_measurement(&9).unwrap(); let algorithm_id = 5; let szk_typ = Szk::new_turboshake128(sum, algorithm_id); @@ -896,9 +897,10 @@ mod tests { #[test] fn test_sum_leader_proof_share_roundtrip() { + let max_measurement = 13; let mut nonce = [0u8; 16]; thread_rng().fill(&mut nonce[..]); - let sum = Sum::::new(5).unwrap(); + let sum = Sum::::new(max_measurement).unwrap(); let encoded_measurement = sum.encode_measurement(&9).unwrap(); let algorithm_id = 5; let szk_typ = Szk::new_turboshake128(sum, algorithm_id); @@ -936,9 +938,10 @@ mod tests { #[test] fn test_sum_helper_proof_share_roundtrip() { + let max_measurement = 13; let mut nonce = [0u8; 16]; thread_rng().fill(&mut nonce[..]); - let sum = Sum::::new(5).unwrap(); + let sum = Sum::::new(max_measurement).unwrap(); let encoded_measurement = sum.encode_measurement(&9).unwrap(); let algorithm_id = 5; let szk_typ = Szk::new_turboshake128(sum, algorithm_id); @@ -1138,7 +1141,8 @@ mod tests { #[test] fn test_sum() { - let sum = Sum::::new(5).unwrap(); + let max_measurement = 13; + let sum = Sum::::new(max_measurement).unwrap(); let five = Field128::from(5); let nine = sum.encode_measurement(&9).unwrap(); diff --git a/src/flp/types.rs b/src/flp/types.rs index 9403039ef..2431af986 100644 --- a/src/flp/types.rs +++ b/src/flp/types.rs @@ -2,7 +2,7 @@ //! A collection of [`Type`] implementations. -use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt}; +use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt, Integer}; use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; use crate::flp::{FlpError, Gadget, Type}; use crate::polynomial::poly_range_check; @@ -113,37 +113,57 @@ impl Type for Count { } } -/// This sum type. Each measurement is a integer in `[0, 2^bits)` and the aggregate is the sum of -/// the measurements. +/// The sum type. Each measurement is a integer in `[0, max_measurement]` and the aggregate is the +/// sum of the measurements. /// /// The validity circuit is based on the SIMD circuit construction of [[BBCG+19], Theorem 5.3]. /// /// [BBCG+19]: https://ia.cr/2019/188 #[derive(Clone, PartialEq, Eq)] pub struct Sum { + max_measurement: F::Integer, + + // Computed from max_measurement + offset: F::Integer, bits: usize, - range_checker: Vec, + // Constant + bit_range_checker: Vec, } impl Debug for Sum { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Sum").field("bits", &self.bits).finish() + f.debug_struct("Sum") + .field("max_measurement", &self.max_measurement) + .field("bits", &self.bits) + .finish() } } impl Sum { /// Return a new [`Sum`] type parameter. Each value of this type is an integer in range `[0, - /// 2^bits)`. - pub fn new(bits: usize) -> Result { - if !F::valid_integer_bitlength(bits) { - return Err(FlpError::Encode( - "invalid bits: number of bits exceeds maximum number of bits in this field" - .to_string(), + /// max_measurement]` where `max_measurement > 0`. Errors if `max_measurement == 0`. + pub fn new(max_measurement: F::Integer) -> Result { + if max_measurement == F::Integer::zero() { + return Err(FlpError::InvalidParameter( + "max measurement cannot be zero".to_string(), )); } + + // Number of bits needed to represent x is ⌊log₂(x)⌋ + 1 + let bits = max_measurement.checked_ilog2().unwrap() as usize + 1; + + // The offset we add to the summand for range-checking purposes + let one = F::Integer::try_from(1).unwrap(); + let offset = (one << bits) - one - max_measurement; + + // Construct a range checker to ensure encoded bits are in the range [0, 2) + let bit_range_checker = poly_range_check(0, 2); + Ok(Self { bits, - range_checker: poly_range_check(0, 2), + max_measurement, + offset, + bit_range_checker, }) } } @@ -154,8 +174,17 @@ impl Type for Sum { type Field = F; fn encode_measurement(&self, summand: &F::Integer) -> Result, FlpError> { - let v = F::encode_as_bitvector(*summand, self.bits)?.collect(); - Ok(v) + if summand > &self.max_measurement { + return Err(FlpError::Encode(format!( + "unexpected measurement: got {:?}; want ≤{:?}", + summand, self.max_measurement + ))); + } + + let enc_summand = F::encode_as_bitvector(*summand, self.bits)?; + let enc_summand_plus_offset = F::encode_as_bitvector(self.offset + *summand, self.bits)?; + + Ok(enc_summand.chain(enc_summand_plus_offset).collect()) } fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result { @@ -164,8 +193,8 @@ impl Type for Sum { fn gadget(&self) -> Vec>> { vec![Box::new(PolyEval::new( - self.range_checker.clone(), - self.bits, + self.bit_range_checker.clone(), + 2 * self.bits, ))] } @@ -178,25 +207,38 @@ impl Type for Sum { g: &mut Vec>>, input: &[F], joint_rand: &[F], - _num_shares: usize, + num_shares: usize, ) -> Result, FlpError> { self.valid_call_check(input, joint_rand)?; let gadget = &mut g[0]; - input.iter().map(|&b| gadget.call(&[b])).collect() + let bit_checks = input + .iter() + .map(|&b| gadget.call(&[b])) + .collect::, _>>()?; + + let range_check = { + let offset = F::from(self.offset); + let shares_inv = F::from(F::valid_integer_try_from(num_shares)?).inv(); + let sum = F::decode_bitvector(&input[..self.bits])?; + let sum_plus_offset = F::decode_bitvector(&input[self.bits..])?; + offset * shares_inv + sum - sum_plus_offset + }; + + Ok([bit_checks.as_slice(), &[range_check]].concat()) } fn truncate(&self, input: Vec) -> Result, FlpError> { self.truncate_call_check(&input)?; - let res = F::decode_bitvector(&input)?; + let res = F::decode_bitvector(&input[..self.bits])?; Ok(vec![res]) } fn input_len(&self) -> usize { - self.bits + 2 * self.bits } fn proof_len(&self) -> usize { - 2 * ((1 + self.bits).next_power_of_two() - 1) + 2 + 2 * ((1 + 2 * self.bits).next_power_of_two() - 1) + 2 } fn verifier_len(&self) -> usize { @@ -212,7 +254,7 @@ impl Type for Sum { } fn eval_output_len(&self) -> usize { - self.bits + 2 * self.bits + 1 } fn prove_rand_len(&self) -> usize { @@ -220,8 +262,8 @@ impl Type for Sum { } } -/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the -/// aggregate is the arithmetic average. +/// The average type. Each measurement is an integer in `[0, max_measurement]` and the aggregate is +/// the arithmetic average of the measurements. // This is just a `Sum` object under the hood. The only difference is that the aggregate result is // an f64, which we get by dividing by `num_measurements` #[derive(Clone, PartialEq, Eq)] @@ -232,6 +274,7 @@ pub struct Average { impl Debug for Average { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Average") + .field("max_measurement", &self.summer.max_measurement) .field("bits", &self.summer.bits) .finish() } @@ -239,9 +282,9 @@ impl Debug for Average { impl Average { /// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0, - /// 2^bits)`. - pub fn new(bits: usize) -> Result { - let summer = Sum::new(bits)?; + /// max_measurement]` where `max_measurement > 0`. Errors if `max_measurement == 0`. + pub fn new(max_measurement: F::Integer) -> Result { + let summer = Sum::new(max_measurement)?; Ok(Average { summer }) } } @@ -288,7 +331,7 @@ impl Type for Average { } fn input_len(&self) -> usize { - self.summer.bits + self.summer.input_len() } fn proof_len(&self) -> usize { @@ -592,20 +635,19 @@ where } // Convert bool vector to field elems - let multihot_vec: Vec = measurement + let multihot_vec = measurement .iter() // We can unwrap because any Integer type can cast from bool - .map(|bit| F::from(F::valid_integer_try_from(*bit as usize).unwrap())) - .collect(); + .map(|bit| F::from(F::valid_integer_try_from(*bit as usize).unwrap())); // Encode the measurement weight in binary (actually, the weight plus some offset) let offset_weight_bits = { let offset_weight_reported = F::valid_integer_try_from(self.offset + weight_reported)?; - F::encode_as_bitvector(offset_weight_reported, self.bits_for_weight)?.collect() + F::encode_as_bitvector(offset_weight_reported, self.bits_for_weight)? }; // Report the concat of the two - Ok([multihot_vec, offset_weight_bits].concat()) + Ok(multihot_vec.chain(offset_weight_bits).collect()) } fn decode_result( @@ -1024,7 +1066,9 @@ mod tests { #[test] fn test_sum() { - let sum = Sum::new(11).unwrap(); + let max_measurement = 1458; + + let sum = Sum::new(max_measurement).unwrap(); let zero = TestField::zero(); let one = TestField::one(); let nine = TestField::from(9); @@ -1045,22 +1089,52 @@ mod tests { &sum.encode_measurement(&1337).unwrap(), &[TestField::from(1337)], ); - FlpTest::expect_valid::<3>(&Sum::new(0).unwrap(), &[], &[zero]); - FlpTest::expect_valid::<3>(&Sum::new(2).unwrap(), &[one, zero], &[one]); - FlpTest::expect_valid::<3>( - &Sum::new(9).unwrap(), - &[one, zero, one, one, zero, one, one, one, zero], - &[TestField::from(237)], - ); - // Test FLP on invalid input. - FlpTest::expect_invalid::<3>(&Sum::new(3).unwrap(), &[one, nine, zero]); - FlpTest::expect_invalid::<3>(&Sum::new(5).unwrap(), &[zero, zero, zero, zero, nine]); + { + let sum = Sum::new(3).unwrap(); + let meas = 1; + FlpTest::expect_valid::<3>( + &sum, + &sum.encode_measurement(&meas).unwrap(), + &[TestField::from(meas)], + ); + } + + { + let sum = Sum::new(400).unwrap(); + let meas = 237; + FlpTest::expect_valid::<3>( + &sum, + &sum.encode_measurement(&meas).unwrap(), + &[TestField::from(meas)], + ); + } + + // Test FLP on invalid input, specifically on field elements outside of {0,1} + { + let sum = Sum::new((1 << 3) - 1).unwrap(); + // The sum+offset value can be whatever. The binariness test should fail first + let sum_plus_offset = vec![zero; 3]; + FlpTest::expect_invalid::<3>( + &sum, + &[&[one, nine, zero], sum_plus_offset.as_slice()].concat(), + ); + } + { + let sum = Sum::new((1 << 5) - 1).unwrap(); + let sum_plus_offset = vec![zero; 5]; + FlpTest::expect_invalid::<3>( + &sum, + &[&[zero, zero, zero, zero, nine], sum_plus_offset.as_slice()].concat(), + ); + } } #[test] fn test_average() { - let average = Average::new(11).unwrap(); + let max_measurement = (1 << 11) - 13; + + let average = Average::new(max_measurement).unwrap(); let zero = TestField::zero(); let one = TestField::one(); let ten = TestField::from(10); diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index afbac9331..6e3426b5f 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -394,9 +394,12 @@ mod tests { #[test] fn test_mastic_shard_sum() { let algorithm_id = 6; - let sum_typ = Sum::::new(5).unwrap(); + let max_measurement = 29; + let sum_typ = Sum::::new(max_measurement).unwrap(); + let encoded_meas_len = sum_typ.input_len(); + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); - let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(encoded_meas_len); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -414,9 +417,12 @@ mod tests { #[test] fn test_input_share_encode_sum() { let algorithm_id = 6; - let sum_typ = Sum::::new(5).unwrap(); + let max_measurement = 29; + let sum_typ = Sum::::new(max_measurement).unwrap(); + let encoded_meas_len = sum_typ.input_len(); + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); - let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(encoded_meas_len); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 3936730ec..840872e76 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -33,7 +33,9 @@ use super::AggregatorWithNoise; use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode}; #[cfg(feature = "experimental")] use crate::dp::DifferentialPrivacyStrategy; -use crate::field::{decode_fieldvec, FftFriendlyFieldElement, FieldElement}; +use crate::field::{ + decode_fieldvec, FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, +}; use crate::field::{Field128, Field64}; #[cfg(feature = "multithreaded")] use crate::flp::gadgets::ParallelSumMultithreaded; @@ -141,16 +143,13 @@ impl Prio3SumVecMultithreaded { pub type Prio3Sum = Prio3, XofTurboShake128, 16>; impl Prio3Sum { - /// Construct an instance of Prio3Sum with the given number of aggregators and required bit - /// length. The bit length must not exceed 64. - pub fn new_sum(num_aggregators: u8, bits: usize) -> Result { - if bits > 64 { - return Err(VdafError::Uncategorized(format!( - "bit length ({bits}) exceeds limit for aggregate type (64)" - ))); - } - - Prio3::new(num_aggregators, 1, 0x00000001, Sum::new(bits)?) + /// Construct an instance of `Prio3Sum` with the given number of aggregators, where each summand + /// must be in the range `[0, max_measurement]`. Errors if `max_measurement == 0`. + pub fn new_sum( + num_aggregators: u8, + max_measurement: ::Integer, + ) -> Result { + Prio3::new(num_aggregators, 1, 0x00000001, Sum::new(max_measurement)?) } } @@ -340,22 +339,19 @@ impl Prio3MultihotCountVecMultithreaded { pub type Prio3Average = Prio3, XofTurboShake128, 16>; impl Prio3Average { - /// Construct an instance of Prio3Average with the given number of aggregators and required bit - /// length. The bit length must not exceed 64. - pub fn new_average(num_aggregators: u8, bits: usize) -> Result { + /// Construct an instance of `Prio3Average` with the given number of aggregators, where each + /// summand must be in the range `[0, max_measurement]`. Errors if `max_measurement == 0`. + pub fn new_average( + num_aggregators: u8, + max_measurement: ::Integer, + ) -> Result { check_num_aggregators(num_aggregators)?; - if bits > 64 { - return Err(VdafError::Uncategorized(format!( - "bit length ({bits}) exceeds limit for aggregate type (64)" - ))); - } - Ok(Prio3 { num_aggregators, num_proofs: 1, algorithm_id: 0xFFFF0000, - typ: Average::new(bits)?, + typ: Average::new(max_measurement)?, phantom: PhantomData, }) } @@ -1700,11 +1696,13 @@ mod tests { #[test] fn test_prio3_sum() { - let prio3 = Prio3::new_sum(3, 16).unwrap(); + let max_measurement = 35_891; + + let prio3 = Prio3::new_sum(3, max_measurement).unwrap(); assert_eq!( - run_vdaf(CTX_STR, &prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(), - (1 << 16) + 1 + run_vdaf(CTX_STR, &prio3, &(), [0, max_measurement, 0, 1, 1]).unwrap(), + max_measurement + 2, ); let mut verify_key = [0; 16]; @@ -2082,7 +2080,8 @@ mod tests { #[test] fn test_prio3_average() { - let prio3 = Prio3::new_average(2, 64).unwrap(); + let max_measurement = 43_208; + let prio3 = Prio3::new_average(2, max_measurement).unwrap(); assert_eq!(run_vdaf(CTX_STR, &prio3, &(), [17, 8]).unwrap(), 12.5f64); assert_eq!(run_vdaf(CTX_STR, &prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64); @@ -2098,7 +2097,8 @@ mod tests { #[test] fn test_prio3_input_share() { - let prio3 = Prio3::new_sum(5, 16).unwrap(); + let max_measurement = 1; + let prio3 = Prio3::new_sum(5, max_measurement).unwrap(); let (_public_share, input_shares) = prio3.shard(CTX_STR, &1, &[0; 16]).unwrap(); // Check that seed shares are distinct. @@ -2217,7 +2217,8 @@ mod tests { let vdaf = Prio3::new_count(2).unwrap(); fieldvec_roundtrip_test::>(&vdaf, &(), 1); - let vdaf = Prio3::new_sum(2, 17).unwrap(); + let max_measurement = 13; + let vdaf = Prio3::new_sum(2, max_measurement).unwrap(); fieldvec_roundtrip_test::>(&vdaf, &(), 1); let vdaf = Prio3::new_histogram(2, 12, 3).unwrap(); @@ -2229,7 +2230,8 @@ mod tests { let vdaf = Prio3::new_count(2).unwrap(); fieldvec_roundtrip_test::>(&vdaf, &(), 1); - let vdaf = Prio3::new_sum(2, 17).unwrap(); + let max_measurement = 13; + let vdaf = Prio3::new_sum(2, max_measurement).unwrap(); fieldvec_roundtrip_test::>(&vdaf, &(), 1); let vdaf = Prio3::new_histogram(2, 12, 3).unwrap(); diff --git a/src/vdaf/prio3_test.rs b/src/vdaf/prio3_test.rs index 10b72c739..56223bcb9 100644 --- a/src/vdaf/prio3_test.rs +++ b/src/vdaf/prio3_test.rs @@ -285,13 +285,14 @@ mod tests { #[ignore] #[test] fn test_vec_prio3_sum() { + const FAKE_MAX_MEASUREMENT_UPDATE_ME: u128 = 0; for test_vector_str in [ include_str!("test_vec/08/Prio3Sum_0.json"), include_str!("test_vec/08/Prio3Sum_1.json"), ] { check_test_vec(test_vector_str, |json_params, num_shares| { - let bits = json_params["bits"].as_u64().unwrap() as usize; - Prio3::new_sum(num_shares, bits).unwrap() + let _bits = json_params["bits"].as_u64().unwrap() as usize; + Prio3::new_sum(num_shares, FAKE_MAX_MEASUREMENT_UPDATE_ME).unwrap() }); } }