From 1e39d9d56c6d0c81cbec1d95c2e22aa504aa3a40 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Thu, 26 Sep 2024 18:12:14 +0800 Subject: [PATCH 1/2] Remove the input hash in the groth16 inputs (decrease 1 Uint256). --- gnark-utils/lib/circuit.go | 53 +++++++++---------- gnark-utils/lib/deserialize.go | 1 + gnark-utils/lib/lib.go | 11 ++-- groth16-framework/src/proof.rs | 2 +- groth16-framework/src/prover/groth16.rs | 4 +- .../test_data/Groth16VerifierExtensions.sol | 14 ++--- 6 files changed, 44 insertions(+), 41 deletions(-) diff --git a/gnark-utils/lib/circuit.go b/gnark-utils/lib/circuit.go index f15744e78..cb98e4960 100644 --- a/gnark-utils/lib/circuit.go +++ b/gnark-utils/lib/circuit.go @@ -8,10 +8,10 @@ import ( "math/big" "github.com/consensys/gnark/frontend" + gl "github.com/succinctlabs/gnark-plonky2-verifier/goldilocks" "github.com/succinctlabs/gnark-plonky2-verifier/types" "github.com/succinctlabs/gnark-plonky2-verifier/variables" "github.com/succinctlabs/gnark-plonky2-verifier/verifier" - gl "github.com/succinctlabs/gnark-plonky2-verifier/goldilocks" ) type VerifierCircuit struct { @@ -19,7 +19,7 @@ type VerifierCircuit struct { VerifierDigest frontend.Variable `gnark:"verifierDigest,public"` // The input hash is the hash of all onchain inputs into the function. - InputHash frontend.Variable `gnark:"inputHash,public"` + InputHash frontend.Variable `gnark:"inputHash"` // The output hash is the hash of all outputs from the function. OutputHash frontend.Variable `gnark:"outputHash,public"` @@ -72,35 +72,35 @@ func (c *VerifierCircuit) Define(api frontend.API) error { } // Build a `ProofWithPublicInputs` variable to be employed in `VerifierCircuit` to verify a proof for a circuit -// with `commonCircuitData` +// with `commonCircuitData` func NewProofWithPublicInputs(commonCircuitData *types.CommonCircuitData) variables.ProofWithPublicInputs { proof := newProof(commonCircuitData) public_inputs := make([]gl.Variable, commonCircuitData.NumPublicInputs) return variables.ProofWithPublicInputs{ - Proof: proof, + Proof: proof, PublicInputs: public_inputs, } } -const SALT_SIZE = 4; // same as SALT_SIZE constant in Plonky2 +const SALT_SIZE = 4 // same as SALT_SIZE constant in Plonky2 func newOpeningSet(commonCircuitData *types.CommonCircuitData) variables.OpeningSet { - constants := make([]gl.QuadraticExtensionVariable, commonCircuitData.NumConstants) - plonk_sigmas := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumRoutedWires) - wires := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumWires) - plonk_zs := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumChallenges) - plonk_zs_next := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumChallenges) - partial_products := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumChallenges * commonCircuitData.NumPartialProducts) - quotient_polys := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumChallenges * commonCircuitData.QuotientDegreeFactor) - return variables.OpeningSet{ - Constants: constants, - PlonkSigmas: plonk_sigmas, - Wires: wires, - PlonkZs: plonk_zs, - PlonkZsNext: plonk_zs_next, - PartialProducts: partial_products, - QuotientPolys: quotient_polys, - } + constants := make([]gl.QuadraticExtensionVariable, commonCircuitData.NumConstants) + plonk_sigmas := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumRoutedWires) + wires := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumWires) + plonk_zs := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumChallenges) + plonk_zs_next := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumChallenges) + partial_products := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumChallenges*commonCircuitData.NumPartialProducts) + quotient_polys := make([]gl.QuadraticExtensionVariable, commonCircuitData.Config.NumChallenges*commonCircuitData.QuotientDegreeFactor) + return variables.OpeningSet{ + Constants: constants, + PlonkSigmas: plonk_sigmas, + Wires: wires, + PlonkZs: plonk_zs, + PlonkZsNext: plonk_zs_next, + PartialProducts: partial_products, + QuotientPolys: quotient_polys, + } } func newFriQueryRound(commonCircuitData *types.CommonCircuitData) variables.FriQueryRound { @@ -116,8 +116,8 @@ func newFriQueryRound(commonCircuitData *types.CommonCircuitData) variables.FriQ num_leaves_per_oracle := [4]uint64{ commonCircuitData.NumConstants + commonCircuitData.Config.NumRoutedWires, commonCircuitData.Config.NumWires + salt_size(), - commonCircuitData.Config.NumChallenges*(1 + commonCircuitData.NumPartialProducts) + salt_size(), - commonCircuitData.QuotientDegreeFactor* commonCircuitData.Config.NumChallenges + salt_size(), + commonCircuitData.Config.NumChallenges*(1+commonCircuitData.NumPartialProducts) + salt_size(), + commonCircuitData.QuotientDegreeFactor*commonCircuitData.Config.NumChallenges + salt_size(), } merkle_proof_len := params.LdeBits() - int(cap_height) if merkle_proof_len < 0 { @@ -127,9 +127,9 @@ func newFriQueryRound(commonCircuitData *types.CommonCircuitData) variables.FriQ eval_proofs := make([]variables.FriEvalProof, len(num_leaves_per_oracle)) for j := 0; j < len(eval_proofs); j++ { eval_proofs[j] = variables.NewFriEvalProof( - make([]gl.Variable, num_leaves_per_oracle[j]), - variables.NewFriMerkleProof(uint64(merkle_proof_len)), - ) + make([]gl.Variable, num_leaves_per_oracle[j]), + variables.NewFriMerkleProof(uint64(merkle_proof_len)), + ) } initial_trees := variables.NewFriInitialTreeProof(eval_proofs) // build `FriQueryStep` @@ -186,4 +186,3 @@ func newProof(commonCircuitData *types.CommonCircuitData) variables.Proof { OpeningProof: fri_proof, } } - diff --git a/gnark-utils/lib/deserialize.go b/gnark-utils/lib/deserialize.go index 67b197e2f..46c798a4e 100644 --- a/gnark-utils/lib/deserialize.go +++ b/gnark-utils/lib/deserialize.go @@ -3,6 +3,7 @@ package main import "C" + import ( "encoding/json" diff --git a/gnark-utils/lib/lib.go b/gnark-utils/lib/lib.go index 43d766c19..fca341fc6 100644 --- a/gnark-utils/lib/lib.go +++ b/gnark-utils/lib/lib.go @@ -12,6 +12,7 @@ package main #include */ import "C" + import ( "bufio" "bytes" @@ -39,8 +40,10 @@ import ( // Global variables for the proving process are only necessary to initialize // once by InitProver function. -var R1CS constraint.ConstraintSystem -var PK groth16.ProvingKey +var ( + R1CS constraint.ConstraintSystem + PK groth16.ProvingKey +) // Global variables for the verifying process are only necessary to initialize // once by InitVerifier function. @@ -468,9 +471,9 @@ func ProveCircuit( // We cut off the first 12 bytes because they encode length information. publicWitnessBytes := rawPublicWitnessBytes[12:] - inputs := make([]string, 3) + inputs := make([]string, 2) // Print out the public witness bytes. - for i := 0; i < 3; i++ { + for i := 0; i < 2; i++ { inputs[i] = "0x" + hex.EncodeToString(publicWitnessBytes[i*fpSize:(i+1)*fpSize]) } diff --git a/groth16-framework/src/proof.rs b/groth16-framework/src/proof.rs index 49054e4dd..c50232495 100644 --- a/groth16-framework/src/proof.rs +++ b/groth16-framework/src/proof.rs @@ -8,7 +8,7 @@ pub struct Groth16Proof { /// The proofs item is an array of [U256; 8], which should be passed to the /// `verifyProof` function of the Solidity verifier contract. pub proofs: Vec, - /// The inputs item is an array of [U256; 3], which should be passed to the + /// The inputs item is an array of [U256; 2], which should be passed to the /// `verifyProof` function of the Solidity verifier contract. pub inputs: Vec, /// The original raw proof data is used to be verified off-chain. diff --git a/groth16-framework/src/prover/groth16.rs b/groth16-framework/src/prover/groth16.rs index 6aec74a55..f89790b84 100644 --- a/groth16-framework/src/prover/groth16.rs +++ b/groth16-framework/src/prover/groth16.rs @@ -62,7 +62,7 @@ impl Groth16Prover { /// `groth16_proof.proofs + groth16_proof.inputs + plonky2_proof.public_inputs`. /// In the combined bytes, each part has number as: /// - groth16_proof.proofs: 8 * U256 = 256 bytes - /// - groth16_proof.inputs: 3 * U256 = 96 bytes + /// - groth16_proof.inputs: 2 * U256 = 64 bytes /// - plonky2_proof.public_inputs: the encoded public inputs exported by user pub fn prove(&self, plonky2_proof: &[u8]) -> Result> { // Deserialize the plonky2 proof. @@ -112,7 +112,7 @@ fn load_circuit_data(asset_dir: &str) -> Result> { /// `groth16_proof.proofs + groth16_proof.inputs + plonky2_proof.public_inputs`. /// In the combined bytes, each part has number as: /// - groth16_proof.proofs: 8 * U256 = 256 bytes -/// - groth16_proof.inputs: 3 * U256 = 96 bytes +/// - groth16_proof.inputs: 2 * U256 = 64 bytes /// - plonky2_proof.public_inputs: the encoded public inputs exported by user, /// all fields must be in range of Uint32, it's restricted by `sha256` (in plonky2x). pub fn combine_proofs( diff --git a/groth16-framework/test_data/Groth16VerifierExtensions.sol b/groth16-framework/test_data/Groth16VerifierExtensions.sol index 7bc223d24..4cb78e195 100644 --- a/groth16-framework/test_data/Groth16VerifierExtensions.sol +++ b/groth16-framework/test_data/Groth16VerifierExtensions.sol @@ -94,10 +94,10 @@ contract Query is Verifier { } // The processQuery function does the followings: - // 1. Parse the Groth16 proofs (8 uint256) and inputs (3 uint256) from the `data` + // 1. Parse the Groth16 proofs (8 uint256) and inputs (2 uint256) from the `data` // argument, and call `verifyProof` function for Groth16 verification. // 2. Calculate sha256 on the public inputs, and set the top 3 bits of this hash to 0. - // Then ensure this hash value equals to the last Groth16 input (groth16_inputs[2]). + // Then ensure this hash value equals to the last Groth16 input (groth16_inputs[1]). // 3. Parse the items from public inputs, and check as expected for query. // 4. Parse and return the query output from public inputs. function processQuery( @@ -105,7 +105,7 @@ contract Query is Verifier { QueryInput memory query ) public view returns (QueryOutput memory) { // 1. Groth16 verification - uint256[3] memory groth16Inputs = verifyGroth16Proof(data); + uint256[2] memory groth16Inputs = verifyGroth16Proof(data); // 2. Ensure the sha256 of public inputs equals to the last Groth16 input. verifyPublicInputs(data, groth16Inputs); @@ -122,12 +122,12 @@ contract Query is Verifier { bytes32[] calldata data ) internal view returns (uint256[3] memory) { uint256[8] memory proofs; - uint256[3] memory inputs; + uint256[2] memory inputs; for (uint32 i = 0; i < 8; ++i) { proofs[i] = uint256(data[i]); } - for (uint32 i = 0; i < 3; ++i) { + for (uint32 i = 0; i < 2; ++i) { inputs[i] = uint256(data[i + 8]); } @@ -146,7 +146,7 @@ contract Query is Verifier { // Compute sha256 on the public inputs, and ensure it equals to the last Groth16 input. function verifyPublicInputs( bytes32[] calldata data, - uint256[3] memory groth16Inputs + uint256[2] memory groth16Inputs ) internal pure { // Parse the public inputs from calldata. bytes memory pi = parsePublicInputs(data); @@ -158,7 +158,7 @@ contract Query is Verifier { // Require the sha256 equals to the last Groth16 input. require( - hash == groth16Inputs[2], + hash == groth16Inputs[1], "The sha256 hash of public inputs must be equal to the last of the Groth16 inputs" ); } From 98a9405ff9ec6ea4aec110a5681ed811133996e6 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Thu, 26 Sep 2024 19:32:49 +0800 Subject: [PATCH 2/2] Make simple Groth16 test work. --- groth16-framework/src/lib.rs | 16 +++++++++++++--- groth16-framework/src/test_utils.rs | 25 +++++++++++++++++++------ groth16-framework/src/utils.rs | 13 ++++++++++++- 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/groth16-framework/src/lib.rs b/groth16-framework/src/lib.rs index 59a58857f..79468421e 100644 --- a/groth16-framework/src/lib.rs +++ b/groth16-framework/src/lib.rs @@ -103,7 +103,9 @@ pub type C = PoseidonGoldilocksConfig; #[cfg(test)] mod tests { use super::*; - use crate::test_utils::test_groth16_proving_and_verification; + use crate::test_utils::{ + evm_verify_on_groth16_proof_file, test_groth16_proving_and_verification, + }; use mp2_common::{proof::serialize_proof, D, F}; use plonky2::{ field::types::Field, @@ -116,6 +118,16 @@ mod tests { use rand::{thread_rng, Rng}; use serial_test::serial; + const ASSET_DIR: &str = "groth16_simple"; + + /// Test the verification on a local file of generated Groth16 proof for the simple circuit. + #[ignore] // Ignore in CI, since it could only run for local test. + #[serial] + #[test] + fn test_groth16_simple_local_verification() { + evm_verify_on_groth16_proof_file(ASSET_DIR); + } + /// Test proving and verifying with a simple circuit. #[ignore] // Ignore for long running time in CI. #[serial] @@ -123,8 +135,6 @@ mod tests { fn test_groth16_proving_simple() { env_logger::init(); - const ASSET_DIR: &str = "groth16_simple"; - // Build for the simple circuit and generate the plonky2 proof. let (circuit_data, proof) = plonky2_build_and_prove(); diff --git a/groth16-framework/src/test_utils.rs b/groth16-framework/src/test_utils.rs index 6c75a6d2e..de47044b0 100644 --- a/groth16-framework/src/test_utils.rs +++ b/groth16-framework/src/test_utils.rs @@ -3,7 +3,7 @@ use crate::{ prover::groth16::combine_proofs, - utils::{hex_to_u256, read_file, write_file}, + utils::{deserialize_json_file, hex_to_u256, read_file, write_file}, EVMVerifier, Groth16Proof, Groth16Prover, Groth16Verifier, C, }; use alloy::{contract::Interface, dyn_abi::DynSolValue, json_abi::JsonAbi}; @@ -11,6 +11,11 @@ use mp2_common::{proof::deserialize_proof, D, F}; use plonky2::plonk::proof::ProofWithPublicInputs; use std::path::Path; +const R1CS_FILENAME: &str = "r1cs.bin"; +const PK_FILENAME: &str = "pk.bin"; +const CIRCUIT_FILENAME: &str = "circuit.bin"; +const GROTH16_PROOF_FILENAME: &str = "groth16_proof.json"; + /// Test Groth16 proving, verification and Solidity verification. pub fn test_groth16_proving_and_verification(asset_dir: &str, plonky2_proof: &[u8]) { // Generate the Groth16 proof. @@ -32,16 +37,16 @@ pub fn test_groth16_proving_and_verification(asset_dir: &str, plonky2_proof: &[u /// Test to generate the proof. fn groth16_prove(asset_dir: &str, plonky2_proof: &ProofWithPublicInputs) -> Groth16Proof { // Read r1cs, pk and circuit bytes from asset dir. - let r1cs = read_file(Path::new(asset_dir).join("r1cs.bin")).unwrap(); - let pk = read_file(Path::new(asset_dir).join("pk.bin")).unwrap(); - let circuit = read_file(Path::new(asset_dir).join("circuit.bin")).unwrap(); + let r1cs = read_file(Path::new(asset_dir).join(R1CS_FILENAME)).unwrap(); + let pk = read_file(Path::new(asset_dir).join(PK_FILENAME)).unwrap(); + let circuit = read_file(Path::new(asset_dir).join(CIRCUIT_FILENAME)).unwrap(); // Initialize the Groth16 prover. let prover = Groth16Prover::from_bytes(r1cs, pk, circuit).expect("Failed to initialize the prover"); // Construct the file paths to save the Groth16 and full proofs. - let groth16_proof_path = Path::new(asset_dir).join("groth16_proof.json"); + let groth16_proof_path = Path::new(asset_dir).join(GROTH16_PROOF_FILENAME); // Generate the Groth16 proof. let groth16_proof = prover @@ -63,6 +68,14 @@ fn groth16_verify(asset_dir: &str, proof: &Groth16Proof) { verifier.verify(proof).expect("Failed to verify the proof") } +/// Test the Solidity verification on a JSON file of Groth16 proof. +pub(crate) fn evm_verify_on_groth16_proof_file(asset_dir: &str) { + let groth16_proof_path = Path::new(asset_dir).join(GROTH16_PROOF_FILENAME); + let groth16_proof = deserialize_json_file(groth16_proof_path).unwrap(); + + evm_verify(asset_dir, &groth16_proof); +} + /// Test the Solidity verification. fn evm_verify(asset_dir: &str, proof: &Groth16Proof) { let solidity_file_path = Path::new(asset_dir) @@ -72,7 +85,7 @@ fn evm_verify(asset_dir: &str, proof: &Groth16Proof) { // Build the contract interface for encoding the arguments of verification function. let abi = JsonAbi::parse([ - "function verifyProof(uint256[8] calldata proof, uint256[3] calldata input)", + "function verifyProof(uint256[8] calldata proof, uint256[2] calldata input)", ]) .unwrap(); let contract = Interface::new(abi); diff --git a/groth16-framework/src/utils.rs b/groth16-framework/src/utils.rs index 6f6e629e6..0189a68f6 100644 --- a/groth16-framework/src/utils.rs +++ b/groth16-framework/src/utils.rs @@ -8,9 +8,10 @@ use mp2_common::{ D, F, }; use plonky2::plonk::circuit_data::CircuitData; +use serde::Deserialize; use std::{ fs::{create_dir_all, File}, - io::{Read, Write}, + io::{BufReader, Read, Write}, path::Path, }; @@ -31,6 +32,16 @@ pub fn hex_to_u256(s: &str) -> Result { Ok(u) } +/// Deserialize from a JSON file. +pub fn deserialize_json_file, T: for<'de> Deserialize<'de>>( + file_path: P, +) -> Result { + let file = File::open(file_path)?; + let reader = BufReader::new(file); + + Ok(serde_json::from_reader(reader)?) +} + /// Read the data from a file. pub fn read_file>(file_path: P) -> Result> { let mut data = vec![];