Skip to content

Commit

Permalink
Expand AES prove test to check u_star, v_star etc
Browse files Browse the repository at this point in the history
  • Loading branch information
pascholl committed Feb 5, 2025
1 parent 099ec8e commit dac6dcf
Showing 1 changed file with 63 additions and 7 deletions.
70 changes: 63 additions & 7 deletions tests/aes_prove.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
*/

#include "faest_aes.h"
#include "fields.h"
#include "utils.hpp"
#include "utils.h"
#include "instances.hpp"
#include "randomness.h"

Expand Down Expand Up @@ -105,6 +108,22 @@ namespace rijndael_em_256_tv {
};
} // namespace rijndael_em_256_tv

static bf128_t* column_to_row_major_and_shrink_V_128(uint8_t** v, unsigned int ell) {
// V is \hat \ell times \lambda matrix over F_2
// v has \hat \ell rows, \lambda columns, storing in column-major order, new_v has \ell + 2 \lambda
// rows and \lambda columns storing in row-major order
bf128_t* new_v = (bf128_t*) malloc((ell + FAEST_128F_LAMBDA*2) * sizeof(bf128_t));
// faest_aligned_alloc(BF128_ALIGN, (ell + FAEST_128F_LAMBDA*2) * sizeof(bf128_t));
for (unsigned int row = 0; row != ell + FAEST_128F_LAMBDA*2; ++row) {
uint8_t new_row[BF128_NUM_BYTES] = {0};
for (unsigned int column = 0; column != FAEST_128F_LAMBDA; ++column) {
ptr_set_bit(new_row, ptr_get_bit(v[column], row), column);
}
new_v[row] = bf128_load(new_row);
}
return new_v;
}

BOOST_AUTO_TEST_SUITE(test_aes_prove)

BOOST_DATA_TEST_CASE(aes_prove_verify, all_parameters, param_id) {
Expand Down Expand Up @@ -203,8 +222,10 @@ BOOST_DATA_TEST_CASE(aes_prove_verify, all_parameters, param_id) {

// prepare vole correlation
std::vector<uint8_t> delta(lambda / 8, 0);
delta[0] = 42;
std::vector<uint8_t> u(ell_hat_bytes, 0x13);
for (size_t i = 0; i < lambda / 8; ++i) {
delta[i] = (uint8_t) i;
}
std::vector<uint8_t> u(ell_hat_bytes, 0x13); // 1 bit per byte
std::vector<uint8_t> vs(ell_hat_bytes * lambda, 0x37);
std::vector<uint8_t> qs = vs;
std::vector<uint8_t*> V(lambda, NULL);
Expand All @@ -229,30 +250,65 @@ BOOST_DATA_TEST_CASE(aes_prove_verify, all_parameters, param_id) {
for (size_t i = 0; i < ell_bytes; ++i) {
d[i] = u[i] ^ w[i];
}
std::vector<uint8_t> d_bits(ell, 0x00); // 1 bit in per uint8_t
for (unsigned int bit_i = 0; bit_i < ell; bit_i++) {
d_bits[bit_i] = (d[bit_i / 8] >> bit_i % 8) & 1;
}
// std::vector<uint8_t> d_bits(ell, 0x00); // 1 bit in per uint8_t
// for (unsigned int bit_i = 0; bit_i < ell; bit_i++) {
// d_bits[bit_i] = (d[bit_i / 8] >> bit_i % 8) & 1;
// }

std::vector<uint8_t> chall_2((3 * lambda + 64) / 8, 47);

std::vector<uint8_t> a0_tilde(lambda / 8, 0);
std::vector<uint8_t> a1_tilde(lambda / 8, 0);
std::vector<uint8_t> a2_tilde(lambda / 8, 0);

bf128_t bf_delta = bf128_load(delta.data());
bf128_t* q = column_to_row_major_and_shrink_V_128(Q.data(), ell);

bf128_t* w_tag = column_to_row_major_and_shrink_V_128(V.data(), ell);
bf128_t* bf_u_bits = (bf128_t*) malloc(2*lambda * sizeof(bf128_t));

for (unsigned int bit_i = 0; bit_i < 2*lambda; bit_i++) {
u_bits[bit_i] = (u[(ell + bit_i)/8] >> (ell + bit_i)%8) & 1;
bf_u_bits[bit_i] = bf128_from_bit(u_bits[bit_i]);
}
// verifier Delta, q_star
// printf("delta\n");
// print_array<uint8_t>((uint8_t*) &bf_delta, lambdaBytes);
bf128_t q_star_0 = bf128_sum_poly(q + ell);
// printf("q_star_0\n");
// print_array<uint8_t>((uint8_t*) &q_star_0, lambdaBytes);

// prover u,v (want: q = u*Delta + v)
bf128_t bf_u_star_0 = bf128_sum_poly(bf_u_bits);
bf128_t bf_v_star_0 = bf128_sum_poly(w_tag + ell);

//printf("u_star_0\n");
// print_array<uint8_t>((uint8_t*) &bf_u_star_0, lambdaBytes);
//printf("v_star_0\n");
// print_array<uint8_t>((uint8_t*) &bf_v_star_0, lambdaBytes);

bf128_t test_v0 = bf128_add(q_star_0, bf128_mul(bf_delta, bf_u_star_0));
//printf("q_star_0 + delta * u_star_0\n");
// print_array<uint8_t>((uint8_t*) &test_v0, lambdaBytes);

BOOST_TEST(memcmp(&test_v0, &bf_v_star_0, lambdaBytes) == 0);

printf("testing aes_prove\n");

aes_prove(a0_tilde.data(), a1_tilde.data(), a2_tilde.data(), w_bits.data(), u_bits.data(),
V.data(), in.data(), out.data(), chall_2.data(), params);

uint8_t* recomputed_a0_tilde =
aes_verify(d_bits.data(), Q.data(), chall_2.data(), delta.data(), a1_tilde.data(),
aes_verify(d.data(), Q.data(), chall_2.data(), delta.data(), a1_tilde.data(),
a2_tilde.data(), in.data(), out.data(), params);

// check that the proof verifies
BOOST_TEST(memcmp(recomputed_a0_tilde, a0_tilde.data(), lambdaBytes) == 0);
free(recomputed_a0_tilde);
free(w);
free(bf_u_bits);
free(w_tag);
free(q);
}
}

Expand Down

0 comments on commit dac6dcf

Please sign in to comment.