Skip to content

Commit

Permalink
partial Result
Browse files Browse the repository at this point in the history
  • Loading branch information
eschorn1 committed Jan 1, 2024
1 parent c466bc4 commit 8359aa0
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 132 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ rust-version = "1.72"
[dependencies]
zeroize = { version = "1.6.0", features = ["zeroize_derive"] }
rand_core = { version = "0.6.4", default-features = false }
sha3 = "0.10.8"
rand = "0.8.5"
sha3 = { version = "0.10.8", default-features = false }


[features]
Expand Down
96 changes: 54 additions & 42 deletions src/byte_fns.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,41 @@
use crate::helpers::ensure;
use crate::Q;
use crate::types::Z256;

/// Algorithm 2 `BitsToBytes(b)` on page 17.
/// Converts a bit string (of length a multiple of eight) into an array of bytes.
pub(crate) fn bits_to_bytes(bits: &[u8], bytes: &mut [u8]) {
// Input: bit array b ∈ {0,1}^{8·ℓ}
// Output: byte array B ∈ B^ℓ
debug_assert_eq!(bits.len() % 8, 0); // bit_array is multiple of 8
debug_assert_eq!(bits.len(), 8 * bytes.len()); // bit_array length is 8ℓ
///
/// Input: bit array b ∈ {0,1}^{8·ℓ} <br>
/// Output: byte array B ∈ B^ℓ
pub(crate) fn bits_to_bytes(bits: &[u8], bytes: &mut [u8]) -> Result<(), &'static str> {
ensure!(bits.len() % 8 == 0, "TKTK");
// bit_array is multiple of 8
ensure!(bits.len() == 8 * bytes.len(), "TKTK"); // bit_array length is 8ℓ

// 1: B ← (0, . . . , 0) (returned mutable data struct is provided by the caller)
// (reconsider zeroing ... found one bug already)
bytes.iter_mut().for_each(|b| *b = 0);

// 2: for (i ← 0; i < 8ℓ; i ++)
for i in 0..bits.len() {
//
// 3: B [⌊i/8⌋] ← B [⌊i/8⌋] + b[i] · 2^{i mod 8}
bytes[i / 8] += bits[i] * 2u8.pow(u32::try_from(i).expect("too many bits") % 8);
bytes[i / 8] += bits[i] * 2u8.pow(u32::try_from(i).map_err(|_| "too many bits")? % 8);
//
} // 4: end for

Ok(())
} // 5: return B


/// Algorithm 3 `BytesToBits(B)` on page 18.
/// Performs the inverse of `BitsToBytes`, converting a byte array into a bit array.
pub(crate) fn bytes_to_bits(bytes: &[u8], bits: &mut [u8]) {
// Input: byte array B ∈ B^ℓ
// Output: bit array b ∈ {0,1}^{8·ℓ}
debug_assert_eq!(bits.len() % 8, 0); // bit_array is multiple of 8
debug_assert_eq!(bytes.len() * 8, bits.len()); // bit_array length is 8ℓ
///
/// Input: byte array B ∈ B^ℓ <br>
/// Output: bit array b ∈ {0,1}^{8·ℓ}
pub(crate) fn bytes_to_bits(bytes: &[u8], bits: &mut [u8]) -> Result<(), &'static str> {
ensure!(bits.len() % 8 == 0, "TKTK");
// bit_array is multiple of 8
ensure!(bytes.len() * 8 == bits.len(), "TKTK"); // bit_array length is 8ℓ

// 1: for (i ← 0; i < ℓ; i ++)
for i in 0..bytes.len() {
Expand All @@ -46,25 +53,27 @@ pub(crate) fn bytes_to_bits(bytes: &[u8], bits: &mut [u8]) {
//
} // 5: end for
} // 6: end for
Ok(())
} // 7: return b


