diff --git a/circuits/circom/calculateTotal.circom b/circuits/circom/calculateTotal.circom new file mode 100644 index 0000000000..d3bf034a85 --- /dev/null +++ b/circuits/circom/calculateTotal.circom @@ -0,0 +1,22 @@ +pragma circom 2.0.0; + +/** + * Computes the cumulative sum of an array of n input signals. + * It iterates through each input, aggregating the sum up to that point, + * and outputs the total sum of all inputs. This template is useful for + * operations requiring the total sum of multiple signals, ensuring the + * final output reflects the cumulative total of the inputs provided. + */ +template CalculateTotal(n) { + signal input nums[n]; + signal output sum; + + signal sums[n]; + sums[0] <== nums[0]; + + for (var i=1; i < n; i++) { + sums[i] <== sums[i - 1] + nums[i]; + } + + sum <== sums[n - 1]; +} diff --git a/circuits/circom/ecdh.circom b/circuits/circom/ecdh.circom deleted file mode 100644 index f9c6451430..0000000000 --- a/circuits/circom/ecdh.circom +++ /dev/null @@ -1,26 +0,0 @@ -pragma circom 2.0.0; - -// circomlib imports -include "./bitify.circom"; -include "./escalarmulany.circom"; - -template Ecdh() { - // Note: the private key needs to be hashed and pruned first - signal input privKey; - signal input pubKey[2]; - signal output sharedKey[2]; - - component privBits = Num2Bits(253); - privBits.in <== privKey; - - component mulFix = EscalarMulAny(253); - mulFix.p[0] <== pubKey[0]; - mulFix.p[1] <== pubKey[1]; - - for (var i = 0; i < 253; i++) { - mulFix.e[i] <== privBits.out[i]; - } - - sharedKey[0] <== mulFix.out[0]; - sharedKey[1] <== mulFix.out[1]; -} diff --git a/circuits/circom/float.circom b/circuits/circom/float.circom deleted file mode 100644 index f2d479fde7..0000000000 --- a/circuits/circom/float.circom +++ /dev/null @@ -1,181 +0,0 @@ -pragma circom 2.0.0; - -// circomlib imports -include "./bitify.circom"; -include "./comparators.circom"; -include "./mux1.circom"; - -template msb(n) { - // require in < 2**n - signal input in; - signal output out; - component n2b = Num2Bits(n); - n2b.in <== in; - n2b.out[n-1] ==> out; -} - -template shift(n) { - // shift divident and partial rem together - // divident will reduce by 1 bit each call - // require divident < 2**n - signal input divident; - signal input rem; - signal output divident1; - signal output rem1; - - component lmsb = msb(n); - lmsb.in <== divident; - rem1 <== rem * 2 + lmsb.out; - divident1 <== divident - lmsb.out * 2**(n-1); -} - -template IntegerDivision(n) { - signal input a; - signal input b; - signal output c; - - component lta = LessThan(252); - lta.in[0] <== a; - lta.in[1] <== 2**n; - lta.out === 1; - component ltb = LessThan(252); - ltb.in[0] <== b; - ltb.in[1] <== 2**n; - ltb.out === 1; - - component isz = IsZero(); - isz.in <== b; - isz.out === 0; - - var divident = a; - var rem = 0; - - component b2n = Bits2Num(n); - component shf[n]; - component lt[n]; - component mux[n]; - - for (var i = n - 1; i >= 0; i--) { - shf[i] = shift(i+1); - lt[i] = LessEqThan(n); - mux[i] = Mux1(); - } - - for (var i = n-1; i >= 0; i--) { - shf[i].divident <== divident; - shf[i].rem <== rem; - divident = shf[i].divident1; - rem = shf[i].rem1; - - lt[i].in[0] <== b; - lt[i].in[1] <== rem; - - mux[i].s <== lt[i].out; - mux[i].c[0] <== 0; - mux[i].c[1] <== 1; - mux[i].out ==> b2n.in[i]; - - rem = rem - b * lt[i].out; - } - b2n.out ==> c; -} - -template ToFloat(W) { - // W is the number of digits in decimal part - // 10^75 < 2^252 - assert(W < 75); - - // in*10^W <= 10^75 - signal input in; - signal output out; - component lt = LessEqThan(252); - lt.in[0] <== in; - lt.in[1] <== 10**(75-W); - lt.out === 1; - out <== in * (10**W); -} - -template DivisionFromFloat(W, n) { - // W is the number of digits in decimal part - // n is maximum width of a and b in binary form - // assume a, b are both float representation - signal input a; - signal input b; - signal output c; - - assert(W < 75); - assert(n < 252); - - component lt = LessThan(252); - lt.in[0] <== a; - lt.in[1] <== 10 ** (75 - W); - lt.out === 1; - - component div = IntegerDivision(n); - div.a <== a * (10 ** W); - div.b <== b; - c <== div.c; -} - -template DivisionFromNormal(W, n) { - // W is the number of digits in decimal part - // n is maximum width of integer part of a and b in binary form - // assume a, b are both normal representation - signal input a; - signal input b; - signal output c; - component tfa = ToFloat(W); - component tfb = ToFloat(W); - component div = DivisionFromFloat(W, n); - tfa.in <== a; - tfb.in <== b; - div.a <== tfa.out; - div.b <== tfb.out; - c <== div.c; -} - -template MultiplicationFromFloat(W, n) { - // W is the number of digits in decimal part - // n is maximum width of integer part of a and b in binary form - // assume a, b are both float representation - signal input a; - signal input b; - signal output c; - - assert(W < 75); - assert(n < 252); - assert(10**W < 2**n); - - component div = IntegerDivision(n); - - // not the best way, but works in our case - component lta = LessThan(252); - lta.in[0] <== a; - lta.in[1] <== 2 ** 126; - lta.out === 1; - component ltb = LessThan(252); - ltb.in[0] <== b; - ltb.in[1] <== 2 ** 126; - ltb.out === 1; - - div.a <== a * b; - div.b <== 10**W; - c <== div.c; -} - -template MultiplicationFromNormal(W, n) { - // W is the number of digits in decimal part - // n is maximum width of integer part of a and b in binary form - // assume a, b are both float representation - signal input a; - signal input b; - signal output c; - component tfa = ToFloat(W); - component tfb = ToFloat(W); - component mul = MultiplicationFromFloat(W, n); - tfa.in <== a; - tfb.in <== b; - mul.a <== tfa.out; - mul.b <== tfb.out; - c <== mul.c; -} diff --git a/circuits/circom/hashers.circom b/circuits/circom/hashers.circom index 21e1938d55..0f60249a06 100644 --- a/circuits/circom/hashers.circom +++ b/circuits/circom/hashers.circom @@ -1,10 +1,17 @@ pragma circom 2.0.0; -// https://github.com/weijiekoh/circomlib/blob/feat/poseidon-encryption/circuits/poseidon.circom +// from @zk-kit/circuits package. include "./poseidon-cipher.circom"; +// from circomlib. include "./sha256/sha256.circom"; include "./bitify.circom"; +/** + * Computes the SHA-256 hash of an array of input signals. Each input is first + * converted to a 256-bit representation, then these are concatenated and passed + * to the SHA-256 hash function. The output is the 256 hash value of the inputs bits + * converted back to numbers. + */ template Sha256Hasher(length) { var inBits = 256 * length; @@ -28,16 +35,22 @@ template Sha256Hasher(length) { for (var i = 0; i < 256; i++) { shaOut.in[i] <== sha.out[255-i]; } + hash <== shaOut.out; } -// Template for computing the Poseidon hash of an array of 'n' inputs -// with a default zero state (not included in the 'n' inputs). +/** + * Computes the Poseidon hash for an array of n inputs, including a default initial state + * of zero not counted in n. First, extends the inputs by prepending a zero, creating an array [0, inputs]. + * Then, the Poseidon hash of the extended inputs is calculated, with the first element of the + * result assigned as the output. + */ template PoseidonHasher(n) { signal input inputs[n]; signal output out; - var extendedInputs[n + 1]; // [0, inputs]. + // [0, inputs]. + var extendedInputs[n + 1]; extendedInputs[0] = 0; for (var i = 0; i < n; i++) { @@ -51,14 +64,20 @@ template PoseidonHasher(n) { out <== perm[0]; } -// hash a MACI message together with the public key -// used to encrypt the message +/** + * Hashes a MACI message and the public key used for message encryption. + * This template processes 11 message inputs and a 2-element public key + * combining them using the Poseidon hash function. The hashing process involves two stages: + * 1. hashing message parts in groups of five and, + * 2. hashing the grouped results alongside the first message input and + * the encryption public key to produce a final hash output. + */ template MessageHasher() { - // 11 inputs are the MACI message + // 11 inputs are the MACI message. signal input in[11]; - // the public key used to encrypt the message + // the public key used to encrypt the message. signal input encPubKey[2]; - // we output an hash + // we output an hash. signal output hash; // Hasher5( diff --git a/circuits/circom/iqt.circom b/circuits/circom/iqt.circom new file mode 100644 index 0000000000..8c9ef265b6 --- /dev/null +++ b/circuits/circom/iqt.circom @@ -0,0 +1,280 @@ +pragma circom 2.0.0; + +// from @zk-kit/circuits package. +include "./safe-comparators.circom"; +// from circomlib. +include "./bitify.circom"; +include "./mux1.circom"; +// local. +include "./calculateTotal.circom"; +include "./hashers.circom"; + +// Incremental Quintary Merkle Tree (IQT) verification circuits. +// Since each node contains 5 leaves, we are using PoseidonT6 for hashing them. +// +// nb. circom has some particularities which limit the code patterns we can use: +// - You can only assign a value to a signal once. +// - A component's input signal must only be wired to another component's output signal. +// - Variables can store linear combinations, and can also be used for loops, +// declaring sizes of things, and anything that is not related to inputs of a circuit. +// - The compiler fails whenever you try to mix invalid elements. +// - You can't use a signal as a list index. + +/** + * Selects an item from a list based on the given index. + * It verifies the index is within the valid range and then iterates over the inputs to find the match. + * For each item, it checks if its position equals the given index and if so, multiplies the item + * by the result of the equality check, effectively selecting it. + * The sum of these results yields the selected item, ensuring only the item at the specified index be the output. + * + * nb. The number of items must be less than 8, and the index must be less than the number of items. + */ +template QuinSelector(choices) { + signal input in[choices]; + signal input index; + signal output out; + + // Ensure that index < choices. + var lessThan = SafeLessThan(3)([index, choices]); + lessThan === 1; + + // Initialize an array to hold the results of equality checks. + var results[choices]; + + // For each item, check whether its index equals the input index. + // The result is multiplied by the corresponding input value. + for (var i = 0; i < choices; i++) { + var isEq = IsEqual()([i, index]); + + results[i] = isEq * in[i]; + } + + // Calculate the total sum of the results array. + out <== CalculateTotal(choices)(results); +} + +/** + * The output array contains the input items, with the the leaf inserted at the + * specified index. For example, if input = [0, 20, 30, 40], index = 3, and + * leaf = 10, the output will be [0, 20, 30, 10, 40]. + */ +template Splicer(numItems) { + // The number of output items (because only one item is inserted). + var NUM_OUTPUT_ITEMS = numItems + 1; + + signal input in[numItems]; + signal input leaf; + signal input index; + signal output out[NUM_OUTPUT_ITEMS]; + + // There is a loop where the goal is to assign values to the output signal. + // + // | output[0] | output[1] | output[2] | ... + // + // We can either assign the leaf, or an item from the `items` signal, to the output, using Mux1(). + // The Mux1's selector is 0 or 1 depending on whether the index is equal to the loop counter. + // + // i --> [IsEqual] <-- index + // | + // v + // leaf --> [Mux1] <-- + // | + // v + // output[m] + // + // To obtain the value from , we need to compute an item + // index (let it be `s`). + // 1. if index = 2 and i = 0, then s = 0 + // 2. if index = 2 and i = 1, then s = 1 + // 3. if index = 2 and i = 2, then s = 2 + // 4. if index = 2 and i = 3, then s = 2 + // 5. if index = 2 and i = 4, then s = 3 + // We then wire `s`, as well as each item in `in` to a QuinSelector. + // The output signal from the QuinSelector is and gets + // wired to Mux1 (as above). + + for (var i = 0; i < NUM_OUTPUT_ITEMS; i++) { + // Determines if current index is greater than the insertion index. + var isAfterInsertPoint = SafeGreaterThan(3)([i, index]); + + // Calculates correct index for original items, adjusting for leaf insertion. + var adjustedIndex = i - isAfterInsertPoint; + + // Selects item from the original array or the leaf for insertion. + var selected = QuinSelector(NUM_OUTPUT_ITEMS)([in[0], in[1], in[2], in[3], 0], adjustedIndex); + var isEq = IsEqual()([index, i]); + var mux = Mux1()([selected, leaf], isEq); + + out[i] <== mux; + } +} + +/** + * Computes the root of an IQT given a leaf, its path, and sibling nodes at each level of the tree. + * It iteratively incorporates the leaf or the hash from the previous level with sibling nodes using + * the Splicer to place the leaf or hash at the correct position based on path_index. + * Then, it hashes these values together with PoseidonHasher to move up the tree. + * This process repeats for each level (levels) of the tree, culminating in the computation of the tree's root. + */ +template QuinTreeInclusionProof(levels) { + var LEAVES_PER_NODE = 5; + var LEAVES_PER_PATH_LEVEL = LEAVES_PER_NODE - 1; + + signal input leaf; + signal input path_index[levels]; + signal input path_elements[levels][LEAVES_PER_PATH_LEVEL]; + signal output root; + + var currentLeaf = leaf; + + // Iteratively hash each level of path_elements with the leaf or previous hash + for (var i = 0; i < levels; i++) { + var splicedLeaf[LEAVES_PER_NODE] = Splicer(LEAVES_PER_PATH_LEVEL)( + [path_elements[i][0], path_elements[i][1], path_elements[i][2], path_elements[i][3]], + currentLeaf, + path_index[i] + ); + + currentLeaf = PoseidonHasher(5)([ + splicedLeaf[0], + splicedLeaf[1], + splicedLeaf[2], + splicedLeaf[3], + splicedLeaf[4] + ]); + } + + root <== currentLeaf; +} + +/** + * Verifies if a given leaf exists within an IQT. + * Takes a leaf, its path to the root (specified by indices and path elements), + * and the root itself, to verify the leaf's inclusion within the tree. + */ +template QuinLeafExists(levels){ + var LEAVES_PER_NODE = 5; + var LEAVES_PER_PATH_LEVEL = LEAVES_PER_NODE - 1; + + signal input leaf; + signal input path_index[levels]; + signal input path_elements[levels][LEAVES_PER_PATH_LEVEL]; + signal input root; + + // Verify the Merkle path. + var verifier = QuinTreeInclusionProof(levels)(leaf, path_index, path_elements); + + root === verifier; +} + +/** + * Checks if a list of leaves exists within an IQT, leveraging the PoseidonT6 + * circuit for hashing. This can be used to verify the presence of multiple leaves. + */ +template QuinBatchLeavesExists(levels, batchLevels) { + var LEAVES_PER_NODE = 5; + var LEAVES_PER_PATH_LEVEL = LEAVES_PER_NODE - 1; + var LEAVES_PER_BATCH = LEAVES_PER_NODE ** batchLevels; + + signal input root; + signal input leaves[LEAVES_PER_BATCH]; + signal input path_index[levels - batchLevels]; + signal input path_elements[levels - batchLevels][LEAVES_PER_PATH_LEVEL]; + + // Compute the subroot (= leaf). + var subroot = QuinCheckRoot(batchLevels)(leaves); + + // Check if the Merkle path is valid + QuinLeafExists(levels - batchLevels)(subroot, path_index, path_elements, root); +} + +/** + * Calculates the path indices required for Merkle proof verifications (e.g., QuinTreeInclusionProof, QuinLeafExists). + * Given a node index within an IQT and the total tree levels, it outputs the path indices leading to that node. + * The template handles the modulo and division operations to break down the tree index into its constituent path indices. + * e.g., if the index is 30 and the number of levels is 4, the output should be [0, 1, 1, 0]. + */ +template QuinGeneratePathIndices(levels) { + var BASE = 5; + + signal input in; + signal output out[levels]; + signal n[levels + 1]; + + var m = in; + var results[levels]; + + for (var i = 0; i < levels; i++) { + // circom's best practices suggests to avoid using <-- unless you + // are aware of what's going on. This is the only way to do modulo operation. + n[i] <-- m; + out[i] <-- m % BASE; + m = m \ BASE; + } + + n[levels] <-- m; + + for (var i = 0; i < levels; i++) { + // Check that each output element is less than the base. + var lessThan = SafeLessThan(3)([out[i], BASE]); + lessThan === 1; + + // Re-compute the total sum. + results[i] = out[i] * (BASE ** i); + } + + // Check that the total sum matches the index. + var calculateTotal = CalculateTotal(levels)(results); + + calculateTotal === in; +} + +/** + * Computes the root of a quintary Merkle tree given a list of leaves. + * This template constructs a Merkle tree with each node having 5 children (quintary) + * and computes the root by hashing with Poseidon the leaves and intermediate nodes in the given order. + * The computation is performed by first hashing groups of 5 leaves to form the bottom layer of nodes, + * then recursively hashing groups of these nodes to form the next layer, and so on, until the root is computed. + */ +template QuinCheckRoot(levels) { + var LEAVES_PER_NODE = 5; + var totalLeaves = LEAVES_PER_NODE ** levels; + var numLeafHashers = LEAVES_PER_NODE ** (levels - 1); + + signal input leaves[totalLeaves]; + signal output root; + + // Determine the total number of hashers. + var numHashers = 0; + for (var i = 0; i < levels; i++) { + numHashers += LEAVES_PER_NODE ** i; + } + + var hashers[numHashers]; + + // Initialize hashers for the leaves. + for (var i = 0; i < numLeafHashers; i++) { + hashers[i] = PoseidonHasher(5)([ + leaves[i*LEAVES_PER_NODE+0], + leaves[i*LEAVES_PER_NODE+1], + leaves[i*LEAVES_PER_NODE+2], + leaves[i*LEAVES_PER_NODE+3], + leaves[i*LEAVES_PER_NODE+4] + ]); + } + + // Initialize hashers for intermediate nodes and compute the root. + var k = 0; + for (var i = numLeafHashers; i < numHashers; i++) { + hashers[i] = PoseidonHasher(5)([ + hashers[k*LEAVES_PER_NODE+0], + hashers[k*LEAVES_PER_NODE+1], + hashers[k*LEAVES_PER_NODE+2], + hashers[k*LEAVES_PER_NODE+3], + hashers[k*LEAVES_PER_NODE+4] + ]); + k++; + } + + root <== hashers[numHashers-1]; +} \ No newline at end of file diff --git a/circuits/circom/messageToCommand.circom b/circuits/circom/messageToCommand.circom index fac1bba455..9ff017e575 100644 --- a/circuits/circom/messageToCommand.circom +++ b/circuits/circom/messageToCommand.circom @@ -1,29 +1,34 @@ pragma circom 2.0.0; -// circomlib import +// from @zk-kit/circuits package. +include "./ecdh.circom"; +include "./unpack-element.circom"; +include "./poseidon-cipher.circom"; +// from circomlib. include "./bitify.circom"; -// @zk-kit import +// local. include "./hashers.circom"; -// local imports -include "./ecdh.circom"; -include "./unpackElement.circom"; - -// template that converts a MACI message -// to a command (decrypts it) +/** + * Converts a MACI message to a command by decrypting it. + * Processes encrypted MACI messages into structured MACI commands + * by decrypting using a shared key derived from ECDH. After decryption, + * unpacks and assigns decrypted values to specific command components. + */ template MessageToCommand() { var MSG_LENGTH = 7; var PACKED_CMD_LENGTH = 4; var UNPACKED_CMD_LENGTH = 8; + var UNPACK_ELEM_LENGTH = 5; + var DECRYPTED_LENGTH = 9; + var MESSAGE_PARTS = 11; - // the message is an array of 11 parts - signal input message[11]; - // we have the encryption private key + // The message is an array of 11 parts. + signal input message[MESSAGE_PARTS]; signal input encPrivKey; - // and the encryption public key signal input encPubKey[2]; - // we output all of the parts of the command + // Command parts. signal output stateIndex; signal output newPubKey[2]; signal output voteOptionIndex; @@ -33,54 +38,46 @@ template MessageToCommand() { signal output salt; signal output sigR8[2]; signal output sigS; - // and also the packed command + // Packed command. signal output packedCommandOut[PACKED_CMD_LENGTH]; - // generate the shared key so we can decrypt - // the message with it - component ecdh = Ecdh(); - ecdh.privKey <== encPrivKey; - ecdh.pubKey[0] <== encPubKey[0]; - ecdh.pubKey[1] <== encPubKey[1]; + // Generate the shared key for decrypting the message. + var ecdh[2] = Ecdh()(encPrivKey, encPubKey); - // decrypt the message using poseidon decryption - component decryptor = PoseidonDecryptWithoutCheck(MSG_LENGTH); - decryptor.key[0] <== ecdh.sharedKey[0]; - decryptor.key[1] <== ecdh.sharedKey[1]; - decryptor.nonce <== 0; - for (var i = 1; i < 11; i++) { - // the first one is msg type, skip - decryptor.ciphertext[i-1] <== message[i]; - } + // Decrypt the message using Poseidon decryption. + var decryptor[DECRYPTED_LENGTH] = PoseidonDecryptWithoutCheck(MSG_LENGTH)( + [ + // nb. the first one is the msg type => skip. + message[1], message[2], message[3], message[4], + message[5], message[6], message[7], message[8], + message[9], message[10] + ], + 0, + ecdh + ); - // save the decrypted message into a packed command signal + // Save the decrypted message into a packed command signal. signal packedCommand[PACKED_CMD_LENGTH]; for (var i = 0; i < PACKED_CMD_LENGTH; i++) { - packedCommand[i] <== decryptor.decrypted[i]; + packedCommand[i] <== decryptor[i]; } - component unpack = UnpackElement(5); - unpack.in <== packedCommand[0]; + var unpack[UNPACK_ELEM_LENGTH] = UnpackElement(UNPACK_ELEM_LENGTH)(packedCommand[0]); - // all of the below were packed - // into the first element - stateIndex <== unpack.out[4]; - voteOptionIndex <== unpack.out[3]; - newVoteWeight <== unpack.out[2]; - nonce <== unpack.out[1]; - pollId <== unpack.out[0]; + // Everything below were packed into the first element. + stateIndex <== unpack[4]; + voteOptionIndex <== unpack[3]; + newVoteWeight <== unpack[2]; + nonce <== unpack[1]; + pollId <== unpack[0]; newPubKey[0] <== packedCommand[1]; newPubKey[1] <== packedCommand[2]; salt <== packedCommand[3]; - sigR8[0] <== decryptor.decrypted[4]; - sigR8[1] <== decryptor.decrypted[5]; - sigS <== decryptor.decrypted[6]; + sigR8[0] <== decryptor[4]; + sigR8[1] <== decryptor[5]; + sigS <== decryptor[6]; - // this could be removed and instead - // use packedCommand as output - for (var i = 0; i < PACKED_CMD_LENGTH; i++) { - packedCommandOut[i] <== packedCommand[i]; - } + packedCommandOut <== packedCommand; } diff --git a/circuits/circom/messageValidator.circom b/circuits/circom/messageValidator.circom index fae47b6a87..017e53b71a 100644 --- a/circuits/circom/messageValidator.circom +++ b/circuits/circom/messageValidator.circom @@ -1,91 +1,81 @@ pragma circom 2.0.0; -// local imports +// from @zk-kit/circuits package. +include "./safe-comparators.circom"; +// local. include "./verifySignature.circom"; -include "./utils.circom"; -// template that validates whether a message -// is valid or not +/** + * Checks if a MACI message is valid or not. + * This template supports the Quadratic Voting (QV). + */ template MessageValidator() { - // a) Whether the state leaf index is valid + var PACKED_CMD_LENGTH = 4; + signal input stateTreeIndex; - // how many signups we have in the state tree + // Number of signups in the state tree. signal input numSignUps; - // we check that the state tree index is <= than the number of signups - // as first validation - // it is <= because the state tree index is 1-based - // 0 is for blank state leaf then 1 for the first actual user - // which is where the numSignUps starts - component validStateLeafIndex = SafeLessEqThan(252); - validStateLeafIndex.in[0] <== stateTreeIndex; - validStateLeafIndex.in[1] <== numSignUps; - - // b) Whether the max vote option tree index is correct signal input voteOptionIndex; signal input maxVoteOptions; - component validVoteOptionIndex = SafeLessThan(252); - validVoteOptionIndex.in[0] <== voteOptionIndex; - validVoteOptionIndex.in[1] <== maxVoteOptions; - - // c) Whether the nonce is correct signal input originalNonce; signal input nonce; - component validNonce = IsEqual(); - // the nonce should be previous nonce + 1 - validNonce.in[0] <== originalNonce + 1; - validNonce.in[1] <== nonce; - - var PACKED_CMD_LENGTH = 4; - // d) Whether the signature is correct signal input cmd[PACKED_CMD_LENGTH]; signal input pubKey[2]; signal input sigR8[2]; signal input sigS; - - component validSignature = VerifySignature(); - validSignature.pubKey[0] <== pubKey[0]; - validSignature.pubKey[1] <== pubKey[1]; - validSignature.R8[0] <== sigR8[0]; - validSignature.R8[1] <== sigR8[1]; - validSignature.S <== sigS; - for (var i = 0; i < PACKED_CMD_LENGTH; i++) { - validSignature.preimage[i] <== cmd[i]; - } - - // e) Whether the state leaf was inserted before the Poll period ended signal input slTimestamp; signal input pollEndTimestamp; - component validTimestamp = SafeLessEqThan(252); - validTimestamp.in[0] <== slTimestamp; - validTimestamp.in[1] <== pollEndTimestamp; - - // f) Whether there are sufficient voice credits signal input currentVoiceCreditBalance; signal input currentVotesForOption; signal input voteWeight; + signal output isValid; + + // Check (1) - The state leaf index must be valid. + // The check ensure that the stateTreeIndex <= numSignUps as first validation. + // Must be <= because the stateTreeIndex is 1-based. Zero is for blank state leaf + // while 1 is for the first actual user matching the numSignUps start. + var validStateLeafIndex = SafeLessEqThan(252)([stateTreeIndex, numSignUps]); + + // Check (2) - The max vote option tree index must be correct. + var validVoteOptionIndex = SafeLessThan(252)([voteOptionIndex, maxVoteOptions]); + + // Check (3) - The nonce must be correct. + var validNonce = IsEqual()([originalNonce + 1, nonce]); - // Check that voteWeight is < sqrt(field size), so voteWeight ^ 2 will not - // overflow - component validVoteWeight = SafeLessEqThan(252); - validVoteWeight.in[0] <== voteWeight; - validVoteWeight.in[1] <== 147946756881789319005730692170996259609; + // Check (4) - The signature must be correct. + var validSignature = VerifySignature()(pubKey, sigR8, sigS, cmd); - // Check that currentVoiceCreditBalance + (currentVotesForOption ** 2) >= (voteWeight ** 2) + // Check (5) - The state leaf must be inserted before the Poll period end. + var validTimestamp = SafeLessEqThan(252)([slTimestamp, pollEndTimestamp]); + + // Check (6) - There must be sufficient voice credits. + // The check ensure that the voteWeight is < sqrt(field size) + // so that voteWeight ^ 2 will not overflow. + var validVoteWeight = SafeLessEqThan(252)([voteWeight, 147946756881789319005730692170996259609]); + + // Check (7) - Check the current voice credit balance. + // The check ensure that currentVoiceCreditBalance + (currentVotesForOption ** 2) >= (voteWeight ** 2) // @note what is the difference between voteWeight and currentVotesForOption? - component sufficientVoiceCredits = SafeGreaterEqThan(252); - sufficientVoiceCredits.in[0] <== (currentVotesForOption * currentVotesForOption) + currentVoiceCreditBalance; - sufficientVoiceCredits.in[1] <== voteWeight * voteWeight; + var sufficientVoiceCredits = SafeGreaterEqThan(252)( + [ + (currentVotesForOption * currentVotesForOption) + currentVoiceCreditBalance, + voteWeight * voteWeight + ] + ); - // if all 7 checks are correct then is IsValid = 1 - component validUpdate = IsEqual(); - validUpdate.in[0] <== 7; - validUpdate.in[1] <== validSignature.valid + - sufficientVoiceCredits.out + - validVoteWeight.out + - validNonce.out + - validStateLeafIndex.out + - validTimestamp.out + - validVoteOptionIndex.out; - signal output isValid; - isValid <== validUpdate.out; + // When all seven checks are correct, then isValid = 1. + var validUpdate = IsEqual()( + [ + 7, + validSignature + + sufficientVoiceCredits + + validVoteWeight + + validNonce + + validStateLeafIndex + + validTimestamp + + validVoteOptionIndex + ] + ); + + isValid <== validUpdate; } diff --git a/circuits/circom/messageValidatorNonQv.circom b/circuits/circom/messageValidatorNonQv.circom index eedd317a74..7b1689477e 100644 --- a/circuits/circom/messageValidatorNonQv.circom +++ b/circuits/circom/messageValidatorNonQv.circom @@ -1,84 +1,74 @@ pragma circom 2.0.0; -// local imports +// from @zk-kit/circuits package. +include "./safe-comparators.circom"; +// local. include "./verifySignature.circom"; -include "./utils.circom"; -// template that validates whether a message -// is valid or not -// @note it does not do quadratic voting +/** + * Checks if a MACI message is valid or not. + * This template does not support the Quadratic Voting (QV). + */ template MessageValidatorNonQv() { - // a) Whether the state leaf index is valid + var PACKED_CMD_LENGTH = 4; + signal input stateTreeIndex; - // how many signups we have in the state tree + // Number of signups in the state tree. signal input numSignUps; - // we check that the state tree index is <= than the number of signups - // as first validation - // it is <= because the state tree index is 1-based - // 0 is for blank state leaf then 1 for the first actual user - // which is where the numSignUps starts - component validStateLeafIndex = SafeLessEqThan(252); - validStateLeafIndex.in[0] <== stateTreeIndex; - validStateLeafIndex.in[1] <== numSignUps; - - // b) Whether the max vote option tree index is correct signal input voteOptionIndex; signal input maxVoteOptions; - component validVoteOptionIndex = SafeLessThan(252); - validVoteOptionIndex.in[0] <== voteOptionIndex; - validVoteOptionIndex.in[1] <== maxVoteOptions; - - // c) Whether the nonce is correct signal input originalNonce; signal input nonce; - component validNonce = IsEqual(); - // the nonce should be previous nonce + 1 - validNonce.in[0] <== originalNonce + 1; - validNonce.in[1] <== nonce; - - var PACKED_CMD_LENGTH = 4; - // d) Whether the signature is correct signal input cmd[PACKED_CMD_LENGTH]; signal input pubKey[2]; signal input sigR8[2]; signal input sigS; - - component validSignature = VerifySignature(); - validSignature.pubKey[0] <== pubKey[0]; - validSignature.pubKey[1] <== pubKey[1]; - validSignature.R8[0] <== sigR8[0]; - validSignature.R8[1] <== sigR8[1]; - validSignature.S <== sigS; - for (var i = 0; i < PACKED_CMD_LENGTH; i++) { - validSignature.preimage[i] <== cmd[i]; - } - - // e) Whether the state leaf was inserted before the Poll period ended signal input slTimestamp; signal input pollEndTimestamp; - component validTimestamp = SafeLessEqThan(252); - validTimestamp.in[0] <== slTimestamp; - validTimestamp.in[1] <== pollEndTimestamp; - - // f) Whether there are sufficient voice credits signal input currentVoiceCreditBalance; signal input currentVotesForOption; signal input voteWeight; + signal output isValid; - // Check that currentVoiceCreditBalance + (currentVotesForOption) >= (voteWeight) - component sufficientVoiceCredits = SafeGreaterEqThan(252); - sufficientVoiceCredits.in[0] <== currentVotesForOption + currentVoiceCreditBalance; - sufficientVoiceCredits.in[1] <== voteWeight; + // Check (1) - The state leaf index must be valid. + // The check ensure that the stateTreeIndex <= numSignUps as first validation. + // Must be <= because the stateTreeIndex is 1-based. Zero is for blank state leaf + // while 1 is for the first actual user matching the numSignUps start. + var validStateLeafIndex = SafeLessEqThan(252)([stateTreeIndex, numSignUps]); - // if all 6 checks are correct then is IsValid = 1 - component validUpdate = IsEqual(); - validUpdate.in[0] <== 6; - validUpdate.in[1] <== validSignature.valid + - sufficientVoiceCredits.out + - validNonce.out + - validStateLeafIndex.out + - validTimestamp.out + - validVoteOptionIndex.out; - signal output isValid; - isValid <== validUpdate.out; + // Check (2) - The max vote option tree index must be correct. + var validVoteOptionIndex = SafeLessThan(252)([voteOptionIndex, maxVoteOptions]); + + // Check (3) - The nonce must be correct. + var validNonce = IsEqual()([originalNonce + 1, nonce]); + + // Check (4) - The signature must be correct. + var validSignature = VerifySignature()(pubKey, sigR8, sigS, cmd); + + // Check (5) - The state leaf must be inserted before the Poll period end. + var validTimestamp = SafeLessEqThan(252)([slTimestamp, pollEndTimestamp]); + + // Check (6) - There must be sufficient voice credits. + // The check ensure that currentVoiceCreditBalance + (currentVotesForOption) >= (voteWeight). + var sufficientVoiceCredits = SafeGreaterEqThan(252)( + [ + currentVotesForOption + currentVoiceCreditBalance, + voteWeight + ] + ); + + // When all six checks are correct, then isValid = 1. + var validUpdate = IsEqual()( + [ + 6, + validSignature + + sufficientVoiceCredits + + validNonce + + validStateLeafIndex + + validTimestamp + + validVoteOptionIndex + ] + ); + + isValid <== validUpdate; } diff --git a/circuits/circom/privToPubKey.circom b/circuits/circom/privToPubKey.circom index 4a8c3cade5..4ffeb14f61 100644 --- a/circuits/circom/privToPubKey.circom +++ b/circuits/circom/privToPubKey.circom @@ -1,31 +1,28 @@ pragma circom 2.0.0; -// circomlib imports +// from circomlib. include "./bitify.circom"; include "./escalarmulfix.circom"; -// convert a private key to a public key -// @note the basepoint is the base point of the baby jubjub curve +/** + * Converts a private key to a public key on the BabyJubJub curve. + * The input private key needs to be hashed and then pruned before. + */ template PrivToPubKey() { - // Needs to be hashed, and then pruned before supplying it to the circuit - signal input privKey; - signal output pubKey[2]; - - // convert the private key to bits - component privBits = Num2Bits(253); - privBits.in <== privKey; - + // The base point of the BabyJubJub curve. var BASE8[2] = [ 5299619240641551281634865583518297030282874472190772894086521144482721001553, 16950150798460657717958625567821834550301663161624707787222815936182638968203 ]; - // perform scalar multiplication with the basepoint - component mulFix = EscalarMulFix(253, BASE8); - for (var i = 0; i < 253; i++) { - mulFix.e[i] <== privBits.out[i]; - } + signal input privKey; + signal output pubKey[2]; + + // Convert the private key to bits. + var privBits[253] = Num2Bits(253)(privKey); + + // Perform scalar multiplication with the basepoint. + var mulFix[2] = EscalarMulFix(253, BASE8)(privBits); - pubKey[0] <== mulFix.out[0]; - pubKey[1] <== mulFix.out[1]; + pubKey <== mulFix; } diff --git a/circuits/circom/processMessages.circom b/circuits/circom/processMessages.circom index 4e463046bd..ca876d76eb 100644 --- a/circuits/circom/processMessages.circom +++ b/circuits/circom/processMessages.circom @@ -8,8 +8,10 @@ include "./hashers.circom"; include "./messageToCommand.circom"; include "./privToPubKey.circom"; include "./stateLeafAndBallotTransformer.circom"; -include "./trees/incrementalQuinTree.circom"; -include "./utils.circom"; +include "./iqt.circom"; +// zk-kit imports +include "./safe-comparators.circom"; +include "./processMessages.circom"; // Proves the correctness of processing a batch of messages. template ProcessMessages( diff --git a/circuits/circom/processMessagesNonQv.circom b/circuits/circom/processMessagesNonQv.circom index f34f28f8ab..609a9d2828 100644 --- a/circuits/circom/processMessagesNonQv.circom +++ b/circuits/circom/processMessagesNonQv.circom @@ -8,8 +8,9 @@ include "./hashers.circom"; include "./messageToCommand.circom"; include "./privToPubKey.circom"; include "./stateLeafAndBallotTransformerNonQv.circom"; -include "./trees/incrementalQuinTree.circom"; -include "./utils.circom"; +include "./iqt.circom"; +// zk-kit imports +include "./safe-comparators.circom"; include "./processMessages.circom"; // Proves the correctness of processing a batch of messages. diff --git a/circuits/circom/subsidy.circom b/circuits/circom/subsidy.circom index 096def57df..72575ac3cc 100644 --- a/circuits/circom/subsidy.circom +++ b/circuits/circom/subsidy.circom @@ -2,14 +2,15 @@ pragma circom 2.0.0; // circomlib import include "./comparators.circom"; +// zk-kit import +include "./float.circom"; +include "./unpack-element.circom"; // local imports -include "./trees/incrementalQuinTree.circom"; -include "./trees/calculateTotal.circom"; -include "./trees/checkRoot.circom"; +include "./iqt.circom"; +include "./calculateTotal.circom"; include "./hashers.circom"; -include "./unpackElement.circom"; -include "./float.circom"; +include "./unpack-element.circom"; /* * calculate subsidy, batch by batch. diff --git a/circuits/circom/tallyVotes.circom b/circuits/circom/tallyVotes.circom index 93789772d6..2ba0ae2466 100644 --- a/circuits/circom/tallyVotes.circom +++ b/circuits/circom/tallyVotes.circom @@ -4,11 +4,10 @@ pragma circom 2.0.0; include "./comparators.circom"; // local imports -include "./trees/incrementalQuinTree.circom"; -include "./trees/calculateTotal.circom"; -include "./trees/checkRoot.circom"; +include "./iqt.circom"; +include "./calculateTotal.circom"; include "./hashers.circom"; -include "./unpackElement.circom"; +include "./unpack-element.circom"; // Tally votes in the ballots, batch by batch. template TallyVotes( diff --git a/circuits/circom/tallyVotesNonQv.circom b/circuits/circom/tallyVotesNonQv.circom index 6d870fd66b..021afb1753 100644 --- a/circuits/circom/tallyVotesNonQv.circom +++ b/circuits/circom/tallyVotesNonQv.circom @@ -4,11 +4,10 @@ pragma circom 2.0.0; include "./comparators.circom"; // local imports -include "./trees/incrementalQuinTree.circom"; -include "./trees/calculateTotal.circom"; -include "./trees/checkRoot.circom"; +include "./iqt.circom"; +include "./calculateTotal.circom"; include "./hashers.circom"; -include "./unpackElement.circom"; +include "./unpack-element.circom"; include "./tallyVotes.circom"; // Tally votes in the ballots, batch by batch. diff --git a/circuits/circom/test/ProcessMessages_10-2-1-2_test.circom b/circuits/circom/test/ProcessMessages_10-2-1-2_test.circom index 66f0afb572..07fae34684 100644 --- a/circuits/circom/test/ProcessMessages_10-2-1-2_test.circom +++ b/circuits/circom/test/ProcessMessages_10-2-1-2_test.circom @@ -1,12 +1,6 @@ +// auto-generated by circomkit pragma circom 2.0.0; include "../processMessages.circom"; -/* -stateTreeDepth, -msgTreeDepth, -msgSubTreeDepth -voteOptionTreeDepth, -*/ - -component main {public [inputHash]} = ProcessMessages(10, 2, 1, 2); +component main {public[inputHash]} = ProcessMessages(10, 2, 1, 2); diff --git a/circuits/circom/test/SubsidyPerBatch_10-1-2_test.circom b/circuits/circom/test/SubsidyPerBatch_10-1-2_test.circom index e5f3123ea4..bdcbd1d332 100644 --- a/circuits/circom/test/SubsidyPerBatch_10-1-2_test.circom +++ b/circuits/circom/test/SubsidyPerBatch_10-1-2_test.circom @@ -1,5 +1,6 @@ +// auto-generated by circomkit pragma circom 2.0.0; include "../subsidy.circom"; -component main {public [inputHash]} = SubsidyPerBatch(10, 1, 2); +component main {public[inputHash]} = SubsidyPerBatch(10, 1, 2); diff --git a/circuits/circom/test/TallyVotes_10-1-2_test.circom b/circuits/circom/test/TallyVotes_10-1-2_test.circom index fdd5bfb555..db24eccfcd 100644 --- a/circuits/circom/test/TallyVotes_10-1-2_test.circom +++ b/circuits/circom/test/TallyVotes_10-1-2_test.circom @@ -1,5 +1,6 @@ +// auto-generated by circomkit pragma circom 2.0.0; include "../tallyVotes.circom"; -component main {public [inputHash]} = TallyVotes(10, 1, 2); +component main {public[inputHash]} = TallyVotes(10, 1, 2); diff --git a/circuits/circom/trees/calculateTotal.circom b/circuits/circom/trees/calculateTotal.circom deleted file mode 100644 index 8284db11aa..0000000000 --- a/circuits/circom/trees/calculateTotal.circom +++ /dev/null @@ -1,17 +0,0 @@ -pragma circom 2.0.0; - -// This circuit returns the sum of the inputs. -// n must be greater than 0. -template CalculateTotal(n) { - signal input nums[n]; - signal output sum; - - signal sums[n]; - sums[0] <== nums[0]; - - for (var i=1; i < n; i++) { - sums[i] <== sums[i - 1] + nums[i]; - } - - sum <== sums[n - 1]; -} diff --git a/circuits/circom/trees/checkRoot.circom b/circuits/circom/trees/checkRoot.circom deleted file mode 100644 index 64291656b9..0000000000 --- a/circuits/circom/trees/checkRoot.circom +++ /dev/null @@ -1,63 +0,0 @@ -pragma circom 2.0.0; - -// local import -include "../hashers.circom"; - -// Given a list of leaves, compute the root of the merkle tree -// by inserting all the leaves into the tree in the given -// order. -template QuinCheckRoot(levels) { - var LEAVES_PER_NODE = 5; - - // The total number of leaves - var totalLeaves = LEAVES_PER_NODE ** levels; - - // The number of Hasher5 components which will be used to hash the - // leaves - var numLeafHashers = LEAVES_PER_NODE ** (levels - 1); - - // Inputs to the snark - signal input leaves[totalLeaves]; - - // The output - signal output root; - - var i; - var j; - - // The total number of hashers - var numHashers = 0; - for (i = 0; i < levels; i++) { - numHashers += LEAVES_PER_NODE ** i; - } - - var hashers[numHashers]; - - // Wire the leaf values into the leaf hashers - for (i = 0; i < numLeafHashers; i++){ - hashers[i] = PoseidonHasher(5)([ - leaves[i*LEAVES_PER_NODE+0], - leaves[i*LEAVES_PER_NODE+1], - leaves[i*LEAVES_PER_NODE+2], - leaves[i*LEAVES_PER_NODE+3], - leaves[i*LEAVES_PER_NODE+4] - ]); - } - - // Wire the outputs of the leaf hashers to the intermediate hasher inputs - var k = 0; - for (i = numLeafHashers; i < numHashers; i++) { - hashers[i] = PoseidonHasher(5)([ - hashers[k*LEAVES_PER_NODE+0], - hashers[k*LEAVES_PER_NODE+1], - hashers[k*LEAVES_PER_NODE+2], - hashers[k*LEAVES_PER_NODE+3], - hashers[k*LEAVES_PER_NODE+4] - ]); - k++; - } - - // Wire the output of the final hash to this circuit's output - root <== hashers[numHashers-1]; -} - diff --git a/circuits/circom/trees/incrementalMerkleTree.circom b/circuits/circom/trees/incrementalMerkleTree.circom deleted file mode 100644 index 3e15fda9f5..0000000000 --- a/circuits/circom/trees/incrementalMerkleTree.circom +++ /dev/null @@ -1,122 +0,0 @@ -pragma circom 2.0.0; - -// Refer to: -// https://github.com/peppersec/tornado-mixer/blob/master/circuits/merkleTree.circom -// https://github.com/semaphore-protocol/semaphore/blob/audited/circuits/circom/semaphore-base.circom -include "./mux1.circom"; - -// local import -include "../hashers.circom"; - -// recompute a merkle root from a leaf and a path -template MerkleTreeInclusionProof(n_levels) { - signal input leaf; - signal input path_index[n_levels]; - signal input path_elements[n_levels][1]; - signal output root; - - component hashers[n_levels]; - component mux[n_levels]; - - signal levelHashes[n_levels + 1]; - levelHashes[0] <== leaf; - - for (var i = 0; i < n_levels; i++) { - // Should be 0 or 1 - path_index[i] * (1 - path_index[i]) === 0; - - mux[i] = MultiMux1(2); - - mux[i].c[0][0] <== levelHashes[i]; - mux[i].c[0][1] <== path_elements[i][0]; - - mux[i].c[1][0] <== path_elements[i][0]; - mux[i].c[1][1] <== levelHashes[i]; - - mux[i].s <== path_index[i]; - hashers[i] = PoseidonHasher(2)([mux[i].out[0], mux[i].out[1]]); - - levelHashes[i + 1] <== hashers[i].hash; - } - - root <== levelHashes[n_levels]; -} - -// Ensures that a leaf exists within a merkletree with given `root` -template LeafExists(levels){ - - // levels is depth of tree - signal input leaf; - - signal input path_elements[levels][1]; - signal input path_index[levels]; - - signal input root; - - component merkletree = MerkleTreeInclusionProof(levels); - merkletree.leaf <== leaf; - for (var i = 0; i < levels; i++) { - merkletree.path_index[i] <== path_index[i]; - merkletree.path_elements[i][0] <== path_elements[i][0]; - } - - root === merkletree.root; -} - -// Given a Merkle root and a list of leaves, check if the root is the -// correct result of inserting all the leaves into the tree (in the given -// order) -template CheckRoot(levels) { - // Circom has some perticularities which limit the code patterns we can - // use. - - // You can only assign a value to a signal once. - - // A component's input signal must only be wired to another component's output - // signal. - - // Variables are only used for loops, declaring sizes of things, and anything - // that is not related to inputs of a circuit. - - // The total number of leaves - var totalLeaves = 2 ** levels; - - // The number of HashLeftRight components which will be used to hash the - // leaves - var numLeafHashers = totalLeaves / 2; - - // The number of HashLeftRight components which will be used to hash the - // output of the leaf hasher components - var numIntermediateHashers = numLeafHashers - 1; - - // Inputs to the snark - signal input leaves[totalLeaves]; - - // The output - signal output root; - - // The total number of hashers - var numHashers = totalLeaves - 1; - var hashers[numHashers]; - - // Instantiate all hashers - var i; - for (i=0; i < numHashers; i++) { - hashers[i] = PoseidonHasher(2); - } - - // Wire the leaf values into the leaf hashers - for (i=0; i < numLeafHashers; i++){ - hasher[i] = PoseidonHasher(2)([leaves[i * 2], leaves[i * 2 + 1]]) - } - - // Wire the outputs of the leaf hashers to the intermediate hasher inputs - var k = 0; - for (i=numLeafHashers; i [IsEqual] <-- index - | - v - leaf ---> [Mux1] <--- - | - v - output[m] - - To obtain the value from , we need to compute an item - index (let it be `s`). - - 1. if index = 2 and i = 0, then s = 0 - 2. if index = 2 and i = 1, then s = 1 - 3. if index = 2 and i = 2, then s = 2 - 4. if index = 2 and i = 3, then s = 2 - 5. if index = 2 and i = 4, then s = 3 - - We then wire `s`, as well as each item in `in` to a QuinSelector. - The output signal from the QuinSelector is and gets - wired to Mux1 (as above). - */ - for (i = 0; i < numItems + 1; i++) { - // greaterThen[i].out will be 1 if the i is greater than the index - greaterThan[i] = SafeGreaterThan(3); - greaterThan[i].in[0] <== i; - greaterThan[i].in[1] <== index; - - quinSelectors[i] = QuinSelector(numItems + 1); - - // Select the value from `in` at index i - greaterThan[i].out. - // e.g. if index = 2 and i = 1, greaterThan[i].out = 0, so 1 - 0 = 1 - // but if index = 2 and i = 3, greaterThan[i].out = 1, so 3 - 1 = 2 - quinSelectors[i].index <== i - greaterThan[i].out; - - for (j = 0; j < numItems; j++) { - quinSelectors[i].in[j] <== in[j]; - } - quinSelectors[i].in[numItems] <== 0; - - isLeafIndex[i] = IsEqual(); - isLeafIndex[i].in[0] <== index; - isLeafIndex[i].in[1] <== i; - - muxes[i] = Mux1(); - muxes[i].s <== isLeafIndex[i].out; - muxes[i].c[0] <== quinSelectors[i].out; - muxes[i].c[1] <== leaf; - - out[i] <== muxes[i].out; - } -} - -// Given a list of leaves, as well as the path to the root, -// compute the root -template QuinTreeInclusionProof(levels) { - // Each node has 5 leaves - var LEAVES_PER_NODE = 5; - var LEAVES_PER_PATH_LEVEL = LEAVES_PER_NODE - 1; - - signal input leaf; - signal input path_index[levels]; - signal input path_elements[levels][LEAVES_PER_PATH_LEVEL]; - signal output root; - - var i; - var j; - - var hashers[levels]; - component splicers[levels]; - - // Hash the first level of path_elements - splicers[0] = Splicer(LEAVES_PER_PATH_LEVEL); - splicers[0].index <== path_index[0]; - splicers[0].leaf <== leaf; - - for (i = 0; i < LEAVES_PER_PATH_LEVEL; i++) { - splicers[0].in[i] <== path_elements[0][i]; - } - - hashers[0] = PoseidonHasher(5)([ - splicers[0].out[0], - splicers[0].out[1], - splicers[0].out[2], - splicers[0].out[3], - splicers[0].out[4] - ]); - - // Hash each level of path_elements - for (i = 1; i < levels; i++) { - splicers[i] = Splicer(LEAVES_PER_PATH_LEVEL); - splicers[i].index <== path_index[i]; - - splicers[i].leaf <== hashers[i - 1]; - - for (j = 0; j < LEAVES_PER_PATH_LEVEL; j++) { - splicers[i].in[j] <== path_elements[i][j]; - } - - hashers[i] = PoseidonHasher(5)([ - splicers[i].out[0], - splicers[i].out[1], - splicers[i].out[2], - splicers[i].out[3], - splicers[i].out[4] - ]); - } - - root <== hashers[levels - 1]; -} - -// Ensures that a leaf exists within a quintree with given `root` -template QuinLeafExists(levels){ - var LEAVES_PER_NODE = 5; - var LEAVES_PER_PATH_LEVEL = LEAVES_PER_NODE - 1; - - var i; - var j; - - signal input leaf; - signal input path_elements[levels][LEAVES_PER_PATH_LEVEL]; - signal input path_index[levels]; - signal input root; - - // Verify the Merkle path - component verifier = QuinTreeInclusionProof(levels); - verifier.leaf <== leaf; - for (i = 0; i < levels; i++) { - verifier.path_index[i] <== path_index[i]; - for (j = 0; j < LEAVES_PER_PATH_LEVEL; j++) { - verifier.path_elements[i][j] <== path_elements[i][j]; - } - } - - root === verifier.root; -} - -// Given a list of leaves, check whether they exist in -// a quinary merkle tree -template QuinBatchLeavesExists(levels, batchLevels) { - // Compute the root of a subtree of leaves, and then check whether the - // subroot exists in the main tree - - var LEAVES_PER_NODE = 5; - var LEAVES_PER_PATH_LEVEL = LEAVES_PER_NODE - 1; - var LEAVES_PER_BATCH = LEAVES_PER_NODE ** batchLevels; - - // The main root - signal input root; - - // The batch of leaves - signal input leaves[LEAVES_PER_BATCH]; - - // The Merkle path from the subroot to the main root - signal input path_index[levels - batchLevels]; - signal input path_elements[levels - batchLevels][LEAVES_PER_PATH_LEVEL]; - - // Compute the subroot - component qcr = QuinCheckRoot(batchLevels); - for (var i = 0; i < LEAVES_PER_BATCH; i++) { - qcr.leaves[i] <== leaves[i]; - } - - // Check if the Merkle path is valid - component qle = QuinLeafExists(levels - batchLevels); - - // The subroot is the leaf - qle.leaf <== qcr.root; - qle.root <== root; - for (var i = 0; i < levels - batchLevels; i++) { - qle.path_index[i] <== path_index[i]; - for (var j = 0; j < LEAVES_PER_PATH_LEVEL; j++) { - qle.path_elements[i][j] <== path_elements[i][j]; - } - } -} - - -// Given a tree index, generate the indices which QuinTreeInclusionProof and -// QuinLeafExists require. e.g. if the index is 30 and the number of levels is -// 4, the output should be [0, 1, 1, 0] -template QuinGeneratePathIndices(levels) { - var BASE = 5; - signal input in; - signal output out[levels]; - - var m = in; - signal n[levels + 1]; - for (var i = 0; i < levels; i++) { - // circom's best practices state that we should avoid using <-- unless - // we know what we are doing. But this is the only way to perform the - // modulo operation. - - n[i] <-- m; - - out[i] <-- m % BASE; - - m = m \ BASE; - } - - n[levels] <-- m; - - component leq[levels]; - component sum = CalculateTotal(levels); - for (var i = 0; i < levels; i++) { - // Check that each output element is less than the base - leq[i] = SafeLessThan(3); - leq[i].in[0] <== out[i]; - leq[i].in[1] <== BASE; - leq[i].out === 1; - - // Re-compute the total sum - sum.nums[i] <== out[i] * (BASE ** i); - } - - // Check that the total sum matches the index - sum.sum === in; -} diff --git a/circuits/circom/unpackElement.circom b/circuits/circom/unpackElement.circom deleted file mode 100644 index 9f3d41bd57..0000000000 --- a/circuits/circom/unpackElement.circom +++ /dev/null @@ -1,26 +0,0 @@ -pragma circom 2.0.0; - -// circomlib import -include "./bitify.circom"; - -// Converts a field element (253 bits) to n 50-bit output elements -// where n <= 5 and n > 1 -template UnpackElement(n) { - signal input in; - signal output out[n]; - assert(n > 1); - assert(n <= 5); - - // Convert input to bits - component inputBits = Num2Bits_strict(); - inputBits.in <== in; - - component outputElements[n]; - for (var i = 0; i < n; i++) { - outputElements[i] = Bits2Num(50); - for (var j = 0; j < 50; j++) { - outputElements[i].in[j] <== inputBits.out[((n - i - 1) * 50) + j]; - } - out[i] <== outputElements[i].out; - } -} diff --git a/circuits/circom/utils.circom b/circuits/circom/utils.circom deleted file mode 100644 index 98edd84f14..0000000000 --- a/circuits/circom/utils.circom +++ /dev/null @@ -1,62 +0,0 @@ -pragma circom 2.0.0; - -// circomlib import -include "./bitify.circom"; - -// the implicit assumption of LessThan is both inputs are at most n bits -// so we need add range check for both inputs -template SafeLessThan(n) { - assert(n <= 252); - signal input in[2]; - signal output out; - - component n2b1 = Num2Bits(n); - n2b1.in <== in[0]; - component n2b2 = Num2Bits(n); - n2b2.in <== in[1]; - - component n2b = Num2Bits(n+1); - - n2b.in <== in[0] + (1< out; -} - -// N is the number of bits the input have. -// The MSF is the sign bit. -template SafeGreaterThan(n) { - signal input in[2]; - signal output out; - - component lt = SafeLessThan(n); - - lt.in[0] <== in[1]; - lt.in[1] <== in[0]; - lt.out ==> out; -} - -// N is the number of bits the input have. -// The MSF is the sign bit. -template SafeGreaterEqThan(n) { - signal input in[2]; - signal output out; - - component lt = SafeLessThan(n); - - lt.in[0] <== in[1]; - lt.in[1] <== in[0]+1; - lt.out ==> out; -} diff --git a/circuits/package.json b/circuits/package.json index cf810252ab..ff125d3ad2 100644 --- a/circuits/package.json +++ b/circuits/package.json @@ -22,24 +22,21 @@ "types": "tsc -p tsconfig.json --noEmit", "test": "ts-mocha --exit ts/__tests__/*.test.ts", "test:hasher": "ts-mocha --exit ts/__tests__/Hasher.test.ts", - "test:unpackElement": "ts-mocha --exit ts/__tests__/UnpackElement.test.ts", "test:slAndBallotTransformer": "ts-mocha --exit ts/__tests__/StateLeafAndBallotTransformer.test.ts", "test:messageToCommand": "ts-mocha --exit ts/__tests__/MessageToCommand.test.ts", "test:messageValidator": "ts-mocha --exit ts/__tests__/MessageValidator.test.ts", "test:verifySignature": "ts-mocha --exit ts/__tests__/VerifySignature.test.ts", "test:splicer": "ts-mocha --exit ts/__tests__/Splicer.test.ts", - "test:ecdh": "ts-mocha --exit ts/__tests__/Ecdh.test.ts", "test:privToPubKey": "ts-mocha --exit ts/__tests__/PrivToPubKey.test.ts", "test:calculateTotal": "ts-mocha --exit ts/__tests__/CalculateTotal.test.ts", "test:processMessages": "NODE_OPTIONS=--max-old-space-size=4096 ts-mocha --exit ts/__tests__/ProcessMessages.test.ts", "test:tallyVotes": "NODE_OPTIONS=--max-old-space-size=4096 ts-mocha --exit ts/__tests__/TallyVotes.test.ts", "test:ceremonyParams": "ts-mocha --exit ts/__tests__/CeremonyParams.test.ts", - "test:quinCheckRoot": "ts-mocha --exit ts/__tests__/QuinCheckRoot.test.ts", - "test:incrementalQuinTree": "ts-mocha --exit ts/__tests__/IncrementalQuinTree.test.ts" + "test:iqt": "ts-mocha --exit ts/__tests__/IQT.test.ts" }, "dependencies": { - "@zk-kit/circuits": "^0.3.0", - "circomkit": "^0.0.22", + "@zk-kit/circuits": "^0.4.0", + "circomkit": "^0.0.24", "circomlib": "^2.0.5", "maci-core": "^1.1.2", "maci-crypto": "^1.1.2", diff --git a/circuits/ts/__tests__/CalculateTotal.test.ts b/circuits/ts/__tests__/CalculateTotal.test.ts index 7ed1beabbe..d5da5487cb 100644 --- a/circuits/ts/__tests__/CalculateTotal.test.ts +++ b/circuits/ts/__tests__/CalculateTotal.test.ts @@ -7,7 +7,7 @@ describe("CalculateTotal circuit", () => { before(async () => { circuit = await circomkitInstance.WitnessTester("calculateTotal", { - file: "trees/calculateTotal", + file: "calculateTotal", template: "CalculateTotal", params: [6], }); diff --git a/circuits/ts/__tests__/Ecdh.test.ts b/circuits/ts/__tests__/Ecdh.test.ts deleted file mode 100644 index 3baa883900..0000000000 --- a/circuits/ts/__tests__/Ecdh.test.ts +++ /dev/null @@ -1,67 +0,0 @@ -import chai, { expect } from "chai"; -import chaiAsPromised from "chai-as-promised"; -import { type WitnessTester } from "circomkit"; -import { Keypair } from "maci-domainobjs"; - -import { circomkitInstance } from "./utils/utils"; - -chai.use(chaiAsPromised); - -describe("Public key derivation circuit", () => { - let circuit: WitnessTester<["privKey", "pubKey"], ["sharedKey"]>; - - before(async () => { - circuit = await circomkitInstance.WitnessTester("ecdh", { - file: "ecdh", - template: "Ecdh", - }); - }); - - it("correctly computes a public key", async () => { - const keypair = new Keypair(); - const keypair2 = new Keypair(); - - const ecdhSharedKey = Keypair.genEcdhSharedKey(keypair.privKey, keypair2.pubKey); - - const circuitInputs = { - privKey: BigInt(keypair.privKey.asCircuitInputs()), - pubKey: keypair2.pubKey.rawPubKey as [bigint, bigint], - }; - - await circuit.expectPass(circuitInputs, { sharedKey: [ecdhSharedKey[0], ecdhSharedKey[1]] }); - }); - - it("should generate the same ECDH key given the same inputs", async () => { - const keypair = new Keypair(); - const keypair2 = new Keypair(); - - const circuitInputs = { - privKey: BigInt(keypair.privKey.asCircuitInputs()), - pubKey: keypair2.pubKey.asCircuitInputs() as unknown as bigint[], - }; - - // calculate first time witness and check contraints - const witness = await circuit.calculateWitness(circuitInputs); - await circuit.expectConstraintPass(witness); - - // read out - const out = await circuit.readWitnessSignals(witness, ["sharedKey"]); - - // calculate again - await circuit.expectPass(circuitInputs, { sharedKey: out.sharedKey }); - }); - - it("should throw when given invalid inputs (pubKey too short)", async () => { - const keypair = new Keypair(); - const keypair2 = new Keypair(); - - const circuitInputs = { - privKey: BigInt(keypair.privKey.asCircuitInputs()), - pubKey: keypair2.pubKey.asCircuitInputs().slice(0, 1) as unknown as [bigint, bigint], - }; - - await expect(circuit.calculateWitness(circuitInputs)).to.be.rejectedWith( - "Not enough values for input signal pubKey", - ); - }); -}); diff --git a/circuits/ts/__tests__/IncrementalQuinTree.test.ts b/circuits/ts/__tests__/IQT.test.ts similarity index 77% rename from circuits/ts/__tests__/IncrementalQuinTree.test.ts rename to circuits/ts/__tests__/IQT.test.ts index aa5e76b84a..d610adb77f 100644 --- a/circuits/ts/__tests__/IncrementalQuinTree.test.ts +++ b/circuits/ts/__tests__/IQT.test.ts @@ -7,7 +7,7 @@ import { getSignal, circomkitInstance } from "./utils/utils"; chai.use(chaiAsPromised); -describe("IncrementalQuinTree circuit", function test() { +describe("Incremental Quinary Tree (IQT)", function test() { this.timeout(50000); const leavesPerNode = 5; @@ -17,31 +17,38 @@ describe("IncrementalQuinTree circuit", function test() { let circuitGeneratePathIndices: WitnessTester<["in"], ["out"]>; let circuitQuinSelector: WitnessTester<["in", "index"], ["out"]>; let splicerCircuit: WitnessTester<["in", "leaf", "index"], ["out"]>; + let quinCheckRoot: WitnessTester<["leaves"], ["root"]>; before(async () => { circuitLeafExists = await circomkitInstance.WitnessTester("quinLeafExists", { - file: "./trees/incrementalQuinTree", + file: "./iqt", template: "QuinLeafExists", params: [3], }); circuitGeneratePathIndices = await circomkitInstance.WitnessTester("quinGeneratePathIndices", { - file: "./trees/incrementalQuinTree", + file: "./iqt", template: "QuinGeneratePathIndices", params: [4], }); circuitQuinSelector = await circomkitInstance.WitnessTester("quinSelector", { - file: "./trees/incrementalQuinTree", + file: "./iqt", template: "QuinSelector", params: [5], }); splicerCircuit = await circomkitInstance.WitnessTester("splicer", { - file: "./trees/incrementalQuinTree", + file: "./iqt", template: "Splicer", params: [4], }); + + quinCheckRoot = await circomkitInstance.WitnessTester("quinCheckRoot", { + file: "iqt", + template: "QuinCheckRoot", + params: [3], + }); }); describe("QuinSelector", () => { @@ -149,4 +156,37 @@ describe("IncrementalQuinTree circuit", function test() { await expect(circuitLeafExists.calculateWitness(circuitInputs)).to.be.rejectedWith("Assert Failed."); }); }); + + describe("QuinCheckRoot", () => { + it("should compute the correct merkle root", async () => { + const leaves = Array(leavesPerNode ** treeDepth).fill(5n); + + const circuitInputs = { + leaves, + }; + + const tree = new IncrementalQuinTree(3, 0n, 5, hash5); + leaves.forEach((leaf) => { + tree.insert(leaf); + }); + + const witness = await quinCheckRoot.calculateWitness(circuitInputs); + await quinCheckRoot.expectConstraintPass(witness); + + const circuitRoot = await getSignal(quinCheckRoot, witness, "root"); + expect(circuitRoot.toString()).to.be.eq(tree.root.toString()); + }); + + it("should not accept less leaves than a full tree", async () => { + const leaves = Array(leavesPerNode ** treeDepth - 1).fill(5n); + + const circuitInputs = { + leaves, + }; + + await expect(quinCheckRoot.calculateWitness(circuitInputs)).to.be.rejectedWith( + "Not enough values for input signal leaves", + ); + }); + }); }); diff --git a/circuits/ts/__tests__/MessageValidator.test.ts b/circuits/ts/__tests__/MessageValidator.test.ts index f678c14a4a..fff62aa8a0 100644 --- a/circuits/ts/__tests__/MessageValidator.test.ts +++ b/circuits/ts/__tests__/MessageValidator.test.ts @@ -10,147 +10,295 @@ import { getSignal, circomkitInstance } from "./utils/utils"; describe("MessageValidator circuit", function test() { this.timeout(90000); - let circuitInputs: IMessageValidatorCircuitInputs; - - let circuit: WitnessTester< - [ - "stateTreeIndex", - "numSignUps", - "voteOptionIndex", - "maxVoteOptions", - "originalNonce", - "nonce", - "cmd", - "pubKey", - "sigR8", - "sigS", - "currentVoiceCreditBalance", - "currentVotesForOption", - "voteWeight", - "slTimestamp", - "pollEndTimestamp", - ], - ["isValid"] - >; - - before(async () => { - circuit = await circomkitInstance.WitnessTester("messageValidator", { - file: "messageValidator", - template: "MessageValidator", + describe("MessageValidatorQV", () => { + let circuitInputs: IMessageValidatorCircuitInputs; + + let circuit: WitnessTester< + [ + "stateTreeIndex", + "numSignUps", + "voteOptionIndex", + "maxVoteOptions", + "originalNonce", + "nonce", + "cmd", + "pubKey", + "sigR8", + "sigS", + "currentVoiceCreditBalance", + "currentVotesForOption", + "voteWeight", + "slTimestamp", + "pollEndTimestamp", + ], + ["isValid"] + >; + + before(async () => { + circuit = await circomkitInstance.WitnessTester("messageValidator", { + file: "messageValidator", + template: "MessageValidator", + }); }); - }); - before(() => { - const { privKey, pubKey } = new Keypair(); - - // Note that the command fields don't matter in this test - const command: PCommand = new PCommand( - BigInt(1), - pubKey, - BigInt(2), - BigInt(3), - BigInt(4), - BigInt(5), - genRandomSalt(), - ); - - const signature = command.sign(privKey); - - circuitInputs = { - stateTreeIndex: 0n as SignalValueType, - numSignUps: 1n, - voteOptionIndex: 0n, - maxVoteOptions: 1n, - originalNonce: 1n, - nonce: 2n, - cmd: command.asCircuitInputs(), - pubKey: pubKey.asCircuitInputs() as unknown as [bigint, bigint], - sigR8: signature.R8 as unknown as bigint, - sigS: signature.S as bigint, - currentVoiceCreditBalance: 100n, - currentVotesForOption: 0n, - voteWeight: 9n, - slTimestamp: 1n, - pollEndTimestamp: 2n, - }; - }); + before(() => { + const { privKey, pubKey } = new Keypair(); - it("should pass if all inputs are valid", async () => { - const witness = await circuit.calculateWitness(circuitInputs); - await circuit.expectConstraintPass(witness); - const isValid = await getSignal(circuit, witness, "isValid"); - expect(isValid.toString()).to.be.eq("1"); - }); + // Note that the command fields don't matter in this test + const command: PCommand = new PCommand( + BigInt(1), + pubKey, + BigInt(2), + BigInt(3), + BigInt(4), + BigInt(5), + genRandomSalt(), + ); - it("should be invalid if the signature is invalid", async () => { - const circuitInputs2 = circuitInputs; - circuitInputs2.sigS = 0n; - const witness = await circuit.calculateWitness(circuitInputs2); - await circuit.expectConstraintPass(witness); - const isValid = await getSignal(circuit, witness, "isValid"); - expect(isValid.toString()).to.be.eq("0"); - }); + const signature = command.sign(privKey); - it("should be invalid if the pubkey is invalid", async () => { - const circuitInputs2 = circuitInputs; - circuitInputs2.pubKey = [0n, 1n]; - const witness = await circuit.calculateWitness(circuitInputs2); - await circuit.expectConstraintPass(witness); - const isValid = await getSignal(circuit, witness, "isValid"); - expect(isValid.toString()).to.be.eq("0"); - }); + circuitInputs = { + stateTreeIndex: 0n as SignalValueType, + numSignUps: 1n, + voteOptionIndex: 0n, + maxVoteOptions: 1n, + originalNonce: 1n, + nonce: 2n, + cmd: command.asCircuitInputs(), + pubKey: pubKey.asCircuitInputs() as unknown as [bigint, bigint], + sigR8: signature.R8 as unknown as bigint, + sigS: signature.S as bigint, + currentVoiceCreditBalance: 100n, + currentVotesForOption: 0n, + voteWeight: 9n, + slTimestamp: 1n, + pollEndTimestamp: 2n, + }; + }); - it("should be invalid if there are insufficient voice credits", async () => { - const circuitInputs2 = circuitInputs; - circuitInputs2.voteWeight = 11n; - const witness = await circuit.calculateWitness(circuitInputs2); - await circuit.expectConstraintPass(witness); - const isValid = await getSignal(circuit, witness, "isValid"); - expect(isValid.toString()).to.be.eq("0"); - }); + it("should pass if all inputs are valid", async () => { + const witness = await circuit.calculateWitness(circuitInputs); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("1"); + }); - it("should be invalid if the nonce is invalid", async () => { - const circuitInputs2 = circuitInputs; - circuitInputs2.nonce = 3n; - const witness = await circuit.calculateWitness(circuitInputs2); - await circuit.expectConstraintPass(witness); - const isValid = await getSignal(circuit, witness, "isValid"); - expect(isValid.toString()).to.be.eq("0"); - }); + it("should be invalid if the signature is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.sigS = 0n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); - it("should be invalid if the state leaf index is invalid", async () => { - const circuitInputs2 = circuitInputs; - circuitInputs2.stateTreeIndex = 2n; - const witness = await circuit.calculateWitness(circuitInputs2); - await circuit.expectConstraintPass(witness); - const isValid = await getSignal(circuit, witness, "isValid"); - expect(isValid.toString()).to.be.eq("0"); - }); + it("should be invalid if the pubkey is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.pubKey = [0n, 1n]; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); - it("should be invalid if the vote option index is invalid", async () => { - const circuitInputs2 = circuitInputs; - circuitInputs2.voteOptionIndex = 1n; - const witness = await circuit.calculateWitness(circuitInputs2); - await circuit.expectConstraintPass(witness); - const isValid = await getSignal(circuit, witness, "isValid"); - expect(isValid.toString()).to.be.eq("0"); - }); + it("should be invalid if there are insufficient voice credits", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.voteWeight = 11n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if the nonce is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.nonce = 3n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if the state leaf index is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.stateTreeIndex = 2n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); - it("should be invalid if the vote option index is invalid", async () => { - const circuitInputs2 = circuitInputs; - circuitInputs2.voteOptionIndex = 6049261729n; - const witness = await circuit.calculateWitness(circuitInputs2); - await circuit.expectConstraintPass(witness); - const isValid = await getSignal(circuit, witness, "isValid"); - expect(isValid.toString()).to.be.eq("0"); + it("should be invalid if the vote option index is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.voteOptionIndex = 1n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if the vote option index is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.voteOptionIndex = 6049261729n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if the state leaf timestamp is too high", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.slTimestamp = 3n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); }); - it("should be invalid if the state leaf timestamp is too high", async () => { - const circuitInputs2 = circuitInputs; - circuitInputs2.slTimestamp = 3n; - const witness = await circuit.calculateWitness(circuitInputs2); - await circuit.expectConstraintPass(witness); - const isValid = await getSignal(circuit, witness, "isValid"); - expect(isValid.toString()).to.be.eq("0"); + describe("MessageValidatorNonQV", () => { + let circuitInputs: IMessageValidatorCircuitInputs; + + let circuit: WitnessTester< + [ + "stateTreeIndex", + "numSignUps", + "voteOptionIndex", + "maxVoteOptions", + "originalNonce", + "nonce", + "cmd", + "pubKey", + "sigR8", + "sigS", + "currentVoiceCreditBalance", + "currentVotesForOption", + "voteWeight", + "slTimestamp", + "pollEndTimestamp", + ], + ["isValid"] + >; + + before(async () => { + circuit = await circomkitInstance.WitnessTester("messageValidatorNonQv", { + file: "messageValidatorNonQv", + template: "MessageValidatorNonQv", + }); + }); + + before(() => { + const { privKey, pubKey } = new Keypair(); + + // Note that the command fields don't matter in this test + const command: PCommand = new PCommand( + BigInt(1), + pubKey, + BigInt(2), + BigInt(3), + BigInt(4), + BigInt(5), + genRandomSalt(), + ); + + const signature = command.sign(privKey); + + circuitInputs = { + stateTreeIndex: 0n as SignalValueType, + numSignUps: 1n, + voteOptionIndex: 0n, + maxVoteOptions: 1n, + originalNonce: 1n, + nonce: 2n, + cmd: command.asCircuitInputs(), + pubKey: pubKey.asCircuitInputs() as unknown as [bigint, bigint], + sigR8: signature.R8 as unknown as bigint, + sigS: signature.S as bigint, + currentVoiceCreditBalance: 100n, + currentVotesForOption: 0n, + voteWeight: 9n, + slTimestamp: 1n, + pollEndTimestamp: 2n, + }; + }); + + it("should pass if all inputs are valid", async () => { + const witness = await circuit.calculateWitness(circuitInputs); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("1"); + }); + + it("should be invalid if the signature is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.sigS = 0n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if the pubkey is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.pubKey = [0n, 1n]; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if there are insufficient voice credits", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.voteWeight = 11n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if the nonce is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.nonce = 3n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if the state leaf index is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.stateTreeIndex = 2n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if the vote option index is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.voteOptionIndex = 1n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if the vote option index is invalid", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.voteOptionIndex = 6049261729n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); + + it("should be invalid if the state leaf timestamp is too high", async () => { + const circuitInputs2 = circuitInputs; + circuitInputs2.slTimestamp = 3n; + const witness = await circuit.calculateWitness(circuitInputs2); + await circuit.expectConstraintPass(witness); + const isValid = await getSignal(circuit, witness, "isValid"); + expect(isValid.toString()).to.be.eq("0"); + }); }); }); diff --git a/circuits/ts/__tests__/QuinCheckRoot.test.ts b/circuits/ts/__tests__/QuinCheckRoot.test.ts deleted file mode 100644 index 5ec5cd719f..0000000000 --- a/circuits/ts/__tests__/QuinCheckRoot.test.ts +++ /dev/null @@ -1,56 +0,0 @@ -import chai, { expect } from "chai"; -import chaiAsPromised from "chai-as-promised"; -import { type WitnessTester } from "circomkit"; -import { IncrementalQuinTree, hash5 } from "maci-crypto"; - -import { getSignal, circomkitInstance } from "./utils/utils"; - -chai.use(chaiAsPromised); - -describe("QuinCheckRoot circuit", function test() { - this.timeout(50000); - - const leavesPerNode = 5; - const treeDepth = 3; - - let circuit: WitnessTester<["leaves"], ["root"]>; - - before(async () => { - circuit = await circomkitInstance.WitnessTester("checkRoot", { - file: "trees/checkRoot", - template: "QuinCheckRoot", - params: [3], - }); - }); - - it("should compute the correct merkle root", async () => { - const leaves = Array(leavesPerNode ** treeDepth).fill(5n); - - const circuitInputs = { - leaves, - }; - - const tree = new IncrementalQuinTree(3, 0n, 5, hash5); - leaves.forEach((leaf) => { - tree.insert(leaf); - }); - - const witness = await circuit.calculateWitness(circuitInputs); - await circuit.expectConstraintPass(witness); - - const circuitRoot = await getSignal(circuit, witness, "root"); - expect(circuitRoot.toString()).to.be.eq(tree.root.toString()); - }); - - it("should not accept less leaves than a full tree", async () => { - const leaves = Array(leavesPerNode ** treeDepth - 1).fill(5n); - - const circuitInputs = { - leaves, - }; - - await expect(circuit.calculateWitness(circuitInputs)).to.be.rejectedWith( - "Not enough values for input signal leaves", - ); - }); -}); diff --git a/circuits/ts/__tests__/UnpackElement.test.ts b/circuits/ts/__tests__/UnpackElement.test.ts deleted file mode 100644 index dff9cc8c2a..0000000000 --- a/circuits/ts/__tests__/UnpackElement.test.ts +++ /dev/null @@ -1,75 +0,0 @@ -import { expect } from "chai"; -import { type WitnessTester } from "circomkit"; -import { genRandomSalt } from "maci-crypto"; - -import { getSignal, circomkitInstance } from "./utils/utils"; - -describe("UnpackElement circuit", () => { - let circuit: WitnessTester<["in"], ["out"]>; - - before(async () => { - circuit = await circomkitInstance.WitnessTester("unpackElement", { - file: "unpackElement", - template: "UnpackElement", - params: [5], - }); - }); - - it("should unpack a field element with 5 packed values correctly", async () => { - const elements: string[] = []; - for (let i = 0; i < 5; i += 1) { - let e = (BigInt(genRandomSalt().toString()) % BigInt(2 ** 50)).toString(2); - while (e.length < 50) { - e = `0${e}`; - } - elements.push(e); - } - - const circuitInputs = { - in: BigInt(`0b${elements.join("")}`), - }; - - const witness = await circuit.calculateWitness(circuitInputs); - await circuit.expectConstraintPass(witness); - - for (let i = 0; i < 5; i += 1) { - // eslint-disable-next-line no-await-in-loop - const out = await getSignal(circuit, witness, `out[${i}]`); - expect(BigInt(`0b${BigInt(out).toString(2)}`).toString()).to.be.eq(BigInt(`0b${elements[i]}`).toString()); - } - }); - - describe("unpackElement4", () => { - before(async () => { - circuit = await circomkitInstance.WitnessTester("unpackElement", { - file: "unpackElement", - template: "UnpackElement", - params: [4], - }); - }); - - it("should unpack a field element with 4 packed values correctly", async () => { - const elements: string[] = []; - for (let i = 0; i < 4; i += 1) { - let e = (BigInt(genRandomSalt().toString()) % BigInt(2 ** 50)).toString(2); - while (e.length < 50) { - e = `0${e}`; - } - elements.push(e); - } - - const circuitInputs = { - in: BigInt(`0b${elements.join("")}`), - }; - - const witness = await circuit.calculateWitness(circuitInputs); - await circuit.expectConstraintPass(witness); - - for (let i = 0; i < 4; i += 1) { - // eslint-disable-next-line no-await-in-loop - const out = await getSignal(circuit, witness, `out[${i}]`); - expect(BigInt(`0b${BigInt(out).toString(2)}`).toString()).to.be.eq(BigInt(`0b${elements[i]}`).toString()); - } - }); - }); -}); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 033d142fbc..ff4365d321 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -93,11 +93,11 @@ importers: circuits: dependencies: '@zk-kit/circuits': - specifier: ^0.3.0 - version: 0.3.0 + specifier: ^0.4.0 + version: 0.4.0 circomkit: - specifier: ^0.0.22 - version: 0.0.22 + specifier: ^0.0.24 + version: 0.0.24 circomlib: specifier: ^2.0.5 version: 2.0.5 @@ -6095,8 +6095,8 @@ packages: '@zk-kit/utils': 0.1.0 dev: false - /@zk-kit/circuits@0.3.0: - resolution: {integrity: sha512-v46KHC3sBRXUJbYi8d5PTAm3zCdBeArvWw3de+A2LcW/C9beYqBo8QJ/h6NWKZWOgpwqvCHzJa5HvyG6x3lIZQ==} + /@zk-kit/circuits@0.4.0: + resolution: {integrity: sha512-Di7mokhwBS3qxVeCfHxGeNIpDg1kTnr1JXmsWiQMZLkRTn3Hugh6Tl07J394rWD0pIWRwPQsinaMVL2sB4F8yQ==} dependencies: circomlib: 2.0.5 dev: false @@ -7323,8 +7323,8 @@ packages: util: 0.12.5 dev: false - /circomkit@0.0.22: - resolution: {integrity: sha512-rfNiDCBrg/c9JI4nbvpKa123eFyrBFHBHUUeJFjHpzuG7Zw3KMbIUs4dw9xoIeqtzLYL4Hs4jXTMKluStEMxKA==} + /circomkit@0.0.24: + resolution: {integrity: sha512-lw5Kj6zAWS8NYZjlDCGEDeA1e0/Vpa6t6W3GT0AxfhswUoqK0Nu3sz5hu8ZQ+Efh0Ss3eLoD0y+9sOkySEwgEA==} engines: {node: '>=12.0.0'} hasBin: true dependencies: