Skip to content

Commit

Permalink
add checks for ML-KEM keys
Browse files Browse the repository at this point in the history
Signed-off-by: Abhinav Saxena <[email protected]>
  • Loading branch information
abhinav-thales committed Dec 2, 2024
1 parent d0d0413 commit 1e40f58
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 7 deletions.
9 changes: 9 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ target_link_libraries(vectors_sig PRIVATE ${TEST_DEPS})
add_executable(vectors_kem vectors_kem.c)
target_link_libraries(vectors_kem PRIVATE ${TEST_DEPS})

if(CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BUILD_SHARED_LIBS)
# workaround for Windows .dll
if(MINGW OR MSYS OR CYGWIN OR CMAKE_CROSSCOMPILING)
target_link_options(vectors_kem PRIVATE -Wl,--allow-multiple-definition)
else()
target_link_options(vectors_kem PRIVATE "/FORCE:MULTIPLE")
endif()
endif()

# Enable Valgrind-based timing side-channel analysis for test_kem and test_sig
if(OQS_ENABLE_TEST_CONSTANT_TIME AND NOT OQS_DEBUG_BUILD)
message(WARNING "OQS_ENABLE_TEST_CONSTANT_TIME is incompatible with CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}.")
Expand Down
119 changes: 112 additions & 7 deletions tests/vectors_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,34 @@
#include <sys/stat.h>

#include <oqs/oqs.h>

#include <oqs/sha3.h>
#include "system_info.c"

#ifdef OQS_ENABLE_KEM_ML_KEM
/* macros for sanity checks for encaps and decaps key */
#define ML_KEM_BLOCKSIZE 384
#define ML_KEM_K_MAX 4
#define ML_KEM_N 256
#define ML_KEM_1024_PK_SIZE 1568
#define ML_KEM_Q 3329
#define SHA256_OP_LEN 32
/* since x is 12 bits, max value could be 4095. the below macro uses this to implement a simple time constant mod 3329 */
#define MOD_Q(x) ((x) - ((x >= ML_KEM_Q) * ML_KEM_Q))
#endif //OQS_ENABLE_KEM_ML_KEM

struct {
const uint8_t *pos;
} prng_state = {
.pos = 0
};

/* MLKEM-specific functions */
static inline bool is_ml_kem(const char *method_name) {
return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024));
}

static void fprintBstr(FILE *fp, const char *S, const uint8_t *A, size_t L) {
size_t i;
fprintf(fp, "%s", S);
Expand Down Expand Up @@ -58,13 +77,75 @@ static void hexStringToByteArray(const char *hexString, uint8_t *byteArray) {
}
}

/* ML_KEM-specific functions */
static inline bool is_ml_kem(const char *method_name) {
return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024));
#ifdef OQS_ENABLE_KEM_ML_KEM
static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) {
/* sanity checks */
if ((NULL == sk) || (NULL == method_name) || (false == is_ml_kem(method_name))) {
fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or invalid method !\n", method_name);
return false;
}
/* buffer to hold public key hash */
uint8_t pkdig[SHA256_OP_LEN] = {0};
/* fetch the value of k according to the ML-KEM algorithm as per FIPS-203
K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */
uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4;
/* calcualte hash of the public key(len = 384k+32) stored in private key at offset of 384k */
OQS_SHA3_sha3_256(pkdig, sk + (ML_KEM_BLOCKSIZE * K), (ML_KEM_BLOCKSIZE * K) + 32);
/* compare it with public key hash stored at 768k+32 offset */
if (0 != memcmp(pkdig, sk + (ML_KEM_BLOCKSIZE * K * 2) + 32, SHA256_OP_LEN)) {
return false;
}
return true;
}

