Skip to content

Commit

Permalink
Introduce new trait for integers in field module (#853)
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave authored Dec 4, 2023
1 parent 3bbeb52 commit 711aec2
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 84 deletions.
180 changes: 106 additions & 74 deletions src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,33 +166,44 @@ pub trait FieldElement:
}
}

/// An integer type that accompanies a finite field. Integers and field elements may be converted
/// back and forth via the natural map between residue classes modulo 'p' and integers between 0
/// and p - 1.
pub trait Integer:
Debug
+ Eq
+ Ord
+ BitAnd<Output = Self>
+ Div<Output = Self>
+ Shl<usize, Output = Self>
+ Shr<usize, Output = Self>
+ Add<Output = Self>
+ Sub<Output = Self>
+ TryFrom<usize, Error = Self::TryFromUsizeError>
+ TryInto<u64, Error = Self::TryIntoU64Error>
{
/// The error returned if converting `usize` to this integer type fails.
type TryFromUsizeError: std::error::Error;

/// The error returned if converting this integer type to a `u64` fails.
type TryIntoU64Error: std::error::Error;

/// Returns zero.
fn zero() -> Self;

/// Returns one.
fn one() -> Self;
}

/// Extension trait for field elements that can be converted back and forth to an integer type.
///
/// The `Integer` associated type is an integer (primitive or otherwise) that supports various
/// arithmetic operations. The order of the field is guaranteed to fit inside the range of the
/// integer type. This trait also defines methods on field elements, `pow` and `modulus`, that make
/// use of the associated integer type.
pub trait FieldElementWithInteger: FieldElement + From<Self::Integer> {
/// The error returned if converting `usize` to an `Integer` fails.
type IntegerTryFromError: std::error::Error;

/// The error returned if converting an `Integer` to a `u64` fails.
type TryIntoU64Error: std::error::Error;

/// The integer representation of a field element.
type Integer: Copy
+ Debug
+ Eq
+ Ord
+ BitAnd<Output = Self::Integer>
+ Div<Output = Self::Integer>
+ Shl<usize, Output = Self::Integer>
+ Shr<usize, Output = Self::Integer>
+ Add<Output = Self::Integer>
+ Sub<Output = Self::Integer>
+ From<Self>
+ TryFrom<usize, Error = Self::IntegerTryFromError>
+ TryInto<u64, Error = Self::TryIntoU64Error>;
type Integer: Integer + From<Self> + Copy;

/// Modular exponentation, i.e., `self^exp (mod p)`.
fn pow(&self, exp: Self::Integer) -> Self;
Expand All @@ -216,7 +227,7 @@ pub trait FieldElementWithInteger: FieldElement + From<Self::Integer> {
// Check if the input value can be represented in the requested number of bits by shifting
// it. The above check on `bits` ensures this shift won't panic due to the shift width
// being too large.
if input >> bits != Self::zero_integer() {
if input >> bits != Self::Integer::zero() {
return Err(FieldError::InputSizeMismatch);
}

Expand Down Expand Up @@ -274,7 +285,7 @@ where
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let bit_offset = self.inner.next()?;
Some(F::from((self.input >> bit_offset) & F::one_integer()))
Some(F::from((self.input >> bit_offset) & F::Integer::one()))
}
}

Expand All @@ -299,28 +310,14 @@ pub(crate) trait FieldElementWithIntegerExt: FieldElementWithInteger {
if bits >= 8 * Self::ENCODED_SIZE {
return false;
}
if Self::modulus() >> bits != Self::zero_integer() {
if Self::modulus() >> bits != Self::Integer::zero() {
return true;
}
false
}

/// Returns the integer representation of the additive identity.
fn zero_integer() -> Self::Integer;

/// Returns the integer representation of the multiplicative identity.
fn one_integer() -> Self::Integer;
}

impl<F: FieldElementWithInteger> FieldElementWithIntegerExt for F {
fn zero_integer() -> Self::Integer {
0usize.try_into().unwrap()
}

fn one_integer() -> Self::Integer {
1usize.try_into().unwrap()
}
}
impl<F: FieldElementWithInteger> FieldElementWithIntegerExt for F {}

/// Methods common to all `FieldElement` implementations that are private to the crate.
pub(crate) trait FieldElementExt: FieldElement {
Expand Down Expand Up @@ -711,8 +708,6 @@ macro_rules! make_field {

impl FieldElementWithInteger for $elem {
type Integer = $int;
type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error;
type TryIntoU64Error = <Self::Integer as TryInto<u64>>::Error;

fn pow(&self, exp: Self::Integer) -> Self {
// FieldParameters::pow() relies on mul(), and will always return a value less
Expand Down Expand Up @@ -745,6 +740,45 @@ macro_rules! make_field {
};
}

impl Integer for u32 {
type TryFromUsizeError = <Self as TryFrom<usize>>::Error;
type TryIntoU64Error = <Self as TryInto<u64>>::Error;

fn zero() -> Self {
0
}

fn one() -> Self {
1
}
}

impl Integer for u64 {
type TryFromUsizeError = <Self as TryFrom<usize>>::Error;
type TryIntoU64Error = <Self as TryInto<u64>>::Error;

fn zero() -> Self {
0
}

fn one() -> Self {
1
}
}

impl Integer for u128 {
type TryFromUsizeError = <Self as TryFrom<usize>>::Error;
type TryIntoU64Error = <Self as TryInto<u64>>::Error;

fn zero() -> Self {
0
}

fn one() -> Self {
1
}
}

make_field!(
/// Same as Field32, but encoded in little endian for compatibility with Prio v2.
FieldPrio2,
Expand Down Expand Up @@ -850,16 +884,14 @@ pub(crate) fn decode_fieldvec<F: FieldElement>(

#[cfg(test)]
pub(crate) mod test_utils {
use super::{FieldElement, FieldElementWithInteger};
use super::{FieldElement, FieldElementWithInteger, Integer};
use crate::{codec::CodecError, field::FieldError, prng::Prng};
use assert_matches::assert_matches;
use std::{
collections::hash_map::DefaultHasher,
convert::{TryFrom, TryInto},
fmt::Debug,
convert::TryFrom,
hash::{Hash, Hasher},
io::Cursor,
ops::{Add, BitAnd, Div, Shl, Shr, Sub},
};

/// A test-only copy of `FieldElementWithInteger`.
Expand All @@ -870,54 +902,42 @@ pub(crate) mod test_utils {
/// requires the `Integer` associated type satisfy `Clone`, not `Copy`, so that it may be used
/// with arbitrary precision integer implementations.
pub(crate) trait TestFieldElementWithInteger:
FieldElement + From<Self::Integer>
FieldElement + From<Self::TestInteger>
{
type IntegerTryFromError: std::error::Error;
type TryIntoU64Error: std::error::Error;
type Integer: Clone
+ Debug
+ Eq
+ Ord
+ BitAnd<Output = Self::Integer>
+ Div<Output = Self::Integer>
+ Shl<usize, Output = Self::Integer>
+ Shr<usize, Output = Self::Integer>
+ Add<Output = Self::Integer>
+ Sub<Output = Self::Integer>
+ From<Self>
+ TryFrom<usize, Error = Self::IntegerTryFromError>
+ TryInto<u64, Error = Self::TryIntoU64Error>;

fn pow(&self, exp: Self::Integer) -> Self;

fn modulus() -> Self::Integer;
type TestInteger: Integer + From<Self> + Clone;

fn pow(&self, exp: Self::TestInteger) -> Self;

fn modulus() -> Self::TestInteger;
}

impl<F> TestFieldElementWithInteger for F
where
F: FieldElementWithInteger,
{
type IntegerTryFromError = <F as FieldElementWithInteger>::IntegerTryFromError;
type TryIntoU64Error = <F as FieldElementWithInteger>::TryIntoU64Error;
type Integer = <F as FieldElementWithInteger>::Integer;
type IntegerTryFromError = <F::Integer as Integer>::TryFromUsizeError;
type TryIntoU64Error = <F::Integer as Integer>::TryIntoU64Error;
type TestInteger = F::Integer;

fn pow(&self, exp: Self::Integer) -> Self {
fn pow(&self, exp: Self::TestInteger) -> Self {
<F as FieldElementWithInteger>::pow(self, exp)
}

fn modulus() -> Self::Integer {
fn modulus() -> Self::TestInteger {
<F as FieldElementWithInteger>::modulus()
}
}

pub(crate) fn field_element_test_common<F: TestFieldElementWithInteger>() {
let mut prng: Prng<F, _> = Prng::new().unwrap();
let int_modulus = F::modulus();
let int_one = F::Integer::try_from(1).unwrap();
let int_one = F::TestInteger::try_from(1).unwrap();
let zero = F::zero();
let one = F::one();
let two = F::from(F::Integer::try_from(2).unwrap());
let four = F::from(F::Integer::try_from(4).unwrap());
let two = F::from(F::TestInteger::try_from(2).unwrap());
let four = F::from(F::TestInteger::try_from(4).unwrap());

// add
assert_eq!(F::from(int_modulus.clone() - int_one.clone()) + one, zero);
Expand Down Expand Up @@ -971,10 +991,22 @@ pub(crate) mod test_utils {
assert_eq!(a, c);

// integer conversion
assert_eq!(F::Integer::from(zero), F::Integer::try_from(0).unwrap());
assert_eq!(F::Integer::from(one), F::Integer::try_from(1).unwrap());
assert_eq!(F::Integer::from(two), F::Integer::try_from(2).unwrap());
assert_eq!(F::Integer::from(four), F::Integer::try_from(4).unwrap());
assert_eq!(
F::TestInteger::from(zero),
F::TestInteger::try_from(0).unwrap()
);
assert_eq!(
F::TestInteger::from(one),
F::TestInteger::try_from(1).unwrap()
);
assert_eq!(
F::TestInteger::from(two),
F::TestInteger::try_from(2).unwrap()
);
assert_eq!(
F::TestInteger::from(four),
F::TestInteger::try_from(4).unwrap()
);

// serialization
let test_inputs = vec![
Expand Down Expand Up @@ -1032,7 +1064,7 @@ pub(crate) mod test_utils {
// various products that should be equal have the same hash. Three is chosen as a generator
// here because it happens to generate fairly large subgroups of (Z/pZ)* for all four
// primes.
let three = F::from(F::Integer::try_from(3).unwrap());
let three = F::from(F::TestInteger::try_from(3).unwrap());
let mut powers_of_three = Vec::with_capacity(500);
let mut power = one;
for _ in 0..500 {
Expand Down
26 changes: 20 additions & 6 deletions src/field/field255.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ mod tests {
codec::Encode,
field::{
test_utils::{field_element_test_common, TestFieldElementWithInteger},
FieldElement, FieldError,
FieldElement, FieldError, Integer,
},
};
use assert_matches::assert_matches;
Expand Down Expand Up @@ -375,16 +375,30 @@ mod tests {
}
}

impl Integer for BigUint {
type TryFromUsizeError = <Self as TryFrom<usize>>::Error;

type TryIntoU64Error = <Self as TryInto<u64>>::Error;

fn zero() -> Self {
Self::new(Vec::new())
}

fn one() -> Self {
Self::new(Vec::from([1]))
}
}

impl TestFieldElementWithInteger for Field255 {
type Integer = BigUint;
type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error;
type TryIntoU64Error = <Self::Integer as TryInto<u64>>::Error;
type TestInteger = BigUint;
type IntegerTryFromError = <Self::TestInteger as TryFrom<usize>>::Error;
type TryIntoU64Error = <Self::TestInteger as TryInto<u64>>::Error;

fn pow(&self, _exp: Self::Integer) -> Self {
fn pow(&self, _exp: Self::TestInteger) -> Self {
unimplemented!("Field255::pow() is not implemented because it's not needed yet")
}

fn modulus() -> Self::Integer {
fn modulus() -> Self::TestInteger {
MODULUS.clone()
}
}
Expand Down
10 changes: 6 additions & 4 deletions src/flp/types/fixedpoint_l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,17 @@
pub mod compatible_float;

use crate::dp::{distributions::ZCdpDiscreteGaussian, DifferentialPrivacyStrategy, DpError};
use crate::field::{Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt};
use crate::field::{
Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt, Integer,
};
use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval};
use crate::flp::types::fixedpoint_l2::compatible_float::CompatibleFloat;
use crate::flp::types::parallel_sum_range_checks;
use crate::flp::{FlpError, Gadget, Type, TypeWithNoise};
use crate::vdaf::xof::SeedStreamTurboShake128;
use fixed::traits::Fixed;
use num_bigint::{BigInt, BigUint, TryFromBigIntError};
use num_integer::Integer;
use num_integer::Integer as _;
use num_rational::Ratio;
use rand::{distributions::Distribution, Rng};
use rand_core::SeedableRng;
Expand Down Expand Up @@ -250,7 +252,7 @@ where
/// fixed point vector with `entries` entries.
pub fn new(entries: usize) -> Result<Self, FlpError> {
// (0) initialize constants
let fi_one = Field128::one_integer();
let fi_one = <Field128 as FieldElementWithInteger>::Integer::one();

// (I) Check that the fixed type is compatible.
//
Expand Down Expand Up @@ -537,7 +539,7 @@ where

// Chunks which are too short need to be extended with a share of the
// encoded zero value, that is: 1/num_shares * (2^(n-1))
let fi_one = Field128::one_integer();
let fi_one = <Field128 as FieldElementWithInteger>::Integer::one();
let zero_enc = Field128::from(fi_one << (self.bits_per_entry - 1));
let zero_enc_share = zero_enc * num_shares_inverse;

Expand Down

0 comments on commit 711aec2

Please sign in to comment.