diff --git a/faest_aes.c b/faest_aes.c index c1f2f4b..7ec64f7 100644 --- a/faest_aes.c +++ b/faest_aes.c @@ -2353,7 +2353,92 @@ static void aes_128_keyexp_backward_verifier(bf128_t* y_key, const bf128_t* x_ke } } } + // TODO: AES 192/256 +static void aes_192_keyexp_backward_prover(uint8_t* y, bf192_t* y_tag, const uint8_t* x, const bf192_t* x_tag, const uint8_t* key, const bf192_t* key_tag, const faest_paramset_t* params) { + + const unsigned int Ske = params->faest_param.Ske; + const unsigned int lambda = params->faest_param.lambda; + + // ::2 + uint8_t x_tilde[8]; + bf192_t x_tilde_tag[8]; + // ::3 + unsigned int iwd = 0; + // ::4 + bool rmvRcon = true; + // ::5-6 + for (unsigned int j = 0; j < Ske; j++) { + // ::7-10 + for (unsigned int bit_i = 0; bit_i < 8; bit_i++) { + + x_tilde[bit_i] = x[j*8 + bit_i] ^ key[iwd + (j%4)*8 + bit_i]; // for the witness + x_tilde_tag[bit_i] = bf192_add(x_tag[j*8 + bit_i], key_tag[iwd + (j%4)*8 + bit_i]); // for the tags of each witness bit + + if (rmvRcon == true && j % 4 == 0) { + // adding round constant to the witness + x_tilde[bit_i] = x_tilde[bit_i] ^ get_bit(Rcon[j / 4], bit_i); + } + } + + // ::11 + aes_192_inverse_affine_byte_prover(y + 8*j, y_tag + 8*j, x_tilde, x_tilde_tag); // working in bit per uint8 + + // ::12-16 + if (j%4 == 3) { + if (lambda == 192) { + iwd += 192; + } + else { + iwd += 128; + if (lambda == 256) { + rmvRcon = !rmvRcon; + } + } + } + } +} +static void aes_192_keyexp_backward_verifier(bf192_t* y_key, const bf192_t* x_key, bf192_t* key_key, bf192_t delta, const faest_paramset_t* params) { + + const unsigned int Ske = params->faest_param.Ske; + const unsigned int lambda = params->faest_param.lambda; + + // ::2 + bf192_t x_tilde_key[8]; + // ::3 + unsigned int iwd = 0; + // ::4 + bool rmvRcon = true; + // ::5-6 + for (unsigned int j = 0; j < Ske; j++) { + // ::7 + for (unsigned int bit_i = 0; bit_i < 8; bit_i++) { + x_tilde_key[bit_i] = bf192_add(x_key[j*8 + bit_i], key_key[iwd + (j%4)*8 + bit_i]); // for the tags of each witness bit + // ::8-10 + if (rmvRcon == true && j % 4 == 0) { + bf192_t rcon_key; + const uint8_t c = (Rcon[j / 4] >> bit_i) & 1; + constant_to_vole_192_verifier(&rcon_key, &c, delta, 1); + x_tilde_key[bit_i] = bf192_add(x_tilde_key[bit_i], rcon_key); + } + } + // ::11 + aes_192_inverse_affine_byte_verifier(y_key + 8*j, x_tilde_key, delta); + + // ::12-16 + if (j%4 == 3) { + if (lambda == 192) { + iwd += 192; + } + else { + iwd += 128; + if (lambda == 256) { + rmvRcon = !rmvRcon; + } + } + } + } +} // // KEY EXP FWD static void aes_128_keyexp_forward_prover(uint8_t* y, bf128_t* y_tag, const uint8_t* w, const bf128_t* w_tag, const faest_paramset_t* params) { @@ -2422,6 +2507,74 @@ static void aes_128_keyexp_forward_verifier(bf128_t* y_key, const bf128_t* w_key } } } + +static void aes_192_keyexp_forward_prover(uint8_t* y, bf192_t* y_tag, const uint8_t* w, const bf192_t* w_tag, const faest_paramset_t* params) { + + const unsigned int lambda = params->faest_param.lambda; + const unsigned int Nk = lambda/32; + const unsigned int R = params->faest_param.R; + + // ::1-2 + for (unsigned int i = 0; i < lambda; i++) { + y[i] = w[i]; + y_tag[i] = w_tag[i]; + } + // ::3 + unsigned int i_wd = lambda; + // ::4-10 + for (unsigned int j = Nk; j < 4*(R + 1); j++) { + // ::5 + if ((j % Nk == 0) || ((Nk > 6) && (j % Nk == 4))) { + // ::6 + for (unsigned int word_idx = 0; word_idx < 32; word_idx++) { + y[32*j + word_idx] = w[i_wd + word_idx]; // storing bit by bit + y_tag[32*j + word_idx] = w_tag[i_wd + word_idx]; // storing tags + } + // ::7 + i_wd += 32; + // ::8 + } + else { + // ::9-10 + for (unsigned int word_idx = 0; word_idx < 32; word_idx++) { + y[32*j + word_idx] = y[32*(j - Nk) + word_idx] ^ y[32*(j - 1) + word_idx]; + y_tag[32*j + word_idx] = bf192_add(y_tag[32*(j - Nk) + word_idx], y_tag[32*(j - 1) + word_idx]); + } + } + } +} +static void aes_192_keyexp_forward_verifier(bf192_t* y_key, const bf192_t* w_key, const faest_paramset_t* params) { + + unsigned int lambda = params->faest_param.lambda; + unsigned int Nk = lambda/32; + unsigned int R = params->faest_param.R; + + // ::1-2 + for (unsigned int i = 0; i < lambda; i++) { + y_key[i] = w_key[i]; + } + // ::3 + unsigned int i_wd = lambda; + // ::4-10 + for (unsigned int j = Nk; j < 4*(R + 1); j++) { + // ::5 + if ((j % Nk == 0) || ((Nk > 6) && (j % Nk == 4))) { + // ::6 + for (unsigned int word_idx = 0; word_idx < 32; word_idx++) { + y_key[32*j + word_idx] = w_key[i_wd + word_idx]; + } + // ::7 + i_wd += 32; // 32 bits -> 4 words + // ::8 + } else { + // ::9-10 + for (unsigned int word_idx = 0; word_idx < 32; word_idx++) { + y_key[32*j + word_idx] = bf192_add(y_key[32*(j - Nk) + word_idx], y_key[32*(j - 1) + word_idx]); + } + } + } +} + // TODO: AES 192/256 // // KEY EXP CSTRNTS @@ -2535,7 +2688,113 @@ static void aes_128_expkey_constraints_prover(bf128_t* z_deg0, bf128_t* z_deg1, free(w_flat_tag); } static void aes_192_expkey_constraints_prover(bf192_t* z_deg0, bf192_t* z_deg1, uint8_t* k, bf192_t* k_tag, const uint8_t* w, const bf192_t* w_tag, const faest_paramset_t* params) { - // TODO + + unsigned int Ske = params->faest_param.Ske; + unsigned int lambda = params->faest_param.lambda; + unsigned int Nk = lambda/32; + unsigned int r_prime; + + bool do_rot_word = true; + + // ::1 + aes_192_keyexp_forward_prover(k, k_tag, w, w_tag, params); + // ::2 + uint8_t* w_flat = (uint8_t*)malloc(8 * Ske * sizeof(uint8_t)); + bf192_t* w_flat_tag = faest_aligned_alloc(BF192_ALIGN, 8 * Ske * sizeof(bf192_t)); + aes_192_keyexp_backward_prover(w_flat, w_flat_tag, w + lambda, w_tag + lambda, k, k_tag, params); + + + + // ::3-5 + unsigned int iwd = 32*(Nk - 1); + // ::6 Used only on AES-256 + // ::7 + for (unsigned int j = 0; j < Ske / 4; j++) { + // ::8 + bf192_t k_hat[4]; // expnaded key witness + bf192_t w_hat[4]; // inverse output + bf192_t k_hat_sq[4]; // expanded key witness sq + bf192_t w_hat_sq[4]; // inverse output sq + + bf192_t k_hat_tag[4]; // expanded key witness tag + bf192_t w_hat_tag[4]; // inverse output tag + bf192_t k_hat_tag_sq[4]; // expanded key tag sq + bf192_t w_hat_tag_sq[4]; // inverser output tag sq + + // ::9 + for (unsigned int r = 0; r < 4; r++) { + // ::10 + r_prime = r; + // ::11 + if (do_rot_word) { + r_prime = (r + 3) % 4; + } + // ::12-15 + k_hat[r_prime] = bf192_byte_combine_bits(&k[(iwd + 8 * r)]); // lifted key witness + k_hat_sq[r_prime] = bf192_byte_combine_bits_sq(&k[(iwd + 8 * r)]); // lifted key witness sq + { + bf192_t debug = bf192_mul(k_hat[r_prime], k_hat[r_prime]); + assert(memcmp(&debug, &k_hat_sq[r_prime], sizeof(debug)) == 0); + } + + w_hat[r] = bf192_byte_combine_bits(&w_flat[(32 * j + 8 * r)]); // lifted output + w_hat_sq[r] = bf192_byte_combine_bits_sq(&w_flat[(32 * j + 8 * r)]); // lifted output sq + { + bf192_t debug = bf192_mul(w_hat[r], w_hat[r]); + assert(memcmp(&debug, &w_hat_sq[r], sizeof(debug)) == 0); + } + + // done by both prover and verifier + k_hat_tag[r_prime] = bf192_byte_combine(k_tag + (iwd + 8 * r)); // lifted key tag + k_hat_tag_sq[r_prime] = bf192_byte_combine_sq(k_tag + (iwd + 8 * r)); // lifted key tag sq + + w_hat_tag[r] = bf192_byte_combine(w_flat_tag + ((32 * j + 8 * r))); // lifted output tag + w_hat_tag_sq[r] = bf192_byte_combine_sq(w_flat_tag + (32 * j + 8 * r)); // lifted output tag sq + } + // ::16 used only for AES-256 + if (lambda == 256) { + do_rot_word = !do_rot_word; + } + // ::17 + for (unsigned int r = 0; r < 4; r++) { + { + bf192_t debug = bf192_mul(k_hat[r], w_hat[r]); + bf192_t one = bf192_one(); + bf192_t zero = bf192_zero(); + assert((memcmp(&debug, &one, sizeof(debug)) == 0) + || ((memcmp(&k_hat[r], &zero, sizeof(debug)) == 0) + && (memcmp(&w_hat[r], &zero, sizeof(debug)) == 0))); + } + + // ::18-19 + z_deg1[8*j + 2*r] = bf192_add( + bf192_add( + bf192_mul(k_hat_sq[r], w_hat_tag[r]), + bf192_mul(k_hat_tag_sq[r], w_hat[r])), + k_hat_tag[r]); + + z_deg1[8*j + 2*r + 1] = bf192_add( + bf192_add( + bf192_mul(k_hat[r], w_hat_tag_sq[r]), + bf192_mul(k_hat_tag[r], w_hat_sq[r])), + w_hat_tag[r]); + + z_deg0[8*j + 2*r] = bf192_mul(k_hat_tag_sq[r], w_hat_tag[r]); + z_deg0[8*j + 2*r + 1] = bf192_mul(k_hat_tag[r], w_hat_tag_sq[r]); + + //z_deg1[8*j + 2*r + 1] = bf192_add(bf192_mul(k_hat[r], w_hat_sq[r]), w_hat[r]); + //z_deg0[8*j + 2*r] = bf192_add(bf192_mul(k_hat_tag_sq[r], w_hat_tag[r]), k_hat_tag[r]); + //z_deg0[8*j + 2*r + 1] = bf192_add(bf192_mul(k_hat_tag[r], w_hat_tag_sq[r]), k_hat_tag[r]); + } + if (lambda == 192) { + iwd += 192; + } + else { + iwd += 128; + } + } + free(w_flat); + free(w_flat_tag); } static void aes_256_expkey_constraints_prover(bf256_t* z_deg0, bf256_t* z_deg1, uint8_t* k, bf256_t* k_tag, const uint8_t* w, const bf256_t* w_tag, const faest_paramset_t* params) { // TODO @@ -2605,7 +2864,66 @@ static void aes_128_expkey_constraints_verifier(bf128_t* z_deg1, bf128_t* k_key, free(w_flat_key); } static void aes_192_expkey_constraints_verifier(bf192_t* z_deg1, bf192_t* k_key, const bf192_t* w_key, bf192_t delta, const faest_paramset_t* params) { - // TODO + unsigned int Ske = params->faest_param.Ske; + unsigned int lambda = params->faest_param.lambda; + unsigned int Nk = lambda/32; + unsigned int r_prime; + + bool do_rot_word = true; + + // ::1 + aes_192_keyexp_forward_verifier(k_key, w_key, params); + // ::2 + bf192_t* w_flat_key = faest_aligned_alloc(BF192_ALIGN, 8 * Ske * sizeof(bf192_t)); + aes_192_keyexp_backward_verifier(w_flat_key, w_key + lambda, k_key, delta, params); + + // ::3-5 + unsigned int iwd = 32*(Nk - 1); // as 1 unit8 has 8 bits + // ::6 Used only on AES-256 + // ::7 + for (unsigned int j = 0; j < Ske / 4; j++) { + // ::8 + bf192_t k_hat_key[4]; // expanded key witness tag + bf192_t w_hat_key[4]; // inverse output tag + bf192_t k_hat_key_sq[4]; // expanded key tag sq + bf192_t w_hat_key_sq[4]; // inverser output tag sq + + // ::9 + for (unsigned int r = 0; r < 4; r++) { + // ::10 + r_prime = r; + // ::11 + if (do_rot_word) { + r_prime = (r + 3) % 4; + } + // ::12-15 + k_hat_key[r_prime] = bf192_byte_combine(k_key + (iwd + 8 * r)); // lifted key tag + k_hat_key_sq[r_prime] = bf192_byte_combine_sq(k_key + (iwd + 8 * r)); // lifted key tag sq + + w_hat_key[r] = bf192_byte_combine(w_flat_key + ((32 * j + 8 * r))); // lifted output tag + w_hat_key_sq[r] = bf192_byte_combine_sq(w_flat_key + (32 * j + 8 * r)); // lifted output tag sq + } + // ::16 used only for AES-256 + if (lambda == 256) { + do_rot_word = !do_rot_word; + } + // ::17-20 + for (unsigned int r = 0; r < 4; r++) { + z_deg1[8*j + 2*r] = bf192_add( + bf192_mul(k_hat_key_sq[r], w_hat_key[r]), + bf192_mul(delta, k_hat_key[r])); + z_deg1[8*j + 2*r + 1] = bf192_add( + bf192_mul(k_hat_key[r], w_hat_key_sq[r]), + bf192_mul(delta, w_hat_key[r])); + } + if (lambda == 192) { + iwd += 192; + } + else { + iwd += 128; + } + } + free(w_flat_key); } static void aes_256_expkey_constraints_verifier(bf256_t* z_deg1, bf256_t* k_key, const bf256_t* w_key, bf256_t delta, const faest_paramset_t* params) { // TODO @@ -3494,7 +3812,7 @@ static void aes_192_constraints_prover(bf192_t* z_deg0, bf192_t* z_deg1, bf192_t // ::16 bf192_t* z_tilde_deg0_tag = (bf192_t*)malloc(2*Ske * sizeof(bf192_t)); bf192_t* z_tilde_deg1_val = (bf192_t*)malloc(2*Ske * sizeof(bf192_t)); - aes_192_expkey_constraints_prover(z_tilde_deg0_tag, z_tilde_deg1_val, rkeys, rkeys_tag, w, w_tag, params); // w is bit per uint8 + // aes_192_expkey_constraints_prover(z_tilde_deg0_tag, z_tilde_deg1_val, rkeys, rkeys_tag, w, w_tag, params); // w is bit per uint8 // ::17 raise degree for (unsigned int i = 0; i < 2*Ske; i++) { @@ -4069,7 +4387,7 @@ static void aes_192_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til unsigned int ell = params->faest_param.l; // ::1-5 - // V becomes the w_tag: ell + 2*lambda field elements + // V becomes the w_tag bf192_t* w_tag = column_to_row_major_and_shrink_V_192(V, ell); // This is the tag for w // ::6-7 embed VOLE masks @@ -4126,7 +4444,7 @@ static void aes_256_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til unsigned int ell = params->faest_param.l; // ::1-5 - // V becomes the w_tag: ell + 2*lambda field elements + // V becomes the w_tag bf256_t* w_tag = column_to_row_major_and_shrink_V_256(V, ell); // This is the tag for w // ::6-7 embed VOLE masks @@ -4169,9 +4487,9 @@ static void aes_256_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til free(z1_val); free(z2_gamma); - zk_hash_256_finalize(a0_tilde, &a0_ctx, bf_v_star_0); - zk_hash_256_finalize(a1_tilde, &a1_ctx, bf256_add(bf_u_star_0, bf_v_star_1)); - zk_hash_256_finalize(a2_tilde, &a2_ctx, bf_u_star_1); + zk_hash_256_finalize(a0_tilde, &a0_ctx, bf_u_star_0); + zk_hash_256_finalize(a1_tilde, &a1_ctx, bf256_add(bf_v_star_0, bf_u_star_1)); + zk_hash_256_finalize(a2_tilde, &a2_ctx, bf_v_star_1); free(bf_u_bits); free(w_tag);