static inline bool sanityCheckPK(const uint8_t *pk, uint32_t pkLen, const char *method_name) {
/* sanity checks */
if ((NULL == pk) || (0 == pkLen) || (NULL == method_name) || (false == is_ml_kem(method_name))) {
fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or zero or invalid method !\n", method_name);
return false;
}
unsigned int i, j;
/* fetch the value of k according to the ML-KEM algorithm as per FIPS-203
K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */
uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4;
/* buffer to hold decoded value. max value used, so same buffer could be used for ML-KEM versions
encaps key is of length 384K bytes(384K*8 bits). Grouped into 12-bit values, the buffer requires (384*K*8)/12 = 256*K entries of 12 bits */
uint16_t buffd[ML_KEM_N * ML_KEM_K_MAX] = {0};
/* buffer to hold encoded value */
uint8_t buffe[ML_KEM_1024_PK_SIZE] = {0};
uint16_t *buff_dec;
/* perform byte decoding as per Algo 6 of FIPS 203 */
for (i = 0; i < K; i++) {
buff_dec = &buffd[i * ML_KEM_N];
const uint8_t *curr_pk = &pk[i * ML_KEM_BLOCKSIZE];
for (j = 0; j < ML_KEM_N / 2; j++) {
buff_dec[2 * j] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF;
buff_dec[2 * j] = MOD_Q(buff_dec[2 * j]);
buff_dec[2 * j + 1] = ((curr_pk[3 * j + 1] >> 4) | ((uint16_t)curr_pk[3 * j + 2] << 4)) & 0xFFF;
buff_dec[2 * j + 1] = MOD_Q(buff_dec[2 * j + 1]);
}
}
/* perform byte encoding as per Algo 5 of FIPS 203 */
for (i = 0; i < K; i++) {
uint16_t t0, t1;
buff_dec = &buffd[i * ML_KEM_N];
uint8_t *buff_enc = &buffe[i * ML_KEM_BLOCKSIZE];
for (j = 0; j < ML_KEM_N / 2; j++) {
t0 = buff_dec[2 * j];
t1 = buff_dec[2 * j + 1];
buff_enc[3 * j + 0] = (t0 >> 0);
buff_enc[3 * j + 1] = (t0 >> 8) | (t1 << 4);
buff_enc[3 * j + 2] = (t1 >> 4);
}
}
/* compare the encoded value with original public key. discard value of `rho(32 bytes)` during comparision as its not encoded */
if (0 != memcmp(buffe, pk, pkLen - 32)) {
return false;
}
return true;
}
#endif //OQS_ENABLE_KEM_ML_KEM

static void MLKEM_randombytes_init(const uint8_t *entropy_input, const uint8_t *personalization_string) {
(void) personalization_string;
prng_state.pos = entropy_input;
Expand Down Expand Up @@ -134,6 +215,14 @@ static OQS_STATUS kem_kg_vector(const char *method_name,
fprintBstr(fh, "ek: ", public_key, kem->length_public_key);
fprintBstr(fh, "dk: ", secret_key, kem->length_secret_key);

#ifdef OQS_ENABLE_KEM_ML_KEM
if ((false == sanityCheckPK(public_key, kem->length_public_key, method_name)) || (false == sanityCheckSK(secret_key, method_name))) {
ret = OQS_ERROR;
fprintf(stderr, "[vectors_kem] %s ERROR: generated public key or private key are corrupted !\n", method_name);
goto err;
}
#endif //OQS_ENABLE_KEM_ML_KEM

if (!memcmp(public_key, kg_pk, kem->length_public_key) && !memcmp(secret_key, kg_sk, kem->length_secret_key)) {
ret = OQS_SUCCESS;
} else {
Expand Down Expand Up @@ -208,6 +297,14 @@ static OQS_STATUS kem_vector_encdec_aft(const char *method_name,
goto err;
}

#ifdef OQS_ENABLE_KEM_ML_KEM
if (false == sanityCheckPK(encdec_pk, kem->length_public_key, method_name)) {
ret = OQS_ERROR;
fprintf(stderr, "[vectors_kem] %s ERROR: passed encapsulation key is corrupted !\n", method_name);
goto err;
}
#endif //OQS_ENABLE_KEM_ML_KEM

rc = OQS_KEM_encaps(kem, ct_encaps, ss_encaps, encdec_pk);
if (rc != OQS_SUCCESS) {
fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name);
Expand Down Expand Up @@ -273,6 +370,14 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name,
goto err;
}

#ifdef OQS_ENABLE_KEM_ML_KEM
if (false == sanityCheckSK(encdec_sk, method_name)) {
ret = OQS_ERROR;
fprintf(stderr, "[vectors_kem] %s ERROR: passed decapsulation key is corrupted !\n", method_name);
goto err;
}
#endif //OQS_ENABLE_KEM_ML_KEM

rc = OQS_KEM_decaps(kem, ss_decaps, encdec_c, encdec_sk);
if (rc != OQS_SUCCESS) {
fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name);
Expand Down Expand Up @@ -469,4 +574,4 @@ int main(int argc, char **argv) {
} else {
return EXIT_SUCCESS;
}
}
}

0 comments on commit 1e40f58

Please sign in to comment.