Skip to content

Commit

Permalink
fix: pass thread handle instead of witnesses directly
Browse files Browse the repository at this point in the history
  • Loading branch information
KimiWu123 committed Jan 23, 2025
1 parent f533c8d commit b6d39a6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
7 changes: 6 additions & 1 deletion circom-prover/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use anyhow::Result;
use num::BigUint;
use std::thread::JoinHandle;

pub mod arkworks;
pub mod serialization;
Expand All @@ -14,7 +15,11 @@ pub enum ProofLib {
RapidSnark,
}

pub fn prove(lib: ProofLib, zkey_path: String, witnesses: Vec<BigUint>) -> Result<CircomProof> {
pub fn prove(
lib: ProofLib,
zkey_path: String,
witnesses: JoinHandle<Vec<BigUint>>,
) -> Result<CircomProof> {
match lib {
ProofLib::Arkworks => arkworks::generate_circom_proof(zkey_path, witnesses),
ProofLib::RapidSnark => panic!("Not supported yet."),
Expand Down
18 changes: 15 additions & 3 deletions circom-prover/src/prover/arkworks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,37 @@ use serialization::{SerializableInputs, SerializableProof};

use anyhow::{bail, Result};
use num_bigint::BigUint;
use std::fs::File;
use std::{fs::File, thread::JoinHandle};

use super::{serialization, CircomProof};

pub fn generate_circom_proof(zkey_path: String, witnesses: Vec<BigUint>) -> Result<CircomProof> {
pub fn generate_circom_proof(
zkey_path: String,
witness_thread: JoinHandle<Vec<BigUint>>,
) -> Result<CircomProof> {
// here we make a loader just to get the groth16 header
// this header tells us what curve the zkey was compiled for
// this loader will only load the first few bytes
let mut header_reader = ZkeyHeaderReader::new(&zkey_path);
header_reader.read();
let file = File::open(&zkey_path)?;
let mut reader = std::io::BufReader::new(file);

// check the prime in the header
// println!("{} {} {}", header.q, header.n8q, ark_bls12_381::Fq::MODULUS);
if header_reader.r == BigUint::from(ark_bn254::Fr::MODULUS) {
let (proving_key, matrices) = read_zkey::<_, Bn254>(&mut reader)?;
// Get the result witness from the background thread
let witnesses = witness_thread
.join()
.map_err(|_e| anyhow::anyhow!("witness thread panicked"))
.unwrap();
prove(proving_key, matrices, witnesses)
} else if header_reader.r == BigUint::from(ark_bls12_381::Fr::MODULUS) {
let (proving_key, matrices) = read_zkey::<_, Bls12_381>(&mut reader)?;
let witnesses = witness_thread
.join()
.map_err(|_e| anyhow::anyhow!("witness thread panicked"))
.unwrap();
prove(proving_key, matrices, witnesses)
} else {
panic!("unknown curve detected in zkey");
Expand Down
7 changes: 2 additions & 5 deletions circom-prover/src/witness.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use num::{BigInt, BigUint};
use std::{collections::HashMap, str::FromStr};
use std::{collections::HashMap, str::FromStr, thread::JoinHandle};

/// Witness function signature for rust_witness (inputs) -> witness
type RustWitnessWtnsFn = fn(HashMap<String, Vec<BigInt>>) -> Vec<BigInt>;
Expand Down Expand Up @@ -34,7 +34,7 @@ pub fn generate_witness(
witness_fn: WitnessFn,
inputs: HashMap<String, Vec<String>>,
dat_path: String,
) -> Vec<BigUint> {
) -> JoinHandle<Vec<BigUint>> {
std::thread::spawn(move || {
let bigint_inputs = inputs
.into_iter()
Expand All @@ -57,7 +57,4 @@ pub fn generate_witness(
.map(|w| w.to_biguint().unwrap())
.collect::<Vec<_>>()
})
.join()
.map_err(|_e| anyhow::anyhow!("witness thread panicked"))
.unwrap()
}
4 changes: 2 additions & 2 deletions mopro-ffi/src/circom/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ pub fn generate_circom_proof_wtns(
// .dat file is supposed to be located next to the zkey file and have the same filename
let mut dat_file_path = zkey_path.clone();
dat_file_path = dat_file_path.replace(".zkey", ".dat");
let witnesses = generate_witness(witness_fn, inputs, dat_file_path);
let witness_thread = generate_witness(witness_fn, inputs, dat_file_path);

let ret = prove(proof_lib, zkey_path, witnesses).unwrap();
let ret = prove(proof_lib, zkey_path, witness_thread).unwrap();
Ok(GenerateProofResult {
proof: ret.proof,
inputs: ret.pub_inputs,
Expand Down

0 comments on commit b6d39a6

Please sign in to comment.