Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow JS values as ZkProgram inputs #1934

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions src/examples/zkprogram/program-with-input.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ let { verificationKey } = await MyProgram.compile();
console.log('verification key', verificationKey.data.slice(0, 10) + '..');

console.log('proving base case...');
let { proof } = await MyProgram.baseCase(Field(0));
let { proof } = await MyProgram.baseCase(0);
proof = await testJsonRoundtrip(MyProgram.Proof, proof);

// type sanity check
Expand All @@ -57,7 +57,7 @@ ok = await MyProgram.verify(proof);
console.log('ok (alternative)?', ok);

console.log('proving step 1...');
let { proof: proof1 } = await MyProgram.inductiveCase(Field(1), proof);
let { proof: proof1 } = await MyProgram.inductiveCase(1, proof);
proof1 = await testJsonRoundtrip(MyProgram.Proof, proof1);

console.log('verify...');
Expand All @@ -69,7 +69,7 @@ ok = await MyProgram.verify(proof1);
console.log('ok (alternative)?', ok);

console.log('proving step 2...');
let { proof: proof2 } = await MyProgram.inductiveCase(Field(2), proof1);
let { proof: proof2 } = await MyProgram.inductiveCase(2, proof1);
proof2 = await testJsonRoundtrip(MyProgram.Proof, proof2);

console.log('verify...');
Expand Down
58 changes: 39 additions & 19 deletions src/lib/proof-system/zkprogram.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import {
featureFlagsToMlOption,
} from './feature-flags.js';
import { emptyWitness } from '../provable/types/util.js';
import { InferValue } from '../../bindings/lib/provable-generic.js';
import { From, InferValue } from '../../bindings/lib/provable-generic.js';
import { DeclaredProof, ZkProgramContext } from './zkprogram-context.js';
import { mapObject, mapToObject, zip } from '../util/arrays.js';

Expand Down Expand Up @@ -246,6 +246,7 @@ function ZkProgram<
proveRecursively: {
[I in keyof Config['methods']]: RecursiveProver<
InferProvableOrUndefined<Get<Config, 'publicInput'>>,
ProvableOrUndefined<Get<Config, 'publicInput'>>,
InferProvableOrVoid<Get<Config, 'publicOutput'>>,
PrivateInputs[I]
>;
Expand All @@ -261,11 +262,16 @@ function ZkProgram<
} & {
[I in keyof Config['methods']]: Prover<
InferProvableOrUndefined<Get<Config, 'publicInput'>>,
ProvableOrUndefined<Get<Config, 'publicInput'>>,
InferProvableOrVoid<Get<Config, 'publicOutput'>>,
PrivateInputs[I],
InferProvableOrUndefined<AuxiliaryOutputs[I]>
>;
} {
type PublicInputType = ProvableOrUndefined<Get<Config, 'publicInput'>>;
type PublicInput = InferProvableOrUndefined<Get<Config, 'publicInput'>>;
type PublicOutput = InferProvableOrVoid<Get<Config, 'publicOutput'>>;

let doProving = true;

let methods = config.methods;
Expand All @@ -279,8 +285,6 @@ function ZkProgram<
);

let selfTag = { name: config.name };
type PublicInput = InferProvableOrUndefined<Get<Config, 'publicInput'>>;
type PublicOutput = InferProvableOrVoid<Get<Config, 'publicOutput'>>;

class SelfProof extends Proof<PublicInput, PublicOutput> {
static publicInputType = publicInputType;
Expand All @@ -301,6 +305,7 @@ function ZkProgram<
)
);
let methodFunctions = methodKeys.map((key) => methods[key].method);
let privateInputTypes = methodIntfs.map((m) => m.args);
let maxProofsVerified: undefined | 0 | 1 | 2 = undefined;

async function getMaxProofsVerified() {
Expand Down Expand Up @@ -379,8 +384,8 @@ function ZkProgram<
}

