Skip to content

Commit

Permalink
Align prover implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastinas committed Feb 6, 2025
1 parent cc928aa commit a530338
Showing 1 changed file with 44 additions and 22 deletions.
66 changes: 44 additions & 22 deletions faest_aes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3745,30 +3745,40 @@ static void aes_128_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til
free(bf_u_bits);
free(w_tag);
}
static void aes_192_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_tilde, const uint8_t* w, const uint8_t* u,

static void aes_192_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_tilde, const uint8_t* w_bits, const uint8_t* u_bits,
uint8_t** V, const uint8_t* owf_in, const uint8_t* owf_out, const uint8_t* chall_2, const faest_paramset_t* params, bool isEM) {

unsigned int lambda = params->faest_param.lambda;
unsigned int c = params->faest_param.C;
unsigned int ell = params->faest_param.l;

// ::1-5
// V becomes the w_tag
// V becomes the w_tag: ell + 2*lambda field elements
bf192_t* w_tag = column_to_row_major_and_shrink_V_192(V, ell); // This is the tag for w

// ::6-7 embed VOLE masks
bf192_t bf_u_star_0 = bf192_load_bits(u);
bf192_t bf_u_star_1 = bf192_load_bits(u + lambda);
bf192_t* bf_u_bits = (bf192_t*) malloc(2*lambda * sizeof(bf192_t));
for (unsigned int i = 0; i < 2*lambda; i++) {
bf_u_bits[i] = bf192_from_bit(u_bits[i]);
}

bf192_t bf_u_star_0 = bf192_sum_poly(bf_u_bits); // U IS 1 Byte per uint8 right??
bf192_t bf_u_star_1 = bf192_sum_poly(bf_u_bits + lambda);
// ::8-9
bf192_t bf_v_star_0 = bf192_sum_poly(w_tag);
bf192_t bf_v_star_1 = bf192_sum_poly(w_tag + lambda);
bf192_t bf_v_star_0 = bf192_sum_poly(w_tag + ell);
bf192_t bf_v_star_1 = bf192_sum_poly(w_tag + ell + lambda);

// ::10-12
bf192_t* z0_tag = (bf192_t*)malloc(c * sizeof(bf192_t)); // this contains the bf tag
bf192_t* z1_val = (bf192_t*)malloc(c * sizeof(bf192_t)); // this contains the bf val
bf192_t* z2_gamma = (bf192_t*)malloc(c * sizeof(bf192_t)); // this contains the bf gamma
aes_192_constraints_prover(z0_tag, z1_val, z2_gamma, w, w_tag, owf_in, owf_out, params, isEM);

memset(z0_tag, 0, c * sizeof(bf192_t));
memset(z1_val, 0, c * sizeof(bf192_t));
memset(z2_gamma, 0, c * sizeof(bf192_t));

aes_192_constraints_prover(z0_tag, z1_val, z2_gamma, w_bits, w_tag, owf_in, owf_out, params, isEM);

// Step: 13-18
zk_hash_192_ctx a0_ctx;
zk_hash_192_ctx a1_ctx;
Expand All @@ -3787,36 +3797,47 @@ static void aes_192_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til
free(z1_val);
free(z2_gamma);

zk_hash_192_finalize(a0_tilde, &a0_ctx, bf_u_star_0);
zk_hash_192_finalize(a1_tilde, &a1_ctx, bf192_add(bf_v_star_0, bf_u_star_1));
zk_hash_192_finalize(a2_tilde, &a2_ctx, bf_v_star_1);
zk_hash_192_finalize(a0_tilde, &a0_ctx, bf_v_star_0);
zk_hash_192_finalize(a1_tilde, &a1_ctx, bf192_add(bf_u_star_0, bf_v_star_1));
zk_hash_192_finalize(a2_tilde, &a2_ctx, bf_u_star_1);

free(bf_u_bits);
free(w_tag);
}
static void aes_256_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_tilde, const uint8_t* w, const uint8_t* u,

static void aes_256_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_tilde, const uint8_t* w_bits, const uint8_t* u_bits,
uint8_t** V, const uint8_t* owf_in, const uint8_t* owf_out, const uint8_t* chall_2, const faest_paramset_t* params, bool isEM) {

unsigned int lambda = params->faest_param.lambda;
unsigned int c = params->faest_param.C;
unsigned int ell = params->faest_param.l;

// ::1-5
// V becomes the w_tag
// V becomes the w_tag: ell + 2*lambda field elements
bf256_t* w_tag = column_to_row_major_and_shrink_V_256(V, ell); // This is the tag for w

// ::6-7 embed VOLE masks
bf256_t bf_u_star_0 = bf256_load_bits(u);
bf256_t bf_u_star_1 = bf256_load_bits(u + lambda);
bf256_t* bf_u_bits = (bf256_t*) malloc(2*lambda * sizeof(bf256_t));
for (unsigned int i = 0; i < 2*lambda; i++) {
bf_u_bits[i] = bf256_from_bit(u_bits[i]);
}

bf256_t bf_u_star_0 = bf256_sum_poly(bf_u_bits); // U IS 1 Byte per uint8 right??
bf256_t bf_u_star_1 = bf256_sum_poly(bf_u_bits + lambda);
// ::8-9
bf256_t bf_v_star_0 = bf256_sum_poly(w_tag);
bf256_t bf_v_star_1 = bf256_sum_poly(w_tag + lambda);
bf256_t bf_v_star_0 = bf256_sum_poly(w_tag + ell);
bf256_t bf_v_star_1 = bf256_sum_poly(w_tag + ell + lambda);

// ::10-12
bf256_t* z0_tag = (bf256_t*)malloc(c * sizeof(bf256_t)); // this contains the bf tag
bf256_t* z1_val = (bf256_t*)malloc(c * sizeof(bf256_t)); // this contains the bf val
bf256_t* z2_gamma = (bf256_t*)malloc(c * sizeof(bf256_t)); // this contains the bf gamma
aes_256_constraints_prover(z0_tag, z1_val, z2_gamma, w, w_tag, owf_in, owf_out, params, isEM);

memset(z0_tag, 0, c * sizeof(bf256_t));
memset(z1_val, 0, c * sizeof(bf256_t));
memset(z2_gamma, 0, c * sizeof(bf256_t));

aes_256_constraints_prover(z0_tag, z1_val, z2_gamma, w_bits, w_tag, owf_in, owf_out, params, isEM);

// Step: 13-18
zk_hash_256_ctx a0_ctx;
zk_hash_256_ctx a1_ctx;
Expand All @@ -3835,10 +3856,11 @@ static void aes_256_prover(uint8_t* a0_tilde, uint8_t* a1_tilde, uint8_t* a2_til
free(z1_val);
free(z2_gamma);

zk_hash_256_finalize(a0_tilde, &a0_ctx, bf_u_star_0);
zk_hash_256_finalize(a1_tilde, &a1_ctx, bf256_add(bf_v_star_0, bf_u_star_1));
zk_hash_256_finalize(a2_tilde, &a2_ctx, bf_v_star_1);
zk_hash_256_finalize(a0_tilde, &a0_ctx, bf_v_star_0);
zk_hash_256_finalize(a1_tilde, &a1_ctx, bf256_add(bf_u_star_0, bf_v_star_1));
zk_hash_256_finalize(a2_tilde, &a2_ctx, bf_u_star_1);

free(bf_u_bits);
free(w_tag);
}

Expand Down

0 comments on commit a530338

Please sign in to comment.