/// Algorithm 4 `ByteEncode<d>(F)` on page 19.
/// Encodes an array of d-bit integers into a byte array, for 1 ≤ d ≤ 12.
///
/// Input: integer array `F ∈ Z^256_m`, where `m = 2^d if d < 12` and `m = q if d = 12` <br>
/// Output: byte array B ∈ B^{32d}
pub(crate) fn byte_encode<const D: usize, const D_256: usize>(
integers_f: &[Z256; 256], bytes_b: &mut [u8],
) {
// Input: integer array F ∈ Z^256_m, where m = 2^d if d < 12 and m = q if d = 12
// Output: byte array B ∈ B^{32d}
debug_assert!((1 <= D) & (D <= 12));
debug_assert_eq!(D * 256, D_256);
debug_assert_eq!(integers_f.len(), 256);
debug_assert_eq!(bytes_b.len(), 32 * D);
) -> Result<(), &'static str> {
ensure!((1 <= D) & (D <= 12), "TKTK");
ensure!(D * 256 == D_256, "TKTK");
ensure!(integers_f.len() == 256, "TKTK");
ensure!(bytes_b.len() == 32 * D, "TKTK");

let m_mod = if D < 12 {
2_u16.pow(u32::try_from(D).unwrap())
2_u16.pow(u32::try_from(D).map_err(|_| "impossible")?)
} else {
u16::try_from(Q).unwrap()
u16::try_from(Q).map_err(|_| "impossible")?
};
let mut bit_array = [0u8; D_256];

Expand All @@ -87,32 +96,34 @@ pub(crate) fn byte_encode<const D: usize, const D_256: usize>(
} // 7: end for
//
// 8: B ← BitsToBytes(b)
bits_to_bytes(&bit_array, bytes_b);
bits_to_bytes(&bit_array, bytes_b)?;
//
Ok(())
} // 9: return B


/// Algorithm 5 `ByteDecode<d>(B)` on page 19.
/// Decodes a byte array into an array of d-bit integers, for 1 ≤ d ≤ 12.
///
/// Input: byte array B ∈ B^{32d} <br>
/// Output: integer array `F ∈ Z^256_m`, where `m = 2^d if d < 12` and `m = q if d = 12`
pub(crate) fn byte_decode<const D: usize, const D_256: usize>(
bytes_b: &[u8], integers_f: &mut [Z256; 256],
) {
// Input: byte array B ∈ B^{32d}
// Output: integer array F ∈ Z^256_m, where m = 2^d if d < 12 and m = q if d = 12
debug_assert!((1 <= D) & (D <= 12));
debug_assert_eq!(D * 256, D_256);
debug_assert_eq!(bytes_b.len(), 32 * D);
debug_assert_eq!(integers_f.len(), 256);
) -> Result<(), &'static str> {
ensure!((1 <= D) & (D <= 12), "TKTK");
ensure!(D * 256 == D_256, "TKTK");
ensure!(bytes_b.len() == 32 * D ,"TKTK");
ensure!(integers_f.len() == 256, "TKTKT");

let m_mod = if D < 12 {
2_u16.pow(u32::try_from(D).unwrap())
2_u16.pow(u32::try_from(D).map_err(|_| "impossible")?)
} else {
u16::try_from(Q).unwrap()
u16::try_from(Q).map_err(|_| "impossible")?
};
let mut bit_array = [0u8; D_256];

// 1: b ← BytesToBits(B)
bytes_to_bits(bytes_b, &mut bit_array);
bytes_to_bits(bytes_b, &mut bit_array)?;

// 2: for (i ← 0; i < 256; i ++)
for i in 0..256 {
Expand All @@ -127,6 +138,7 @@ pub(crate) fn byte_decode<const D: usize, const D_256: usize>(
});
//
} // 4: end for
Ok(())
} // 5: return F


