Skip to content

Commit

Permalink
adding 192 key stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
shibammukherjee committed Feb 6, 2025
1 parent 62046ce commit 82f97b7
Showing 1 changed file with 326 additions and 8 deletions.
334 changes: 326 additions & 8 deletions faest_aes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2353,7 +2353,92 @@ static void aes_128_keyexp_backward_verifier(bf128_t* y_key, const bf128_t* x_ke
}
}
}

// TODO: AES 192/256
static void aes_192_keyexp_backward_prover(uint8_t* y, bf192_t* y_tag, const uint8_t* x, const bf192_t* x_tag, const uint8_t* key, const bf192_t* key_tag, const faest_paramset_t* params) {

const unsigned int Ske = params->faest_param.Ske;
const unsigned int lambda = params->faest_param.lambda;

// ::2
uint8_t x_tilde[8];
bf192_t x_tilde_tag[8];
// ::3
unsigned int iwd = 0;
// ::4
bool rmvRcon = true;
// ::5-6
for (unsigned int j = 0; j < Ske; j++) {
// ::7-10
for (unsigned int bit_i = 0; bit_i < 8; bit_i++) {

x_tilde[bit_i] = x[j*8 + bit_i] ^ key[iwd + (j%4)*8 + bit_i]; // for the witness
x_tilde_tag[bit_i] = bf192_add(x_tag[j*8 + bit_i], key_tag[iwd + (j%4)*8 + bit_i]); // for the tags of each witness bit

if (rmvRcon == true && j % 4 == 0) {
// adding round constant to the witness
x_tilde[bit_i] = x_tilde[bit_i] ^ get_bit(Rcon[j / 4], bit_i);
}
}

// ::11
aes_192_inverse_affine_byte_prover(y + 8*j, y_tag + 8*j, x_tilde, x_tilde_tag); // working in bit per uint8

// ::12-16
if (j%4 == 3) {
if (lambda == 192) {
iwd += 192;
}
else {
iwd += 128;
if (lambda == 256) {
rmvRcon = !rmvRcon;
}
}
}
}
}
static void aes_192_keyexp_backward_verifier(bf192_t* y_key, const bf192_t* x_key, bf192_t* key_key, bf192_t delta, const faest_paramset_t* params) {

const unsigned int Ske = params->faest_param.Ske;
const unsigned int lambda = params->faest_param.lambda;

// ::2
bf192_t x_tilde_key[8];
// ::3
unsigned int iwd = 0;
// ::4
bool rmvRcon = true;
// ::5-6
for (unsigned int j = 0; j < Ske; j++) {
// ::7
for (unsigned int bit_i = 0; bit_i < 8; bit_i++) {
x_tilde_key[bit_i] = bf192_add(x_key[j*8 + bit_i], key_key[iwd + (j%4)*8 + bit_i]); // for the tags of each witness bit
// ::8-10
if (rmvRcon == true && j % 4 == 0) {
bf192_t rcon_key;
const uint8_t c = (Rcon[j / 4] >> bit_i) & 1;
constant_to_vole_192_verifier(&rcon_key, &c, delta, 1);
x_tilde_key[bit_i] = bf192_add(x_tilde_key[bit_i], rcon_key);
}
}
// ::11
aes_192_inverse_affine_byte_verifier(y_key + 8*j, x_tilde_key, delta);

// ::12-16
if (j%4 == 3) {
if (lambda == 192) {
iwd += 192;
}
else {
iwd += 128;
if (lambda == 256) {
rmvRcon = !rmvRcon;
}
}
}
}
}

// // KEY EXP FWD
static void aes_128_keyexp_forward_prover(uint8_t* y, bf128_t* y_tag, const uint8_t* w, const bf128_t* w_tag, const faest_paramset_t* params) {
Expand Down Expand Up @@ -2422,6 +2507,74 @@ static void aes_128_keyexp_forward_verifier(bf128_t* y_key, const bf128_t* w_key
}
}
}

