From 1e40f58cb2bdb2bcb64c8bd94301a2c145d3f072 Mon Sep 17 00:00:00 2001 From: Abhinav Saxena Date: Mon, 2 Dec 2024 11:45:33 +0530 Subject: [PATCH] add checks for ML-KEM keys Signed-off-by: Abhinav Saxena --- tests/CMakeLists.txt | 9 ++++ tests/vectors_kem.c | 119 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 121 insertions(+), 7 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6d08516a8..fae241991 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -173,6 +173,15 @@ target_link_libraries(vectors_sig PRIVATE ${TEST_DEPS}) add_executable(vectors_kem vectors_kem.c) target_link_libraries(vectors_kem PRIVATE ${TEST_DEPS}) +if(CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BUILD_SHARED_LIBS) + # workaround for Windows .dll + if(MINGW OR MSYS OR CYGWIN OR CMAKE_CROSSCOMPILING) + target_link_options(vectors_kem PRIVATE -Wl,--allow-multiple-definition) + else() + target_link_options(vectors_kem PRIVATE "/FORCE:MULTIPLE") + endif() +endif() + # Enable Valgrind-based timing side-channel analysis for test_kem and test_sig if(OQS_ENABLE_TEST_CONSTANT_TIME AND NOT OQS_DEBUG_BUILD) message(WARNING "OQS_ENABLE_TEST_CONSTANT_TIME is incompatible with CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}.") diff --git a/tests/vectors_kem.c b/tests/vectors_kem.c index a7a1dc6a7..2b0ac88ff 100644 --- a/tests/vectors_kem.c +++ b/tests/vectors_kem.c @@ -11,15 +11,34 @@ #include #include - +#include #include "system_info.c" +#ifdef OQS_ENABLE_KEM_ML_KEM +/* macros for sanity checks for encaps and decaps key */ +#define ML_KEM_BLOCKSIZE 384 +#define ML_KEM_K_MAX 4 +#define ML_KEM_N 256 +#define ML_KEM_1024_PK_SIZE 1568 +#define ML_KEM_Q 3329 +#define SHA256_OP_LEN 32 +/* since x is 12 bits, max value could be 4095. the below macro uses this to implement a simple time constant mod 3329 */ +#define MOD_Q(x) ((x) - ((x >= ML_KEM_Q) * ML_KEM_Q)) +#endif //OQS_ENABLE_KEM_ML_KEM + struct { const uint8_t *pos; } prng_state = { .pos = 0 }; +/* MLKEM-specific functions */ +static inline bool is_ml_kem(const char *method_name) { + return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) + || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) + || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024)); +} + static void fprintBstr(FILE *fp, const char *S, const uint8_t *A, size_t L) { size_t i; fprintf(fp, "%s", S); @@ -58,13 +77,75 @@ static void hexStringToByteArray(const char *hexString, uint8_t *byteArray) { } } -/* ML_KEM-specific functions */ -static inline bool is_ml_kem(const char *method_name) { - return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) - || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) - || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024)); +#ifdef OQS_ENABLE_KEM_ML_KEM +static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) { + /* sanity checks */ + if ((NULL == sk) || (NULL == method_name) || (false == is_ml_kem(method_name))) { + fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or invalid method !\n", method_name); + return false; + } + /* buffer to hold public key hash */ + uint8_t pkdig[SHA256_OP_LEN] = {0}; + /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203 + K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */ + uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4; + /* calcualte hash of the public key(len = 384k+32) stored in private key at offset of 384k */ + OQS_SHA3_sha3_256(pkdig, sk + (ML_KEM_BLOCKSIZE * K), (ML_KEM_BLOCKSIZE * K) + 32); + /* compare it with public key hash stored at 768k+32 offset */ + if (0 != memcmp(pkdig, sk + (ML_KEM_BLOCKSIZE * K * 2) + 32, SHA256_OP_LEN)) { + return false; + } + return true; } +static inline bool sanityCheckPK(const uint8_t *pk, uint32_t pkLen, const char *method_name) { + /* sanity checks */ + if ((NULL == pk) || (0 == pkLen) || (NULL == method_name) || (false == is_ml_kem(method_name))) { + fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or zero or invalid method !\n", method_name); + return false; + } + unsigned int i, j; + /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203 + K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */ + uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4; + /* buffer to hold decoded value. max value used, so same buffer could be used for ML-KEM versions + encaps key is of length 384K bytes(384K*8 bits). Grouped into 12-bit values, the buffer requires (384*K*8)/12 = 256*K entries of 12 bits */ + uint16_t buffd[ML_KEM_N * ML_KEM_K_MAX] = {0}; + /* buffer to hold encoded value */ + uint8_t buffe[ML_KEM_1024_PK_SIZE] = {0}; + uint16_t *buff_dec; + /* perform byte decoding as per Algo 6 of FIPS 203 */ + for (i = 0; i < K; i++) { + buff_dec = &buffd[i * ML_KEM_N]; + const uint8_t *curr_pk = &pk[i * ML_KEM_BLOCKSIZE]; + for (j = 0; j < ML_KEM_N / 2; j++) { + buff_dec[2 * j] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF; + buff_dec[2 * j] = MOD_Q(buff_dec[2 * j]); + buff_dec[2 * j + 1] = ((curr_pk[3 * j + 1] >> 4) | ((uint16_t)curr_pk[3 * j + 2] << 4)) & 0xFFF; + buff_dec[2 * j + 1] = MOD_Q(buff_dec[2 * j + 1]); + } + } + /* perform byte encoding as per Algo 5 of FIPS 203 */ + for (i = 0; i < K; i++) { + uint16_t t0, t1; + buff_dec = &buffd[i * ML_KEM_N]; + uint8_t *buff_enc = &buffe[i * ML_KEM_BLOCKSIZE]; + for (j = 0; j < ML_KEM_N / 2; j++) { + t0 = buff_dec[2 * j]; + t1 = buff_dec[2 * j + 1]; + buff_enc[3 * j + 0] = (t0 >> 0); + buff_enc[3 * j + 1] = (t0 >> 8) | (t1 << 4); + buff_enc[3 * j + 2] = (t1 >> 4); + } + } + /* compare the encoded value with original public key. discard value of `rho(32 bytes)` during comparision as its not encoded */ + if (0 != memcmp(buffe, pk, pkLen - 32)) { + return false; + } + return true; +} +#endif //OQS_ENABLE_KEM_ML_KEM + static void MLKEM_randombytes_init(const uint8_t *entropy_input, const uint8_t *personalization_string) { (void) personalization_string; prng_state.pos = entropy_input; @@ -134,6 +215,14 @@ static OQS_STATUS kem_kg_vector(const char *method_name, fprintBstr(fh, "ek: ", public_key, kem->length_public_key); fprintBstr(fh, "dk: ", secret_key, kem->length_secret_key); +#ifdef OQS_ENABLE_KEM_ML_KEM + if ((false == sanityCheckPK(public_key, kem->length_public_key, method_name)) || (false == sanityCheckSK(secret_key, method_name))) { + ret = OQS_ERROR; + fprintf(stderr, "[vectors_kem] %s ERROR: generated public key or private key are corrupted !\n", method_name); + goto err; + } +#endif //OQS_ENABLE_KEM_ML_KEM + if (!memcmp(public_key, kg_pk, kem->length_public_key) && !memcmp(secret_key, kg_sk, kem->length_secret_key)) { ret = OQS_SUCCESS; } else { @@ -208,6 +297,14 @@ static OQS_STATUS kem_vector_encdec_aft(const char *method_name, goto err; } +#ifdef OQS_ENABLE_KEM_ML_KEM + if (false == sanityCheckPK(encdec_pk, kem->length_public_key, method_name)) { + ret = OQS_ERROR; + fprintf(stderr, "[vectors_kem] %s ERROR: passed encapsulation key is corrupted !\n", method_name); + goto err; + } +#endif //OQS_ENABLE_KEM_ML_KEM + rc = OQS_KEM_encaps(kem, ct_encaps, ss_encaps, encdec_pk); if (rc != OQS_SUCCESS) { fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name); @@ -273,6 +370,14 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name, goto err; } +#ifdef OQS_ENABLE_KEM_ML_KEM + if (false == sanityCheckSK(encdec_sk, method_name)) { + ret = OQS_ERROR; + fprintf(stderr, "[vectors_kem] %s ERROR: passed decapsulation key is corrupted !\n", method_name); + goto err; + } +#endif //OQS_ENABLE_KEM_ML_KEM + rc = OQS_KEM_decaps(kem, ss_decaps, encdec_c, encdec_sk); if (rc != OQS_SUCCESS) { fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name); @@ -469,4 +574,4 @@ int main(int argc, char **argv) { } else { return EXIT_SUCCESS; } -} +} \ No newline at end of file