Expand All @@ -150,9 +162,9 @@ mod tests {
let num_bytes = rng.gen::<u8>();
let bytes1: Vec<u8> = (0..num_bytes).map(|_| rng.gen()).collect();
let mut bits = vec![0u8; num_bytes as usize * 8];
bytes_to_bits(&bytes1, &mut bits[..]);
bytes_to_bits(&bytes1, &mut bits[..]).unwrap();
let mut bytes2 = vec![0u8; num_bytes as usize];
bits_to_bytes(&bits, &mut bytes2[..]);
bits_to_bytes(&bits, &mut bytes2[..]).unwrap();
assert_eq!(bytes1, bytes2);
}
}
Expand All @@ -165,29 +177,29 @@ mod tests {
let num_bytes = 32 * 11;
let mut bytes2 = vec![0u8; num_bytes];
let bytes1: Vec<u8> = (0..num_bytes).map(|_| rng.gen()).collect();
byte_decode::<11, { 11 * 256 }>(&bytes1, &mut integer_array);
byte_encode::<11, 2816>(&integer_array, &mut bytes2);
byte_decode::<11, { 11 * 256 }>(&bytes1, &mut integer_array).unwrap();
byte_encode::<11, 2816>(&integer_array, &mut bytes2).unwrap();
assert_eq!(bytes1, bytes2);

let num_bytes = 32 * 10;
let bytes1: Vec<u8> = (0..num_bytes).map(|_| rng.gen()).collect();
let mut bytes2 = vec![0u8; num_bytes];
byte_decode::<10, 2560>(&bytes1, &mut integer_array);
byte_encode::<10, 2560>(&integer_array, &mut bytes2);
byte_decode::<10, 2560>(&bytes1, &mut integer_array).unwrap();
byte_encode::<10, 2560>(&integer_array, &mut bytes2).unwrap();
assert_eq!(bytes1, bytes2);

let num_bytes = 32 * 5;
let bytes1: Vec<u8> = (0..num_bytes).map(|_| rng.gen()).collect();
let mut bytes2 = vec![0u8; num_bytes];
byte_decode::<5, 1280>(&bytes1, &mut integer_array);
byte_encode::<5, 1280>(&integer_array, &mut bytes2);
byte_decode::<5, 1280>(&bytes1, &mut integer_array).unwrap();
byte_encode::<5, 1280>(&integer_array, &mut bytes2).unwrap();
assert_eq!(bytes1, bytes2);

let num_bytes = 32 * 4;
let bytes1: Vec<u8> = (0..num_bytes).map(|_| rng.gen()).collect();
let mut bytes2 = vec![0u8; num_bytes];
byte_decode::<4, 1024>(&bytes1, &mut integer_array);
byte_encode::<4, 1024>(&integer_array, &mut bytes2);
byte_decode::<4, 1024>(&bytes1, &mut integer_array).unwrap();
byte_encode::<4, 1024>(&integer_array, &mut bytes2).unwrap();
assert_eq!(bytes1, bytes2);
}
}
Expand Down
15 changes: 13 additions & 2 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ use crate::ntt::multiply_ntts;
use crate::Q;
use crate::types::Z256;

/// If the condition is not met, return an error message. Borrowed from the `anyhow` crate.
macro_rules! ensure {
($cond:expr, $msg:literal $(,)?) => {
if !$cond {
return Err($msg);
}
};
}

pub(crate) use ensure; // make available throughout crate

/// Vector addition; See bottom of page 9, second row: `z_hat` = `u_hat` + `v_hat`
#[must_use]
pub(crate) fn vec_add<const K: usize>(
Expand All @@ -14,7 +25,7 @@ pub(crate) fn vec_add<const K: usize>(
let mut result = [[Z256(0); 256]; K];
for i in 0..vec_a.len() {
for j in 0..vec_a[i].len() {
result[i][j] = vec_a[i][j].add(vec_b[i][j]); //.set_u16(vec_a[i][j].get_u32() + vec_b[i][j].get_u32());
result[i][j] = vec_a[i][j].add(vec_b[i][j]);
}
}
result
Expand All @@ -33,7 +44,7 @@ pub(crate) fn mat_vec_mul<const K: usize>(
for j in 0..K {
let tmp = multiply_ntts(&a_hat[i][j], &u_hat[j]);
for k in 0..256 {
w_hat[i][k] = w_hat[i][k].add(tmp[k]); //.set_u16(w_hat[i][k].get_u32() + tmp[k].get_u32());
w_hat[i][k] = w_hat[i][k].add(tmp[k]);
}
}
}
Expand Down
Loading

0 comments on commit 8359aa0

Please sign in to comment.