static void aes_192_keyexp_forward_prover(uint8_t* y, bf192_t* y_tag, const uint8_t* w, const bf192_t* w_tag, const faest_paramset_t* params) {

const unsigned int lambda = params->faest_param.lambda;
const unsigned int Nk = lambda/32;
const unsigned int R = params->faest_param.R;

// ::1-2
for (unsigned int i = 0; i < lambda; i++) {
y[i] = w[i];
y_tag[i] = w_tag[i];
}
// ::3
unsigned int i_wd = lambda;
// ::4-10
for (unsigned int j = Nk; j < 4*(R + 1); j++) {
// ::5
if ((j % Nk == 0) || ((Nk > 6) && (j % Nk == 4))) {
// ::6
for (unsigned int word_idx = 0; word_idx < 32; word_idx++) {
y[32*j + word_idx] = w[i_wd + word_idx]; // storing bit by bit
y_tag[32*j + word_idx] = w_tag[i_wd + word_idx]; // storing tags
}
// ::7
i_wd += 32;
// ::8
}
else {
// ::9-10
for (unsigned int word_idx = 0; word_idx < 32; word_idx++) {
y[32*j + word_idx] = y[32*(j - Nk) + word_idx] ^ y[32*(j - 1) + word_idx];
y_tag[32*j + word_idx] = bf192_add(y_tag[32*(j - Nk) + word_idx], y_tag[32*(j - 1) + word_idx]);
}
}
}
}
static void aes_192_keyexp_forward_verifier(bf192_t* y_key, const bf192_t* w_key, const faest_paramset_t* params) {

unsigned int lambda = params->faest_param.lambda;
unsigned int Nk = lambda/32;
unsigned int R = params->faest_param.R;

// ::1-2
for (unsigned int i = 0; i < lambda; i++) {
y_key[i] = w_key[i];
}
// ::3
unsigned int i_wd = lambda;
// ::4-10
for (unsigned int j = Nk; j < 4*(R + 1); j++) {
// ::5
if ((j % Nk == 0) || ((Nk > 6) && (j % Nk == 4))) {
// ::6
for (unsigned int word_idx = 0; word_idx < 32; word_idx++) {
y_key[32*j + word_idx] = w_key[i_wd + word_idx];
}
// ::7
i_wd += 32; // 32 bits -> 4 words
// ::8
} else {
// ::9-10
for (unsigned int word_idx = 0; word_idx < 32; word_idx++) {
y_key[32*j + word_idx] = bf192_add(y_key[32*(j - Nk) + word_idx], y_key[32*(j - 1) + word_idx]);
}
}
}
}

// TODO: AES 192/256

