Skip to content

Commit

Permalink
Merge pull request #824 from cryspen/keks/test-bench-rsa
Browse files Browse the repository at this point in the history
RSA: Add Wycheproof Test
  • Loading branch information
jschneider-bensch authored Feb 13, 2025
2 parents 2611285 + e54bb04 commit 7ac47a8
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 4 deletions.
3 changes: 3 additions & 0 deletions rsa/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ libcrux-macros = { version = "=0.0.2-beta.2", path = "../macros" }
libcrux-sha2 = { version = "=0.0.2-beta.2", path = "../sha2", features = [
"expose-hacl",
] }

[dev-dependencies]
wycheproof = "0.6.0"
6 changes: 3 additions & 3 deletions rsa/src/impl_hacl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ macro_rules! impl_rsapss {
salt: &[u8],
sig: &mut [u8; $bytes],
) -> Result<(), Error> {
return sign_varlen(alg, &sk.into(), msg, salt, sig);
sign_varlen(alg, &sk.into(), msg, salt, sig)
}

/// Returns `Ok(())` if the provided signature is valid.
Expand Down Expand Up @@ -258,7 +258,7 @@ pub fn verify(
/// - checked explicitly
/// - `msgLen `less_than_max_input_length` a`
/// - follows from the check that messages are shorter than `u32::MAX`.
fn sign_varlen(
pub fn sign_varlen(
alg: crate::DigestAlgorithm,
sk: &VarLenPrivateKey<'_>,
msg: &[u8],
Expand Down Expand Up @@ -312,7 +312,7 @@ fn sign_varlen(
/// is less than `u32::MAX`.
/// - `msgLen less_than_max_input_length a`
/// - follows from the check that messages are shorter than `u32::MAX`.
fn verify_varlen(
pub fn verify_varlen(
alg: crate::DigestAlgorithm,
pk: &VarLenPublicKey<'_>,
msg: &[u8],
Expand Down
101 changes: 100 additions & 1 deletion rsa/tests/rsa_pss.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use libcrux_rsa::{
sign, sign_2048, verify, verify_2048, DigestAlgorithm, Error, PrivateKey, PublicKey,
VarLenPrivateKey,
VarLenPrivateKey, VarLenPublicKey,
};

const MODULUS: [u8; 256] = [
Expand Down Expand Up @@ -171,3 +171,102 @@ fn wycheproof_single_test() {
verify_2048(DigestAlgorithm::Sha2_256, &pk, &msg, 32, &signature)
.expect("Error verifying signature");
}

#[test]
fn run_wycheproof() {
for test_name in wycheproof::rsa_pss_verify::TestName::all() {
let test_set = wycheproof::rsa_pss_verify::TestSet::load(test_name)
.expect("error loading wycheproof test for name {test_name}");
println!("Test Set {test_name:?}");

let num_tests = test_set.number_of_tests;
let mut tests_run = 0;

for test_group in test_set.test_groups {
if !matches!(
test_group.hash,
wycheproof::HashFunction::Sha2_256
| wycheproof::HashFunction::Sha3_384
| wycheproof::HashFunction::Sha2_512
) {
println!(
"skipping due to unsupported hash function {:?}",
test_group.hash
);
continue;
}

if !matches!(
test_group.mgf_hash.unwrap(),
wycheproof::HashFunction::Sha2_256
| wycheproof::HashFunction::Sha3_384
| wycheproof::HashFunction::Sha2_512
) {
println!(
"skipping due to unsupported mgf hash function {:?}",
test_group.mgf_hash
);
continue;
}

if test_group.hash != test_group.mgf_hash.unwrap() {
println!(
"skipping because hash function doesn't match mgf hash function: {:?} != {:?}",
test_group.hash, test_group.mgf_hash
);
continue;
}

assert!(matches!(test_group.key.e.as_slice(), [1, 0, 1]));
assert_eq!(test_group.mgf, wycheproof::Mgf::Mgf1);

let salt_len = test_group.salt_size;
let mut n = test_group.key.n.as_slice();

assert!(salt_len < u32::MAX as usize);
let salt_len = salt_len as u32;

// strip leading 0s in n
let first_non_zero = n.iter().position(|&x| x != 0).unwrap();
for _ in 0..first_non_zero {
n = &n[1..];
}
let pk = VarLenPublicKey::try_from(n).unwrap();

let hash_algorithm = match &test_group.hash {
wycheproof::HashFunction::Sha2_256 => DigestAlgorithm::Sha2_256,
wycheproof::HashFunction::Sha2_384 => DigestAlgorithm::Sha2_384,
wycheproof::HashFunction::Sha2_512 => DigestAlgorithm::Sha2_512,
_ => panic!("Unknown hash algorithm {:?}", test_group.hash),
};
for (i, test) in test_group.tests.into_iter().enumerate() {
let comment = &test.comment;
println!("Test {i}: {comment}");

let valid = matches!(
test.result,
wycheproof::TestResult::Valid | wycheproof::TestResult::Acceptable
);

let msg = &test.msg;
let signature = &test.sig;

let result = libcrux_rsa::verify_varlen(
hash_algorithm,
&pk,
msg,
salt_len,
signature.as_slice(),
);
if valid {
assert!(result.is_ok());
} else {
assert!(result.is_err());
}
tests_run += 1;
}
}

println!("Ran {} out of {} tests.", tests_run, num_tests);
}
}

0 comments on commit 7ac47a8

Please sign in to comment.