Skip to content

Commit

Permalink
chore: check output public signals (#6)
Browse files Browse the repository at this point in the history
* chore: check output public signals

* chore(ffi, ios): verify output signals in ffi and swift tests

* chore: add new lines
  • Loading branch information
vivianjeng authored Oct 25, 2023
1 parent ad5d5f6 commit 29c98e9
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 11 deletions.
27 changes: 22 additions & 5 deletions mopro-core/src/middleware/circom/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::MoproError;
use std::collections::HashMap;
use std::time::Instant;

use ark_bn254::Bn254;
use ark_bn254::{Bn254, Fr};
use ark_circom::{CircomBuilder, CircomCircuit, CircomConfig};
use ark_crypto_primitives::snark::SNARK;
use ark_groth16::{Groth16, ProvingKey};
Expand Down Expand Up @@ -181,6 +181,12 @@ pub fn bytes_to_circuit_inputs(bytes: &[u8]) -> CircuitInputs {
inputs
}

pub fn bytes_to_circuit_outputs(bytes: &[u8]) -> SerializableInputs {
let bits = bytes_to_bits(bytes);
let field_bits = bits.into_iter().map(|bit| Fr::from(bit as u8)).collect();
SerializableInputs(field_bits)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -203,8 +209,14 @@ mod tests {

// Prepare inputs
let mut inputs = HashMap::new();
inputs.insert("a".to_string(), vec![BigInt::from(3)]);
inputs.insert("b".to_string(), vec![BigInt::from(5)]);
let a = 3;
let b = 5;
let c = a * b;
inputs.insert("a".to_string(), vec![BigInt::from(a)]);
inputs.insert("b".to_string(), vec![BigInt::from(b)]);
// output = [public output c, public input a]
let expected_output = vec![Fr::from(c), Fr::from(a)];
let serialized_outputs = SerializableInputs(expected_output);

// Proof generation
let generate_proof_res = circom_state.generate_proof(inputs);
Expand All @@ -218,6 +230,9 @@ mod tests {

let (serialized_proof, serialized_inputs) = generate_proof_res.unwrap();

// Check output
assert_eq!(serialized_inputs, serialized_outputs);

// Proof verification
let verify_res = circom_state.verify_proof(serialized_proof, serialized_inputs);
assert!(verify_res.is_ok());
Expand Down Expand Up @@ -248,12 +263,13 @@ mod tests {
];

// Expected output
let _expected_output_vec = vec![
let expected_output_vec = vec![
37, 17, 98, 135, 161, 178, 88, 97, 125, 150, 143, 65, 228, 211, 170, 133, 153, 9, 88,
212, 4, 212, 175, 238, 249, 210, 214, 116, 170, 85, 45, 21,
];

let inputs = bytes_to_circuit_inputs(&input_vec);
let serialized_outputs = bytes_to_circuit_outputs(&expected_output_vec);

// Proof generation
let generate_proof_res = circom_state.generate_proof(inputs);
Expand All @@ -267,7 +283,8 @@ mod tests {

let (serialized_proof, serialized_inputs) = generate_proof_res.unwrap();

// TODO: Use expected_output_vec here when verifying proof
// Check output
assert_eq!(serialized_inputs, serialized_outputs);

// Proof verification
let verify_res = circom_state.verify_proof(serialized_proof, serialized_inputs);
Expand Down
2 changes: 1 addition & 1 deletion mopro-core/src/middleware/circom/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub struct SerializableProvingKey(pub ProvingKey<Bn254>);
#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug)]
pub struct SerializableProof(pub Proof<Bn254>);

#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug)]
#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug, PartialEq)]
pub struct SerializableInputs(pub Vec<<Bn254 as Pairing>::ScalarField>);

pub fn serialize_proof(proof: &SerializableProof) -> Vec<u8> {
Expand Down
1 change: 1 addition & 0 deletions mopro-ffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ uniffi = { version = "0.24", features = ["build"] }

[dev-dependencies]
uniffi = { version = "0.24", features = ["bindgen-tests"] }
ark-bn254 = { version = "=0.4.0" }
31 changes: 27 additions & 4 deletions mopro-ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ uniffi::include_scaffolding!("mopro");
#[cfg(test)]
mod tests {
use super::*;
use ark_bn254::Fr;

fn bytes_to_circuit_inputs(input_vec: &Vec<u8>) -> HashMap<String, Vec<i32>> {
let bits = circom::utils::bytes_to_bits(&input_vec);
Expand All @@ -127,6 +128,13 @@ mod tests {
inputs
}

fn bytes_to_circuit_outputs(bytes: &[u8]) -> Vec<u8> {
let bits = circom::utils::bytes_to_bits(bytes);
let field_bits = bits.into_iter().map(|bit| Fr::from(bit as u8)).collect();
let circom_outputs = circom::serialization::SerializableInputs(field_bits);
circom::serialization::serialize_inputs(&circom_outputs)
}

#[test]
fn add_works() {
let result = add(2, 2);
Expand All @@ -147,18 +155,25 @@ mod tests {
assert!(setup_result.provingKey.len() > 0);

let mut inputs = HashMap::new();
inputs.insert("a".to_string(), vec![3_i32; 1]);
inputs.insert("b".to_string(), vec![5_i32; 1]);
let a = 3;
let b = 5;
let c = a * b;
inputs.insert("a".to_string(), vec![a]);
inputs.insert("b".to_string(), vec![b]);
// output = [public output c, public input a]
let expected_output = vec![Fr::from(c), Fr::from(a)];
let circom_outputs = circom::serialization::SerializableInputs(expected_output);
let serialized_outputs = circom::serialization::serialize_inputs(&circom_outputs);

// Step 2: Generate Proof
let generate_proof_result = mopro_circom.generate_proof(inputs)?;
let serialized_proof = generate_proof_result.proof;
let serialized_inputs = generate_proof_result.inputs;

assert!(serialized_proof.len() > 0);
assert_eq!(serialized_inputs, serialized_outputs);

// Step 3: Verify Proof
// TODO: This should also check inputs, make sure it does
let is_valid = mopro_circom.verify_proof(serialized_proof, serialized_inputs)?;
assert!(is_valid);

Expand All @@ -185,17 +200,25 @@ mod tests {
0, 0, 0, 0, 0, 0,
];

// Expected output
let expected_output_vec = vec![
37, 17, 98, 135, 161, 178, 88, 97, 125, 150, 143, 65, 228, 211, 170, 133, 153, 9, 88,
212, 4, 212, 175, 238, 249, 210, 214, 116, 170, 85, 45, 21,
];

let inputs = bytes_to_circuit_inputs(&input_vec);
let serialized_outputs = bytes_to_circuit_outputs(&expected_output_vec);

// Step 2: Generate Proof
let generate_proof_result = mopro_circom.generate_proof(inputs)?;
let serialized_proof = generate_proof_result.proof;
let serialized_inputs = generate_proof_result.inputs;

assert!(serialized_proof.len() > 0);
assert_eq!(serialized_inputs, serialized_outputs);

// Step 3: Verify Proof
// TODO: This should also check inputs, make sure it does

let is_valid = mopro_circom.verify_proof(serialized_proof, serialized_inputs)?;
assert!(is_valid);

Expand Down
32 changes: 32 additions & 0 deletions mopro-ffi/tests/bindings/test_mopro.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
import mopro
import Foundation

let moproCircom = MoproCircom()

let wasmPath = "./../../../../mopro-core/examples/circom/target/multiplier2_js/multiplier2.wasm"
let r1csPath = "./../../../../mopro-core/examples/circom/target/multiplier2.r1cs"

// TODO: should handle 254-bit input
func serializeOutputs(_ int32Array: [Int32]) -> [UInt8] {
var bytesArray: [UInt8] = []
let length = int32Array.count
var littleEndianLength = length.littleEndian
let targetLength = 32
withUnsafeBytes(of: &littleEndianLength) {
bytesArray.append(contentsOf: $0)
}
for value in int32Array {
var littleEndian = value.littleEndian
var byteLength = 0
withUnsafeBytes(of: &littleEndian) {
bytesArray.append(contentsOf: $0)
byteLength = byteLength + $0.count
}
if byteLength < targetLength {
let paddingCount = targetLength - byteLength
let paddingArray = [UInt8](repeating: 0, count: paddingCount)
bytesArray.append(contentsOf: paddingArray)
}
}
return bytesArray
}

do {
// Setup
let setupResult = try moproCircom.setup(wasmPath: wasmPath, r1csPath: r1csPath)
Expand All @@ -15,11 +41,17 @@ do {
inputs["a"] = [3]
inputs["b"] = [5]

// Expected outputs
let outputs: [Int32] = [15, 3]
let expectedOutput: [UInt8] = serializeOutputs(outputs)

// Generate Proof
let generateProofResult = try moproCircom.generateProof(circuitInputs: inputs)
assert(!generateProofResult.proof.isEmpty, "Proof should not be empty")

// Verify Proof
assert(Data(expectedOutput) == generateProofResult.inputs, "Circuit outputs mismatch the expected outputs")

let isValid = try moproCircom.verifyProof(proof: generateProofResult.proof, publicInput: generateProofResult.inputs)
assert(isValid, "Proof verification should succeed")

Expand Down
37 changes: 36 additions & 1 deletion mopro-ffi/tests/bindings/test_mopro_keccak.swift
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import mopro
import Foundation

let moproCircom = MoproCircom()


let wasmPath = "./../../../../mopro-core/examples/circom/keccak256/target/keccak256_256_test_js/keccak256_256_test.wasm"
let r1csPath = "./../../../../mopro-core/examples/circom/keccak256/target/keccak256_256_test.r1cs"

Expand All @@ -18,6 +18,31 @@ func bytesToBits(bytes: [UInt8]) -> [Int32] {
return bits
}

// TODO: should handle 254-bit input
func serializeOutputs(_ int32Array: [Int32]) -> [UInt8] {
var bytesArray: [UInt8] = []
let length = int32Array.count
var littleEndianLength = length.littleEndian
let targetLength = 32
withUnsafeBytes(of: &littleEndianLength) {
bytesArray.append(contentsOf: $0)
}
for value in int32Array {
var littleEndian = value.littleEndian
var byteLength = 0
withUnsafeBytes(of: &littleEndian) {
bytesArray.append(contentsOf: $0)
byteLength = byteLength + $0.count
}
if byteLength < targetLength {
let paddingCount = targetLength - byteLength
let paddingArray = [UInt8](repeating: 0, count: paddingCount)
bytesArray.append(contentsOf: paddingArray)
}
}
return bytesArray
}

do {
// Setup
let setupResult = try moproCircom.setup(wasmPath: wasmPath, r1csPath: r1csPath)
Expand All @@ -32,11 +57,21 @@ do {
var inputs = [String: [Int32]]()
inputs["in"] = bits

// Expected outputs
let outputVec: [UInt8] = [
37, 17, 98, 135, 161, 178, 88, 97, 125, 150, 143, 65, 228, 211, 170, 133, 153, 9, 88,
212, 4, 212, 175, 238, 249, 210, 214, 116, 170, 85, 45, 21,
]
let outputBits: [Int32] = bytesToBits(bytes: outputVec)
let expectedOutput: [UInt8] = serializeOutputs(outputBits)

// Generate Proof
let generateProofResult = try moproCircom.generateProof(circuitInputs: inputs)
assert(!generateProofResult.proof.isEmpty, "Proof should not be empty")

// Verify Proof
assert(Data(expectedOutput) == generateProofResult.inputs, "Circuit outputs mismatch the expected outputs")

let isValid = try moproCircom.verifyProof(proof: generateProofResult.proof, publicInput: generateProofResult.inputs)
assert(isValid, "Proof verification should succeed")

Expand Down
36 changes: 36 additions & 0 deletions mopro-ios/MoproKit/Example/MoproKit/ViewController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,28 @@ class ViewController: UIViewController {
var inputs = [String: [Int32]]()
inputs["in"] = bits

// Expected outputs
let outputVec: [UInt8] = [
37, 17, 98, 135, 161, 178, 88, 97, 125, 150, 143, 65, 228, 211, 170, 133, 153, 9, 88,
212, 4, 212, 175, 238, 249, 210, 214, 116, 170, 85, 45, 21,
]
let outputBits: [Int32] = bytesToBits(bytes: outputVec)
let expectedOutput: [UInt8] = serializeOutputs(outputBits)

// Multiplier example
// var inputs = [String: [Int32]]()
// inputs["a"] = [3]
// inputs["b"] = [5]
// let outputs: [Int32] = [15, 3]
// let expectedOutput: [UInt8] = serializeOutputs(outputs)

// Record start time
let start = CFAbsoluteTimeGetCurrent()

// Generate Proof
let generateProofResult = try moproCircom.generateProof(circuitInputs: inputs)
assert(!generateProofResult.proof.isEmpty, "Proof should not be empty")
assert(Data(expectedOutput) == generateProofResult.inputs, "Circuit outputs mismatch the expected outputs")

// Record end time and compute duration
let end = CFAbsoluteTimeGetCurrent()
Expand Down Expand Up @@ -160,3 +171,28 @@ func bytesToBits(bytes: [UInt8]) -> [Int32] {
}
return bits
}

// TODO: should handle 254-bit input
func serializeOutputs(_ int32Array: [Int32]) -> [UInt8] {
var bytesArray: [UInt8] = []
let length = int32Array.count
var littleEndianLength = length.littleEndian
let targetLength = 32
withUnsafeBytes(of: &littleEndianLength) {
bytesArray.append(contentsOf: $0)
}
for value in int32Array {
var littleEndian = value.littleEndian
var byteLength = 0
withUnsafeBytes(of: &littleEndian) {
bytesArray.append(contentsOf: $0)
byteLength = byteLength + $0.count
}
if byteLength < targetLength {
let paddingCount = targetLength - byteLength
let paddingArray = [UInt8](repeating: 0, count: paddingCount)
bytesArray.append(contentsOf: paddingArray)
}
}
return bytesArray
}

0 comments on commit 29c98e9

Please sign in to comment.