Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
shibammukherjee committed Feb 6, 2025
1 parent 16d4787 commit 452b09a
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 34 deletions.
56 changes: 28 additions & 28 deletions faest.c
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ void faest_sign(uint8_t* sig, const uint8_t* msg, size_t msg_len, const uint8_t*
// ::8
uint8_t chall_1[(5 * MAX_LAMBDA_BYTES) + 8];
hash_challenge_1(chall_1, mu, bavc.h, signature_c(sig, 0, params), iv, lambda, ell, tau);
debug_print_buf("P_chall_1", chall_1, lambda / 8);
// debug_print_buf("P_chall_1", chall_1, lambda / 8);

// ::9-10
vole_hash(signature_u_tilde(sig, params), chall_1, u, ell, lambda);
debug_print_buf("P_u", u, 16);
debug_print_buf("P_u_tilde", dsignature_u_tilde(sig, params), 16);
// debug_print_buf("P_u", u, 16);
// debug_print_buf("P_u_tilde", dsignature_u_tilde(sig, params), 16);

// ::11-12
// To save memory consumption, the chall_2 is computed in an
Expand Down Expand Up @@ -370,7 +370,7 @@ void faest_sign(uint8_t* sig, const uint8_t* msg, size_t msg_len, const uint8_t*
// :15
uint8_t chall_2[3 * MAX_LAMBDA_BYTES + 8];
hash_challenge_2_finalize(chall_2, &chall_2_ctx, signature_d(sig, params), lambda, ell);
debug_print_buf("P_chall_2", chall_2, lambda / 8);
// debug_print_buf("P_chall_2", chall_2, lambda / 8);

// ::16-20
uint8_t a0_tilde[MAX_LAMBDA_BYTES];
Expand All @@ -379,7 +379,7 @@ void faest_sign(uint8_t* sig, const uint8_t* msg, size_t msg_len, const uint8_t*
for (unsigned int bit_i = 0; bit_i < ell; bit_i++) {
w_bits[bit_i] = (w[bit_i/8] >> bit_i%8) & 1;
}
debug_print_buf("w", w, 32);
// debug_print_buf("w", w, 32);
uint8_t* u_bits = (uint8_t*)malloc(2*lambda); // 1 bit per uint8_t
for (unsigned int bit_i = 0; bit_i < 2*lambda; bit_i++) {
u_bits[bit_i] = (u[(ell + bit_i)/8] >> (ell + bit_i)%8) & 1;
Expand Down Expand Up @@ -430,16 +430,16 @@ void faest_sign(uint8_t* sig, const uint8_t* msg, size_t msg_len, const uint8_t*
ctr = htole32(ctr);
memcpy(signature_ctr(sig, params), &ctr, sizeof(ctr));

debug_print_buf("P_iv", iv, IV_SIZE);
debug_print_buf("P_u_tilde", dsignature_u_tilde(sig, params), 16);
debug_print_buf("P_c", dsignature_c(sig, 0, params), 16);
debug_print_buf("P_d", dsignature_d(sig, params), 16);
debug_print_buf("P_a1_tilde", dsignature_a1_tilde(sig, params), 16);
debug_print_buf("P_a2_tilde", dsignature_a2_tilde(sig, params), 16);
debug_print_buf("P_chall_3", dsignature_chall_3(sig, params), 16);
debug_print_buf("P_decom_i", dsignature_decom_i(sig, params), 16);
debug_print_buf("P_ctr", dsignature_ctr(sig, params), 4);
debug_print_buf("P_a0_tilde", a0_tilde, 16);
// debug_print_buf("P_iv", iv, IV_SIZE);
// debug_print_buf("P_u_tilde", dsignature_u_tilde(sig, params), 16);
// debug_print_buf("P_c", dsignature_c(sig, 0, params), 16);
// debug_print_buf("P_d", dsignature_d(sig, params), 16);
// debug_print_buf("P_a1_tilde", dsignature_a1_tilde(sig, params), 16);
// debug_print_buf("P_a2_tilde", dsignature_a2_tilde(sig, params), 16);
// debug_print_buf("P_chall_3", dsignature_chall_3(sig, params), 16);
// debug_print_buf("P_decom_i", dsignature_decom_i(sig, params), 16);
// debug_print_buf("P_ctr", dsignature_ctr(sig, params), 4);
// debug_print_buf("P_a0_tilde", a0_tilde, 16);
printf("======================= END PROVER ==========================\n");
}

Expand Down Expand Up @@ -468,15 +468,15 @@ int faest_verify(const uint8_t* msg, size_t msglen, const uint8_t* sig, const ui
uint8_t iv[IV_SIZE];
hash_iv(iv, dsignature_iv_pre(sig, params), lambda);

debug_print_buf("V_iv", iv, IV_SIZE);
debug_print_buf("V_u_tilde", dsignature_u_tilde(sig, params), 16);
debug_print_buf("V_c", dsignature_c(sig, 0, params), 16);
debug_print_buf("V_d", dsignature_d(sig, params), 16);
debug_print_buf("V_a1_tilde", dsignature_a1_tilde(sig, params), 16);
debug_print_buf("V_a2_tilde", dsignature_a2_tilde(sig, params), 16);
debug_print_buf("V_chall_3", dsignature_chall_3(sig, params), 16);
debug_print_buf("V_decom_i", dsignature_decom_i(sig, params), 16);
debug_print_buf("V_ctr", dsignature_ctr(sig, params), 4);
// debug_print_buf("V_iv", iv, IV_SIZE);
// debug_print_buf("V_u_tilde", dsignature_u_tilde(sig, params), 16);
// debug_print_buf("V_c", dsignature_c(sig, 0, params), 16);
// debug_print_buf("V_d", dsignature_d(sig, params), 16);
// debug_print_buf("V_a1_tilde", dsignature_a1_tilde(sig, params), 16);
// debug_print_buf("V_a2_tilde", dsignature_a2_tilde(sig, params), 16);
// debug_print_buf("V_chall_3", dsignature_chall_3(sig, params), 16);
// debug_print_buf("V_decom_i", dsignature_decom_i(sig, params), 16);
// debug_print_buf("V_ctr", dsignature_ctr(sig, params), 4);

printf("verify: step 5\n");
// Step: 6-7
Expand All @@ -498,7 +498,7 @@ int faest_verify(const uint8_t* msg, size_t msglen, const uint8_t* sig, const ui
// ::10
uint8_t chall_1[5 * MAX_LAMBDA_BYTES + 8];
hash_challenge_1(chall_1, mu, hcom, dsignature_c(sig, 0, params), iv, lambda, ell, tau);
debug_print_buf("V_chall_1", chall_1, lambda / 8);
// debug_print_buf("V_chall_1", chall_1, lambda / 8);

// Step 12, 14 and 15
H2_context_t chall_2_ctx;
Expand All @@ -525,7 +525,7 @@ int faest_verify(const uint8_t* msg, size_t msglen, const uint8_t* sig, const ui
// Step 15
uint8_t chall_2[3 * MAX_LAMBDA_BYTES + 8];
hash_challenge_2_finalize(chall_2, &chall_2_ctx, dsignature_d(sig, params), lambda, ell);
debug_print_buf("V_chall_2", chall_2, lambda / 8);
// debug_print_buf("V_chall_2", chall_2, lambda / 8);
printf("verify: step 16\n");
// Step 18

Expand All @@ -539,7 +539,7 @@ int faest_verify(const uint8_t* msg, size_t msglen, const uint8_t* sig, const ui
const uint8_t* a0_tilde = aes_verify(d_bits, q, chall_2, dsignature_chall_3(sig, params), dsignature_a1_tilde(sig, params), dsignature_a2_tilde(sig, params), owf_input,
owf_output, params);

debug_print_buf("V_a0_tilde", a0_tilde, 16);
// debug_print_buf("V_a0_tilde", a0_tilde, 16);


free_pointer_array(&q);
Expand All @@ -550,7 +550,7 @@ int faest_verify(const uint8_t* msg, size_t msglen, const uint8_t* sig, const ui
hash_challenge_3(chall_3, chall_2, a0_tilde, dsignature_a1_tilde(sig, params),
dsignature_a2_tilde(sig, params), dsignature_ctr(sig, params), lambda);
free((void*)a0_tilde);
debug_print_buf("V_chall_3_check", chall_3, lambda / 8);
// debug_print_buf("V_chall_3_check", chall_3, lambda / 8);

// Step 21
return memcmp(chall_3, dsignature_chall_3(sig, params), lambdaBytes) == 0 ? 0 : -1;
Expand Down
10 changes: 5 additions & 5 deletions faest_aes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3361,7 +3361,7 @@ static void aes_128_constraints_verifier(bf128_t* z_key, const bf128_t* w_key, c
// ::4-5
z_key[0] = bf128_mul(delta, bf128_mul(w_key[0], w_key[1]));

debug_print_bf128("delta", &delta);
// debug_print_bf128("delta", &delta);

// ::7-8
bf128_t* rkeys_key = (bf128_t*)malloc(sizeof(bf128_t) * (R+1) * blocksize);
Expand Down Expand Up @@ -3461,7 +3461,7 @@ static void aes_192_constraints_verifier(bf192_t* z_key, const bf192_t* w_key, c
// ::4-5
z_key[0] = bf192_mul(delta, bf192_mul(w_key[0], w_key[1]));

debug_print_bf192("delta", &delta);
// debug_print_bf192("delta", &delta);

// ::7-8
bf192_t* rkeys_key = (bf192_t*)malloc(sizeof(bf192_t) * (R+1) * blocksize);
Expand Down Expand Up @@ -3785,7 +3785,7 @@ static void aes_192_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til
memset(z1_val, 0, c * sizeof(bf192_t));
memset(z2_gamma, 0, c * sizeof(bf192_t));

/* aes_192_constraints_prover(z0_tag, z1_val, z2_gamma, w_bits, w_tag, owf_in, owf_out, params, isEM); */
aes_192_constraints_prover(z0_tag, z1_val, z2_gamma, w_bits, w_tag, owf_in, owf_out, params, isEM);

// Step: 13-18
zk_hash_192_ctx a0_ctx;
Expand Down Expand Up @@ -3825,7 +3825,7 @@ static void aes_256_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til

// ::6-7 embed VOLE masks
bf256_t* bf_u_bits = (bf256_t*) malloc(2*lambda * sizeof(bf256_t));
debug_print_buf_bits("u_bits", u_bits, 2*lambda);
// debug_print_buf_bits("u_bits", u_bits, 2*lambda);
for (unsigned int i = 0; i < 2*lambda; i++) {
bf_u_bits[i] = bf256_from_bit(u_bits[i]);
}
Expand Down Expand Up @@ -3960,7 +3960,7 @@ static uint8_t* aes_192_verifier(const uint8_t* d_bits, uint8_t** Q, const uint8
bf192_mul_bit(bf_delta, d_bits[i]));
}
memset(z2_key, 0, c * sizeof(bf192_t));
/* aes_192_constraints_verifier(z2_key, w_key, owf_in, owf_out, bf_delta, params, isEM); */
aes_192_constraints_verifier(z2_key, w_key, owf_in, owf_out, bf_delta, params, isEM);

// ::13-14
zk_hash_192_ctx b_ctx;
Expand Down
214 changes: 213 additions & 1 deletion tests/aes_prove.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,220 @@ static bf128_t* column_to_row_major_and_shrink_V_128(uint8_t** v, unsigned int e
return new_v;
}

static bf192_t* column_to_row_major_and_shrink_V_192(uint8_t** v, unsigned int ell) {
// V is \hat \ell times \lambda matrix over F_2
// v has \hat \ell rows, \lambda columns, storing in column-major order, new_v has \ell + \lambda
// rows and \lambda columns storing in row-major order
bf192_t* new_v = (bf192_t*) malloc((ell + FAEST_192F_LAMBDA*2) * sizeof(bf192_t));
for (unsigned int row = 0; row != ell + FAEST_192F_LAMBDA*2; ++row) {
uint8_t new_row[BF192_NUM_BYTES] = {0};
for (unsigned int column = 0; column != FAEST_192F_LAMBDA; ++column) {
ptr_set_bit(new_row, ptr_get_bit(v[column], row), column);
}
new_v[row] = bf192_load(new_row);
}

return new_v;
}

BOOST_AUTO_TEST_SUITE(test_aes_prove)


BOOST_DATA_TEST_CASE(aes_prove_verify, all_parameters, param_id) {
BOOST_TEST_CONTEXT("Parameter set: " << faest_get_param_name(param_id)) {
const faest_paramset_t* params = faest_get_paramset(param_id);
const bool is_em = faest_is_em(params);
const unsigned int lambda = params->faest_param.lambda;
const unsigned int lambdaBytes = lambda / 8;
const unsigned int ell = params->faest_param.l;
const unsigned int ell_hat =
params->faest_param.l + params->faest_param.lambda * 3 + UNIVERSAL_HASH_B_BITS;
const unsigned int ell_hat_bytes = (ell_hat + 7) / 8;
const unsigned int ell_bytes = (params->faest_param.l + 7) / 8;

// extended witness
//std::vector<uint8_t> w;
std::vector<uint8_t> in;
std::vector<uint8_t> out;

if (lambda == 192 && !is_em) {
for (const auto byte : aes_ctr_192_tv::in) {
for (size_t bit_i = 0; bit_i < ell; bit_i++) {
in.push_back((byte >> bit_i) & 1);
}
}
for (const auto byte : aes_ctr_192_tv::out) {
for (size_t bit_i = 0; bit_i < ell; bit_i++) {
out.push_back((byte >> bit_i) & 1);
}
}
} else if (lambda == 192 && is_em) {
for (const auto byte : rijndael_em_192_tv::in) {
for (size_t bit_i = 0; bit_i < ell; bit_i++) {
in.push_back((byte >> bit_i) & 1);
}
}
for (const auto byte : rijndael_em_192_tv::out) {
for (size_t bit_i = 0; bit_i < ell; bit_i++) {
out.push_back((byte >> bit_i) & 1);
}
}
}
else {
return;
}

uint8_t* w = aes_extend_witness(in.data(), out.data(), params);
std::vector<uint8_t> w_bits(ell, 0x00); // 1 bit in per uint8_t
for (unsigned int bit_i = 0; bit_i < ell; bit_i++) {
w_bits[bit_i] = (w[bit_i/8] >> bit_i%8) & 1;
}

// if (lambda == 192 && !is_em) {
// std::copy(aes_ctr_192_tv::in.begin(), aes_ctr_192_tv::in.end(),
// std::back_insert_iterator(in));
// std::copy(aes_ctr_192_tv::out.begin(), aes_ctr_192_tv::out.end(),
// std::back_insert_iterator(out));
// std::copy(aes_ctr_192_tv::expected_extended_witness.begin(),
// aes_ctr_192_tv::expected_extended_witness.end(), std::back_insert_iterator(w));
// } else if (lambda == 192 && is_em) {
// std::copy(rijndael_em_192_tv::in.begin(), rijndael_em_192_tv::in.end(),
// std::back_insert_iterator(in));
// std::copy(rijndael_em_192_tv::out.begin(), rijndael_em_192_tv::out.end(),
// std::back_insert_iterator(out));
// std::copy(rijndael_em_192_tv::expected_extended_witness.begin(),
// rijndael_em_192_tv::expected_extended_witness.end(), std::back_insert_iterator(w));
// } else if (lambda == 192 && !is_em) {
// std::copy(aes_ctr_192_tv::in.begin(), aes_ctr_192_tv::in.end(),
// std::back_insert_iterator(in));
// std::copy(aes_ctr_192_tv::out.begin(), aes_ctr_192_tv::out.end(),
// std::back_insert_iterator(out));
// std::copy(aes_ctr_192_tv::expected_extended_witness.begin(),
// aes_ctr_192_tv::expected_extended_witness.end(), std::back_insert_iterator(w));
// } else if (lambda == 192 && is_em) {
// std::copy(rijndael_em_192_tv::in.begin(), rijndael_em_192_tv::in.end(),
// std::back_insert_iterator(in));
// std::copy(rijndael_em_192_tv::out.begin(), rijndael_em_192_tv::out.end(),
// std::back_insert_iterator(out));
// std::copy(rijndael_em_192_tv::expected_extended_witness.begin(),
// rijndael_em_192_tv::expected_extended_witness.end(), std::back_insert_iterator(w));
// } else if (lambda == 256 && !is_em) {
// std::copy(aes_ctr_256_tv::in.begin(), aes_ctr_256_tv::in.end(),
// std::back_insert_iterator(in));
// std::copy(aes_ctr_256_tv::out.begin(), aes_ctr_256_tv::out.end(),
// std::back_insert_iterator(out));
// std::copy(aes_ctr_256_tv::expected_extended_witness.begin(),
// aes_ctr_256_tv::expected_extended_witness.end(), std::back_insert_iterator(w));
// } else if (lambda == 256 && is_em) {
// std::copy(rijndael_em_256_tv::in.begin(), rijndael_em_256_tv::in.end(),
// std::back_insert_iterator(in));
// std::copy(rijndael_em_256_tv::out.begin(), rijndael_em_256_tv::out.end(),
// std::back_insert_iterator(out));
// std::copy(rijndael_em_256_tv::expected_extended_witness.begin(),
// rijndael_em_256_tv::expected_extended_witness.end(), std::back_insert_iterator(w));
// }

// prepare vole correlation
std::vector<uint8_t> delta(lambda / 8, 0);
for (size_t i = 0; i < lambda / 8; ++i) {
delta[i] = (uint8_t) i;
}
std::vector<uint8_t> u(ell_hat_bytes, 0x13);
std::vector<uint8_t> vs(ell_hat_bytes * lambda, 0x37);
std::vector<uint8_t> qs = vs;
std::vector<uint8_t*> V(lambda, NULL);
std::vector<uint8_t*> Q(lambda, NULL);

for (size_t i = 0; i < lambda; ++i) {
V[i] = vs.data() + i * ell_hat_bytes;
Q[i] = qs.data() + i * ell_hat_bytes;
if ((delta[i / 8] >> (i % 8)) & 1) {
for (size_t j = 0; j < ell_hat_bytes; ++j) {
Q[i][j] ^= u[j];
}
}
}

std::vector<uint8_t> u_bits(2 * lambda, 0x00); // 1 bit in per uint8_t
for (unsigned int bit_i = 0; bit_i < 2 * lambda; bit_i++) {
u_bits[bit_i] = (u[(ell + bit_i) / 8] >> (ell + bit_i) % 8) & 1;
}
// masked witness d = u ^ w
std::vector<uint8_t> d(ell_bytes, 0x13);
for (size_t i = 0; i < ell_bytes; ++i) {
d[i] = u[i] ^ w[i];
}
std::vector<uint8_t> d_bits(ell, 0x00); // 1 bit in per uint8_t
for (unsigned int bit_i = 0; bit_i < ell; bit_i++) {
d_bits[bit_i] = (d[bit_i / 8] >> bit_i % 8) & 1;
}

std::vector<uint8_t> chall_2((3 * lambda + 64) / 8, 47);

std::vector<uint8_t> a0_tilde(lambda / 8, 0);
std::vector<uint8_t> a1_tilde(lambda / 8, 0);
std::vector<uint8_t> a2_tilde(lambda / 8, 0);

bf192_t bf_delta = bf192_load(delta.data());
bf192_t* q = column_to_row_major_and_shrink_V_192(Q.data(), ell);

bf192_t* w_tag = column_to_row_major_and_shrink_V_192(V.data(), ell);
bf192_t* bf_u_bits = (bf192_t*) malloc(2*lambda * sizeof(bf192_t));

for (unsigned int bit_i = 0; bit_i < 2*lambda; bit_i++) {
u_bits[bit_i] = (u[(ell + bit_i)/8] >> (ell + bit_i)%8) & 1;
bf_u_bits[bit_i] = bf192_from_bit(u_bits[bit_i]);
}
// verifier Delta, q_star
// printf("delta\n");
// print_array<uint8_t>((uint8_t*) &bf_delta, lambdaBytes);
bf192_t q_star_0 = bf192_sum_poly(q + ell);
// printf("q_star_0\n");
// print_array<uint8_t>((uint8_t*) &q_star_0, lambdaBytes);

// prover u,v (want: q = u*Delta + v)
bf192_t bf_u_star_0 = bf192_sum_poly(bf_u_bits);
bf192_t bf_v_star_0 = bf192_sum_poly(w_tag + ell);

//printf("u_star_0\n");
// print_array<uint8_t>((uint8_t*) &bf_u_star_0, lambdaBytes);
//printf("v_star_0\n");
// print_array<uint8_t>((uint8_t*) &bf_v_star_0, lambdaBytes);

bf192_t test_v0 = bf192_add(q_star_0, bf192_mul(bf_delta, bf_u_star_0));
//printf("q_star_0 + delta * u_star_0\n");
// print_array<uint8_t>((uint8_t*) &test_v0, lambdaBytes);

BOOST_TEST(memcmp(&test_v0, &bf_v_star_0, lambdaBytes) == 0);

printf("testing aes_prove\n");

aes_prove(a0_tilde.data(), a1_tilde.data(), a2_tilde.data(), w_bits.data(), u_bits.data(),
V.data(), in.data(), out.data(), chall_2.data(), params);

uint8_t* recomputed_a0_tilde =
aes_verify(d_bits.data(), Q.data(), chall_2.data(), delta.data(), a1_tilde.data(),
a2_tilde.data(), in.data(), out.data(), params);

// check that the proof verifies
printf("FAEST - %s\n", faest_get_param_name(param_id));
for (size_t i = 0; i < 24; i++) {
printf("%d-%d ", recomputed_a0_tilde[i], a0_tilde.data()[i]);
}
printf("\n");

BOOST_TEST(memcmp(recomputed_a0_tilde, a0_tilde.data(), lambdaBytes) == 0);
free(recomputed_a0_tilde);
free(w);
free(bf_u_bits);
free(w_tag);
free(q);
}
}



/*
BOOST_DATA_TEST_CASE(aes_prove_verify, all_parameters, param_id) {
BOOST_TEST_CONTEXT("Parameter set: " << faest_get_param_name(param_id)) {
const faest_paramset_t* params = faest_get_paramset(param_id);
Expand Down Expand Up @@ -317,5 +529,5 @@ BOOST_DATA_TEST_CASE(aes_prove_verify, all_parameters, param_id) {
free(q);
}
}

*/
BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 452b09a

Please sign in to comment.