diff --git a/Cargo.toml b/Cargo.toml index a88e6985b..3c288beeb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ num-bigint = { version = "0.4.6", optional = true, features = ["rand", "serde"] num-integer = { version = "0.1.46", optional = true } num-iter = { version = "0.1.45", optional = true } num-rational = { version = "0.4.2", optional = true, features = ["serde"] } -num-traits = { version = "0.2.19", optional = true } +num-traits = "0.2.19" rand = "0.8" rand_core = "0.6.4" rayon = { version = "1.10.0", optional = true } @@ -52,7 +52,7 @@ statrs = "0.17.1" [features] default = ["crypto-dependencies"] -experimental = ["bitvec", "fiat-crypto", "fixed", "num-bigint", "num-rational", "num-traits", "num-integer", "num-iter"] +experimental = ["bitvec", "fiat-crypto", "fixed", "num-bigint", "num-rational", "num-integer", "num-iter"] multithreaded = ["rayon"] crypto-dependencies = ["aes", "ctr", "hmac", "sha2"] test-util = ["hex", "serde_json", "zipf"] diff --git a/documentation/field_parameters.sage b/documentation/field_parameters.sage index 328984d57..377729bca 100755 --- a/documentation/field_parameters.sage +++ b/documentation/field_parameters.sage @@ -99,21 +99,27 @@ class Field: for i in range(min(self.num_roots, 20) + 1) ] + def log2_base(self): + """ + Returns log2(r), where r is the base used for multiprecision arithmetic. + """ + return log(self.r, 2) + + def log2_radix(self): + """ + Returns log2(R), where R is the machine word-friendly modulus + used in the Montgomery representation. + """ + return log(self.R, 2) + FIELDS = [ Field( - "FieldPrio2, u128", + "FieldPrio2, u32", 2 ^ 20 * 4095 + 1, 3925978153, - 2 ^ 64, - 2 ^ 128, - ), - Field( - "Field64, u128", - 2 ^ 32 * 4294967295 + 1, - pow(7, 4294967295, 2 ^ 32 * 4294967295 + 1), - 2 ^ 64, - 2 ^ 128, + 2 ^ 32, + 2 ^ 32, ), Field( "Field64, u64", @@ -140,4 +146,6 @@ for field in FIELDS: print(f"bit_mask: {field.bit_mask()}") print("roots:") pprint.pprint(field.roots()) + print(f"log2_base: {field.log2_base()}") + print(f"log2_radix: {field.log2_radix()}") print() diff --git a/src/field.rs b/src/field.rs index 3c329a455..7b8460a4b 100644 --- a/src/field.rs +++ b/src/field.rs @@ -9,8 +9,7 @@ use crate::{ codec::{CodecError, Decode, Encode}, - fp::{FP128, FP32}, - fp64::FP64, + fp::{FieldOps, FieldParameters, FP128, FP32, FP64}, prng::{Prng, PrngError}, }; use rand::{ @@ -465,12 +464,12 @@ macro_rules! make_field { int &= mask; - if int >= $fp.p { + if int >= $fp::PRIME { return Err(FieldError::ModulusOverflow); } // FieldParameters::montgomery() will return a value that has been fully reduced // mod p, satisfying the invariant on Self. - Ok(Self($fp.montgomery(int))) + Ok(Self($fp::montgomery(int))) } } @@ -481,8 +480,8 @@ macro_rules! make_field { // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq // Check the invariant that the integer representation is fully reduced. - debug_assert!(self.0 < $fp.p); - debug_assert!(rhs.0 < $fp.p); + debug_assert!(self.0 < $fp::PRIME); + debug_assert!(rhs.0 < $fp::PRIME); self.0 == rhs.0 } @@ -507,7 +506,7 @@ macro_rules! make_field { // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq // Check the invariant that the integer representation is fully reduced. - debug_assert!(self.0 < $fp.p); + debug_assert!(self.0 < $fp::PRIME); self.0.hash(state); } @@ -520,7 +519,7 @@ macro_rules! make_field { fn add(self, rhs: Self) -> Self { // FieldParameters::add() returns a value that has been fully reduced // mod p, satisfying the invariant on Self. - Self($fp.add(self.0, rhs.0)) + Self($fp::add(self.0, rhs.0)) } } @@ -542,7 +541,7 @@ macro_rules! make_field { fn sub(self, rhs: Self) -> Self { // We know that self.0 and rhs.0 are both less than p, thus FieldParameters::sub() // returns a value less than p, satisfying the invariant on Self. - Self($fp.sub(self.0, rhs.0)) + Self($fp::sub(self.0, rhs.0)) } } @@ -564,7 +563,7 @@ macro_rules! make_field { fn mul(self, rhs: Self) -> Self { // FieldParameters::mul() always returns a value less than p, so the invariant on // Self is satisfied. - Self($fp.mul(self.0, rhs.0)) + Self($fp::mul(self.0, rhs.0)) } } @@ -607,7 +606,7 @@ macro_rules! make_field { fn neg(self) -> Self { // FieldParameters::neg() will return a value less than p because self.0 is less // than p, and neg() dispatches to sub(). - Self($fp.neg(self.0)) + Self($fp::neg(self.0)) } } @@ -622,19 +621,19 @@ macro_rules! make_field { fn from(x: $int_conversion) -> Self { // FieldParameters::montgomery() will return a value that has been fully reduced // mod p, satisfying the invariant on Self. - Self($fp.montgomery($int_internal::try_from(x).unwrap())) + Self($fp::montgomery($int_internal::try_from(x).unwrap())) } } impl From<$elem> for $int_conversion { fn from(x: $elem) -> Self { - $int_conversion::try_from($fp.residue(x.0)).unwrap() + $int_conversion::try_from($fp::residue(x.0)).unwrap() } } impl PartialEq<$int_conversion> for $elem { fn eq(&self, rhs: &$int_conversion) -> bool { - $fp.residue(self.0) == $int_internal::try_from(*rhs).unwrap() + $fp::residue(self.0) == $int_internal::try_from(*rhs).unwrap() } } @@ -648,7 +647,7 @@ macro_rules! make_field { impl From<$elem> for [u8; $elem::ENCODED_SIZE] { fn from(elem: $elem) -> Self { - let int = $fp.residue(elem.0); + let int = $fp::residue(elem.0); let mut slice = [0; $elem::ENCODED_SIZE]; for i in 0..$elem::ENCODED_SIZE { slice[i] = ((int >> (i << 3)) & 0xff) as u8; @@ -665,13 +664,13 @@ macro_rules! make_field { impl Display for $elem { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - write!(f, "{}", $fp.residue(self.0)) + write!(f, "{}", $fp::residue(self.0)) } } impl Debug for $elem { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", $fp.residue(self.0)) + write!(f, "{}", $fp::residue(self.0)) } } @@ -719,11 +718,11 @@ macro_rules! make_field { fn inv(&self) -> Self { // FieldParameters::inv() ultimately relies on mul(), and will always return a // value less than p. - Self($fp.inv(self.0)) + Self($fp::inv(self.0)) } fn try_from_random(bytes: &[u8]) -> Result { - $elem::try_from_bytes(bytes, $fp.bit_mask) + $elem::try_from_bytes(bytes, $fp::BIT_MASK) } fn zero() -> Self { @@ -731,7 +730,7 @@ macro_rules! make_field { } fn one() -> Self { - Self($fp.roots[0]) + Self($fp::ROOTS[0]) } } @@ -741,26 +740,26 @@ macro_rules! make_field { fn pow(&self, exp: Self::Integer) -> Self { // FieldParameters::pow() relies on mul(), and will always return a value less // than p. - Self($fp.pow(self.0, $int_internal::try_from(exp).unwrap())) + Self($fp::pow(self.0, $int_internal::try_from(exp).unwrap())) } fn modulus() -> Self::Integer { - $fp.p as $int_conversion + $fp::PRIME as $int_conversion } } impl FftFriendlyFieldElement for $elem { fn generator() -> Self { - Self($fp.g) + Self($fp::G) } fn generator_order() -> Self::Integer { - 1 << (Self::Integer::try_from($fp.num_roots).unwrap()) + 1 << (Self::Integer::try_from($fp::NUM_ROOTS).unwrap()) } fn root(l: usize) -> Option { - if l < min($fp.roots.len(), $fp.num_roots+1) { - Some(Self($fp.roots[l])) + if l < min($fp::ROOTS.len(), $fp::NUM_ROOTS+1) { + Some(Self($fp::ROOTS[l])) } else { None } @@ -815,9 +814,9 @@ impl Integer for u128 { } make_field!( - /// Same as Field32, but encoded in little endian for compatibility with Prio v2. + /// `GF(4293918721)`, a 32-bit field. FieldPrio2, - u128, + u32, u32, FP32, 4, diff --git a/src/fp.rs b/src/fp.rs index da9fe0fc9..956ba63c6 100644 --- a/src/fp.rs +++ b/src/fp.rs @@ -2,305 +2,96 @@ //! Finite field arithmetic for any field GF(p) for which p < 2^128. +#[macro_use] +mod ops; + +pub use ops::{FieldOps, FieldParameters}; + /// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots /// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This /// is the largest input size we would ever need for the cryptographic applications in this crate. pub(crate) const MAX_ROOTS: usize = 20; -/// This structure represents the parameters of a finite field GF(p) for which p < 2^128. -#[derive(Debug, PartialEq, Eq)] -pub(crate) struct FieldParameters { - /// The prime modulus `p`. - pub p: u128, - /// `mu = -p^(-1) mod 2^64`. - pub mu: u64, - /// `r2 = (2^128)^2 mod p`. - pub r2: u128, - /// The `2^num_roots`-th -principal root of unity. This element is used to generate the - /// elements of `roots`. - pub g: u128, - /// The number of principal roots of unity in `roots`. - pub num_roots: usize, - /// Equal to `2^b - 1`, where `b` is the length of `p` in bits. - pub bit_mask: u128, - /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the - /// multiplicative group. `roots[0]` is equal to one by definition. - pub roots: [u128; MAX_ROOTS + 1], -} - -impl FieldParameters { - /// Addition. The result will be in [0, p), so long as both x and y are as well. - #[inline(always)] - pub fn add(&self, x: u128, y: u128) -> u128 { - // 0,x - // + 0,y - // ===== - // c,z - let (z, carry) = x.overflowing_add(y); - // c, z - // - 0, p - // ======== - // b1,s1,s0 - let (s0, b0) = z.overflowing_sub(self.p); - let (_s1, b1) = (carry as u128).overflowing_sub(b0 as u128); - // if b1 == 1: return z - // else: return s0 - let m = 0u128.wrapping_sub(b1 as u128); - (z & m) | (s0 & !m) - } - - /// Subtraction. The result will be in [0, p), so long as both x and y are as well. - #[inline(always)] - pub fn sub(&self, x: u128, y: u128) -> u128 { - // x - // - y - // ======== - // b0,z0 - let (z0, b0) = x.overflowing_sub(y); - let m = 0u128.wrapping_sub(b0 as u128); - // z0 - // + p - // ======== - // s1,s0 - z0.wrapping_add(m & self.p) - // if b1 == 1: return s0 - // else: return z0 - } - - /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm - /// described [here][montgomery]. The result will be in [0, p). - /// - /// # Example usage - /// ```text - /// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46); - /// ``` - /// - /// [montgomery]: https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf - #[inline(always)] - pub fn mul(&self, x: u128, y: u128) -> u128 { - let x = [lo64(x), hi64(x)]; - let y = [lo64(y), hi64(y)]; - let p = [lo64(self.p), hi64(self.p)]; - let mut zz = [0; 4]; - - // Integer multiplication - // z = x * y - - // x1,x0 - // * y1,y0 - // =========== - // z3,z2,z1,z0 - let mut result = x[0] * y[0]; - let mut carry = hi64(result); - zz[0] = lo64(result); - result = x[0] * y[1]; - let mut hi = hi64(result); - let mut lo = lo64(result); - result = lo + carry; - zz[1] = lo64(result); - let mut cc = hi64(result); - result = hi + cc; - zz[2] = lo64(result); - - result = x[1] * y[0]; - hi = hi64(result); - lo = lo64(result); - result = zz[1] + lo; - zz[1] = lo64(result); - cc = hi64(result); - result = hi + cc; - carry = lo64(result); - - result = x[1] * y[1]; - hi = hi64(result); - lo = lo64(result); - result = lo + carry; - lo = lo64(result); - cc = hi64(result); - result = hi + cc; - hi = lo64(result); - result = zz[2] + lo; - zz[2] = lo64(result); - cc = hi64(result); - result = hi + cc; - zz[3] = lo64(result); - - // Montgomery Reduction - // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64. - - // z3,z2,z1,z0 - // + p1,p0 - // * w = mu*z0 - // =========== - // z3,z2,z1, 0 - let w = self.mu.wrapping_mul(zz[0] as u64); - result = p[0] * (w as u128); - hi = hi64(result); - lo = lo64(result); - result = zz[0] + lo; - zz[0] = lo64(result); - cc = hi64(result); - result = hi + cc; - carry = lo64(result); - - result = p[1] * (w as u128); - hi = hi64(result); - lo = lo64(result); - result = lo + carry; - lo = lo64(result); - cc = hi64(result); - result = hi + cc; - hi = lo64(result); - result = zz[1] + lo; - zz[1] = lo64(result); - cc = hi64(result); - result = zz[2] + hi + cc; - zz[2] = lo64(result); - cc = hi64(result); - result = zz[3] + cc; - zz[3] = lo64(result); - - // z3,z2,z1 - // + p1,p0 - // * w = mu*z1 - // =========== - // z3,z2, 0 - let w = self.mu.wrapping_mul(zz[1] as u64); - result = p[0] * (w as u128); - hi = hi64(result); - lo = lo64(result); - result = zz[1] + lo; - zz[1] = lo64(result); - cc = hi64(result); - result = hi + cc; - carry = lo64(result); - - result = p[1] * (w as u128); - hi = hi64(result); - lo = lo64(result); - result = lo + carry; - lo = lo64(result); - cc = hi64(result); - result = hi + cc; - hi = lo64(result); - result = zz[2] + lo; - zz[2] = lo64(result); - cc = hi64(result); - result = zz[3] + hi + cc; - zz[3] = lo64(result); - cc = hi64(result); - - // z = (z3,z2) - let prod = zz[2] | (zz[3] << 64); - - // Final subtraction - // If z >= p, then z = z - p - - // cc, z - // - 0, p - // ======== - // b1,s1,s0 - let (s0, b0) = prod.overflowing_sub(self.p); - let (_s1, b1) = cc.overflowing_sub(b0 as u128); - // if b1 == 1: return z - // else: return s0 - let mask = 0u128.wrapping_sub(b1 as u128); - (prod & mask) | (s0 & !mask) - } - - /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the - /// runtime of this algorithm is linear in the bit length of `exp`. - pub fn pow(&self, x: u128, exp: u128) -> u128 { - let mut t = self.montgomery(1); - for i in (0..128 - exp.leading_zeros()).rev() { - t = self.mul(t, t); - if (exp >> i) & 1 != 0 { - t = self.mul(t, x); - } - } - t - } - - /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of - /// this algorithm is linear in the bit length of `p`. - #[inline(always)] - pub fn inv(&self, x: u128) -> u128 { - self.pow(x, self.p - 2) - } - - /// Negation, i.e., `-x (mod p)` where `p` is the modulus. - #[inline(always)] - pub fn neg(&self, x: u128) -> u128 { - self.sub(0, x) - } - - /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery - /// domain in order to carry out field arithmetic. The result will be in [0, p). - /// - /// # Example usage - /// ```text - /// let integer = 1; // Standard integer representation - /// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain - /// assert_eq!(elem, 2564090464); - /// ``` - #[inline(always)] - pub fn montgomery(&self, x: u128) -> u128 { - modp(self.mul(x, self.r2), self.p) - } - - /// Maps a field element to its representation as an integer. The result will be in [0, p). - /// - /// #Example usage - /// ```text - /// let elem = 2564090464; // Internal representation in the Montgomery domain - /// let integer = fp.residue(elem); // Standard integer representation - /// assert_eq!(integer, 1); - /// ``` - #[inline(always)] - pub fn residue(&self, x: u128) -> u128 { - modp(self.mul(x, 1), self.p) - } +/// FP32 implements operations over GF(p) for which the prime +/// modulus `p` fits in a u32 word. +pub(crate) struct FP32; + +impl_field_ops_single_word!(FP32, u32, u64); + +impl FieldParameters for FP32 { + const PRIME: u32 = 4293918721; + const MU: u32 = 4293918719; + const R2: u32 = 266338049; + const G: u32 = 3903828692; + const NUM_ROOTS: usize = 20; + const BIT_MASK: u32 = 4294967295; + const ROOTS: [u32; MAX_ROOTS + 1] = [ + 1048575, 4292870146, 1189722990, 3984864191, 2523259768, 2828840154, 1658715539, + 1534972560, 3732920810, 3229320047, 2836564014, 2170197442, 3760663902, 2144268387, + 3849278021, 1395394315, 574397626, 125025876, 3755041587, 2680072542, 3903828692, + ]; + #[cfg(test)] + const LOG2_BASE: usize = 32; + #[cfg(test)] + const LOG2_RADIX: usize = 32; } -#[inline(always)] -pub(crate) fn lo64(x: u128) -> u128 { - x & ((1 << 64) - 1) +/// FP64 implements operations over GF(p) for which the prime +/// modulus `p` fits in a u64 word. +pub(crate) struct FP64; + +impl_field_ops_single_word!(FP64, u64, u128); + +impl FieldParameters for FP64 { + const PRIME: u64 = 18446744069414584321; + const MU: u64 = 18446744069414584319; + const R2: u64 = 18446744065119617025; + const G: u64 = 15733474329512464024; + const NUM_ROOTS: usize = 32; + const BIT_MASK: u64 = 18446744073709551615; + const ROOTS: [u64; MAX_ROOTS + 1] = [ + 4294967295, + 18446744065119617026, + 18446744069414518785, + 18374686475393433601, + 268435456, + 18446673700670406657, + 18446744069414584193, + 576460752303421440, + 16576810576923738718, + 6647628942875889800, + 10087739294013848503, + 2135208489130820273, + 10781050935026037169, + 3878014442329970502, + 1205735313231991947, + 2523909884358325590, + 13797134855221748930, + 12267112747022536458, + 430584883067102937, + 10135969988448727187, + 6815045114074884550, + ]; + #[cfg(test)] + const LOG2_BASE: usize = 64; + #[cfg(test)] + const LOG2_RADIX: usize = 64; } -#[inline(always)] -pub(crate) fn hi64(x: u128) -> u128 { - x >> 64 -} +/// FP128 implements operations over GF(p) for which the prime +/// modulus `p` fits in a u128 word. +pub(crate) struct FP128; -#[inline(always)] -fn modp(x: u128, p: u128) -> u128 { - let (z, carry) = x.overflowing_sub(p); - let m = 0u128.wrapping_sub(carry as u128); - z.wrapping_add(m & p) -} +impl_field_ops_split_word!(FP128, u128, u64); -pub(crate) const FP32: FieldParameters = FieldParameters { - p: 4293918721, // 32-bit prime - mu: 17302828673139736575, - r2: 1676699750, - g: 1074114499, - num_roots: 20, - bit_mask: 4294967295, - roots: [ - 2564090464, 1729828257, 306605458, 2294308040, 1648889905, 57098624, 2788941825, - 2779858277, 368200145, 2760217336, 594450960, 4255832533, 1372848488, 721329415, - 3873251478, 1134002069, 7138597, 2004587313, 2989350643, 725214187, 1074114499, - ], -}; - -pub(crate) const FP128: FieldParameters = FieldParameters { - p: 340282366920938462946865773367900766209, // 128-bit prime - mu: 18446744073709551615, - r2: 403909908237944342183153, - g: 107630958476043550189608038630704257141, - num_roots: 66, - bit_mask: 340282366920938463463374607431768211455, - roots: [ +impl FieldParameters for FP128 { + const PRIME: u128 = 340282366920938462946865773367900766209; + const MU: u128 = 18446744073709551615; + const R2: u128 = 403909908237944342183153; + const G: u128 = 107630958476043550189608038630704257141; + const NUM_ROOTS: usize = 66; + const BIT_MASK: u128 = 340282366920938463463374607431768211455; + const ROOTS: [u128; MAX_ROOTS + 1] = [ 516508834063867445247, 340282366920938462430356939304033320962, 129526470195413442198896969089616959958, @@ -322,8 +113,12 @@ pub(crate) const FP128: FieldParameters = FieldParameters { 332677126194796691532164818746739771387, 258279638927684931537542082169183965856, 148221243758794364405224645520862378432, - ], -}; + ]; + #[cfg(test)] + const LOG2_BASE: usize = 64; + #[cfg(test)] + const LOG2_RADIX: usize = 128; +} /// Compute the ceiling of the base-2 logarithm of `x`. pub(crate) fn log2(x: u128) -> u128 { @@ -333,98 +128,14 @@ pub(crate) fn log2(x: u128) -> u128 { #[cfg(test)] pub(crate) mod tests { - use super::*; + use core::{cmp::max, fmt::Debug, marker::PhantomData}; use modinverse::modinverse; use num_bigint::{BigInt, ToBigInt}; + use num_traits::AsPrimitive; use rand::{distributions::Distribution, thread_rng, Rng}; - use std::cmp::max; - - /// This trait abstracts over the details of [`FieldParameters`] and - /// [`FieldParameters64`](crate::fp64::FieldParameters64) to allow reuse of test code. - pub(crate) trait TestFieldParameters { - fn p(&self) -> u128; - fn g(&self) -> u128; - fn r2(&self) -> u128; - fn mu(&self) -> u64; - fn bit_mask(&self) -> u128; - fn num_roots(&self) -> usize; - fn roots(&self) -> Vec; - fn montgomery(&self, x: u128) -> u128; - fn residue(&self, x: u128) -> u128; - fn add(&self, x: u128, y: u128) -> u128; - fn sub(&self, x: u128, y: u128) -> u128; - fn neg(&self, x: u128) -> u128; - fn mul(&self, x: u128, y: u128) -> u128; - fn pow(&self, x: u128, exp: u128) -> u128; - fn inv(&self, x: u128) -> u128; - fn radix(&self) -> BigInt; - } - - impl TestFieldParameters for FieldParameters { - fn p(&self) -> u128 { - self.p - } - - fn g(&self) -> u128 { - self.g - } - - fn r2(&self) -> u128 { - self.r2 - } - - fn mu(&self) -> u64 { - self.mu - } - - fn bit_mask(&self) -> u128 { - self.bit_mask - } - - fn num_roots(&self) -> usize { - self.num_roots - } - - fn roots(&self) -> Vec { - self.roots.to_vec() - } - - fn montgomery(&self, x: u128) -> u128 { - FieldParameters::montgomery(self, x) - } - - fn residue(&self, x: u128) -> u128 { - FieldParameters::residue(self, x) - } - - fn add(&self, x: u128, y: u128) -> u128 { - FieldParameters::add(self, x, y) - } - - fn sub(&self, x: u128, y: u128) -> u128 { - FieldParameters::sub(self, x, y) - } - - fn neg(&self, x: u128) -> u128 { - FieldParameters::neg(self, x) - } - - fn mul(&self, x: u128, y: u128) -> u128 { - FieldParameters::mul(self, x, y) - } - - fn pow(&self, x: u128, exp: u128) -> u128 { - FieldParameters::pow(self, x, exp) - } - - fn inv(&self, x: u128) -> u128 { - FieldParameters::inv(self, x) - } - fn radix(&self) -> BigInt { - BigInt::from(1) << 128 - } - } + use super::ops::Word; + use crate::fp::{log2, FieldOps, FP128, FP32, FP64, MAX_ROOTS}; #[test] fn test_log2() { @@ -440,165 +151,213 @@ pub(crate) mod tests { assert_eq!(log2((1 << 127) + 13), 128); } - pub(crate) struct TestFieldParametersData { - /// The paramters being tested - pub fp: Box, + struct TestFieldParametersData, W: Word> { /// Expected fp.p - pub expected_p: u128, + pub expected_p: W, /// Expected fp.residue(fp.g) - pub expected_g: u128, - /// Expect fp.residue(fp.pow(fp.g, expected_order)) == 1 - pub expected_order: u128, - } + pub expected_g: W, + /// Expect fp.residue(fp.pow(fp.g, 1 << expected_log2_order)) == 1 + pub expected_log2_order: usize, - #[test] - fn test_fp32_u128() { - all_field_parameters_tests(TestFieldParametersData { - fp: Box::new(FP32), - expected_p: 4293918721, - expected_g: 3925978153, - expected_order: 1 << 20, - }); + phantom: PhantomData, } - #[test] - fn test_fp128_u128() { - all_field_parameters_tests(TestFieldParametersData { - fp: Box::new(FP128), - expected_p: 340282366920938462946865773367900766209, - expected_g: 145091266659756586618791329697897684742, - expected_order: 1 << 66, - }); - } - - pub(crate) fn all_field_parameters_tests(t: TestFieldParametersData) { - // Check that the field parameters have been constructed properly. - check_consistency(t.fp.as_ref(), t.expected_p, t.expected_g, t.expected_order); + impl TestFieldParametersData + where + T: FieldOps, + W: Word + AsPrimitive + ToBigInt + for<'a> TryFrom<&'a BigInt> + Debug, + for<'a> >::Error: Debug, + { + fn all_field_parameters_tests(&self) { + self.check_generator(); + self.check_consistency(); + self.arithmetic_test(); + } // Check that the generator has the correct order. - assert_eq!(t.fp.residue(t.fp.pow(t.fp.g(), t.expected_order)), 1); - assert_ne!(t.fp.residue(t.fp.pow(t.fp.g(), t.expected_order / 2)), 1); - - // Test arithmetic using the field parameters. - arithmetic_test(t.fp.as_ref()); - } + fn check_generator(&self) { + assert_eq!( + T::residue(T::pow(T::G, W::ONE << self.expected_log2_order)), + W::ONE + ); + assert_ne!( + T::residue(T::pow(T::G, W::ONE << (self.expected_log2_order / 2))), + W::ONE + ); + } - fn check_consistency(fp: &dyn TestFieldParameters, p: u128, g: u128, order: u128) { - assert_eq!(fp.p(), p, "p mismatch"); - - let mu = match modinverse((-(p as i128)).rem_euclid(1 << 64), 1 << 64) { - Some(mu) => mu as u64, - None => panic!("inverse of -p (mod 2^64) is undefined"), - }; - assert_eq!(fp.mu(), mu, "mu mismatch"); - - let big_p = &p.to_bigint().unwrap(); - let big_r: &BigInt = &(fp.radix() % big_p); - let big_r2: &BigInt = &(&(big_r * big_r) % big_p); - let mut it = big_r2.iter_u64_digits(); - let mut r2 = 0; - r2 |= it.next().unwrap() as u128; - if let Some(x) = it.next() { - r2 |= (x as u128) << 64; + // Check that the field parameters have been constructed properly. + fn check_consistency(&self) { + assert_eq!(T::PRIME, self.expected_p, "p mismatch"); + + let u128_p = T::PRIME.as_(); + let base = 1i128 << T::LOG2_BASE; + let mu = match modinverse((-(u128_p as i128)).rem_euclid(base), base) { + Some(mu) => mu as u128, + None => panic!("inverse of -p (mod base) is undefined"), + }; + assert_eq!(T::MU.as_(), mu, "mu mismatch"); + + let big_p = &u128_p.to_bigint().unwrap(); + let big_radix = BigInt::from(1) << T::LOG2_RADIX; + let big_r: &BigInt = &(big_radix % big_p); + let big_r2: &BigInt = &(&(big_r * big_r) % big_p); + let mut it = big_r2.iter_u64_digits(); + let mut r2 = 0; + r2 |= it.next().unwrap() as u128; + if let Some(x) = it.next() { + r2 |= (x as u128) << 64; + } + assert_eq!(T::R2.as_(), r2, "r2 mismatch"); + + assert_eq!(T::G, T::montgomery(self.expected_g), "g mismatch"); + assert_eq!( + T::residue(T::pow(T::G, W::ONE << self.expected_log2_order)), + W::ONE, + "g order incorrect" + ); + + let num_roots = self.expected_log2_order; + assert_eq!(T::NUM_ROOTS, num_roots, "num_roots mismatch"); + + let mut roots = vec![W::ZERO; max(num_roots, MAX_ROOTS) + 1]; + roots[num_roots] = T::montgomery(self.expected_g); + for i in (0..num_roots).rev() { + roots[i] = T::mul(roots[i + 1], roots[i + 1]); + } + assert_eq!(T::ROOTS, &roots[..MAX_ROOTS + 1], "roots mismatch"); + assert_eq!(T::residue(T::ROOTS[0]), W::ONE, "first root is not one"); + + let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1); + assert_eq!( + T::BIT_MASK.to_bigint().unwrap(), + bit_mask, + "bit_mask mismatch" + ); } - assert_eq!(fp.r2(), r2, "r2 mismatch"); - assert_eq!(fp.g(), fp.montgomery(g), "g mismatch"); - assert_eq!(fp.residue(fp.pow(fp.g(), order)), 1, "g order incorrect"); + // Test arithmetic using the field parameters. + fn arithmetic_test(&self) { + let u128_p = T::PRIME.as_(); + let big_p = &u128_p.to_bigint().unwrap(); + let big_zero = &BigInt::from(0); + let uniform = rand::distributions::Uniform::from(0..u128_p); + let mut rng = thread_rng(); + + let mut weird_ints = Vec::from([ + 0, + 1, + T::BIT_MASK.as_() - u128_p, + T::BIT_MASK.as_() - u128_p + 1, + u128_p - 1, + ]); + if u128_p > u64::MAX as u128 { + weird_ints.extend_from_slice(&[ + u64::MAX as u128, + 1 << 64, + u128_p & u64::MAX as u128, + u128_p & !u64::MAX as u128, + u128_p & !u64::MAX as u128 | 1, + ]); + } - let num_roots = log2(order) as usize; - assert_eq!(order, 1 << num_roots, "order not a power of 2"); - assert_eq!(fp.num_roots(), num_roots, "num_roots mismatch"); + let mut generate_random = || -> (W, BigInt) { + // Add bias to random element generation, to explore "interesting" inputs. + let intu128 = if rng.gen_ratio(1, 4) { + weird_ints[rng.gen_range(0..weird_ints.len())] + } else { + uniform.sample(&mut rng) + }; + let bigint = intu128.to_bigint().unwrap(); + let int = W::try_from(&bigint).unwrap(); + let montgomery_domain = T::montgomery(int); + (montgomery_domain, bigint) + }; - let mut roots = vec![0; max(num_roots, MAX_ROOTS) + 1]; - roots[num_roots] = fp.montgomery(g); - for i in (0..num_roots).rev() { - roots[i] = fp.mul(roots[i + 1], roots[i + 1]); + for _ in 0..1000 { + let (x, ref big_x) = generate_random(); + let (y, ref big_y) = generate_random(); + + // Test addition. + let got = T::add(x, y); + let want = (big_x + big_y) % big_p; + assert_eq!(T::residue(got).to_bigint().unwrap(), want); + + // Test subtraction. + let got = T::sub(x, y); + let want = if big_x >= big_y { + big_x - big_y + } else { + big_p - big_y + big_x + }; + assert_eq!(T::residue(got).to_bigint().unwrap(), want); + + // Test multiplication. + let got = T::mul(x, y); + let want = (big_x * big_y) % big_p; + assert_eq!(T::residue(got).to_bigint().unwrap(), want); + + // Test inversion. + let got = T::inv(x); + let want = big_x.modpow(&(big_p - 2), big_p); + assert_eq!(T::residue(got).to_bigint().unwrap(), want); + if big_x == big_zero { + assert_eq!(T::residue(T::mul(got, x)), W::ZERO); + } else { + assert_eq!(T::residue(T::mul(got, x)), W::ONE); + } + + // Test negation. + let got = T::neg(x); + let want = (big_p - big_x) % big_p; + assert_eq!(T::residue(got).to_bigint().unwrap(), want); + assert_eq!(T::residue(T::add(got, x)), W::ZERO); + } } - assert_eq!(fp.roots(), &roots[..MAX_ROOTS + 1], "roots mismatch"); - assert_eq!(fp.residue(fp.roots()[0]), 1, "first root is not one"); - - let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1); - assert_eq!( - fp.bit_mask().to_bigint().unwrap(), - bit_mask, - "bit_mask mismatch" - ); } - fn arithmetic_test(fp: &dyn TestFieldParameters) { - let big_p = &fp.p().to_bigint().unwrap(); - let big_zero = &BigInt::from(0); - let uniform = rand::distributions::Uniform::from(0..fp.p()); - let mut rng = thread_rng(); - - let mut weird_ints = Vec::from([ - 0, - 1, - fp.bit_mask() - fp.p(), - fp.bit_mask() - fp.p() + 1, - fp.p() - 1, - ]); - if fp.p() > u64::MAX as u128 { - weird_ints.extend_from_slice(&[ - u64::MAX as u128, - 1 << 64, - fp.p() & u64::MAX as u128, - fp.p() & !u64::MAX as u128, - fp.p() & !u64::MAX as u128 | 1, - ]); + mod fp32 { + #[test] + fn check_field_parameters() { + use super::*; + + TestFieldParametersData:: { + expected_p: 4293918721, + expected_g: 3925978153, + expected_log2_order: 20, + phantom: PhantomData, + } + .all_field_parameters_tests(); } + } - let mut generate_random = || -> (u128, BigInt) { - // Add bias to random element generation, to explore "interesting" inputs. - let int = if rng.gen_ratio(1, 4) { - weird_ints[rng.gen_range(0..weird_ints.len())] - } else { - uniform.sample(&mut rng) - }; - let bigint = int.to_bigint().unwrap(); - let montgomery_domain = fp.montgomery(int); - (montgomery_domain, bigint) - }; - - for _ in 0..1000 { - let (x, ref big_x) = generate_random(); - let (y, ref big_y) = generate_random(); - - // Test addition. - let got = fp.add(x, y); - let want = (big_x + big_y) % big_p; - assert_eq!(fp.residue(got).to_bigint().unwrap(), want); - - // Test subtraction. - let got = fp.sub(x, y); - let want = if big_x >= big_y { - big_x - big_y - } else { - big_p - big_y + big_x - }; - assert_eq!(fp.residue(got).to_bigint().unwrap(), want); - - // Test multiplication. - let got = fp.mul(x, y); - let want = (big_x * big_y) % big_p; - assert_eq!(fp.residue(got).to_bigint().unwrap(), want); - - // Test inversion. - let got = fp.inv(x); - let want = big_x.modpow(&(big_p - 2u128), big_p); - assert_eq!(fp.residue(got).to_bigint().unwrap(), want); - if big_x == big_zero { - assert_eq!(fp.residue(fp.mul(got, x)), 0); - } else { - assert_eq!(fp.residue(fp.mul(got, x)), 1); + mod fp64 { + #[test] + fn check_field_parameters() { + use super::*; + + TestFieldParametersData:: { + expected_p: 18446744069414584321, + expected_g: 1753635133440165772, + expected_log2_order: 32, + phantom: PhantomData, } + .all_field_parameters_tests(); + } + } - // Test negation. - let got = fp.neg(x); - let want = (big_p - big_x) % big_p; - assert_eq!(fp.residue(got).to_bigint().unwrap(), want); - assert_eq!(fp.residue(fp.add(got, x)), 0); + mod fp128 { + #[test] + fn check_field_parameters() { + use super::*; + + TestFieldParametersData:: { + expected_p: 340282366920938462946865773367900766209, + expected_g: 145091266659756586618791329697897684742, + expected_log2_order: 66, + phantom: PhantomData, + } + .all_field_parameters_tests(); } } } diff --git a/src/fp/ops.rs b/src/fp/ops.rs new file mode 100644 index 000000000..87aedc7ee --- /dev/null +++ b/src/fp/ops.rs @@ -0,0 +1,429 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Prime field arithmetic for any Galois field GF(p) for which `p < 2^W`, +//! where `W ≤ 128` is a specified word size. + +use num_traits::{ + ops::overflowing::{OverflowingAdd, OverflowingSub}, + AsPrimitive, ConstOne, ConstZero, PrimInt, Unsigned, WrappingAdd, WrappingMul, WrappingSub, +}; + +use crate::fp::MAX_ROOTS; + +/// `Word` is the datatype used for storing and operating with +/// field elements. +/// +/// The types `u32`, `u64`, and `u128` implement this trait. +pub trait Word: + 'static + + Unsigned + + PrimInt + + OverflowingAdd + + OverflowingSub + + WrappingAdd + + WrappingSub + + WrappingMul + + ConstZero + + ConstOne + + From +{ + /// Number of bits of word size. + const BITS: usize; +} + +impl Word for u32 { + const BITS: usize = Self::BITS as usize; +} + +impl Word for u64 { + const BITS: usize = Self::BITS as usize; +} + +impl Word for u128 { + const BITS: usize = Self::BITS as usize; +} + +/// `FieldParameters` sets the parameters to implement a prime field +/// GF(p) for which the prime modulus `p` fits in one word of +/// `W ≤ 128` bits. +pub trait FieldParameters { + /// The prime modulus `p`. + const PRIME: W; + /// `mu = -p^(-1) mod 2^LOG2_BASE`. + const MU: W; + /// `r2 = R^2 mod p`. + const R2: W; + /// The `2^num_roots`-th -principal root of unity. This element + /// is used to generate the elements of `ROOTS`. + const G: W; + /// The number of principal roots of unity in `ROOTS`. + const NUM_ROOTS: usize; + /// Equal to `2^b - 1`, where `b` is the length of `p` in bits. + const BIT_MASK: W; + /// `ROOTS[l]` is the `2^l`-th principal root of unity, i.e., + /// `ROOTS[l]` has order `2^l` in the multiplicative group. + /// `ROOTS[0]` is equal to one by definition. + const ROOTS: [W; MAX_ROOTS + 1]; + /// The log2(base) for the base used for multiprecision arithmetic. + /// So, `LOG2_BASE ≤ 64` as processors have at most a 64-bit + /// integer multiplier. + #[cfg(test)] + const LOG2_BASE: usize; + /// The log2(R) where R is the machine word-friendly modulus + /// used in the Montgomery representation. + #[cfg(test)] + const LOG2_RADIX: usize; +} + +/// `FieldOps` provides arithmetic operations over GF(p). +/// +/// The multiplication method is required as it admits different +/// implementations. +pub trait FieldOps: FieldParameters { + /// Addition. The result will be in [0, p), so long as both x + /// and y are as well. + #[inline(always)] + fn add(x: W, y: W) -> W { + // 0,x + // + 0,y + // ===== + // c,z + let (z, carry) = x.overflowing_add(&y); + + // c, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = z.overflowing_sub(&Self::PRIME); + let (_s1, b1) = + >::from(carry).overflowing_sub(&>::from(b0)); + // if b1 == 1: return z + // else: return s0 + let mask = W::ZERO.wrapping_sub(&>::from(b1)); + (z & mask) | (s0 & !mask) + } + + /// Subtraction. The result will be in [0, p), so long as both x + /// and y are as well. + #[inline(always)] + fn sub(x: W, y: W) -> W { + // x + // - y + // ======== + // b0,z0 + let (z0, b0) = x.overflowing_sub(&y); + let mask = W::ZERO.wrapping_sub(&>::from(b0)); + // z0 + // + p + // ======== + // s1,s0 + z0.wrapping_add(&(mask & Self::PRIME)) + // z0 + (m & self.p) + // if b1 == 1: return s0 + // else: return z0 + } + + /// Negation, i.e., `-x (mod p)` where `p` is the modulus. + #[inline(always)] + fn neg(x: W) -> W { + Self::sub(W::ZERO, x) + } + + #[inline(always)] + fn modp(x: W) -> W { + Self::sub(x, Self::PRIME) + } + + /// Multiplication. The result will be in [0, p), so long as both x + /// and y are as well. + fn mul(x: W, y: W) -> W; + + /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is + /// the modulus. Note that the runtime of this algorithm is + /// linear in the bit length of `exp`. + fn pow(x: W, exp: W) -> W { + let mut t = Self::ROOTS[0]; + for i in (0..W::BITS - (exp.leading_zeros() as usize)).rev() { + t = Self::mul(t, t); + if (exp >> i) & W::ONE != W::ZERO { + t = Self::mul(t, x); + } + } + t + } + + /// Modular inversion, i.e., x^-1 (mod p) where `p` is the + /// modulus. Note that the runtime of this algorithm is + /// linear in the bit length of `p`. + #[inline(always)] + fn inv(x: W) -> W { + Self::pow(x, Self::PRIME - W::ONE - W::ONE) + } + + /// Maps an integer to its internal representation. Field + /// elements are mapped to the Montgomery domain in order to + /// carry out field arithmetic. The result will be in [0, p). + #[inline(always)] + fn montgomery(x: W) -> W { + Self::modp(Self::mul(x, Self::R2)) + } + + /// Maps a field element to its representation as an integer. + /// The result will be in [0, p). + #[inline(always)] + fn residue(x: W) -> W { + Self::modp(Self::mul(x, W::ONE)) + } +} + +/// `FieldMulOpsSingleWord` implements prime field multiplication. +/// +/// The implementation assumes that the modulus `p` fits in one word of +/// 'W' bits, and that the product of two integers fits in a datatype +/// (`Self::DoubleWord`) of exactly `2*W` bits. +pub(crate) trait FieldMulOpsSingleWord: FieldParameters +where + W: Word + AsPrimitive, +{ + type DoubleWord: Word + AsPrimitive; + + /// Multiplication of field elements in the Montgomery domain. + /// This uses the Montgomery's [REDC algorithm][montgomery]. + /// The result will be in [0, p). + /// + /// [montgomery]: https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf + fn mul(x: W, y: W) -> W { + let hi_lo = |v: Self::DoubleWord| -> (W, W) { ((v >> W::BITS).as_(), v.as_()) }; + + // Integer multiplication + // z = x * y + + // x + // * y + // ===== + // z1,z0 + let (z1, z0) = hi_lo(x.as_() * y.as_()); + + // Montgomery Reduction + // z = z + p * mu*(z mod 2^W), where mu = (-p)^(-1) mod 2^W. + // = z + p * w, where w = mu*z0 + let w = Self::MU.wrapping_mul(&z0); + let (r1, r0) = hi_lo(Self::PRIME.as_() * w.as_()); + + // z1,z0 + // + r1,r0 + // ===== + // cc, z, 0 + let (_zero, carry) = z0.overflowing_add(&r0); + let (cc, z) = hi_lo(z1.as_() + r1.as_() + >::from(carry)); + + // Final subtraction + // If z >= p, then z = z - p + + // cc, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = z.overflowing_sub(&Self::PRIME); + let (_s1, b1) = cc.overflowing_sub(&>::from(b0)); + // if b1 == 1: return z + // else: return s0 + let mask = W::ZERO.wrapping_sub(&>::from(b1)); + (z & mask) | (s0 & !mask) + } +} + +/// `FieldMulOpsSplitWord` implements prime field multiplication. +/// +/// The implementation assumes that the modulus `p` fits in one word of +/// 'W' bits, but the product of two integers does not fit in any primitive +/// integer. Thus, multiplication is processed splitting integers in two +/// words. +pub(crate) trait FieldMulOpsSplitWord: FieldParameters +where + W: Word + AsPrimitive, +{ + type HalfWord: Word + AsPrimitive; + const MU: Self::HalfWord; + /// Multiplication of field elements in the Montgomery domain. + /// This uses the Montgomery's [REDC algorithm][montgomery]. + /// The result will be in [0, p). + /// + /// [montgomery]: https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf + fn mul(x: W, y: W) -> W { + let high = |v: W| v >> (W::BITS / 2); + let low = |v: W| v & ((W::ONE << (W::BITS / 2)) - W::ONE); + + let (x1, x0) = (high(x), low(x)); + let (y1, y0) = (high(y), low(y)); + + // Integer multiplication + // z = x * y + + // x1,x0 + // * y1,y0 + // =========== + // z3,z2,z1,z0 + let mut result = x0 * y0; + let mut carry = high(result); + let z0 = low(result); + result = x0 * y1; + let mut hi = high(result); + let mut lo = low(result); + result = lo + carry; + let mut z1 = low(result); + let mut cc = high(result); + result = hi + cc; + let mut z2 = low(result); + + result = x1 * y0; + hi = high(result); + lo = low(result); + result = z1 + lo; + z1 = low(result); + cc = high(result); + result = hi + cc; + carry = low(result); + + result = x1 * y1; + hi = high(result); + lo = low(result); + result = lo + carry; + lo = low(result); + cc = high(result); + result = hi + cc; + hi = low(result); + result = z2 + lo; + z2 = low(result); + cc = high(result); + result = hi + cc; + let mut z3 = low(result); + + // Montgomery Reduction + // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64. + + // z3,z2,z1,z0 + // + p1,p0 + // * w = mu*z0 + // =========== + // z3,z2,z1, 0 + let mut w = >::MU.wrapping_mul(&z0.as_()); + let p0 = low(Self::PRIME); + result = p0 * w.as_(); + hi = high(result); + lo = low(result); + result = z0 + lo; + cc = high(result); + result = hi + cc; + carry = low(result); + + let p1 = high(Self::PRIME); + result = p1 * w.as_(); + hi = high(result); + lo = low(result); + result = lo + carry; + lo = low(result); + cc = high(result); + result = hi + cc; + hi = low(result); + result = z1 + lo; + z1 = low(result); + cc = high(result); + result = z2 + hi + cc; + z2 = low(result); + cc = high(result); + result = z3 + cc; + z3 = low(result); + + // z3,z2,z1 + // + p1,p0 + // * w = mu*z1 + // =========== + // z3,z2, 0 + w = >::MU.wrapping_mul(&z1.as_()); + result = p0 * w.as_(); + hi = high(result); + lo = low(result); + result = z1 + lo; + cc = high(result); + result = hi + cc; + carry = low(result); + + result = p1 * w.as_(); + hi = high(result); + lo = low(result); + result = lo + carry; + lo = low(result); + cc = high(result); + result = hi + cc; + hi = low(result); + result = z2 + lo; + z2 = low(result); + cc = high(result); + result = z3 + hi + cc; + z3 = low(result); + cc = high(result); + + // z = (z3,z2) + let prod = z2 | (z3 << (W::BITS / 2)); + + // Final subtraction + // If z >= p, then z = z - p + + // cc, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = prod.overflowing_sub(&Self::PRIME); + let (_s1, b1) = cc.overflowing_sub(&>::from(b0)); + // if b1 == 1: return z + // else: return s0 + let mask = W::ZERO.wrapping_sub(&>::from(b1)); + (prod & mask) | (s0 & !mask) + } +} + +/// `impl_field_ops_single_word` helper to implement prime field operations. +/// +/// The implementation assumes that the modulus `p` fits in one word of +/// 'W' bits, and that the product of two integers fits in a datatype +/// (`Self::DoubleWord`) of exactly `2*W` bits. +macro_rules! impl_field_ops_single_word { + ($struct_name:ident, $W:ty, $W2:ty) => { + const _: () = assert!(<$W2>::BITS == 2 * <$W>::BITS); + impl $crate::fp::ops::FieldMulOpsSingleWord<$W> for $struct_name { + type DoubleWord = $W2; + } + impl $crate::fp::ops::FieldOps<$W> for $struct_name { + #[inline(always)] + fn mul(x: $W, y: $W) -> $W { + >::mul(x, y) + } + } + }; +} + +/// `impl_field_ops_split_word` helper to implement prime field operations. +/// +/// The implementation assumes that the modulus `p` fits in one word of +/// 'W' bits, but the product of two integers does not fit. Thus, +/// multiplication is processed splitting integers in two words. +macro_rules! impl_field_ops_split_word { + ($struct_name:ident, $W:ty, $W2:ty) => { + const _: () = assert!(2 * <$W2>::BITS == <$W>::BITS); + impl $crate::fp::ops::FieldMulOpsSplitWord<$W> for $struct_name { + type HalfWord = $W2; + const MU: Self::HalfWord = { + let mu = <$struct_name as FieldParameters<$W>>::MU; + assert!(mu <= (<$W2>::MAX as $W)); + mu as $W2 + }; + } + impl $crate::fp::ops::FieldOps<$W> for $struct_name { + #[inline(always)] + fn mul(x: $W, y: $W) -> $W { + >::mul(x, y) + } + } + }; +} diff --git a/src/fp64.rs b/src/fp64.rs deleted file mode 100644 index e2f2cf2b0..000000000 --- a/src/fp64.rs +++ /dev/null @@ -1,307 +0,0 @@ -// SPDX-License-Identifier: MPL-2.0 - -//! Finite field arithmetic for any field GF(p) for which p < 2^64. - -use crate::fp::{hi64, lo64, MAX_ROOTS}; - -/// This structure represents the parameters of a finite field GF(p) for which p < 2^64. -/// -/// See also [`FieldParameters`](crate::fp::FieldParameters). -#[derive(Debug, PartialEq, Eq)] -pub(crate) struct FieldParameters64 { - /// The prime modulus `p`. - pub p: u64, - /// `mu = -p^(-1) mod 2^64`. - pub mu: u64, - /// `r2 = (2^64)^2 mod p`. - pub r2: u64, - /// The `2^num_roots`-th -principal root of unity. This element is used to generate the - /// elements of `roots`. - pub g: u64, - /// The number of principal roots of unity in `roots`. - pub num_roots: usize, - /// Equal to `2^b - 1`, where `b` is the length of `p` in bits. - pub bit_mask: u64, - /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the - /// multiplicative group. `roots[0]` is equal to one by definition. - pub roots: [u64; MAX_ROOTS + 1], -} - -impl FieldParameters64 { - /// Addition. The result will be in [0, p), so long as both x and y are as well. - #[inline(always)] - pub fn add(&self, x: u64, y: u64) -> u64 { - // 0,x - // + 0,y - // ===== - // c,z - let (z, carry) = x.overflowing_add(y); - // c, z - // - 0, p - // ======== - // b1,s1,s0 - let (s0, b0) = z.overflowing_sub(self.p); - let (_s1, b1) = (carry as u64).overflowing_sub(b0 as u64); - // if b1 == 1: return z - // else: return s0 - let m = 0u64.wrapping_sub(b1 as u64); - (z & m) | (s0 & !m) - } - - /// Subtraction. The result will be in [0, p), so long as both x and y are as well. - #[inline(always)] - pub fn sub(&self, x: u64, y: u64) -> u64 { - // x - // - y - // ======== - // b0,z0 - let (z0, b0) = x.overflowing_sub(y); - let m = 0u64.wrapping_sub(b0 as u64); - // z0 - // + p - // ======== - // s1,s0 - z0.wrapping_add(m & self.p) - // if b1 == 1: return s0 - // else: return z0 - } - - /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm - /// described [here][montgomery]. The result will be in [0, p). - /// - /// # Example usage - /// ```text - /// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46); - /// ``` - /// - /// [montgomery]: https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf - #[inline(always)] - pub fn mul(&self, x: u64, y: u64) -> u64 { - let mut zz = [0; 2]; - - // Integer multiplication - // z = x * y - - // x - // * y - // ===== - // z1,z0 - let result = (x as u128) * (y as u128); - zz[0] = lo64(result) as u64; - zz[1] = hi64(result) as u64; - - // Montgomery Reduction - // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64. - - // z1,z0 - // + p - // * w = mu*z0 - // ===== - // z1, 0 - let w = self.mu.wrapping_mul(zz[0]); - let result = (self.p as u128) * (w as u128); - let hi = hi64(result); - let lo = lo64(result) as u64; - let (result, carry) = zz[0].overflowing_add(lo); - zz[0] = result; - let result = zz[1] as u128 + hi + carry as u128; - zz[1] = lo64(result) as u64; - let cc = hi64(result) as u64; - - // z = (z1) - let prod = zz[1]; - - // Final subtraction - // If z >= p, then z = z - p - - // cc, z - // - 0, p - // ======== - // b1,s1,s0 - let (s0, b0) = prod.overflowing_sub(self.p); - let (_s1, b1) = cc.overflowing_sub(b0 as u64); - // if b1 == 1: return z - // else: return s0 - let mask = 0u64.wrapping_sub(b1 as u64); - (prod & mask) | (s0 & !mask) - } - - /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the - /// runtime of this algorithm is linear in the bit length of `exp`. - pub fn pow(&self, x: u64, exp: u64) -> u64 { - let mut t = self.montgomery(1); - for i in (0..64 - exp.leading_zeros()).rev() { - t = self.mul(t, t); - if (exp >> i) & 1 != 0 { - t = self.mul(t, x); - } - } - t - } - - /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of - /// this algorithm is linear in the bit length of `p`. - #[inline(always)] - pub fn inv(&self, x: u64) -> u64 { - self.pow(x, self.p - 2) - } - - /// Negation, i.e., `-x (mod p)` where `p` is the modulus. - #[inline(always)] - pub fn neg(&self, x: u64) -> u64 { - self.sub(0, x) - } - - /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery - /// domain in order to carry out field arithmetic. The result will be in [0, p). - /// - /// # Example usage - /// ```text - /// let integer = 1; // Standard integer representation - /// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain - /// assert_eq!(elem, 2564090464); - /// ``` - #[inline(always)] - pub fn montgomery(&self, x: u64) -> u64 { - modp(self.mul(x, self.r2), self.p) - } - - /// Maps a field element to its representation as an integer. The result will be in [0, p). - /// - /// #Example usage - /// ```text - /// let elem = 2564090464; // Internal representation in the Montgomery domain - /// let integer = fp.residue(elem); // Standard integer representation - /// assert_eq!(integer, 1); - /// ``` - #[inline(always)] - pub fn residue(&self, x: u64) -> u64 { - modp(self.mul(x, 1), self.p) - } -} - -#[inline(always)] -fn modp(x: u64, p: u64) -> u64 { - let (z, carry) = x.overflowing_sub(p); - let m = 0u64.wrapping_sub(carry as u64); - z.wrapping_add(m & p) -} - -pub(crate) const FP64: FieldParameters64 = FieldParameters64 { - p: 18446744069414584321, // 64-bit prime - mu: 18446744069414584319, - r2: 18446744065119617025, - g: 15733474329512464024, - num_roots: 32, - bit_mask: 18446744073709551615, - roots: [ - 4294967295, - 18446744065119617026, - 18446744069414518785, - 18374686475393433601, - 268435456, - 18446673700670406657, - 18446744069414584193, - 576460752303421440, - 16576810576923738718, - 6647628942875889800, - 10087739294013848503, - 2135208489130820273, - 10781050935026037169, - 3878014442329970502, - 1205735313231991947, - 2523909884358325590, - 13797134855221748930, - 12267112747022536458, - 430584883067102937, - 10135969988448727187, - 6815045114074884550, - ], -}; - -#[cfg(test)] -mod tests { - use num_bigint::BigInt; - - use crate::fp::tests::{ - all_field_parameters_tests, TestFieldParameters, TestFieldParametersData, - }; - - use super::*; - - impl TestFieldParameters for FieldParameters64 { - fn p(&self) -> u128 { - self.p.into() - } - - fn g(&self) -> u128 { - self.g as u128 - } - - fn r2(&self) -> u128 { - self.r2 as u128 - } - - fn mu(&self) -> u64 { - self.mu - } - - fn bit_mask(&self) -> u128 { - self.bit_mask as u128 - } - - fn num_roots(&self) -> usize { - self.num_roots - } - - fn roots(&self) -> Vec { - self.roots.iter().map(|x| *x as u128).collect() - } - - fn montgomery(&self, x: u128) -> u128 { - FieldParameters64::montgomery(self, x.try_into().unwrap()).into() - } - - fn residue(&self, x: u128) -> u128 { - FieldParameters64::residue(self, x.try_into().unwrap()).into() - } - - fn add(&self, x: u128, y: u128) -> u128 { - FieldParameters64::add(self, x.try_into().unwrap(), y.try_into().unwrap()).into() - } - - fn sub(&self, x: u128, y: u128) -> u128 { - FieldParameters64::sub(self, x.try_into().unwrap(), y.try_into().unwrap()).into() - } - - fn neg(&self, x: u128) -> u128 { - FieldParameters64::neg(self, x.try_into().unwrap()).into() - } - - fn mul(&self, x: u128, y: u128) -> u128 { - FieldParameters64::mul(self, x.try_into().unwrap(), y.try_into().unwrap()).into() - } - - fn pow(&self, x: u128, exp: u128) -> u128 { - FieldParameters64::pow(self, x.try_into().unwrap(), exp.try_into().unwrap()).into() - } - - fn inv(&self, x: u128) -> u128 { - FieldParameters64::inv(self, x.try_into().unwrap()).into() - } - - fn radix(&self) -> BigInt { - BigInt::from(1) << 64 - } - } - - #[test] - fn test_fp64_u64() { - all_field_parameters_tests(TestFieldParametersData { - fp: Box::new(FP64), - expected_p: 18446744069414584321, - expected_g: 1753635133440165772, - expected_order: 1 << 32, - }) - } -} diff --git a/src/lib.rs b/src/lib.rs index e5280d505..d306926ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,6 @@ mod fft; pub mod field; pub mod flp; mod fp; -mod fp64; #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] #[cfg_attr( docsrs,