// // KEY EXP CSTRNTS
Expand Down Expand Up @@ -2535,7 +2688,113 @@ static void aes_128_expkey_constraints_prover(bf128_t* z_deg0, bf128_t* z_deg1,
free(w_flat_tag);
}
static void aes_192_expkey_constraints_prover(bf192_t* z_deg0, bf192_t* z_deg1, uint8_t* k, bf192_t* k_tag, const uint8_t* w, const bf192_t* w_tag, const faest_paramset_t* params) {
// TODO

unsigned int Ske = params->faest_param.Ske;
unsigned int lambda = params->faest_param.lambda;
unsigned int Nk = lambda/32;
unsigned int r_prime;

bool do_rot_word = true;

// ::1
aes_192_keyexp_forward_prover(k, k_tag, w, w_tag, params);
// ::2
uint8_t* w_flat = (uint8_t*)malloc(8 * Ske * sizeof(uint8_t));
bf192_t* w_flat_tag = faest_aligned_alloc(BF192_ALIGN, 8 * Ske * sizeof(bf192_t));
aes_192_keyexp_backward_prover(w_flat, w_flat_tag, w + lambda, w_tag + lambda, k, k_tag, params);



// ::3-5
unsigned int iwd = 32*(Nk - 1);
// ::6 Used only on AES-256
// ::7
for (unsigned int j = 0; j < Ske / 4; j++) {
// ::8
bf192_t k_hat[4]; // expnaded key witness
bf192_t w_hat[4]; // inverse output
bf192_t k_hat_sq[4]; // expanded key witness sq
bf192_t w_hat_sq[4]; // inverse output sq

bf192_t k_hat_tag[4]; // expanded key witness tag
bf192_t w_hat_tag[4]; // inverse output tag
bf192_t k_hat_tag_sq[4]; // expanded key tag sq
bf192_t w_hat_tag_sq[4]; // inverser output tag sq

// ::9
for (unsigned int r = 0; r < 4; r++) {
// ::10
r_prime = r;
// ::11
if (do_rot_word) {
r_prime = (r + 3) % 4;
}
// ::12-15
k_hat[r_prime] = bf192_byte_combine_bits(&k[(iwd + 8 * r)]); // lifted key witness
k_hat_sq[r_prime] = bf192_byte_combine_bits_sq(&k[(iwd + 8 * r)]); // lifted key witness sq
{
bf192_t debug = bf192_mul(k_hat[r_prime], k_hat[r_prime]);
assert(memcmp(&debug, &k_hat_sq[r_prime], sizeof(debug)) == 0);
}

w_hat[r] = bf192_byte_combine_bits(&w_flat[(32 * j + 8 * r)]); // lifted output
w_hat_sq[r] = bf192_byte_combine_bits_sq(&w_flat[(32 * j + 8 * r)]); // lifted output sq
{
bf192_t debug = bf192_mul(w_hat[r], w_hat[r]);
assert(memcmp(&debug, &w_hat_sq[r], sizeof(debug)) == 0);
}

// done by both prover and verifier
k_hat_tag[r_prime] = bf192_byte_combine(k_tag + (iwd + 8 * r)); // lifted key tag
k_hat_tag_sq[r_prime] = bf192_byte_combine_sq(k_tag + (iwd + 8 * r)); // lifted key tag sq

w_hat_tag[r] = bf192_byte_combine(w_flat_tag + ((32 * j + 8 * r))); // lifted output tag
w_hat_tag_sq[r] = bf192_byte_combine_sq(w_flat_tag + (32 * j + 8 * r)); // lifted output tag sq
}
// ::16 used only for AES-256
if (lambda == 256) {
do_rot_word = !do_rot_word;
}
// ::17
for (unsigned int r = 0; r < 4; r++) {
{
bf192_t debug = bf192_mul(k_hat[r], w_hat[r]);
bf192_t one = bf192_one();
bf192_t zero = bf192_zero();
assert((memcmp(&debug, &one, sizeof(debug)) == 0)
|| ((memcmp(&k_hat[r], &zero, sizeof(debug)) == 0)
&& (memcmp(&w_hat[r], &zero, sizeof(debug)) == 0)));
}

// ::18-19
z_deg1[8*j + 2*r] = bf192_add(
bf192_add(
bf192_mul(k_hat_sq[r], w_hat_tag[r]),
bf192_mul(k_hat_tag_sq[r], w_hat[r])),
k_hat_tag[r]);

z_deg1[8*j + 2*r + 1] = bf192_add(
bf192_add(
bf192_mul(k_hat[r], w_hat_tag_sq[r]),
bf192_mul(k_hat_tag[r], w_hat_sq[r])),
w_hat_tag[r]);

z_deg0[8*j + 2*r] = bf192_mul(k_hat_tag_sq[r], w_hat_tag[r]);
z_deg0[8*j + 2*r + 1] = bf192_mul(k_hat_tag[r], w_hat_tag_sq[r]);

//z_deg1[8*j + 2*r + 1] = bf192_add(bf192_mul(k_hat[r], w_hat_sq[r]), w_hat[r]);
//z_deg0[8*j + 2*r] = bf192_add(bf192_mul(k_hat_tag_sq[r], w_hat_tag[r]), k_hat_tag[r]);
//z_deg0[8*j + 2*r + 1] = bf192_add(bf192_mul(k_hat_tag[r], w_hat_tag_sq[r]), k_hat_tag[r]);
}
if (lambda == 192) {
iwd += 192;
}
else {
iwd += 128;
}
}
free(w_flat);
free(w_flat_tag);
}
static void aes_256_expkey_constraints_prover(bf256_t* z_deg0, bf256_t* z_deg1, uint8_t* k, bf256_t* k_tag, const uint8_t* w, const bf256_t* w_tag, const faest_paramset_t* params) {
// TODO
Expand Down Expand Up @@ -2605,7 +2864,66 @@ static void aes_128_expkey_constraints_verifier(bf128_t* z_deg1, bf128_t* k_key,
free(w_flat_key);
}
static void aes_192_expkey_constraints_verifier(bf192_t* z_deg1, bf192_t* k_key, const bf192_t* w_key, bf192_t delta, const faest_paramset_t* params) {
// TODO
unsigned int Ske = params->faest_param.Ske;
unsigned int lambda = params->faest_param.lambda;
unsigned int Nk = lambda/32;
unsigned int r_prime;

bool do_rot_word = true;

// ::1
aes_192_keyexp_forward_verifier(k_key, w_key, params);
// ::2
bf192_t* w_flat_key = faest_aligned_alloc(BF192_ALIGN, 8 * Ske * sizeof(bf192_t));
aes_192_keyexp_backward_verifier(w_flat_key, w_key + lambda, k_key, delta, params);

// ::3-5
unsigned int iwd = 32*(Nk - 1); // as 1 unit8 has 8 bits
// ::6 Used only on AES-256
// ::7
for (unsigned int j = 0; j < Ske / 4; j++) {
// ::8
bf192_t k_hat_key[4]; // expanded key witness tag
bf192_t w_hat_key[4]; // inverse output tag
bf192_t k_hat_key_sq[4]; // expanded key tag sq
bf192_t w_hat_key_sq[4]; // inverser output tag sq

// ::9
for (unsigned int r = 0; r < 4; r++) {
// ::10
r_prime = r;
// ::11
if (do_rot_word) {
r_prime = (r + 3) % 4;
}
// ::12-15
k_hat_key[r_prime] = bf192_byte_combine(k_key + (iwd + 8 * r)); // lifted key tag
k_hat_key_sq[r_prime] = bf192_byte_combine_sq(k_key + (iwd + 8 * r)); // lifted key tag sq

w_hat_key[r] = bf192_byte_combine(w_flat_key + ((32 * j + 8 * r))); // lifted output tag
w_hat_key_sq[r] = bf192_byte_combine_sq(w_flat_key + (32 * j + 8 * r)); // lifted output tag sq
}
// ::16 used only for AES-256
if (lambda == 256) {
do_rot_word = !do_rot_word;
}
// ::17-20
for (unsigned int r = 0; r < 4; r++) {
z_deg1[8*j + 2*r] = bf192_add(
bf192_mul(k_hat_key_sq[r], w_hat_key[r]),
bf192_mul(delta, k_hat_key[r]));
z_deg1[8*j + 2*r + 1] = bf192_add(
bf192_mul(k_hat_key[r], w_hat_key_sq[r]),
bf192_mul(delta, w_hat_key[r]));
}
if (lambda == 192) {
iwd += 192;
}
else {
iwd += 128;
}
}
free(w_flat_key);
}
static void aes_256_expkey_constraints_verifier(bf256_t* z_deg1, bf256_t* k_key, const bf256_t* w_key, bf256_t delta, const faest_paramset_t* params) {
// TODO
Expand Down Expand Up @@ -3494,7 +3812,7 @@ static void aes_192_constraints_prover(bf192_t* z_deg0, bf192_t* z_deg1, bf192_t
// ::16
bf192_t* z_tilde_deg0_tag = (bf192_t*)malloc(2*Ske * sizeof(bf192_t));
bf192_t* z_tilde_deg1_val = (bf192_t*)malloc(2*Ske * sizeof(bf192_t));
aes_192_expkey_constraints_prover(z_tilde_deg0_tag, z_tilde_deg1_val, rkeys, rkeys_tag, w, w_tag, params); // w is bit per uint8
// aes_192_expkey_constraints_prover(z_tilde_deg0_tag, z_tilde_deg1_val, rkeys, rkeys_tag, w, w_tag, params); // w is bit per uint8

// ::17 raise degree
for (unsigned int i = 0; i < 2*Ske; i++) {
Expand Down Expand Up @@ -4069,7 +4387,7 @@ static void aes_192_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til
unsigned int ell = params->faest_param.l;

// ::1-5
// V becomes the w_tag: ell + 2*lambda field elements
// V becomes the w_tag
bf192_t* w_tag = column_to_row_major_and_shrink_V_192(V, ell); // This is the tag for w

// ::6-7 embed VOLE masks
Expand Down Expand Up @@ -4126,7 +4444,7 @@ static void aes_256_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til
unsigned int ell = params->faest_param.l;

// ::1-5
// V becomes the w_tag: ell + 2*lambda field elements
// V becomes the w_tag
bf256_t* w_tag = column_to_row_major_and_shrink_V_256(V, ell); // This is the tag for w

// ::6-7 embed VOLE masks
Expand Down Expand Up @@ -4169,9 +4487,9 @@ static void aes_256_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til
free(z1_val);
free(z2_gamma);

zk_hash_256_finalize(a0_tilde, &a0_ctx, bf_v_star_0);
zk_hash_256_finalize(a1_tilde, &a1_ctx, bf256_add(bf_u_star_0, bf_v_star_1));
zk_hash_256_finalize(a2_tilde, &a2_ctx, bf_u_star_1);
zk_hash_256_finalize(a0_tilde, &a0_ctx, bf_u_star_0);
zk_hash_256_finalize(a1_tilde, &a1_ctx, bf256_add(bf_v_star_0, bf_u_star_1));
zk_hash_256_finalize(a2_tilde, &a2_ctx, bf_v_star_1);

free(bf_u_bits);
free(w_tag);
Expand Down

0 comments on commit 82f97b7

Please sign in to comment.