From 8359aa02648f695d4a96ab2957be9ca24ea4f7a8 Mon Sep 17 00:00:00 2001 From: eschorn1 Date: Mon, 1 Jan 2024 15:18:39 -0600 Subject: [PATCH] partial Result --- Cargo.toml | 3 +- src/byte_fns.rs | 96 +++++++++++++++++++++++++++---------------------- src/helpers.rs | 15 ++++++-- src/k_pke.rs | 79 ++++++++++++++++++++-------------------- src/lib.rs | 6 ++-- src/ml_kem.rs | 38 ++++++++++---------- src/ntt.rs | 18 +++++----- src/sampling.rs | 18 +++++----- src/traits.rs | 12 +++++++ src/types.rs | 17 +++++---- 10 files changed, 170 insertions(+), 132 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 27ab889..b6a4040 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,8 +15,7 @@ rust-version = "1.72" [dependencies] zeroize = { version = "1.6.0", features = ["zeroize_derive"] } rand_core = { version = "0.6.4", default-features = false } -sha3 = "0.10.8" -rand = "0.8.5" +sha3 = { version = "0.10.8", default-features = false } [features] diff --git a/src/byte_fns.rs b/src/byte_fns.rs index 07685dd..5ef772f 100644 --- a/src/byte_fns.rs +++ b/src/byte_fns.rs @@ -1,34 +1,41 @@ +use crate::helpers::ensure; use crate::Q; use crate::types::Z256; /// Algorithm 2 `BitsToBytes(b)` on page 17. /// Converts a bit string (of length a multiple of eight) into an array of bytes. -pub(crate) fn bits_to_bytes(bits: &[u8], bytes: &mut [u8]) { - // Input: bit array b ∈ {0,1}^{8·ℓ} - // Output: byte array B ∈ B^ℓ - debug_assert_eq!(bits.len() % 8, 0); // bit_array is multiple of 8 - debug_assert_eq!(bits.len(), 8 * bytes.len()); // bit_array length is 8ℓ +/// +/// Input: bit array b ∈ {0,1}^{8·ℓ}
+/// Output: byte array B ∈ B^ℓ +pub(crate) fn bits_to_bytes(bits: &[u8], bytes: &mut [u8]) -> Result<(), &'static str> { + ensure!(bits.len() % 8 == 0, "TKTK"); + // bit_array is multiple of 8 + ensure!(bits.len() == 8 * bytes.len(), "TKTK"); // bit_array length is 8ℓ // 1: B ← (0, . . . , 0) (returned mutable data struct is provided by the caller) - // (reconsider zeroing ... found one bug already) + bytes.iter_mut().for_each(|b| *b = 0); // 2: for (i ← 0; i < 8ℓ; i ++) for i in 0..bits.len() { // // 3: B [⌊i/8⌋] ← B [⌊i/8⌋] + b[i] · 2^{i mod 8} - bytes[i / 8] += bits[i] * 2u8.pow(u32::try_from(i).expect("too many bits") % 8); + bytes[i / 8] += bits[i] * 2u8.pow(u32::try_from(i).map_err(|_| "too many bits")? % 8); // } // 4: end for + + Ok(()) } // 5: return B /// Algorithm 3 `BytesToBits(B)` on page 18. /// Performs the inverse of `BitsToBytes`, converting a byte array into a bit array. -pub(crate) fn bytes_to_bits(bytes: &[u8], bits: &mut [u8]) { - // Input: byte array B ∈ B^ℓ - // Output: bit array b ∈ {0,1}^{8·ℓ} - debug_assert_eq!(bits.len() % 8, 0); // bit_array is multiple of 8 - debug_assert_eq!(bytes.len() * 8, bits.len()); // bit_array length is 8ℓ +/// +/// Input: byte array B ∈ B^ℓ
+/// Output: bit array b ∈ {0,1}^{8·ℓ} +pub(crate) fn bytes_to_bits(bytes: &[u8], bits: &mut [u8]) -> Result<(), &'static str> { + ensure!(bits.len() % 8 == 0, "TKTK"); + // bit_array is multiple of 8 + ensure!(bytes.len() * 8 == bits.len(), "TKTK"); // bit_array length is 8ℓ // 1: for (i ← 0; i < ℓ; i ++) for i in 0..bytes.len() { @@ -46,25 +53,27 @@ pub(crate) fn bytes_to_bits(bytes: &[u8], bits: &mut [u8]) { // } // 5: end for } // 6: end for + Ok(()) } // 7: return b /// Algorithm 4 `ByteEncode(F)` on page 19. /// Encodes an array of d-bit integers into a byte array, for 1 ≤ d ≤ 12. +/// +/// Input: integer array `F ∈ Z^256_m`, where `m = 2^d if d < 12` and `m = q if d = 12`
+/// Output: byte array B ∈ B^{32d} pub(crate) fn byte_encode( integers_f: &[Z256; 256], bytes_b: &mut [u8], -) { - // Input: integer array F ∈ Z^256_m, where m = 2^d if d < 12 and m = q if d = 12 - // Output: byte array B ∈ B^{32d} - debug_assert!((1 <= D) & (D <= 12)); - debug_assert_eq!(D * 256, D_256); - debug_assert_eq!(integers_f.len(), 256); - debug_assert_eq!(bytes_b.len(), 32 * D); +) -> Result<(), &'static str> { + ensure!((1 <= D) & (D <= 12), "TKTK"); + ensure!(D * 256 == D_256, "TKTK"); + ensure!(integers_f.len() == 256, "TKTK"); + ensure!(bytes_b.len() == 32 * D, "TKTK"); let m_mod = if D < 12 { - 2_u16.pow(u32::try_from(D).unwrap()) + 2_u16.pow(u32::try_from(D).map_err(|_| "impossible")?) } else { - u16::try_from(Q).unwrap() + u16::try_from(Q).map_err(|_| "impossible")? }; let mut bit_array = [0u8; D_256]; @@ -87,32 +96,34 @@ pub(crate) fn byte_encode( } // 7: end for // // 8: B ← BitsToBytes(b) - bits_to_bytes(&bit_array, bytes_b); + bits_to_bytes(&bit_array, bytes_b)?; // + Ok(()) } // 9: return B /// Algorithm 5 `ByteDecode(B)` on page 19. /// Decodes a byte array into an array of d-bit integers, for 1 ≤ d ≤ 12. +/// +/// Input: byte array B ∈ B^{32d}
+/// Output: integer array `F ∈ Z^256_m`, where `m = 2^d if d < 12` and `m = q if d = 12` pub(crate) fn byte_decode( bytes_b: &[u8], integers_f: &mut [Z256; 256], -) { - // Input: byte array B ∈ B^{32d} - // Output: integer array F ∈ Z^256_m, where m = 2^d if d < 12 and m = q if d = 12 - debug_assert!((1 <= D) & (D <= 12)); - debug_assert_eq!(D * 256, D_256); - debug_assert_eq!(bytes_b.len(), 32 * D); - debug_assert_eq!(integers_f.len(), 256); +) -> Result<(), &'static str> { + ensure!((1 <= D) & (D <= 12), "TKTK"); + ensure!(D * 256 == D_256, "TKTK"); + ensure!(bytes_b.len() == 32 * D ,"TKTK"); + ensure!(integers_f.len() == 256, "TKTKT"); let m_mod = if D < 12 { - 2_u16.pow(u32::try_from(D).unwrap()) + 2_u16.pow(u32::try_from(D).map_err(|_| "impossible")?) } else { - u16::try_from(Q).unwrap() + u16::try_from(Q).map_err(|_| "impossible")? }; let mut bit_array = [0u8; D_256]; // 1: b ← BytesToBits(B) - bytes_to_bits(bytes_b, &mut bit_array); + bytes_to_bits(bytes_b, &mut bit_array)?; // 2: for (i ← 0; i < 256; i ++) for i in 0..256 { @@ -127,6 +138,7 @@ pub(crate) fn byte_decode( }); // } // 4: end for + Ok(()) } // 5: return F @@ -150,9 +162,9 @@ mod tests { let num_bytes = rng.gen::(); let bytes1: Vec = (0..num_bytes).map(|_| rng.gen()).collect(); let mut bits = vec![0u8; num_bytes as usize * 8]; - bytes_to_bits(&bytes1, &mut bits[..]); + bytes_to_bits(&bytes1, &mut bits[..]).unwrap(); let mut bytes2 = vec![0u8; num_bytes as usize]; - bits_to_bytes(&bits, &mut bytes2[..]); + bits_to_bytes(&bits, &mut bytes2[..]).unwrap(); assert_eq!(bytes1, bytes2); } } @@ -165,29 +177,29 @@ mod tests { let num_bytes = 32 * 11; let mut bytes2 = vec![0u8; num_bytes]; let bytes1: Vec = (0..num_bytes).map(|_| rng.gen()).collect(); - byte_decode::<11, { 11 * 256 }>(&bytes1, &mut integer_array); - byte_encode::<11, 2816>(&integer_array, &mut bytes2); + byte_decode::<11, { 11 * 256 }>(&bytes1, &mut integer_array).unwrap(); + byte_encode::<11, 2816>(&integer_array, &mut bytes2).unwrap(); assert_eq!(bytes1, bytes2); let num_bytes = 32 * 10; let bytes1: Vec = (0..num_bytes).map(|_| rng.gen()).collect(); let mut bytes2 = vec![0u8; num_bytes]; - byte_decode::<10, 2560>(&bytes1, &mut integer_array); - byte_encode::<10, 2560>(&integer_array, &mut bytes2); + byte_decode::<10, 2560>(&bytes1, &mut integer_array).unwrap(); + byte_encode::<10, 2560>(&integer_array, &mut bytes2).unwrap(); assert_eq!(bytes1, bytes2); let num_bytes = 32 * 5; let bytes1: Vec = (0..num_bytes).map(|_| rng.gen()).collect(); let mut bytes2 = vec![0u8; num_bytes]; - byte_decode::<5, 1280>(&bytes1, &mut integer_array); - byte_encode::<5, 1280>(&integer_array, &mut bytes2); + byte_decode::<5, 1280>(&bytes1, &mut integer_array).unwrap(); + byte_encode::<5, 1280>(&integer_array, &mut bytes2).unwrap(); assert_eq!(bytes1, bytes2); let num_bytes = 32 * 4; let bytes1: Vec = (0..num_bytes).map(|_| rng.gen()).collect(); let mut bytes2 = vec![0u8; num_bytes]; - byte_decode::<4, 1024>(&bytes1, &mut integer_array); - byte_encode::<4, 1024>(&integer_array, &mut bytes2); + byte_decode::<4, 1024>(&bytes1, &mut integer_array).unwrap(); + byte_encode::<4, 1024>(&integer_array, &mut bytes2).unwrap(); assert_eq!(bytes1, bytes2); } } diff --git a/src/helpers.rs b/src/helpers.rs index 3dd004f..0e69d75 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -6,6 +6,17 @@ use crate::ntt::multiply_ntts; use crate::Q; use crate::types::Z256; +/// If the condition is not met, return an error message. Borrowed from the `anyhow` crate. +macro_rules! ensure { + ($cond:expr, $msg:literal $(,)?) => { + if !$cond { + return Err($msg); + } + }; +} + +pub(crate) use ensure; // make available throughout crate + /// Vector addition; See bottom of page 9, second row: `z_hat` = `u_hat` + `v_hat` #[must_use] pub(crate) fn vec_add( @@ -14,7 +25,7 @@ pub(crate) fn vec_add( let mut result = [[Z256(0); 256]; K]; for i in 0..vec_a.len() { for j in 0..vec_a[i].len() { - result[i][j] = vec_a[i][j].add(vec_b[i][j]); //.set_u16(vec_a[i][j].get_u32() + vec_b[i][j].get_u32()); + result[i][j] = vec_a[i][j].add(vec_b[i][j]); } } result @@ -33,7 +44,7 @@ pub(crate) fn mat_vec_mul( for j in 0..K { let tmp = multiply_ntts(&a_hat[i][j], &u_hat[j]); for k in 0..256 { - w_hat[i][k] = w_hat[i][k].add(tmp[k]); //.set_u16(w_hat[i][k].get_u32() + tmp[k].get_u32()); + w_hat[i][k] = w_hat[i][k].add(tmp[k]); } } } diff --git a/src/k_pke.rs b/src/k_pke.rs index 5e725b7..56f43a1 100644 --- a/src/k_pke.rs +++ b/src/k_pke.rs @@ -1,15 +1,16 @@ use rand_core::CryptoRngCore; use crate::byte_fns::{byte_decode, byte_encode}; -use crate::helpers::{ - compress, decompress, dot_t_prod, g, mat_t_vec_mul, mat_vec_mul, prf, vec_add, xof, -}; +use crate::helpers::{compress, decompress, dot_t_prod, ensure, g, mat_t_vec_mul, mat_vec_mul, prf, vec_add, xof}; use crate::ntt::{ntt, ntt_inv}; use crate::sampling::{sample_ntt, sample_poly_cbd}; use crate::types::Z256; /// Algorithm 12 `K-PKE.KeyGen()` on page 26. /// Generates an encryption key and a corresponding decryption key. +/// +/// Output: encryption key `ekPKE ∈ B^{384*k+32}`
+/// Output: decryption key `dkPKE ∈ B^{384*k}` #[allow(clippy::similar_names, clippy::module_name_repetitions)] pub fn k_pke_key_gen< const K: usize, @@ -18,11 +19,9 @@ pub fn k_pke_key_gen< const ETA1_512: usize, >( rng: &mut impl CryptoRngCore, ek_pke: &mut [u8], dk_pke: &mut [u8], -) { - // Output: encryption key ekPKE ∈ B^{384*k+32} - // Output: decryption key dkPKE ∈ B^{384*k} - debug_assert_eq!(ek_pke.len(), 384 * K + 32); - debug_assert_eq!(dk_pke.len(), 384 * K); +) -> Result<(), &'static str> { + ensure!(ek_pke.len() == 384 * K + 32, "TKTK"); + ensure!(dk_pke.len() == 384 * K, "TKTK"); // 1: d ←− B^{32} ▷ d is 32 random bytes (see Section 3.3) let mut d = [0u8; 32]; @@ -57,7 +56,7 @@ pub fn k_pke_key_gen< for i in 0..K { // // 10: s[i] ← SamplePolyCBDη1(PRFη1(σ, N)) ▷ s[i] ∈ Z^{256}_q sampled from CBD - s[i] = sample_poly_cbd::(&prf::(&sigma, n)); + s[i] = sample_poly_cbd::(&prf::(&sigma, n))?; // 11: N ← N +1 n += 1; @@ -71,7 +70,7 @@ pub fn k_pke_key_gen< for i in 0..K { // // 14: e[i] ← SamplePolyCBDη1(PRFη1(σ, N)) ▷ e[i] ∈ Z^{256}_q sampled from CBD - e[i] = sample_poly_cbd::(&prf::(&sigma, n)); + e[i] = sample_poly_cbd::(&prf::(&sigma, n))?; // 15: N ← N +1 n += 1; @@ -99,16 +98,17 @@ pub fn k_pke_key_gen< // 20: ek_{PKE} ← ByteEncode12(t̂)∥ρ ▷ ByteEncode12 is run k times; include seed for  for i in 0..K { - byte_encode::<12, 3072>(&t_hat[i], &mut ek_pke[i * 384..(i + 1) * 384]); + byte_encode::<12, 3072>(&t_hat[i], &mut ek_pke[i * 384..(i + 1) * 384])?; } ek_pke[K * 384..].copy_from_slice(&rho); // 21: dk_{PKE} ← ByteEncode12(ŝ) ▷ ByteEncode12 is run k times for i in 0..K { - byte_encode::<12, 3072>(&s_hat[i], &mut dk_pke[i * 384..(i + 1) * 384]); + byte_encode::<12, 3072>(&s_hat[i], &mut dk_pke[i * 384..(i + 1) * 384])?; } // 22: return (ekPKE , dkPKE ) + Ok(()) } @@ -129,20 +129,20 @@ pub(crate) fn k_pke_encrypt< const DV_256: usize, >( ek: &[u8], m: &[u8], randomness: &[u8; 32], ct: &mut [u8], -) { +) -> Result<(), &'static str> { // Input: encryption key ekPKE ∈ B^{384k+32} // Input: message m ∈ B^{32} // Input: encryption randomness r ∈ B^{32} // Output: ciphertext c ∈ B^{32(du k+dv )} - debug_assert_eq!(ek.len(), 384 * K + 32); - debug_assert_eq!(m.len(), 32); - debug_assert_eq!(randomness.len(), 32); - debug_assert_eq!(ETA1 * 64, ETA1_64); - debug_assert_eq!(ETA1 * 512, ETA1_512); - debug_assert_eq!(ETA2 * 64, ETA2_64); - debug_assert_eq!(ETA2 * 512, ETA2_512); - debug_assert_eq!(DU * 256, DU_256); - debug_assert_eq!(DV * 256, DV_256); + ensure!(ek.len() == 384 * K + 32, "TKTK"); + ensure!(m.len() == 32, "TKTK"); + ensure!(randomness.len() == 32, "TKTK"); + ensure!(ETA1 * 64 == ETA1_64, "TKTK"); + ensure!(ETA1 * 512 == ETA1_512, "TKTK"); + ensure!(ETA2 * 64 == ETA2_64, "TKTK"); + ensure!(ETA2 * 512 == ETA2_512, "TKTK"); + ensure!(DU * 256 == DU_256, "TKTK"); + ensure!(DV * 256 == DV_256, "TKTK"); // 1: N ← 0 let mut n = 0; @@ -150,7 +150,7 @@ pub(crate) fn k_pke_encrypt< // 2: t̂ ← ByteDecode12 (ekPKE [0 : 384k]) let mut t_hat = [[Z256(0); 256]; K]; for i in 0..K { - byte_decode::<12, { 12 * 256 }>(&ek[384 * i..384 * (i + 1)], &mut t_hat[i]); + byte_decode::<12, { 12 * 256 }>(&ek[384 * i..384 * (i + 1)], &mut t_hat[i])?; } // 3: 3: ρ ← ekPKE [384k : 384k + 32] ▷ extract 32-byte seed from ekPKE @@ -180,7 +180,7 @@ pub(crate) fn k_pke_encrypt< for i in 0..K { // // 10: r[i] ← SamplePolyCBDη 1 (PRFη 1 (r, N)) ▷ r[i] ∈ Z^{256}_q sampled from CBD - r[i] = sample_poly_cbd::(&prf::(randomness, n)); + r[i] = sample_poly_cbd::(&prf::(randomness, n))?; // 11: N ← N +1 n += 1; @@ -194,7 +194,7 @@ pub(crate) fn k_pke_encrypt< for i in 0..K { // // 14: e1 [i] ← SamplePolyCBDη2(PRFη2(r, N)) ▷ e1 [i] ∈ Z^{256}_q sampled from CBD - e1[i] = sample_poly_cbd::(&prf::(randomness, n)); + e1[i] = sample_poly_cbd::(&prf::(randomness, n))?; // 15: N ← N +1 n += 1; @@ -202,7 +202,7 @@ pub(crate) fn k_pke_encrypt< } // 16: end for // 17: 17: e2 ← SamplePolyCBDη(PRFη2(r, N)) ▷ sample e2 ∈ Z^{256}_q from CBD - let e2 = sample_poly_cbd::(&prf::(randomness, n)); + let e2 = sample_poly_cbd::(&prf::(randomness, n))?; // 18: 18: r̂ ← NTT(r) ▷ NTT is run k times let mut r_hat = [[Z256(0); 256]; K]; @@ -220,7 +220,7 @@ pub(crate) fn k_pke_encrypt< // 20: µ ← Decompress1(ByteDecode1(m))) let mut mu = [Z256(0); 256]; - byte_decode::<1, 256>(m, &mut mu); + byte_decode::<1, 256>(m, &mut mu)?; decompress::<1>(&mut mu); // 21: v ← NTT−1 (t̂⊺ ◦ r̂) + e2 + µ ▷ encode plaintext m into polynomial v. @@ -231,14 +231,15 @@ pub(crate) fn k_pke_encrypt< let step = 32 * DU; for i in 0..K { compress::(&mut u[i]); - byte_encode::(&u[i], &mut ct[i * step..(i + 1) * step]); + byte_encode::(&u[i], &mut ct[i * step..(i + 1) * step])?; } // 23: c2 ← ByteEncode_{dv}(Compress_{dv}(v)) compress::(&mut v); - byte_encode::(&v, &mut ct[K * step..(K * step + 32 * DV)]); + byte_encode::(&v, &mut ct[K * step..(K * step + 32 * DV)])?; // 24: return c ← (c1 ∥ c2 ) + Ok(()) } @@ -252,14 +253,14 @@ pub(crate) fn k_pke_decrypt< const DV_256: usize, >( dk: &[u8], ct: &[u8], -) -> [u8; 32] { +) -> Result<[u8; 32], &'static str> { // Input: decryption key dk_{PKE} ∈ B^{384*k} // Input: ciphertext c ∈ B^{32(du*k+dv)} // Output: message m ∈ B^{32} - debug_assert_eq!(dk.len(), 384 * K); - debug_assert_eq!(ct.len(), 32 * (DU * K + DV)); - debug_assert_eq!(DU * 256, DU_256); - debug_assert_eq!(DV * 256, DV_256); + ensure!(dk.len() == 384 * K, "TKTK"); + ensure!(ct.len() == 32 * (DU * K + DV), "TKTK"); + ensure!(DU * 256 == DU_256, "TKTK"); + ensure!(DV * 256 == DV_256, "TKTK"); // 1: c1 ← c[0 : 32du k] let c1 = &ct[0..32 * DU * K]; @@ -270,19 +271,19 @@ pub(crate) fn k_pke_decrypt< // 3: 3: u ← Decompress_{du}(ByteDecode_{du}(c_1)) ▷ ByteDecode_{du} invoked k times let mut u = [[Z256(0); 256]; K]; for i in 0..K { - byte_decode::(&c1[32 * DU * i..32 * DU * (i + 1)], &mut u[i]); + byte_decode::(&c1[32 * DU * i..32 * DU * (i + 1)], &mut u[i])?; decompress::(&mut u[i]); } // 4: v ← Decompress_{dv}(ByteDecode_{dv}(c_2)) let mut v = [Z256(0); 256]; - byte_decode::(c2, &mut v); + byte_decode::(c2, &mut v)?; decompress::(&mut v); // 5: s_hat ← ByteDecode_{12}(dk_{PKE{) let mut s_hat = [[Z256(0); 256]; K]; for i in 0..K { - byte_decode::<12, { 12 * 256 }>(&dk[384 * i..384 * (i + 1)], &mut s_hat[i]); + byte_decode::<12, { 12 * 256 }>(&dk[384 * i..384 * (i + 1)], &mut s_hat[i])?; } // 6: w ← v − NTT−1 (ŝ⊺ ◦ NTT(u)) ▷ NTT−1 and NTT invoked k times @@ -303,8 +304,8 @@ pub(crate) fn k_pke_decrypt< // 7: m ← ByteEncode1 (Compress1 (w)) ▷ decode plaintext m from polynomial v compress::<1>(&mut w); let mut m = [0u8; 32]; - byte_encode::<1, 256>(&w, &mut m); + byte_encode::<1, 256>(&w, &mut m)?; // 8: return m - m + Ok(m) } diff --git a/src/lib.rs b/src/lib.rs index 799a922..f01a1ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -120,7 +120,7 @@ macro_rules! functionality { rng: &mut impl CryptoRngCore, ) -> Result<(EncapsKey, DecapsKey), &'static str> { let (mut ek, mut dk) = ([0u8; EK_LEN], [0u8; DK_LEN]); - ml_kem::ml_kem_key_gen::(rng, &mut ek, &mut dk); // handle internal results + ml_kem::ml_kem_key_gen::(rng, &mut ek, &mut dk)?; // handle internal results Ok((EncapsKey(ek), DecapsKey(dk))) } } @@ -146,7 +146,7 @@ macro_rules! functionality { DU_256, DV, DV_256, - >(rng, &self.0, &mut ct); + >(rng, &self.0, &mut ct)?; Ok((ssk, CipherText(ct))) } } @@ -172,7 +172,7 @@ macro_rules! functionality { J_LEN, CT_LEN, >(&self.0, &ct.0); - Ok(ssk) + ssk } } diff --git a/src/ml_kem.rs b/src/ml_kem.rs index 864639c..6c9ceb4 100644 --- a/src/ml_kem.rs +++ b/src/ml_kem.rs @@ -1,7 +1,7 @@ use rand_core::CryptoRngCore; use crate::byte_fns::{byte_decode, byte_encode}; -use crate::helpers::{g, h, j}; +use crate::helpers::{ensure, g, h, j}; use crate::k_pke::k_pke_decrypt; use crate::SharedSecretKey; use crate::types::Z256; @@ -17,11 +17,11 @@ pub(crate) fn ml_kem_key_gen< const ETA1_512: usize, >( rng: &mut impl CryptoRngCore, ek: &mut [u8], dk: &mut [u8], -) { +) -> Result<(), &'static str> { // Output: Encapsulation key ek ∈ B^{384k+32} // Output: Decapsulation key dk ∈ B^{768k+96} - debug_assert_eq!(ek.len(), 384 * K + 32); - debug_assert_eq!(dk.len(), 768 * K + 96); + ensure!(ek.len() == 384 * K + 32, "TKTK"); + ensure!(dk.len() == 768 * K + 96, "TKTK"); // 1: z ←− B32 ▷ z is 32 random bytes (see Section 3.3) let mut z = [0u8; 32]; @@ -29,7 +29,7 @@ pub(crate) fn ml_kem_key_gen< // 2: (ek_{PKE}, dk_{PKE}) ← K-PKE.KeyGen() ▷ run key generation for K-PKE let p1 = 384 * K; - k_pke_key_gen::(rng, ek, &mut dk[..p1]); // 3: ek ← ekPKE + k_pke_key_gen::(rng, ek, &mut dk[..p1])?; // 3: ek ← ekPKE // 4: dk ← (dkPKE ∥ek∥H(ek)∥z) (first concat element is done above alongside ek) let h_ek = h(ek); @@ -40,6 +40,7 @@ pub(crate) fn ml_kem_key_gen< dk[p3..].copy_from_slice(&z); // 5: return (ek, dk) + Ok(()) } @@ -59,20 +60,20 @@ pub(crate) fn ml_kem_encaps< const DV_256: usize, >( rng: &mut impl CryptoRngCore, ek: &[u8], ct: &mut [u8], -) -> SharedSecretKey { +) -> Result { // Validated input: encapsulation key ek ∈ B^{384k+32} // Output: shared key K ∈ B^{32} // Output: ciphertext c ∈ B^{32(du k+dv)} - assert_eq!(ek.len(), 384 * K + 32); // type check: array of length 384k + 32 + ensure!(ek.len() == 384 * K + 32, "TKTK"); // type check: array of length 384k + 32 // modulus check: perform the computation ek ← ByteEncode12 (ByteDecode12(ek_tidle) // note: after checking, we run with the original input (due to const array allocation); the last 32 bytes is rho // TODO: revisit let mut ek_hat = [Z256(0); 256]; for i in 0..K { let mut ek_tilde = [0u8; 384]; - byte_decode::<12, { 12 * 256 }>(&ek[384 * i..384 * (i + 1)], &mut ek_hat); - byte_encode::<12, { 384 * 8 }>(&ek_hat, &mut ek_tilde); - assert_eq!(ek_tilde, ek[384 * i..384 * (i + 1)]); + byte_decode::<12, { 12 * 256 }>(&ek[384 * i..384 * (i + 1)], &mut ek_hat)?; + byte_encode::<12, { 384 * 8 }>(&ek_hat, &mut ek_tilde)?; + ensure!(ek_tilde == ek[384 * i..384 * (i + 1)], "TKTK"); } // 1: m ←− B32 ▷ m is 32 random bytes (see Section 3.3) @@ -89,10 +90,10 @@ pub(crate) fn ml_kem_encaps< // 3: 3: c ← K-PKE.Encrypt(ek, m, r) ▷ encrypt m using K-PKE with randomness r k_pke_encrypt::( ek, &m, &r, ct, - ); + )?; // 4: return (K, c) (note: ct is mutable input) - SharedSecretKey(k) + Ok(SharedSecretKey(k)) } @@ -115,13 +116,14 @@ pub(crate) fn ml_kem_decaps< const CT_LEN: usize, >( dk: &[u8], ct: &[u8], -) -> SharedSecretKey { +) -> Result { // Validated input: ciphertext c ∈ B^{32(du k+dv )} // Validated input: decapsulation key dk ∈ B^{768k+96} // Output: shared key K ∈ B^{32} // These length checks are a bit redundant...but present for completeness and paranoia - assert_eq!(ct.len(), 32 * (DU * K + DV)); // Ciphertext type check - assert_eq!(dk.len(), 768 * K + 96); // Decapsulation key type check + ensure!(ct.len() == 32 * (DU * K + DV), "TKTK"); + // Ciphertext type check + ensure!(dk.len() == 768 * K + 96, "TKTK"); // Decapsulation key type check // 1019 For some applications, further validation of the decapsulation key dk_tilde may be appropriate. For // 1020 instance, in cases where dk_tilde was generated by a third party, users may want to ensure that the four @@ -144,7 +146,7 @@ pub(crate) fn ml_kem_decaps< let z = &dk[768 * K + 64..768 * K + 96]; // 5: m′ ← K-PKE.Decrypt(dkPKE,c) - let m_prime = k_pke_decrypt::(dk_pke, ct); + let m_prime = k_pke_decrypt::(dk_pke, ct)?; // 6: (K′, r′) ← G(m′ ∥ h) let mut g_input = [0u8; 32 + 32]; @@ -166,10 +168,10 @@ pub(crate) fn ml_kem_decaps< &m_prime, &r_prime, &mut c_prime[0..ct.len()], - ); + )?; if *ct != c_prime[0..ct.len()] { k_prime = k_bar; }; - SharedSecretKey(k_prime) + Ok(SharedSecretKey(k_prime)) } diff --git a/src/ntt.rs b/src/ntt.rs index 15113d0..e9656a3 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -22,7 +22,7 @@ pub fn ntt(array_f: &[Z256; 256]) -> [Z256; 256] { for start in (0..256).step_by(2 * len) { // // 5: zeta ← ζ^{BitRev7 (k)} mod q - let zeta = Z256(ZETA_TABLE[k << 1] as u16); + let zeta = Z256(ZETA_TABLE[k << 1]); // 6: k ← k+1 @@ -70,7 +70,7 @@ pub fn ntt_inv(f_hat: &[Z256; 256]) -> [Z256; 256] { for start in (0..256).step_by(2 * len) { // // 5: zeta ← ζ^{BitRev7(k)} mod q - let zeta = Z256(ZETA_TABLE[k << 1] as u16); + let zeta = Z256(ZETA_TABLE[k << 1]); // 6: k ← k − 1 k -= 1; @@ -118,7 +118,7 @@ pub fn multiply_ntts(f_hat: &[Z256; 256], g_hat: &[Z256; 256]) -> [Z256; 256] { f_hat[2 * i + 1], g_hat[2 * i], g_hat[2 * i + 1], - Z256(ZETA_TABLE[i ^ 0x80] as u16), + Z256(ZETA_TABLE[i ^ 0x80]), ); h_hat[2 * i] = h_hat_2i; h_hat[2 * i + 1] = h_hat_2ip1; @@ -148,7 +148,8 @@ pub fn base_case_multiply(a0: Z256, a1: Z256, b0: Z256, b1: Z256, gamma: Z256) - /// HAC Algorithm 14.76 Right-to-left binary exponentiation mod Q. #[must_use] -const fn pow_mod_q(g: u32, e: u8) -> u32 { +#[allow(clippy::cast_possible_truncation)] +const fn pow_mod_q(g: u32, e: u8) -> u16 { let g = g as u64; let mut result = 1; let mut s = g; @@ -163,12 +164,13 @@ const fn pow_mod_q(g: u32, e: u8) -> u32 { }; } //reduce_q64(result) - result as u32 + result as u16 } #[allow(dead_code)] -const fn gen_zeta_table() -> [u32; 256] { - let mut result = [0u32; 256]; +#[allow(clippy::cast_possible_truncation)] +const fn gen_zeta_table() -> [u16; 256] { + let mut result = [0u16; 256]; let mut i = 0; while i < 256u16 { result[i as usize] = pow_mod_q(ZETA, (i as u8).reverse_bits()); @@ -178,5 +180,5 @@ const fn gen_zeta_table() -> [u32; 256] { } #[allow(dead_code)] -pub(crate) static ZETA_TABLE: [u32; 256] = gen_zeta_table(); +pub(crate) static ZETA_TABLE: [u16; 256] = gen_zeta_table(); diff --git a/src/sampling.rs b/src/sampling.rs index 8b74401..4524d28 100644 --- a/src/sampling.rs +++ b/src/sampling.rs @@ -1,6 +1,7 @@ use sha3::digest::XofReader; use crate::byte_fns::bytes_to_bits; +use crate::helpers::ensure; use crate::Q; use crate::types::Z256; @@ -33,7 +34,7 @@ pub fn sample_ntt(mut byte_stream_b: impl XofReader) -> [Z256; 256] { if d1 < Q { // // 7: a_hat[j] ← d1 ▷ a_hat ∈ Z256 - array_a_hat[j] = Z256(d1 as u16); //.set_u16(d1); + array_a_hat[j] = Z256(u16::try_from(d1).unwrap()); // 8: j ← j+1 j += 1; @@ -44,7 +45,7 @@ pub fn sample_ntt(mut byte_stream_b: impl XofReader) -> [Z256; 256] { if (d2 < Q) & (j < 256) { // // 11: a_hat[j] ← d2 - array_a_hat[j] = Z256(d2 as u16); //.set_u16(d2); + array_a_hat[j] = Z256(u16::try_from(d2).unwrap()); //.set_u16(d2); // 12: j ← j+1 j += 1; @@ -60,18 +61,17 @@ pub fn sample_ntt(mut byte_stream_b: impl XofReader) -> [Z256; 256] { /// Algorithm 7 `SamplePolyCBDη(B)` on page 20. /// If the input is a stream of uniformly random bytes, outputs a sample from the distribution Dη (Rq ). -#[must_use] -pub fn sample_poly_cbd(byte_array_b: &[u8]) -> [Z256; 256] { +pub fn sample_poly_cbd(byte_array_b: &[u8]) -> Result<[Z256; 256], &'static str> { // Input: byte array B ∈ B^{64η} // Output: array f ∈ Z^{256}_q - debug_assert_eq!(ETA * 512, ETA_512); - debug_assert_eq!(byte_array_b.len(), ETA * 64); + ensure!(ETA * 512 == ETA_512, "TKTK"); + ensure!(byte_array_b.len() == ETA * 64, "TKTK"); let mut array_f: [Z256; 256] = [Z256(0); 256]; let mut bit_array = [0u8; ETA_512]; // 1: b ← BytesToBits(B) - bytes_to_bits(byte_array_b, &mut bit_array); + bytes_to_bits(byte_array_b, &mut bit_array)?; // 2: for (i ← 0; i < 256; i ++) for i in 0..256 { @@ -83,10 +83,10 @@ pub fn sample_poly_cbd(byte_array_b: &[u let y = (0..ETA).fold(0, |acc: u32, j| acc + u32::from(bit_array[2 * i * ETA + ETA + j])); // 5: f [i] ← x − y mod q - array_f[i] = Z256(x as u16).sub(Z256(y as u16)); + array_f[i] = Z256(u16::try_from(x).unwrap()).sub(Z256(u16::try_from(y).unwrap())); // } // 6: end for - array_f // 7: return f + Ok(array_f) // 7: return f } diff --git a/src/traits.rs b/src/traits.rs index 08a49e6..f8c977d 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -9,12 +9,16 @@ pub trait KeyGen { /// TKTK type DecapsKey; + /// TKTK + /// # Errors /// TKTK #[cfg(feature = "default-rng")] fn try_keygen_vt() -> Result<(Self::EncapsKey, Self::DecapsKey), &'static str> { Self::try_keygen_with_rng_vt(&mut OsRng) } + /// TKTK + /// # Errors /// TKTK fn try_keygen_with_rng_vt( rng: &mut impl CryptoRngCore, @@ -29,12 +33,16 @@ pub trait Encaps { /// TKTK type CipherText; + /// TKTK + /// # Errors /// TKTK #[cfg(feature = "default-rng")] fn try_encaps_vt(&self) -> Result<(Self::SharedSecretKey, Self::CipherText), &'static str> { self.try_encaps_with_rng_vt(&mut OsRng) } + /// TKTK + /// # Errors /// TKTK fn try_encaps_with_rng_vt( &self, rng: &mut impl CryptoRngCore, @@ -49,6 +57,8 @@ pub trait Decaps { /// TKTK type SharedSecretKey; + /// TKTK + /// # Errors /// TKTK fn try_decaps_vt(&self, ct: &Self::CipherText) -> Result; } @@ -62,6 +72,8 @@ pub trait SerDes { /// TKTK fn into_bytes(self) -> Self::ByteArray; + /// TKTK + /// # Errors /// TKTK fn try_from_bytes(ba: Self::ByteArray) -> Result where diff --git a/src/types.rs b/src/types.rs index 96eeba9..c48173c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -7,19 +7,18 @@ use crate::Q; #[derive(Clone, Copy)] pub struct Z256(pub u16); +#[allow(clippy::inline_always)] impl Z256 { pub fn get_u16(self) -> u16 { self.0 } - // pub fn set_u16(&mut self, a: u32) { - // //debug_assert!(a < Q); //u32::from(u16::MAX)); - // self.0 = u16::try_from(a % Q).unwrap(); // TODO: Revisit - // } + #[allow(clippy::cast_possible_truncation)] + const Q16: u16 = Q as u16; #[inline(always)] pub fn add(self, other: Self) -> Self { let sum = self.0.wrapping_add(other.0); - let (trial, borrow) = sum.overflowing_sub(Q as u16); + let (trial, borrow) = sum.overflowing_sub(Self::Q16); let result = if borrow { sum } else { trial }; // Not quite CT Self(result) } @@ -28,13 +27,13 @@ impl Z256 { #[inline(always)] pub fn sub(self, other: Self) -> Self { let (diff, borrow) = self.0.overflowing_sub(other.0); - let trial = diff.wrapping_add(Q as u16); + let trial = diff.wrapping_add(Self::Q16); let result = if borrow { trial } else { diff }; // Not quite CT - Self(result as u16) + Self(result) } - const M: u64 = 2u64.pow(32) / (Q as u64); + const M: u64 = 2u64.pow(32) / (Self::Q64); const Q64: u64 = Q as u64; #[inline(always)] pub fn mul(self, other: Self) -> Self { @@ -44,6 +43,6 @@ impl Z256 { let rem = prod - quot * Self::Q64; let (diff, borrow) = rem.overflowing_sub(Self::Q64); let result = if borrow { rem } else { diff }; // Not quite CT - Self(result as u16) + Self(u16::try_from(result).unwrap()) } }