type RegularProver<K extends MethodKey> = (
publicInput: PublicInput,
...args: PrivateInputs[K]
publicInput: From<PublicInputType>,
...args: TupleFrom<PrivateInputs[K]>
) => Promise<{
proof: Proof<PublicInput, PublicOutput>;
auxiliaryOutput: InferProvableOrUndefined<AuxiliaryOutputs[K]>;
Expand All @@ -390,7 +395,11 @@ function ZkProgram<
key: K,
i: number
): RegularProver<K> {
return async function prove_(publicInput, ...args) {
return async function prove_(inputPublicInput, ...inputArgs) {
let publicInput = publicInputType.fromValue(inputPublicInput);
let args = zip(inputArgs, privateInputTypes[i]).map(([arg, type]) =>
ProvableType.get(type).fromValue(arg)
);
if (!doProving) {
let { publicOutput, auxiliaryOutput } =
(hasPublicInput
Expand Down Expand Up @@ -456,6 +465,7 @@ function ZkProgram<

type Prover_<K extends MethodKey = MethodKey> = Prover<
PublicInput,
PublicInputType,
PublicOutput,
PrivateInputs[K],
InferProvableOrUndefined<AuxiliaryOutputs[K]>
Expand Down Expand Up @@ -487,17 +497,20 @@ function ZkProgram<
return compileOutput.verify(statement, proof.proof);
}

let regularRecursiveProvers = mapObject(regularProvers, (prover, key) => {
let regularRecursiveProvers = mapObject(regularProvers, (prover, _key, i) => {
return async function proveRecursively_(
publicInput: PublicInput,
...args: TupleToInstances<PrivateInputs[MethodKey]>
...args: TupleFrom<PrivateInputs[MethodKey]>
) {
// create the base proof in a witness block
let proof = await Provable.witnessAsync(SelfProof, async () => {
// move method args to constants
let constInput = Provable.toConstant(publicInputType, publicInput);
let constArgs = zip(args, methods[key].privateInputs).map(
([arg, type]) => Provable.toConstant(type, arg)
let constInput = Provable.toConstant(
publicInputType,
publicInputType.fromValue(publicInput)
);
let constArgs = zip(args, privateInputTypes[i]).map(([arg, type]) =>
Provable.toConstant(type, ProvableType.get(type).fromValue(arg))
);
let { proof } = await prover(constInput, ...(constArgs as any));
return proof;
Expand All @@ -516,6 +529,7 @@ function ZkProgram<
});
type RecursiveProver_<K extends MethodKey> = RecursiveProver<
PublicInput,
PublicInputType,
PublicOutput,
PrivateInputs[K]
>;
Expand Down Expand Up @@ -555,8 +569,9 @@ function ZkProgram<
publicOutputType: publicOutputType as ProvableOrVoid<
Get<Config, 'publicOutput'>
>,
privateInputTypes: Object.fromEntries(
methodKeys.map((key) => [key, methods[key].privateInputs])
privateInputTypes: mapToObject(
methodKeys,
(_, i) => privateInputTypes[i]
) as any,
auxiliaryOutputTypes: Object.fromEntries(
methodKeys.map((key) => [key, methods[key].auxiliaryOutput])
Expand Down Expand Up @@ -1127,6 +1142,9 @@ type Infer<T> = T extends Subclass<typeof ProofBase>
type TupleToInstances<T> = {
[I in keyof T]: Infer<T[I]>;
};
type TupleFrom<T> = {
[I in keyof T]: From<T[I]>;
};

type PrivateInput = ProvableType | Subclass<typeof ProofBase>;

Expand Down Expand Up @@ -1177,31 +1195,33 @@ type Method<

type Prover<
PublicInput,
PublicInputType,
PublicOutput,
Args extends Tuple<PrivateInput>,
AuxiliaryOutput
> = PublicInput extends undefined
? (...args: TupleToInstances<Args>) => Promise<{
? (...args: TupleFrom<Args>) => Promise<{
proof: Proof<PublicInput, PublicOutput>;
auxiliaryOutput: AuxiliaryOutput;
}>
: (
publicInput: PublicInput,
...args: TupleToInstances<Args>
publicInput: From<PublicInputType>,
...args: TupleFrom<Args>
) => Promise<{
proof: Proof<PublicInput, PublicOutput>;
auxiliaryOutput: AuxiliaryOutput;
}>;

type RecursiveProver<
PublicInput,
PublicInputType,
PublicOutput,
Args extends Tuple<PrivateInput>
> = PublicInput extends undefined
? (...args: TupleToInstances<Args>) => Promise<PublicOutput>
? (...args: TupleFrom<Args>) => Promise<PublicOutput>
: (
publicInput: PublicInput,
...args: TupleToInstances<Args>
publicInput: From<PublicInputType>,
...args: TupleFrom<Args>
) => Promise<PublicOutput>;

type ProvableOrUndefined<A> = A extends undefined
Expand Down
7 changes: 7 additions & 0 deletions src/lib/provable/crypto/signature.ts
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,11 @@ class Signature extends CircuitValue {
let s = this.s.toBigInt();
return SignatureBigint.toBase58({ r, s });
}

static fromValue<T extends AnyConstructor>(
this: T,
{ r, s }: { r: Field | bigint; s: Scalar | bigint }
): InstanceType<T> {
return Signature.fromObject({ r: Field.from(r), s: Scalar.from(s) }) as any;
}
}
15 changes: 12 additions & 3 deletions src/lib/provable/int.ts
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ class UInt64 extends CircuitValue {
}

static fromValue<T extends AnyConstructor>(
x: bigint | UInt64
x: number | bigint | UInt64
): InstanceType<T> {
return UInt64.from(x) as any;
}
Expand Down Expand Up @@ -995,7 +995,7 @@ class UInt32 extends CircuitValue {
}

static fromValue<T extends AnyConstructor>(
x: bigint | UInt32
x: number | bigint | UInt32
): InstanceType<T> {
return UInt32.from(x) as any;
}
Expand Down Expand Up @@ -1081,7 +1081,7 @@ class Sign extends CircuitValue {
}

static fromValue<T extends AnyConstructor>(
x: bigint | Sign
x: number | bigint | Sign
): InstanceType<T> {
if (x instanceof Sign) return x as any;
return new Sign(Field(x)) as any;
Expand Down Expand Up @@ -1842,6 +1842,15 @@ class UInt8 extends Struct({
return new UInt8(x);
}

static fromValue(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of fromField(ff), isn't this similar to new UInt8(ff.value) or UInt8.from(ff), which are not recommended to use instead of UInt8.Unsafe.fromField(ff) because they don't clearly indicate the potential unsafety of wrapping a Field into a UInt8, making the code less readable.

Copy link
Collaborator Author

@mitschabaude mitschabaude Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah absolutely, and this method isn't really intended for developers to use (that's what .from() was designed for).

fromValue() is part of the Provable<T> API and what we use internally to convert JS values into the type -- it's needed here for that reason.

Note that this PR doesn't add UInt8.fromValue() -- it was already there, we just override it to also handle number inputs (which is not unsafe), because number is the most natural type that you'd want to convert a UInt8 from

Copy link
Collaborator Author

@mitschabaude mitschabaude Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are two things to improve IMO, both of which aren't related to this PR though:

  1. internal Provable methods should be on a separate namespace .provable instead of on the class itself, to mark them more clearly as not-developer facing Move Provable<T> methods to a separate .provable namespace, on each provable type #1248
  2. UInt8 shouldn't be a Struct, because while Struct is great for easily composing types, it's not great for defining the base primitives. In this case, the problem is that the fromValue() type of UInt8 is inferred from the Struct({ value: Field }) definition to include { value: Field }, which really should be UInt8 instead (but that doesn't work because the Struct definition doesn't know anything about UInt8)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this similar to new UInt8(ff.value) or UInt8.from(ff), which are not recommended to use

side note: UInt8.from(field) is perfectly fine, because it adds a range-check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the insight!

UInt8 shouldn't be a Struct, because while Struct is great for easily composing types, it's not great for defining the base primitives.

I recall encountering an issue a while ago where I couldn't use Provable.witness for a UInt32 with proofsEnabled=true. Is this related to the same issue, or is it unrelated? I’ll need to revisit and check this again!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UInt32 is not implemented with Struct (but with the older CircuitValue) so that might be a different issue!

in the case of UInt8, one thing that bites me now and then is that UInt8 satisfies Provable<UInt8> is not true (and similar for other Structs), because e.g. the return type of UInt8.fromFields() is { value: Field } and not UInt8

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While if you look at newer types like Bytes, you'll see that because of the way we made .provable aware of the class itself, this works:

Bytes(32) satisfies ProvableType<Bytes>;

// we need all the { value } inputs to correctly extend the Struct
x: number | UInt8 | { value: string | number | bigint | Field }
) {
if (typeof x === 'number') return UInt8.from(x);
if (x instanceof UInt8) return x;
return UInt8.Unsafe.fromField(Field(x.value));
}

private static checkConstant(x: Field) {
if (!x.isConstant()) return;
RangeCheck.rangeCheck8(x);
Expand Down
6 changes: 4 additions & 2 deletions src/lib/util/arrays.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ function pad<T>(array: T[], size: number, value: T): T[] {

function mapObject<
T extends Record<string, any>,
F extends <K extends keyof T>(value: T[K], key: K) => any
F extends <K extends keyof T>(value: T[K], key: K, i: number) => any
>(t: T, fn: F) {
let s = {} as { [K in keyof T]: ReturnType<F> };
let i = 0;
for (let key in t) {
s[key] = fn(t[key], key);
s[key] = fn(t[key], key, i);
i++;
}
return s;
}
Expand Down
2 changes: 1 addition & 1 deletion src/mina-signer/tests/verify-in-snark.unit-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ const MyProgram = ZkProgram({
});

await MyProgram.compile();
let { proof } = await MyProgram.verifySignature(signature, fieldsSnarky);
let { proof } = await MyProgram.verifySignature(signature, fields);
ok = await MyProgram.verify(proof);
expect(ok).toEqual(true);

Expand Down
27 changes: 13 additions & 14 deletions src/tests/fake-proof.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ const RecursiveProgram = ZkProgram({
},
},
verifyInternal: {
privateInputs: [Unconstrained<Proof<undefined, UInt64> | undefined>],
async method(
fakeProof: Unconstrained<Proof<undefined, UInt64> | undefined>
) {
privateInputs: [
Unconstrained.withEmpty<RealProof | undefined>(undefined),
],
async method(fakeProof: Unconstrained<RealProof | undefined>) {
// witness either fake proof from input, or real proof
let proof = await Provable.witnessAsync(RealProof, async () => {
let maybeFakeProof = fakeProof.get();
if (maybeFakeProof !== undefined) return maybeFakeProof;

let { proof } = await RealProgram.make(UInt64.from(34));
let { proof } = await RealProgram.make(34);
return proof;
});

Expand All @@ -98,7 +98,7 @@ let { verificationKey: contractVk } = await RecursiveContract.compile();
let { verificationKey: programVk } = await RecursiveProgram.compile();

// proof that should be rejected
const { proof: fakeProof } = await FakeProgram.make(UInt64.from(99999));
const { proof: fakeProof } = await FakeProgram.make(99999);
const dummyProof = await RealProof.dummy(undefined, UInt64.zero, 0);

for (let proof of [fakeProof, dummyProof]) {
Expand All @@ -115,7 +115,7 @@ for (let proof of [fakeProof, dummyProof]) {
}

// proof that should be accepted
const { proof: realProof } = await RealProgram.make(UInt64.from(34));
const { proof: realProof } = await RealProgram.make(34);

// zkprogram accepts proof
const { proof: recursiveProof } = await RecursiveProgram.verifyReal(realProof);
Expand All @@ -139,15 +139,14 @@ console.log('fake proof test passed 🎉');
for (let proof of [fakeProof, dummyProof]) {
// zkprogram rejects proof (nested)
await assert.rejects(async () => {
await RecursiveProgram.verifyNested(Field(0), { inner: proof });
await RecursiveProgram.verifyNested(0, { inner: proof });
}, 'recursive program rejects fake proof (nested)');
}

// zkprogram accepts proof (nested)
const { proof: recursiveProofNested } = await RecursiveProgram.verifyNested(
Field(0),
{ inner: realProof }
);
const { proof: recursiveProofNested } = await RecursiveProgram.verifyNested(0, {
inner: realProof,
});
assert(
await verify(recursiveProofNested, programVk),
'recursive program accepts real proof (nested)'
Expand All @@ -160,13 +159,13 @@ console.log('fake proof test passed for nested proofs 🎉');
for (let proof of [fakeProof, dummyProof]) {
// zkprogram rejects proof (internal)
await assert.rejects(async () => {
await RecursiveProgram.verifyInternal(Unconstrained.from(proof));
await RecursiveProgram.verifyInternal(proof);
}, 'recursive program rejects fake proof (internal)');
}

// zkprogram accepts proof (internal)
const { proof: internalProof } = await RecursiveProgram.verifyInternal(
Unconstrained.from(undefined)
undefined
);
assert(
await verify(internalProof, programVk),
Expand Down
2 changes: 1 addition & 1 deletion src/tests/inductive-proofs-internal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ let Wrapper = ZkProgram({

methods: {
wrap: {
privateInputs: [ZkProgram.Proof(MaxProofsVerifiedTwo)],
privateInputs: [MaxProofsVerifiedTwo.Proof],

async method(proof: Proof<undefined, Field>) {
proof.verify();
Expand Down